├── .flake8 ├── .gitignore ├── .pre-commit-config.yaml ├── README.md ├── commitlint.config.js ├── discounting_chain.ipynb ├── discounting_chain ├── __init__.py ├── a2c.py ├── base.py ├── bmg_a2c.py ├── data │ ├── discounting_chain_appendix_histories_array.pickle │ └── discounting_chain_histories_array.pickle ├── data_generation │ ├── __init__.py │ ├── n_step.py │ ├── n_step_data_generator.py │ └── n_step_utils.py ├── envs │ ├── __init__.py │ ├── env_utils.py │ ├── gymnax_dc_wrapper.py │ └── gymnax_wrapper.py ├── list_logger.py ├── meta_a2c.py ├── nets.py ├── nets_split.py └── train_utils.py ├── discounting_chain_train.ipynb ├── mypy.ini ├── plots ├── discounting_chain │ ├── dc_chain_discount_factor.png │ ├── dc_chain_discount_factor_appendix.png │ ├── dc_chain_outer_loss_advantage.png │ ├── dc_chain_outer_loss_advantage_appendix.png │ ├── dc_chain_return.png │ └── dc_chain_return_appendix.png └── snake │ ├── snake_discount_factor.png │ ├── snake_discount_factor_appendix.png │ ├── snake_outer_loss_advantage.png │ ├── snake_outer_loss_advantage_appendix.png │ ├── snake_return.png │ └── snake_return_appendix.png ├── requirements-dev.txt ├── requirements.txt ├── snake.ipynb ├── snake ├── __init__.py ├── agent │ ├── __init__.py │ ├── a2c.py │ ├── actor_critic_agent.py │ ├── gae.py │ └── meta_a2c.py ├── configs │ ├── agent │ │ ├── a2c.yaml │ │ ├── bootstrap.yaml │ │ └── mgrl.yaml │ └── config.yaml ├── data │ ├── a2c_gamma │ │ ├── MET-2150__train_mean_gamma.csv │ │ ├── MET-2156__train_mean_gamma.csv │ │ ├── MET-2158__train_mean_gamma.csv │ │ ├── MET-2162__train_mean_gamma.csv │ │ ├── MET-2165__train_mean_gamma.csv │ │ ├── MET-2168__train_mean_gamma.csv │ │ ├── MET-2171__train_mean_gamma.csv │ │ ├── MET-2175__train_mean_gamma.csv │ │ ├── MET-2177__train_mean_gamma.csv │ │ └── MET-2180__train_mean_gamma.csv │ ├── a2c_return │ │ ├── MET-2150__eval_episode_reward_determinist_policy.csv │ │ ├── MET-2156__eval_episode_reward_determinist_policy.csv │ │ ├── MET-2158__eval_episode_reward_determinist_policy.csv │ │ ├── MET-2162__eval_episode_reward_determinist_policy.csv │ │ ├── MET-2165__eval_episode_reward_determinist_policy.csv │ │ ├── MET-2168__eval_episode_reward_determinist_policy.csv │ │ ├── MET-2171__eval_episode_reward_determinist_policy.csv │ │ ├── MET-2175__eval_episode_reward_determinist_policy.csv │ │ ├── MET-2177__eval_episode_reward_determinist_policy.csv │ │ └── MET-2180__eval_episode_reward_determinist_policy.csv │ ├── appendix │ │ ├── bias │ │ │ ├── bias_bootstrap_no_outer_no_norm.csv │ │ │ ├── bias_bootstrap_no_outer_norm.csv │ │ │ ├── bias_bootstrap_outer_no_norm.csv │ │ │ ├── bias_bootstrap_outer_norm.csv │ │ │ ├── bias_mgrl_no_outer_no_norm.csv │ │ │ ├── bias_mgrl_no_outer_norm.csv │ │ │ ├── bias_mgrl_outer_no_norm.csv │ │ │ └── bias_mgrl_outer_norm.csv │ │ ├── gamma │ │ │ ├── gamma_a2c_no_norm.csv │ │ │ ├── gamma_a2c_norm.csv │ │ │ ├── gamma_bootstrap_no_outer_no_norm.csv │ │ │ ├── gamma_bootstrap_no_outer_norm.csv │ │ │ ├── gamma_bootstrap_outer_no_norm.csv │ │ │ ├── gamma_bootstrap_outer_norm.csv │ │ │ ├── gamma_mgrl_no_outer_no_norm.csv │ │ │ ├── gamma_mgrl_no_outer_norm.csv │ │ │ ├── gamma_mgrl_outer_no_norm.csv │ │ │ └── gamma_mgrl_outer_norm.csv │ │ └── return │ │ │ ├── return_a2c_no_norm.csv │ │ │ ├── return_a2c_norm.csv │ │ │ ├── return_bootstrap_no_outer_no_norm.csv │ │ │ ├── return_bootstrap_no_outer_norm.csv │ │ │ ├── return_bootstrap_outer_no_norm.csv │ │ │ ├── return_bootstrap_outer_norm.csv │ │ │ ├── return_mgrl_no_outer_no_norm.csv │ │ │ ├── return_mgrl_no_outer_norm.csv │ │ │ ├── return_mgrl_outer_no_norm.csv │ │ │ └── return_mgrl_outer_norm.csv │ ├── bootstrap_bias │ │ ├── MET-2154__train_mean_advantage_outer_loss.csv │ │ ├── MET-2160__train_mean_advantage_outer_loss.csv │ │ ├── MET-2166__train_mean_advantage_outer_loss.csv │ │ ├── MET-2172__train_mean_advantage_outer_loss.csv │ │ ├── MET-2178__train_mean_advantage_outer_loss.csv │ │ ├── MET-2183__train_mean_advantage_outer_loss.csv │ │ ├── MET-2188__train_mean_advantage_outer_loss.csv │ │ ├── MET-2192__train_mean_advantage_outer_loss.csv │ │ ├── MET-2196__train_mean_advantage_outer_loss.csv │ │ └── MET-2199__train_mean_advantage_outer_loss.csv │ ├── bootstrap_gamma │ │ ├── MET-2154__train_mean_gamma.csv │ │ ├── MET-2160__train_mean_gamma.csv │ │ ├── MET-2166__train_mean_gamma.csv │ │ ├── MET-2172__train_mean_gamma.csv │ │ ├── MET-2178__train_mean_gamma.csv │ │ ├── MET-2183__train_mean_gamma.csv │ │ ├── MET-2188__train_mean_gamma.csv │ │ ├── MET-2192__train_mean_gamma.csv │ │ ├── MET-2196__train_mean_gamma.csv │ │ └── MET-2199__train_mean_gamma.csv │ ├── bootstrap_outer_critic_bias │ │ ├── MET-2153__train_mean_advantage_outer_loss.csv │ │ ├── MET-2161__train_mean_advantage_outer_loss.csv │ │ ├── MET-2167__train_mean_advantage_outer_loss.csv │ │ ├── MET-2173__train_mean_advantage_outer_loss.csv │ │ ├── MET-2181__train_mean_advantage_outer_loss.csv │ │ ├── MET-2185__train_mean_advantage_outer_loss.csv │ │ ├── MET-2190__train_mean_advantage_outer_loss.csv │ │ ├── MET-2195__train_mean_advantage_outer_loss.csv │ │ ├── MET-2198__train_mean_advantage_outer_loss.csv │ │ └── MET-2200__train_mean_advantage_outer_loss.csv │ ├── bootstrap_outer_critic_gamma │ │ ├── MET-2153__train_mean_gamma.csv │ │ ├── MET-2161__train_mean_gamma.csv │ │ ├── MET-2167__train_mean_gamma.csv │ │ ├── MET-2173__train_mean_gamma.csv │ │ ├── MET-2181__train_mean_gamma.csv │ │ ├── MET-2185__train_mean_gamma.csv │ │ ├── MET-2190__train_mean_gamma.csv │ │ ├── MET-2195__train_mean_gamma.csv │ │ ├── MET-2198__train_mean_gamma.csv │ │ └── MET-2200__train_mean_gamma.csv │ ├── bootstrap_outer_critic_return │ │ ├── MET-2153__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2161__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2167__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2173__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2181__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2185__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2190__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2195__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2198__eval_episode_reward_stochastic_policy.csv │ │ └── MET-2200__eval_episode_reward_stochastic_policy.csv │ ├── bootstrap_return │ │ ├── MET-2154__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2160__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2166__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2172__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2178__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2183__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2188__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2192__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2196__eval_episode_reward_stochastic_policy.csv │ │ └── MET-2199__eval_episode_reward_stochastic_policy.csv │ ├── mgrl_bias │ │ ├── MET-2151__train_mean_advantage_outer_loss.csv │ │ ├── MET-2157__train_mean_advantage_outer_loss.csv │ │ ├── MET-2163__train_mean_advantage_outer_loss.csv │ │ ├── MET-2169__train_mean_advantage_outer_loss.csv │ │ ├── MET-2174__train_mean_advantage_outer_loss.csv │ │ ├── MET-2179__train_mean_advantage_outer_loss.csv │ │ ├── MET-2184__train_mean_advantage_outer_loss.csv │ │ ├── MET-2187__train_mean_advantage_outer_loss.csv │ │ ├── MET-2191__train_mean_advantage_outer_loss.csv │ │ └── MET-2194__train_mean_advantage_outer_loss.csv │ ├── mgrl_gamma │ │ ├── MET-2151__train_mean_gamma.csv │ │ ├── MET-2157__train_mean_gamma.csv │ │ ├── MET-2163__train_mean_gamma.csv │ │ ├── MET-2169__train_mean_gamma.csv │ │ ├── MET-2174__train_mean_gamma.csv │ │ ├── MET-2179__train_mean_gamma.csv │ │ ├── MET-2184__train_mean_gamma.csv │ │ ├── MET-2187__train_mean_gamma.csv │ │ ├── MET-2191__train_mean_gamma.csv │ │ └── MET-2194__train_mean_gamma.csv │ ├── mgrl_outer_critic_bias │ │ ├── MET-2152__train_mean_advantage_outer_loss.csv │ │ ├── MET-2159__train_mean_advantage_outer_loss.csv │ │ ├── MET-2164__train_mean_advantage_outer_loss.csv │ │ ├── MET-2170__train_mean_advantage_outer_loss.csv │ │ ├── MET-2176__train_mean_advantage_outer_loss.csv │ │ ├── MET-2182__train_mean_advantage_outer_loss.csv │ │ ├── MET-2186__train_mean_advantage_outer_loss.csv │ │ ├── MET-2189__train_mean_advantage_outer_loss.csv │ │ ├── MET-2193__train_mean_advantage_outer_loss.csv │ │ └── MET-2197__train_mean_advantage_outer_loss.csv │ ├── mgrl_outer_critic_gamma │ │ ├── MET-2152__train_mean_gamma.csv │ │ ├── MET-2159__train_mean_gamma.csv │ │ ├── MET-2164__train_mean_gamma.csv │ │ ├── MET-2170__train_mean_gamma.csv │ │ ├── MET-2176__train_mean_gamma.csv │ │ ├── MET-2182__train_mean_gamma.csv │ │ ├── MET-2186__train_mean_gamma.csv │ │ ├── MET-2189__train_mean_gamma.csv │ │ ├── MET-2193__train_mean_gamma.csv │ │ └── MET-2197__train_mean_gamma.csv │ ├── mgrl_outer_critic_return │ │ ├── MET-2152__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2159__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2164__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2170__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2176__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2182__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2186__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2189__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2193__eval_episode_reward_stochastic_policy.csv │ │ └── MET-2197__eval_episode_reward_stochastic_policy.csv │ └── mgrl_return │ │ ├── MET-2151__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2157__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2163__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2169__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2174__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2179__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2184__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2187__eval_episode_reward_stochastic_policy.csv │ │ ├── MET-2191__eval_episode_reward_stochastic_policy.csv │ │ └── MET-2194__eval_episode_reward_stochastic_policy.csv ├── env │ ├── __init__.py │ └── snake.py ├── networks │ ├── __init__.py │ ├── actor_critic.py │ ├── cnn.py │ ├── distribution.py │ └── snake.py └── training │ ├── __init__.py │ ├── config.py │ ├── evaluator.py │ ├── logger.py │ ├── setup_run.py │ ├── types.py │ └── utils.py └── snake_train.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | select = A,B,C,D,E,F,G,I,N,T,W 3 | exclude = 4 | .tox, 5 | .git, 6 | __pycache__, 7 | build, 8 | dist, 9 | proto/*, 10 | *.pyc, 11 | *.egg-info, 12 | .cache, 13 | .eggs 14 | max-line-length=100 15 | import-order-style = google 16 | application-import-names = metal 17 | doctests = True 18 | docstring-convention = google 19 | per-file-ignores = __init__.py:F401 20 | 21 | ignore = 22 | D107 # Do not require docstrings for __init__ 23 | W503 # line break before binary operator (not compatible with black) 24 | E731 # do not assign a lambda expression, use a def 25 | N802 # function names should be lowercase 26 | N803 # arguments should be lowercase 27 | N806 # variables should be lowercase 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | **/.ipynb_checkpoints/ 3 | venv/ 4 | libtpu_lockfile 5 | .idea/ 6 | **videos/ 7 | **/params* 8 | 9 | # wandb 10 | **/logs/ 11 | **/outputs/ 12 | **/multirun/ 13 | **/wandb/ 14 | 15 | # internal scripts 16 | **/internal_* 17 | 18 | *.pyc 19 | *.iml 20 | *.xml 21 | *.log 22 | playground/* 23 | logs/* 24 | .neptune 25 | 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | pip-wheel-metadata/ 41 | share/python-wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | MANIFEST 46 | src/* 47 | 48 | # Tests 49 | coverage_html_report 50 | .coverage 51 | test-results.xml 52 | 53 | # Functional tests 54 | generated-functional-tests-runs-ci.yml 55 | 56 | # IDE Configs 57 | .vscode 58 | .devcontainer 59 | .idea 60 | 61 | # Logs 62 | /examples/logs/ 63 | 64 | /nul/ 65 | /checkpoints/ 66 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_stages: [ "commit", "commit-msg", "push" ] 2 | default_language_version: 3 | python: python3.8 4 | 5 | 6 | repos: 7 | - repo: https://github.com/timothycrosley/isort 8 | rev: 5.0.4 9 | hooks: 10 | - id: isort 11 | args: ["--profile", "black"] 12 | 13 | - repo: https://github.com/ambv/black 14 | rev: 21.5b1 15 | hooks: 16 | - id: black 17 | 18 | - repo: https://github.com/pre-commit/pre-commit-hooks 19 | rev: v3.4.0 20 | hooks: 21 | # - id: end-of-file-fixer # removed because of notebooks 22 | - id: debug-statements 23 | - id: requirements-txt-fixer 24 | - id: mixed-line-ending 25 | - id: check-yaml 26 | args: [ '--unsafe' ] 27 | - id: trailing-whitespace 28 | 29 | - repo: https://gitlab.com/pycqa/flake8 30 | rev: 3.9.2 31 | hooks: 32 | - id: flake8 33 | args: 34 | - --max-line-length=100 35 | - --max-cognitive-complexity=10 36 | - --ignore=E266,E501,E731,W503 37 | additional_dependencies: 38 | - pep8-naming 39 | - flake8-builtins 40 | - flake8-comprehensions 41 | - flake8-bugbear 42 | - flake8-pytest-style 43 | - flake8-cognitive-complexity 44 | 45 | - repo: https://github.com/pre-commit/mirrors-mypy 46 | rev: v0.812 47 | hooks: 48 | - id: mypy 49 | 50 | - repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook 51 | rev: v9.0.0 52 | hooks: 53 | - id: commitlint 54 | stages: [ commit-msg ] 55 | additional_dependencies: [ '@commitlint/config-conventional' ] 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Debiasing Meta-Gradient Reinforcement Learning by Learning the Outer Value Function 2 | 3 | This repository contains the code from the paper [_Debiasing Meta-Gradient Reinforcement Learning by 4 | Learning the Outer Value Function_](https://arxiv.org/abs/2211.10550) (Clément Bonnet, Laurence Midgley, 5 | Alexandre Laterre) published at the [6th Workshop on Meta-Learning](https://meta-learn.github.io/2022/) 6 | at NeurIPS 2022, New Orleans. 7 | 8 | ## Abstract 9 | Meta-gradient Reinforcement Learning (RL) allows agents to self-tune their hyper-parameters in an online fashion during training. 10 | In this paper, we identify a bias in the meta-gradient of current meta-gradient RL approaches. 11 | This bias comes from using the critic that is trained using the meta-learned discount factor for the advantage estimation in the outer objective which requires a different discount factor. 12 | Because the meta-learned discount factor is typically lower than the one used in the outer objective, the resulting bias can cause the meta-gradient to favor myopic policies. 13 | We propose a simple solution to this issue: we eliminate this bias by using an alternative, *outer* value function in the estimation of the outer loss. 14 | To obtain this outer value function we add a second head to the critic network and train it alongside the classic critic, using the outer loss discount factor. 15 | On an illustrative toy problem, we show that the bias can cause catastrophic failure of current meta-gradient RL approaches, and show that our proposed solution fixes it. 16 | We then apply our method to a more complex environment and demonstrate that fixing the meta-gradient bias can significantly improve performance. 17 | 18 | 19 | ## Experiments 20 | 21 | We denote: 22 | - A2C: Advantage Actor Critic 23 | - MG: meta-gradient algorithm from [Xu et al., 2018] 24 | - BMG: bootstrapped meta-gradients from [Flennerhag et al., 2022] 25 | - MG outer-critic: the MG algorithm equipped with an outer-critic that estimates the outer value function used in the outer loss 26 | - BMG outer-critic: the BMG algorithm similarly equipped with an outer-critic 27 | 28 | ### Discounting Chain 29 |

30 | Discounting Chain Return 31 | Discounting Chain Discount Factor 32 |

33 | 34 | ### Snake 35 |

36 | Snake Return 37 | Snake Discount Factor 38 |

39 | 40 | 41 | 42 | ## Reproducibility 43 | 44 | We provide a `requirements.txt` file with all the tagged packages needed to reproduce the 45 | experiments in the paper. 46 | Here is a snippet of commands to set up a virtual environment and install these packages. 47 | Alternatively, one can use a Conda environment or a similar solution. 48 | ```shell 49 | python -m venv venv 50 | source venv/bin/activate 51 | pip install -U pip setuptools wheel 52 | pip install -r requirements.txt 53 | ``` 54 | 55 | ### Discounting Chain 56 | 57 | The Discounting Chain environment is originally from [bsuite](https://github.com/deepmind/bsuite) but is imported from [Gymnax](https://github.com/RobertTLange/gymnax) in this paper to benefit from its JAX implementation. 58 | 59 | To reproduce the experiments on the Discounting Chain environment, you can run the jupyter notebook `discounting_chain_train.ipynb`. 60 | The other notebook `discounting_chain.ipynb` loads the data and plots the figures from the paper. 61 | 62 | ### Snake 63 | 64 | The Snake environment is provided by [Jumanji](https://github.com/instadeepai/jumanji) in JAX. 65 | 66 | To reproduce the experiments on the Snake environment, one can run the following commands. 67 | 68 | - Advantage Actor Critic (A2C) 69 | ```shell 70 | python snake_train.py -m agent=a2c training.seed=1,2,3,4,5,6,7,8,9,10 71 | ``` 72 | 73 | - Meta-Gradient (MG) 74 | ```shell 75 | python snake_train.py -m agent=mgrl agent.outer_critic=false training.seed=1,2,3,4,5,6,7,8,9,10 76 | python snake_train.py -m agent=mgrl agent.outer_critic=true training.seed=1,2,3,4,5,6,7,8,9,10 77 | ``` 78 | 79 | - Bootstrapped Meta-Gradient (BMG) 80 | ```shell 81 | python snake_train.py -m agent=bootstrap agent.outer_critic=false training.seed=1,2,3,4,5,6,7,8,9,10 82 | python snake_train.py -m agent=bootstrap agent.outer_critic=true training.seed=1,2,3,4,5,6,7,8,9,10 83 | ``` 84 | 85 | - appendix 86 | ```shell 87 | python snake_train.py -m agent=a2c agent.outer_critic=false agent.normalize_advantage=false,true agent.normalize_outer_advantage=false training.seed=1 88 | python snake_train.py -m agent=mgrl agent.outer_critic=false,true agent.normalize_advantage=false,true agent.normalize_outer_advantage=false,true training.seed=1 89 | python snake_train.py -m agent=bootstrap agent.outer_critic=false,true agent.normalize_advantage=false,true agent.normalize_outer_advantage=false,true training.seed=1 90 | ``` 91 | 92 | Note that the default logger is `"terminal"`. If you want to save the data, a Neptune logger is implemented, 93 | and you can enable it by replacing `"terminal"` with `"neptune"` in `snake_train.py`. 94 | 95 | For the paper, the data from these runs was collected and uploaded in `snake/data/`. 96 | The `snake.ipynb` notebook loads this data and make the plots from the paper. 97 | 98 | 99 | ## Citation 100 | 101 | For attribution in academic contexts, please use the following citation. 102 | 103 | ``` 104 | @misc{bonnet2022debiasing, 105 | title = {Debiasing Meta-Gradient Reinforcement Learning by Learning the Outer Value Function}, 106 | author = {Bonnet, Clément and Midgley, Laurence and Laterre, Alexandre}, 107 | doi = {10.48550/ARXIV.2211.10550}, 108 | url = {https://arxiv.org/abs/2211.10550}, 109 | year = {2022}, 110 | booktitle={Sixth Workshop on Meta-Learning at the Conference on Neural Information Processing Systems}, 111 | } 112 | ``` 113 | -------------------------------------------------------------------------------- /commitlint.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | extends: ['@commitlint/config-conventional'], 3 | rules: { 4 | 'subject-case': [1, 'always'], 5 | 'header-max-length': [1, 'always', 100], 6 | 'type-enum': [ 7 | 2, 8 | 'always', 9 | [ 10 | 'build', 11 | 'chore', 12 | 'ci', 13 | 'docs', 14 | 'feat', 15 | 'fix', 16 | 'perf', 17 | 'refactor', 18 | 'revert', 19 | 'style', 20 | 'test', 21 | 'exp', 22 | 'func', 23 | ], 24 | ], 25 | }, 26 | } 27 | -------------------------------------------------------------------------------- /discounting_chain/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/discounting_chain/__init__.py -------------------------------------------------------------------------------- /discounting_chain/base.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, NamedTuple, Optional, Tuple 2 | 3 | import chex 4 | from acme import specs 5 | from acme.jax.networks import Action, Observation 6 | from dm_env import TimeStep 7 | from jax import numpy as jnp 8 | from typing_extensions import Protocol 9 | 10 | ArrayTree = Any 11 | AgentState = ArrayTree 12 | EnvironmentState = ArrayTree 13 | DataGeneratorState = ArrayTree 14 | BufferState = ArrayTree 15 | Extras = ArrayTree 16 | Batch = ArrayTree 17 | 18 | Metrics = Dict[str, jnp.ndarray] 19 | SelectAction = Callable[[AgentState, Observation, chex.PRNGKey], Tuple[Action, Extras]] 20 | EvaluateFn = Callable[[AgentState, chex.PRNGKey], Metrics] 21 | 22 | 23 | class DataGenerationInfo(NamedTuple): 24 | """Container for the information returned by data generation.""" 25 | 26 | metrics: Metrics 27 | num_steps: jnp.int_ 28 | num_episodes: jnp.int_ 29 | 30 | 31 | class EnvLoopState(NamedTuple): 32 | """State of the environment loop.""" 33 | 34 | data_generator: DataGeneratorState 35 | buffer: BufferState 36 | agent: AgentState 37 | num_steps: jnp.int_ 38 | num_episodes: jnp.int_ 39 | num_iterations: jnp.int_ 40 | 41 | 42 | class GenerateDataFn(Protocol): 43 | """Callable specification for generation of experience. 44 | 45 | This typically contains multi-step, batched acting within a jittable environment.""" 46 | 47 | def __call__( 48 | self, agent_state: AgentState, data_generator_state: DataGeneratorState 49 | ) -> Tuple[Batch, DataGeneratorState, DataGenerationInfo]: 50 | """ 51 | The `generate_data` function. 52 | 53 | Args: 54 | agent_state: State of the agent itself (i.e. trainable parameters), 55 | used for generating actions. 56 | data_generator_state: State used during acting for management of the environment 57 | (e.g. keeping track of environment state). 58 | 59 | Returns: 60 | batch: Batch of data that will be passed to the agent update. 61 | data_generator_state: Updated data generator state. 62 | data_generation_info: Metrics from the data generation (information on acting). 63 | """ 64 | 65 | 66 | class DataGenerator(NamedTuple): 67 | """Container specifying pure functions for generation of experience data.""" 68 | 69 | init: Callable[[chex.PRNGKey], DataGeneratorState] 70 | generate_data: GenerateDataFn 71 | 72 | 73 | class Buffer(NamedTuple): 74 | """Buffer for managing the storage of experience generated by the actor, 75 | and the sampling of experience for updating the agent's parameters. 76 | 77 | The `sample` and `add` functions are both jittable.""" 78 | 79 | add: Callable[[BufferState, Batch], BufferState] 80 | sample: Callable[[BufferState], Tuple[Batch, BufferState]] 81 | init: Callable[[chex.PRNGKey], BufferState] 82 | 83 | 84 | class Agent(NamedTuple): 85 | """Pure functions defining the key components of the agent. 86 | 87 | Both the update function (which updates the agent's learnt parameters), and the 88 | select action function (which selects actions within the environment) are jittable. 89 | If it is desired to use multiple devices, then the update function should have the 90 | appropriately placed `jax.lax.pmean` to aggregate gradients across devices. 91 | """ 92 | 93 | init: Callable[[chex.PRNGKey], AgentState] 94 | select_action: SelectAction 95 | update: Callable[[AgentState, Batch], Tuple[AgentState, Metrics]] 96 | # `select_action` will be used for evaluation if `select_action_eval` is None 97 | select_action_eval: Optional[SelectAction] = None 98 | 99 | 100 | class OnlineAgent(NamedTuple): 101 | init: Callable[[chex.PRNGKey], AgentState] 102 | select_action: SelectAction 103 | update: Callable[[AgentState], Tuple[AgentState, Metrics]] 104 | select_action_eval: Optional[SelectAction] = None 105 | 106 | 107 | class StepEnvironment(Protocol): 108 | """ 109 | Environment transition definition. 110 | 111 | Note: The environment step function automatically resets the environment when an 112 | episode is completed. This allows for seamless, jitted n-step acting, however it requires some 113 | care when setting up the environment step function. Therefore, if you are creating an instance 114 | of this function, please refer to the details of the `__call__` function below (specifically the 115 | docstring of the returned objects), as well as the example environment in 116 | `ridl.functional.testing.toy_env`. 117 | 118 | """ 119 | 120 | def __call__( 121 | self, state: EnvironmentState, action: Action 122 | ) -> Tuple[EnvironmentState, TimeStep, Metrics]: 123 | """ 124 | Step the environment to the next state in the MDP with automatic resetting of 125 | terminated episodes. 126 | 127 | Args: 128 | state: State of the environment at the current timestep. 129 | action: Action taken at the current timestep. 130 | 131 | Returns: 132 | state: State of the environment at the next state (if the episode is 133 | still underway), or the state from a reset (if the episode terminates). 134 | timestep: Timestep containing information on the transition. 135 | If the episode terminates, then the fields of the timestep all correspond to the 136 | terminal timestep EXCEPT for the observation which corresponds to the observation 137 | from the auto-reset, as this is required for the next acting state, and the 138 | terminal observation is not needed. 139 | metrics: Relevant information for logging. 140 | """ 141 | 142 | 143 | class Environment(NamedTuple): 144 | """Minimal container which stores the key components of a jax environment, which has a jittable 145 | step function. 146 | 147 | Note: 148 | - This container is used within the data generator to generate experience, and is 149 | therefore not explicitly passed to the environment loop. 150 | - The `step` function is assumed to automatically reset the environment if an 151 | episode terminates, please refer to the `StepEnvironment` for details. 152 | """ 153 | 154 | init: Callable[[chex.PRNGKey], Tuple[EnvironmentState, TimeStep]] 155 | step: StepEnvironment 156 | spec: specs.EnvironmentSpec 157 | -------------------------------------------------------------------------------- /discounting_chain/data/discounting_chain_appendix_histories_array.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/discounting_chain/data/discounting_chain_appendix_histories_array.pickle -------------------------------------------------------------------------------- /discounting_chain/data/discounting_chain_histories_array.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/discounting_chain/data/discounting_chain_histories_array.pickle -------------------------------------------------------------------------------- /discounting_chain/data_generation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/discounting_chain/data_generation/__init__.py -------------------------------------------------------------------------------- /discounting_chain/data_generation/n_step.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, NamedTuple, Tuple 2 | 3 | import chex 4 | import jax 5 | import typing_extensions 6 | from acme.types import Transition 7 | from dm_env import TimeStep 8 | from jax import numpy as jnp 9 | 10 | from discounting_chain.base import ( 11 | AgentState, 12 | EnvironmentState, 13 | Metrics, 14 | SelectAction, 15 | StepEnvironment, 16 | ) 17 | 18 | 19 | class NStepGeneratorState(NamedTuple): 20 | """Container used within the `run_n_step` function for managing objects related to the 21 | generation of data through acting in the environment.""" 22 | 23 | environment_state: EnvironmentState 24 | previous_timestep: TimeStep 25 | rng_key: chex.PRNGKey 26 | 27 | 28 | class AccumulateBatchMetricsFn(typing_extensions.Protocol): 29 | """Accumulate metrics over a batch of n-step transitions.""" 30 | 31 | def __call__(self, batch: Transition, metrics: Metrics) -> Metrics: 32 | """ 33 | The function for accumulating the metrics. 34 | 35 | Args: 36 | batch: Batch of transition data with leading axis of (batch_size, n_step). 37 | metrics: Batch of n_step metrics with leading axis of (batch_size, n_step). 38 | 39 | Returns: 40 | metrics: Relevant information for logging. 41 | """ 42 | 43 | 44 | def _run_n_step( 45 | n_step: int, 46 | select_action: SelectAction, 47 | step_environment: StepEnvironment, 48 | agent_state: AgentState, 49 | data_generator_state: NStepGeneratorState, 50 | ) -> Tuple[Transition, NStepGeneratorState, Metrics]: 51 | """ 52 | Run n environment steps (unbatched). 53 | 54 | Args: 55 | n_step: Number of environment steps. 56 | select_action: Function for action selection by the agent. 57 | step_environment: Function for stepping the environment. 58 | agent_state: State of the agent. 59 | data_generator_state: State for managing the environment interaction. 60 | 61 | Returns: 62 | transition: N-steps of experience 63 | data_generator_state: Updated data generator state 64 | metrics: Relevant information for logging. 65 | 66 | This function may be called with vmap for batching (see `make_run_n_step`). 67 | 68 | """ 69 | 70 | def run_one_step( 71 | data_generator_state: NStepGeneratorState, 72 | ) -> Tuple[NStepGeneratorState, Tuple[Transition, Metrics]]: 73 | """Run a single step of acting - for use with `jax.lax.scan`.""" 74 | rng_key, rng_subkey = jax.random.split(data_generator_state.rng_key) 75 | action, extras = select_action( 76 | agent_state, data_generator_state.previous_timestep.observation, rng_subkey 77 | ) 78 | next_env_state, next_timestep, metrics = step_environment( 79 | data_generator_state.environment_state, action 80 | ) 81 | trajectory = Transition( 82 | observation=data_generator_state.previous_timestep.observation, 83 | action=action, 84 | discount=next_timestep.discount, 85 | reward=next_timestep.reward, 86 | next_observation=next_timestep.observation, 87 | extras=extras, 88 | ) 89 | data_generator_state = NStepGeneratorState( 90 | environment_state=next_env_state, 91 | previous_timestep=next_timestep, 92 | rng_key=rng_key, 93 | ) 94 | return data_generator_state, (trajectory, metrics) 95 | 96 | next_data_generator_state, (trajectory, metrics) = jax.lax.scan( 97 | lambda act_state, xs: run_one_step(act_state), 98 | init=data_generator_state, 99 | xs=None, 100 | length=n_step, 101 | ) 102 | 103 | return trajectory, next_data_generator_state, metrics 104 | 105 | 106 | def accumulate_batch_metrics(batch: Transition, metrics: Metrics) -> Metrics: 107 | """Accumulate metrics over the batch of n-steps. 108 | 109 | This function assumes that `metrics` uses NaNs to mask values, when an environment step does 110 | not result in a metric that should be recorded. For example, because we only want to log 111 | episode returns for completed episodes, values within the batch of n-steps are typically NaN 112 | for all transitions where the episode is not yet complete. 113 | """ 114 | # We calculate the number of complete episodes using the discount factor. 115 | num_episodes = jnp.sum(batch.discount == 0) 116 | accumulated_metrics = {"num_episodes": num_episodes} 117 | 118 | # Record the max and mean (masking NaN values) for metrics returned by the environment. 119 | mean_metrics = jax.tree_map(jnp.nanmean, metrics) 120 | max_metrics = jax.tree_map(jnp.nanmax, metrics) 121 | accumulated_metrics.update( 122 | {key + "_mean": val for key, val in mean_metrics.items()} 123 | ) 124 | accumulated_metrics.update({key + "_max": val for key, val in max_metrics.items()}) 125 | return accumulated_metrics 126 | 127 | 128 | def make_run_n_step( 129 | n_step: int, 130 | select_action: SelectAction, 131 | step_environment: StepEnvironment, 132 | accumulate_batch_metrics_fn: AccumulateBatchMetricsFn = accumulate_batch_metrics, 133 | ) -> Callable[ 134 | [AgentState, NStepGeneratorState], Tuple[Transition, NStepGeneratorState, Metrics] 135 | ]: 136 | """ 137 | Create a run-n-step function for running n environment steps in batches. 138 | 139 | Args: 140 | n_step: Number of environment steps. 141 | select_action: Function for action selection by the agent. 142 | step_environment: Function for stepping the environment. 143 | accumulate_batch_metrics_fn: Function for accumulating metrics over the batch of n-steps. 144 | 145 | Returns: 146 | run_n_step: Function for generating experience through batched n-step interaction with 147 | an environment. 148 | 149 | """ 150 | 151 | def run_n_step( 152 | agent_state: AgentState, data_generator_state: NStepGeneratorState 153 | ) -> Tuple[Transition, NStepGeneratorState, Metrics]: 154 | """ 155 | Run a batch of n environment steps. 156 | 157 | Args: 158 | agent_state: State of the agent. 159 | data_generator_state: state used for managing the environment interaction, 160 | with the leading axis of each node in the pytree representing the batch size. 161 | 162 | Returns: 163 | trajectory_batch: Batch of experience, with leading axis of [batch_size, n_step]. 164 | data_generator_state: Updated data_generator_state. 165 | metrics: Metrics accumulated during the generation of experience. 166 | 167 | """ 168 | trajectory_batch, data_generator_state, episode_metrics = jax.vmap( 169 | _run_n_step, in_axes=(None, None, None, None, 0) 170 | )(n_step, select_action, step_environment, agent_state, data_generator_state) 171 | metrics = accumulate_batch_metrics_fn(trajectory_batch, episode_metrics) 172 | return trajectory_batch, data_generator_state, metrics 173 | 174 | return run_n_step 175 | -------------------------------------------------------------------------------- /discounting_chain/data_generation/n_step_data_generator.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import chex 4 | import jax 5 | from acme.types import Transition 6 | from jax import numpy as jnp 7 | 8 | from discounting_chain.base import ( 9 | AgentState, 10 | DataGenerationInfo, 11 | DataGenerator, 12 | Environment, 13 | SelectAction, 14 | ) 15 | from discounting_chain.data_generation.n_step import ( 16 | NStepGeneratorState, 17 | make_run_n_step, 18 | ) 19 | 20 | 21 | def make_n_step_data_generator( 22 | select_action: SelectAction, 23 | environment: Environment, 24 | n_step: int, 25 | batch_size_per_device: int, 26 | ) -> DataGenerator: 27 | """ 28 | Create a data generator that runs batches of n-step interaction with an environment, 29 | using the `run_n_step` function. 30 | 31 | Args: 32 | select_action: Callable specifying action selection by the agent. 33 | environment: Environment container specifying the environment state initialisation and 34 | step functions. 35 | n_step: Numer of environment steps per batch. 36 | batch_size_per_device: Batch size that we vmap over in the `run_n_step` function. 37 | We typically run the `generate_data` method across multiple devices, in which case the 38 | total batch size will be equal to batch_size_per_device*num_devices. 39 | 40 | Returns: 41 | The n-step data generator. 42 | 43 | """ 44 | 45 | def init(rng_key: chex.PRNGKey) -> NStepGeneratorState: 46 | """Initialise the data generation state.""" 47 | env_rng_key, data_generator_rng_key = jax.random.split(rng_key) 48 | env_rng_key_batch = jax.random.split(env_rng_key, batch_size_per_device) 49 | environment_state, time_step = jax.vmap(environment.init)(env_rng_key_batch) 50 | 51 | data_generation_state = NStepGeneratorState( 52 | environment_state=environment_state, 53 | previous_timestep=time_step, 54 | rng_key=jax.random.split(data_generator_rng_key, batch_size_per_device), 55 | ) 56 | return data_generation_state 57 | 58 | run_n_step_fn = make_run_n_step(n_step, select_action, environment.step) 59 | 60 | def generate_data( 61 | agent_state: AgentState, data_generator_state: NStepGeneratorState 62 | ) -> Tuple[Transition, NStepGeneratorState, DataGenerationInfo]: 63 | """Generate a batch of data using the `run_n_step_fn`.""" 64 | batch, data_generator_state, metrics = run_n_step_fn( 65 | agent_state, data_generator_state 66 | ) 67 | data_generation_info = DataGenerationInfo( 68 | metrics=metrics, 69 | num_episodes=metrics["num_episodes"], 70 | num_steps=jnp.array(batch_size_per_device * n_step, "int"), 71 | ) 72 | return batch, data_generator_state, data_generation_info 73 | 74 | data_generator = DataGenerator(init, generate_data) 75 | return data_generator 76 | -------------------------------------------------------------------------------- /discounting_chain/data_generation/n_step_utils.py: -------------------------------------------------------------------------------- 1 | import chex 2 | import jax 3 | from acme.specs import EnvironmentSpec 4 | from acme.types import Transition 5 | from jax import numpy as jnp 6 | 7 | 8 | def create_fake_n_step_batch( 9 | environment_spec: EnvironmentSpec, 10 | batch_size: int, 11 | n_step: int, 12 | fake_extras: chex.ArrayTree = (), 13 | ) -> Transition: 14 | """ 15 | Create a fake batch of data that matches what is returned by the n-step generator. 16 | This is useful for initialisation of the buffer, which has a state initialisation that 17 | depends on the batch it receives from experience generation. 18 | 19 | Args: 20 | environment_spec: Environment spec. Each field in the `environment_spec` has to implement 21 | the `generate_value` method for this function to work. 22 | fake_extras: An object matching the shapes and types of Extras returned by the 23 | `select_action` method of an agent. 24 | Returns: 25 | batch: Batch of data matching what the n-step generator returns. 26 | """ 27 | transition = Transition( 28 | observation=environment_spec.observations.generate_value(), 29 | action=environment_spec.actions.generate_value(), 30 | reward=environment_spec.rewards.generate_value(), 31 | next_observation=environment_spec.observations.generate_value(), 32 | discount=environment_spec.discounts.generate_value(), 33 | extras=fake_extras, 34 | ) 35 | batch = jax.tree_map( 36 | lambda x: jnp.broadcast_to(x, shape=(batch_size, n_step, *x.shape)), transition 37 | ) 38 | return batch 39 | -------------------------------------------------------------------------------- /discounting_chain/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/discounting_chain/envs/__init__.py -------------------------------------------------------------------------------- /discounting_chain/envs/env_utils.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, Tuple 2 | 3 | import chex 4 | import jax 5 | from dm_env import TimeStep 6 | from jax import numpy as jnp 7 | 8 | from discounting_chain.base import Action, Environment, Metrics 9 | 10 | 11 | class EnvStateWithEpisodeMetrics(NamedTuple): 12 | """Container that extends an environment state to additionally keep track of the step count 13 | within episodes, and the cumulative reward.""" 14 | 15 | original_env_state: chex.ArrayTree 16 | episode_step_count: jnp.int32 17 | cumulative_reward: jnp.float32 18 | 19 | 20 | def wrap_env_for_episode_metrics(env: Environment) -> Environment: 21 | """Wrapper for environment that records episode length and return, and adds them to the 22 | environment's metrics. NaN values are used for episode length & return in the metrics for 23 | non-terminal timesteps.""" 24 | 25 | def accumulate_return_and_step_count( 26 | cumulative_reward: jnp.float32, step_count: jnp.int32, timestep: TimeStep 27 | ) -> Tuple[chex.Array, chex.Array]: 28 | """Update the episode's step count and cumulative return, resetting them to 0 if the episode 29 | is completed.""" 30 | step_count = jax.lax.select( 31 | timestep.last(), jnp.array(0, dtype=int), step_count + 1 32 | ) 33 | cumulative_reward = jax.lax.select( 34 | timestep.last(), 35 | jnp.array(0.0, dtype=float), 36 | cumulative_reward + timestep.reward, 37 | ) 38 | return step_count, cumulative_reward 39 | 40 | def get_episode_metrics( 41 | previous_state: EnvStateWithEpisodeMetrics, timestep: TimeStep 42 | ) -> Metrics: 43 | """Get the episode metrics for the current timestep, if an episode is complete we return the 44 | episode return and length, otherwise we return NaN's to indicate the episode is still 45 | underway.""" 46 | episode_return = jax.lax.select( 47 | timestep.last(), 48 | previous_state.cumulative_reward + timestep.reward, 49 | jnp.nan, 50 | ) 51 | episode_length = jax.lax.select( 52 | timestep.last(), 53 | jnp.array(previous_state.episode_step_count + 1, dtype=float), 54 | jnp.nan, 55 | ) 56 | metrics = {"episode_return": episode_return, "episode_length": episode_length} 57 | return metrics 58 | 59 | def init(rng_key: chex.PRNGKey) -> Tuple[EnvStateWithEpisodeMetrics, TimeStep]: 60 | """Initialise the environment's state.""" 61 | original_env_state, timestep = env.init(rng_key) 62 | env_state = EnvStateWithEpisodeMetrics( 63 | original_env_state=original_env_state, 64 | episode_step_count=jnp.int32(0), 65 | cumulative_reward=jnp.float32(0), 66 | ) 67 | return env_state, timestep 68 | 69 | def step( 70 | state: EnvStateWithEpisodeMetrics, action: Action 71 | ) -> Tuple[EnvStateWithEpisodeMetrics, TimeStep, Metrics]: 72 | """Step the environment.""" 73 | original_env_state, timestep, metrics = env.step( 74 | state.original_env_state, action 75 | ) 76 | 77 | episode_metrics = get_episode_metrics(state, timestep) 78 | metrics.update(episode_metrics) 79 | 80 | episode_step_count, cumulative_reward = accumulate_return_and_step_count( 81 | state.cumulative_reward, state.episode_step_count, timestep 82 | ) 83 | state = EnvStateWithEpisodeMetrics( 84 | original_env_state=original_env_state, 85 | cumulative_reward=cumulative_reward, 86 | episode_step_count=episode_step_count, 87 | ) 88 | 89 | return state, timestep, metrics 90 | 91 | return Environment(init=init, step=step, spec=env.spec) 92 | -------------------------------------------------------------------------------- /discounting_chain/envs/gymnax_dc_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, NamedTuple, Tuple 2 | 3 | import chex 4 | import jax.random 5 | from acme.types import specs 6 | from dm_env import TimeStep 7 | from gymnax.environments.bsuite.discounting_chain import DiscountingChain 8 | from jax import numpy as jnp 9 | 10 | from discounting_chain.base import Environment, Metrics 11 | from discounting_chain.envs.env_utils import wrap_env_for_episode_metrics 12 | 13 | Action = chex.ArrayTree 14 | 15 | 16 | class State(NamedTuple): 17 | gymnax_state: chex.ArrayTree 18 | key: chex.PRNGKey 19 | 20 | 21 | def create_dc_gmnax(mapping_seed=3) -> Tuple[Environment, Callable]: 22 | gymnax_env = DiscountingChain(mapping_seed=mapping_seed) 23 | env_params = gymnax_env.default_params 24 | 25 | def get_true_value_non_first_step(obs, probs, gamma): 26 | gamma = jax.lax.stop_gradient(gamma) 27 | context, time = obs 28 | time = time * env_params.max_steps_in_episode + 1 # unnormalise time 29 | reward_step = env_params.reward_timestep[jnp.int_(context)] 30 | time_till_reward = reward_step - time 31 | reward_already_recieved = time_till_reward < 0 32 | reward = gymnax_env.reward[jnp.int_(context)] 33 | value = gamma ** time_till_reward * reward 34 | value = jax.lax.select(reward_already_recieved, 0.0, value) 35 | return value 36 | 37 | def get_true_value_first_step_single_action(action, gamma): 38 | time = 0 39 | reward_step = env_params.reward_timestep[jnp.int_(action)] 40 | time_till_reward = reward_step - time 41 | reward_already_recieved = time_till_reward < 0 42 | reward = gymnax_env.reward[jnp.int_(action)] 43 | value = gamma ** time_till_reward * reward 44 | value = jax.lax.select(reward_already_recieved, 0.0, value) 45 | return value 46 | 47 | def get_true_value_first_step(obs, probs, gamma): 48 | values = jax.vmap(get_true_value_first_step_single_action, in_axes=(0, None))( 49 | jnp.arange(probs.shape[0]), gamma 50 | ) 51 | chex.assert_equal_shape([probs, values]) 52 | return jnp.sum(probs * values) 53 | 54 | def get_true_value(obs, probs, gamma): 55 | context, time = obs 56 | value = jax.lax.cond( 57 | time == 0, 58 | get_true_value_first_step, 59 | get_true_value_non_first_step, 60 | obs, 61 | probs, 62 | gamma, 63 | ) 64 | return jax.lax.stop_gradient(value) 65 | 66 | def init(key: chex.PRNGKey): 67 | key, subkey = jax.random.split(key) 68 | obs, state = gymnax_env.reset(subkey, env_params) 69 | timestep = TimeStep( 70 | step_type=jnp.int_(0), 71 | reward=jnp.float32(0.0), 72 | discount=jnp.float32(0.0), 73 | observation=obs, 74 | ) 75 | state = State(state, key) 76 | return state, timestep 77 | 78 | def auto_reset(state: State, timestep: TimeStep) -> Tuple[State, TimeStep]: 79 | key, subkey = jax.random.split(state.key) 80 | obs, state = gymnax_env.reset(subkey) 81 | # Replace observation with reset observation. 82 | timestep = timestep._replace(observation=obs) 83 | state = State(state, key) 84 | return state, timestep 85 | 86 | def step(state: State, action: Action) -> Tuple[State, TimeStep, Metrics]: 87 | key, subkey = jax.random.split(state.key) 88 | n_obs, n_state, reward, done, info = gymnax_env.step_env( 89 | subkey, state.gymnax_state, action, env_params 90 | ) 91 | timestep = TimeStep( 92 | step_type=jax.lax.select(done, jnp.int_(2), jnp.int_(1)), 93 | reward=jnp.float32(reward), 94 | discount=jnp.float32(1 - done), 95 | observation=n_obs, 96 | ) 97 | state = State(n_state, key) 98 | 99 | state, timestep = jax.lax.cond( 100 | timestep.last(), 101 | auto_reset, 102 | lambda new_state, timestep: (new_state, timestep), 103 | state, 104 | timestep, 105 | ) 106 | metrics = info 107 | return state, timestep, metrics 108 | 109 | gymnax_obs_spec = gymnax_env.observation_space(env_params) 110 | gymnax_action_spec = gymnax_env.action_space(env_params) 111 | 112 | spec = specs.EnvironmentSpec( 113 | observations=specs.Array( 114 | shape=gymnax_obs_spec.shape, dtype=gymnax_obs_spec.dtype 115 | ), 116 | actions=specs.DiscreteArray( 117 | num_values=gymnax_action_spec.n, 118 | ), 119 | rewards=None, 120 | discounts=None, 121 | ) 122 | 123 | env = Environment(step=step, init=init, spec=spec) 124 | env = wrap_env_for_episode_metrics(env) 125 | return env, get_true_value 126 | -------------------------------------------------------------------------------- /discounting_chain/envs/gymnax_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, Tuple 2 | 3 | import chex 4 | import gymnax 5 | import jax.random 6 | from acme.types import specs 7 | from dm_env import TimeStep 8 | from gymnax.environments.environment import Environment as GymnaxEnv 9 | from jax import numpy as jnp 10 | 11 | from discounting_chain.base import Environment, Metrics 12 | from discounting_chain.envs.env_utils import wrap_env_for_episode_metrics 13 | 14 | Action = chex.ArrayTree 15 | 16 | 17 | class State(NamedTuple): 18 | gymnax_state: chex.ArrayTree 19 | key: chex.PRNGKey 20 | 21 | 22 | def create_gymnax_env(env_name="CartPole-v1", **kwargs) -> Environment: 23 | gymnax_env: GymnaxEnv 24 | if env_name == "DiscountingChain-bsuite": 25 | from gymnax.environments.bsuite.discounting_chain import DiscountingChain 26 | 27 | print(kwargs) 28 | gymnax_env = DiscountingChain(**kwargs) 29 | env_params = gymnax_env.default_params 30 | elif env_name == "UmbrellaChain-bsuite": 31 | from gymnax.environments.bsuite.umbrella_chain import EnvParams, UmbrellaChain 32 | 33 | gymnax_env = UmbrellaChain(n_distractor=0) 34 | change_length = 20 if not kwargs else kwargs["chain_length"] 35 | env_params = EnvParams(change_length, change_length) 36 | else: 37 | gymnax_env, env_params = gymnax.make(env_name) 38 | 39 | def init(key: chex.PRNGKey): 40 | key, subkey = jax.random.split(key) 41 | obs, state = gymnax_env.reset(subkey, env_params) 42 | timestep = TimeStep( 43 | step_type=jnp.int_(0), 44 | reward=jnp.float32(0.0), 45 | discount=jnp.float32(0.0), 46 | observation=obs, 47 | ) 48 | state = State(state, key) 49 | return state, timestep 50 | 51 | def auto_reset(state: State, timestep: TimeStep) -> Tuple[State, TimeStep]: 52 | key, subkey = jax.random.split(state.key) 53 | obs, state = gymnax_env.reset(subkey) 54 | timestep = timestep._replace(observation=obs) 55 | state = State(state, key) 56 | return state, timestep 57 | 58 | def step(state: State, action: Action) -> Tuple[State, TimeStep, Metrics]: 59 | key, subkey = jax.random.split(state.key) 60 | n_obs, n_state, reward, done, info = gymnax_env.step_env( 61 | subkey, state.gymnax_state, action, env_params 62 | ) 63 | timestep = TimeStep( 64 | step_type=jax.lax.select(done, jnp.int_(2), jnp.int_(1)), 65 | reward=jnp.float32(reward), 66 | discount=jnp.float32(1 - done), 67 | observation=n_obs, 68 | ) 69 | state = State(n_state, key) 70 | 71 | state, timestep = jax.lax.cond( 72 | timestep.last(), 73 | auto_reset, 74 | lambda new_state, timestep: (new_state, timestep), 75 | state, 76 | timestep, 77 | ) 78 | metrics = info 79 | return state, timestep, metrics 80 | 81 | gymnax_obs_spec = gymnax_env.observation_space(env_params) 82 | gymnax_action_spec = gymnax_env.action_space(env_params) 83 | 84 | spec = specs.EnvironmentSpec( 85 | observations=specs.Array( 86 | shape=gymnax_obs_spec.shape, dtype=gymnax_obs_spec.dtype 87 | ), 88 | actions=specs.DiscreteArray( 89 | num_values=gymnax_action_spec.n, 90 | ), 91 | rewards=None, 92 | discounts=None, 93 | ) 94 | 95 | env = Environment(step=step, init=init, spec=spec) 96 | env = wrap_env_for_episode_metrics(env) 97 | return env 98 | -------------------------------------------------------------------------------- /discounting_chain/list_logger.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Dict, List, Union 3 | 4 | import numpy as np 5 | 6 | 7 | class ListLogger: 8 | def __init__(self): 9 | self.history: Dict[str, List[Union[np.ndarray, float, int]]] = {} 10 | self.iter_n = 0 11 | self.start_time = 0.0 12 | 13 | def write(self, data: Dict) -> None: 14 | for key, value in data.items(): 15 | if key in self.history: 16 | self.history[key].append(value) 17 | else: 18 | self.history[key] = [value] 19 | 20 | def close(self) -> None: 21 | pass 22 | 23 | def init_time(self) -> None: 24 | self.start_time = time.time() 25 | -------------------------------------------------------------------------------- /discounting_chain/nets.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import chex 4 | import haiku as hk 5 | import jax 6 | from acme.jax import networks as jax_networks 7 | from acme.specs import EnvironmentSpec 8 | from jax import numpy as jnp 9 | from tensorflow_probability.substrates import jax as tfp 10 | 11 | 12 | class MiniAtariTorso(jax_networks.base.Module): 13 | """A network in the style of `acme.jax.networks.AtariTorso` but for smaller image sizes, 14 | like with Snake from Jumanji.""" 15 | 16 | def __init__(self, conv_strides) -> None: 17 | super().__init__(name="atari_torso") 18 | self._network = hk.Sequential( 19 | [ 20 | hk.Conv2D(32, [2, 2], conv_strides[0]), 21 | jax.nn.relu, 22 | hk.Conv2D(32, [2, 2], conv_strides[1]), 23 | jax.nn.relu, 24 | ] 25 | ) 26 | 27 | def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: 28 | inputs_rank = jnp.ndim(inputs) 29 | batched_inputs = inputs_rank == 4 30 | if inputs_rank < 3 or inputs_rank > 4: 31 | raise ValueError("Expected input BHWC or HWC. Got rank %d" % inputs_rank) 32 | 33 | outputs = self._network(inputs) 34 | 35 | if batched_inputs: 36 | return jnp.reshape(outputs, [outputs.shape[0], -1]) # [B, D] 37 | return jnp.reshape(outputs, [-1]) # [D] 38 | 39 | 40 | def create_linear_forward_fn( 41 | environment_spec: EnvironmentSpec, double_value_head: bool = False 42 | ): 43 | # A minimal forward function, where the actor is a linear layer. 44 | # The critic for this function is not intended to be used. 45 | def forward_fn( 46 | obs: jnp.ndarray, 47 | ) -> Tuple[tfp.distributions.Distribution, chex.ArrayTree]: 48 | actor_network = jax_networks.CategoricalHead( 49 | num_values=environment_spec.actions.num_values, 50 | w_init=hk.initializers.VarianceScaling(1e-4), 51 | ) 52 | policy = actor_network(obs) 53 | # value function is a dummy if using linear forward fn 54 | val = hk.Linear(1, w_init=hk.initializers.VarianceScaling(1e-4))(obs) * jnp.nan 55 | value = val, val if double_value_head else val 56 | return policy, value 57 | 58 | return hk.without_apply_rng(hk.transform(forward_fn)) 59 | 60 | 61 | def create_forward_fn( 62 | environment_spec: EnvironmentSpec, 63 | actor_hidden_layers: Tuple[int, ...] = (32, 32), 64 | critic_hidden_layers: Tuple[int, ...] = (64, 64), 65 | meta_critic_hidden_layers: Optional = None, 66 | torso_mlp_units=(32, 32), 67 | conv_strides=(2, 1), 68 | double_value_head: bool = False, 69 | ) -> hk.Transformed: 70 | assert len(environment_spec.observations.shape) in [1, 3] 71 | 72 | torso = ( 73 | lambda: MiniAtariTorso(conv_strides) 74 | if len(environment_spec.observations.shape) == 3 75 | else hk.nets.MLP(torso_mlp_units) 76 | ) 77 | 78 | def forward_fn( 79 | obs: jnp.ndarray, 80 | ) -> Tuple[tfp.distributions.Distribution, chex.ArrayTree]: 81 | enc_crit = torso()(obs) 82 | enc_crit_torso = enc_crit 83 | critic_network: hk.Module = hk.nets.MLP( 84 | critic_hidden_layers, activate_final=True 85 | ) 86 | enc_crit = critic_network(enc_crit) 87 | value = hk.Linear(1, w_init=hk.initializers.VarianceScaling(1e-4))(enc_crit) 88 | 89 | enc_act = torso()(obs) 90 | actor_network: hk.Module = hk.Sequential( 91 | [ 92 | hk.nets.MLP(actor_hidden_layers, activate_final=True), 93 | jax_networks.CategoricalHead( 94 | num_values=environment_spec.actions.num_values 95 | ), 96 | ], 97 | name="actor", 98 | ) 99 | policy = actor_network(enc_act) 100 | value = jnp.squeeze(value, axis=-1) 101 | if double_value_head: 102 | if meta_critic_hidden_layers: 103 | critic_network_2 = hk.nets.MLP( 104 | meta_critic_hidden_layers, activate_final=True 105 | ) 106 | value_2 = critic_network_2(jax.lax.stop_gradient(enc_crit_torso)) 107 | else: 108 | value_2 = jax.lax.stop_gradient(enc_crit) 109 | value_2 = hk.Linear(1, w_init=hk.initializers.VarianceScaling(1e-4))( 110 | value_2 111 | ) 112 | value_2 = jnp.squeeze(value_2, axis=-1) 113 | value = value, value_2 114 | return policy, value 115 | 116 | return hk.without_apply_rng(hk.transform(forward_fn)) 117 | -------------------------------------------------------------------------------- /discounting_chain/nets_split.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import chex 4 | import haiku as hk 5 | import jax 6 | from acme.jax import networks as jax_networks 7 | from acme.specs import EnvironmentSpec 8 | from jax import numpy as jnp 9 | from tensorflow_probability.substrates import jax as tfp 10 | 11 | 12 | class MiniAtariTorso(jax_networks.base.Module): 13 | """A network in the style of `acme.jax.networks.AtariTorso` but for smaller image sizes, 14 | like with Snake from Jumanji.""" 15 | 16 | def __init__(self) -> None: 17 | super().__init__(name="atari_torso") 18 | self._network = hk.Sequential( 19 | [ 20 | hk.Conv2D(32, [2, 2], 2), 21 | jax.nn.relu, 22 | hk.Conv2D(32, [2, 2], 1), 23 | jax.nn.relu, 24 | ] 25 | ) 26 | 27 | def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: 28 | inputs_rank = jnp.ndim(inputs) 29 | batched_inputs = inputs_rank == 4 30 | if inputs_rank < 3 or inputs_rank > 4: 31 | raise ValueError("Expected input BHWC or HWC. Got rank %d" % inputs_rank) 32 | 33 | outputs = self._network(inputs) 34 | 35 | if batched_inputs: 36 | return jnp.reshape(outputs, [outputs.shape[0], -1]) # [B, D] 37 | return jnp.reshape(outputs, [-1]) # [D] 38 | 39 | 40 | def create_forward_fn( 41 | environment_spec: EnvironmentSpec, 42 | actor_hidden_layers: Tuple[int, ...] = (64, 64), 43 | critic_hidden_layers: Tuple[int, ...] = (64, 64), 44 | torso_mlp_units=(32, 32), 45 | ) -> Tuple[hk.Transformed, hk.Transformed]: 46 | assert len(environment_spec.observations.shape) in [1, 3] 47 | 48 | torso = ( 49 | lambda: MiniAtariTorso() 50 | if len(environment_spec.observations.shape) == 3 51 | else hk.nets.MLP(torso_mlp_units) 52 | ) 53 | 54 | def policy_forward_fn( 55 | obs: jnp.ndarray, 56 | ) -> Tuple[tfp.distributions.Distribution, chex.ArrayTree]: 57 | enc_act = torso()(obs) 58 | actor_network: hk.Module = hk.Sequential( 59 | [ 60 | hk.nets.MLP(actor_hidden_layers, activate_final=True), 61 | jax_networks.CategoricalHead( 62 | num_values=environment_spec.actions.num_values 63 | ), 64 | ], 65 | name="actor", 66 | ) 67 | policy = actor_network(enc_act) 68 | return policy 69 | 70 | def critic_forward_fn( 71 | obs: jnp.ndarray, 72 | ) -> Tuple[tfp.distributions.Distribution, chex.ArrayTree]: 73 | enc_crit = torso()(obs) 74 | critic_network: hk.Module = hk.nets.MLP( 75 | critic_hidden_layers, activate_final=True 76 | ) 77 | enc_crit = critic_network(enc_crit) 78 | value = hk.Linear(1, w_init=hk.initializers.VarianceScaling(1e-4))(enc_crit) 79 | value = jnp.squeeze(value, axis=-1) 80 | return value 81 | 82 | return hk.without_apply_rng(hk.transform(policy_forward_fn)), hk.without_apply_rng( 83 | hk.transform(critic_forward_fn) 84 | ) 85 | -------------------------------------------------------------------------------- /discounting_chain/train_utils.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import jax.random 4 | from acme.jax import utils as acme_utils 5 | from jax import numpy as jnp 6 | from tqdm import tqdm 7 | 8 | from discounting_chain.base import OnlineAgent 9 | from discounting_chain.list_logger import ListLogger 10 | 11 | 12 | def outer_iter(state, update_fn, n_updates): 13 | state, metrics = jax.lax.scan( 14 | lambda state, xs: update_fn(state), init=state, xs=None, length=n_updates 15 | ) 16 | metrics = jax.tree_map(lambda x: x[-1], metrics) # last device, last element 17 | return state, metrics 18 | 19 | 20 | def run( 21 | num_iterations: int, 22 | n_updates_per_iter: int, 23 | agent: OnlineAgent, 24 | logger: ListLogger, 25 | seed: int = 0, 26 | ): 27 | key = jax.random.PRNGKey(seed) 28 | key1, key2 = jax.random.split(key) 29 | devices = jax.devices() 30 | num_devices = len(devices) 31 | pmap_axis_name = "num_devices" 32 | outer_iter_fn = partial( 33 | outer_iter, update_fn=agent.update, n_updates=n_updates_per_iter 34 | ) 35 | outer_iter_fn = jax.pmap(outer_iter_fn, axis_name=pmap_axis_name, devices=devices) 36 | state = jax.pmap(agent.init)( 37 | jnp.stack([key1] * num_devices), jax.random.split(key2, num_devices) 38 | ) 39 | 40 | for _ in tqdm(range(num_iterations)): 41 | state, metrics = outer_iter_fn(state) 42 | metrics = acme_utils.get_from_first_device(metrics) 43 | logger.write(metrics) 44 | 45 | return acme_utils.get_from_first_device(state) 46 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.8 3 | namespace_packages = True 4 | incremental = False 5 | cache_dir = nul 6 | warn_redundant_casts = True 7 | warn_return_any = True 8 | warn_unused_configs = True 9 | warn_unused_ignores = False 10 | allow_redefinition = True 11 | disallow_untyped_calls = True 12 | disallow_untyped_defs = True 13 | disallow_incomplete_defs = True 14 | check_untyped_defs = True 15 | disallow_untyped_decorators = False 16 | strict_optional = True 17 | strict_equality = True 18 | explicit_package_bases = True 19 | follow_imports = skip 20 | 21 | [mypy-numpy.*] 22 | ignore_missing_imports = True 23 | 24 | [mypy-tree.*] 25 | ignore_missing_imports = True 26 | 27 | [mypy-pytest.*] 28 | ignore_missing_imports = True 29 | 30 | [mypy-hydra.*] 31 | ignore_missing_imports = True 32 | 33 | [mypy-omegaconf.*] 34 | ignore_missing_imports = True 35 | 36 | [mypy-optax.*] 37 | ignore_missing_imports = True 38 | -------------------------------------------------------------------------------- /plots/discounting_chain/dc_chain_discount_factor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/discounting_chain/dc_chain_discount_factor.png -------------------------------------------------------------------------------- /plots/discounting_chain/dc_chain_discount_factor_appendix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/discounting_chain/dc_chain_discount_factor_appendix.png -------------------------------------------------------------------------------- /plots/discounting_chain/dc_chain_outer_loss_advantage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/discounting_chain/dc_chain_outer_loss_advantage.png -------------------------------------------------------------------------------- /plots/discounting_chain/dc_chain_outer_loss_advantage_appendix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/discounting_chain/dc_chain_outer_loss_advantage_appendix.png -------------------------------------------------------------------------------- /plots/discounting_chain/dc_chain_return.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/discounting_chain/dc_chain_return.png -------------------------------------------------------------------------------- /plots/discounting_chain/dc_chain_return_appendix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/discounting_chain/dc_chain_return_appendix.png -------------------------------------------------------------------------------- /plots/snake/snake_discount_factor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/snake/snake_discount_factor.png -------------------------------------------------------------------------------- /plots/snake/snake_discount_factor_appendix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/snake/snake_discount_factor_appendix.png -------------------------------------------------------------------------------- /plots/snake/snake_outer_loss_advantage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/snake/snake_outer_loss_advantage.png -------------------------------------------------------------------------------- /plots/snake/snake_outer_loss_advantage_appendix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/snake/snake_outer_loss_advantage_appendix.png -------------------------------------------------------------------------------- /plots/snake/snake_return.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/snake/snake_return.png -------------------------------------------------------------------------------- /plots/snake/snake_return_appendix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/snake/snake_return_appendix.png -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # All of these dependencies are pinned to specific versions to achieve reproducible builds (CI/Dev) 2 | black==22.10.0 3 | coverage==6.4.4 4 | flake8==5.0.4 5 | isort==5.10.1 6 | mypy==0.971 7 | pre-commit==2.20.0 8 | promise==2.3 9 | pytest==7.1.3 10 | pytest-cov==3.0.0 11 | pytest-mock==3.8.2 12 | pytest-parallel==0.1.1 13 | pytest-xdist==2.5.0 14 | pytype==2022.9.8 15 | testfixtures==7.0.0 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | chex==0.1.5 2 | distrax==0.1.2 3 | dm-haiku==0.0.7 4 | gymnax==0.0.5 5 | hydra-core==1.2.0 6 | jax==0.3.22 7 | jaxlib==0.3.22 8 | jumanji==0.1.3 9 | jupyter==1.0.0 10 | matplotlib==3.6.2 11 | neptune-client==0.16.7 12 | numpy==1.22.4 13 | omegaconf==2.2.3 14 | optax==0.1.3 15 | rlax==0.1.4 16 | scipy==1.9.1 17 | tqdm==4.64.1 18 | -------------------------------------------------------------------------------- /snake/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/snake/__init__.py -------------------------------------------------------------------------------- /snake/agent/__init__.py: -------------------------------------------------------------------------------- 1 | from snake.agent.a2c import A2C 2 | from snake.agent.actor_critic_agent import ActorCriticAgent 3 | from snake.agent.meta_a2c import MetaA2C, meta_params_to_hyper_params 4 | -------------------------------------------------------------------------------- /snake/agent/gae.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Tuple 3 | 4 | import chex 5 | import jax 6 | import rlax 7 | from jax import numpy as jnp 8 | 9 | 10 | def compute_td_lambda( 11 | discount: chex.Array, 12 | rewards: chex.Array, 13 | values: chex.Array, 14 | bootstrap_value: chex.Array, 15 | lambda_: float, 16 | discount_factor: float, 17 | ) -> Tuple[chex.Array, chex.Array]: 18 | v_tm1 = values 19 | v_t = jnp.concatenate([values[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0) 20 | r_t = rewards 21 | discount_t = discount * discount_factor 22 | advantages = jax.vmap( 23 | functools.partial(rlax.td_lambda, lambda_=lambda_, stop_target_gradients=False), 24 | in_axes=1, 25 | out_axes=1, 26 | )( 27 | v_tm1, 28 | r_t, 29 | discount_t, 30 | v_t, 31 | ) 32 | vs = advantages + v_tm1 33 | return vs, advantages 34 | -------------------------------------------------------------------------------- /snake/configs/agent/a2c.yaml: -------------------------------------------------------------------------------- 1 | name: a2c 2 | agent: a2c 3 | meta_objective: null 4 | gamma_metal: null 5 | lambda_metal: null 6 | l_pg_metal: null 7 | l_td_metal: null 8 | l_en_metal: null 9 | gamma_outer: null 10 | lambda_outer: null 11 | l_pg_outer: null 12 | l_td_outer: null 13 | l_en_outer: null 14 | l_kl_outer: null 15 | bootstrap_l: null 16 | bootstrap_pm: null 17 | bootstrap_vm: null 18 | meta_optimizer: null 19 | meta_learning_rate: null 20 | meta_gradient_clip_norm: null 21 | bootstrap_l_optimizer: null 22 | bootstrap_l_learning_rate: null 23 | -------------------------------------------------------------------------------- /snake/configs/agent/bootstrap.yaml: -------------------------------------------------------------------------------- 1 | name: bootstrap 2 | agent: meta_a2c 3 | meta_objective: bootstrap 4 | gamma_metal: true 5 | lambda_metal: false 6 | l_pg_metal: false 7 | l_td_metal: false 8 | l_en_metal: false 9 | gamma_outer: 1.0 10 | lambda_outer: 1.0 11 | l_pg_outer: 1.0 12 | l_td_outer: 0 13 | l_en_outer: 0 14 | l_kl_outer: 0 15 | bootstrap_l: 1 16 | bootstrap_pm: 1 17 | bootstrap_vm: 0.0 18 | meta_optimizer: adam # [sgd, rmsprop, adam] 19 | meta_learning_rate: 6e-3 20 | meta_gradient_clip_norm: 0.1 # [null, ] 21 | bootstrap_l_optimizer: rmsprop 22 | bootstrap_l_learning_rate: ${training.learning_rate} 23 | -------------------------------------------------------------------------------- /snake/configs/agent/mgrl.yaml: -------------------------------------------------------------------------------- 1 | name: mgrl 2 | agent: meta_a2c 3 | meta_objective: meta_gradient 4 | gamma_metal: true 5 | lambda_metal: false 6 | l_pg_metal: false 7 | l_td_metal: false 8 | l_en_metal: false 9 | gamma_outer: 1.0 10 | lambda_outer: 1.0 11 | l_pg_outer: 1.0 12 | l_td_outer: 0 13 | l_en_outer: 0 14 | l_kl_outer: 0 15 | bootstrap_l: null 16 | bootstrap_pm: null 17 | bootstrap_vm: null 18 | meta_optimizer: adam # [sgd, rmsprop, adam] 19 | meta_learning_rate: 3e-3 20 | meta_gradient_clip_norm: 0.1 # [null, ] 21 | bootstrap_l_optimizer: null 22 | bootstrap_l_learning_rate: null 23 | -------------------------------------------------------------------------------- /snake/configs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - agent: mgrl # [a2c, mgrl, bootstrap] 3 | - _self_ 4 | 5 | training: 6 | name: appendix_norm_advantages_snake_${agent.name}_outer_critic_${agent.outer_critic}_normalize_advantage_${agent.normalize_advantage}_normalize_outer_advantage_${agent.normalize_outer_advantage} 7 | num_timesteps: 75_000_000 8 | num_eval_points: 200 9 | total_batch_size: 512 10 | total_num_envs: 512 11 | n_steps: 5 12 | reward_scaling: 1 13 | optimizer: rmsprop # [sgd, rmsprop, adam] 14 | learning_rate: 5e-4 15 | gradient_clip_norm: null # [null, ] 16 | gamma_init: 0.8 17 | lambda_init: 0.95 18 | l_pg_init: 1.0 19 | l_td_init: 0.5 20 | l_en_init: 1e-2 21 | seed: 1 22 | 23 | 24 | agent: 25 | # Common to all agents 26 | outer_critic: false # [true, false] 27 | normalize_advantage: false # [true, false] 28 | normalize_outer_advantage: false # [true, false] 29 | total_num_eval: 512 30 | deterministic_eval: false # Does not matter since now returns both metrics (determinist and stochastic) 31 | 32 | network: 33 | # Overwrite 34 | type: snake 35 | num_channels: 32 36 | policy_layers: [64, 64] # Tuned to [64, 64] out of [[64, 64], [32, 32]] 37 | value_layers: [128, 128] 38 | embedding_size_actor: null # [, null] 39 | embedding_size_critic: null # [, null] 40 | -------------------------------------------------------------------------------- /snake/data/a2c_return/MET-2150__eval_episode_reward_determinist_policy.csv: -------------------------------------------------------------------------------- 1 | 3.0,1667986130985,0.08984375 2 | 146.0,1667986137977,0.1484375 3 | 292.0,1667986140676,0.1640625 4 | 438.0,1667986143329,1.462890625 5 | 584.0,1667986146010,4.087890625 6 | 730.0,1667986148551,3.978515625 7 | 876.0,1667986151127,3.876953125 8 | 1022.0,1667986153788,4.275390625 9 | 1168.0,1667986156381,4.640625 10 | 1314.0,1667986158996,4.69140625 11 | 1460.0,1667986161606,5.037109375 12 | 1606.0,1667986164201,5.029296875 13 | 1752.0,1667986166926,5.419921875 14 | 1898.0,1667986169604,5.470703125 15 | 2044.0,1667986172243,5.4296875 16 | 2190.0,1667986174887,5.783203125 17 | 2336.0,1667986177553,5.77734375 18 | 2482.0,1667986180187,5.771484375 19 | 2628.0,1667986182764,5.896484375 20 | 2774.0,1667986185322,6.35546875 21 | 2920.0,1667986187944,6.169921875 22 | 3066.0,1667986190567,6.283203125 23 | 3212.0,1667986193194,6.794921875 24 | 3358.0,1667986195773,6.794921875 25 | 3504.0,1667986198404,6.70703125 26 | 3650.0,1667986201275,7.203125 27 | 3796.0,1667986203893,7.40625 28 | 3942.0,1667986206545,7.66796875 29 | 4088.0,1667986209185,8.009765625 30 | 4234.0,1667986211820,8.2109375 31 | 4380.0,1667986214410,8.775390625 32 | 4526.0,1667986217095,9.494140625 33 | 4672.0,1667986219705,9.607421875 34 | 4818.0,1667986222317,10.248046875 35 | 4964.0,1667986224904,10.298828125 36 | 5110.0,1667986227626,10.349609375 37 | 5256.0,1667986230294,10.75 38 | 5402.0,1667986232953,10.7734375 39 | 5548.0,1667986235499,11.03125 40 | 5694.0,1667986238166,11.609375 41 | 5840.0,1667986240780,11.2734375 42 | 5986.0,1667986243398,11.123046875 43 | 6132.0,1667986245996,11.46875 44 | 6278.0,1667986248671,11.69921875 45 | 6424.0,1667986251336,12.326171875 46 | 6570.0,1667986253957,11.7890625 47 | 6716.0,1667986256607,11.935546875 48 | 6862.0,1667986259283,12.3359375 49 | 7008.0,1667986261919,12.28515625 50 | 7154.0,1667986264568,12.931640625 51 | 7300.0,1667986267282,13.208984375 52 | 7446.0,1667986269879,13.341796875 53 | 7592.0,1667986272482,13.46875 54 | 7738.0,1667986275155,13.52734375 55 | 7884.0,1667986277771,13.603515625 56 | 8030.0,1667986280433,13.591796875 57 | 8176.0,1667986283111,13.65625 58 | 8322.0,1667986285763,14.1953125 59 | 8468.0,1667986288457,13.95703125 60 | 8614.0,1667986291151,13.9765625 61 | 8760.0,1667986293792,14.0625 62 | 8906.0,1667986296379,14.056640625 63 | 9052.0,1667986299064,14.220703125 64 | 9198.0,1667986301695,13.888671875 65 | 9344.0,1667986304350,14.255859375 66 | 9490.0,1667986307065,14.794921875 67 | 9636.0,1667986309712,14.716796875 68 | 9782.0,1667986312360,14.904296875 69 | 9928.0,1667986315018,14.818359375 70 | 10074.0,1667986317683,14.9609375 71 | 10220.0,1667986320328,14.767578125 72 | 10366.0,1667986323004,14.96875 73 | 10512.0,1667986325707,15.0078125 74 | 10658.0,1667986328283,15.298828125 75 | 10804.0,1667986330916,14.78515625 76 | 10950.0,1667986333553,15.36328125 77 | 11096.0,1667986336144,14.96875 78 | 11242.0,1667986338835,15.060546875 79 | 11388.0,1667986341441,15.337890625 80 | 11534.0,1667986344090,15.3046875 81 | 11680.0,1667986346758,15.625 82 | 11826.0,1667986349418,15.728515625 83 | 11972.0,1667986352023,15.9453125 84 | 12118.0,1667986354605,15.810546875 85 | 12264.0,1667986357232,15.771484375 86 | 12410.0,1667986359841,16.22265625 87 | 12556.0,1667986362468,15.8515625 88 | 12702.0,1667986365058,16.251953125 89 | 12848.0,1667986367773,15.80078125 90 | 12994.0,1667986370387,15.99609375 91 | 13140.0,1667986373078,15.685546875 92 | 13286.0,1667986375699,15.431640625 93 | 13432.0,1667986378359,15.736328125 94 | 13578.0,1667986380960,15.46484375 95 | 13724.0,1667986383614,15.96875 96 | 13870.0,1667986386244,16.0546875 97 | 14016.0,1667986388906,15.94921875 98 | 14162.0,1667986391517,15.9609375 99 | 14308.0,1667986394205,16.2578125 100 | 14454.0,1667986396715,15.931640625 101 | 14600.0,1667986399421,15.90234375 102 | 14746.0,1667986402022,15.87109375 103 | 14892.0,1667986404622,15.9609375 104 | 15038.0,1667986407310,16.0625 105 | 15184.0,1667986409944,16.228515625 106 | 15330.0,1667986412640,16.10546875 107 | 15476.0,1667986415244,16.322265625 108 | 15622.0,1667986417911,16.189453125 109 | 15768.0,1667986421289,16.669921875 110 | 15914.0,1667986423893,16.44140625 111 | 16060.0,1667986426458,16.13671875 112 | 16206.0,1667986429265,16.302734375 113 | 16352.0,1667986431954,16.529296875 114 | 16498.0,1667986434571,16.525390625 115 | 16644.0,1667986437218,16.517578125 116 | 16790.0,1667986439895,16.33203125 117 | 16936.0,1667986442550,16.0625 118 | 17082.0,1667986445214,16.61328125 119 | 17228.0,1667986447862,15.880859375 120 | 17374.0,1667986450550,16.333984375 121 | 17520.0,1667986453244,16.384765625 122 | 17666.0,1667986455851,16.34375 123 | 17812.0,1667986458457,16.431640625 124 | 17958.0,1667986461001,16.95703125 125 | 18104.0,1667986463506,17.076171875 126 | 18250.0,1667986466044,16.884765625 127 | 18396.0,1667986468622,16.7109375 128 | 18542.0,1667986471195,16.484375 129 | 18688.0,1667986473826,16.580078125 130 | 18834.0,1667986476401,16.873046875 131 | 18980.0,1667986479105,16.5625 132 | 19126.0,1667986481717,17.048828125 133 | 19272.0,1667986484294,17.06640625 134 | 19418.0,1667986486875,16.8046875 135 | 19564.0,1667986489596,16.435546875 136 | 19710.0,1667986492238,16.873046875 137 | 19856.0,1667986494881,16.82421875 138 | 20002.0,1667986497668,16.43359375 139 | 20148.0,1667986500376,16.791015625 140 | 20294.0,1667986503073,16.970703125 141 | 20440.0,1667986505889,16.802734375 142 | 20586.0,1667986508597,16.8671875 143 | 20732.0,1667986511179,16.75 144 | 20878.0,1667986513734,16.87890625 145 | 21024.0,1667986516362,17.078125 146 | 21170.0,1667986519023,17.083984375 147 | 21316.0,1667986521662,17.0390625 148 | 21462.0,1667986524266,17.021484375 149 | 21608.0,1667986526935,17.125 150 | 21754.0,1667986529563,17.123046875 151 | 21900.0,1667986532175,16.818359375 152 | 22046.0,1667986534837,16.58203125 153 | 22192.0,1667986537474,17.1796875 154 | 22338.0,1667986540093,17.087890625 155 | 22484.0,1667986542738,17.26953125 156 | 22630.0,1667986545352,17.53125 157 | 22776.0,1667986547950,16.798828125 158 | 22922.0,1667986550538,16.982421875 159 | 23068.0,1667986553117,16.619140625 160 | 23214.0,1667986555700,17.185546875 161 | 23360.0,1667986558337,17.05078125 162 | 23506.0,1667986560988,17.24609375 163 | 23652.0,1667986563555,17.3046875 164 | 23798.0,1667986566130,17.11328125 165 | 23944.0,1667986568802,17.34765625 166 | 24090.0,1667986571389,17.361328125 167 | 24236.0,1667986574049,17.197265625 168 | 24382.0,1667986576601,17.439453125 169 | 24528.0,1667986579220,17.458984375 170 | 24674.0,1667986581826,17.376953125 171 | 24820.0,1667986584428,17.41015625 172 | 24966.0,1667986587134,17.28515625 173 | 25112.0,1667986589806,17.205078125 174 | 25258.0,1667986592632,16.98828125 175 | 25404.0,1667986595277,17.021484375 176 | 25550.0,1667986597895,17.35546875 177 | 25696.0,1667986600551,17.25 178 | 25842.0,1667986603118,16.81640625 179 | 25988.0,1667986605742,17.220703125 180 | 26134.0,1667986608357,16.947265625 181 | 26280.0,1667986610945,17.244140625 182 | 26426.0,1667986613635,17.16796875 183 | 26572.0,1667986616194,17.40625 184 | 26718.0,1667986618868,17.49609375 185 | 26864.0,1667986621550,17.68359375 186 | 27010.0,1667986624198,17.5625 187 | 27156.0,1667986626898,17.298828125 188 | 27302.0,1667986629607,17.69921875 189 | 27448.0,1667986632210,17.427734375 190 | 27594.0,1667986634852,17.408203125 191 | 27740.0,1667986637523,17.728515625 192 | 27886.0,1667986640170,17.26171875 193 | 28032.0,1667986642782,17.818359375 194 | 28178.0,1667986645444,17.63671875 195 | 28324.0,1667986648060,17.44140625 196 | 28470.0,1667986650671,17.126953125 197 | 28616.0,1667986653396,17.46875 198 | 28762.0,1667986656046,17.291015625 199 | 28908.0,1667986658737,17.333984375 200 | 29054.0,1667986661386,17.861328125 201 | 29200.0,1667986664068,17.9921875 202 | -------------------------------------------------------------------------------- /snake/data/a2c_return/MET-2175__eval_episode_reward_determinist_policy.csv: -------------------------------------------------------------------------------- 1 | 3.0,1667989976182,0.099609375 2 | 146.0,1667989981915,0.15625 3 | 292.0,1667989984563,0.162109375 4 | 438.0,1667989987274,1.740234375 5 | 584.0,1667989989853,3.791015625 6 | 730.0,1667989992338,3.96875 7 | 876.0,1667989994992,4.443359375 8 | 1022.0,1667989997599,4.462890625 9 | 1168.0,1667990000225,5.017578125 10 | 1314.0,1667990002815,5.3515625 11 | 1460.0,1667990005479,5.134765625 12 | 1606.0,1667990008073,5.248046875 13 | 1752.0,1667990010690,5.5078125 14 | 1898.0,1667990013326,5.80078125 15 | 2044.0,1667990015889,5.666015625 16 | 2190.0,1667990018450,6.076171875 17 | 2336.0,1667990021077,5.92578125 18 | 2482.0,1667990023655,6.037109375 19 | 2628.0,1667990026315,6.462890625 20 | 2774.0,1667990028932,6.423828125 21 | 2920.0,1667990031518,6.51953125 22 | 3066.0,1667990034139,6.65625 23 | 3212.0,1667990036773,6.45703125 24 | 3358.0,1667990039338,6.154296875 25 | 3504.0,1667990041975,6.44140625 26 | 3650.0,1667990044688,6.484375 27 | 3796.0,1667990047256,6.181640625 28 | 3942.0,1667990049847,6.623046875 29 | 4088.0,1667990052436,6.5234375 30 | 4234.0,1667990055125,6.478515625 31 | 4380.0,1667990057725,6.8984375 32 | 4526.0,1667990060296,6.162109375 33 | 4672.0,1667990062897,7.015625 34 | 4818.0,1667990065537,6.916015625 35 | 4964.0,1667990068103,6.619140625 36 | 5110.0,1667990070666,6.7734375 37 | 5256.0,1667990073287,6.896484375 38 | 5402.0,1667990075844,6.736328125 39 | 5548.0,1667990078438,7.04296875 40 | 5694.0,1667990081023,6.6171875 41 | 5840.0,1667990083628,7.025390625 42 | 5986.0,1667990086138,6.859375 43 | 6132.0,1667990088732,7.025390625 44 | 6278.0,1667990091317,7.517578125 45 | 6424.0,1667990093894,7.19140625 46 | 6570.0,1667990096517,7.4296875 47 | 6716.0,1667990099126,7.734375 48 | 6862.0,1667990101733,7.703125 49 | 7008.0,1667990104459,7.912109375 50 | 7154.0,1667990107053,8.0234375 51 | 7300.0,1667990109614,8.076171875 52 | 7446.0,1667990112304,7.9296875 53 | 7592.0,1667990114964,8.052734375 54 | 7738.0,1667990117590,8.419921875 55 | 7884.0,1667990120192,8.7109375 56 | 8030.0,1667990122856,8.712890625 57 | 8176.0,1667990125531,8.431640625 58 | 8322.0,1667990128100,9.23046875 59 | 8468.0,1667990130664,9.380859375 60 | 8614.0,1667990133271,9.248046875 61 | 8760.0,1667990135902,9.9921875 62 | 8906.0,1667990138510,10.396484375 63 | 9052.0,1667990141052,10.49609375 64 | 9198.0,1667990143727,10.53515625 65 | 9344.0,1667990146374,10.791015625 66 | 9490.0,1667990149019,10.68359375 67 | 9636.0,1667990151526,10.990234375 68 | 9782.0,1667990154145,11.623046875 69 | 9928.0,1667990156724,11.455078125 70 | 10074.0,1667990159396,11.404296875 71 | 10220.0,1667990162027,11.548828125 72 | 10366.0,1667990164783,11.4296875 73 | 10512.0,1667990167355,11.720703125 74 | 10658.0,1667990169958,12.072265625 75 | 10804.0,1667990172593,12.408203125 76 | 10950.0,1667990175362,12.3359375 77 | 11096.0,1667990177925,12.75390625 78 | 11242.0,1667990180561,12.20703125 79 | 11388.0,1667990183206,12.4453125 80 | 11534.0,1667990185832,12.607421875 81 | 11680.0,1667990188394,13.1640625 82 | 11826.0,1667990191032,13.232421875 83 | 11972.0,1667990193656,13.1484375 84 | 12118.0,1667990196267,13.28515625 85 | 12264.0,1667990198879,13.083984375 86 | 12410.0,1667990201450,13.111328125 87 | 12556.0,1667990204127,13.68359375 88 | 12702.0,1667990206747,13.6640625 89 | 12848.0,1667990209391,13.783203125 90 | 12994.0,1667990212044,13.826171875 91 | 13140.0,1667990214739,14.1875 92 | 13286.0,1667990217321,14.310546875 93 | 13432.0,1667990219915,14.13671875 94 | 13578.0,1667990222592,14.14453125 95 | 13724.0,1667990225338,14.392578125 96 | 13870.0,1667990227927,14.03125 97 | 14016.0,1667990230608,14.626953125 98 | 14162.0,1667990233236,14.578125 99 | 14308.0,1667990235804,14.341796875 100 | 14454.0,1667990238405,14.416015625 101 | 14600.0,1667990240995,14.455078125 102 | 14746.0,1667990243630,15.046875 103 | 14892.0,1667990246192,14.93359375 104 | 15038.0,1667990248815,15.33203125 105 | 15184.0,1667990251419,15.09765625 106 | 15330.0,1667990254048,15.09765625 107 | 15476.0,1667990256642,14.90625 108 | 15622.0,1667990259285,15.240234375 109 | 15768.0,1667990261878,14.9921875 110 | 15914.0,1667990264545,15.28515625 111 | 16060.0,1667990267159,15.302734375 112 | 16206.0,1667990269730,15.541015625 113 | 16352.0,1667990272338,15.419921875 114 | 16498.0,1667990274999,15.65625 115 | 16644.0,1667990277627,15.697265625 116 | 16790.0,1667990280275,15.85546875 117 | 16936.0,1667990282941,15.44921875 118 | 17082.0,1667990285631,15.8671875 119 | 17228.0,1667990288181,15.548828125 120 | 17374.0,1667990290779,15.591796875 121 | 17520.0,1667990293430,15.875 122 | 17666.0,1667990296013,16.123046875 123 | 17812.0,1667990298631,15.9453125 124 | 17958.0,1667990301206,15.94921875 125 | 18104.0,1667990303936,15.837890625 126 | 18250.0,1667990306518,15.85546875 127 | 18396.0,1667990309156,16.10546875 128 | 18542.0,1667990311762,15.900390625 129 | 18688.0,1667990314434,16.123046875 130 | 18834.0,1667990317080,16.3125 131 | 18980.0,1667990319815,16.142578125 132 | 19126.0,1667990322492,16.5625 133 | 19272.0,1667990325204,16.4609375 134 | 19418.0,1667990327771,16.296875 135 | 19564.0,1667990330386,16.20703125 136 | 19710.0,1667990332997,16.0703125 137 | 19856.0,1667990335621,16.40625 138 | 20002.0,1667990338130,16.330078125 139 | 20148.0,1667990340800,16.236328125 140 | 20294.0,1667990343447,16.41796875 141 | 20440.0,1667990346065,16.443359375 142 | 20586.0,1667990348613,16.724609375 143 | 20732.0,1667990351219,16.556640625 144 | 20878.0,1667990353831,15.95703125 145 | 21024.0,1667990356455,16.396484375 146 | 21170.0,1667990359166,16.0234375 147 | 21316.0,1667990361773,16.205078125 148 | 21462.0,1667990364467,16.166015625 149 | 21608.0,1667990367120,16.36328125 150 | 21754.0,1667990369686,16.224609375 151 | 21900.0,1667990372283,16.47265625 152 | 22046.0,1667990374973,16.724609375 153 | 22192.0,1667990377590,16.7109375 154 | 22338.0,1667990380343,16.41796875 155 | 22484.0,1667990382953,16.314453125 156 | 22630.0,1667990385562,16.775390625 157 | 22776.0,1667990388170,16.94140625 158 | 22922.0,1667990390697,16.68359375 159 | 23068.0,1667990393356,16.806640625 160 | 23214.0,1667990395979,16.451171875 161 | 23360.0,1667990398612,16.939453125 162 | 23506.0,1667990401159,16.90625 163 | 23652.0,1667990403892,16.6640625 164 | 23798.0,1667990406472,16.990234375 165 | 23944.0,1667990409025,16.5390625 166 | 24090.0,1667990411675,17.267578125 167 | 24236.0,1667990414317,16.310546875 168 | 24382.0,1667990416923,17.2421875 169 | 24528.0,1667990419548,16.6015625 170 | 24674.0,1667990422108,16.8828125 171 | 24820.0,1667990424853,16.673828125 172 | 24966.0,1667990427453,17.021484375 173 | 25112.0,1667990430084,16.404296875 174 | 25258.0,1667990432769,17.14453125 175 | 25404.0,1667990435344,17.162109375 176 | 25550.0,1667990437917,16.986328125 177 | 25696.0,1667990440636,16.916015625 178 | 25842.0,1667990443247,16.466796875 179 | 25988.0,1667990445842,16.845703125 180 | 26134.0,1667990448403,16.966796875 181 | 26280.0,1667990451046,16.44140625 182 | 26426.0,1667990453610,16.693359375 183 | 26572.0,1667990456315,16.947265625 184 | 26718.0,1667990458909,17.033203125 185 | 26864.0,1667990461517,17.20703125 186 | 27010.0,1667990464172,17.125 187 | 27156.0,1667990466755,17.3984375 188 | 27302.0,1667990469321,17.310546875 189 | 27448.0,1667990471876,17.00390625 190 | 27594.0,1667990474523,16.708984375 191 | 27740.0,1667990477134,17.35546875 192 | 27886.0,1667990479801,17.20703125 193 | 28032.0,1667990482322,17.310546875 194 | 28178.0,1667990485028,16.76171875 195 | 28324.0,1667990487645,16.892578125 196 | 28470.0,1667990490252,17.203125 197 | 28616.0,1667990492836,17.19921875 198 | 28762.0,1667990495446,16.828125 199 | 28908.0,1667990498049,17.08984375 200 | 29054.0,1667990500812,17.111328125 201 | 29200.0,1667990503455,17.140625 202 | -------------------------------------------------------------------------------- /snake/data/appendix/return/return_a2c_no_norm.csv: -------------------------------------------------------------------------------- 1 | 3.0,1668722618745,0.1328125 2 | 146.0,1668722625924,0.17578125 3 | 292.0,1668722628787,0.248046875 4 | 438.0,1668722631689,1.306640625 5 | 584.0,1668722634513,3.3203125 6 | 730.0,1668722637239,3.830078125 7 | 876.0,1668722640071,4.123046875 8 | 1022.0,1668722643095,4.265625 9 | 1168.0,1668722646166,4.673828125 10 | 1314.0,1668722649007,4.82421875 11 | 1460.0,1668722651770,5.0234375 12 | 1606.0,1668722654546,5.16015625 13 | 1752.0,1668722657502,5.390625 14 | 1898.0,1668722660270,5.638671875 15 | 2044.0,1668722663126,5.703125 16 | 2190.0,1668722665935,5.626953125 17 | 2336.0,1668722668631,5.728515625 18 | 2482.0,1668722671420,5.884765625 19 | 2628.0,1668722674190,5.865234375 20 | 2774.0,1668722676928,6.033203125 21 | 2920.0,1668722679738,6.09765625 22 | 3066.0,1668722682509,6.349609375 23 | 3212.0,1668722685376,6.517578125 24 | 3358.0,1668722688201,6.458984375 25 | 3504.0,1668722691091,6.53125 26 | 3650.0,1668722693921,6.966796875 27 | 3796.0,1668722696767,7.638671875 28 | 3942.0,1668722699623,8.21484375 29 | 4088.0,1668722702517,8.09375 30 | 4234.0,1668722705340,8.544921875 31 | 4380.0,1668722708099,9.291015625 32 | 4526.0,1668722710957,9.58203125 33 | 4672.0,1668722713792,9.79296875 34 | 4818.0,1668722717101,9.626953125 35 | 4964.0,1668722719869,10.21484375 36 | 5110.0,1668722722717,10.4453125 37 | 5256.0,1668722725611,10.462890625 38 | 5402.0,1668722728366,10.546875 39 | 5548.0,1668722731141,11.421875 40 | 5694.0,1668722733947,11.341796875 41 | 5840.0,1668722736816,11.14453125 42 | 5986.0,1668722739885,11.376953125 43 | 6132.0,1668722742746,11.9296875 44 | 6278.0,1668722745719,11.630859375 45 | 6424.0,1668722748523,12.109375 46 | 6570.0,1668722751278,12.1484375 47 | 6716.0,1668722754023,12.201171875 48 | 6862.0,1668722756899,12.23046875 49 | 7008.0,1668722759735,12.189453125 50 | 7154.0,1668722762564,12.53515625 51 | 7300.0,1668722765390,12.814453125 52 | 7446.0,1668722768199,12.556640625 53 | 7592.0,1668722771022,13.16015625 54 | 7738.0,1668722773830,12.890625 55 | 7884.0,1668722776824,13.134765625 56 | 8030.0,1668722779608,13.09375 57 | 8176.0,1668722782636,12.8125 58 | 8322.0,1668722785514,13.759765625 59 | 8468.0,1668722788352,13.64453125 60 | 8614.0,1668722791245,13.568359375 61 | 8760.0,1668722794029,14.412109375 62 | 8906.0,1668722796856,13.87109375 63 | 9052.0,1668722799676,14.154296875 64 | 9198.0,1668722802439,14.294921875 65 | 9344.0,1668722805324,14.009765625 66 | 9490.0,1668722808229,14.669921875 67 | 9636.0,1668722811198,14.671875 68 | 9782.0,1668722814124,14.595703125 69 | 9928.0,1668722817013,14.349609375 70 | 10074.0,1668722819787,14.6484375 71 | 10220.0,1668722822591,14.232421875 72 | 10366.0,1668722825498,14.62109375 73 | 10512.0,1668722828389,14.556640625 74 | 10658.0,1668722831247,15.119140625 75 | 10804.0,1668722834023,14.6796875 76 | 10950.0,1668722836968,15.001953125 77 | 11096.0,1668722839780,14.845703125 78 | 11242.0,1668722842592,14.880859375 79 | 11388.0,1668722845395,14.57421875 80 | 11534.0,1668722848187,15.150390625 81 | 11680.0,1668722851047,15.416015625 82 | 11826.0,1668722853835,15.0546875 83 | 11972.0,1668722856750,15.02734375 84 | 12118.0,1668722859551,15.283203125 85 | 12264.0,1668722862325,15.052734375 86 | 12410.0,1668722865124,15.521484375 87 | 12556.0,1668722868063,14.876953125 88 | 12702.0,1668722870829,15.5703125 89 | 12848.0,1668722873599,15.712890625 90 | 12994.0,1668722876406,15.048828125 91 | 13140.0,1668722879185,15.6796875 92 | 13286.0,1668722881996,15.11328125 93 | 13432.0,1668722884857,16.09375 94 | 13578.0,1668722887671,15.765625 95 | 13724.0,1668722890474,15.298828125 96 | 13870.0,1668722893353,15.76953125 97 | 14016.0,1668722896310,15.470703125 98 | 14162.0,1668722899277,15.794921875 99 | 14308.0,1668722902083,15.4453125 100 | 14454.0,1668722904975,15.8515625 101 | 14600.0,1668722907863,15.51953125 102 | 14746.0,1668722910645,16.197265625 103 | 14892.0,1668722913480,16.1796875 104 | 15038.0,1668722916347,15.724609375 105 | 15184.0,1668722919129,15.78515625 106 | 15330.0,1668722922023,15.998046875 107 | 15476.0,1668722924856,15.859375 108 | 15622.0,1668722927743,16.169921875 109 | 15768.0,1668722930602,16.171875 110 | 15914.0,1668722933435,16.32421875 111 | 16060.0,1668722936266,16.32421875 112 | 16206.0,1668722939233,16.3125 113 | 16352.0,1668722942041,16.140625 114 | 16498.0,1668722944989,16.05859375 115 | 16644.0,1668722947910,16.212890625 116 | 16790.0,1668722950779,16.005859375 117 | 16936.0,1668722953680,16.47265625 118 | 17082.0,1668722956602,15.984375 119 | 17228.0,1668722959405,16.5234375 120 | 17374.0,1668722962206,16.30859375 121 | 17520.0,1668722965025,16.671875 122 | 17666.0,1668722967794,16.21484375 123 | 17812.0,1668722970580,16.564453125 124 | 17958.0,1668722973367,16.361328125 125 | 18104.0,1668722976233,16.33203125 126 | 18250.0,1668722978995,16.666015625 127 | 18396.0,1668722981858,16.787109375 128 | 18542.0,1668722984660,16.830078125 129 | 18688.0,1668722987523,16.58203125 130 | 18834.0,1668722990306,16.134765625 131 | 18980.0,1668722993061,16.52734375 132 | 19126.0,1668722995874,16.70703125 133 | 19272.0,1668722998644,16.501953125 134 | 19418.0,1668723001472,16.595703125 135 | 19564.0,1668723004334,16.4296875 136 | 19710.0,1668723007222,16.86328125 137 | 19856.0,1668723010009,16.73046875 138 | 20002.0,1668723012831,16.751953125 139 | 20148.0,1668723015635,16.884765625 140 | 20294.0,1668723018499,16.5 141 | 20440.0,1668723021226,16.6328125 142 | 20586.0,1668723024047,16.525390625 143 | 20732.0,1668723026780,16.51171875 144 | 20878.0,1668723029608,16.666015625 145 | 21024.0,1668723032434,16.9375 146 | 21170.0,1668723035307,16.83984375 147 | 21316.0,1668723038150,16.453125 148 | 21462.0,1668723040926,17.25 149 | 21608.0,1668723043703,16.78515625 150 | 21754.0,1668723046622,16.779296875 151 | 21900.0,1668723049502,16.404296875 152 | 22046.0,1668723052430,16.783203125 153 | 22192.0,1668723055266,16.884765625 154 | 22338.0,1668723058103,16.787109375 155 | 22484.0,1668723060992,16.939453125 156 | 22630.0,1668723063866,17.103515625 157 | 22776.0,1668723066826,17.25 158 | 22922.0,1668723069711,16.765625 159 | 23068.0,1668723072587,16.919921875 160 | 23214.0,1668723075419,16.849609375 161 | 23360.0,1668723078219,17.375 162 | 23506.0,1668723081105,16.9375 163 | 23652.0,1668723083923,16.921875 164 | 23798.0,1668723086729,16.94921875 165 | 23944.0,1668723089558,17.1796875 166 | 24090.0,1668723092344,16.98828125 167 | 24236.0,1668723095182,17.072265625 168 | 24382.0,1668723098064,17.19921875 169 | 24528.0,1668723100929,17.064453125 170 | 24674.0,1668723103813,17.375 171 | 24820.0,1668723106747,17.03125 172 | 24966.0,1668723109744,17.4375 173 | 25112.0,1668723112557,17.0703125 174 | 25258.0,1668723115409,17.28515625 175 | 25404.0,1668723118248,17.1953125 176 | 25550.0,1668723121087,16.943359375 177 | 25696.0,1668723123940,16.875 178 | 25842.0,1668723126839,17.06640625 179 | 25988.0,1668723129723,17.3515625 180 | 26134.0,1668723132585,17.177734375 181 | 26280.0,1668723135450,16.86328125 182 | 26426.0,1668723138352,17.1328125 183 | 26572.0,1668723141229,17.509765625 184 | 26718.0,1668723144063,17.349609375 185 | 26864.0,1668723146850,17.587890625 186 | 27010.0,1668723149608,17.283203125 187 | 27156.0,1668723152440,17.228515625 188 | 27302.0,1668723155309,17.798828125 189 | 27448.0,1668723158106,17.27734375 190 | 27594.0,1668723160924,17.126953125 191 | 27740.0,1668723163828,17.2265625 192 | 27886.0,1668723166770,17.990234375 193 | 28032.0,1668723169609,17.072265625 194 | 28178.0,1668723172508,17.09765625 195 | 28324.0,1668723175368,17.052734375 196 | 28470.0,1668723178253,17.23828125 197 | 28616.0,1668723181075,17.541015625 198 | 28762.0,1668723183835,17.408203125 199 | 28908.0,1668723186657,17.556640625 200 | 29054.0,1668723189500,17.318359375 201 | 29200.0,1668723192319,17.4453125 202 | -------------------------------------------------------------------------------- /snake/data/appendix/return/return_a2c_norm.csv: -------------------------------------------------------------------------------- 1 | 3.0,1668723199785,0.1328125 2 | 146.0,1668723206026,0.453125 3 | 292.0,1668723208886,3.03125 4 | 438.0,1668723211647,3.771484375 5 | 584.0,1668723214505,4.033203125 6 | 730.0,1668723217425,4.22265625 7 | 876.0,1668723220288,4.453125 8 | 1022.0,1668723223134,4.732421875 9 | 1168.0,1668723226081,5.306640625 10 | 1314.0,1668723228938,5.54296875 11 | 1460.0,1668723231848,5.572265625 12 | 1606.0,1668723234717,5.8125 13 | 1752.0,1668723237594,6.025390625 14 | 1898.0,1668723240485,6.244140625 15 | 2044.0,1668723243281,6.216796875 16 | 2190.0,1668723246223,5.982421875 17 | 2336.0,1668723249199,6.421875 18 | 2482.0,1668723252189,6.41015625 19 | 2628.0,1668723255166,6.265625 20 | 2774.0,1668723258184,6.599609375 21 | 2920.0,1668723261077,6.408203125 22 | 3066.0,1668723264072,6.451171875 23 | 3212.0,1668723267061,6.826171875 24 | 3358.0,1668723269946,6.94921875 25 | 3504.0,1668723272867,6.423828125 26 | 3650.0,1668723275798,6.7265625 27 | 3796.0,1668723279741,6.51171875 28 | 3942.0,1668723282643,7.005859375 29 | 4088.0,1668723285567,6.7265625 30 | 4234.0,1668723288677,6.875 31 | 4380.0,1668723291552,7.146484375 32 | 4526.0,1668723294427,7.115234375 33 | 4672.0,1668723297414,7.189453125 34 | 4818.0,1668723300362,7.126953125 35 | 4964.0,1668723303314,7.7578125 36 | 5110.0,1668723306284,7.6015625 37 | 5256.0,1668723309194,7.517578125 38 | 5402.0,1668723312157,7.802734375 39 | 5548.0,1668723315124,7.798828125 40 | 5694.0,1668723318250,8.259765625 41 | 5840.0,1668723321158,8.380859375 42 | 5986.0,1668723324131,8.1875 43 | 6132.0,1668723327063,8.951171875 44 | 6278.0,1668723329843,8.6484375 45 | 6424.0,1668723332791,8.97265625 46 | 6570.0,1668723335712,9.037109375 47 | 6716.0,1668723338673,9.619140625 48 | 6862.0,1668723341533,9.62890625 49 | 7008.0,1668723344402,10.087890625 50 | 7154.0,1668723347417,10.1875 51 | 7300.0,1668723350321,10.0625 52 | 7446.0,1668723353280,10.076171875 53 | 7592.0,1668723356204,10.912109375 54 | 7738.0,1668723359128,10.685546875 55 | 7884.0,1668723362073,10.974609375 56 | 8030.0,1668723364954,11.37109375 57 | 8176.0,1668723367896,11.498046875 58 | 8322.0,1668723370858,11.626953125 59 | 8468.0,1668723373782,11.728515625 60 | 8614.0,1668723376815,11.515625 61 | 8760.0,1668723379712,12.392578125 62 | 8906.0,1668723382637,12.296875 63 | 9052.0,1668723385527,12.439453125 64 | 9198.0,1668723388476,12.74609375 65 | 9344.0,1668723391353,12.759765625 66 | 9490.0,1668723394186,12.921875 67 | 9636.0,1668723397082,12.861328125 68 | 9782.0,1668723400028,13.08984375 69 | 9928.0,1668723402958,13.16796875 70 | 10074.0,1668723405823,13.40625 71 | 10220.0,1668723408727,13.669921875 72 | 10366.0,1668723411631,13.703125 73 | 10512.0,1668723414577,13.83203125 74 | 10658.0,1668723417496,13.775390625 75 | 10804.0,1668723420512,13.69921875 76 | 10950.0,1668723423555,13.99609375 77 | 11096.0,1668723426590,14.388671875 78 | 11242.0,1668723429565,14.0234375 79 | 11388.0,1668723432592,14.25390625 80 | 11534.0,1668723435586,14.408203125 81 | 11680.0,1668723438543,14.341796875 82 | 11826.0,1668723441477,14.41015625 83 | 11972.0,1668723444481,14.509765625 84 | 12118.0,1668723447475,14.5625 85 | 12264.0,1668723450431,14.8203125 86 | 12410.0,1668723453351,14.68359375 87 | 12556.0,1668723456319,15.0078125 88 | 12702.0,1668723459301,14.8203125 89 | 12848.0,1668723462271,14.9609375 90 | 12994.0,1668723465203,14.685546875 91 | 13140.0,1668723468190,14.966796875 92 | 13286.0,1668723471099,15.4375 93 | 13432.0,1668723473967,15.154296875 94 | 13578.0,1668723476912,15.56640625 95 | 13724.0,1668723479815,14.88671875 96 | 13870.0,1668723482708,15.30078125 97 | 14016.0,1668723485617,15.419921875 98 | 14162.0,1668723488575,15.322265625 99 | 14308.0,1668723491485,15.03515625 100 | 14454.0,1668723494503,15.60546875 101 | 14600.0,1668723497485,15.576171875 102 | 14746.0,1668723500409,15.388671875 103 | 14892.0,1668723503375,15.26953125 104 | 15038.0,1668723506337,15.81640625 105 | 15184.0,1668723509229,15.255859375 106 | 15330.0,1668723512148,15.59375 107 | 15476.0,1668723515091,15.5625 108 | 15622.0,1668723518053,15.01953125 109 | 15768.0,1668723521065,15.666015625 110 | 15914.0,1668723524056,15.654296875 111 | 16060.0,1668723527090,15.591796875 112 | 16206.0,1668723529981,15.859375 113 | 16352.0,1668723532928,15.775390625 114 | 16498.0,1668723535841,15.54296875 115 | 16644.0,1668723538784,15.720703125 116 | 16790.0,1668723541790,15.82421875 117 | 16936.0,1668723544819,15.86328125 118 | 17082.0,1668723547705,15.615234375 119 | 17228.0,1668723550588,16.09375 120 | 17374.0,1668723553509,16.263671875 121 | 17520.0,1668723556527,15.994140625 122 | 17666.0,1668723559470,15.546875 123 | 17812.0,1668723562402,15.943359375 124 | 17958.0,1668723565333,15.974609375 125 | 18104.0,1668723568300,16.341796875 126 | 18250.0,1668723571237,16.20703125 127 | 18396.0,1668723574092,16.228515625 128 | 18542.0,1668723577087,15.837890625 129 | 18688.0,1668723580052,16.091796875 130 | 18834.0,1668723583048,16.27734375 131 | 18980.0,1668723586132,16.474609375 132 | 19126.0,1668723589090,16.12890625 133 | 19272.0,1668723592019,16.0703125 134 | 19418.0,1668723594981,16.474609375 135 | 19564.0,1668723597962,16.533203125 136 | 19710.0,1668723600911,16.349609375 137 | 19856.0,1668723603876,16.271484375 138 | 20002.0,1668723606852,16.41015625 139 | 20148.0,1668723609817,16.666015625 140 | 20294.0,1668723612733,15.9921875 141 | 20440.0,1668723615677,16.208984375 142 | 20586.0,1668723618627,16.38671875 143 | 20732.0,1668723621547,16.056640625 144 | 20878.0,1668723624474,16.201171875 145 | 21024.0,1668723627407,16.220703125 146 | 21170.0,1668723630334,16.4765625 147 | 21316.0,1668723633241,16.158203125 148 | 21462.0,1668723636246,16.0 149 | 21608.0,1668723639268,16.228515625 150 | 21754.0,1668723642226,16.4296875 151 | 21900.0,1668723645173,16.263671875 152 | 22046.0,1668723648177,16.443359375 153 | 22192.0,1668723651061,16.509765625 154 | 22338.0,1668723653963,16.892578125 155 | 22484.0,1668723656873,16.4453125 156 | 22630.0,1668723659903,17.046875 157 | 22776.0,1668723662906,16.40625 158 | 22922.0,1668723665798,16.34375 159 | 23068.0,1668723668717,16.65234375 160 | 23214.0,1668723671714,16.462890625 161 | 23360.0,1668723674653,16.751953125 162 | 23506.0,1668723677650,16.7421875 163 | 23652.0,1668723680603,17.119140625 164 | 23798.0,1668723683547,16.59765625 165 | 23944.0,1668723686475,16.423828125 166 | 24090.0,1668723689408,16.884765625 167 | 24236.0,1668723692326,16.9453125 168 | 24382.0,1668723695270,16.6953125 169 | 24528.0,1668723698284,17.109375 170 | 24674.0,1668723701255,16.861328125 171 | 24820.0,1668723704207,16.98828125 172 | 24966.0,1668723707193,16.962890625 173 | 25112.0,1668723710116,16.658203125 174 | 25258.0,1668723713026,16.90234375 175 | 25404.0,1668723715902,16.59765625 176 | 25550.0,1668723718935,16.9375 177 | 25696.0,1668723721917,16.916015625 178 | 25842.0,1668723724835,16.775390625 179 | 25988.0,1668723727867,16.81640625 180 | 26134.0,1668723730811,16.62890625 181 | 26280.0,1668723733942,16.947265625 182 | 26426.0,1668723736984,16.720703125 183 | 26572.0,1668723740048,16.615234375 184 | 26718.0,1668723742952,16.669921875 185 | 26864.0,1668723745880,16.7421875 186 | 27010.0,1668723748878,17.107421875 187 | 27156.0,1668723751871,16.92578125 188 | 27302.0,1668723754827,17.1171875 189 | 27448.0,1668723757777,16.955078125 190 | 27594.0,1668723760671,16.591796875 191 | 27740.0,1668723763502,16.888671875 192 | 27886.0,1668723766617,16.90625 193 | 28032.0,1668723769501,16.6328125 194 | 28178.0,1668723772364,16.623046875 195 | 28324.0,1668723775217,16.91015625 196 | 28470.0,1668723778204,17.29296875 197 | 28616.0,1668723781057,16.5859375 198 | 28762.0,1668723784024,16.6484375 199 | 28908.0,1668723786984,17.357421875 200 | 29054.0,1668723789946,16.8984375 201 | 29200.0,1668723792895,17.1953125 202 | -------------------------------------------------------------------------------- /snake/data/appendix/return/return_bootstrap_no_outer_no_norm.csv: -------------------------------------------------------------------------------- 1 | 3.0,1668722619197,0.1328125 2 | 146.0,1668722640028,0.16796875 3 | 292.0,1668722645854,0.21875 4 | 438.0,1668722651650,1.458984375 5 | 584.0,1668722657527,3.625 6 | 730.0,1668722663141,3.990234375 7 | 876.0,1668722668885,4.234375 8 | 1022.0,1668722674617,4.57421875 9 | 1168.0,1668722680358,4.61328125 10 | 1314.0,1668722686093,4.646484375 11 | 1460.0,1668722691785,5.3125 12 | 1606.0,1668722697437,5.103515625 13 | 1752.0,1668722703235,5.53515625 14 | 1898.0,1668722708898,5.640625 15 | 2044.0,1668722714739,5.671875 16 | 2190.0,1668722720660,5.6640625 17 | 2336.0,1668722726364,6.224609375 18 | 2482.0,1668722732005,6.1640625 19 | 2628.0,1668722737803,6.025390625 20 | 2774.0,1668722743800,6.353515625 21 | 2920.0,1668722749562,6.39453125 22 | 3066.0,1668722755350,6.638671875 23 | 3212.0,1668722761124,6.541015625 24 | 3358.0,1668722766927,7.033203125 25 | 3504.0,1668722772705,6.435546875 26 | 3650.0,1668722778522,6.80859375 27 | 3796.0,1668722784271,6.65234375 28 | 3942.0,1668722790034,7.447265625 29 | 4088.0,1668722795789,7.05859375 30 | 4234.0,1668722801481,7.560546875 31 | 4380.0,1668722807333,7.66015625 32 | 4526.0,1668722813076,7.87109375 33 | 4672.0,1668722818834,8.384765625 34 | 4818.0,1668722824472,8.107421875 35 | 4964.0,1668722830204,8.70703125 36 | 5110.0,1668722836025,8.7578125 37 | 5256.0,1668722841726,8.2734375 38 | 5402.0,1668722847339,9.134765625 39 | 5548.0,1668722852929,9.25 40 | 5694.0,1668722858607,9.380859375 41 | 5840.0,1668722864225,9.373046875 42 | 5986.0,1668722869908,9.5 43 | 6132.0,1668722875683,9.91015625 44 | 6278.0,1668722881367,10.482421875 45 | 6424.0,1668722887141,10.7890625 46 | 6570.0,1668722892839,10.759765625 47 | 6716.0,1668722898701,11.0 48 | 6862.0,1668722904312,10.791015625 49 | 7008.0,1668722910087,11.10546875 50 | 7154.0,1668722915758,11.486328125 51 | 7300.0,1668722921420,11.8828125 52 | 7446.0,1668722927185,12.400390625 53 | 7592.0,1668722932867,12.267578125 54 | 7738.0,1668722938582,12.279296875 55 | 7884.0,1668722944416,12.427734375 56 | 8030.0,1668722950303,12.724609375 57 | 8176.0,1668722956098,12.482421875 58 | 8322.0,1668722961755,12.580078125 59 | 8468.0,1668722967472,13.337890625 60 | 8614.0,1668722973222,13.10546875 61 | 8760.0,1668722979002,13.607421875 62 | 8906.0,1668722984751,13.703125 63 | 9052.0,1668722990496,13.484375 64 | 9198.0,1668722996216,13.451171875 65 | 9344.0,1668723001906,14.0859375 66 | 9490.0,1668723007633,14.357421875 67 | 9636.0,1668723013453,14.505859375 68 | 9782.0,1668723019300,14.357421875 69 | 9928.0,1668723025025,14.5 70 | 10074.0,1668723030696,14.58984375 71 | 10220.0,1668723036496,14.451171875 72 | 10366.0,1668723042147,14.431640625 73 | 10512.0,1668723047937,14.759765625 74 | 10658.0,1668723053729,14.748046875 75 | 10804.0,1668723059509,14.59375 76 | 10950.0,1668723065218,14.919921875 77 | 11096.0,1668723071050,14.951171875 78 | 11242.0,1668723076833,15.271484375 79 | 11388.0,1668723082607,15.20703125 80 | 11534.0,1668723088315,14.990234375 81 | 11680.0,1668723094012,15.1640625 82 | 11826.0,1668723099722,15.47265625 83 | 11972.0,1668723105532,15.541015625 84 | 12118.0,1668723111235,15.7265625 85 | 12264.0,1668723116936,15.697265625 86 | 12410.0,1668723122630,15.26953125 87 | 12556.0,1668723128438,15.72265625 88 | 12702.0,1668723134292,16.083984375 89 | 12848.0,1668723140026,15.923828125 90 | 12994.0,1668723145687,16.046875 91 | 13140.0,1668723151384,16.3203125 92 | 13286.0,1668723157110,16.23046875 93 | 13432.0,1668723162716,16.359375 94 | 13578.0,1668723168363,16.75 95 | 13724.0,1668723174022,16.55078125 96 | 13870.0,1668723179751,16.431640625 97 | 14016.0,1668723185409,16.75 98 | 14162.0,1668723191036,16.263671875 99 | 14308.0,1668723196819,16.861328125 100 | 14454.0,1668723202442,16.787109375 101 | 14600.0,1668723208221,17.412109375 102 | 14746.0,1668723213865,17.044921875 103 | 14892.0,1668723219625,17.08203125 104 | 15038.0,1668723225379,17.294921875 105 | 15184.0,1668723231053,16.6171875 106 | 15330.0,1668723236781,17.359375 107 | 15476.0,1668723242545,17.640625 108 | 15622.0,1668723248259,16.830078125 109 | 15768.0,1668723254016,17.14453125 110 | 15914.0,1668723259906,17.533203125 111 | 16060.0,1668723265740,17.5703125 112 | 16206.0,1668723271470,17.556640625 113 | 16352.0,1668723277358,17.6015625 114 | 16498.0,1668723283080,17.60546875 115 | 16644.0,1668723288988,17.693359375 116 | 16790.0,1668723294686,17.771484375 117 | 16936.0,1668723300475,17.123046875 118 | 17082.0,1668723306407,17.744140625 119 | 17228.0,1668723312107,18.01953125 120 | 17374.0,1668723317927,17.8984375 121 | 17520.0,1668723323617,17.455078125 122 | 17666.0,1668723329307,17.86328125 123 | 17812.0,1668723334958,17.154296875 124 | 17958.0,1668723340804,18.560546875 125 | 18104.0,1668723346540,18.130859375 126 | 18250.0,1668723352150,18.03515625 127 | 18396.0,1668723357881,18.5 128 | 18542.0,1668723363667,17.99609375 129 | 18688.0,1668723369414,17.68359375 130 | 18834.0,1668723375124,18.001953125 131 | 18980.0,1668723380993,18.318359375 132 | 19126.0,1668723386615,18.185546875 133 | 19272.0,1668723392258,18.5 134 | 19418.0,1668723397953,18.427734375 135 | 19564.0,1668723403594,18.51171875 136 | 19710.0,1668723409217,18.392578125 137 | 19856.0,1668723415006,18.486328125 138 | 20002.0,1668723420733,17.962890625 139 | 20148.0,1668723426617,18.3828125 140 | 20294.0,1668723432389,18.623046875 141 | 20440.0,1668723438213,18.5859375 142 | 20586.0,1668723443997,18.44921875 143 | 20732.0,1668723449726,18.4609375 144 | 20878.0,1668723455467,18.51953125 145 | 21024.0,1668723461158,18.59765625 146 | 21170.0,1668723466870,19.298828125 147 | 21316.0,1668723472513,18.619140625 148 | 21462.0,1668723478178,19.0703125 149 | 21608.0,1668723483915,18.5234375 150 | 21754.0,1668723489649,19.423828125 151 | 21900.0,1668723495350,18.9375 152 | 22046.0,1668723501107,19.021484375 153 | 22192.0,1668723506790,18.94140625 154 | 22338.0,1668723512511,19.171875 155 | 22484.0,1668723518326,19.357421875 156 | 22630.0,1668723524410,19.34765625 157 | 22776.0,1668723530157,19.197265625 158 | 22922.0,1668723535915,18.9453125 159 | 23068.0,1668723541714,19.548828125 160 | 23214.0,1668723547515,20.001953125 161 | 23360.0,1668723553261,19.470703125 162 | 23506.0,1668723559000,19.6796875 163 | 23652.0,1668723564710,19.603515625 164 | 23798.0,1668723570466,19.767578125 165 | 23944.0,1668723576104,19.845703125 166 | 24090.0,1668723581868,20.306640625 167 | 24236.0,1668723587643,20.60546875 168 | 24382.0,1668723593429,21.0703125 169 | 24528.0,1668723599213,20.18359375 170 | 24674.0,1668723604930,21.396484375 171 | 24820.0,1668723610673,21.44921875 172 | 24966.0,1668723616481,21.0859375 173 | 25112.0,1668723622140,20.41796875 174 | 25258.0,1668723627902,21.591796875 175 | 25404.0,1668723633692,21.63671875 176 | 25550.0,1668723639407,22.3828125 177 | 25696.0,1668723645060,21.3046875 178 | 25842.0,1668723650840,21.978515625 179 | 25988.0,1668723656534,23.203125 180 | 26134.0,1668723662184,22.7109375 181 | 26280.0,1668723667968,23.30078125 182 | 26426.0,1668723673616,23.57421875 183 | 26572.0,1668723679336,23.759765625 184 | 26718.0,1668723685088,24.6953125 185 | 26864.0,1668723690820,25.328125 186 | 27010.0,1668723696535,25.3671875 187 | 27156.0,1668723702282,26.216796875 188 | 27302.0,1668723708087,26.984375 189 | 27448.0,1668723713670,28.15625 190 | 27594.0,1668723719435,27.2421875 191 | 27740.0,1668723725158,28.361328125 192 | 27886.0,1668723731004,29.01953125 193 | 28032.0,1668723736800,28.388671875 194 | 28178.0,1668723742499,29.0625 195 | 28324.0,1668723748207,29.712890625 196 | 28470.0,1668723753972,30.822265625 197 | 28616.0,1668723759729,30.857421875 198 | 28762.0,1668723765471,30.544921875 199 | 28908.0,1668723771156,31.25 200 | 29054.0,1668723776872,31.35546875 201 | 29200.0,1668723782544,31.185546875 202 | -------------------------------------------------------------------------------- /snake/data/appendix/return/return_bootstrap_no_outer_norm.csv: -------------------------------------------------------------------------------- 1 | 3.0,1668726176304,0.1328125 2 | 146.0,1668726195787,0.482421875 3 | 292.0,1668726201637,3.060546875 4 | 438.0,1668726207614,3.86328125 5 | 584.0,1668726213336,4.142578125 6 | 730.0,1668726219094,4.36328125 7 | 876.0,1668726224913,5.1015625 8 | 1022.0,1668726230744,4.978515625 9 | 1168.0,1668726236583,5.68359375 10 | 1314.0,1668726242399,4.96875 11 | 1460.0,1668726248214,5.515625 12 | 1606.0,1668726254036,6.123046875 13 | 1752.0,1668726259933,6.259765625 14 | 1898.0,1668726265912,6.41796875 15 | 2044.0,1668726271703,6.2578125 16 | 2190.0,1668726277593,6.05859375 17 | 2336.0,1668726283479,6.634765625 18 | 2482.0,1668726289219,6.552734375 19 | 2628.0,1668726294913,6.28125 20 | 2774.0,1668726300709,6.505859375 21 | 2920.0,1668726306543,6.306640625 22 | 3066.0,1668726312335,6.5078125 23 | 3212.0,1668726318156,6.783203125 24 | 3358.0,1668726324064,6.83984375 25 | 3504.0,1668726329855,6.373046875 26 | 3650.0,1668726335682,6.615234375 27 | 3796.0,1668726341555,6.44140625 28 | 3942.0,1668726347521,6.853515625 29 | 4088.0,1668726353462,6.611328125 30 | 4234.0,1668726359238,6.69921875 31 | 4380.0,1668726365187,6.83203125 32 | 4526.0,1668726371142,6.853515625 33 | 4672.0,1668726377172,7.193359375 34 | 4818.0,1668726383132,6.818359375 35 | 4964.0,1668726388941,7.1171875 36 | 5110.0,1668726394825,6.78125 37 | 5256.0,1668726400580,7.18359375 38 | 5402.0,1668726406397,7.25 39 | 5548.0,1668726412196,7.076171875 40 | 5694.0,1668726418025,7.18359375 41 | 5840.0,1668726423794,7.138671875 42 | 5986.0,1668726429613,7.2109375 43 | 6132.0,1668726435531,7.40625 44 | 6278.0,1668726441367,7.447265625 45 | 6424.0,1668726447299,7.509765625 46 | 6570.0,1668726453290,7.52734375 47 | 6716.0,1668726459307,7.509765625 48 | 6862.0,1668726465173,7.306640625 49 | 7008.0,1668726471197,8.1796875 50 | 7154.0,1668726477248,7.4140625 51 | 7300.0,1668726483092,7.791015625 52 | 7446.0,1668726488945,8.02734375 53 | 7592.0,1668726494638,8.18359375 54 | 7738.0,1668726500284,8.3203125 55 | 7884.0,1668726506030,8.447265625 56 | 8030.0,1668726511687,8.5703125 57 | 8176.0,1668726517465,8.595703125 58 | 8322.0,1668726523288,8.896484375 59 | 8468.0,1668726529121,8.84765625 60 | 8614.0,1668726534863,9.078125 61 | 8760.0,1668726540646,9.056640625 62 | 8906.0,1668726546472,9.525390625 63 | 9052.0,1668726552178,9.419921875 64 | 9198.0,1668726557821,9.375 65 | 9344.0,1668726563522,9.64453125 66 | 9490.0,1668726569288,9.91015625 67 | 9636.0,1668726575047,10.111328125 68 | 9782.0,1668726580800,9.943359375 69 | 9928.0,1668726586523,10.10546875 70 | 10074.0,1668726592249,10.072265625 71 | 10220.0,1668726597902,10.650390625 72 | 10366.0,1668726603786,10.28125 73 | 10512.0,1668726609549,10.529296875 74 | 10658.0,1668726615355,10.76171875 75 | 10804.0,1668726621214,10.798828125 76 | 10950.0,1668726626992,11.3828125 77 | 11096.0,1668726632753,10.67578125 78 | 11242.0,1668726638599,11.576171875 79 | 11388.0,1668726644345,11.404296875 80 | 11534.0,1668726650056,11.255859375 81 | 11680.0,1668726655842,11.716796875 82 | 11826.0,1668726661702,11.58984375 83 | 11972.0,1668726667535,11.759765625 84 | 12118.0,1668726673359,12.0 85 | 12264.0,1668726679202,11.986328125 86 | 12410.0,1668726684918,11.677734375 87 | 12556.0,1668726690644,12.447265625 88 | 12702.0,1668726696372,12.2109375 89 | 12848.0,1668726702172,12.3203125 90 | 12994.0,1668726708026,12.375 91 | 13140.0,1668726713764,12.568359375 92 | 13286.0,1668726719548,12.490234375 93 | 13432.0,1668726725285,12.505859375 94 | 13578.0,1668726731045,13.189453125 95 | 13724.0,1668726736955,12.740234375 96 | 13870.0,1668726742690,12.779296875 97 | 14016.0,1668726748456,13.095703125 98 | 14162.0,1668726754285,12.990234375 99 | 14308.0,1668726760149,13.087890625 100 | 14454.0,1668726765949,13.177734375 101 | 14600.0,1668726771767,12.640625 102 | 14746.0,1668726777571,13.255859375 103 | 14892.0,1668726783327,13.28515625 104 | 15038.0,1668726789094,13.068359375 105 | 15184.0,1668726795074,13.361328125 106 | 15330.0,1668726800939,13.34375 107 | 15476.0,1668726806671,12.939453125 108 | 15622.0,1668726812432,12.92578125 109 | 15768.0,1668726818166,12.94140625 110 | 15914.0,1668726823930,12.779296875 111 | 16060.0,1668726829650,13.359375 112 | 16206.0,1668726835530,12.884765625 113 | 16352.0,1668726841307,13.21875 114 | 16498.0,1668726847147,12.97265625 115 | 16644.0,1668726852981,13.734375 116 | 16790.0,1668726858855,13.306640625 117 | 16936.0,1668726864543,13.080078125 118 | 17082.0,1668726870264,13.173828125 119 | 17228.0,1668726876001,13.330078125 120 | 17374.0,1668726881723,13.421875 121 | 17520.0,1668726887686,13.35546875 122 | 17666.0,1668726893504,13.3515625 123 | 17812.0,1668726899332,13.2578125 124 | 17958.0,1668726905220,13.3359375 125 | 18104.0,1668726911033,13.640625 126 | 18250.0,1668726916856,13.44140625 127 | 18396.0,1668726922579,13.447265625 128 | 18542.0,1668726928392,13.01171875 129 | 18688.0,1668726934203,13.388671875 130 | 18834.0,1668726939947,13.01953125 131 | 18980.0,1668726945715,13.322265625 132 | 19126.0,1668726951473,13.37109375 133 | 19272.0,1668726957272,13.46875 134 | 19418.0,1668726963106,13.2734375 135 | 19564.0,1668726968919,12.9375 136 | 19710.0,1668726974705,13.44921875 137 | 19856.0,1668726980479,13.197265625 138 | 20002.0,1668726986303,13.205078125 139 | 20148.0,1668726992186,13.12109375 140 | 20294.0,1668726998108,12.708984375 141 | 20440.0,1668727003907,13.298828125 142 | 20586.0,1668727009881,12.134765625 143 | 20732.0,1668727015758,12.857421875 144 | 20878.0,1668727021683,12.44921875 145 | 21024.0,1668727027577,12.931640625 146 | 21170.0,1668727033537,13.15234375 147 | 21316.0,1668727039497,12.505859375 148 | 21462.0,1668727045431,12.1484375 149 | 21608.0,1668727051179,13.0078125 150 | 21754.0,1668727057083,12.876953125 151 | 21900.0,1668727062937,12.806640625 152 | 22046.0,1668727068822,12.92578125 153 | 22192.0,1668727074650,13.0390625 154 | 22338.0,1668727080481,12.98828125 155 | 22484.0,1668727086427,13.26171875 156 | 22630.0,1668727092241,13.197265625 157 | 22776.0,1668727098143,12.86328125 158 | 22922.0,1668727103974,12.908203125 159 | 23068.0,1668727109757,13.291015625 160 | 23214.0,1668727115639,12.90625 161 | 23360.0,1668727121612,13.12890625 162 | 23506.0,1668727127593,13.26171875 163 | 23652.0,1668727133444,13.259765625 164 | 23798.0,1668727139309,13.515625 165 | 23944.0,1668727145383,13.4921875 166 | 24090.0,1668727151326,13.51171875 167 | 24236.0,1668727157297,13.52734375 168 | 24382.0,1668727163195,13.865234375 169 | 24528.0,1668727169206,13.69140625 170 | 24674.0,1668727175068,13.32421875 171 | 24820.0,1668727180935,13.33203125 172 | 24966.0,1668727186800,13.4140625 173 | 25112.0,1668727192658,13.23828125 174 | 25258.0,1668727198503,13.634765625 175 | 25404.0,1668727204475,12.8984375 176 | 25550.0,1668727210328,13.3125 177 | 25696.0,1668727216147,13.70703125 178 | 25842.0,1668727221931,13.796875 179 | 25988.0,1668727227675,13.443359375 180 | 26134.0,1668727233400,13.623046875 181 | 26280.0,1668727239274,13.486328125 182 | 26426.0,1668727245186,13.78125 183 | 26572.0,1668727250984,13.88671875 184 | 26718.0,1668727256776,13.939453125 185 | 26864.0,1668727262592,13.734375 186 | 27010.0,1668727268466,13.75390625 187 | 27156.0,1668727274248,13.5546875 188 | 27302.0,1668727280098,13.26171875 189 | 27448.0,1668727285757,13.546875 190 | 27594.0,1668727291492,13.529296875 191 | 27740.0,1668727297352,13.740234375 192 | 27886.0,1668727303050,13.951171875 193 | 28032.0,1668727308907,13.984375 194 | 28178.0,1668727314686,13.212890625 195 | 28324.0,1668727320550,13.724609375 196 | 28470.0,1668727326383,13.73828125 197 | 28616.0,1668727332393,13.666015625 198 | 28762.0,1668727338229,13.390625 199 | 28908.0,1668727344047,14.0546875 200 | 29054.0,1668727349833,13.919921875 201 | 29200.0,1668727355880,14.05078125 202 | -------------------------------------------------------------------------------- /snake/data/bootstrap_outer_critic_return/MET-2195__eval_episode_reward_stochastic_policy.csv: -------------------------------------------------------------------------------- 1 | 3.0,1667994813080,0.1484375 2 | 146.0,1667994833793,0.126953125 3 | 292.0,1667994839668,0.212890625 4 | 438.0,1667994845709,1.52734375 5 | 584.0,1667994851582,3.611328125 6 | 730.0,1667994857340,3.849609375 7 | 876.0,1667994863289,4.201171875 8 | 1022.0,1667994869143,4.740234375 9 | 1168.0,1667994875019,5.052734375 10 | 1314.0,1667994880867,5.59375 11 | 1460.0,1667994886513,5.349609375 12 | 1606.0,1667994892262,5.4453125 13 | 1752.0,1667994898097,5.880859375 14 | 1898.0,1667994904052,5.958984375 15 | 2044.0,1667994909875,6.05859375 16 | 2190.0,1667994915733,6.005859375 17 | 2336.0,1667994921625,6.037109375 18 | 2482.0,1667994927434,6.42578125 19 | 2628.0,1667994933147,6.380859375 20 | 2774.0,1667994938926,6.396484375 21 | 2920.0,1667994944691,6.423828125 22 | 3066.0,1667994950507,6.54296875 23 | 3212.0,1667994956285,6.501953125 24 | 3358.0,1667994962115,6.21484375 25 | 3504.0,1667994967933,6.328125 26 | 3650.0,1667994973731,6.3828125 27 | 3796.0,1667994979591,6.353515625 28 | 3942.0,1667994985608,6.58203125 29 | 4088.0,1667994991503,6.4921875 30 | 4234.0,1667994997258,6.4375 31 | 4380.0,1667995003073,6.998046875 32 | 4526.0,1667995008857,6.40625 33 | 4672.0,1667995014681,7.193359375 34 | 4818.0,1667995020445,7.173828125 35 | 4964.0,1667995026258,7.15234375 36 | 5110.0,1667995032111,7.119140625 37 | 5256.0,1667995037843,7.732421875 38 | 5402.0,1667995043840,7.37109375 39 | 5548.0,1667995049690,7.76953125 40 | 5694.0,1667995055531,7.44140625 41 | 5840.0,1667995061415,7.83984375 42 | 5986.0,1667995067242,7.98046875 43 | 6132.0,1667995072900,8.375 44 | 6278.0,1667995078562,8.810546875 45 | 6424.0,1667995084444,8.556640625 46 | 6570.0,1667995090169,8.87109375 47 | 6716.0,1667995095998,9.251953125 48 | 6862.0,1667995101806,9.734375 49 | 7008.0,1667995107636,9.765625 50 | 7154.0,1667995113440,9.94921875 51 | 7300.0,1667995119170,10.49609375 52 | 7446.0,1667995125049,10.69140625 53 | 7592.0,1667995130983,10.990234375 54 | 7738.0,1667995136954,11.427734375 55 | 7884.0,1667995142796,11.916015625 56 | 8030.0,1667995148612,11.9296875 57 | 8176.0,1667995154334,12.25 58 | 8322.0,1667995160132,12.615234375 59 | 8468.0,1667995165970,12.7578125 60 | 8614.0,1667995171794,12.8828125 61 | 8760.0,1667995177629,13.70703125 62 | 8906.0,1667995183430,13.87890625 63 | 9052.0,1667995189389,14.150390625 64 | 9198.0,1667995195211,14.55859375 65 | 9344.0,1667995201166,15.099609375 66 | 9490.0,1667995207160,14.92578125 67 | 9636.0,1667995213240,15.380859375 68 | 9782.0,1667995219066,15.79296875 69 | 9928.0,1667995225622,15.953125 70 | 10074.0,1667995231599,16.103515625 71 | 10220.0,1667995237550,16.302734375 72 | 10366.0,1667995243444,16.3203125 73 | 10512.0,1667995249140,16.234375 74 | 10658.0,1667995255036,16.16015625 75 | 10804.0,1667995260930,16.265625 76 | 10950.0,1667995266944,16.134765625 77 | 11096.0,1667995272956,16.94921875 78 | 11242.0,1667995278848,17.03515625 79 | 11388.0,1667995284901,17.591796875 80 | 11534.0,1667995290930,16.916015625 81 | 11680.0,1667995296926,17.4453125 82 | 11826.0,1667995302850,16.880859375 83 | 11972.0,1667995308762,17.30078125 84 | 12118.0,1667995314730,17.732421875 85 | 12264.0,1667995320661,17.912109375 86 | 12410.0,1667995326638,17.587890625 87 | 12556.0,1667995332542,18.1171875 88 | 12702.0,1667995338459,17.90234375 89 | 12848.0,1667995344430,17.927734375 90 | 12994.0,1667995350351,17.6328125 91 | 13140.0,1667995356317,17.9453125 92 | 13286.0,1667995362287,18.291015625 93 | 13432.0,1667995368168,17.65234375 94 | 13578.0,1667995374071,18.21875 95 | 13724.0,1667995380123,18.28125 96 | 13870.0,1667995386023,18.1875 97 | 14016.0,1667995391939,18.130859375 98 | 14162.0,1667995397974,18.974609375 99 | 14308.0,1667995403885,18.6171875 100 | 14454.0,1667995409772,18.75 101 | 14600.0,1667995415753,18.625 102 | 14746.0,1667995421752,18.876953125 103 | 14892.0,1667995427663,19.267578125 104 | 15038.0,1667995433605,19.34765625 105 | 15184.0,1667995439535,19.41796875 106 | 15330.0,1667995445407,19.83984375 107 | 15476.0,1667995451416,19.79296875 108 | 15622.0,1667995457167,20.18359375 109 | 15768.0,1667995462998,19.564453125 110 | 15914.0,1667995468767,20.078125 111 | 16060.0,1667995474635,20.568359375 112 | 16206.0,1667995480424,20.16015625 113 | 16352.0,1667995486275,20.755859375 114 | 16498.0,1667995492313,20.703125 115 | 16644.0,1667995498120,20.38671875 116 | 16790.0,1667995504181,21.5703125 117 | 16936.0,1667995510054,20.720703125 118 | 17082.0,1667995515917,21.12890625 119 | 17228.0,1667995521742,21.681640625 120 | 17374.0,1667995527611,22.0 121 | 17520.0,1667995533541,22.8828125 122 | 17666.0,1667995539381,22.4921875 123 | 17812.0,1667995545272,22.68359375 124 | 17958.0,1667995551131,23.044921875 125 | 18104.0,1667995556973,23.62890625 126 | 18250.0,1667995562927,24.1796875 127 | 18396.0,1667995568679,24.24609375 128 | 18542.0,1667995574573,24.10546875 129 | 18688.0,1667995580510,24.322265625 130 | 18834.0,1667995586366,25.462890625 131 | 18980.0,1667995592325,25.7265625 132 | 19126.0,1667995598196,27.662109375 133 | 19272.0,1667995604231,27.09765625 134 | 19418.0,1667995610194,28.6640625 135 | 19564.0,1667995616033,29.3046875 136 | 19710.0,1667995621869,30.00390625 137 | 19856.0,1667995627806,29.5546875 138 | 20002.0,1667995633722,29.63671875 139 | 20148.0,1667995639557,30.845703125 140 | 20294.0,1667995645636,30.8671875 141 | 20440.0,1667995651476,30.66015625 142 | 20586.0,1667995657363,31.708984375 143 | 20732.0,1667995663219,32.171875 144 | 20878.0,1667995668993,32.44140625 145 | 21024.0,1667995674930,31.77734375 146 | 21170.0,1667995680800,33.005859375 147 | 21316.0,1667995686902,32.884765625 148 | 21462.0,1667995692776,33.498046875 149 | 21608.0,1667995698724,33.087890625 150 | 21754.0,1667995704785,33.486328125 151 | 21900.0,1667995710755,33.15234375 152 | 22046.0,1667995716719,33.376953125 153 | 22192.0,1667995722735,33.4609375 154 | 22338.0,1667995728700,33.8828125 155 | 22484.0,1667995734668,33.349609375 156 | 22630.0,1667995740594,33.75390625 157 | 22776.0,1667995746579,33.111328125 158 | 22922.0,1667995752950,32.490234375 159 | 23068.0,1667995759828,33.837890625 160 | 23214.0,1667995766824,33.9296875 161 | 23360.0,1667995773799,33.63671875 162 | 23506.0,1667995780744,33.52734375 163 | 23652.0,1667995787983,33.5546875 164 | 23798.0,1667995794921,33.697265625 165 | 23944.0,1667995801292,34.21875 166 | 24090.0,1667995807280,33.951171875 167 | 24236.0,1667995813292,33.677734375 168 | 24382.0,1667995819288,33.859375 169 | 24528.0,1667995825341,33.853515625 170 | 24674.0,1667995831300,33.537109375 171 | 24820.0,1667995837385,34.171875 172 | 24966.0,1667995843389,34.580078125 173 | 25112.0,1667995849396,34.0234375 174 | 25258.0,1667995855246,33.9765625 175 | 25404.0,1667995861101,33.689453125 176 | 25550.0,1667995867022,33.69921875 177 | 25696.0,1667995872920,34.1171875 178 | 25842.0,1667995878835,33.640625 179 | 25988.0,1667995884899,32.494140625 180 | 26134.0,1667995890730,33.984375 181 | 26280.0,1667995896584,34.189453125 182 | 26426.0,1667995902459,33.87109375 183 | 26572.0,1667995908499,33.392578125 184 | 26718.0,1667995914492,33.392578125 185 | 26864.0,1667995920445,33.568359375 186 | 27010.0,1667995926292,34.341796875 187 | 27156.0,1667995932242,33.759765625 188 | 27302.0,1667995938174,34.267578125 189 | 27448.0,1667995944157,34.447265625 190 | 27594.0,1667995949918,33.857421875 191 | 27740.0,1667995955812,33.857421875 192 | 27886.0,1667995961815,34.212890625 193 | 28032.0,1667995967779,33.83203125 194 | 28178.0,1667995973678,33.677734375 195 | 28324.0,1667995979587,34.337890625 196 | 28470.0,1667995985513,33.505859375 197 | 28616.0,1667995991466,34.1875 198 | 28762.0,1667995997360,33.986328125 199 | 28908.0,1667996003288,34.248046875 200 | 29054.0,1667996009218,34.5 201 | 29200.0,1667996015216,34.625 202 | -------------------------------------------------------------------------------- /snake/data/bootstrap_return/MET-2166__eval_episode_reward_stochastic_policy.csv: -------------------------------------------------------------------------------- 1 | 3.0,1667988472883,0.158203125 2 | 146.0,1667988491997,0.150390625 3 | 292.0,1667988497527,0.228515625 4 | 438.0,1667988503102,1.978515625 5 | 584.0,1667988508577,3.443359375 6 | 730.0,1667988514213,3.65234375 7 | 876.0,1667988519766,4.08203125 8 | 1022.0,1667988525253,4.197265625 9 | 1168.0,1667988530896,4.51953125 10 | 1314.0,1667988536374,4.685546875 11 | 1460.0,1667988542058,4.74609375 12 | 1606.0,1667988547740,5.15625 13 | 1752.0,1667988553696,4.541015625 14 | 1898.0,1667988559209,4.966796875 15 | 2044.0,1667988564737,4.95703125 16 | 2190.0,1667988570263,5.501953125 17 | 2336.0,1667988575738,5.712890625 18 | 2482.0,1667988581225,5.7421875 19 | 2628.0,1667988586661,5.6640625 20 | 2774.0,1667988592131,6.125 21 | 2920.0,1667988597730,6.025390625 22 | 3066.0,1667988603340,6.017578125 23 | 3212.0,1667988608949,6.205078125 24 | 3358.0,1667988614494,6.4375 25 | 3504.0,1667988619953,6.5625 26 | 3650.0,1667988625421,6.8125 27 | 3796.0,1667988630924,6.568359375 28 | 3942.0,1667988636336,6.466796875 29 | 4088.0,1667988641750,6.498046875 30 | 4234.0,1667988647158,6.625 31 | 4380.0,1667988652654,6.708984375 32 | 4526.0,1667988658121,7.296875 33 | 4672.0,1667988663742,7.40234375 34 | 4818.0,1667988669165,7.515625 35 | 4964.0,1667988674675,7.49609375 36 | 5110.0,1667988680164,7.59375 37 | 5256.0,1667988685636,7.58203125 38 | 5402.0,1667988691076,7.779296875 39 | 5548.0,1667988696603,8.201171875 40 | 5694.0,1667988702133,8.20703125 41 | 5840.0,1667988707728,8.18359375 42 | 5986.0,1667988713172,8.697265625 43 | 6132.0,1667988718647,8.4609375 44 | 6278.0,1667988724117,9.16015625 45 | 6424.0,1667988729721,8.8828125 46 | 6570.0,1667988735243,9.30078125 47 | 6716.0,1667988740826,9.080078125 48 | 6862.0,1667988746237,9.6171875 49 | 7008.0,1667988751634,9.609375 50 | 7154.0,1667988757093,9.990234375 51 | 7300.0,1667988762770,9.978515625 52 | 7446.0,1667988768278,10.06640625 53 | 7592.0,1667988773813,10.177734375 54 | 7738.0,1667988779365,9.74609375 55 | 7884.0,1667988784885,9.9765625 56 | 8030.0,1667988790435,9.99609375 57 | 8176.0,1667988796066,9.94140625 58 | 8322.0,1667988801546,10.048828125 59 | 8468.0,1667988807276,9.818359375 60 | 8614.0,1667988812786,9.896484375 61 | 8760.0,1667988818223,10.158203125 62 | 8906.0,1667988823737,10.36328125 63 | 9052.0,1667988829212,9.873046875 64 | 9198.0,1667988834723,9.87890625 65 | 9344.0,1667988840216,10.15625 66 | 9490.0,1667988845697,9.58203125 67 | 9636.0,1667988851268,9.853515625 68 | 9782.0,1667988856677,9.7890625 69 | 9928.0,1667988862281,9.94140625 70 | 10074.0,1667988867854,10.06640625 71 | 10220.0,1667988873353,9.93359375 72 | 10366.0,1667988878806,9.9296875 73 | 10512.0,1667988884313,9.3515625 74 | 10658.0,1667988889787,9.3671875 75 | 10804.0,1667988895266,9.541015625 76 | 10950.0,1667988900847,9.173828125 77 | 11096.0,1667988906491,9.400390625 78 | 11242.0,1667988911988,9.33984375 79 | 11388.0,1667988917483,9.18359375 80 | 11534.0,1667988922955,9.1953125 81 | 11680.0,1667988928562,9.0 82 | 11826.0,1667988934090,9.076171875 83 | 11972.0,1667988939646,8.666015625 84 | 12118.0,1667988945037,9.283203125 85 | 12264.0,1667988950624,8.849609375 86 | 12410.0,1667988956106,8.712890625 87 | 12556.0,1667988961631,8.818359375 88 | 12702.0,1667988967223,9.005859375 89 | 12848.0,1667988973075,8.93359375 90 | 12994.0,1667988978559,8.890625 91 | 13140.0,1667988984243,9.1015625 92 | 13286.0,1667988989794,8.763671875 93 | 13432.0,1667988995424,8.68359375 94 | 13578.0,1667989001056,8.765625 95 | 13724.0,1667989006550,8.580078125 96 | 13870.0,1667989012109,8.181640625 97 | 14016.0,1667989017569,7.755859375 98 | 14162.0,1667989023090,8.154296875 99 | 14308.0,1667989028609,7.970703125 100 | 14454.0,1667989034207,7.896484375 101 | 14600.0,1667989039737,7.865234375 102 | 14746.0,1667989045202,7.986328125 103 | 14892.0,1667989050731,7.5234375 104 | 15038.0,1667989056168,7.685546875 105 | 15184.0,1667989061724,7.140625 106 | 15330.0,1667989067337,7.57421875 107 | 15476.0,1667989072853,7.130859375 108 | 15622.0,1667989078338,7.033203125 109 | 15768.0,1667989084104,7.19140625 110 | 15914.0,1667989089604,7.482421875 111 | 16060.0,1667989095122,6.5703125 112 | 16206.0,1667989100730,7.103515625 113 | 16352.0,1667989106342,7.1015625 114 | 16498.0,1667989111813,6.470703125 115 | 16644.0,1667989117288,6.80078125 116 | 16790.0,1667989122829,6.41015625 117 | 16936.0,1667989128349,6.74609375 118 | 17082.0,1667989133816,6.134765625 119 | 17228.0,1667989139393,6.515625 120 | 17374.0,1667989144943,6.34765625 121 | 17520.0,1667989150451,6.1328125 122 | 17666.0,1667989155881,5.939453125 123 | 17812.0,1667989161596,5.95703125 124 | 17958.0,1667989167227,5.798828125 125 | 18104.0,1667989172709,6.5390625 126 | 18250.0,1667989178376,6.373046875 127 | 18396.0,1667989183921,6.25390625 128 | 18542.0,1667989189437,5.814453125 129 | 18688.0,1667989195024,5.7890625 130 | 18834.0,1667989200615,5.76953125 131 | 18980.0,1667989206254,6.01171875 132 | 19126.0,1667989211877,5.70703125 133 | 19272.0,1667989217439,5.865234375 134 | 19418.0,1667989222961,6.25390625 135 | 19564.0,1667989228552,5.8125 136 | 19710.0,1667989233963,5.8828125 137 | 19856.0,1667989239593,6.177734375 138 | 20002.0,1667989245103,6.26953125 139 | 20148.0,1667989250840,5.724609375 140 | 20294.0,1667989256420,5.44921875 141 | 20440.0,1667989262128,6.06640625 142 | 20586.0,1667989267722,6.20703125 143 | 20732.0,1667989273464,6.068359375 144 | 20878.0,1667989279148,6.0 145 | 21024.0,1667989284818,5.76953125 146 | 21170.0,1667989290358,5.859375 147 | 21316.0,1667989296065,5.580078125 148 | 21462.0,1667989301627,5.5859375 149 | 21608.0,1667989307144,5.900390625 150 | 21754.0,1667989312782,5.625 151 | 21900.0,1667989318230,6.21484375 152 | 22046.0,1667989323876,6.861328125 153 | 22192.0,1667989329537,6.830078125 154 | 22338.0,1667989335045,7.0390625 155 | 22484.0,1667989340555,7.083984375 156 | 22630.0,1667989346356,7.009765625 157 | 22776.0,1667989351924,6.953125 158 | 22922.0,1667989357409,7.044921875 159 | 23068.0,1667989363104,6.93359375 160 | 23214.0,1667989368709,6.919921875 161 | 23360.0,1667989374276,6.32421875 162 | 23506.0,1667989379790,6.232421875 163 | 23652.0,1667989385458,6.30078125 164 | 23798.0,1667989391103,6.37890625 165 | 23944.0,1667989396680,6.1640625 166 | 24090.0,1667989402418,6.001953125 167 | 24236.0,1667989407991,5.697265625 168 | 24382.0,1667989413623,5.623046875 169 | 24528.0,1667989419143,5.5859375 170 | 24674.0,1667989424840,5.666015625 171 | 24820.0,1667989430399,5.22265625 172 | 24966.0,1667989435963,5.513671875 173 | 25112.0,1667989441769,5.283203125 174 | 25258.0,1667989447356,5.384765625 175 | 25404.0,1667989452914,5.4921875 176 | 25550.0,1667989458435,5.556640625 177 | 25696.0,1667989463979,6.072265625 178 | 25842.0,1667989469541,5.740234375 179 | 25988.0,1667989474971,6.935546875 180 | 26134.0,1667989480532,5.990234375 181 | 26280.0,1667989486146,5.98046875 182 | 26426.0,1667989491672,6.02734375 183 | 26572.0,1667989497182,5.798828125 184 | 26718.0,1667989502845,5.71875 185 | 26864.0,1667989508428,5.279296875 186 | 27010.0,1667989514035,5.85546875 187 | 27156.0,1667989519472,5.970703125 188 | 27302.0,1667989525010,5.8671875 189 | 27448.0,1667989530454,5.63671875 190 | 27594.0,1667989535756,6.10546875 191 | 27740.0,1667989541487,6.11328125 192 | 27886.0,1667989546947,5.62890625 193 | 28032.0,1667989552507,5.939453125 194 | 28178.0,1667989558083,6.1796875 195 | 28324.0,1667989563702,5.787109375 196 | 28470.0,1667989569260,5.765625 197 | 28616.0,1667989574740,5.84765625 198 | 28762.0,1667989580270,5.916015625 199 | 28908.0,1667989585951,5.83984375 200 | 29054.0,1667989591436,5.615234375 201 | 29200.0,1667989596990,5.74609375 202 | -------------------------------------------------------------------------------- /snake/data/bootstrap_return/MET-2172__eval_episode_reward_stochastic_policy.csv: -------------------------------------------------------------------------------- 1 | 3.0,1667989606621,0.166015625 2 | 146.0,1667989625138,0.158203125 3 | 292.0,1667989630579,0.24609375 4 | 438.0,1667989635914,1.88671875 5 | 584.0,1667989641271,3.87890625 6 | 730.0,1667989646730,4.14453125 7 | 876.0,1667989652025,4.419921875 8 | 1022.0,1667989657482,4.892578125 9 | 1168.0,1667989662785,5.154296875 10 | 1314.0,1667989668110,5.3125 11 | 1460.0,1667989673459,5.236328125 12 | 1606.0,1667989678906,5.49609375 13 | 1752.0,1667989684313,5.83203125 14 | 1898.0,1667989689652,5.89453125 15 | 2044.0,1667989694912,6.21484375 16 | 2190.0,1667989700263,6.6875 17 | 2336.0,1667989705752,6.501953125 18 | 2482.0,1667989711098,6.703125 19 | 2628.0,1667989716563,7.486328125 20 | 2774.0,1667989721923,8.080078125 21 | 2920.0,1667989727393,8.16015625 22 | 3066.0,1667989732774,8.640625 23 | 3212.0,1667989738178,9.16015625 24 | 3358.0,1667989743588,9.669921875 25 | 3504.0,1667989748982,9.61328125 26 | 3650.0,1667989754313,10.1484375 27 | 3796.0,1667989759652,10.625 28 | 3942.0,1667989765050,9.859375 29 | 4088.0,1667989770346,11.04296875 30 | 4234.0,1667989775754,11.564453125 31 | 4380.0,1667989781054,11.720703125 32 | 4526.0,1667989786426,11.671875 33 | 4672.0,1667989791720,12.06640625 34 | 4818.0,1667989797083,12.37109375 35 | 4964.0,1667989802727,12.7421875 36 | 5110.0,1667989808253,12.75 37 | 5256.0,1667989813739,12.67578125 38 | 5402.0,1667989819118,12.6171875 39 | 5548.0,1667989824623,13.6484375 40 | 5694.0,1667989830006,12.931640625 41 | 5840.0,1667989835450,13.296875 42 | 5986.0,1667989840833,13.810546875 43 | 6132.0,1667989846274,13.958984375 44 | 6278.0,1667989851593,13.72265625 45 | 6424.0,1667989857027,14.03515625 46 | 6570.0,1667989862412,14.58203125 47 | 6716.0,1667989867866,14.72265625 48 | 6862.0,1667989873262,14.265625 49 | 7008.0,1667989878649,14.60546875 50 | 7154.0,1667989884117,14.341796875 51 | 7300.0,1667989889449,14.833984375 52 | 7446.0,1667989894976,14.98828125 53 | 7592.0,1667989900818,15.12109375 54 | 7738.0,1667989906292,15.158203125 55 | 7884.0,1667989911704,15.8828125 56 | 8030.0,1667989917173,15.306640625 57 | 8176.0,1667989922585,15.513671875 58 | 8322.0,1667989928011,15.248046875 59 | 8468.0,1667989933387,15.78125 60 | 8614.0,1667989938798,15.453125 61 | 8760.0,1667989944248,15.92578125 62 | 8906.0,1667989949684,15.60546875 63 | 9052.0,1667989955094,16.140625 64 | 9198.0,1667989960539,16.1171875 65 | 9344.0,1667989965889,15.857421875 66 | 9490.0,1667989971366,16.23046875 67 | 9636.0,1667989976720,16.150390625 68 | 9782.0,1667989982204,16.203125 69 | 9928.0,1667989987803,16.169921875 70 | 10074.0,1667989993088,16.033203125 71 | 10220.0,1667989998467,16.390625 72 | 10366.0,1667990003907,16.013671875 73 | 10512.0,1667990009300,16.060546875 74 | 10658.0,1667990014707,16.3828125 75 | 10804.0,1667990020076,16.646484375 76 | 10950.0,1667990025489,16.203125 77 | 11096.0,1667990030793,16.787109375 78 | 11242.0,1667990036211,16.95703125 79 | 11388.0,1667990041631,16.41015625 80 | 11534.0,1667990047186,16.92578125 81 | 11680.0,1667990052685,16.951171875 82 | 11826.0,1667990058039,16.4765625 83 | 11972.0,1667990063356,16.916015625 84 | 12118.0,1667990068752,16.923828125 85 | 12264.0,1667990074197,16.873046875 86 | 12410.0,1667990079543,17.0390625 87 | 12556.0,1667990084850,17.0546875 88 | 12702.0,1667990090240,17.25 89 | 12848.0,1667990095578,16.765625 90 | 12994.0,1667990101108,17.25390625 91 | 13140.0,1667990106533,17.310546875 92 | 13286.0,1667990111870,17.41015625 93 | 13432.0,1667990117274,17.42578125 94 | 13578.0,1667990122597,17.07421875 95 | 13724.0,1667990127980,17.373046875 96 | 13870.0,1667990133290,17.4765625 97 | 14016.0,1667990138604,17.44921875 98 | 14162.0,1667990144034,17.310546875 99 | 14308.0,1667990149408,17.2734375 100 | 14454.0,1667990154736,17.916015625 101 | 14600.0,1667990160076,17.416015625 102 | 14746.0,1667990165533,17.705078125 103 | 14892.0,1667990170829,17.7421875 104 | 15038.0,1667990176217,17.505859375 105 | 15184.0,1667990181580,17.5234375 106 | 15330.0,1667990186930,17.0703125 107 | 15476.0,1667990192323,18.234375 108 | 15622.0,1667990197788,17.72265625 109 | 15768.0,1667990203161,17.671875 110 | 15914.0,1667990208603,17.828125 111 | 16060.0,1667990213950,17.927734375 112 | 16206.0,1667990219250,17.9765625 113 | 16352.0,1667990224932,18.119140625 114 | 16498.0,1667990230243,18.474609375 115 | 16644.0,1667990235568,18.05078125 116 | 16790.0,1667990240902,18.7421875 117 | 16936.0,1667990246305,18.50390625 118 | 17082.0,1667990251584,18.51953125 119 | 17228.0,1667990257138,18.576171875 120 | 17374.0,1667990262554,18.330078125 121 | 17520.0,1667990267910,18.796875 122 | 17666.0,1667990273304,18.43359375 123 | 17812.0,1667990278692,18.5625 124 | 17958.0,1667990284065,18.7109375 125 | 18104.0,1667990289432,18.677734375 126 | 18250.0,1667990294816,18.4140625 127 | 18396.0,1667990300234,18.1484375 128 | 18542.0,1667990305654,17.92578125 129 | 18688.0,1667990311055,18.50390625 130 | 18834.0,1667990316605,17.794921875 131 | 18980.0,1667990321929,18.5 132 | 19126.0,1667990327356,18.7109375 133 | 19272.0,1667990332764,17.62109375 134 | 19418.0,1667990338212,18.712890625 135 | 19564.0,1667990343759,18.46875 136 | 19710.0,1667990349210,18.11328125 137 | 19856.0,1667990354619,18.712890625 138 | 20002.0,1667990359979,18.791015625 139 | 20148.0,1667990365439,18.859375 140 | 20294.0,1667990370813,18.552734375 141 | 20440.0,1667990376304,18.34765625 142 | 20586.0,1667990381686,18.08203125 143 | 20732.0,1667990387051,17.908203125 144 | 20878.0,1667990392411,18.294921875 145 | 21024.0,1667990397789,18.166015625 146 | 21170.0,1667990403297,18.359375 147 | 21316.0,1667990408685,18.369140625 148 | 21462.0,1667990414183,18.53515625 149 | 21608.0,1667990419459,18.322265625 150 | 21754.0,1667990424928,17.890625 151 | 21900.0,1667990430289,18.515625 152 | 22046.0,1667990435639,17.994140625 153 | 22192.0,1667990441001,18.34765625 154 | 22338.0,1667990446412,18.138671875 155 | 22484.0,1667990451711,18.443359375 156 | 22630.0,1667990457136,18.4453125 157 | 22776.0,1667990462520,18.904296875 158 | 22922.0,1667990467963,18.912109375 159 | 23068.0,1667990473357,18.771484375 160 | 23214.0,1667990478959,18.89453125 161 | 23360.0,1667990484434,18.359375 162 | 23506.0,1667990489837,17.98828125 163 | 23652.0,1667990495244,18.228515625 164 | 23798.0,1667990500609,18.583984375 165 | 23944.0,1667990505954,18.564453125 166 | 24090.0,1667990511297,18.83203125 167 | 24236.0,1667990516630,18.681640625 168 | 24382.0,1667990522000,18.509765625 169 | 24528.0,1667990527382,18.39453125 170 | 24674.0,1667990532768,17.955078125 171 | 24820.0,1667990538141,18.390625 172 | 24966.0,1667990543530,18.7109375 173 | 25112.0,1667990548883,18.572265625 174 | 25258.0,1667990554369,18.26953125 175 | 25404.0,1667990559688,18.263671875 176 | 25550.0,1667990565010,18.841796875 177 | 25696.0,1667990570410,18.47265625 178 | 25842.0,1667990575764,18.6015625 179 | 25988.0,1667990581064,18.552734375 180 | 26134.0,1667990586521,18.556640625 181 | 26280.0,1667990591880,18.966796875 182 | 26426.0,1667990597346,18.970703125 183 | 26572.0,1667990602702,19.12109375 184 | 26718.0,1667990608107,19.291015625 185 | 26864.0,1667990613428,19.15234375 186 | 27010.0,1667990618771,19.134765625 187 | 27156.0,1667990624143,19.447265625 188 | 27302.0,1667990629471,19.724609375 189 | 27448.0,1667990634988,19.34765625 190 | 27594.0,1667990640307,19.568359375 191 | 27740.0,1667990645789,19.40234375 192 | 27886.0,1667990651073,19.40625 193 | 28032.0,1667990656348,19.400390625 194 | 28178.0,1667990661691,19.5078125 195 | 28324.0,1667990667109,19.77734375 196 | 28470.0,1667990672484,20.2421875 197 | 28616.0,1667990677866,20.83984375 198 | 28762.0,1667990683181,20.037109375 199 | 28908.0,1667990688545,20.05078125 200 | 29054.0,1667990693925,20.044921875 201 | 29200.0,1667990699295,20.578125 202 | -------------------------------------------------------------------------------- /snake/data/bootstrap_return/MET-2178__eval_episode_reward_stochastic_policy.csv: -------------------------------------------------------------------------------- 1 | 3.0,1667990709114,0.134765625 2 | 146.0,1667990728218,0.15625 3 | 292.0,1667990733655,0.212890625 4 | 438.0,1667990739155,2.224609375 5 | 584.0,1667990744637,3.521484375 6 | 730.0,1667990750032,3.59375 7 | 876.0,1667990755436,4.0546875 8 | 1022.0,1667990760943,4.29296875 9 | 1168.0,1667990766410,4.701171875 10 | 1314.0,1667990771839,4.751953125 11 | 1460.0,1667990777430,5.1953125 12 | 1606.0,1667990782974,5.21484375 13 | 1752.0,1667990788579,5.82421875 14 | 1898.0,1667990793918,5.951171875 15 | 2044.0,1667990799510,5.734375 16 | 2190.0,1667990805018,5.76953125 17 | 2336.0,1667990810468,5.833984375 18 | 2482.0,1667990815847,6.12890625 19 | 2628.0,1667990821318,6.1328125 20 | 2774.0,1667990826836,6.345703125 21 | 2920.0,1667990832242,6.48828125 22 | 3066.0,1667990837686,6.419921875 23 | 3212.0,1667990843047,5.818359375 24 | 3358.0,1667990848501,6.296875 25 | 3504.0,1667990853924,6.544921875 26 | 3650.0,1667990859443,6.27734375 27 | 3796.0,1667990864846,6.6171875 28 | 3942.0,1667990870252,6.568359375 29 | 4088.0,1667990875790,6.771484375 30 | 4234.0,1667990881343,6.48046875 31 | 4380.0,1667990886799,6.73828125 32 | 4526.0,1667990892201,6.810546875 33 | 4672.0,1667990897621,6.271484375 34 | 4818.0,1667990903101,6.693359375 35 | 4964.0,1667990908580,7.09375 36 | 5110.0,1667990914054,7.0390625 37 | 5256.0,1667990919710,7.00390625 38 | 5402.0,1667990925180,7.4296875 39 | 5548.0,1667990930729,6.97265625 40 | 5694.0,1667990936280,7.548828125 41 | 5840.0,1667990941859,7.072265625 42 | 5986.0,1667990947416,7.173828125 43 | 6132.0,1667990952860,7.724609375 44 | 6278.0,1667990958277,7.91796875 45 | 6424.0,1667990963751,8.001953125 46 | 6570.0,1667990969279,8.59375 47 | 6716.0,1667990974829,8.65234375 48 | 6862.0,1667990980398,8.609375 49 | 7008.0,1667990985796,9.376953125 50 | 7154.0,1667990991411,9.478515625 51 | 7300.0,1667990996917,9.7265625 52 | 7446.0,1667991002575,10.205078125 53 | 7592.0,1667991008252,10.490234375 54 | 7738.0,1667991013757,10.724609375 55 | 7884.0,1667991019217,10.728515625 56 | 8030.0,1667991024852,11.078125 57 | 8176.0,1667991030410,11.43359375 58 | 8322.0,1667991036027,11.591796875 59 | 8468.0,1667991041592,12.046875 60 | 8614.0,1667991047137,11.982421875 61 | 8760.0,1667991052676,12.59765625 62 | 8906.0,1667991058115,13.041015625 63 | 9052.0,1667991063690,13.23046875 64 | 9198.0,1667991069295,13.60546875 65 | 9344.0,1667991074715,13.361328125 66 | 9490.0,1667991080245,14.01171875 67 | 9636.0,1667991085823,14.09375 68 | 9782.0,1667991091332,14.275390625 69 | 9928.0,1667991096912,14.923828125 70 | 10074.0,1667991102434,15.0703125 71 | 10220.0,1667991107992,14.986328125 72 | 10366.0,1667991113373,15.771484375 73 | 10512.0,1667991118839,15.19140625 74 | 10658.0,1667991124317,15.9921875 75 | 10804.0,1667991129912,15.80859375 76 | 10950.0,1667991135585,15.724609375 77 | 11096.0,1667991141285,15.759765625 78 | 11242.0,1667991146762,16.119140625 79 | 11388.0,1667991152253,15.916015625 80 | 11534.0,1667991157783,16.732421875 81 | 11680.0,1667991163267,16.498046875 82 | 11826.0,1667991168697,16.77734375 83 | 11972.0,1667991174091,16.265625 84 | 12118.0,1667991179539,16.978515625 85 | 12264.0,1667991184980,16.423828125 86 | 12410.0,1667991190511,16.82421875 87 | 12556.0,1667991196093,16.6015625 88 | 12702.0,1667991201558,17.12109375 89 | 12848.0,1667991207103,17.15234375 90 | 12994.0,1667991212643,16.828125 91 | 13140.0,1667991218271,16.716796875 92 | 13286.0,1667991223695,17.44140625 93 | 13432.0,1667991229269,17.564453125 94 | 13578.0,1667991234762,17.58203125 95 | 13724.0,1667991240403,18.15625 96 | 13870.0,1667991245965,17.853515625 97 | 14016.0,1667991251604,17.216796875 98 | 14162.0,1667991257177,17.978515625 99 | 14308.0,1667991262719,17.72265625 100 | 14454.0,1667991268211,17.783203125 101 | 14600.0,1667991273684,17.9765625 102 | 14746.0,1667991279325,17.55859375 103 | 14892.0,1667991284851,18.138671875 104 | 15038.0,1667991290372,18.337890625 105 | 15184.0,1667991296126,18.07421875 106 | 15330.0,1667991301771,18.169921875 107 | 15476.0,1667991307475,18.259765625 108 | 15622.0,1667991313059,18.646484375 109 | 15768.0,1667991318667,18.333984375 110 | 15914.0,1667991324142,18.640625 111 | 16060.0,1667991329767,18.84765625 112 | 16206.0,1667991335359,18.3671875 113 | 16352.0,1667991341032,19.107421875 114 | 16498.0,1667991346604,19.548828125 115 | 16644.0,1667991352162,19.611328125 116 | 16790.0,1667991357648,19.376953125 117 | 16936.0,1667991363082,18.0859375 118 | 17082.0,1667991368620,19.458984375 119 | 17228.0,1667991374171,19.623046875 120 | 17374.0,1667991379602,19.439453125 121 | 17520.0,1667991385127,19.6171875 122 | 17666.0,1667991390663,19.78515625 123 | 17812.0,1667991396204,19.798828125 124 | 17958.0,1667991401745,20.029296875 125 | 18104.0,1667991407312,20.541015625 126 | 18250.0,1667991412797,20.7578125 127 | 18396.0,1667991418374,20.80078125 128 | 18542.0,1667991423977,20.82421875 129 | 18688.0,1667991429571,21.083984375 130 | 18834.0,1667991435099,20.439453125 131 | 18980.0,1667991440712,20.69921875 132 | 19126.0,1667991446769,21.390625 133 | 19272.0,1667991452318,21.98828125 134 | 19418.0,1667991457895,22.34765625 135 | 19564.0,1667991463328,21.798828125 136 | 19710.0,1667991468848,23.21484375 137 | 19856.0,1667991474295,23.150390625 138 | 20002.0,1667991479883,24.0390625 139 | 20148.0,1667991485393,23.26953125 140 | 20294.0,1667991490904,23.0625 141 | 20440.0,1667991496376,24.2890625 142 | 20586.0,1667991501872,25.779296875 143 | 20732.0,1667991507603,26.154296875 144 | 20878.0,1667991513032,27.578125 145 | 21024.0,1667991518584,26.671875 146 | 21170.0,1667991524019,27.03125 147 | 21316.0,1667991529689,28.181640625 148 | 21462.0,1667991535161,28.251953125 149 | 21608.0,1667991540741,29.208984375 150 | 21754.0,1667991546338,29.025390625 151 | 21900.0,1667991551845,29.63671875 152 | 22046.0,1667991557403,28.42578125 153 | 22192.0,1667991563021,30.96875 154 | 22338.0,1667991568662,30.7890625 155 | 22484.0,1667991574127,30.58984375 156 | 22630.0,1667991579822,32.3046875 157 | 22776.0,1667991585233,30.9609375 158 | 22922.0,1667991590737,31.330078125 159 | 23068.0,1667991596370,31.138671875 160 | 23214.0,1667991602144,31.279296875 161 | 23360.0,1667991608013,33.0703125 162 | 23506.0,1667991613533,32.494140625 163 | 23652.0,1667991619103,32.30859375 164 | 23798.0,1667991624785,33.203125 165 | 23944.0,1667991630381,32.912109375 166 | 24090.0,1667991635977,32.734375 167 | 24236.0,1667991641615,32.07421875 168 | 24382.0,1667991647176,33.0703125 169 | 24528.0,1667991652850,33.375 170 | 24674.0,1667991658481,34.138671875 171 | 24820.0,1667991664195,33.275390625 172 | 24966.0,1667991669944,33.818359375 173 | 25112.0,1667991675499,33.5390625 174 | 25258.0,1667991681230,34.0703125 175 | 25404.0,1667991686761,34.05078125 176 | 25550.0,1667991692367,33.75390625 177 | 25696.0,1667991698058,34.15234375 178 | 25842.0,1667991703681,34.3671875 179 | 25988.0,1667991709379,32.48828125 180 | 26134.0,1667991714817,34.2734375 181 | 26280.0,1667991720625,33.94921875 182 | 26426.0,1667991726310,34.23828125 183 | 26572.0,1667991731858,34.27734375 184 | 26718.0,1667991737348,34.228515625 185 | 26864.0,1667991742833,33.93359375 186 | 27010.0,1667991748282,34.19921875 187 | 27156.0,1667991753795,34.3359375 188 | 27302.0,1667991759494,34.28515625 189 | 27448.0,1667991765049,34.314453125 190 | 27594.0,1667991770668,34.0859375 191 | 27740.0,1667991776053,33.642578125 192 | 27886.0,1667991781701,34.37890625 193 | 28032.0,1667991787293,34.2265625 194 | 28178.0,1667991792781,34.490234375 195 | 28324.0,1667991798294,34.244140625 196 | 28470.0,1667991803821,33.763671875 197 | 28616.0,1667991809269,34.328125 198 | 28762.0,1667991814799,34.66015625 199 | 28908.0,1667991820279,34.59375 200 | 29054.0,1667991825799,34.3125 201 | 29200.0,1667991831400,34.734375 202 | -------------------------------------------------------------------------------- /snake/data/bootstrap_return/MET-2188__eval_episode_reward_stochastic_policy.csv: -------------------------------------------------------------------------------- 1 | 3.0,1667992944929,0.1640625 2 | 146.0,1667992962921,0.177734375 3 | 292.0,1667992968277,0.212890625 4 | 438.0,1667992973585,0.853515625 5 | 584.0,1667992978985,3.140625 6 | 730.0,1667992984488,3.951171875 7 | 876.0,1667992989743,3.943359375 8 | 1022.0,1667992995036,3.9921875 9 | 1168.0,1667993000415,4.107421875 10 | 1314.0,1667993005795,4.10546875 11 | 1460.0,1667993011041,4.486328125 12 | 1606.0,1667993016447,4.568359375 13 | 1752.0,1667993021790,4.794921875 14 | 1898.0,1667993027068,4.6953125 15 | 2044.0,1667993032371,5.21875 16 | 2190.0,1667993037661,5.421875 17 | 2336.0,1667993043028,5.44140625 18 | 2482.0,1667993048329,5.345703125 19 | 2628.0,1667993053729,5.482421875 20 | 2774.0,1667993059113,5.658203125 21 | 2920.0,1667993064490,5.724609375 22 | 3066.0,1667993069782,5.28125 23 | 3212.0,1667993075172,5.521484375 24 | 3358.0,1667993080573,5.5078125 25 | 3504.0,1667993085962,5.314453125 26 | 3650.0,1667993091364,6.2421875 27 | 3796.0,1667993096699,5.384765625 28 | 3942.0,1667993102169,5.392578125 29 | 4088.0,1667993107501,5.728515625 30 | 4234.0,1667993112871,6.080078125 31 | 4380.0,1667993118240,5.908203125 32 | 4526.0,1667993123578,5.869140625 33 | 4672.0,1667993128847,6.044921875 34 | 4818.0,1667993134243,6.52734375 35 | 4964.0,1667993139541,6.26953125 36 | 5110.0,1667993144899,6.23828125 37 | 5256.0,1667993150286,5.7734375 38 | 5402.0,1667993155648,6.55078125 39 | 5548.0,1667993161015,6.326171875 40 | 5694.0,1667993166460,6.69921875 41 | 5840.0,1667993171809,6.3828125 42 | 5986.0,1667993177159,6.11328125 43 | 6132.0,1667993182520,6.4765625 44 | 6278.0,1667993187944,6.36328125 45 | 6424.0,1667993193190,6.73046875 46 | 6570.0,1667993198483,6.939453125 47 | 6716.0,1667993203844,7.220703125 48 | 6862.0,1667993209126,7.55078125 49 | 7008.0,1667993214451,8.09765625 50 | 7154.0,1667993219682,8.044921875 51 | 7300.0,1667993225026,7.97265625 52 | 7446.0,1667993230323,8.447265625 53 | 7592.0,1667993235696,8.751953125 54 | 7738.0,1667993241028,8.794921875 55 | 7884.0,1667993246457,8.826171875 56 | 8030.0,1667993251876,8.890625 57 | 8176.0,1667993257367,8.9609375 58 | 8322.0,1667993262694,9.357421875 59 | 8468.0,1667993268023,9.44140625 60 | 8614.0,1667993273422,9.224609375 61 | 8760.0,1667993278755,9.33203125 62 | 8906.0,1667993284077,9.830078125 63 | 9052.0,1667993289417,9.744140625 64 | 9198.0,1667993294751,9.90625 65 | 9344.0,1667993300066,10.263671875 66 | 9490.0,1667993305513,10.57421875 67 | 9636.0,1667993310878,10.037109375 68 | 9782.0,1667993316297,9.91796875 69 | 9928.0,1667993321681,10.5625 70 | 10074.0,1667993327058,10.298828125 71 | 10220.0,1667993332494,10.716796875 72 | 10366.0,1667993337724,10.75390625 73 | 10512.0,1667993343072,10.81640625 74 | 10658.0,1667993348386,10.6484375 75 | 10804.0,1667993353650,10.6796875 76 | 10950.0,1667993359050,10.708984375 77 | 11096.0,1667993364413,10.564453125 78 | 11242.0,1667993369750,10.697265625 79 | 11388.0,1667993375085,10.982421875 80 | 11534.0,1667993380337,10.734375 81 | 11680.0,1667993385690,11.009765625 82 | 11826.0,1667993391044,10.962890625 83 | 11972.0,1667993396480,10.896484375 84 | 12118.0,1667993402009,10.923828125 85 | 12264.0,1667993407433,11.080078125 86 | 12410.0,1667993412838,10.974609375 87 | 12556.0,1667993418273,11.05078125 88 | 12702.0,1667993423722,10.646484375 89 | 12848.0,1667993429193,10.27734375 90 | 12994.0,1667993434686,11.18359375 91 | 13140.0,1667993440044,10.78125 92 | 13286.0,1667993445460,11.111328125 93 | 13432.0,1667993450818,10.841796875 94 | 13578.0,1667993456167,10.755859375 95 | 13724.0,1667993461796,11.203125 96 | 13870.0,1667993467377,11.111328125 97 | 14016.0,1667993472828,10.904296875 98 | 14162.0,1667993478249,11.033203125 99 | 14308.0,1667993483731,11.26171875 100 | 14454.0,1667993489211,11.384765625 101 | 14600.0,1667993494683,11.451171875 102 | 14746.0,1667993500149,10.943359375 103 | 14892.0,1667993505561,11.23046875 104 | 15038.0,1667993511139,10.822265625 105 | 15184.0,1667993516599,11.181640625 106 | 15330.0,1667993522226,11.333984375 107 | 15476.0,1667993527698,11.076171875 108 | 15622.0,1667993533221,11.33984375 109 | 15768.0,1667993538682,11.478515625 110 | 15914.0,1667993544034,11.041015625 111 | 16060.0,1667993549457,11.22265625 112 | 16206.0,1667993554894,11.337890625 113 | 16352.0,1667993560238,10.8359375 114 | 16498.0,1667993565735,11.275390625 115 | 16644.0,1667993571126,11.572265625 116 | 16790.0,1667993576658,11.4140625 117 | 16936.0,1667993582096,11.4921875 118 | 17082.0,1667993587535,11.2109375 119 | 17228.0,1667993592993,11.609375 120 | 17374.0,1667993598436,11.59375 121 | 17520.0,1667993603915,11.8984375 122 | 17666.0,1667993609231,11.546875 123 | 17812.0,1667993614642,11.701171875 124 | 17958.0,1667993620104,11.66015625 125 | 18104.0,1667993625680,11.46875 126 | 18250.0,1667993631090,12.087890625 127 | 18396.0,1667993636587,11.896484375 128 | 18542.0,1667993642000,12.373046875 129 | 18688.0,1667993647433,12.275390625 130 | 18834.0,1667993652810,12.060546875 131 | 18980.0,1667993658306,12.51171875 132 | 19126.0,1667993663829,12.810546875 133 | 19272.0,1667993669122,12.560546875 134 | 19418.0,1667993674428,12.689453125 135 | 19564.0,1667993679783,12.671875 136 | 19710.0,1667993685065,12.7890625 137 | 19856.0,1667993690489,13.072265625 138 | 20002.0,1667993695813,13.0 139 | 20148.0,1667993701227,12.771484375 140 | 20294.0,1667993706658,12.83203125 141 | 20440.0,1667993712211,12.8359375 142 | 20586.0,1667993717521,12.857421875 143 | 20732.0,1667993722963,12.69140625 144 | 20878.0,1667993728240,12.765625 145 | 21024.0,1667993733872,12.943359375 146 | 21170.0,1667993739152,12.71875 147 | 21316.0,1667993744589,12.37890625 148 | 21462.0,1667993749888,12.3828125 149 | 21608.0,1667993755275,12.353515625 150 | 21754.0,1667993760613,12.30859375 151 | 21900.0,1667993765967,12.07421875 152 | 22046.0,1667993771296,12.4296875 153 | 22192.0,1667993776996,12.27734375 154 | 22338.0,1667993782314,12.20703125 155 | 22484.0,1667993787753,12.349609375 156 | 22630.0,1667993793108,12.30859375 157 | 22776.0,1667993798620,12.5625 158 | 22922.0,1667993803942,12.556640625 159 | 23068.0,1667993809237,12.40625 160 | 23214.0,1667993814703,12.611328125 161 | 23360.0,1667993820120,13.2109375 162 | 23506.0,1667993825535,12.83203125 163 | 23652.0,1667993830966,12.75390625 164 | 23798.0,1667993836509,12.896484375 165 | 23944.0,1667993841875,12.9375 166 | 24090.0,1667993847241,12.458984375 167 | 24236.0,1667993852647,12.376953125 168 | 24382.0,1667993858043,12.705078125 169 | 24528.0,1667993863400,12.7890625 170 | 24674.0,1667993868751,12.810546875 171 | 24820.0,1667993874245,12.802734375 172 | 24966.0,1667993879710,12.607421875 173 | 25112.0,1667993885182,12.9140625 174 | 25258.0,1667993890630,12.751953125 175 | 25404.0,1667993896244,12.83203125 176 | 25550.0,1667993901670,12.787109375 177 | 25696.0,1667993907167,12.603515625 178 | 25842.0,1667993912678,12.755859375 179 | 25988.0,1667993918109,12.884765625 180 | 26134.0,1667993923525,13.087890625 181 | 26280.0,1667993928895,13.072265625 182 | 26426.0,1667993934480,13.078125 183 | 26572.0,1667993939813,13.06640625 184 | 26718.0,1667993945204,12.994140625 185 | 26864.0,1667993950495,13.4765625 186 | 27010.0,1667993955897,13.47265625 187 | 27156.0,1667993961229,13.84375 188 | 27302.0,1667993966706,13.833984375 189 | 27448.0,1667993972189,13.55078125 190 | 27594.0,1667993977524,13.822265625 191 | 27740.0,1667993982856,13.75 192 | 27886.0,1667993988220,13.759765625 193 | 28032.0,1667993993651,13.42578125 194 | 28178.0,1667993999080,13.966796875 195 | 28324.0,1667994004560,14.015625 196 | 28470.0,1667994009970,14.388671875 197 | 28616.0,1667994015385,14.3046875 198 | 28762.0,1667994020780,13.765625 199 | 28908.0,1667994026252,14.001953125 200 | 29054.0,1667994031763,14.1484375 201 | 29200.0,1667994037287,13.931640625 202 | -------------------------------------------------------------------------------- /snake/data/bootstrap_return/MET-2196__eval_episode_reward_stochastic_policy.csv: -------------------------------------------------------------------------------- 1 | 3.0,1667995177742,0.154296875 2 | 146.0,1667995196189,0.171875 3 | 292.0,1667995201910,0.259765625 4 | 438.0,1667995207664,2.080078125 5 | 584.0,1667995213249,3.546875 6 | 730.0,1667995218727,3.908203125 7 | 876.0,1667995224273,4.0625 8 | 1022.0,1667995229766,4.765625 9 | 1168.0,1667995235306,4.904296875 10 | 1314.0,1667995240820,5.01953125 11 | 1460.0,1667995246473,5.666015625 12 | 1606.0,1667995251894,5.173828125 13 | 1752.0,1667995257430,5.857421875 14 | 1898.0,1667995262848,5.7265625 15 | 2044.0,1667995268300,6.267578125 16 | 2190.0,1667995273762,6.478515625 17 | 2336.0,1667995279295,5.875 18 | 2482.0,1667995284885,6.15234375 19 | 2628.0,1667995290417,5.96484375 20 | 2774.0,1667995296031,6.1796875 21 | 2920.0,1667995301523,5.88671875 22 | 3066.0,1667995307082,5.966796875 23 | 3212.0,1667995312629,6.408203125 24 | 3358.0,1667995318165,6.078125 25 | 3504.0,1667995323714,6.4375 26 | 3650.0,1667995329275,6.34765625 27 | 3796.0,1667995334939,6.609375 28 | 3942.0,1667995340534,6.572265625 29 | 4088.0,1667995346065,6.77734375 30 | 4234.0,1667995351548,5.92578125 31 | 4380.0,1667995357096,6.53515625 32 | 4526.0,1667995362476,6.640625 33 | 4672.0,1667995367953,7.029296875 34 | 4818.0,1667995373428,6.525390625 35 | 4964.0,1667995378960,6.98046875 36 | 5110.0,1667995384416,6.30078125 37 | 5256.0,1667995389941,6.55859375 38 | 5402.0,1667995395435,6.31640625 39 | 5548.0,1667995400874,7.025390625 40 | 5694.0,1667995406493,7.2109375 41 | 5840.0,1667995411890,7.283203125 42 | 5986.0,1667995417430,7.20703125 43 | 6132.0,1667995422939,7.78125 44 | 6278.0,1667995428384,7.9765625 45 | 6424.0,1667995433832,7.662109375 46 | 6570.0,1667995439283,8.1875 47 | 6716.0,1667995444804,8.064453125 48 | 6862.0,1667995450269,7.93359375 49 | 7008.0,1667995455735,8.212890625 50 | 7154.0,1667995461203,8.4296875 51 | 7300.0,1667995466776,8.49609375 52 | 7446.0,1667995472212,8.27734375 53 | 7592.0,1667995477682,8.4296875 54 | 7738.0,1667995483287,8.302734375 55 | 7884.0,1667995488828,8.720703125 56 | 8030.0,1667995494248,8.7421875 57 | 8176.0,1667995499596,8.40625 58 | 8322.0,1667995505230,9.0703125 59 | 8468.0,1667995510762,8.966796875 60 | 8614.0,1667995516288,8.794921875 61 | 8760.0,1667995521808,9.015625 62 | 8906.0,1667995527362,9.443359375 63 | 9052.0,1667995532917,9.294921875 64 | 9198.0,1667995538490,9.2421875 65 | 9344.0,1667995544078,9.3359375 66 | 9490.0,1667995549658,9.60546875 67 | 9636.0,1667995555161,9.466796875 68 | 9782.0,1667995560627,9.615234375 69 | 9928.0,1667995566178,9.537109375 70 | 10074.0,1667995571648,9.689453125 71 | 10220.0,1667995577215,9.703125 72 | 10366.0,1667995582681,9.533203125 73 | 10512.0,1667995588325,9.828125 74 | 10658.0,1667995593809,9.86328125 75 | 10804.0,1667995599709,9.7109375 76 | 10950.0,1667995605371,10.015625 77 | 11096.0,1667995610980,10.138671875 78 | 11242.0,1667995616555,10.005859375 79 | 11388.0,1667995622032,10.361328125 80 | 11534.0,1667995627674,10.29296875 81 | 11680.0,1667995633149,10.443359375 82 | 11826.0,1667995638650,10.6796875 83 | 11972.0,1667995644253,10.515625 84 | 12118.0,1667995649786,10.765625 85 | 12264.0,1667995655316,10.822265625 86 | 12410.0,1667995660805,10.888671875 87 | 12556.0,1667995666438,10.5546875 88 | 12702.0,1667995672003,10.861328125 89 | 12848.0,1667995677525,10.47265625 90 | 12994.0,1667995683035,11.04296875 91 | 13140.0,1667995688539,10.591796875 92 | 13286.0,1667995694057,10.70703125 93 | 13432.0,1667995699542,10.759765625 94 | 13578.0,1667995705068,10.66015625 95 | 13724.0,1667995710623,10.76171875 96 | 13870.0,1667995716154,11.18359375 97 | 14016.0,1667995721642,10.689453125 98 | 14162.0,1667995727200,10.88671875 99 | 14308.0,1667995732750,10.642578125 100 | 14454.0,1667995738272,11.025390625 101 | 14600.0,1667995743765,10.9453125 102 | 14746.0,1667995749387,10.78515625 103 | 14892.0,1667995755349,10.8828125 104 | 15038.0,1667995762059,11.2265625 105 | 15184.0,1667995768892,10.95703125 106 | 15330.0,1667995775030,11.150390625 107 | 15476.0,1667995781023,11.51953125 108 | 15622.0,1667995787787,11.03125 109 | 15768.0,1667995794518,11.521484375 110 | 15914.0,1667995800316,11.36328125 111 | 16060.0,1667995805929,11.65625 112 | 16206.0,1667995811430,11.642578125 113 | 16352.0,1667995817099,11.64453125 114 | 16498.0,1667995822591,11.59765625 115 | 16644.0,1667995828121,11.8046875 116 | 16790.0,1667995833648,11.712890625 117 | 16936.0,1667995839121,12.109375 118 | 17082.0,1667995844789,11.974609375 119 | 17228.0,1667995850227,12.1875 120 | 17374.0,1667995855815,12.26953125 121 | 17520.0,1667995861373,12.283203125 122 | 17666.0,1667995866957,12.12109375 123 | 17812.0,1667995872446,12.083984375 124 | 17958.0,1667995878051,12.208984375 125 | 18104.0,1667995883533,12.45703125 126 | 18250.0,1667995889097,12.111328125 127 | 18396.0,1667995894568,12.3984375 128 | 18542.0,1667995900122,12.20703125 129 | 18688.0,1667995905579,12.66015625 130 | 18834.0,1667995911110,12.25390625 131 | 18980.0,1667995916617,12.62890625 132 | 19126.0,1667995922151,12.7421875 133 | 19272.0,1667995927696,12.345703125 134 | 19418.0,1667995933239,12.685546875 135 | 19564.0,1667995938829,12.490234375 136 | 19710.0,1667995944444,12.583984375 137 | 19856.0,1667995950103,12.5390625 138 | 20002.0,1667995955757,12.259765625 139 | 20148.0,1667995961372,12.462890625 140 | 20294.0,1667995967070,12.337890625 141 | 20440.0,1667995972610,12.203125 142 | 20586.0,1667995978219,12.53125 143 | 20732.0,1667995983772,12.736328125 144 | 20878.0,1667995989442,12.09765625 145 | 21024.0,1667995995065,12.421875 146 | 21170.0,1667996000631,12.4765625 147 | 21316.0,1667996006272,12.68359375 148 | 21462.0,1667996011856,12.58984375 149 | 21608.0,1667996017369,11.994140625 150 | 21754.0,1667996022918,12.40234375 151 | 21900.0,1667996028588,12.865234375 152 | 22046.0,1667996034130,12.318359375 153 | 22192.0,1667996039722,12.189453125 154 | 22338.0,1667996045386,12.107421875 155 | 22484.0,1667996050939,12.802734375 156 | 22630.0,1667996056606,12.326171875 157 | 22776.0,1667996062196,12.384765625 158 | 22922.0,1667996067921,12.69921875 159 | 23068.0,1667996073514,12.541015625 160 | 23214.0,1667996079102,12.6484375 161 | 23360.0,1667996084753,12.453125 162 | 23506.0,1667996090303,12.126953125 163 | 23652.0,1667996095988,12.716796875 164 | 23798.0,1667996101594,12.84375 165 | 23944.0,1667996107263,12.59765625 166 | 24090.0,1667996112881,12.873046875 167 | 24236.0,1667996118538,12.626953125 168 | 24382.0,1667996124187,12.83984375 169 | 24528.0,1667996129852,12.912109375 170 | 24674.0,1667996135519,12.591796875 171 | 24820.0,1667996141137,12.55078125 172 | 24966.0,1667996146911,13.37109375 173 | 25112.0,1667996152521,12.583984375 174 | 25258.0,1667996158310,12.947265625 175 | 25404.0,1667996163991,13.447265625 176 | 25550.0,1667996169709,13.375 177 | 25696.0,1667996175396,13.376953125 178 | 25842.0,1667996181153,13.333984375 179 | 25988.0,1667996186938,13.013671875 180 | 26134.0,1667996192637,13.166015625 181 | 26280.0,1667996198388,13.16015625 182 | 26426.0,1667996204060,12.791015625 183 | 26572.0,1667996209672,13.283203125 184 | 26718.0,1667996215382,13.20703125 185 | 26864.0,1667996221046,13.208984375 186 | 27010.0,1667996226638,13.23046875 187 | 27156.0,1667996232242,13.169921875 188 | 27302.0,1667996237908,13.146484375 189 | 27448.0,1667996243515,13.056640625 190 | 27594.0,1667996249132,12.8828125 191 | 27740.0,1667996254767,12.837890625 192 | 27886.0,1667996260363,13.283203125 193 | 28032.0,1667996266032,13.228515625 194 | 28178.0,1667996271686,12.69140625 195 | 28324.0,1667996277311,13.38671875 196 | 28470.0,1667996282935,12.416015625 197 | 28616.0,1667996288513,12.861328125 198 | 28762.0,1667996294050,12.55859375 199 | 28908.0,1667996299727,12.611328125 200 | 29054.0,1667996305475,12.71484375 201 | 29200.0,1667996311015,12.556640625 202 | -------------------------------------------------------------------------------- /snake/data/bootstrap_return/MET-2199__eval_episode_reward_stochastic_policy.csv: -------------------------------------------------------------------------------- 1 | 3.0,1667996320975,0.14453125 2 | 146.0,1667996339763,0.150390625 3 | 292.0,1667996345348,0.271484375 4 | 438.0,1667996350930,2.166015625 5 | 584.0,1667996356459,3.521484375 6 | 730.0,1667996361918,4.017578125 7 | 876.0,1667996367450,4.380859375 8 | 1022.0,1667996373017,4.8125 9 | 1168.0,1667996378612,5.072265625 10 | 1314.0,1667996384064,5.275390625 11 | 1460.0,1667996389582,5.787109375 12 | 1606.0,1667996395028,5.82421875 13 | 1752.0,1667996400554,5.822265625 14 | 1898.0,1667996406193,6.40625 15 | 2044.0,1667996411765,6.095703125 16 | 2190.0,1667996417488,6.41796875 17 | 2336.0,1667996423097,6.166015625 18 | 2482.0,1667996428667,6.474609375 19 | 2628.0,1667996434164,6.630859375 20 | 2774.0,1667996439601,6.529296875 21 | 2920.0,1667996445130,6.703125 22 | 3066.0,1667996450646,6.634765625 23 | 3212.0,1667996456199,6.7734375 24 | 3358.0,1667996461826,6.623046875 25 | 3504.0,1667996467384,6.6171875 26 | 3650.0,1667996472841,7.103515625 27 | 3796.0,1667996478566,6.966796875 28 | 3942.0,1667996484102,7.072265625 29 | 4088.0,1667996489624,7.53515625 30 | 4234.0,1667996495201,7.095703125 31 | 4380.0,1667996500787,7.384765625 32 | 4526.0,1667996506286,7.53125 33 | 4672.0,1667996511822,7.98046875 34 | 4818.0,1667996517389,8.115234375 35 | 4964.0,1667996523057,8.19140625 36 | 5110.0,1667996528634,8.19921875 37 | 5256.0,1667996534139,8.740234375 38 | 5402.0,1667996539660,8.875 39 | 5548.0,1667996545315,8.72265625 40 | 5694.0,1667996550844,8.982421875 41 | 5840.0,1667996556397,9.12890625 42 | 5986.0,1667996561941,9.080078125 43 | 6132.0,1667996567503,8.994140625 44 | 6278.0,1667996573006,9.509765625 45 | 6424.0,1667996578522,9.345703125 46 | 6570.0,1667996584093,9.8203125 47 | 6716.0,1667996589597,10.267578125 48 | 6862.0,1667996595246,9.681640625 49 | 7008.0,1667996600806,9.80859375 50 | 7154.0,1667996606398,9.890625 51 | 7300.0,1667996611924,10.69140625 52 | 7446.0,1667996617556,10.486328125 53 | 7592.0,1667996623095,10.794921875 54 | 7738.0,1667996628576,10.841796875 55 | 7884.0,1667996634213,10.99609375 56 | 8030.0,1667996639836,11.15625 57 | 8176.0,1667996645354,11.306640625 58 | 8322.0,1667996650809,11.205078125 59 | 8468.0,1667996656300,11.681640625 60 | 8614.0,1667996661951,11.66796875 61 | 8760.0,1667996667391,12.078125 62 | 8906.0,1667996672918,12.22265625 63 | 9052.0,1667996678604,12.37890625 64 | 9198.0,1667996684185,12.609375 65 | 9344.0,1667996689865,13.037109375 66 | 9490.0,1667996695395,12.68359375 67 | 9636.0,1667996701056,13.23046875 68 | 9782.0,1667996706683,12.82421875 69 | 9928.0,1667996712195,13.3125 70 | 10074.0,1667996717750,13.451171875 71 | 10220.0,1667996723353,13.099609375 72 | 10366.0,1667996729021,13.732421875 73 | 10512.0,1667996734603,14.037109375 74 | 10658.0,1667996740221,13.470703125 75 | 10804.0,1667996745771,13.50390625 76 | 10950.0,1667996751432,14.1953125 77 | 11096.0,1667996756942,13.90625 78 | 11242.0,1667996762540,13.951171875 79 | 11388.0,1667996768165,14.26171875 80 | 11534.0,1667996773707,14.015625 81 | 11680.0,1667996779269,13.759765625 82 | 11826.0,1667996784821,14.03515625 83 | 11972.0,1667996790398,14.3125 84 | 12118.0,1667996795850,14.10546875 85 | 12264.0,1667996801483,14.9453125 86 | 12410.0,1667996807086,14.58984375 87 | 12556.0,1667996812759,15.47265625 88 | 12702.0,1667996818301,15.0390625 89 | 12848.0,1667996823849,15.3984375 90 | 12994.0,1667996829514,15.5078125 91 | 13140.0,1667996835227,15.9609375 92 | 13286.0,1667996840854,15.515625 93 | 13432.0,1667996846668,15.796875 94 | 13578.0,1667996852480,15.783203125 95 | 13724.0,1667996858303,16.138671875 96 | 13870.0,1667996864131,15.83984375 97 | 14016.0,1667996870354,15.48828125 98 | 14162.0,1667996876167,16.0390625 99 | 14308.0,1667996882447,16.376953125 100 | 14454.0,1667996888244,16.482421875 101 | 14600.0,1667996893862,16.25390625 102 | 14746.0,1667996899449,16.708984375 103 | 14892.0,1667996905031,16.591796875 104 | 15038.0,1667996910581,16.44921875 105 | 15184.0,1667996916193,16.78515625 106 | 15330.0,1667996921865,17.05859375 107 | 15476.0,1667996927563,16.96875 108 | 15622.0,1667996933282,16.984375 109 | 15768.0,1667996938949,16.224609375 110 | 15914.0,1667996944516,16.859375 111 | 16060.0,1667996950255,16.841796875 112 | 16206.0,1667996955807,17.119140625 113 | 16352.0,1667996961359,17.068359375 114 | 16498.0,1667996966957,16.826171875 115 | 16644.0,1667996972611,16.978515625 116 | 16790.0,1667996978353,17.33203125 117 | 16936.0,1667996984076,17.064453125 118 | 17082.0,1667996990314,16.931640625 119 | 17228.0,1667996996427,17.2109375 120 | 17374.0,1667997002739,17.267578125 121 | 17520.0,1667997008933,17.521484375 122 | 17666.0,1667997014742,17.4375 123 | 17812.0,1667997020399,17.341796875 124 | 17958.0,1667997026027,17.142578125 125 | 18104.0,1667997031680,17.357421875 126 | 18250.0,1667997037368,17.455078125 127 | 18396.0,1667997043321,17.33984375 128 | 18542.0,1667997048944,17.396484375 129 | 18688.0,1667997054564,17.52734375 130 | 18834.0,1667997060116,17.3984375 131 | 18980.0,1667997065668,17.43359375 132 | 19126.0,1667997071360,17.34765625 133 | 19272.0,1667997077026,18.087890625 134 | 19418.0,1667997082633,18.14453125 135 | 19564.0,1667997088183,17.9609375 136 | 19710.0,1667997093703,18.212890625 137 | 19856.0,1667997099385,18.287109375 138 | 20002.0,1667997104944,18.501953125 139 | 20148.0,1667997110569,18.490234375 140 | 20294.0,1667997116066,18.490234375 141 | 20440.0,1667997121854,18.11328125 142 | 20586.0,1667997127482,19.0 143 | 20732.0,1667997133184,18.9140625 144 | 20878.0,1667997138901,18.90625 145 | 21024.0,1667997144549,18.751953125 146 | 21170.0,1667997150128,19.375 147 | 21316.0,1667997155724,19.419921875 148 | 21462.0,1667997161319,19.12890625 149 | 21608.0,1667997166991,19.7265625 150 | 21754.0,1667997172671,19.8046875 151 | 21900.0,1667997178352,20.01953125 152 | 22046.0,1667997184020,20.080078125 153 | 22192.0,1667997189634,19.525390625 154 | 22338.0,1667997195271,20.103515625 155 | 22484.0,1667997200972,20.361328125 156 | 22630.0,1667997206621,20.5078125 157 | 22776.0,1667997212201,20.525390625 158 | 22922.0,1667997217811,20.990234375 159 | 23068.0,1667997223375,20.2109375 160 | 23214.0,1667997228978,20.791015625 161 | 23360.0,1667997234756,20.80078125 162 | 23506.0,1667997240303,20.80859375 163 | 23652.0,1667997245870,21.244140625 164 | 23798.0,1667997251593,21.505859375 165 | 23944.0,1667997257342,21.033203125 166 | 24090.0,1667997262995,22.150390625 167 | 24236.0,1667997268741,22.046875 168 | 24382.0,1667997274387,21.609375 169 | 24528.0,1667997280162,21.73828125 170 | 24674.0,1667997285678,21.90234375 171 | 24820.0,1667997291393,21.92578125 172 | 24966.0,1667997296968,22.265625 173 | 25112.0,1667997302741,22.4375 174 | 25258.0,1667997308508,22.384765625 175 | 25404.0,1667997314099,22.306640625 176 | 25550.0,1667997319683,22.802734375 177 | 25696.0,1667997325431,22.740234375 178 | 25842.0,1667997331029,23.19921875 179 | 25988.0,1667997336681,23.111328125 180 | 26134.0,1667997342383,24.279296875 181 | 26280.0,1667997348005,24.12109375 182 | 26426.0,1667997353521,24.28515625 183 | 26572.0,1667997359275,25.828125 184 | 26718.0,1667997364839,26.453125 185 | 26864.0,1667997370640,27.3203125 186 | 27010.0,1667997376386,29.033203125 187 | 27156.0,1667997382016,29.91796875 188 | 27302.0,1667997387597,30.78125 189 | 27448.0,1667997393223,31.267578125 190 | 27594.0,1667997398863,32.009765625 191 | 27740.0,1667997404452,30.61328125 192 | 27886.0,1667997410178,32.748046875 193 | 28032.0,1667997415869,32.1484375 194 | 28178.0,1667997421659,32.271484375 195 | 28324.0,1667997427224,32.677734375 196 | 28470.0,1667997432872,32.66796875 197 | 28616.0,1667997438558,32.48828125 198 | 28762.0,1667997444117,33.21484375 199 | 28908.0,1667997449795,32.892578125 200 | 29054.0,1667997455441,33.40625 201 | 29200.0,1667997461124,33.4921875 202 | -------------------------------------------------------------------------------- /snake/data/mgrl_return/MET-2157__eval_episode_reward_stochastic_policy.csv: -------------------------------------------------------------------------------- 1 | 3.0,1667987108012,0.140625 2 | 146.0,1667987120064,0.185546875 3 | 292.0,1667987124772,0.3125 4 | 438.0,1667987129323,2.08984375 5 | 584.0,1667987133938,3.650390625 6 | 730.0,1667987138527,3.86328125 7 | 876.0,1667987143127,4.072265625 8 | 1022.0,1667987147765,4.703125 9 | 1168.0,1667987152404,4.888671875 10 | 1314.0,1667987157001,5.099609375 11 | 1460.0,1667987161673,5.595703125 12 | 1606.0,1667987166268,5.345703125 13 | 1752.0,1667987170802,5.81640625 14 | 1898.0,1667987175408,5.958984375 15 | 2044.0,1667987180042,5.98828125 16 | 2190.0,1667987184611,6.02734375 17 | 2336.0,1667987189125,5.720703125 18 | 2482.0,1667987193708,5.484375 19 | 2628.0,1667987198419,5.93359375 20 | 2774.0,1667987202946,6.28515625 21 | 2920.0,1667987207607,6.33203125 22 | 3066.0,1667987212231,6.5078125 23 | 3212.0,1667987216907,6.48046875 24 | 3358.0,1667987221578,6.201171875 25 | 3504.0,1667987226224,6.2734375 26 | 3650.0,1667987230906,6.744140625 27 | 3796.0,1667987235586,6.6640625 28 | 3942.0,1667987240163,6.951171875 29 | 4088.0,1667987244849,6.82421875 30 | 4234.0,1667987249371,6.73046875 31 | 4380.0,1667987253956,6.84765625 32 | 4526.0,1667987258524,6.94140625 33 | 4672.0,1667987263124,7.107421875 34 | 4818.0,1667987267668,6.8671875 35 | 4964.0,1667987272331,7.025390625 36 | 5110.0,1667987276955,7.4296875 37 | 5256.0,1667987281580,7.43359375 38 | 5402.0,1667987286271,7.685546875 39 | 5548.0,1667987290845,7.55078125 40 | 5694.0,1667987295444,7.607421875 41 | 5840.0,1667987300092,7.517578125 42 | 5986.0,1667987304703,7.962890625 43 | 6132.0,1667987309277,7.875 44 | 6278.0,1667987313897,8.220703125 45 | 6424.0,1667987318542,8.2734375 46 | 6570.0,1667987323153,9.078125 47 | 6716.0,1667987327816,8.87890625 48 | 6862.0,1667987332430,9.361328125 49 | 7008.0,1667987337070,9.193359375 50 | 7154.0,1667987341634,9.9453125 51 | 7300.0,1667987346310,9.998046875 52 | 7446.0,1667987350948,10.365234375 53 | 7592.0,1667987355576,10.615234375 54 | 7738.0,1667987360129,10.828125 55 | 7884.0,1667987364807,11.09765625 56 | 8030.0,1667987369398,11.923828125 57 | 8176.0,1667987373990,11.61328125 58 | 8322.0,1667987378669,12.142578125 59 | 8468.0,1667987383251,12.181640625 60 | 8614.0,1667987387851,12.80078125 61 | 8760.0,1667987392464,12.642578125 62 | 8906.0,1667987397090,12.68359375 63 | 9052.0,1667987401866,13.1171875 64 | 9198.0,1667987406589,13.140625 65 | 9344.0,1667987411224,13.64453125 66 | 9490.0,1667987415898,13.53125 67 | 9636.0,1667987420488,13.9140625 68 | 9782.0,1667987425182,13.943359375 69 | 9928.0,1667987429966,14.123046875 70 | 10074.0,1667987434623,13.9765625 71 | 10220.0,1667987439212,14.3984375 72 | 10366.0,1667987443831,14.58203125 73 | 10512.0,1667987448516,14.59375 74 | 10658.0,1667987453239,14.751953125 75 | 10804.0,1667987457836,15.32421875 76 | 10950.0,1667987462499,15.5703125 77 | 11096.0,1667987467187,15.294921875 78 | 11242.0,1667987471923,15.65234375 79 | 11388.0,1667987476543,15.439453125 80 | 11534.0,1667987481160,14.95703125 81 | 11680.0,1667987485785,15.392578125 82 | 11826.0,1667987490388,15.720703125 83 | 11972.0,1667987494961,15.73828125 84 | 12118.0,1667987499539,15.919921875 85 | 12264.0,1667987504118,15.99609375 86 | 12410.0,1667987508760,15.75390625 87 | 12556.0,1667987513306,15.9921875 88 | 12702.0,1667987517943,16.439453125 89 | 12848.0,1667987522627,16.080078125 90 | 12994.0,1667987527335,15.875 91 | 13140.0,1667987531963,15.8046875 92 | 13286.0,1667987536632,15.84375 93 | 13432.0,1667987541173,16.21484375 94 | 13578.0,1667987546179,15.78515625 95 | 13724.0,1667987550727,16.41015625 96 | 13870.0,1667987555364,15.896484375 97 | 14016.0,1667987560010,16.16796875 98 | 14162.0,1667987564686,16.49609375 99 | 14308.0,1667987569339,16.009765625 100 | 14454.0,1667987573974,16.615234375 101 | 14600.0,1667987578595,16.62890625 102 | 14746.0,1667987583284,16.498046875 103 | 14892.0,1667987587935,16.515625 104 | 15038.0,1667987592539,16.458984375 105 | 15184.0,1667987597298,17.0546875 106 | 15330.0,1667987601849,16.736328125 107 | 15476.0,1667987606573,17.232421875 108 | 15622.0,1667987611179,17.44140625 109 | 15768.0,1667987615902,17.19921875 110 | 15914.0,1667987620555,16.966796875 111 | 16060.0,1667987625183,17.140625 112 | 16206.0,1667987629803,16.78515625 113 | 16352.0,1667987634461,16.646484375 114 | 16498.0,1667987639139,17.078125 115 | 16644.0,1667987643786,17.111328125 116 | 16790.0,1667987648494,17.2890625 117 | 16936.0,1667987653106,17.05078125 118 | 17082.0,1667987657773,17.244140625 119 | 17228.0,1667987662376,17.5078125 120 | 17374.0,1667987666989,17.458984375 121 | 17520.0,1667987671551,17.67578125 122 | 17666.0,1667987676268,17.466796875 123 | 17812.0,1667987680837,17.265625 124 | 17958.0,1667987685424,16.97265625 125 | 18104.0,1667987690016,17.392578125 126 | 18250.0,1667987694737,17.693359375 127 | 18396.0,1667987699419,17.52734375 128 | 18542.0,1667987704134,17.650390625 129 | 18688.0,1667987708752,17.892578125 130 | 18834.0,1667987713280,17.294921875 131 | 18980.0,1667987718004,17.6328125 132 | 19126.0,1667987722681,17.43359375 133 | 19272.0,1667987727429,17.88671875 134 | 19418.0,1667987732055,17.66015625 135 | 19564.0,1667987736700,17.79296875 136 | 19710.0,1667987741313,17.73046875 137 | 19856.0,1667987746039,17.853515625 138 | 20002.0,1667987750707,18.048828125 139 | 20148.0,1667987755305,17.537109375 140 | 20294.0,1667987759969,17.50390625 141 | 20440.0,1667987764553,17.390625 142 | 20586.0,1667987769184,17.984375 143 | 20732.0,1667987773792,17.7265625 144 | 20878.0,1667987778468,17.412109375 145 | 21024.0,1667987783135,17.3359375 146 | 21170.0,1667987787761,18.080078125 147 | 21316.0,1667987792508,18.18359375 148 | 21462.0,1667987797190,18.029296875 149 | 21608.0,1667987801859,18.017578125 150 | 21754.0,1667987806527,17.916015625 151 | 21900.0,1667987811110,17.650390625 152 | 22046.0,1667987815788,18.01171875 153 | 22192.0,1667987820369,17.8515625 154 | 22338.0,1667987825008,18.16796875 155 | 22484.0,1667987829661,17.78125 156 | 22630.0,1667987834278,18.0 157 | 22776.0,1667987838959,18.07421875 158 | 22922.0,1667987843678,18.24609375 159 | 23068.0,1667987848347,17.986328125 160 | 23214.0,1667987852920,17.796875 161 | 23360.0,1667987857475,17.712890625 162 | 23506.0,1667987862136,18.28125 163 | 23652.0,1667987866840,17.931640625 164 | 23798.0,1667987871456,18.375 165 | 23944.0,1667987876013,18.81640625 166 | 24090.0,1667987880626,18.2109375 167 | 24236.0,1667987885316,17.98046875 168 | 24382.0,1667987889948,18.0078125 169 | 24528.0,1667987894520,18.46484375 170 | 24674.0,1667987899133,18.263671875 171 | 24820.0,1667987903797,18.06640625 172 | 24966.0,1667987908413,18.412109375 173 | 25112.0,1667987913109,18.02734375 174 | 25258.0,1667987917860,18.6796875 175 | 25404.0,1667987922499,19.0 176 | 25550.0,1667987927232,18.302734375 177 | 25696.0,1667987931845,18.24609375 178 | 25842.0,1667987936430,18.484375 179 | 25988.0,1667987941029,18.923828125 180 | 26134.0,1667987945734,18.609375 181 | 26280.0,1667987950294,18.8359375 182 | 26426.0,1667987954857,18.625 183 | 26572.0,1667987959594,18.626953125 184 | 26718.0,1667987964182,18.763671875 185 | 26864.0,1667987968747,18.884765625 186 | 27010.0,1667987973393,18.833984375 187 | 27156.0,1667987978091,19.306640625 188 | 27302.0,1667987982658,19.16015625 189 | 27448.0,1667987987309,18.974609375 190 | 27594.0,1667987991945,19.337890625 191 | 27740.0,1667987996592,19.16015625 192 | 27886.0,1667988001536,19.478515625 193 | 28032.0,1667988006429,19.591796875 194 | 28178.0,1667988011140,19.4140625 195 | 28324.0,1667988015794,19.6015625 196 | 28470.0,1667988020489,19.478515625 197 | 28616.0,1667988025271,19.244140625 198 | 28762.0,1667988029880,19.578125 199 | 28908.0,1667988034607,20.52734375 200 | 29054.0,1667988039362,20.111328125 201 | 29200.0,1667988044047,19.568359375 202 | -------------------------------------------------------------------------------- /snake/data/mgrl_return/MET-2179__eval_episode_reward_stochastic_policy.csv: -------------------------------------------------------------------------------- 1 | 3.0,1667990937414,0.1640625 2 | 146.0,1667990949593,0.150390625 3 | 292.0,1667990954319,0.201171875 4 | 438.0,1667990958988,2.099609375 5 | 584.0,1667990963797,3.716796875 6 | 730.0,1667990968479,4.158203125 7 | 876.0,1667990973189,4.669921875 8 | 1022.0,1667990977883,5.017578125 9 | 1168.0,1667990982511,5.109375 10 | 1314.0,1667990987194,5.46484375 11 | 1460.0,1667990991982,5.666015625 12 | 1606.0,1667990996737,5.923828125 13 | 1752.0,1667991001599,6.361328125 14 | 1898.0,1667991006212,6.271484375 15 | 2044.0,1667991010965,6.537109375 16 | 2190.0,1667991015780,6.369140625 17 | 2336.0,1667991020556,6.42578125 18 | 2482.0,1667991025266,6.423828125 19 | 2628.0,1667991029884,6.501953125 20 | 2774.0,1667991034668,6.990234375 21 | 2920.0,1667991039341,6.5703125 22 | 3066.0,1667991044152,6.99609375 23 | 3212.0,1667991048785,6.8515625 24 | 3358.0,1667991053333,7.271484375 25 | 3504.0,1667991058045,6.9765625 26 | 3650.0,1667991062715,7.658203125 27 | 3796.0,1667991067373,7.18359375 28 | 3942.0,1667991072033,7.66796875 29 | 4088.0,1667991076771,7.533203125 30 | 4234.0,1667991081420,8.12109375 31 | 4380.0,1667991086149,7.935546875 32 | 4526.0,1667991090892,8.572265625 33 | 4672.0,1667991095723,8.48046875 34 | 4818.0,1667991100498,8.603515625 35 | 4964.0,1667991105240,9.189453125 36 | 5110.0,1667991109971,9.40234375 37 | 5256.0,1667991114783,9.806640625 38 | 5402.0,1667991119575,10.375 39 | 5548.0,1667991124433,9.970703125 40 | 5694.0,1667991129171,10.189453125 41 | 5840.0,1667991133833,9.828125 42 | 5986.0,1667991138583,10.44140625 43 | 6132.0,1667991143252,10.896484375 44 | 6278.0,1667991148014,10.52734375 45 | 6424.0,1667991152701,10.849609375 46 | 6570.0,1667991157413,11.193359375 47 | 6716.0,1667991162146,11.25 48 | 6862.0,1667991166847,11.5625 49 | 7008.0,1667991171562,11.326171875 50 | 7154.0,1667991176224,11.890625 51 | 7300.0,1667991180981,11.90625 52 | 7446.0,1667991185732,11.916015625 53 | 7592.0,1667991190399,12.271484375 54 | 7738.0,1667991195114,11.97265625 55 | 7884.0,1667991199826,12.419921875 56 | 8030.0,1667991204493,12.728515625 57 | 8176.0,1667991209205,13.078125 58 | 8322.0,1667991214038,12.76171875 59 | 8468.0,1667991218773,13.279296875 60 | 8614.0,1667991223385,13.30859375 61 | 8760.0,1667991228138,13.70703125 62 | 8906.0,1667991232843,13.744140625 63 | 9052.0,1667991237629,14.0625 64 | 9198.0,1667991242342,13.998046875 65 | 9344.0,1667991247069,13.9296875 66 | 9490.0,1667991251826,14.623046875 67 | 9636.0,1667991256593,14.494140625 68 | 9782.0,1667991261376,14.705078125 69 | 9928.0,1667991266131,14.1328125 70 | 10074.0,1667991270872,14.677734375 71 | 10220.0,1667991275662,14.765625 72 | 10366.0,1667991280410,15.052734375 73 | 10512.0,1667991285244,15.21875 74 | 10658.0,1667991290031,15.06640625 75 | 10804.0,1667991294825,15.044921875 76 | 10950.0,1667991299609,15.283203125 77 | 11096.0,1667991304553,15.49609375 78 | 11242.0,1667991309331,15.607421875 79 | 11388.0,1667991314096,15.947265625 80 | 11534.0,1667991318770,15.580078125 81 | 11680.0,1667991323407,15.923828125 82 | 11826.0,1667991328133,15.4453125 83 | 11972.0,1667991332868,15.7109375 84 | 12118.0,1667991337699,15.78125 85 | 12264.0,1667991342398,16.42578125 86 | 12410.0,1667991347339,16.203125 87 | 12556.0,1667991352083,15.92578125 88 | 12702.0,1667991356872,15.810546875 89 | 12848.0,1667991361622,16.322265625 90 | 12994.0,1667991366434,16.486328125 91 | 13140.0,1667991371214,16.478515625 92 | 13286.0,1667991375937,15.986328125 93 | 13432.0,1667991380720,16.865234375 94 | 13578.0,1667991385572,16.5546875 95 | 13724.0,1667991390283,16.390625 96 | 13870.0,1667991395102,16.92578125 97 | 14016.0,1667991399841,16.8203125 98 | 14162.0,1667991404631,16.826171875 99 | 14308.0,1667991409331,16.689453125 100 | 14454.0,1667991414071,17.181640625 101 | 14600.0,1667991418798,16.630859375 102 | 14746.0,1667991423585,16.93359375 103 | 14892.0,1667991428374,17.185546875 104 | 15038.0,1667991433264,17.234375 105 | 15184.0,1667991438044,17.3359375 106 | 15330.0,1667991442932,17.123046875 107 | 15476.0,1667991447693,17.09765625 108 | 15622.0,1667991452475,17.470703125 109 | 15768.0,1667991457262,17.697265625 110 | 15914.0,1667991461967,16.7265625 111 | 16060.0,1667991466679,16.845703125 112 | 16206.0,1667991471428,17.474609375 113 | 16352.0,1667991476122,17.671875 114 | 16498.0,1667991480853,17.515625 115 | 16644.0,1667991485580,17.759765625 116 | 16790.0,1667991490366,17.287109375 117 | 16936.0,1667991495134,17.759765625 118 | 17082.0,1667991499978,18.0390625 119 | 17228.0,1667991504833,17.359375 120 | 17374.0,1667991509654,17.771484375 121 | 17520.0,1667991514421,18.08203125 122 | 17666.0,1667991519116,17.939453125 123 | 17812.0,1667991523916,17.93359375 124 | 17958.0,1667991528698,17.90625 125 | 18104.0,1667991533480,17.890625 126 | 18250.0,1667991538217,18.158203125 127 | 18396.0,1667991542963,18.001953125 128 | 18542.0,1667991547667,18.44140625 129 | 18688.0,1667991552512,18.70703125 130 | 18834.0,1667991557345,18.220703125 131 | 18980.0,1667991562149,18.453125 132 | 19126.0,1667991566940,18.51953125 133 | 19272.0,1667991571817,18.578125 134 | 19418.0,1667991576569,18.4375 135 | 19564.0,1667991581447,18.845703125 136 | 19710.0,1667991586289,18.8046875 137 | 19856.0,1667991591062,18.31640625 138 | 20002.0,1667991595848,18.169921875 139 | 20148.0,1667991600880,18.64453125 140 | 20294.0,1667991605803,18.5 141 | 20440.0,1667991610744,18.392578125 142 | 20586.0,1667991615568,18.87109375 143 | 20732.0,1667991620384,18.869140625 144 | 20878.0,1667991625163,18.3046875 145 | 21024.0,1667991630036,18.70703125 146 | 21170.0,1667991634834,18.951171875 147 | 21316.0,1667991639600,18.755859375 148 | 21462.0,1667991644588,18.5390625 149 | 21608.0,1667991649342,18.91015625 150 | 21754.0,1667991654198,18.625 151 | 21900.0,1667991659022,19.29296875 152 | 22046.0,1667991663911,18.865234375 153 | 22192.0,1667991668695,18.947265625 154 | 22338.0,1667991673446,18.875 155 | 22484.0,1667991678281,18.80859375 156 | 22630.0,1667991683035,18.8203125 157 | 22776.0,1667991687862,18.505859375 158 | 22922.0,1667991692626,18.64453125 159 | 23068.0,1667991697456,18.912109375 160 | 23214.0,1667991702198,19.240234375 161 | 23360.0,1667991706974,18.98828125 162 | 23506.0,1667991711699,18.845703125 163 | 23652.0,1667991716493,18.904296875 164 | 23798.0,1667991721318,19.40234375 165 | 23944.0,1667991726112,19.21875 166 | 24090.0,1667991730873,19.494140625 167 | 24236.0,1667991735641,18.853515625 168 | 24382.0,1667991740418,19.58203125 169 | 24528.0,1667991745359,19.25390625 170 | 24674.0,1667991750079,19.5546875 171 | 24820.0,1667991754934,19.34765625 172 | 24966.0,1667991759776,19.4296875 173 | 25112.0,1667991764663,19.888671875 174 | 25258.0,1667991769463,19.423828125 175 | 25404.0,1667991774231,19.443359375 176 | 25550.0,1667991778976,19.234375 177 | 25696.0,1667991783714,19.61328125 178 | 25842.0,1667991788490,19.671875 179 | 25988.0,1667991793263,19.51171875 180 | 26134.0,1667991798020,19.921875 181 | 26280.0,1667991802799,19.724609375 182 | 26426.0,1667991807516,19.607421875 183 | 26572.0,1667991812315,19.85546875 184 | 26718.0,1667991817102,19.48046875 185 | 26864.0,1667991821846,19.583984375 186 | 27010.0,1667991826575,20.162109375 187 | 27156.0,1667991831361,19.720703125 188 | 27302.0,1667991836128,20.16015625 189 | 27448.0,1667991840871,20.2109375 190 | 27594.0,1667991845610,20.15625 191 | 27740.0,1667991850416,20.48828125 192 | 27886.0,1667991855193,20.53125 193 | 28032.0,1667991859974,20.6640625 194 | 28178.0,1667991864791,20.453125 195 | 28324.0,1667991869531,20.693359375 196 | 28470.0,1667991874321,20.392578125 197 | 28616.0,1667991879030,20.24609375 198 | 28762.0,1667991883687,21.11328125 199 | 28908.0,1667991888340,20.8515625 200 | 29054.0,1667991893075,20.951171875 201 | 29200.0,1667991897900,20.712890625 202 | -------------------------------------------------------------------------------- /snake/data/mgrl_return/MET-2194__eval_episode_reward_stochastic_policy.csv: -------------------------------------------------------------------------------- 1 | 3.0,1667994808583,0.14453125 2 | 146.0,1667994820732,0.150390625 3 | 292.0,1667994825389,0.302734375 4 | 438.0,1667994830210,2.18359375 5 | 584.0,1667994834899,3.470703125 6 | 730.0,1667994839442,3.947265625 7 | 876.0,1667994844002,4.34765625 8 | 1022.0,1667994848501,4.857421875 9 | 1168.0,1667994853026,5.02734375 10 | 1314.0,1667994857578,5.376953125 11 | 1460.0,1667994862142,5.427734375 12 | 1606.0,1667994866708,5.953125 13 | 1752.0,1667994871247,5.830078125 14 | 1898.0,1667994875855,6.421875 15 | 2044.0,1667994880478,6.12890625 16 | 2190.0,1667994884998,6.6171875 17 | 2336.0,1667994889598,6.470703125 18 | 2482.0,1667994894185,6.49609375 19 | 2628.0,1667994898789,6.759765625 20 | 2774.0,1667994903484,6.6953125 21 | 2920.0,1667994908027,6.931640625 22 | 3066.0,1667994912699,6.7734375 23 | 3212.0,1667994917287,6.951171875 24 | 3358.0,1667994921867,6.73046875 25 | 3504.0,1667994926409,7.064453125 26 | 3650.0,1667994931051,7.1484375 27 | 3796.0,1667994935605,7.65625 28 | 3942.0,1667994940155,7.44921875 29 | 4088.0,1667994944649,7.28125 30 | 4234.0,1667994949259,7.166015625 31 | 4380.0,1667994953758,7.685546875 32 | 4526.0,1667994958333,7.607421875 33 | 4672.0,1667994962874,8.34765625 34 | 4818.0,1667994967449,8.4453125 35 | 4964.0,1667994972072,8.509765625 36 | 5110.0,1667994976735,8.7890625 37 | 5256.0,1667994981287,8.990234375 38 | 5402.0,1667994986004,9.361328125 39 | 5548.0,1667994990537,9.1875 40 | 5694.0,1667994995321,9.34765625 41 | 5840.0,1667994999783,9.8828125 42 | 5986.0,1667995004410,9.734375 43 | 6132.0,1667995008975,10.173828125 44 | 6278.0,1667995013582,10.51953125 45 | 6424.0,1667995018223,10.58203125 46 | 6570.0,1667995022808,10.67578125 47 | 6716.0,1667995027392,10.90234375 48 | 6862.0,1667995032014,11.21484375 49 | 7008.0,1667995036534,11.455078125 50 | 7154.0,1667995041153,11.66796875 51 | 7300.0,1667995046018,11.7421875 52 | 7446.0,1667995050607,11.865234375 53 | 7592.0,1667995055266,11.87109375 54 | 7738.0,1667995059875,12.90625 55 | 7884.0,1667995064422,12.857421875 56 | 8030.0,1667995069057,12.43359375 57 | 8176.0,1667995073595,13.041015625 58 | 8322.0,1667995078140,12.9609375 59 | 8468.0,1667995082727,13.52734375 60 | 8614.0,1667995087354,13.013671875 61 | 8760.0,1667995091935,13.8671875 62 | 8906.0,1667995096597,14.09765625 63 | 9052.0,1667995101178,14.234375 64 | 9198.0,1667995105894,14.79296875 65 | 9344.0,1667995110459,14.681640625 66 | 9490.0,1667995115098,14.140625 67 | 9636.0,1667995119742,15.1484375 68 | 9782.0,1667995124307,15.029296875 69 | 9928.0,1667995128966,15.4609375 70 | 10074.0,1667995133552,15.04296875 71 | 10220.0,1667995138192,15.287109375 72 | 10366.0,1667995142819,15.2890625 73 | 10512.0,1667995147453,15.56640625 74 | 10658.0,1667995151992,15.841796875 75 | 10804.0,1667995156603,15.7265625 76 | 10950.0,1667995161173,15.96875 77 | 11096.0,1667995165774,15.666015625 78 | 11242.0,1667995170411,16.15234375 79 | 11388.0,1667995175034,16.220703125 80 | 11534.0,1667995179664,15.693359375 81 | 11680.0,1667995184208,15.755859375 82 | 11826.0,1667995188836,15.849609375 83 | 11972.0,1667995193411,16.2890625 84 | 12118.0,1667995197964,15.87109375 85 | 12264.0,1667995202965,16.388671875 86 | 12410.0,1667995207637,16.205078125 87 | 12556.0,1667995212352,16.326171875 88 | 12702.0,1667995216930,16.37890625 89 | 12848.0,1667995221510,16.822265625 90 | 12994.0,1667995226120,16.48828125 91 | 13140.0,1667995230762,15.78125 92 | 13286.0,1667995235455,16.77734375 93 | 13432.0,1667995240079,17.0 94 | 13578.0,1667995244599,16.650390625 95 | 13724.0,1667995249332,16.953125 96 | 13870.0,1667995253854,17.123046875 97 | 14016.0,1667995258470,16.673828125 98 | 14162.0,1667995263117,17.328125 99 | 14308.0,1667995267839,17.111328125 100 | 14454.0,1667995272407,17.916015625 101 | 14600.0,1667995277022,17.333984375 102 | 14746.0,1667995281712,17.625 103 | 14892.0,1667995286490,17.6953125 104 | 15038.0,1667995291249,17.47265625 105 | 15184.0,1667995296017,18.083984375 106 | 15330.0,1667995300639,17.287109375 107 | 15476.0,1667995305343,17.9296875 108 | 15622.0,1667995309914,17.29296875 109 | 15768.0,1667995314820,17.646484375 110 | 15914.0,1667995319458,17.8203125 111 | 16060.0,1667995324154,17.931640625 112 | 16206.0,1667995328786,17.521484375 113 | 16352.0,1667995333387,18.01953125 114 | 16498.0,1667995338089,18.126953125 115 | 16644.0,1667995342742,17.509765625 116 | 16790.0,1667995347353,17.64453125 117 | 16936.0,1667995351964,18.166015625 118 | 17082.0,1667995356618,17.48828125 119 | 17228.0,1667995361183,17.9453125 120 | 17374.0,1667995365758,18.189453125 121 | 17520.0,1667995370432,17.77734375 122 | 17666.0,1667995374936,18.22265625 123 | 17812.0,1667995379710,17.84765625 124 | 17958.0,1667995384314,18.435546875 125 | 18104.0,1667995389019,17.84765625 126 | 18250.0,1667995393656,18.71484375 127 | 18396.0,1667995398405,18.6875 128 | 18542.0,1667995402898,18.646484375 129 | 18688.0,1667995407610,18.90625 130 | 18834.0,1667995412158,18.720703125 131 | 18980.0,1667995416868,19.25 132 | 19126.0,1667995421547,18.716796875 133 | 19272.0,1667995426235,19.15625 134 | 19418.0,1667995430822,19.349609375 135 | 19564.0,1667995435537,18.59765625 136 | 19710.0,1667995440203,19.478515625 137 | 19856.0,1667995444915,19.84765625 138 | 20002.0,1667995449530,19.03515625 139 | 20148.0,1667995454159,19.533203125 140 | 20294.0,1667995458744,19.044921875 141 | 20440.0,1667995463423,19.830078125 142 | 20586.0,1667995468000,19.46875 143 | 20732.0,1667995472602,20.056640625 144 | 20878.0,1667995477213,19.70703125 145 | 21024.0,1667995481849,19.2265625 146 | 21170.0,1667995486567,19.734375 147 | 21316.0,1667995491295,19.56640625 148 | 21462.0,1667995495967,19.7734375 149 | 21608.0,1667995500636,19.705078125 150 | 21754.0,1667995505261,19.97265625 151 | 21900.0,1667995509963,20.078125 152 | 22046.0,1667995514559,19.912109375 153 | 22192.0,1667995519341,20.06640625 154 | 22338.0,1667995524086,19.703125 155 | 22484.0,1667995528676,19.763671875 156 | 22630.0,1667995533240,20.02734375 157 | 22776.0,1667995537932,20.322265625 158 | 22922.0,1667995542657,21.021484375 159 | 23068.0,1667995547380,20.265625 160 | 23214.0,1667995552062,20.8828125 161 | 23360.0,1667995556843,20.6640625 162 | 23506.0,1667995561419,20.44140625 163 | 23652.0,1667995566063,20.302734375 164 | 23798.0,1667995570685,20.529296875 165 | 23944.0,1667995575210,20.388671875 166 | 24090.0,1667995579931,20.1484375 167 | 24236.0,1667995584488,20.498046875 168 | 24382.0,1667995589118,20.29296875 169 | 24528.0,1667995593779,20.76171875 170 | 24674.0,1667995598459,20.83203125 171 | 24820.0,1667995603161,20.748046875 172 | 24966.0,1667995607887,20.955078125 173 | 25112.0,1667995612613,21.248046875 174 | 25258.0,1667995617384,20.912109375 175 | 25404.0,1667995622055,21.388671875 176 | 25550.0,1667995626741,21.650390625 177 | 25696.0,1667995631324,21.18359375 178 | 25842.0,1667995635950,21.33984375 179 | 25988.0,1667995640567,21.564453125 180 | 26134.0,1667995645224,21.91015625 181 | 26280.0,1667995649868,21.7578125 182 | 26426.0,1667995654537,21.501953125 183 | 26572.0,1667995659204,22.232421875 184 | 26718.0,1667995663877,22.193359375 185 | 26864.0,1667995668570,21.900390625 186 | 27010.0,1667995673218,22.49609375 187 | 27156.0,1667995677887,22.32421875 188 | 27302.0,1667995682651,22.583984375 189 | 27448.0,1667995687232,22.490234375 190 | 27594.0,1667995691810,22.90625 191 | 27740.0,1667995696507,22.841796875 192 | 27886.0,1667995701109,22.845703125 193 | 28032.0,1667995705835,21.994140625 194 | 28178.0,1667995710488,22.865234375 195 | 28324.0,1667995715128,23.298828125 196 | 28470.0,1667995719694,23.322265625 197 | 28616.0,1667995724292,23.078125 198 | 28762.0,1667995728951,23.833984375 199 | 28908.0,1667995733568,23.427734375 200 | 29054.0,1667995738181,23.6015625 201 | 29200.0,1667995742754,23.5625 202 | -------------------------------------------------------------------------------- /snake/env/__init__.py: -------------------------------------------------------------------------------- 1 | from snake.env.snake import make_snake_env 2 | -------------------------------------------------------------------------------- /snake/env/snake.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jumanji 4 | import jumanji.wrappers 5 | 6 | 7 | def make_snake_env() -> Tuple[jumanji.Environment, jumanji.Environment]: 8 | snake = jumanji.make("Snake-6x6-v0") 9 | eval_snake = snake 10 | snake = jumanji.wrappers.AutoResetWrapper(snake) 11 | snake = jumanji.wrappers.VmapWrapper(snake) 12 | return snake, eval_snake 13 | -------------------------------------------------------------------------------- /snake/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from snake.networks.snake import make_actor_critic_networks_snake 2 | -------------------------------------------------------------------------------- /snake/networks/actor_critic.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, NamedTuple, Optional 2 | 3 | import chex 4 | import haiku as hk 5 | from jax import numpy as jnp 6 | 7 | from snake.networks.distribution import ParametricDistribution 8 | 9 | 10 | class FeedForwardNetwork(NamedTuple): 11 | init: Callable[ 12 | [chex.PRNGKey, chex.Array, Optional[jnp.float_]], 13 | hk.Params, 14 | ] 15 | apply: Callable[ 16 | [hk.Params, chex.Array, Optional[jnp.float_]], 17 | chex.Array, 18 | ] 19 | 20 | 21 | class ActorCriticNetworks(NamedTuple): 22 | """Defines the actor-critic networks, which outputs the logits of a policy, and a value given 23 | an observation. 24 | """ 25 | 26 | policy_network: FeedForwardNetwork 27 | value_network: FeedForwardNetwork 28 | outer_value_network: Optional[FeedForwardNetwork] 29 | parametric_action_distribution: ParametricDistribution 30 | -------------------------------------------------------------------------------- /snake/networks/cnn.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence 2 | 3 | import chex 4 | import haiku as hk 5 | import jax 6 | from jax import numpy as jnp 7 | 8 | from snake.networks.actor_critic import ActorCriticNetworks, FeedForwardNetwork 9 | from snake.networks.distribution import CategoricalParametricDistribution 10 | 11 | 12 | def make_actor_critic_networks_cnn( 13 | num_actions: int, 14 | num_channels: int, 15 | policy_layers: Sequence[int], 16 | value_layers: Sequence[int], 17 | outer_critic: bool, 18 | embedding_size_actor: Optional[int], 19 | embedding_size_critic: Optional[int], 20 | ) -> ActorCriticNetworks: 21 | """Make actor-critic networks for Snake.""" 22 | 23 | parametric_action_distribution = CategoricalParametricDistribution( 24 | num_actions=num_actions 25 | ) 26 | policy_network = make_network_cnn( 27 | num_outputs=num_actions, 28 | mlp_units=policy_layers, 29 | conv_n_channels=num_channels, 30 | embedding_size=embedding_size_actor, 31 | ) 32 | value_network = make_network_cnn( 33 | num_outputs=1, 34 | mlp_units=value_layers, 35 | conv_n_channels=num_channels, 36 | embedding_size=embedding_size_critic, 37 | ) 38 | if outer_critic: 39 | outer_value_network: Optional[FeedForwardNetwork] = make_network_cnn( 40 | num_outputs=1, 41 | mlp_units=value_layers, 42 | conv_n_channels=num_channels, 43 | embedding_size=embedding_size_critic, 44 | ) 45 | else: 46 | outer_value_network = None 47 | return ActorCriticNetworks( 48 | policy_network=policy_network, 49 | value_network=value_network, 50 | outer_value_network=outer_value_network, 51 | parametric_action_distribution=parametric_action_distribution, 52 | ) 53 | 54 | 55 | def make_network_cnn( 56 | num_outputs: int, 57 | mlp_units: Sequence[int], 58 | conv_n_channels: int, 59 | embedding_size: Optional[int], 60 | ) -> FeedForwardNetwork: 61 | def network_fn( 62 | observation: chex.Array, 63 | discount_factor: Optional[jnp.float_], 64 | ) -> chex.Array: 65 | torso = hk.Sequential( 66 | [ 67 | hk.Conv2D(conv_n_channels, (2, 2), 2), 68 | jax.nn.relu, 69 | hk.Conv2D(conv_n_channels, (2, 2), 1), 70 | jax.nn.relu, 71 | hk.Flatten(), 72 | ] 73 | ) 74 | if observation.ndim == 5: 75 | torso = jax.vmap(torso) 76 | x = torso(observation) 77 | if embedding_size: 78 | assert discount_factor is not None 79 | discount_factor = jnp.broadcast_to(discount_factor, (*x.shape[:-1], 1)) 80 | embedding = hk.Linear(embedding_size)(discount_factor) 81 | x = jnp.concatenate([x, embedding], axis=-1) 82 | head = hk.nets.MLP((*mlp_units, num_outputs), activate_final=False) 83 | if num_outputs == 1: 84 | return jnp.squeeze(head(x), axis=-1) 85 | else: 86 | return head(x) 87 | 88 | init, apply = hk.without_apply_rng(hk.transform(network_fn)) 89 | return FeedForwardNetwork(init=init, apply=apply) 90 | -------------------------------------------------------------------------------- /snake/networks/distribution.py: -------------------------------------------------------------------------------- 1 | """Copied from Brax and adapted with typing.""" 2 | import abc 3 | from typing import Any 4 | 5 | import chex 6 | import jax 7 | from jax import numpy as jnp 8 | 9 | 10 | class Postprocessor(abc.ABC): 11 | def forward(self, x: chex.Array) -> chex.Array: 12 | raise NotImplementedError 13 | 14 | def inverse(self, y: chex.Array) -> chex.Array: 15 | raise NotImplementedError 16 | 17 | def forward_log_det_jacobian(self, x: chex.Array) -> chex.Array: 18 | raise NotImplementedError 19 | 20 | 21 | class ParametricDistribution(abc.ABC): 22 | """Abstract class for parametric (action) distribution.""" 23 | 24 | def __init__( 25 | self, 26 | param_size: int, 27 | postprocessor: Postprocessor, 28 | event_ndims: int, 29 | reparametrizable: bool, 30 | ): 31 | """Abstract class for parametric (action) distribution. 32 | Specifies how to transform distribution parameters (i.e. actor output) 33 | into a distribution over actions. 34 | Args: 35 | param_size: size of the parameters for the distribution 36 | postprocessor: bijector which is applied after sampling (in practice, it's 37 | tanh or identity) 38 | event_ndims: rank of the distribution sample (i.e. action) 39 | reparametrizable: is the distribution reparametrizable 40 | """ 41 | self._param_size = param_size 42 | self._postprocessor = postprocessor 43 | self._event_ndims = event_ndims # rank of events 44 | self._reparametrizable = reparametrizable 45 | assert event_ndims in [0, 1] 46 | 47 | @abc.abstractmethod 48 | def create_dist(self, parameters: chex.Array) -> Any: 49 | """Creates distribution from parameters.""" 50 | 51 | @property 52 | def param_size(self) -> int: 53 | return self._param_size 54 | 55 | @property 56 | def reparametrizable(self) -> bool: 57 | return self._reparametrizable 58 | 59 | def postprocess(self, event: chex.Array) -> chex.Array: 60 | return self._postprocessor.forward(event) 61 | 62 | def inverse_postprocess(self, event: chex.Array) -> chex.Array: 63 | return self._postprocessor.inverse(event) 64 | 65 | def sample_no_postprocessing( 66 | self, parameters: chex.Array, seed: chex.PRNGKey 67 | ) -> Any: 68 | return self.create_dist(parameters).sample(seed=seed) 69 | 70 | def sample(self, parameters: chex.Array, seed: chex.PRNGKey) -> chex.Array: 71 | """Returns a sample from the postprocessed distribution.""" 72 | return self.postprocess(self.sample_no_postprocessing(parameters, seed)) 73 | 74 | def mode(self, parameters: chex.Array) -> chex.Array: 75 | """Returns the mode of the postprocessed distribution.""" 76 | return self.postprocess(self.create_dist(parameters).mode()) 77 | 78 | def log_prob(self, parameters: chex.Array, raw_actions: chex.Array) -> chex.Array: 79 | """Compute the log probability of actions.""" 80 | dist = self.create_dist(parameters) 81 | log_probs = dist.log_prob(raw_actions) 82 | log_probs -= self._postprocessor.forward_log_det_jacobian(raw_actions) 83 | if self._event_ndims == 1: 84 | log_probs = jnp.sum(log_probs, axis=-1) # sum over action dimension 85 | return log_probs 86 | 87 | def entropy(self, parameters: chex.Array, seed: chex.PRNGKey) -> chex.Array: 88 | """Return the entropy of the given distribution.""" 89 | dist = self.create_dist(parameters) 90 | entropy = dist.entropy() 91 | entropy += self._postprocessor.forward_log_det_jacobian(dist.sample(seed=seed)) 92 | if self._event_ndims == 1: 93 | entropy = jnp.sum(entropy, axis=-1) 94 | return entropy 95 | 96 | def kl_divergence( 97 | self, parameters: chex.Array, other_parameters: chex.Array 98 | ) -> chex.Array: 99 | """KL divergence is invariant with respect to transformation by the same bijector.""" 100 | dist = self.create_dist(parameters) 101 | other_dist = self.create_dist(other_parameters) 102 | return dist.kl_divergence(other_dist) 103 | 104 | 105 | class IdentityBijector(Postprocessor): 106 | """Identity Bijector.""" 107 | 108 | def forward(self, x: chex.Array) -> chex.Array: 109 | return x 110 | 111 | def inverse(self, y: chex.Array) -> chex.Array: 112 | return y 113 | 114 | def forward_log_det_jacobian(self, x: chex.Array) -> chex.Array: 115 | return jnp.zeros_like(x, x.dtype) 116 | 117 | 118 | class CategoricalDistribution: 119 | """Categorical distribution.""" 120 | 121 | def __init__(self, logits: chex.Array): 122 | self.logits = logits 123 | self.num_actions = jnp.shape(logits)[-1] 124 | 125 | def sample(self, seed: chex.PRNGKey) -> chex.Array: 126 | return jax.random.categorical(seed, self.logits) 127 | 128 | def mode(self) -> chex.Array: 129 | return jnp.argmax(self.logits, axis=-1) 130 | 131 | def log_prob(self, x: chex.Array) -> chex.Array: 132 | value_one_hot = jax.nn.one_hot(x, self.num_actions) 133 | mask_outside_domain = jnp.logical_or(x < 0, x > self.num_actions - 1) 134 | safe_log_probs = jnp.where( 135 | value_one_hot == 0, 136 | jnp.zeros((), dtype=self.logits.dtype), 137 | jax.nn.log_softmax(self.logits) * value_one_hot, 138 | ) 139 | return jnp.where( 140 | mask_outside_domain, 141 | -jnp.inf, 142 | jnp.sum(safe_log_probs, axis=-1), 143 | ) 144 | 145 | def entropy(self) -> chex.Array: 146 | log_probs = jax.nn.log_softmax(self.logits) 147 | probs = jnp.exp(log_probs) 148 | return -jnp.sum(jnp.where(probs == 0, 0.0, probs * log_probs), axis=-1) 149 | 150 | def kl_divergence(self, other: "CategoricalDistribution") -> chex.Array: 151 | log_probs = jax.nn.log_softmax(self.logits) 152 | probs = jnp.exp(log_probs) 153 | log_probs_other = jax.nn.log_softmax(other.logits) 154 | return jnp.sum( 155 | jnp.where(probs == 0, 0.0, probs * (log_probs - log_probs_other)), axis=-1 156 | ) 157 | 158 | 159 | class CategoricalParametricDistribution(ParametricDistribution): 160 | """Categorical distribution for discrete action spaces.""" 161 | 162 | def __init__(self, num_actions: int): 163 | """Initialize the distribution. 164 | Args: 165 | num_actions: the number of actions. 166 | """ 167 | postprocessor: Postprocessor 168 | postprocessor = IdentityBijector() 169 | super().__init__( 170 | param_size=num_actions, 171 | postprocessor=postprocessor, 172 | event_ndims=0, 173 | reparametrizable=True, 174 | ) 175 | 176 | def create_dist(self, parameters: chex.Array) -> CategoricalDistribution: 177 | return CategoricalDistribution(logits=parameters) 178 | -------------------------------------------------------------------------------- /snake/networks/snake.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence 2 | 3 | from snake.networks import cnn 4 | from snake.networks.actor_critic import ActorCriticNetworks 5 | 6 | 7 | def make_actor_critic_networks_snake( 8 | num_channels: int, 9 | policy_layers: Sequence[int], 10 | value_layers: Sequence[int], 11 | outer_critic: bool, 12 | embedding_size_actor: Optional[int], 13 | embedding_size_critic: Optional[int], 14 | ) -> ActorCriticNetworks: 15 | """Make actor-critic networks for Snake.""" 16 | 17 | return cnn.make_actor_critic_networks_cnn( 18 | num_actions=4, 19 | num_channels=num_channels, 20 | policy_layers=policy_layers, 21 | value_layers=value_layers, 22 | outer_critic=outer_critic, 23 | embedding_size_actor=embedding_size_actor, 24 | embedding_size_critic=embedding_size_critic, 25 | ) 26 | -------------------------------------------------------------------------------- /snake/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/snake/training/__init__.py -------------------------------------------------------------------------------- /snake/training/config.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, Optional, Sequence 2 | 3 | import omegaconf 4 | 5 | 6 | class Config(NamedTuple): 7 | name: str 8 | agent: str 9 | meta_objective: Optional[str] 10 | outer_critic: bool 11 | num_timesteps: int 12 | num_eval_points: int 13 | total_batch_size: int 14 | total_num_envs: int 15 | n_steps: int 16 | total_num_eval: int 17 | deterministic_eval: bool 18 | network_type: str 19 | num_channels: int 20 | policy_layers: Sequence[int] 21 | value_layers: Sequence[int] 22 | embedding_size_actor: Optional[int] 23 | embedding_size_critic: Optional[int] 24 | optimizer: str 25 | learning_rate: float 26 | gradient_clip_norm: Optional[float] 27 | bootstrap_l_optimizer: Optional[str] 28 | bootstrap_l_learning_rate: Optional[float] 29 | meta_optimizer: Optional[str] 30 | meta_learning_rate: Optional[float] 31 | meta_gradient_clip_norm: Optional[float] 32 | normalize_advantage: bool 33 | normalize_outer_advantage: Optional[bool] 34 | reward_scaling: float 35 | gamma_init: float 36 | gamma_outer: Optional[float] 37 | gamma_metal: Optional[bool] 38 | lambda_init: float 39 | lambda_outer: Optional[float] 40 | lambda_metal: Optional[bool] 41 | l_pg_init: float 42 | l_pg_outer: Optional[float] 43 | l_pg_metal: Optional[bool] 44 | l_td_init: float 45 | l_td_outer: Optional[float] 46 | l_td_metal: Optional[bool] 47 | l_en_init: float 48 | l_en_outer: Optional[float] 49 | l_en_metal: Optional[bool] 50 | l_kl_outer: Optional[float] 51 | bootstrap_l: Optional[int] 52 | bootstrap_pm: Optional[float] 53 | bootstrap_vm: Optional[float] 54 | seed: int 55 | 56 | 57 | def convert_config(cfg: omegaconf.DictConfig) -> Config: 58 | return Config( 59 | name=cfg.training.name, 60 | agent=cfg.agent.agent, 61 | meta_objective=cfg.agent.meta_objective, 62 | num_timesteps=cfg.training.num_timesteps, 63 | num_eval_points=cfg.training.num_eval_points, 64 | reward_scaling=cfg.training.reward_scaling, 65 | n_steps=cfg.training.n_steps, 66 | total_batch_size=cfg.training.total_batch_size, 67 | total_num_envs=cfg.training.total_num_envs, 68 | optimizer=cfg.training.optimizer, 69 | learning_rate=float(cfg.training.learning_rate), 70 | bootstrap_l_optimizer=cfg.agent.bootstrap_l_optimizer, 71 | bootstrap_l_learning_rate=cfg.agent.bootstrap_l_learning_rate, 72 | meta_optimizer=cfg.agent.meta_optimizer, 73 | meta_learning_rate=cfg.agent.meta_learning_rate, 74 | gradient_clip_norm=cfg.training.gradient_clip_norm, 75 | meta_gradient_clip_norm=cfg.agent.meta_gradient_clip_norm, 76 | gamma_init=float(cfg.training.gamma_init), 77 | gamma_outer=float(cfg.agent.gamma_outer) 78 | if cfg.agent.gamma_outer is not None 79 | else None, 80 | gamma_metal=cfg.agent.gamma_metal, 81 | lambda_init=float(cfg.training.lambda_init), 82 | lambda_outer=float(cfg.agent.lambda_outer) 83 | if cfg.agent.lambda_outer is not None 84 | else None, 85 | lambda_metal=cfg.agent.lambda_metal, 86 | l_pg_init=float(cfg.training.l_pg_init), 87 | l_pg_outer=float(cfg.agent.l_pg_outer) 88 | if cfg.agent.l_pg_outer is not None 89 | else None, 90 | l_pg_metal=cfg.agent.l_pg_metal, 91 | l_td_init=float(cfg.training.l_td_init), 92 | l_td_outer=float(cfg.agent.l_td_outer) 93 | if cfg.agent.l_td_outer is not None 94 | else None, 95 | l_td_metal=cfg.agent.l_td_metal, 96 | l_en_init=float(cfg.training.l_en_init), 97 | l_en_outer=float(cfg.agent.l_en_outer) 98 | if cfg.agent.l_en_outer is not None 99 | else None, 100 | l_en_metal=cfg.agent.l_en_metal, 101 | l_kl_outer=float(cfg.agent.l_kl_outer) 102 | if cfg.agent.l_kl_outer is not None 103 | else None, 104 | bootstrap_l=cfg.agent.bootstrap_l, 105 | bootstrap_pm=cfg.agent.bootstrap_pm, 106 | bootstrap_vm=cfg.agent.bootstrap_vm, 107 | normalize_advantage=cfg.agent.normalize_advantage, 108 | normalize_outer_advantage=cfg.agent.normalize_outer_advantage, 109 | total_num_eval=cfg.agent.total_num_eval, 110 | deterministic_eval=cfg.agent.deterministic_eval, 111 | policy_layers=cfg.network.policy_layers, 112 | value_layers=cfg.network.value_layers, 113 | outer_critic=cfg.agent.outer_critic, 114 | seed=cfg.training.seed, 115 | embedding_size_actor=cfg.network.embedding_size_actor, 116 | embedding_size_critic=cfg.network.embedding_size_critic, 117 | network_type=cfg.network.type, 118 | num_channels=cfg.network.num_channels, 119 | ) 120 | -------------------------------------------------------------------------------- /snake/training/evaluator.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Callable, Dict, Optional, Tuple 3 | 4 | import chex 5 | import haiku as hk 6 | import jax 7 | import jumanji.types 8 | from jax import numpy as jnp 9 | 10 | from snake.agent import ActorCriticAgent 11 | from snake.training.types import TrainingState 12 | 13 | 14 | class Evaluator: 15 | """Class to run evaluations.""" 16 | 17 | def __init__( 18 | self, 19 | eval_env: jumanji.Environment, 20 | actor_critic_agent: ActorCriticAgent, 21 | total_num_eval: int, 22 | key: chex.PRNGKey, 23 | deterministic: bool, 24 | ): 25 | self._eval_env = eval_env 26 | self._actor_critic_agent = actor_critic_agent 27 | 28 | num_devices = jax.local_device_count() 29 | self._num_devices = num_devices 30 | assert total_num_eval % num_devices == 0 31 | self._total_num_eval = total_num_eval 32 | self._num_eval_per_device = total_num_eval // num_devices 33 | 34 | self._key = key 35 | self._deterministic = deterministic 36 | self.generate_eval_unroll = jax.pmap( 37 | self._generate_eval_unroll, axis_name="devices" 38 | ) 39 | 40 | def _unroll_one_step( 41 | self, 42 | policy: Callable[[chex.Array, chex.PRNGKey], Tuple[chex.Array, Dict]], 43 | carry: Tuple[jumanji.env.State, jumanji.types.TimeStep], 44 | key: chex.PRNGKey, 45 | ) -> Tuple[Tuple[jumanji.env.State, jumanji.types.TimeStep], None]: 46 | state, timestep = carry 47 | observation = jax.tree_util.tree_map( 48 | lambda x: x[None, ...], timestep.observation 49 | ) 50 | action, _ = policy(observation, key) 51 | next_state, next_timestep = self._eval_env.step(state, jnp.squeeze(action)) 52 | return (next_state, next_timestep), None 53 | 54 | def _generate_eval_one_episode( 55 | self, 56 | policy_params: hk.Params, 57 | discount_factor: Optional[jnp.float_], 58 | key: chex.PRNGKey, 59 | ) -> Tuple[Dict, Dict]: 60 | stochastic_policy = self._actor_critic_agent.make_policy( 61 | policy_params, 62 | discount_factor, 63 | deterministic=False, 64 | ) 65 | determinist_policy = self._actor_critic_agent.make_policy( 66 | policy_params, 67 | discount_factor, 68 | deterministic=True, 69 | ) 70 | 71 | def cond_fun( 72 | carry: Tuple[ 73 | jumanji.env.State, 74 | jumanji.types.TimeStep, 75 | chex.PRNGKey, 76 | jnp.float32, 77 | jnp.int32, 78 | ] 79 | ) -> jnp.bool_: 80 | _, timestep, *_ = carry 81 | return ~timestep.last() 82 | 83 | def body_fun( 84 | policy: Callable[[chex.Array, chex.PRNGKey], Tuple[chex.Array, Dict]], 85 | carry: Tuple[ 86 | jumanji.env.State, 87 | jumanji.types.TimeStep, 88 | chex.PRNGKey, 89 | jnp.float32, 90 | jnp.int32, 91 | ], 92 | ) -> Tuple[ 93 | jumanji.env.State, 94 | jumanji.types.TimeStep, 95 | chex.PRNGKey, 96 | jnp.float32, 97 | jnp.int32, 98 | ]: 99 | state, timestep, key, return_, count = carry 100 | key, step_key = jax.random.split(key) 101 | (state, timestep), _ = self._unroll_one_step( 102 | policy, (state, timestep), step_key 103 | ) 104 | return_ += timestep.reward 105 | count += 1 106 | return state, timestep, key, return_, count 107 | 108 | ( 109 | reset_key_stochastic, 110 | reset_key_determinist, 111 | init_key_stochastic, 112 | init_key_determinist, 113 | ) = jax.random.split(key, 4) 114 | state_stochastic, timestep_stochastic = self._eval_env.reset( 115 | reset_key_stochastic 116 | ) 117 | _, _, _, return_stochastic, count_stochastic = jax.lax.while_loop( 118 | cond_fun, 119 | functools.partial(body_fun, stochastic_policy), 120 | ( 121 | state_stochastic, 122 | timestep_stochastic, 123 | init_key_stochastic, 124 | jnp.float32(0), 125 | jnp.int32(0), 126 | ), 127 | ) 128 | eval_metrics_stochastic = { 129 | "episode_reward": return_stochastic, 130 | "episode_length": count_stochastic, 131 | } 132 | state_determinist, timestep_determinist = self._eval_env.reset( 133 | reset_key_stochastic 134 | ) 135 | _, _, _, return_deterministic, count_deterministic = jax.lax.while_loop( 136 | cond_fun, 137 | functools.partial(body_fun, determinist_policy), 138 | ( 139 | state_determinist, 140 | timestep_determinist, 141 | init_key_determinist, 142 | jnp.float32(0), 143 | jnp.int32(0), 144 | ), 145 | ) 146 | eval_metrics_deterministic = { 147 | "episode_reward": return_deterministic, 148 | "episode_length": count_deterministic, 149 | } 150 | return eval_metrics_stochastic, eval_metrics_deterministic 151 | 152 | def _generate_eval_unroll( 153 | self, 154 | policy_params: hk.Params, 155 | discount_factor: Optional[jnp.float_], 156 | key: chex.PRNGKey, 157 | ) -> Dict: 158 | 159 | keys = jax.random.split(key, self._num_eval_per_device) 160 | eval_metrics = jax.vmap( 161 | self._generate_eval_one_episode, in_axes=(None, None, 0) 162 | )( 163 | policy_params, 164 | discount_factor, 165 | keys, 166 | ) 167 | eval_metrics: Dict = jax.lax.pmean( 168 | jax.tree_util.tree_map(jnp.mean, eval_metrics), 169 | axis_name="devices", 170 | ) 171 | 172 | return eval_metrics 173 | 174 | def run_evaluation(self, training_state: TrainingState) -> Dict: 175 | """Run one epoch of evaluation.""" 176 | self._key, unroll_key = jax.random.split(self._key) 177 | 178 | unroll_keys = jax.random.split(unroll_key, self._num_devices) 179 | if training_state.meta_params is not None: 180 | discount_factor = jax.nn.sigmoid(training_state.meta_params.gamma) 181 | else: 182 | discount_factor = None 183 | eval_metrics_stochastic, eval_metrics_determinist = self.generate_eval_unroll( 184 | training_state.params.actor, 185 | discount_factor, 186 | unroll_keys, 187 | ) 188 | stochastic_metrics = eval_metrics_stochastic 189 | determinist_metrics = eval_metrics_determinist 190 | metrics = { 191 | **{ 192 | key + "_stochastic_policy": value 193 | for key, value in stochastic_metrics.items() 194 | }, 195 | **{ 196 | key + "_determinist_policy": value 197 | for key, value in determinist_metrics.items() 198 | }, 199 | } 200 | return metrics 201 | -------------------------------------------------------------------------------- /snake/training/logger.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Any, Callable, Dict, Mapping, Optional 3 | 4 | import chex 5 | import numpy as np 6 | from neptune import new as neptune 7 | 8 | 9 | class Logger(abc.ABC): 10 | # copied from Acme 11 | """A logger has a `write` method.""" 12 | 13 | @abc.abstractmethod 14 | def write(self, data: Mapping[str, Any], *args: Any, **kwargs: Any) -> None: 15 | """Writes `data` to destination (file, terminal, database, etc).""" 16 | 17 | @abc.abstractmethod 18 | def save_checkpoint(self, file_name: str) -> None: 19 | """Saves a checkpoint.""" 20 | 21 | @abc.abstractmethod 22 | def close(self) -> None: 23 | """Closes the logger, not expecting any further write.""" 24 | 25 | def __enter__(self) -> "Logger": 26 | return self 27 | 28 | def __exit__( 29 | self, exc_type: Exception, exc_val: Exception, exc_tb: Exception 30 | ) -> None: 31 | self.close() 32 | 33 | 34 | class NeptuneLogger(Logger): 35 | def __init__( 36 | self, 37 | config: Dict[str, Any], 38 | aggregation_fn: Callable[[chex.Array], chex.Array], 39 | seed: Optional[int] = None, 40 | **kwargs: Any, 41 | ): 42 | self.run = neptune.init( 43 | project="", # Change the project name to your Neptune project 44 | name=config["name"] + (f"_seed_{seed}" if seed is not None else ""), 45 | **kwargs, 46 | ) 47 | self.run["config"] = config 48 | self._t: int = 0 49 | del aggregation_fn # Moved to logging mean and last sample instead. 50 | self.mean_fn = np.mean 51 | 52 | def downsample(x: np.ndarray) -> np.ndarray: 53 | return x[-1] # type: ignore 54 | 55 | self.downsample_fn = downsample 56 | 57 | def write( # noqa: CCR001 58 | self, 59 | data: Mapping[str, Any], 60 | label: str = "", 61 | timestep: Optional[int] = None, 62 | *args: Any, 63 | **kwargs: Any, 64 | ) -> None: 65 | self._t = timestep or self._t + 1 66 | prefix = label and f"{label}/" 67 | for key, metric in data.items(): 68 | if np.ndim(metric) == 0: 69 | if not np.isnan(metric): 70 | self.run[f"{prefix}/{key}"].log( 71 | float(metric), 72 | step=self._t, 73 | wait=True, 74 | ) 75 | elif np.ndim(metric) == 1: 76 | metric_value_mean = self.mean_fn(metric) 77 | if not np.isnan(metric_value_mean): 78 | self.run[f"{prefix}mean/{key}"].log( 79 | metric_value_mean.item(), 80 | step=self._t, 81 | wait=True, 82 | ) 83 | metric_value_sample = self.downsample_fn(metric) 84 | if not np.isnan(metric_value_sample): 85 | self.run[f"{prefix}sample/{key}"].log( 86 | metric_value_sample.item(), 87 | step=self._t, 88 | wait=True, 89 | ) 90 | else: 91 | raise ValueError( 92 | f"Expected metric to be 0 or 1 dimension, got {metric}." 93 | ) 94 | 95 | def save_checkpoint(self, file_name: str) -> None: 96 | self.run[f"checkpoints/{file_name}"].upload(file_name) 97 | 98 | def close(self) -> None: 99 | self.run.stop() 100 | 101 | 102 | class TerminalLogger(Logger): 103 | def __init__( 104 | self, aggregation_fn: Callable[[chex.Array], chex.Array], **kwargs: Any 105 | ): 106 | self.aggregation_fn = aggregation_fn 107 | print(">>> Terminal Logger") 108 | 109 | def write( 110 | self, 111 | data: Mapping[str, Any], 112 | label: str = "", 113 | timestep: Optional[int] = None, 114 | *args: Any, 115 | **kwargs: Any, 116 | ) -> None: 117 | gamma = data.get("gamma", None) 118 | return_ = data.get("episode_reward_stochastic_policy", None) 119 | if timestep is not None: 120 | print_str = f"\nTimestep {timestep:.2e} >>> " 121 | else: 122 | print_str = "\n" 123 | if return_ is not None: 124 | print_str += f"mean_return: {self.aggregation_fn(return_):.2f} " 125 | if gamma is not None: 126 | print_str += f"discount_factor: {self.aggregation_fn(gamma):.5f}" 127 | print(print_str) 128 | 129 | def save_checkpoint(self, file_name: str) -> None: 130 | pass 131 | 132 | def close(self) -> None: 133 | pass 134 | 135 | 136 | def make_logger_factory( 137 | logger: str, 138 | config_dict: Dict[str, Any], 139 | aggregation_behaviour: str = "mean", 140 | seed: Optional[int] = None, 141 | ) -> Callable[[], Logger]: 142 | if aggregation_behaviour == "mean": 143 | aggregation_fn = np.mean 144 | else: 145 | raise ValueError( 146 | f"aggregation_behaviour is expected to be 'mean', got {aggregation_behaviour} instead." 147 | ) 148 | 149 | def make_logger() -> Logger: 150 | if logger == "neptune": 151 | return NeptuneLogger(config_dict, aggregation_fn, seed) 152 | elif logger == "terminal": 153 | return TerminalLogger(aggregation_fn) 154 | else: 155 | raise ValueError( 156 | f"expected logger in ['neptune', 'terminal'], got {logger}." 157 | ) 158 | 159 | return make_logger 160 | -------------------------------------------------------------------------------- /snake/training/setup_run.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import jax 4 | import jumanji 5 | import optax 6 | 7 | from snake.agent import A2C, ActorCriticAgent, MetaA2C 8 | from snake.env import make_snake_env 9 | from snake.networks import make_actor_critic_networks_snake 10 | from snake.networks.actor_critic import ActorCriticNetworks 11 | from snake.training.config import Config 12 | from snake.training.evaluator import Evaluator 13 | from snake.training.types import HyperParams, Metal 14 | 15 | 16 | def setup(config: Config) -> Tuple[ActorCriticAgent, Evaluator]: 17 | env, eval_env = make_snake_env() 18 | 19 | agent = make_agent(config, env) 20 | evaluator = Evaluator( 21 | eval_env=eval_env, 22 | actor_critic_agent=agent, 23 | total_num_eval=config.total_num_eval, 24 | key=jax.random.PRNGKey(config.seed), 25 | deterministic=config.deterministic_eval, 26 | ) 27 | return agent, evaluator 28 | 29 | 30 | def make_agent( # noqa: CCR001 31 | config: Config, env: jumanji.Environment 32 | ) -> ActorCriticAgent: 33 | 34 | actor_critic_networks: ActorCriticNetworks 35 | if config.network_type == "snake": 36 | actor_critic_networks = make_actor_critic_networks_snake( 37 | num_channels=config.num_channels, 38 | policy_layers=config.policy_layers, 39 | value_layers=config.value_layers, 40 | outer_critic=config.outer_critic, 41 | embedding_size_actor=config.embedding_size_actor, 42 | embedding_size_critic=config.embedding_size_critic, 43 | ) 44 | else: 45 | raise ValueError 46 | 47 | if config.optimizer == "sgd": 48 | optimizer = optax.sgd(config.learning_rate) 49 | elif config.optimizer == "rmsprop": 50 | optimizer = optax.rmsprop(config.learning_rate, decay=0.9) 51 | elif config.optimizer == "adam": 52 | optimizer = optax.adam(config.learning_rate, eps_root=1e-8) 53 | else: 54 | raise ValueError( 55 | f"Expected optimizer to be in ['sgd', 'rmsprop', 'adam'], got " 56 | f"{config.optimizer} instead." 57 | ) 58 | 59 | if config.gradient_clip_norm is not None: 60 | optimizer = optax.chain( 61 | optax.clip_by_global_norm(config.gradient_clip_norm), 62 | optimizer, 63 | ) 64 | agent: ActorCriticAgent 65 | if config.agent == "a2c": 66 | agent = A2C( 67 | n_steps=config.n_steps, 68 | total_batch_size=config.total_batch_size, 69 | total_num_envs=config.total_num_envs, 70 | env=env, 71 | actor_critic_networks=actor_critic_networks, 72 | optimizer=optimizer, 73 | normalize_advantage=config.normalize_advantage, 74 | reward_scaling=config.reward_scaling, 75 | hyper_params=HyperParams( 76 | gamma=config.gamma_init, 77 | lambda_=config.lambda_init, 78 | l_pg=config.l_pg_init, 79 | l_td=config.l_td_init, 80 | l_en=config.l_en_init, 81 | ), 82 | env_type=config.network_type, 83 | ) 84 | elif config.agent == "meta_a2c": 85 | assert config.meta_learning_rate is not None 86 | if config.meta_optimizer == "sgd": 87 | meta_optimizer = optax.sgd(config.meta_learning_rate) 88 | elif config.meta_optimizer == "rmsprop": 89 | meta_optimizer = optax.rmsprop(config.meta_learning_rate, decay=0.9) 90 | elif config.meta_optimizer == "adam": 91 | meta_optimizer = optax.adam(config.meta_learning_rate, eps_root=1e-8) 92 | else: 93 | raise ValueError( 94 | f"Expected meta_optimizer to be in ['sgd', 'rmsprop', 'adam'], got " 95 | f"{config.meta_optimizer} instead." 96 | ) 97 | if config.meta_gradient_clip_norm is not None: 98 | meta_optimizer = optax.chain( 99 | optax.clip_by_global_norm(config.meta_gradient_clip_norm), 100 | meta_optimizer, 101 | ) 102 | 103 | if config.bootstrap_l_optimizer is not None: 104 | if config.bootstrap_l_optimizer == "sgd": 105 | bootstrap_l_optimizer = optax.sgd(config.bootstrap_l_learning_rate) 106 | elif config.bootstrap_l_optimizer == "rmsprop": 107 | bootstrap_l_optimizer = optax.rmsprop( 108 | config.bootstrap_l_learning_rate, decay=0.9 109 | ) 110 | elif config.bootstrap_l_optimizer == "adam": 111 | bootstrap_l_optimizer = optax.adam( 112 | config.bootstrap_l_learning_rate, eps_root=1e-8 113 | ) 114 | else: 115 | raise ValueError( 116 | f"Expected bootstrap_l_optimizer to be in ['sgd', 'rmsprop', 'adam'], got " 117 | f"{config.bootstrap_l_optimizer} instead." 118 | ) 119 | else: 120 | bootstrap_l_optimizer = None 121 | 122 | hyper_params_init = HyperParams( 123 | gamma=config.gamma_init, 124 | lambda_=config.lambda_init, 125 | l_pg=config.l_pg_init, 126 | l_td=config.l_td_init, 127 | l_en=config.l_en_init, 128 | ) 129 | outer_hyper_params = HyperParams( 130 | gamma=config.gamma_outer, 131 | lambda_=config.lambda_outer, 132 | l_pg=config.l_pg_outer, 133 | l_td=config.l_td_outer, 134 | l_en=config.l_en_outer, 135 | ) 136 | assert config.normalize_outer_advantage is not None 137 | metal = Metal( 138 | gamma=(lambda _: _) if config.gamma_metal else jax.lax.stop_gradient, 139 | lambda_=(lambda _: _) if config.lambda_metal else jax.lax.stop_gradient, 140 | l_pg=(lambda _: _) if config.l_pg_metal else jax.lax.stop_gradient, 141 | l_td=(lambda _: _) if config.l_td_metal else jax.lax.stop_gradient, 142 | l_en=(lambda _: _) if config.l_en_metal else jax.lax.stop_gradient, 143 | ) 144 | agent = MetaA2C( 145 | n_steps=config.n_steps, 146 | total_batch_size=config.total_batch_size, 147 | total_num_envs=config.total_num_envs, 148 | env=env, 149 | actor_critic_networks=actor_critic_networks, 150 | optimizer=optimizer, 151 | bootstrap_l_optimizer=bootstrap_l_optimizer, 152 | meta_optimizer=meta_optimizer, 153 | normalize_advantage=config.normalize_advantage, 154 | normalize_outer_advantage=config.normalize_outer_advantage, 155 | reward_scaling=config.reward_scaling, 156 | hyper_params_init=hyper_params_init, 157 | outer_hyper_params=outer_hyper_params, 158 | meta_objective=config.meta_objective, 159 | bootstrap_l=config.bootstrap_l, 160 | bootstrap_pm=config.bootstrap_pm, 161 | bootstrap_vm=config.bootstrap_vm, 162 | metal=metal, 163 | l_kl_outer=config.l_kl_outer, 164 | env_type=config.network_type, 165 | ) 166 | else: 167 | raise ValueError( 168 | f"Expected agent in ['a2c', 'meta_a2c'], got {config.agent} instead." 169 | ) 170 | return agent 171 | -------------------------------------------------------------------------------- /snake/training/types.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, NamedTuple, Optional 2 | 3 | import chex 4 | import haiku as hk 5 | import jumanji.types 6 | import optax 7 | from jax import numpy as jnp 8 | 9 | 10 | class HyperParams(NamedTuple): 11 | gamma: jnp.float_ 12 | lambda_: jnp.float_ 13 | l_pg: jnp.float_ # policy gradient loss cost 14 | l_td: jnp.float_ # temporal difference loss cost 15 | l_en: jnp.float_ # entropy loss cost 16 | 17 | 18 | MetaParams = HyperParams 19 | 20 | 21 | class Metal(NamedTuple): 22 | gamma: Callable[[chex.Array], chex.Array] 23 | lambda_: Callable[[chex.Array], chex.Array] 24 | l_pg: Callable[[chex.Array], chex.Array] 25 | l_td: Callable[[chex.Array], chex.Array] 26 | l_en: Callable[[chex.Array], chex.Array] 27 | 28 | 29 | class ActorCriticParams(NamedTuple): 30 | actor: hk.Params 31 | critic: hk.Params 32 | outer_critic: Optional[hk.Params] 33 | 34 | 35 | class TrainingState(NamedTuple): 36 | """Contains training state for the learner.""" 37 | 38 | params: ActorCriticParams 39 | meta_params: Optional[MetaParams] 40 | optimizer_state: optax.OptState 41 | meta_optimizer_state: Optional[optax.OptState] 42 | env_steps: jnp.int32 43 | 44 | 45 | class ActingState(NamedTuple): 46 | """Container for data used during the acting in the environment.""" 47 | 48 | env_state: jumanji.env.State 49 | timestep: jumanji.types.TimeStep 50 | acting_key: chex.PRNGKey 51 | 52 | 53 | class Transition(NamedTuple): 54 | """Container for a transition.""" 55 | 56 | observation: chex.Array 57 | action: chex.Array 58 | reward: jnp.float_ 59 | discount: jnp.float_ 60 | truncation: jnp.bool_ 61 | next_observation: chex.Array 62 | extras: Dict 63 | 64 | 65 | class State(NamedTuple): 66 | """Container for TrainingState and ActingState.""" 67 | 68 | training_state: TrainingState 69 | acting_state: ActingState 70 | -------------------------------------------------------------------------------- /snake/training/utils.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar 2 | 3 | import jax 4 | 5 | T = TypeVar("T") 6 | 7 | 8 | def first_from_device(tree: T) -> T: 9 | return jax.tree_util.tree_map(lambda x: x[0], tree) # type: ignore 10 | -------------------------------------------------------------------------------- /snake_train.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import pickle 4 | from typing import Dict, Tuple 5 | 6 | import hydra 7 | import jax 8 | import omegaconf 9 | from tqdm import trange 10 | 11 | from snake.agent import A2C, meta_params_to_hyper_params 12 | from snake.training import utils 13 | from snake.training.config import convert_config 14 | from snake.training.logger import make_logger_factory 15 | from snake.training.setup_run import setup 16 | from snake.training.types import State 17 | 18 | 19 | @hydra.main(config_path="snake/configs", config_name="config.yaml") 20 | def run(cfg: omegaconf.DictConfig) -> None: 21 | print(omegaconf.OmegaConf.to_yaml(cfg)) 22 | logging.getLogger().setLevel(logging.INFO) 23 | logging.info({"devices": jax.local_devices()}) 24 | 25 | config = convert_config(cfg) 26 | 27 | agent, evaluator = setup(config) 28 | # Can be either 'neptune' or 'terminal'. 29 | make_logger = make_logger_factory( 30 | "terminal", config._asdict(), aggregation_behaviour="mean" 31 | ) 32 | 33 | num_timesteps_per_learner_inner_update_step = ( 34 | config.total_batch_size * config.n_steps 35 | ) 36 | num_timesteps_per_epoch = config.num_timesteps // config.num_eval_points 37 | num_learner_steps_per_epoch = ( 38 | num_timesteps_per_epoch // num_timesteps_per_learner_inner_update_step 39 | ) 40 | 41 | state = agent.init(jax.random.PRNGKey(config.seed)) 42 | 43 | @functools.partial(jax.pmap, axis_name="devices") 44 | def epoch_fn(state: State) -> Tuple[State, Dict]: 45 | state, metrics = jax.lax.scan( 46 | lambda s, _: agent.update(s), state, None, num_learner_steps_per_epoch 47 | ) 48 | return state, metrics 49 | 50 | with make_logger() as logger, trange( 51 | config.num_eval_points + 1 52 | ) as epochs, jax.log_compiles(): 53 | num_learner_updates = 0 54 | 55 | # Log first hyper_params 56 | if state.training_state.meta_params is not None: 57 | initial_hyper_params = meta_params_to_hyper_params( 58 | utils.first_from_device(state.training_state.meta_params) 59 | ) 60 | else: 61 | assert isinstance(agent, A2C) 62 | initial_hyper_params = agent.hyper_params 63 | logger.write(initial_hyper_params._asdict(), "train/mean/", num_learner_updates) 64 | logger.write( 65 | initial_hyper_params._asdict(), "train/sample/", num_learner_updates 66 | ) 67 | 68 | for _ in epochs: 69 | 70 | # Checkpoint 71 | with open(f"state_{num_learner_updates:.2e}.pickle", "wb") as file_: 72 | pickle.dump(utils.first_from_device(state), file_) 73 | logger.save_checkpoint(file_.name) 74 | 75 | # Validation 76 | metrics = evaluator.run_evaluation(state.training_state) 77 | logger.write(utils.first_from_device(metrics), "eval", num_learner_updates) 78 | 79 | # Training steps 80 | state, metrics = epoch_fn(state) 81 | num_learner_updates += num_learner_steps_per_epoch 82 | logger.write(utils.first_from_device(metrics), "train", num_learner_updates) 83 | 84 | 85 | if __name__ == "__main__": 86 | run() 87 | --------------------------------------------------------------------------------