├── .flake8
├── .gitignore
├── .pre-commit-config.yaml
├── README.md
├── commitlint.config.js
├── discounting_chain.ipynb
├── discounting_chain
├── __init__.py
├── a2c.py
├── base.py
├── bmg_a2c.py
├── data
│ ├── discounting_chain_appendix_histories_array.pickle
│ └── discounting_chain_histories_array.pickle
├── data_generation
│ ├── __init__.py
│ ├── n_step.py
│ ├── n_step_data_generator.py
│ └── n_step_utils.py
├── envs
│ ├── __init__.py
│ ├── env_utils.py
│ ├── gymnax_dc_wrapper.py
│ └── gymnax_wrapper.py
├── list_logger.py
├── meta_a2c.py
├── nets.py
├── nets_split.py
└── train_utils.py
├── discounting_chain_train.ipynb
├── mypy.ini
├── plots
├── discounting_chain
│ ├── dc_chain_discount_factor.png
│ ├── dc_chain_discount_factor_appendix.png
│ ├── dc_chain_outer_loss_advantage.png
│ ├── dc_chain_outer_loss_advantage_appendix.png
│ ├── dc_chain_return.png
│ └── dc_chain_return_appendix.png
└── snake
│ ├── snake_discount_factor.png
│ ├── snake_discount_factor_appendix.png
│ ├── snake_outer_loss_advantage.png
│ ├── snake_outer_loss_advantage_appendix.png
│ ├── snake_return.png
│ └── snake_return_appendix.png
├── requirements-dev.txt
├── requirements.txt
├── snake.ipynb
├── snake
├── __init__.py
├── agent
│ ├── __init__.py
│ ├── a2c.py
│ ├── actor_critic_agent.py
│ ├── gae.py
│ └── meta_a2c.py
├── configs
│ ├── agent
│ │ ├── a2c.yaml
│ │ ├── bootstrap.yaml
│ │ └── mgrl.yaml
│ └── config.yaml
├── data
│ ├── a2c_gamma
│ │ ├── MET-2150__train_mean_gamma.csv
│ │ ├── MET-2156__train_mean_gamma.csv
│ │ ├── MET-2158__train_mean_gamma.csv
│ │ ├── MET-2162__train_mean_gamma.csv
│ │ ├── MET-2165__train_mean_gamma.csv
│ │ ├── MET-2168__train_mean_gamma.csv
│ │ ├── MET-2171__train_mean_gamma.csv
│ │ ├── MET-2175__train_mean_gamma.csv
│ │ ├── MET-2177__train_mean_gamma.csv
│ │ └── MET-2180__train_mean_gamma.csv
│ ├── a2c_return
│ │ ├── MET-2150__eval_episode_reward_determinist_policy.csv
│ │ ├── MET-2156__eval_episode_reward_determinist_policy.csv
│ │ ├── MET-2158__eval_episode_reward_determinist_policy.csv
│ │ ├── MET-2162__eval_episode_reward_determinist_policy.csv
│ │ ├── MET-2165__eval_episode_reward_determinist_policy.csv
│ │ ├── MET-2168__eval_episode_reward_determinist_policy.csv
│ │ ├── MET-2171__eval_episode_reward_determinist_policy.csv
│ │ ├── MET-2175__eval_episode_reward_determinist_policy.csv
│ │ ├── MET-2177__eval_episode_reward_determinist_policy.csv
│ │ └── MET-2180__eval_episode_reward_determinist_policy.csv
│ ├── appendix
│ │ ├── bias
│ │ │ ├── bias_bootstrap_no_outer_no_norm.csv
│ │ │ ├── bias_bootstrap_no_outer_norm.csv
│ │ │ ├── bias_bootstrap_outer_no_norm.csv
│ │ │ ├── bias_bootstrap_outer_norm.csv
│ │ │ ├── bias_mgrl_no_outer_no_norm.csv
│ │ │ ├── bias_mgrl_no_outer_norm.csv
│ │ │ ├── bias_mgrl_outer_no_norm.csv
│ │ │ └── bias_mgrl_outer_norm.csv
│ │ ├── gamma
│ │ │ ├── gamma_a2c_no_norm.csv
│ │ │ ├── gamma_a2c_norm.csv
│ │ │ ├── gamma_bootstrap_no_outer_no_norm.csv
│ │ │ ├── gamma_bootstrap_no_outer_norm.csv
│ │ │ ├── gamma_bootstrap_outer_no_norm.csv
│ │ │ ├── gamma_bootstrap_outer_norm.csv
│ │ │ ├── gamma_mgrl_no_outer_no_norm.csv
│ │ │ ├── gamma_mgrl_no_outer_norm.csv
│ │ │ ├── gamma_mgrl_outer_no_norm.csv
│ │ │ └── gamma_mgrl_outer_norm.csv
│ │ └── return
│ │ │ ├── return_a2c_no_norm.csv
│ │ │ ├── return_a2c_norm.csv
│ │ │ ├── return_bootstrap_no_outer_no_norm.csv
│ │ │ ├── return_bootstrap_no_outer_norm.csv
│ │ │ ├── return_bootstrap_outer_no_norm.csv
│ │ │ ├── return_bootstrap_outer_norm.csv
│ │ │ ├── return_mgrl_no_outer_no_norm.csv
│ │ │ ├── return_mgrl_no_outer_norm.csv
│ │ │ ├── return_mgrl_outer_no_norm.csv
│ │ │ └── return_mgrl_outer_norm.csv
│ ├── bootstrap_bias
│ │ ├── MET-2154__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2160__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2166__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2172__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2178__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2183__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2188__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2192__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2196__train_mean_advantage_outer_loss.csv
│ │ └── MET-2199__train_mean_advantage_outer_loss.csv
│ ├── bootstrap_gamma
│ │ ├── MET-2154__train_mean_gamma.csv
│ │ ├── MET-2160__train_mean_gamma.csv
│ │ ├── MET-2166__train_mean_gamma.csv
│ │ ├── MET-2172__train_mean_gamma.csv
│ │ ├── MET-2178__train_mean_gamma.csv
│ │ ├── MET-2183__train_mean_gamma.csv
│ │ ├── MET-2188__train_mean_gamma.csv
│ │ ├── MET-2192__train_mean_gamma.csv
│ │ ├── MET-2196__train_mean_gamma.csv
│ │ └── MET-2199__train_mean_gamma.csv
│ ├── bootstrap_outer_critic_bias
│ │ ├── MET-2153__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2161__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2167__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2173__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2181__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2185__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2190__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2195__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2198__train_mean_advantage_outer_loss.csv
│ │ └── MET-2200__train_mean_advantage_outer_loss.csv
│ ├── bootstrap_outer_critic_gamma
│ │ ├── MET-2153__train_mean_gamma.csv
│ │ ├── MET-2161__train_mean_gamma.csv
│ │ ├── MET-2167__train_mean_gamma.csv
│ │ ├── MET-2173__train_mean_gamma.csv
│ │ ├── MET-2181__train_mean_gamma.csv
│ │ ├── MET-2185__train_mean_gamma.csv
│ │ ├── MET-2190__train_mean_gamma.csv
│ │ ├── MET-2195__train_mean_gamma.csv
│ │ ├── MET-2198__train_mean_gamma.csv
│ │ └── MET-2200__train_mean_gamma.csv
│ ├── bootstrap_outer_critic_return
│ │ ├── MET-2153__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2161__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2167__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2173__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2181__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2185__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2190__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2195__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2198__eval_episode_reward_stochastic_policy.csv
│ │ └── MET-2200__eval_episode_reward_stochastic_policy.csv
│ ├── bootstrap_return
│ │ ├── MET-2154__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2160__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2166__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2172__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2178__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2183__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2188__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2192__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2196__eval_episode_reward_stochastic_policy.csv
│ │ └── MET-2199__eval_episode_reward_stochastic_policy.csv
│ ├── mgrl_bias
│ │ ├── MET-2151__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2157__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2163__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2169__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2174__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2179__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2184__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2187__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2191__train_mean_advantage_outer_loss.csv
│ │ └── MET-2194__train_mean_advantage_outer_loss.csv
│ ├── mgrl_gamma
│ │ ├── MET-2151__train_mean_gamma.csv
│ │ ├── MET-2157__train_mean_gamma.csv
│ │ ├── MET-2163__train_mean_gamma.csv
│ │ ├── MET-2169__train_mean_gamma.csv
│ │ ├── MET-2174__train_mean_gamma.csv
│ │ ├── MET-2179__train_mean_gamma.csv
│ │ ├── MET-2184__train_mean_gamma.csv
│ │ ├── MET-2187__train_mean_gamma.csv
│ │ ├── MET-2191__train_mean_gamma.csv
│ │ └── MET-2194__train_mean_gamma.csv
│ ├── mgrl_outer_critic_bias
│ │ ├── MET-2152__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2159__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2164__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2170__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2176__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2182__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2186__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2189__train_mean_advantage_outer_loss.csv
│ │ ├── MET-2193__train_mean_advantage_outer_loss.csv
│ │ └── MET-2197__train_mean_advantage_outer_loss.csv
│ ├── mgrl_outer_critic_gamma
│ │ ├── MET-2152__train_mean_gamma.csv
│ │ ├── MET-2159__train_mean_gamma.csv
│ │ ├── MET-2164__train_mean_gamma.csv
│ │ ├── MET-2170__train_mean_gamma.csv
│ │ ├── MET-2176__train_mean_gamma.csv
│ │ ├── MET-2182__train_mean_gamma.csv
│ │ ├── MET-2186__train_mean_gamma.csv
│ │ ├── MET-2189__train_mean_gamma.csv
│ │ ├── MET-2193__train_mean_gamma.csv
│ │ └── MET-2197__train_mean_gamma.csv
│ ├── mgrl_outer_critic_return
│ │ ├── MET-2152__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2159__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2164__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2170__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2176__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2182__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2186__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2189__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2193__eval_episode_reward_stochastic_policy.csv
│ │ └── MET-2197__eval_episode_reward_stochastic_policy.csv
│ └── mgrl_return
│ │ ├── MET-2151__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2157__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2163__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2169__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2174__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2179__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2184__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2187__eval_episode_reward_stochastic_policy.csv
│ │ ├── MET-2191__eval_episode_reward_stochastic_policy.csv
│ │ └── MET-2194__eval_episode_reward_stochastic_policy.csv
├── env
│ ├── __init__.py
│ └── snake.py
├── networks
│ ├── __init__.py
│ ├── actor_critic.py
│ ├── cnn.py
│ ├── distribution.py
│ └── snake.py
└── training
│ ├── __init__.py
│ ├── config.py
│ ├── evaluator.py
│ ├── logger.py
│ ├── setup_run.py
│ ├── types.py
│ └── utils.py
└── snake_train.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | select = A,B,C,D,E,F,G,I,N,T,W
3 | exclude =
4 | .tox,
5 | .git,
6 | __pycache__,
7 | build,
8 | dist,
9 | proto/*,
10 | *.pyc,
11 | *.egg-info,
12 | .cache,
13 | .eggs
14 | max-line-length=100
15 | import-order-style = google
16 | application-import-names = metal
17 | doctests = True
18 | docstring-convention = google
19 | per-file-ignores = __init__.py:F401
20 |
21 | ignore =
22 | D107 # Do not require docstrings for __init__
23 | W503 # line break before binary operator (not compatible with black)
24 | E731 # do not assign a lambda expression, use a def
25 | N802 # function names should be lowercase
26 | N803 # arguments should be lowercase
27 | N806 # variables should be lowercase
28 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__/
2 | **/.ipynb_checkpoints/
3 | venv/
4 | libtpu_lockfile
5 | .idea/
6 | **videos/
7 | **/params*
8 |
9 | # wandb
10 | **/logs/
11 | **/outputs/
12 | **/multirun/
13 | **/wandb/
14 |
15 | # internal scripts
16 | **/internal_*
17 |
18 | *.pyc
19 | *.iml
20 | *.xml
21 | *.log
22 | playground/*
23 | logs/*
24 | .neptune
25 |
26 | # Distribution / packaging
27 | .Python
28 | build/
29 | develop-eggs/
30 | dist/
31 | downloads/
32 | eggs/
33 | .eggs/
34 | lib/
35 | lib64/
36 | parts/
37 | sdist/
38 | var/
39 | wheels/
40 | pip-wheel-metadata/
41 | share/python-wheels/
42 | *.egg-info/
43 | .installed.cfg
44 | *.egg
45 | MANIFEST
46 | src/*
47 |
48 | # Tests
49 | coverage_html_report
50 | .coverage
51 | test-results.xml
52 |
53 | # Functional tests
54 | generated-functional-tests-runs-ci.yml
55 |
56 | # IDE Configs
57 | .vscode
58 | .devcontainer
59 | .idea
60 |
61 | # Logs
62 | /examples/logs/
63 |
64 | /nul/
65 | /checkpoints/
66 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_stages: [ "commit", "commit-msg", "push" ]
2 | default_language_version:
3 | python: python3.8
4 |
5 |
6 | repos:
7 | - repo: https://github.com/timothycrosley/isort
8 | rev: 5.0.4
9 | hooks:
10 | - id: isort
11 | args: ["--profile", "black"]
12 |
13 | - repo: https://github.com/ambv/black
14 | rev: 21.5b1
15 | hooks:
16 | - id: black
17 |
18 | - repo: https://github.com/pre-commit/pre-commit-hooks
19 | rev: v3.4.0
20 | hooks:
21 | # - id: end-of-file-fixer # removed because of notebooks
22 | - id: debug-statements
23 | - id: requirements-txt-fixer
24 | - id: mixed-line-ending
25 | - id: check-yaml
26 | args: [ '--unsafe' ]
27 | - id: trailing-whitespace
28 |
29 | - repo: https://gitlab.com/pycqa/flake8
30 | rev: 3.9.2
31 | hooks:
32 | - id: flake8
33 | args:
34 | - --max-line-length=100
35 | - --max-cognitive-complexity=10
36 | - --ignore=E266,E501,E731,W503
37 | additional_dependencies:
38 | - pep8-naming
39 | - flake8-builtins
40 | - flake8-comprehensions
41 | - flake8-bugbear
42 | - flake8-pytest-style
43 | - flake8-cognitive-complexity
44 |
45 | - repo: https://github.com/pre-commit/mirrors-mypy
46 | rev: v0.812
47 | hooks:
48 | - id: mypy
49 |
50 | - repo: https://github.com/alessandrojcm/commitlint-pre-commit-hook
51 | rev: v9.0.0
52 | hooks:
53 | - id: commitlint
54 | stages: [ commit-msg ]
55 | additional_dependencies: [ '@commitlint/config-conventional' ]
56 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Debiasing Meta-Gradient Reinforcement Learning by Learning the Outer Value Function
2 |
3 | This repository contains the code from the paper [_Debiasing Meta-Gradient Reinforcement Learning by
4 | Learning the Outer Value Function_](https://arxiv.org/abs/2211.10550) (Clément Bonnet, Laurence Midgley,
5 | Alexandre Laterre) published at the [6th Workshop on Meta-Learning](https://meta-learn.github.io/2022/)
6 | at NeurIPS 2022, New Orleans.
7 |
8 | ## Abstract
9 | Meta-gradient Reinforcement Learning (RL) allows agents to self-tune their hyper-parameters in an online fashion during training.
10 | In this paper, we identify a bias in the meta-gradient of current meta-gradient RL approaches.
11 | This bias comes from using the critic that is trained using the meta-learned discount factor for the advantage estimation in the outer objective which requires a different discount factor.
12 | Because the meta-learned discount factor is typically lower than the one used in the outer objective, the resulting bias can cause the meta-gradient to favor myopic policies.
13 | We propose a simple solution to this issue: we eliminate this bias by using an alternative, *outer* value function in the estimation of the outer loss.
14 | To obtain this outer value function we add a second head to the critic network and train it alongside the classic critic, using the outer loss discount factor.
15 | On an illustrative toy problem, we show that the bias can cause catastrophic failure of current meta-gradient RL approaches, and show that our proposed solution fixes it.
16 | We then apply our method to a more complex environment and demonstrate that fixing the meta-gradient bias can significantly improve performance.
17 |
18 |
19 | ## Experiments
20 |
21 | We denote:
22 | - A2C: Advantage Actor Critic
23 | - MG: meta-gradient algorithm from [Xu et al., 2018]
24 | - BMG: bootstrapped meta-gradients from [Flennerhag et al., 2022]
25 | - MG outer-critic: the MG algorithm equipped with an outer-critic that estimates the outer value function used in the outer loss
26 | - BMG outer-critic: the BMG algorithm similarly equipped with an outer-critic
27 |
28 | ### Discounting Chain
29 |
30 |
31 |
32 |
33 |
34 | ### Snake
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 | ## Reproducibility
43 |
44 | We provide a `requirements.txt` file with all the tagged packages needed to reproduce the
45 | experiments in the paper.
46 | Here is a snippet of commands to set up a virtual environment and install these packages.
47 | Alternatively, one can use a Conda environment or a similar solution.
48 | ```shell
49 | python -m venv venv
50 | source venv/bin/activate
51 | pip install -U pip setuptools wheel
52 | pip install -r requirements.txt
53 | ```
54 |
55 | ### Discounting Chain
56 |
57 | The Discounting Chain environment is originally from [bsuite](https://github.com/deepmind/bsuite) but is imported from [Gymnax](https://github.com/RobertTLange/gymnax) in this paper to benefit from its JAX implementation.
58 |
59 | To reproduce the experiments on the Discounting Chain environment, you can run the jupyter notebook `discounting_chain_train.ipynb`.
60 | The other notebook `discounting_chain.ipynb` loads the data and plots the figures from the paper.
61 |
62 | ### Snake
63 |
64 | The Snake environment is provided by [Jumanji](https://github.com/instadeepai/jumanji) in JAX.
65 |
66 | To reproduce the experiments on the Snake environment, one can run the following commands.
67 |
68 | - Advantage Actor Critic (A2C)
69 | ```shell
70 | python snake_train.py -m agent=a2c training.seed=1,2,3,4,5,6,7,8,9,10
71 | ```
72 |
73 | - Meta-Gradient (MG)
74 | ```shell
75 | python snake_train.py -m agent=mgrl agent.outer_critic=false training.seed=1,2,3,4,5,6,7,8,9,10
76 | python snake_train.py -m agent=mgrl agent.outer_critic=true training.seed=1,2,3,4,5,6,7,8,9,10
77 | ```
78 |
79 | - Bootstrapped Meta-Gradient (BMG)
80 | ```shell
81 | python snake_train.py -m agent=bootstrap agent.outer_critic=false training.seed=1,2,3,4,5,6,7,8,9,10
82 | python snake_train.py -m agent=bootstrap agent.outer_critic=true training.seed=1,2,3,4,5,6,7,8,9,10
83 | ```
84 |
85 | - appendix
86 | ```shell
87 | python snake_train.py -m agent=a2c agent.outer_critic=false agent.normalize_advantage=false,true agent.normalize_outer_advantage=false training.seed=1
88 | python snake_train.py -m agent=mgrl agent.outer_critic=false,true agent.normalize_advantage=false,true agent.normalize_outer_advantage=false,true training.seed=1
89 | python snake_train.py -m agent=bootstrap agent.outer_critic=false,true agent.normalize_advantage=false,true agent.normalize_outer_advantage=false,true training.seed=1
90 | ```
91 |
92 | Note that the default logger is `"terminal"`. If you want to save the data, a Neptune logger is implemented,
93 | and you can enable it by replacing `"terminal"` with `"neptune"` in `snake_train.py`.
94 |
95 | For the paper, the data from these runs was collected and uploaded in `snake/data/`.
96 | The `snake.ipynb` notebook loads this data and make the plots from the paper.
97 |
98 |
99 | ## Citation
100 |
101 | For attribution in academic contexts, please use the following citation.
102 |
103 | ```
104 | @misc{bonnet2022debiasing,
105 | title = {Debiasing Meta-Gradient Reinforcement Learning by Learning the Outer Value Function},
106 | author = {Bonnet, Clément and Midgley, Laurence and Laterre, Alexandre},
107 | doi = {10.48550/ARXIV.2211.10550},
108 | url = {https://arxiv.org/abs/2211.10550},
109 | year = {2022},
110 | booktitle={Sixth Workshop on Meta-Learning at the Conference on Neural Information Processing Systems},
111 | }
112 | ```
113 |
--------------------------------------------------------------------------------
/commitlint.config.js:
--------------------------------------------------------------------------------
1 | module.exports = {
2 | extends: ['@commitlint/config-conventional'],
3 | rules: {
4 | 'subject-case': [1, 'always'],
5 | 'header-max-length': [1, 'always', 100],
6 | 'type-enum': [
7 | 2,
8 | 'always',
9 | [
10 | 'build',
11 | 'chore',
12 | 'ci',
13 | 'docs',
14 | 'feat',
15 | 'fix',
16 | 'perf',
17 | 'refactor',
18 | 'revert',
19 | 'style',
20 | 'test',
21 | 'exp',
22 | 'func',
23 | ],
24 | ],
25 | },
26 | }
27 |
--------------------------------------------------------------------------------
/discounting_chain/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/discounting_chain/__init__.py
--------------------------------------------------------------------------------
/discounting_chain/base.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Dict, NamedTuple, Optional, Tuple
2 |
3 | import chex
4 | from acme import specs
5 | from acme.jax.networks import Action, Observation
6 | from dm_env import TimeStep
7 | from jax import numpy as jnp
8 | from typing_extensions import Protocol
9 |
10 | ArrayTree = Any
11 | AgentState = ArrayTree
12 | EnvironmentState = ArrayTree
13 | DataGeneratorState = ArrayTree
14 | BufferState = ArrayTree
15 | Extras = ArrayTree
16 | Batch = ArrayTree
17 |
18 | Metrics = Dict[str, jnp.ndarray]
19 | SelectAction = Callable[[AgentState, Observation, chex.PRNGKey], Tuple[Action, Extras]]
20 | EvaluateFn = Callable[[AgentState, chex.PRNGKey], Metrics]
21 |
22 |
23 | class DataGenerationInfo(NamedTuple):
24 | """Container for the information returned by data generation."""
25 |
26 | metrics: Metrics
27 | num_steps: jnp.int_
28 | num_episodes: jnp.int_
29 |
30 |
31 | class EnvLoopState(NamedTuple):
32 | """State of the environment loop."""
33 |
34 | data_generator: DataGeneratorState
35 | buffer: BufferState
36 | agent: AgentState
37 | num_steps: jnp.int_
38 | num_episodes: jnp.int_
39 | num_iterations: jnp.int_
40 |
41 |
42 | class GenerateDataFn(Protocol):
43 | """Callable specification for generation of experience.
44 |
45 | This typically contains multi-step, batched acting within a jittable environment."""
46 |
47 | def __call__(
48 | self, agent_state: AgentState, data_generator_state: DataGeneratorState
49 | ) -> Tuple[Batch, DataGeneratorState, DataGenerationInfo]:
50 | """
51 | The `generate_data` function.
52 |
53 | Args:
54 | agent_state: State of the agent itself (i.e. trainable parameters),
55 | used for generating actions.
56 | data_generator_state: State used during acting for management of the environment
57 | (e.g. keeping track of environment state).
58 |
59 | Returns:
60 | batch: Batch of data that will be passed to the agent update.
61 | data_generator_state: Updated data generator state.
62 | data_generation_info: Metrics from the data generation (information on acting).
63 | """
64 |
65 |
66 | class DataGenerator(NamedTuple):
67 | """Container specifying pure functions for generation of experience data."""
68 |
69 | init: Callable[[chex.PRNGKey], DataGeneratorState]
70 | generate_data: GenerateDataFn
71 |
72 |
73 | class Buffer(NamedTuple):
74 | """Buffer for managing the storage of experience generated by the actor,
75 | and the sampling of experience for updating the agent's parameters.
76 |
77 | The `sample` and `add` functions are both jittable."""
78 |
79 | add: Callable[[BufferState, Batch], BufferState]
80 | sample: Callable[[BufferState], Tuple[Batch, BufferState]]
81 | init: Callable[[chex.PRNGKey], BufferState]
82 |
83 |
84 | class Agent(NamedTuple):
85 | """Pure functions defining the key components of the agent.
86 |
87 | Both the update function (which updates the agent's learnt parameters), and the
88 | select action function (which selects actions within the environment) are jittable.
89 | If it is desired to use multiple devices, then the update function should have the
90 | appropriately placed `jax.lax.pmean` to aggregate gradients across devices.
91 | """
92 |
93 | init: Callable[[chex.PRNGKey], AgentState]
94 | select_action: SelectAction
95 | update: Callable[[AgentState, Batch], Tuple[AgentState, Metrics]]
96 | # `select_action` will be used for evaluation if `select_action_eval` is None
97 | select_action_eval: Optional[SelectAction] = None
98 |
99 |
100 | class OnlineAgent(NamedTuple):
101 | init: Callable[[chex.PRNGKey], AgentState]
102 | select_action: SelectAction
103 | update: Callable[[AgentState], Tuple[AgentState, Metrics]]
104 | select_action_eval: Optional[SelectAction] = None
105 |
106 |
107 | class StepEnvironment(Protocol):
108 | """
109 | Environment transition definition.
110 |
111 | Note: The environment step function automatically resets the environment when an
112 | episode is completed. This allows for seamless, jitted n-step acting, however it requires some
113 | care when setting up the environment step function. Therefore, if you are creating an instance
114 | of this function, please refer to the details of the `__call__` function below (specifically the
115 | docstring of the returned objects), as well as the example environment in
116 | `ridl.functional.testing.toy_env`.
117 |
118 | """
119 |
120 | def __call__(
121 | self, state: EnvironmentState, action: Action
122 | ) -> Tuple[EnvironmentState, TimeStep, Metrics]:
123 | """
124 | Step the environment to the next state in the MDP with automatic resetting of
125 | terminated episodes.
126 |
127 | Args:
128 | state: State of the environment at the current timestep.
129 | action: Action taken at the current timestep.
130 |
131 | Returns:
132 | state: State of the environment at the next state (if the episode is
133 | still underway), or the state from a reset (if the episode terminates).
134 | timestep: Timestep containing information on the transition.
135 | If the episode terminates, then the fields of the timestep all correspond to the
136 | terminal timestep EXCEPT for the observation which corresponds to the observation
137 | from the auto-reset, as this is required for the next acting state, and the
138 | terminal observation is not needed.
139 | metrics: Relevant information for logging.
140 | """
141 |
142 |
143 | class Environment(NamedTuple):
144 | """Minimal container which stores the key components of a jax environment, which has a jittable
145 | step function.
146 |
147 | Note:
148 | - This container is used within the data generator to generate experience, and is
149 | therefore not explicitly passed to the environment loop.
150 | - The `step` function is assumed to automatically reset the environment if an
151 | episode terminates, please refer to the `StepEnvironment` for details.
152 | """
153 |
154 | init: Callable[[chex.PRNGKey], Tuple[EnvironmentState, TimeStep]]
155 | step: StepEnvironment
156 | spec: specs.EnvironmentSpec
157 |
--------------------------------------------------------------------------------
/discounting_chain/data/discounting_chain_appendix_histories_array.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/discounting_chain/data/discounting_chain_appendix_histories_array.pickle
--------------------------------------------------------------------------------
/discounting_chain/data/discounting_chain_histories_array.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/discounting_chain/data/discounting_chain_histories_array.pickle
--------------------------------------------------------------------------------
/discounting_chain/data_generation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/discounting_chain/data_generation/__init__.py
--------------------------------------------------------------------------------
/discounting_chain/data_generation/n_step.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, NamedTuple, Tuple
2 |
3 | import chex
4 | import jax
5 | import typing_extensions
6 | from acme.types import Transition
7 | from dm_env import TimeStep
8 | from jax import numpy as jnp
9 |
10 | from discounting_chain.base import (
11 | AgentState,
12 | EnvironmentState,
13 | Metrics,
14 | SelectAction,
15 | StepEnvironment,
16 | )
17 |
18 |
19 | class NStepGeneratorState(NamedTuple):
20 | """Container used within the `run_n_step` function for managing objects related to the
21 | generation of data through acting in the environment."""
22 |
23 | environment_state: EnvironmentState
24 | previous_timestep: TimeStep
25 | rng_key: chex.PRNGKey
26 |
27 |
28 | class AccumulateBatchMetricsFn(typing_extensions.Protocol):
29 | """Accumulate metrics over a batch of n-step transitions."""
30 |
31 | def __call__(self, batch: Transition, metrics: Metrics) -> Metrics:
32 | """
33 | The function for accumulating the metrics.
34 |
35 | Args:
36 | batch: Batch of transition data with leading axis of (batch_size, n_step).
37 | metrics: Batch of n_step metrics with leading axis of (batch_size, n_step).
38 |
39 | Returns:
40 | metrics: Relevant information for logging.
41 | """
42 |
43 |
44 | def _run_n_step(
45 | n_step: int,
46 | select_action: SelectAction,
47 | step_environment: StepEnvironment,
48 | agent_state: AgentState,
49 | data_generator_state: NStepGeneratorState,
50 | ) -> Tuple[Transition, NStepGeneratorState, Metrics]:
51 | """
52 | Run n environment steps (unbatched).
53 |
54 | Args:
55 | n_step: Number of environment steps.
56 | select_action: Function for action selection by the agent.
57 | step_environment: Function for stepping the environment.
58 | agent_state: State of the agent.
59 | data_generator_state: State for managing the environment interaction.
60 |
61 | Returns:
62 | transition: N-steps of experience
63 | data_generator_state: Updated data generator state
64 | metrics: Relevant information for logging.
65 |
66 | This function may be called with vmap for batching (see `make_run_n_step`).
67 |
68 | """
69 |
70 | def run_one_step(
71 | data_generator_state: NStepGeneratorState,
72 | ) -> Tuple[NStepGeneratorState, Tuple[Transition, Metrics]]:
73 | """Run a single step of acting - for use with `jax.lax.scan`."""
74 | rng_key, rng_subkey = jax.random.split(data_generator_state.rng_key)
75 | action, extras = select_action(
76 | agent_state, data_generator_state.previous_timestep.observation, rng_subkey
77 | )
78 | next_env_state, next_timestep, metrics = step_environment(
79 | data_generator_state.environment_state, action
80 | )
81 | trajectory = Transition(
82 | observation=data_generator_state.previous_timestep.observation,
83 | action=action,
84 | discount=next_timestep.discount,
85 | reward=next_timestep.reward,
86 | next_observation=next_timestep.observation,
87 | extras=extras,
88 | )
89 | data_generator_state = NStepGeneratorState(
90 | environment_state=next_env_state,
91 | previous_timestep=next_timestep,
92 | rng_key=rng_key,
93 | )
94 | return data_generator_state, (trajectory, metrics)
95 |
96 | next_data_generator_state, (trajectory, metrics) = jax.lax.scan(
97 | lambda act_state, xs: run_one_step(act_state),
98 | init=data_generator_state,
99 | xs=None,
100 | length=n_step,
101 | )
102 |
103 | return trajectory, next_data_generator_state, metrics
104 |
105 |
106 | def accumulate_batch_metrics(batch: Transition, metrics: Metrics) -> Metrics:
107 | """Accumulate metrics over the batch of n-steps.
108 |
109 | This function assumes that `metrics` uses NaNs to mask values, when an environment step does
110 | not result in a metric that should be recorded. For example, because we only want to log
111 | episode returns for completed episodes, values within the batch of n-steps are typically NaN
112 | for all transitions where the episode is not yet complete.
113 | """
114 | # We calculate the number of complete episodes using the discount factor.
115 | num_episodes = jnp.sum(batch.discount == 0)
116 | accumulated_metrics = {"num_episodes": num_episodes}
117 |
118 | # Record the max and mean (masking NaN values) for metrics returned by the environment.
119 | mean_metrics = jax.tree_map(jnp.nanmean, metrics)
120 | max_metrics = jax.tree_map(jnp.nanmax, metrics)
121 | accumulated_metrics.update(
122 | {key + "_mean": val for key, val in mean_metrics.items()}
123 | )
124 | accumulated_metrics.update({key + "_max": val for key, val in max_metrics.items()})
125 | return accumulated_metrics
126 |
127 |
128 | def make_run_n_step(
129 | n_step: int,
130 | select_action: SelectAction,
131 | step_environment: StepEnvironment,
132 | accumulate_batch_metrics_fn: AccumulateBatchMetricsFn = accumulate_batch_metrics,
133 | ) -> Callable[
134 | [AgentState, NStepGeneratorState], Tuple[Transition, NStepGeneratorState, Metrics]
135 | ]:
136 | """
137 | Create a run-n-step function for running n environment steps in batches.
138 |
139 | Args:
140 | n_step: Number of environment steps.
141 | select_action: Function for action selection by the agent.
142 | step_environment: Function for stepping the environment.
143 | accumulate_batch_metrics_fn: Function for accumulating metrics over the batch of n-steps.
144 |
145 | Returns:
146 | run_n_step: Function for generating experience through batched n-step interaction with
147 | an environment.
148 |
149 | """
150 |
151 | def run_n_step(
152 | agent_state: AgentState, data_generator_state: NStepGeneratorState
153 | ) -> Tuple[Transition, NStepGeneratorState, Metrics]:
154 | """
155 | Run a batch of n environment steps.
156 |
157 | Args:
158 | agent_state: State of the agent.
159 | data_generator_state: state used for managing the environment interaction,
160 | with the leading axis of each node in the pytree representing the batch size.
161 |
162 | Returns:
163 | trajectory_batch: Batch of experience, with leading axis of [batch_size, n_step].
164 | data_generator_state: Updated data_generator_state.
165 | metrics: Metrics accumulated during the generation of experience.
166 |
167 | """
168 | trajectory_batch, data_generator_state, episode_metrics = jax.vmap(
169 | _run_n_step, in_axes=(None, None, None, None, 0)
170 | )(n_step, select_action, step_environment, agent_state, data_generator_state)
171 | metrics = accumulate_batch_metrics_fn(trajectory_batch, episode_metrics)
172 | return trajectory_batch, data_generator_state, metrics
173 |
174 | return run_n_step
175 |
--------------------------------------------------------------------------------
/discounting_chain/data_generation/n_step_data_generator.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import chex
4 | import jax
5 | from acme.types import Transition
6 | from jax import numpy as jnp
7 |
8 | from discounting_chain.base import (
9 | AgentState,
10 | DataGenerationInfo,
11 | DataGenerator,
12 | Environment,
13 | SelectAction,
14 | )
15 | from discounting_chain.data_generation.n_step import (
16 | NStepGeneratorState,
17 | make_run_n_step,
18 | )
19 |
20 |
21 | def make_n_step_data_generator(
22 | select_action: SelectAction,
23 | environment: Environment,
24 | n_step: int,
25 | batch_size_per_device: int,
26 | ) -> DataGenerator:
27 | """
28 | Create a data generator that runs batches of n-step interaction with an environment,
29 | using the `run_n_step` function.
30 |
31 | Args:
32 | select_action: Callable specifying action selection by the agent.
33 | environment: Environment container specifying the environment state initialisation and
34 | step functions.
35 | n_step: Numer of environment steps per batch.
36 | batch_size_per_device: Batch size that we vmap over in the `run_n_step` function.
37 | We typically run the `generate_data` method across multiple devices, in which case the
38 | total batch size will be equal to batch_size_per_device*num_devices.
39 |
40 | Returns:
41 | The n-step data generator.
42 |
43 | """
44 |
45 | def init(rng_key: chex.PRNGKey) -> NStepGeneratorState:
46 | """Initialise the data generation state."""
47 | env_rng_key, data_generator_rng_key = jax.random.split(rng_key)
48 | env_rng_key_batch = jax.random.split(env_rng_key, batch_size_per_device)
49 | environment_state, time_step = jax.vmap(environment.init)(env_rng_key_batch)
50 |
51 | data_generation_state = NStepGeneratorState(
52 | environment_state=environment_state,
53 | previous_timestep=time_step,
54 | rng_key=jax.random.split(data_generator_rng_key, batch_size_per_device),
55 | )
56 | return data_generation_state
57 |
58 | run_n_step_fn = make_run_n_step(n_step, select_action, environment.step)
59 |
60 | def generate_data(
61 | agent_state: AgentState, data_generator_state: NStepGeneratorState
62 | ) -> Tuple[Transition, NStepGeneratorState, DataGenerationInfo]:
63 | """Generate a batch of data using the `run_n_step_fn`."""
64 | batch, data_generator_state, metrics = run_n_step_fn(
65 | agent_state, data_generator_state
66 | )
67 | data_generation_info = DataGenerationInfo(
68 | metrics=metrics,
69 | num_episodes=metrics["num_episodes"],
70 | num_steps=jnp.array(batch_size_per_device * n_step, "int"),
71 | )
72 | return batch, data_generator_state, data_generation_info
73 |
74 | data_generator = DataGenerator(init, generate_data)
75 | return data_generator
76 |
--------------------------------------------------------------------------------
/discounting_chain/data_generation/n_step_utils.py:
--------------------------------------------------------------------------------
1 | import chex
2 | import jax
3 | from acme.specs import EnvironmentSpec
4 | from acme.types import Transition
5 | from jax import numpy as jnp
6 |
7 |
8 | def create_fake_n_step_batch(
9 | environment_spec: EnvironmentSpec,
10 | batch_size: int,
11 | n_step: int,
12 | fake_extras: chex.ArrayTree = (),
13 | ) -> Transition:
14 | """
15 | Create a fake batch of data that matches what is returned by the n-step generator.
16 | This is useful for initialisation of the buffer, which has a state initialisation that
17 | depends on the batch it receives from experience generation.
18 |
19 | Args:
20 | environment_spec: Environment spec. Each field in the `environment_spec` has to implement
21 | the `generate_value` method for this function to work.
22 | fake_extras: An object matching the shapes and types of Extras returned by the
23 | `select_action` method of an agent.
24 | Returns:
25 | batch: Batch of data matching what the n-step generator returns.
26 | """
27 | transition = Transition(
28 | observation=environment_spec.observations.generate_value(),
29 | action=environment_spec.actions.generate_value(),
30 | reward=environment_spec.rewards.generate_value(),
31 | next_observation=environment_spec.observations.generate_value(),
32 | discount=environment_spec.discounts.generate_value(),
33 | extras=fake_extras,
34 | )
35 | batch = jax.tree_map(
36 | lambda x: jnp.broadcast_to(x, shape=(batch_size, n_step, *x.shape)), transition
37 | )
38 | return batch
39 |
--------------------------------------------------------------------------------
/discounting_chain/envs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/discounting_chain/envs/__init__.py
--------------------------------------------------------------------------------
/discounting_chain/envs/env_utils.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple, Tuple
2 |
3 | import chex
4 | import jax
5 | from dm_env import TimeStep
6 | from jax import numpy as jnp
7 |
8 | from discounting_chain.base import Action, Environment, Metrics
9 |
10 |
11 | class EnvStateWithEpisodeMetrics(NamedTuple):
12 | """Container that extends an environment state to additionally keep track of the step count
13 | within episodes, and the cumulative reward."""
14 |
15 | original_env_state: chex.ArrayTree
16 | episode_step_count: jnp.int32
17 | cumulative_reward: jnp.float32
18 |
19 |
20 | def wrap_env_for_episode_metrics(env: Environment) -> Environment:
21 | """Wrapper for environment that records episode length and return, and adds them to the
22 | environment's metrics. NaN values are used for episode length & return in the metrics for
23 | non-terminal timesteps."""
24 |
25 | def accumulate_return_and_step_count(
26 | cumulative_reward: jnp.float32, step_count: jnp.int32, timestep: TimeStep
27 | ) -> Tuple[chex.Array, chex.Array]:
28 | """Update the episode's step count and cumulative return, resetting them to 0 if the episode
29 | is completed."""
30 | step_count = jax.lax.select(
31 | timestep.last(), jnp.array(0, dtype=int), step_count + 1
32 | )
33 | cumulative_reward = jax.lax.select(
34 | timestep.last(),
35 | jnp.array(0.0, dtype=float),
36 | cumulative_reward + timestep.reward,
37 | )
38 | return step_count, cumulative_reward
39 |
40 | def get_episode_metrics(
41 | previous_state: EnvStateWithEpisodeMetrics, timestep: TimeStep
42 | ) -> Metrics:
43 | """Get the episode metrics for the current timestep, if an episode is complete we return the
44 | episode return and length, otherwise we return NaN's to indicate the episode is still
45 | underway."""
46 | episode_return = jax.lax.select(
47 | timestep.last(),
48 | previous_state.cumulative_reward + timestep.reward,
49 | jnp.nan,
50 | )
51 | episode_length = jax.lax.select(
52 | timestep.last(),
53 | jnp.array(previous_state.episode_step_count + 1, dtype=float),
54 | jnp.nan,
55 | )
56 | metrics = {"episode_return": episode_return, "episode_length": episode_length}
57 | return metrics
58 |
59 | def init(rng_key: chex.PRNGKey) -> Tuple[EnvStateWithEpisodeMetrics, TimeStep]:
60 | """Initialise the environment's state."""
61 | original_env_state, timestep = env.init(rng_key)
62 | env_state = EnvStateWithEpisodeMetrics(
63 | original_env_state=original_env_state,
64 | episode_step_count=jnp.int32(0),
65 | cumulative_reward=jnp.float32(0),
66 | )
67 | return env_state, timestep
68 |
69 | def step(
70 | state: EnvStateWithEpisodeMetrics, action: Action
71 | ) -> Tuple[EnvStateWithEpisodeMetrics, TimeStep, Metrics]:
72 | """Step the environment."""
73 | original_env_state, timestep, metrics = env.step(
74 | state.original_env_state, action
75 | )
76 |
77 | episode_metrics = get_episode_metrics(state, timestep)
78 | metrics.update(episode_metrics)
79 |
80 | episode_step_count, cumulative_reward = accumulate_return_and_step_count(
81 | state.cumulative_reward, state.episode_step_count, timestep
82 | )
83 | state = EnvStateWithEpisodeMetrics(
84 | original_env_state=original_env_state,
85 | cumulative_reward=cumulative_reward,
86 | episode_step_count=episode_step_count,
87 | )
88 |
89 | return state, timestep, metrics
90 |
91 | return Environment(init=init, step=step, spec=env.spec)
92 |
--------------------------------------------------------------------------------
/discounting_chain/envs/gymnax_dc_wrapper.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, NamedTuple, Tuple
2 |
3 | import chex
4 | import jax.random
5 | from acme.types import specs
6 | from dm_env import TimeStep
7 | from gymnax.environments.bsuite.discounting_chain import DiscountingChain
8 | from jax import numpy as jnp
9 |
10 | from discounting_chain.base import Environment, Metrics
11 | from discounting_chain.envs.env_utils import wrap_env_for_episode_metrics
12 |
13 | Action = chex.ArrayTree
14 |
15 |
16 | class State(NamedTuple):
17 | gymnax_state: chex.ArrayTree
18 | key: chex.PRNGKey
19 |
20 |
21 | def create_dc_gmnax(mapping_seed=3) -> Tuple[Environment, Callable]:
22 | gymnax_env = DiscountingChain(mapping_seed=mapping_seed)
23 | env_params = gymnax_env.default_params
24 |
25 | def get_true_value_non_first_step(obs, probs, gamma):
26 | gamma = jax.lax.stop_gradient(gamma)
27 | context, time = obs
28 | time = time * env_params.max_steps_in_episode + 1 # unnormalise time
29 | reward_step = env_params.reward_timestep[jnp.int_(context)]
30 | time_till_reward = reward_step - time
31 | reward_already_recieved = time_till_reward < 0
32 | reward = gymnax_env.reward[jnp.int_(context)]
33 | value = gamma ** time_till_reward * reward
34 | value = jax.lax.select(reward_already_recieved, 0.0, value)
35 | return value
36 |
37 | def get_true_value_first_step_single_action(action, gamma):
38 | time = 0
39 | reward_step = env_params.reward_timestep[jnp.int_(action)]
40 | time_till_reward = reward_step - time
41 | reward_already_recieved = time_till_reward < 0
42 | reward = gymnax_env.reward[jnp.int_(action)]
43 | value = gamma ** time_till_reward * reward
44 | value = jax.lax.select(reward_already_recieved, 0.0, value)
45 | return value
46 |
47 | def get_true_value_first_step(obs, probs, gamma):
48 | values = jax.vmap(get_true_value_first_step_single_action, in_axes=(0, None))(
49 | jnp.arange(probs.shape[0]), gamma
50 | )
51 | chex.assert_equal_shape([probs, values])
52 | return jnp.sum(probs * values)
53 |
54 | def get_true_value(obs, probs, gamma):
55 | context, time = obs
56 | value = jax.lax.cond(
57 | time == 0,
58 | get_true_value_first_step,
59 | get_true_value_non_first_step,
60 | obs,
61 | probs,
62 | gamma,
63 | )
64 | return jax.lax.stop_gradient(value)
65 |
66 | def init(key: chex.PRNGKey):
67 | key, subkey = jax.random.split(key)
68 | obs, state = gymnax_env.reset(subkey, env_params)
69 | timestep = TimeStep(
70 | step_type=jnp.int_(0),
71 | reward=jnp.float32(0.0),
72 | discount=jnp.float32(0.0),
73 | observation=obs,
74 | )
75 | state = State(state, key)
76 | return state, timestep
77 |
78 | def auto_reset(state: State, timestep: TimeStep) -> Tuple[State, TimeStep]:
79 | key, subkey = jax.random.split(state.key)
80 | obs, state = gymnax_env.reset(subkey)
81 | # Replace observation with reset observation.
82 | timestep = timestep._replace(observation=obs)
83 | state = State(state, key)
84 | return state, timestep
85 |
86 | def step(state: State, action: Action) -> Tuple[State, TimeStep, Metrics]:
87 | key, subkey = jax.random.split(state.key)
88 | n_obs, n_state, reward, done, info = gymnax_env.step_env(
89 | subkey, state.gymnax_state, action, env_params
90 | )
91 | timestep = TimeStep(
92 | step_type=jax.lax.select(done, jnp.int_(2), jnp.int_(1)),
93 | reward=jnp.float32(reward),
94 | discount=jnp.float32(1 - done),
95 | observation=n_obs,
96 | )
97 | state = State(n_state, key)
98 |
99 | state, timestep = jax.lax.cond(
100 | timestep.last(),
101 | auto_reset,
102 | lambda new_state, timestep: (new_state, timestep),
103 | state,
104 | timestep,
105 | )
106 | metrics = info
107 | return state, timestep, metrics
108 |
109 | gymnax_obs_spec = gymnax_env.observation_space(env_params)
110 | gymnax_action_spec = gymnax_env.action_space(env_params)
111 |
112 | spec = specs.EnvironmentSpec(
113 | observations=specs.Array(
114 | shape=gymnax_obs_spec.shape, dtype=gymnax_obs_spec.dtype
115 | ),
116 | actions=specs.DiscreteArray(
117 | num_values=gymnax_action_spec.n,
118 | ),
119 | rewards=None,
120 | discounts=None,
121 | )
122 |
123 | env = Environment(step=step, init=init, spec=spec)
124 | env = wrap_env_for_episode_metrics(env)
125 | return env, get_true_value
126 |
--------------------------------------------------------------------------------
/discounting_chain/envs/gymnax_wrapper.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple, Tuple
2 |
3 | import chex
4 | import gymnax
5 | import jax.random
6 | from acme.types import specs
7 | from dm_env import TimeStep
8 | from gymnax.environments.environment import Environment as GymnaxEnv
9 | from jax import numpy as jnp
10 |
11 | from discounting_chain.base import Environment, Metrics
12 | from discounting_chain.envs.env_utils import wrap_env_for_episode_metrics
13 |
14 | Action = chex.ArrayTree
15 |
16 |
17 | class State(NamedTuple):
18 | gymnax_state: chex.ArrayTree
19 | key: chex.PRNGKey
20 |
21 |
22 | def create_gymnax_env(env_name="CartPole-v1", **kwargs) -> Environment:
23 | gymnax_env: GymnaxEnv
24 | if env_name == "DiscountingChain-bsuite":
25 | from gymnax.environments.bsuite.discounting_chain import DiscountingChain
26 |
27 | print(kwargs)
28 | gymnax_env = DiscountingChain(**kwargs)
29 | env_params = gymnax_env.default_params
30 | elif env_name == "UmbrellaChain-bsuite":
31 | from gymnax.environments.bsuite.umbrella_chain import EnvParams, UmbrellaChain
32 |
33 | gymnax_env = UmbrellaChain(n_distractor=0)
34 | change_length = 20 if not kwargs else kwargs["chain_length"]
35 | env_params = EnvParams(change_length, change_length)
36 | else:
37 | gymnax_env, env_params = gymnax.make(env_name)
38 |
39 | def init(key: chex.PRNGKey):
40 | key, subkey = jax.random.split(key)
41 | obs, state = gymnax_env.reset(subkey, env_params)
42 | timestep = TimeStep(
43 | step_type=jnp.int_(0),
44 | reward=jnp.float32(0.0),
45 | discount=jnp.float32(0.0),
46 | observation=obs,
47 | )
48 | state = State(state, key)
49 | return state, timestep
50 |
51 | def auto_reset(state: State, timestep: TimeStep) -> Tuple[State, TimeStep]:
52 | key, subkey = jax.random.split(state.key)
53 | obs, state = gymnax_env.reset(subkey)
54 | timestep = timestep._replace(observation=obs)
55 | state = State(state, key)
56 | return state, timestep
57 |
58 | def step(state: State, action: Action) -> Tuple[State, TimeStep, Metrics]:
59 | key, subkey = jax.random.split(state.key)
60 | n_obs, n_state, reward, done, info = gymnax_env.step_env(
61 | subkey, state.gymnax_state, action, env_params
62 | )
63 | timestep = TimeStep(
64 | step_type=jax.lax.select(done, jnp.int_(2), jnp.int_(1)),
65 | reward=jnp.float32(reward),
66 | discount=jnp.float32(1 - done),
67 | observation=n_obs,
68 | )
69 | state = State(n_state, key)
70 |
71 | state, timestep = jax.lax.cond(
72 | timestep.last(),
73 | auto_reset,
74 | lambda new_state, timestep: (new_state, timestep),
75 | state,
76 | timestep,
77 | )
78 | metrics = info
79 | return state, timestep, metrics
80 |
81 | gymnax_obs_spec = gymnax_env.observation_space(env_params)
82 | gymnax_action_spec = gymnax_env.action_space(env_params)
83 |
84 | spec = specs.EnvironmentSpec(
85 | observations=specs.Array(
86 | shape=gymnax_obs_spec.shape, dtype=gymnax_obs_spec.dtype
87 | ),
88 | actions=specs.DiscreteArray(
89 | num_values=gymnax_action_spec.n,
90 | ),
91 | rewards=None,
92 | discounts=None,
93 | )
94 |
95 | env = Environment(step=step, init=init, spec=spec)
96 | env = wrap_env_for_episode_metrics(env)
97 | return env
98 |
--------------------------------------------------------------------------------
/discounting_chain/list_logger.py:
--------------------------------------------------------------------------------
1 | import time
2 | from typing import Dict, List, Union
3 |
4 | import numpy as np
5 |
6 |
7 | class ListLogger:
8 | def __init__(self):
9 | self.history: Dict[str, List[Union[np.ndarray, float, int]]] = {}
10 | self.iter_n = 0
11 | self.start_time = 0.0
12 |
13 | def write(self, data: Dict) -> None:
14 | for key, value in data.items():
15 | if key in self.history:
16 | self.history[key].append(value)
17 | else:
18 | self.history[key] = [value]
19 |
20 | def close(self) -> None:
21 | pass
22 |
23 | def init_time(self) -> None:
24 | self.start_time = time.time()
25 |
--------------------------------------------------------------------------------
/discounting_chain/nets.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import chex
4 | import haiku as hk
5 | import jax
6 | from acme.jax import networks as jax_networks
7 | from acme.specs import EnvironmentSpec
8 | from jax import numpy as jnp
9 | from tensorflow_probability.substrates import jax as tfp
10 |
11 |
12 | class MiniAtariTorso(jax_networks.base.Module):
13 | """A network in the style of `acme.jax.networks.AtariTorso` but for smaller image sizes,
14 | like with Snake from Jumanji."""
15 |
16 | def __init__(self, conv_strides) -> None:
17 | super().__init__(name="atari_torso")
18 | self._network = hk.Sequential(
19 | [
20 | hk.Conv2D(32, [2, 2], conv_strides[0]),
21 | jax.nn.relu,
22 | hk.Conv2D(32, [2, 2], conv_strides[1]),
23 | jax.nn.relu,
24 | ]
25 | )
26 |
27 | def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
28 | inputs_rank = jnp.ndim(inputs)
29 | batched_inputs = inputs_rank == 4
30 | if inputs_rank < 3 or inputs_rank > 4:
31 | raise ValueError("Expected input BHWC or HWC. Got rank %d" % inputs_rank)
32 |
33 | outputs = self._network(inputs)
34 |
35 | if batched_inputs:
36 | return jnp.reshape(outputs, [outputs.shape[0], -1]) # [B, D]
37 | return jnp.reshape(outputs, [-1]) # [D]
38 |
39 |
40 | def create_linear_forward_fn(
41 | environment_spec: EnvironmentSpec, double_value_head: bool = False
42 | ):
43 | # A minimal forward function, where the actor is a linear layer.
44 | # The critic for this function is not intended to be used.
45 | def forward_fn(
46 | obs: jnp.ndarray,
47 | ) -> Tuple[tfp.distributions.Distribution, chex.ArrayTree]:
48 | actor_network = jax_networks.CategoricalHead(
49 | num_values=environment_spec.actions.num_values,
50 | w_init=hk.initializers.VarianceScaling(1e-4),
51 | )
52 | policy = actor_network(obs)
53 | # value function is a dummy if using linear forward fn
54 | val = hk.Linear(1, w_init=hk.initializers.VarianceScaling(1e-4))(obs) * jnp.nan
55 | value = val, val if double_value_head else val
56 | return policy, value
57 |
58 | return hk.without_apply_rng(hk.transform(forward_fn))
59 |
60 |
61 | def create_forward_fn(
62 | environment_spec: EnvironmentSpec,
63 | actor_hidden_layers: Tuple[int, ...] = (32, 32),
64 | critic_hidden_layers: Tuple[int, ...] = (64, 64),
65 | meta_critic_hidden_layers: Optional = None,
66 | torso_mlp_units=(32, 32),
67 | conv_strides=(2, 1),
68 | double_value_head: bool = False,
69 | ) -> hk.Transformed:
70 | assert len(environment_spec.observations.shape) in [1, 3]
71 |
72 | torso = (
73 | lambda: MiniAtariTorso(conv_strides)
74 | if len(environment_spec.observations.shape) == 3
75 | else hk.nets.MLP(torso_mlp_units)
76 | )
77 |
78 | def forward_fn(
79 | obs: jnp.ndarray,
80 | ) -> Tuple[tfp.distributions.Distribution, chex.ArrayTree]:
81 | enc_crit = torso()(obs)
82 | enc_crit_torso = enc_crit
83 | critic_network: hk.Module = hk.nets.MLP(
84 | critic_hidden_layers, activate_final=True
85 | )
86 | enc_crit = critic_network(enc_crit)
87 | value = hk.Linear(1, w_init=hk.initializers.VarianceScaling(1e-4))(enc_crit)
88 |
89 | enc_act = torso()(obs)
90 | actor_network: hk.Module = hk.Sequential(
91 | [
92 | hk.nets.MLP(actor_hidden_layers, activate_final=True),
93 | jax_networks.CategoricalHead(
94 | num_values=environment_spec.actions.num_values
95 | ),
96 | ],
97 | name="actor",
98 | )
99 | policy = actor_network(enc_act)
100 | value = jnp.squeeze(value, axis=-1)
101 | if double_value_head:
102 | if meta_critic_hidden_layers:
103 | critic_network_2 = hk.nets.MLP(
104 | meta_critic_hidden_layers, activate_final=True
105 | )
106 | value_2 = critic_network_2(jax.lax.stop_gradient(enc_crit_torso))
107 | else:
108 | value_2 = jax.lax.stop_gradient(enc_crit)
109 | value_2 = hk.Linear(1, w_init=hk.initializers.VarianceScaling(1e-4))(
110 | value_2
111 | )
112 | value_2 = jnp.squeeze(value_2, axis=-1)
113 | value = value, value_2
114 | return policy, value
115 |
116 | return hk.without_apply_rng(hk.transform(forward_fn))
117 |
--------------------------------------------------------------------------------
/discounting_chain/nets_split.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import chex
4 | import haiku as hk
5 | import jax
6 | from acme.jax import networks as jax_networks
7 | from acme.specs import EnvironmentSpec
8 | from jax import numpy as jnp
9 | from tensorflow_probability.substrates import jax as tfp
10 |
11 |
12 | class MiniAtariTorso(jax_networks.base.Module):
13 | """A network in the style of `acme.jax.networks.AtariTorso` but for smaller image sizes,
14 | like with Snake from Jumanji."""
15 |
16 | def __init__(self) -> None:
17 | super().__init__(name="atari_torso")
18 | self._network = hk.Sequential(
19 | [
20 | hk.Conv2D(32, [2, 2], 2),
21 | jax.nn.relu,
22 | hk.Conv2D(32, [2, 2], 1),
23 | jax.nn.relu,
24 | ]
25 | )
26 |
27 | def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
28 | inputs_rank = jnp.ndim(inputs)
29 | batched_inputs = inputs_rank == 4
30 | if inputs_rank < 3 or inputs_rank > 4:
31 | raise ValueError("Expected input BHWC or HWC. Got rank %d" % inputs_rank)
32 |
33 | outputs = self._network(inputs)
34 |
35 | if batched_inputs:
36 | return jnp.reshape(outputs, [outputs.shape[0], -1]) # [B, D]
37 | return jnp.reshape(outputs, [-1]) # [D]
38 |
39 |
40 | def create_forward_fn(
41 | environment_spec: EnvironmentSpec,
42 | actor_hidden_layers: Tuple[int, ...] = (64, 64),
43 | critic_hidden_layers: Tuple[int, ...] = (64, 64),
44 | torso_mlp_units=(32, 32),
45 | ) -> Tuple[hk.Transformed, hk.Transformed]:
46 | assert len(environment_spec.observations.shape) in [1, 3]
47 |
48 | torso = (
49 | lambda: MiniAtariTorso()
50 | if len(environment_spec.observations.shape) == 3
51 | else hk.nets.MLP(torso_mlp_units)
52 | )
53 |
54 | def policy_forward_fn(
55 | obs: jnp.ndarray,
56 | ) -> Tuple[tfp.distributions.Distribution, chex.ArrayTree]:
57 | enc_act = torso()(obs)
58 | actor_network: hk.Module = hk.Sequential(
59 | [
60 | hk.nets.MLP(actor_hidden_layers, activate_final=True),
61 | jax_networks.CategoricalHead(
62 | num_values=environment_spec.actions.num_values
63 | ),
64 | ],
65 | name="actor",
66 | )
67 | policy = actor_network(enc_act)
68 | return policy
69 |
70 | def critic_forward_fn(
71 | obs: jnp.ndarray,
72 | ) -> Tuple[tfp.distributions.Distribution, chex.ArrayTree]:
73 | enc_crit = torso()(obs)
74 | critic_network: hk.Module = hk.nets.MLP(
75 | critic_hidden_layers, activate_final=True
76 | )
77 | enc_crit = critic_network(enc_crit)
78 | value = hk.Linear(1, w_init=hk.initializers.VarianceScaling(1e-4))(enc_crit)
79 | value = jnp.squeeze(value, axis=-1)
80 | return value
81 |
82 | return hk.without_apply_rng(hk.transform(policy_forward_fn)), hk.without_apply_rng(
83 | hk.transform(critic_forward_fn)
84 | )
85 |
--------------------------------------------------------------------------------
/discounting_chain/train_utils.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import jax.random
4 | from acme.jax import utils as acme_utils
5 | from jax import numpy as jnp
6 | from tqdm import tqdm
7 |
8 | from discounting_chain.base import OnlineAgent
9 | from discounting_chain.list_logger import ListLogger
10 |
11 |
12 | def outer_iter(state, update_fn, n_updates):
13 | state, metrics = jax.lax.scan(
14 | lambda state, xs: update_fn(state), init=state, xs=None, length=n_updates
15 | )
16 | metrics = jax.tree_map(lambda x: x[-1], metrics) # last device, last element
17 | return state, metrics
18 |
19 |
20 | def run(
21 | num_iterations: int,
22 | n_updates_per_iter: int,
23 | agent: OnlineAgent,
24 | logger: ListLogger,
25 | seed: int = 0,
26 | ):
27 | key = jax.random.PRNGKey(seed)
28 | key1, key2 = jax.random.split(key)
29 | devices = jax.devices()
30 | num_devices = len(devices)
31 | pmap_axis_name = "num_devices"
32 | outer_iter_fn = partial(
33 | outer_iter, update_fn=agent.update, n_updates=n_updates_per_iter
34 | )
35 | outer_iter_fn = jax.pmap(outer_iter_fn, axis_name=pmap_axis_name, devices=devices)
36 | state = jax.pmap(agent.init)(
37 | jnp.stack([key1] * num_devices), jax.random.split(key2, num_devices)
38 | )
39 |
40 | for _ in tqdm(range(num_iterations)):
41 | state, metrics = outer_iter_fn(state)
42 | metrics = acme_utils.get_from_first_device(metrics)
43 | logger.write(metrics)
44 |
45 | return acme_utils.get_from_first_device(state)
46 |
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 | python_version = 3.8
3 | namespace_packages = True
4 | incremental = False
5 | cache_dir = nul
6 | warn_redundant_casts = True
7 | warn_return_any = True
8 | warn_unused_configs = True
9 | warn_unused_ignores = False
10 | allow_redefinition = True
11 | disallow_untyped_calls = True
12 | disallow_untyped_defs = True
13 | disallow_incomplete_defs = True
14 | check_untyped_defs = True
15 | disallow_untyped_decorators = False
16 | strict_optional = True
17 | strict_equality = True
18 | explicit_package_bases = True
19 | follow_imports = skip
20 |
21 | [mypy-numpy.*]
22 | ignore_missing_imports = True
23 |
24 | [mypy-tree.*]
25 | ignore_missing_imports = True
26 |
27 | [mypy-pytest.*]
28 | ignore_missing_imports = True
29 |
30 | [mypy-hydra.*]
31 | ignore_missing_imports = True
32 |
33 | [mypy-omegaconf.*]
34 | ignore_missing_imports = True
35 |
36 | [mypy-optax.*]
37 | ignore_missing_imports = True
38 |
--------------------------------------------------------------------------------
/plots/discounting_chain/dc_chain_discount_factor.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/discounting_chain/dc_chain_discount_factor.png
--------------------------------------------------------------------------------
/plots/discounting_chain/dc_chain_discount_factor_appendix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/discounting_chain/dc_chain_discount_factor_appendix.png
--------------------------------------------------------------------------------
/plots/discounting_chain/dc_chain_outer_loss_advantage.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/discounting_chain/dc_chain_outer_loss_advantage.png
--------------------------------------------------------------------------------
/plots/discounting_chain/dc_chain_outer_loss_advantage_appendix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/discounting_chain/dc_chain_outer_loss_advantage_appendix.png
--------------------------------------------------------------------------------
/plots/discounting_chain/dc_chain_return.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/discounting_chain/dc_chain_return.png
--------------------------------------------------------------------------------
/plots/discounting_chain/dc_chain_return_appendix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/discounting_chain/dc_chain_return_appendix.png
--------------------------------------------------------------------------------
/plots/snake/snake_discount_factor.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/snake/snake_discount_factor.png
--------------------------------------------------------------------------------
/plots/snake/snake_discount_factor_appendix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/snake/snake_discount_factor_appendix.png
--------------------------------------------------------------------------------
/plots/snake/snake_outer_loss_advantage.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/snake/snake_outer_loss_advantage.png
--------------------------------------------------------------------------------
/plots/snake/snake_outer_loss_advantage_appendix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/snake/snake_outer_loss_advantage_appendix.png
--------------------------------------------------------------------------------
/plots/snake/snake_return.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/snake/snake_return.png
--------------------------------------------------------------------------------
/plots/snake/snake_return_appendix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/plots/snake/snake_return_appendix.png
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | # All of these dependencies are pinned to specific versions to achieve reproducible builds (CI/Dev)
2 | black==22.10.0
3 | coverage==6.4.4
4 | flake8==5.0.4
5 | isort==5.10.1
6 | mypy==0.971
7 | pre-commit==2.20.0
8 | promise==2.3
9 | pytest==7.1.3
10 | pytest-cov==3.0.0
11 | pytest-mock==3.8.2
12 | pytest-parallel==0.1.1
13 | pytest-xdist==2.5.0
14 | pytype==2022.9.8
15 | testfixtures==7.0.0
16 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | chex==0.1.5
2 | distrax==0.1.2
3 | dm-haiku==0.0.7
4 | gymnax==0.0.5
5 | hydra-core==1.2.0
6 | jax==0.3.22
7 | jaxlib==0.3.22
8 | jumanji==0.1.3
9 | jupyter==1.0.0
10 | matplotlib==3.6.2
11 | neptune-client==0.16.7
12 | numpy==1.22.4
13 | omegaconf==2.2.3
14 | optax==0.1.3
15 | rlax==0.1.4
16 | scipy==1.9.1
17 | tqdm==4.64.1
18 |
--------------------------------------------------------------------------------
/snake/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/snake/__init__.py
--------------------------------------------------------------------------------
/snake/agent/__init__.py:
--------------------------------------------------------------------------------
1 | from snake.agent.a2c import A2C
2 | from snake.agent.actor_critic_agent import ActorCriticAgent
3 | from snake.agent.meta_a2c import MetaA2C, meta_params_to_hyper_params
4 |
--------------------------------------------------------------------------------
/snake/agent/gae.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from typing import Tuple
3 |
4 | import chex
5 | import jax
6 | import rlax
7 | from jax import numpy as jnp
8 |
9 |
10 | def compute_td_lambda(
11 | discount: chex.Array,
12 | rewards: chex.Array,
13 | values: chex.Array,
14 | bootstrap_value: chex.Array,
15 | lambda_: float,
16 | discount_factor: float,
17 | ) -> Tuple[chex.Array, chex.Array]:
18 | v_tm1 = values
19 | v_t = jnp.concatenate([values[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0)
20 | r_t = rewards
21 | discount_t = discount * discount_factor
22 | advantages = jax.vmap(
23 | functools.partial(rlax.td_lambda, lambda_=lambda_, stop_target_gradients=False),
24 | in_axes=1,
25 | out_axes=1,
26 | )(
27 | v_tm1,
28 | r_t,
29 | discount_t,
30 | v_t,
31 | )
32 | vs = advantages + v_tm1
33 | return vs, advantages
34 |
--------------------------------------------------------------------------------
/snake/configs/agent/a2c.yaml:
--------------------------------------------------------------------------------
1 | name: a2c
2 | agent: a2c
3 | meta_objective: null
4 | gamma_metal: null
5 | lambda_metal: null
6 | l_pg_metal: null
7 | l_td_metal: null
8 | l_en_metal: null
9 | gamma_outer: null
10 | lambda_outer: null
11 | l_pg_outer: null
12 | l_td_outer: null
13 | l_en_outer: null
14 | l_kl_outer: null
15 | bootstrap_l: null
16 | bootstrap_pm: null
17 | bootstrap_vm: null
18 | meta_optimizer: null
19 | meta_learning_rate: null
20 | meta_gradient_clip_norm: null
21 | bootstrap_l_optimizer: null
22 | bootstrap_l_learning_rate: null
23 |
--------------------------------------------------------------------------------
/snake/configs/agent/bootstrap.yaml:
--------------------------------------------------------------------------------
1 | name: bootstrap
2 | agent: meta_a2c
3 | meta_objective: bootstrap
4 | gamma_metal: true
5 | lambda_metal: false
6 | l_pg_metal: false
7 | l_td_metal: false
8 | l_en_metal: false
9 | gamma_outer: 1.0
10 | lambda_outer: 1.0
11 | l_pg_outer: 1.0
12 | l_td_outer: 0
13 | l_en_outer: 0
14 | l_kl_outer: 0
15 | bootstrap_l: 1
16 | bootstrap_pm: 1
17 | bootstrap_vm: 0.0
18 | meta_optimizer: adam # [sgd, rmsprop, adam]
19 | meta_learning_rate: 6e-3
20 | meta_gradient_clip_norm: 0.1 # [null, ]
21 | bootstrap_l_optimizer: rmsprop
22 | bootstrap_l_learning_rate: ${training.learning_rate}
23 |
--------------------------------------------------------------------------------
/snake/configs/agent/mgrl.yaml:
--------------------------------------------------------------------------------
1 | name: mgrl
2 | agent: meta_a2c
3 | meta_objective: meta_gradient
4 | gamma_metal: true
5 | lambda_metal: false
6 | l_pg_metal: false
7 | l_td_metal: false
8 | l_en_metal: false
9 | gamma_outer: 1.0
10 | lambda_outer: 1.0
11 | l_pg_outer: 1.0
12 | l_td_outer: 0
13 | l_en_outer: 0
14 | l_kl_outer: 0
15 | bootstrap_l: null
16 | bootstrap_pm: null
17 | bootstrap_vm: null
18 | meta_optimizer: adam # [sgd, rmsprop, adam]
19 | meta_learning_rate: 3e-3
20 | meta_gradient_clip_norm: 0.1 # [null, ]
21 | bootstrap_l_optimizer: null
22 | bootstrap_l_learning_rate: null
23 |
--------------------------------------------------------------------------------
/snake/configs/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - agent: mgrl # [a2c, mgrl, bootstrap]
3 | - _self_
4 |
5 | training:
6 | name: appendix_norm_advantages_snake_${agent.name}_outer_critic_${agent.outer_critic}_normalize_advantage_${agent.normalize_advantage}_normalize_outer_advantage_${agent.normalize_outer_advantage}
7 | num_timesteps: 75_000_000
8 | num_eval_points: 200
9 | total_batch_size: 512
10 | total_num_envs: 512
11 | n_steps: 5
12 | reward_scaling: 1
13 | optimizer: rmsprop # [sgd, rmsprop, adam]
14 | learning_rate: 5e-4
15 | gradient_clip_norm: null # [null, ]
16 | gamma_init: 0.8
17 | lambda_init: 0.95
18 | l_pg_init: 1.0
19 | l_td_init: 0.5
20 | l_en_init: 1e-2
21 | seed: 1
22 |
23 |
24 | agent:
25 | # Common to all agents
26 | outer_critic: false # [true, false]
27 | normalize_advantage: false # [true, false]
28 | normalize_outer_advantage: false # [true, false]
29 | total_num_eval: 512
30 | deterministic_eval: false # Does not matter since now returns both metrics (determinist and stochastic)
31 |
32 | network:
33 | # Overwrite
34 | type: snake
35 | num_channels: 32
36 | policy_layers: [64, 64] # Tuned to [64, 64] out of [[64, 64], [32, 32]]
37 | value_layers: [128, 128]
38 | embedding_size_actor: null # [, null]
39 | embedding_size_critic: null # [, null]
40 |
--------------------------------------------------------------------------------
/snake/data/a2c_return/MET-2150__eval_episode_reward_determinist_policy.csv:
--------------------------------------------------------------------------------
1 | 3.0,1667986130985,0.08984375
2 | 146.0,1667986137977,0.1484375
3 | 292.0,1667986140676,0.1640625
4 | 438.0,1667986143329,1.462890625
5 | 584.0,1667986146010,4.087890625
6 | 730.0,1667986148551,3.978515625
7 | 876.0,1667986151127,3.876953125
8 | 1022.0,1667986153788,4.275390625
9 | 1168.0,1667986156381,4.640625
10 | 1314.0,1667986158996,4.69140625
11 | 1460.0,1667986161606,5.037109375
12 | 1606.0,1667986164201,5.029296875
13 | 1752.0,1667986166926,5.419921875
14 | 1898.0,1667986169604,5.470703125
15 | 2044.0,1667986172243,5.4296875
16 | 2190.0,1667986174887,5.783203125
17 | 2336.0,1667986177553,5.77734375
18 | 2482.0,1667986180187,5.771484375
19 | 2628.0,1667986182764,5.896484375
20 | 2774.0,1667986185322,6.35546875
21 | 2920.0,1667986187944,6.169921875
22 | 3066.0,1667986190567,6.283203125
23 | 3212.0,1667986193194,6.794921875
24 | 3358.0,1667986195773,6.794921875
25 | 3504.0,1667986198404,6.70703125
26 | 3650.0,1667986201275,7.203125
27 | 3796.0,1667986203893,7.40625
28 | 3942.0,1667986206545,7.66796875
29 | 4088.0,1667986209185,8.009765625
30 | 4234.0,1667986211820,8.2109375
31 | 4380.0,1667986214410,8.775390625
32 | 4526.0,1667986217095,9.494140625
33 | 4672.0,1667986219705,9.607421875
34 | 4818.0,1667986222317,10.248046875
35 | 4964.0,1667986224904,10.298828125
36 | 5110.0,1667986227626,10.349609375
37 | 5256.0,1667986230294,10.75
38 | 5402.0,1667986232953,10.7734375
39 | 5548.0,1667986235499,11.03125
40 | 5694.0,1667986238166,11.609375
41 | 5840.0,1667986240780,11.2734375
42 | 5986.0,1667986243398,11.123046875
43 | 6132.0,1667986245996,11.46875
44 | 6278.0,1667986248671,11.69921875
45 | 6424.0,1667986251336,12.326171875
46 | 6570.0,1667986253957,11.7890625
47 | 6716.0,1667986256607,11.935546875
48 | 6862.0,1667986259283,12.3359375
49 | 7008.0,1667986261919,12.28515625
50 | 7154.0,1667986264568,12.931640625
51 | 7300.0,1667986267282,13.208984375
52 | 7446.0,1667986269879,13.341796875
53 | 7592.0,1667986272482,13.46875
54 | 7738.0,1667986275155,13.52734375
55 | 7884.0,1667986277771,13.603515625
56 | 8030.0,1667986280433,13.591796875
57 | 8176.0,1667986283111,13.65625
58 | 8322.0,1667986285763,14.1953125
59 | 8468.0,1667986288457,13.95703125
60 | 8614.0,1667986291151,13.9765625
61 | 8760.0,1667986293792,14.0625
62 | 8906.0,1667986296379,14.056640625
63 | 9052.0,1667986299064,14.220703125
64 | 9198.0,1667986301695,13.888671875
65 | 9344.0,1667986304350,14.255859375
66 | 9490.0,1667986307065,14.794921875
67 | 9636.0,1667986309712,14.716796875
68 | 9782.0,1667986312360,14.904296875
69 | 9928.0,1667986315018,14.818359375
70 | 10074.0,1667986317683,14.9609375
71 | 10220.0,1667986320328,14.767578125
72 | 10366.0,1667986323004,14.96875
73 | 10512.0,1667986325707,15.0078125
74 | 10658.0,1667986328283,15.298828125
75 | 10804.0,1667986330916,14.78515625
76 | 10950.0,1667986333553,15.36328125
77 | 11096.0,1667986336144,14.96875
78 | 11242.0,1667986338835,15.060546875
79 | 11388.0,1667986341441,15.337890625
80 | 11534.0,1667986344090,15.3046875
81 | 11680.0,1667986346758,15.625
82 | 11826.0,1667986349418,15.728515625
83 | 11972.0,1667986352023,15.9453125
84 | 12118.0,1667986354605,15.810546875
85 | 12264.0,1667986357232,15.771484375
86 | 12410.0,1667986359841,16.22265625
87 | 12556.0,1667986362468,15.8515625
88 | 12702.0,1667986365058,16.251953125
89 | 12848.0,1667986367773,15.80078125
90 | 12994.0,1667986370387,15.99609375
91 | 13140.0,1667986373078,15.685546875
92 | 13286.0,1667986375699,15.431640625
93 | 13432.0,1667986378359,15.736328125
94 | 13578.0,1667986380960,15.46484375
95 | 13724.0,1667986383614,15.96875
96 | 13870.0,1667986386244,16.0546875
97 | 14016.0,1667986388906,15.94921875
98 | 14162.0,1667986391517,15.9609375
99 | 14308.0,1667986394205,16.2578125
100 | 14454.0,1667986396715,15.931640625
101 | 14600.0,1667986399421,15.90234375
102 | 14746.0,1667986402022,15.87109375
103 | 14892.0,1667986404622,15.9609375
104 | 15038.0,1667986407310,16.0625
105 | 15184.0,1667986409944,16.228515625
106 | 15330.0,1667986412640,16.10546875
107 | 15476.0,1667986415244,16.322265625
108 | 15622.0,1667986417911,16.189453125
109 | 15768.0,1667986421289,16.669921875
110 | 15914.0,1667986423893,16.44140625
111 | 16060.0,1667986426458,16.13671875
112 | 16206.0,1667986429265,16.302734375
113 | 16352.0,1667986431954,16.529296875
114 | 16498.0,1667986434571,16.525390625
115 | 16644.0,1667986437218,16.517578125
116 | 16790.0,1667986439895,16.33203125
117 | 16936.0,1667986442550,16.0625
118 | 17082.0,1667986445214,16.61328125
119 | 17228.0,1667986447862,15.880859375
120 | 17374.0,1667986450550,16.333984375
121 | 17520.0,1667986453244,16.384765625
122 | 17666.0,1667986455851,16.34375
123 | 17812.0,1667986458457,16.431640625
124 | 17958.0,1667986461001,16.95703125
125 | 18104.0,1667986463506,17.076171875
126 | 18250.0,1667986466044,16.884765625
127 | 18396.0,1667986468622,16.7109375
128 | 18542.0,1667986471195,16.484375
129 | 18688.0,1667986473826,16.580078125
130 | 18834.0,1667986476401,16.873046875
131 | 18980.0,1667986479105,16.5625
132 | 19126.0,1667986481717,17.048828125
133 | 19272.0,1667986484294,17.06640625
134 | 19418.0,1667986486875,16.8046875
135 | 19564.0,1667986489596,16.435546875
136 | 19710.0,1667986492238,16.873046875
137 | 19856.0,1667986494881,16.82421875
138 | 20002.0,1667986497668,16.43359375
139 | 20148.0,1667986500376,16.791015625
140 | 20294.0,1667986503073,16.970703125
141 | 20440.0,1667986505889,16.802734375
142 | 20586.0,1667986508597,16.8671875
143 | 20732.0,1667986511179,16.75
144 | 20878.0,1667986513734,16.87890625
145 | 21024.0,1667986516362,17.078125
146 | 21170.0,1667986519023,17.083984375
147 | 21316.0,1667986521662,17.0390625
148 | 21462.0,1667986524266,17.021484375
149 | 21608.0,1667986526935,17.125
150 | 21754.0,1667986529563,17.123046875
151 | 21900.0,1667986532175,16.818359375
152 | 22046.0,1667986534837,16.58203125
153 | 22192.0,1667986537474,17.1796875
154 | 22338.0,1667986540093,17.087890625
155 | 22484.0,1667986542738,17.26953125
156 | 22630.0,1667986545352,17.53125
157 | 22776.0,1667986547950,16.798828125
158 | 22922.0,1667986550538,16.982421875
159 | 23068.0,1667986553117,16.619140625
160 | 23214.0,1667986555700,17.185546875
161 | 23360.0,1667986558337,17.05078125
162 | 23506.0,1667986560988,17.24609375
163 | 23652.0,1667986563555,17.3046875
164 | 23798.0,1667986566130,17.11328125
165 | 23944.0,1667986568802,17.34765625
166 | 24090.0,1667986571389,17.361328125
167 | 24236.0,1667986574049,17.197265625
168 | 24382.0,1667986576601,17.439453125
169 | 24528.0,1667986579220,17.458984375
170 | 24674.0,1667986581826,17.376953125
171 | 24820.0,1667986584428,17.41015625
172 | 24966.0,1667986587134,17.28515625
173 | 25112.0,1667986589806,17.205078125
174 | 25258.0,1667986592632,16.98828125
175 | 25404.0,1667986595277,17.021484375
176 | 25550.0,1667986597895,17.35546875
177 | 25696.0,1667986600551,17.25
178 | 25842.0,1667986603118,16.81640625
179 | 25988.0,1667986605742,17.220703125
180 | 26134.0,1667986608357,16.947265625
181 | 26280.0,1667986610945,17.244140625
182 | 26426.0,1667986613635,17.16796875
183 | 26572.0,1667986616194,17.40625
184 | 26718.0,1667986618868,17.49609375
185 | 26864.0,1667986621550,17.68359375
186 | 27010.0,1667986624198,17.5625
187 | 27156.0,1667986626898,17.298828125
188 | 27302.0,1667986629607,17.69921875
189 | 27448.0,1667986632210,17.427734375
190 | 27594.0,1667986634852,17.408203125
191 | 27740.0,1667986637523,17.728515625
192 | 27886.0,1667986640170,17.26171875
193 | 28032.0,1667986642782,17.818359375
194 | 28178.0,1667986645444,17.63671875
195 | 28324.0,1667986648060,17.44140625
196 | 28470.0,1667986650671,17.126953125
197 | 28616.0,1667986653396,17.46875
198 | 28762.0,1667986656046,17.291015625
199 | 28908.0,1667986658737,17.333984375
200 | 29054.0,1667986661386,17.861328125
201 | 29200.0,1667986664068,17.9921875
202 |
--------------------------------------------------------------------------------
/snake/data/a2c_return/MET-2175__eval_episode_reward_determinist_policy.csv:
--------------------------------------------------------------------------------
1 | 3.0,1667989976182,0.099609375
2 | 146.0,1667989981915,0.15625
3 | 292.0,1667989984563,0.162109375
4 | 438.0,1667989987274,1.740234375
5 | 584.0,1667989989853,3.791015625
6 | 730.0,1667989992338,3.96875
7 | 876.0,1667989994992,4.443359375
8 | 1022.0,1667989997599,4.462890625
9 | 1168.0,1667990000225,5.017578125
10 | 1314.0,1667990002815,5.3515625
11 | 1460.0,1667990005479,5.134765625
12 | 1606.0,1667990008073,5.248046875
13 | 1752.0,1667990010690,5.5078125
14 | 1898.0,1667990013326,5.80078125
15 | 2044.0,1667990015889,5.666015625
16 | 2190.0,1667990018450,6.076171875
17 | 2336.0,1667990021077,5.92578125
18 | 2482.0,1667990023655,6.037109375
19 | 2628.0,1667990026315,6.462890625
20 | 2774.0,1667990028932,6.423828125
21 | 2920.0,1667990031518,6.51953125
22 | 3066.0,1667990034139,6.65625
23 | 3212.0,1667990036773,6.45703125
24 | 3358.0,1667990039338,6.154296875
25 | 3504.0,1667990041975,6.44140625
26 | 3650.0,1667990044688,6.484375
27 | 3796.0,1667990047256,6.181640625
28 | 3942.0,1667990049847,6.623046875
29 | 4088.0,1667990052436,6.5234375
30 | 4234.0,1667990055125,6.478515625
31 | 4380.0,1667990057725,6.8984375
32 | 4526.0,1667990060296,6.162109375
33 | 4672.0,1667990062897,7.015625
34 | 4818.0,1667990065537,6.916015625
35 | 4964.0,1667990068103,6.619140625
36 | 5110.0,1667990070666,6.7734375
37 | 5256.0,1667990073287,6.896484375
38 | 5402.0,1667990075844,6.736328125
39 | 5548.0,1667990078438,7.04296875
40 | 5694.0,1667990081023,6.6171875
41 | 5840.0,1667990083628,7.025390625
42 | 5986.0,1667990086138,6.859375
43 | 6132.0,1667990088732,7.025390625
44 | 6278.0,1667990091317,7.517578125
45 | 6424.0,1667990093894,7.19140625
46 | 6570.0,1667990096517,7.4296875
47 | 6716.0,1667990099126,7.734375
48 | 6862.0,1667990101733,7.703125
49 | 7008.0,1667990104459,7.912109375
50 | 7154.0,1667990107053,8.0234375
51 | 7300.0,1667990109614,8.076171875
52 | 7446.0,1667990112304,7.9296875
53 | 7592.0,1667990114964,8.052734375
54 | 7738.0,1667990117590,8.419921875
55 | 7884.0,1667990120192,8.7109375
56 | 8030.0,1667990122856,8.712890625
57 | 8176.0,1667990125531,8.431640625
58 | 8322.0,1667990128100,9.23046875
59 | 8468.0,1667990130664,9.380859375
60 | 8614.0,1667990133271,9.248046875
61 | 8760.0,1667990135902,9.9921875
62 | 8906.0,1667990138510,10.396484375
63 | 9052.0,1667990141052,10.49609375
64 | 9198.0,1667990143727,10.53515625
65 | 9344.0,1667990146374,10.791015625
66 | 9490.0,1667990149019,10.68359375
67 | 9636.0,1667990151526,10.990234375
68 | 9782.0,1667990154145,11.623046875
69 | 9928.0,1667990156724,11.455078125
70 | 10074.0,1667990159396,11.404296875
71 | 10220.0,1667990162027,11.548828125
72 | 10366.0,1667990164783,11.4296875
73 | 10512.0,1667990167355,11.720703125
74 | 10658.0,1667990169958,12.072265625
75 | 10804.0,1667990172593,12.408203125
76 | 10950.0,1667990175362,12.3359375
77 | 11096.0,1667990177925,12.75390625
78 | 11242.0,1667990180561,12.20703125
79 | 11388.0,1667990183206,12.4453125
80 | 11534.0,1667990185832,12.607421875
81 | 11680.0,1667990188394,13.1640625
82 | 11826.0,1667990191032,13.232421875
83 | 11972.0,1667990193656,13.1484375
84 | 12118.0,1667990196267,13.28515625
85 | 12264.0,1667990198879,13.083984375
86 | 12410.0,1667990201450,13.111328125
87 | 12556.0,1667990204127,13.68359375
88 | 12702.0,1667990206747,13.6640625
89 | 12848.0,1667990209391,13.783203125
90 | 12994.0,1667990212044,13.826171875
91 | 13140.0,1667990214739,14.1875
92 | 13286.0,1667990217321,14.310546875
93 | 13432.0,1667990219915,14.13671875
94 | 13578.0,1667990222592,14.14453125
95 | 13724.0,1667990225338,14.392578125
96 | 13870.0,1667990227927,14.03125
97 | 14016.0,1667990230608,14.626953125
98 | 14162.0,1667990233236,14.578125
99 | 14308.0,1667990235804,14.341796875
100 | 14454.0,1667990238405,14.416015625
101 | 14600.0,1667990240995,14.455078125
102 | 14746.0,1667990243630,15.046875
103 | 14892.0,1667990246192,14.93359375
104 | 15038.0,1667990248815,15.33203125
105 | 15184.0,1667990251419,15.09765625
106 | 15330.0,1667990254048,15.09765625
107 | 15476.0,1667990256642,14.90625
108 | 15622.0,1667990259285,15.240234375
109 | 15768.0,1667990261878,14.9921875
110 | 15914.0,1667990264545,15.28515625
111 | 16060.0,1667990267159,15.302734375
112 | 16206.0,1667990269730,15.541015625
113 | 16352.0,1667990272338,15.419921875
114 | 16498.0,1667990274999,15.65625
115 | 16644.0,1667990277627,15.697265625
116 | 16790.0,1667990280275,15.85546875
117 | 16936.0,1667990282941,15.44921875
118 | 17082.0,1667990285631,15.8671875
119 | 17228.0,1667990288181,15.548828125
120 | 17374.0,1667990290779,15.591796875
121 | 17520.0,1667990293430,15.875
122 | 17666.0,1667990296013,16.123046875
123 | 17812.0,1667990298631,15.9453125
124 | 17958.0,1667990301206,15.94921875
125 | 18104.0,1667990303936,15.837890625
126 | 18250.0,1667990306518,15.85546875
127 | 18396.0,1667990309156,16.10546875
128 | 18542.0,1667990311762,15.900390625
129 | 18688.0,1667990314434,16.123046875
130 | 18834.0,1667990317080,16.3125
131 | 18980.0,1667990319815,16.142578125
132 | 19126.0,1667990322492,16.5625
133 | 19272.0,1667990325204,16.4609375
134 | 19418.0,1667990327771,16.296875
135 | 19564.0,1667990330386,16.20703125
136 | 19710.0,1667990332997,16.0703125
137 | 19856.0,1667990335621,16.40625
138 | 20002.0,1667990338130,16.330078125
139 | 20148.0,1667990340800,16.236328125
140 | 20294.0,1667990343447,16.41796875
141 | 20440.0,1667990346065,16.443359375
142 | 20586.0,1667990348613,16.724609375
143 | 20732.0,1667990351219,16.556640625
144 | 20878.0,1667990353831,15.95703125
145 | 21024.0,1667990356455,16.396484375
146 | 21170.0,1667990359166,16.0234375
147 | 21316.0,1667990361773,16.205078125
148 | 21462.0,1667990364467,16.166015625
149 | 21608.0,1667990367120,16.36328125
150 | 21754.0,1667990369686,16.224609375
151 | 21900.0,1667990372283,16.47265625
152 | 22046.0,1667990374973,16.724609375
153 | 22192.0,1667990377590,16.7109375
154 | 22338.0,1667990380343,16.41796875
155 | 22484.0,1667990382953,16.314453125
156 | 22630.0,1667990385562,16.775390625
157 | 22776.0,1667990388170,16.94140625
158 | 22922.0,1667990390697,16.68359375
159 | 23068.0,1667990393356,16.806640625
160 | 23214.0,1667990395979,16.451171875
161 | 23360.0,1667990398612,16.939453125
162 | 23506.0,1667990401159,16.90625
163 | 23652.0,1667990403892,16.6640625
164 | 23798.0,1667990406472,16.990234375
165 | 23944.0,1667990409025,16.5390625
166 | 24090.0,1667990411675,17.267578125
167 | 24236.0,1667990414317,16.310546875
168 | 24382.0,1667990416923,17.2421875
169 | 24528.0,1667990419548,16.6015625
170 | 24674.0,1667990422108,16.8828125
171 | 24820.0,1667990424853,16.673828125
172 | 24966.0,1667990427453,17.021484375
173 | 25112.0,1667990430084,16.404296875
174 | 25258.0,1667990432769,17.14453125
175 | 25404.0,1667990435344,17.162109375
176 | 25550.0,1667990437917,16.986328125
177 | 25696.0,1667990440636,16.916015625
178 | 25842.0,1667990443247,16.466796875
179 | 25988.0,1667990445842,16.845703125
180 | 26134.0,1667990448403,16.966796875
181 | 26280.0,1667990451046,16.44140625
182 | 26426.0,1667990453610,16.693359375
183 | 26572.0,1667990456315,16.947265625
184 | 26718.0,1667990458909,17.033203125
185 | 26864.0,1667990461517,17.20703125
186 | 27010.0,1667990464172,17.125
187 | 27156.0,1667990466755,17.3984375
188 | 27302.0,1667990469321,17.310546875
189 | 27448.0,1667990471876,17.00390625
190 | 27594.0,1667990474523,16.708984375
191 | 27740.0,1667990477134,17.35546875
192 | 27886.0,1667990479801,17.20703125
193 | 28032.0,1667990482322,17.310546875
194 | 28178.0,1667990485028,16.76171875
195 | 28324.0,1667990487645,16.892578125
196 | 28470.0,1667990490252,17.203125
197 | 28616.0,1667990492836,17.19921875
198 | 28762.0,1667990495446,16.828125
199 | 28908.0,1667990498049,17.08984375
200 | 29054.0,1667990500812,17.111328125
201 | 29200.0,1667990503455,17.140625
202 |
--------------------------------------------------------------------------------
/snake/data/appendix/return/return_a2c_no_norm.csv:
--------------------------------------------------------------------------------
1 | 3.0,1668722618745,0.1328125
2 | 146.0,1668722625924,0.17578125
3 | 292.0,1668722628787,0.248046875
4 | 438.0,1668722631689,1.306640625
5 | 584.0,1668722634513,3.3203125
6 | 730.0,1668722637239,3.830078125
7 | 876.0,1668722640071,4.123046875
8 | 1022.0,1668722643095,4.265625
9 | 1168.0,1668722646166,4.673828125
10 | 1314.0,1668722649007,4.82421875
11 | 1460.0,1668722651770,5.0234375
12 | 1606.0,1668722654546,5.16015625
13 | 1752.0,1668722657502,5.390625
14 | 1898.0,1668722660270,5.638671875
15 | 2044.0,1668722663126,5.703125
16 | 2190.0,1668722665935,5.626953125
17 | 2336.0,1668722668631,5.728515625
18 | 2482.0,1668722671420,5.884765625
19 | 2628.0,1668722674190,5.865234375
20 | 2774.0,1668722676928,6.033203125
21 | 2920.0,1668722679738,6.09765625
22 | 3066.0,1668722682509,6.349609375
23 | 3212.0,1668722685376,6.517578125
24 | 3358.0,1668722688201,6.458984375
25 | 3504.0,1668722691091,6.53125
26 | 3650.0,1668722693921,6.966796875
27 | 3796.0,1668722696767,7.638671875
28 | 3942.0,1668722699623,8.21484375
29 | 4088.0,1668722702517,8.09375
30 | 4234.0,1668722705340,8.544921875
31 | 4380.0,1668722708099,9.291015625
32 | 4526.0,1668722710957,9.58203125
33 | 4672.0,1668722713792,9.79296875
34 | 4818.0,1668722717101,9.626953125
35 | 4964.0,1668722719869,10.21484375
36 | 5110.0,1668722722717,10.4453125
37 | 5256.0,1668722725611,10.462890625
38 | 5402.0,1668722728366,10.546875
39 | 5548.0,1668722731141,11.421875
40 | 5694.0,1668722733947,11.341796875
41 | 5840.0,1668722736816,11.14453125
42 | 5986.0,1668722739885,11.376953125
43 | 6132.0,1668722742746,11.9296875
44 | 6278.0,1668722745719,11.630859375
45 | 6424.0,1668722748523,12.109375
46 | 6570.0,1668722751278,12.1484375
47 | 6716.0,1668722754023,12.201171875
48 | 6862.0,1668722756899,12.23046875
49 | 7008.0,1668722759735,12.189453125
50 | 7154.0,1668722762564,12.53515625
51 | 7300.0,1668722765390,12.814453125
52 | 7446.0,1668722768199,12.556640625
53 | 7592.0,1668722771022,13.16015625
54 | 7738.0,1668722773830,12.890625
55 | 7884.0,1668722776824,13.134765625
56 | 8030.0,1668722779608,13.09375
57 | 8176.0,1668722782636,12.8125
58 | 8322.0,1668722785514,13.759765625
59 | 8468.0,1668722788352,13.64453125
60 | 8614.0,1668722791245,13.568359375
61 | 8760.0,1668722794029,14.412109375
62 | 8906.0,1668722796856,13.87109375
63 | 9052.0,1668722799676,14.154296875
64 | 9198.0,1668722802439,14.294921875
65 | 9344.0,1668722805324,14.009765625
66 | 9490.0,1668722808229,14.669921875
67 | 9636.0,1668722811198,14.671875
68 | 9782.0,1668722814124,14.595703125
69 | 9928.0,1668722817013,14.349609375
70 | 10074.0,1668722819787,14.6484375
71 | 10220.0,1668722822591,14.232421875
72 | 10366.0,1668722825498,14.62109375
73 | 10512.0,1668722828389,14.556640625
74 | 10658.0,1668722831247,15.119140625
75 | 10804.0,1668722834023,14.6796875
76 | 10950.0,1668722836968,15.001953125
77 | 11096.0,1668722839780,14.845703125
78 | 11242.0,1668722842592,14.880859375
79 | 11388.0,1668722845395,14.57421875
80 | 11534.0,1668722848187,15.150390625
81 | 11680.0,1668722851047,15.416015625
82 | 11826.0,1668722853835,15.0546875
83 | 11972.0,1668722856750,15.02734375
84 | 12118.0,1668722859551,15.283203125
85 | 12264.0,1668722862325,15.052734375
86 | 12410.0,1668722865124,15.521484375
87 | 12556.0,1668722868063,14.876953125
88 | 12702.0,1668722870829,15.5703125
89 | 12848.0,1668722873599,15.712890625
90 | 12994.0,1668722876406,15.048828125
91 | 13140.0,1668722879185,15.6796875
92 | 13286.0,1668722881996,15.11328125
93 | 13432.0,1668722884857,16.09375
94 | 13578.0,1668722887671,15.765625
95 | 13724.0,1668722890474,15.298828125
96 | 13870.0,1668722893353,15.76953125
97 | 14016.0,1668722896310,15.470703125
98 | 14162.0,1668722899277,15.794921875
99 | 14308.0,1668722902083,15.4453125
100 | 14454.0,1668722904975,15.8515625
101 | 14600.0,1668722907863,15.51953125
102 | 14746.0,1668722910645,16.197265625
103 | 14892.0,1668722913480,16.1796875
104 | 15038.0,1668722916347,15.724609375
105 | 15184.0,1668722919129,15.78515625
106 | 15330.0,1668722922023,15.998046875
107 | 15476.0,1668722924856,15.859375
108 | 15622.0,1668722927743,16.169921875
109 | 15768.0,1668722930602,16.171875
110 | 15914.0,1668722933435,16.32421875
111 | 16060.0,1668722936266,16.32421875
112 | 16206.0,1668722939233,16.3125
113 | 16352.0,1668722942041,16.140625
114 | 16498.0,1668722944989,16.05859375
115 | 16644.0,1668722947910,16.212890625
116 | 16790.0,1668722950779,16.005859375
117 | 16936.0,1668722953680,16.47265625
118 | 17082.0,1668722956602,15.984375
119 | 17228.0,1668722959405,16.5234375
120 | 17374.0,1668722962206,16.30859375
121 | 17520.0,1668722965025,16.671875
122 | 17666.0,1668722967794,16.21484375
123 | 17812.0,1668722970580,16.564453125
124 | 17958.0,1668722973367,16.361328125
125 | 18104.0,1668722976233,16.33203125
126 | 18250.0,1668722978995,16.666015625
127 | 18396.0,1668722981858,16.787109375
128 | 18542.0,1668722984660,16.830078125
129 | 18688.0,1668722987523,16.58203125
130 | 18834.0,1668722990306,16.134765625
131 | 18980.0,1668722993061,16.52734375
132 | 19126.0,1668722995874,16.70703125
133 | 19272.0,1668722998644,16.501953125
134 | 19418.0,1668723001472,16.595703125
135 | 19564.0,1668723004334,16.4296875
136 | 19710.0,1668723007222,16.86328125
137 | 19856.0,1668723010009,16.73046875
138 | 20002.0,1668723012831,16.751953125
139 | 20148.0,1668723015635,16.884765625
140 | 20294.0,1668723018499,16.5
141 | 20440.0,1668723021226,16.6328125
142 | 20586.0,1668723024047,16.525390625
143 | 20732.0,1668723026780,16.51171875
144 | 20878.0,1668723029608,16.666015625
145 | 21024.0,1668723032434,16.9375
146 | 21170.0,1668723035307,16.83984375
147 | 21316.0,1668723038150,16.453125
148 | 21462.0,1668723040926,17.25
149 | 21608.0,1668723043703,16.78515625
150 | 21754.0,1668723046622,16.779296875
151 | 21900.0,1668723049502,16.404296875
152 | 22046.0,1668723052430,16.783203125
153 | 22192.0,1668723055266,16.884765625
154 | 22338.0,1668723058103,16.787109375
155 | 22484.0,1668723060992,16.939453125
156 | 22630.0,1668723063866,17.103515625
157 | 22776.0,1668723066826,17.25
158 | 22922.0,1668723069711,16.765625
159 | 23068.0,1668723072587,16.919921875
160 | 23214.0,1668723075419,16.849609375
161 | 23360.0,1668723078219,17.375
162 | 23506.0,1668723081105,16.9375
163 | 23652.0,1668723083923,16.921875
164 | 23798.0,1668723086729,16.94921875
165 | 23944.0,1668723089558,17.1796875
166 | 24090.0,1668723092344,16.98828125
167 | 24236.0,1668723095182,17.072265625
168 | 24382.0,1668723098064,17.19921875
169 | 24528.0,1668723100929,17.064453125
170 | 24674.0,1668723103813,17.375
171 | 24820.0,1668723106747,17.03125
172 | 24966.0,1668723109744,17.4375
173 | 25112.0,1668723112557,17.0703125
174 | 25258.0,1668723115409,17.28515625
175 | 25404.0,1668723118248,17.1953125
176 | 25550.0,1668723121087,16.943359375
177 | 25696.0,1668723123940,16.875
178 | 25842.0,1668723126839,17.06640625
179 | 25988.0,1668723129723,17.3515625
180 | 26134.0,1668723132585,17.177734375
181 | 26280.0,1668723135450,16.86328125
182 | 26426.0,1668723138352,17.1328125
183 | 26572.0,1668723141229,17.509765625
184 | 26718.0,1668723144063,17.349609375
185 | 26864.0,1668723146850,17.587890625
186 | 27010.0,1668723149608,17.283203125
187 | 27156.0,1668723152440,17.228515625
188 | 27302.0,1668723155309,17.798828125
189 | 27448.0,1668723158106,17.27734375
190 | 27594.0,1668723160924,17.126953125
191 | 27740.0,1668723163828,17.2265625
192 | 27886.0,1668723166770,17.990234375
193 | 28032.0,1668723169609,17.072265625
194 | 28178.0,1668723172508,17.09765625
195 | 28324.0,1668723175368,17.052734375
196 | 28470.0,1668723178253,17.23828125
197 | 28616.0,1668723181075,17.541015625
198 | 28762.0,1668723183835,17.408203125
199 | 28908.0,1668723186657,17.556640625
200 | 29054.0,1668723189500,17.318359375
201 | 29200.0,1668723192319,17.4453125
202 |
--------------------------------------------------------------------------------
/snake/data/appendix/return/return_a2c_norm.csv:
--------------------------------------------------------------------------------
1 | 3.0,1668723199785,0.1328125
2 | 146.0,1668723206026,0.453125
3 | 292.0,1668723208886,3.03125
4 | 438.0,1668723211647,3.771484375
5 | 584.0,1668723214505,4.033203125
6 | 730.0,1668723217425,4.22265625
7 | 876.0,1668723220288,4.453125
8 | 1022.0,1668723223134,4.732421875
9 | 1168.0,1668723226081,5.306640625
10 | 1314.0,1668723228938,5.54296875
11 | 1460.0,1668723231848,5.572265625
12 | 1606.0,1668723234717,5.8125
13 | 1752.0,1668723237594,6.025390625
14 | 1898.0,1668723240485,6.244140625
15 | 2044.0,1668723243281,6.216796875
16 | 2190.0,1668723246223,5.982421875
17 | 2336.0,1668723249199,6.421875
18 | 2482.0,1668723252189,6.41015625
19 | 2628.0,1668723255166,6.265625
20 | 2774.0,1668723258184,6.599609375
21 | 2920.0,1668723261077,6.408203125
22 | 3066.0,1668723264072,6.451171875
23 | 3212.0,1668723267061,6.826171875
24 | 3358.0,1668723269946,6.94921875
25 | 3504.0,1668723272867,6.423828125
26 | 3650.0,1668723275798,6.7265625
27 | 3796.0,1668723279741,6.51171875
28 | 3942.0,1668723282643,7.005859375
29 | 4088.0,1668723285567,6.7265625
30 | 4234.0,1668723288677,6.875
31 | 4380.0,1668723291552,7.146484375
32 | 4526.0,1668723294427,7.115234375
33 | 4672.0,1668723297414,7.189453125
34 | 4818.0,1668723300362,7.126953125
35 | 4964.0,1668723303314,7.7578125
36 | 5110.0,1668723306284,7.6015625
37 | 5256.0,1668723309194,7.517578125
38 | 5402.0,1668723312157,7.802734375
39 | 5548.0,1668723315124,7.798828125
40 | 5694.0,1668723318250,8.259765625
41 | 5840.0,1668723321158,8.380859375
42 | 5986.0,1668723324131,8.1875
43 | 6132.0,1668723327063,8.951171875
44 | 6278.0,1668723329843,8.6484375
45 | 6424.0,1668723332791,8.97265625
46 | 6570.0,1668723335712,9.037109375
47 | 6716.0,1668723338673,9.619140625
48 | 6862.0,1668723341533,9.62890625
49 | 7008.0,1668723344402,10.087890625
50 | 7154.0,1668723347417,10.1875
51 | 7300.0,1668723350321,10.0625
52 | 7446.0,1668723353280,10.076171875
53 | 7592.0,1668723356204,10.912109375
54 | 7738.0,1668723359128,10.685546875
55 | 7884.0,1668723362073,10.974609375
56 | 8030.0,1668723364954,11.37109375
57 | 8176.0,1668723367896,11.498046875
58 | 8322.0,1668723370858,11.626953125
59 | 8468.0,1668723373782,11.728515625
60 | 8614.0,1668723376815,11.515625
61 | 8760.0,1668723379712,12.392578125
62 | 8906.0,1668723382637,12.296875
63 | 9052.0,1668723385527,12.439453125
64 | 9198.0,1668723388476,12.74609375
65 | 9344.0,1668723391353,12.759765625
66 | 9490.0,1668723394186,12.921875
67 | 9636.0,1668723397082,12.861328125
68 | 9782.0,1668723400028,13.08984375
69 | 9928.0,1668723402958,13.16796875
70 | 10074.0,1668723405823,13.40625
71 | 10220.0,1668723408727,13.669921875
72 | 10366.0,1668723411631,13.703125
73 | 10512.0,1668723414577,13.83203125
74 | 10658.0,1668723417496,13.775390625
75 | 10804.0,1668723420512,13.69921875
76 | 10950.0,1668723423555,13.99609375
77 | 11096.0,1668723426590,14.388671875
78 | 11242.0,1668723429565,14.0234375
79 | 11388.0,1668723432592,14.25390625
80 | 11534.0,1668723435586,14.408203125
81 | 11680.0,1668723438543,14.341796875
82 | 11826.0,1668723441477,14.41015625
83 | 11972.0,1668723444481,14.509765625
84 | 12118.0,1668723447475,14.5625
85 | 12264.0,1668723450431,14.8203125
86 | 12410.0,1668723453351,14.68359375
87 | 12556.0,1668723456319,15.0078125
88 | 12702.0,1668723459301,14.8203125
89 | 12848.0,1668723462271,14.9609375
90 | 12994.0,1668723465203,14.685546875
91 | 13140.0,1668723468190,14.966796875
92 | 13286.0,1668723471099,15.4375
93 | 13432.0,1668723473967,15.154296875
94 | 13578.0,1668723476912,15.56640625
95 | 13724.0,1668723479815,14.88671875
96 | 13870.0,1668723482708,15.30078125
97 | 14016.0,1668723485617,15.419921875
98 | 14162.0,1668723488575,15.322265625
99 | 14308.0,1668723491485,15.03515625
100 | 14454.0,1668723494503,15.60546875
101 | 14600.0,1668723497485,15.576171875
102 | 14746.0,1668723500409,15.388671875
103 | 14892.0,1668723503375,15.26953125
104 | 15038.0,1668723506337,15.81640625
105 | 15184.0,1668723509229,15.255859375
106 | 15330.0,1668723512148,15.59375
107 | 15476.0,1668723515091,15.5625
108 | 15622.0,1668723518053,15.01953125
109 | 15768.0,1668723521065,15.666015625
110 | 15914.0,1668723524056,15.654296875
111 | 16060.0,1668723527090,15.591796875
112 | 16206.0,1668723529981,15.859375
113 | 16352.0,1668723532928,15.775390625
114 | 16498.0,1668723535841,15.54296875
115 | 16644.0,1668723538784,15.720703125
116 | 16790.0,1668723541790,15.82421875
117 | 16936.0,1668723544819,15.86328125
118 | 17082.0,1668723547705,15.615234375
119 | 17228.0,1668723550588,16.09375
120 | 17374.0,1668723553509,16.263671875
121 | 17520.0,1668723556527,15.994140625
122 | 17666.0,1668723559470,15.546875
123 | 17812.0,1668723562402,15.943359375
124 | 17958.0,1668723565333,15.974609375
125 | 18104.0,1668723568300,16.341796875
126 | 18250.0,1668723571237,16.20703125
127 | 18396.0,1668723574092,16.228515625
128 | 18542.0,1668723577087,15.837890625
129 | 18688.0,1668723580052,16.091796875
130 | 18834.0,1668723583048,16.27734375
131 | 18980.0,1668723586132,16.474609375
132 | 19126.0,1668723589090,16.12890625
133 | 19272.0,1668723592019,16.0703125
134 | 19418.0,1668723594981,16.474609375
135 | 19564.0,1668723597962,16.533203125
136 | 19710.0,1668723600911,16.349609375
137 | 19856.0,1668723603876,16.271484375
138 | 20002.0,1668723606852,16.41015625
139 | 20148.0,1668723609817,16.666015625
140 | 20294.0,1668723612733,15.9921875
141 | 20440.0,1668723615677,16.208984375
142 | 20586.0,1668723618627,16.38671875
143 | 20732.0,1668723621547,16.056640625
144 | 20878.0,1668723624474,16.201171875
145 | 21024.0,1668723627407,16.220703125
146 | 21170.0,1668723630334,16.4765625
147 | 21316.0,1668723633241,16.158203125
148 | 21462.0,1668723636246,16.0
149 | 21608.0,1668723639268,16.228515625
150 | 21754.0,1668723642226,16.4296875
151 | 21900.0,1668723645173,16.263671875
152 | 22046.0,1668723648177,16.443359375
153 | 22192.0,1668723651061,16.509765625
154 | 22338.0,1668723653963,16.892578125
155 | 22484.0,1668723656873,16.4453125
156 | 22630.0,1668723659903,17.046875
157 | 22776.0,1668723662906,16.40625
158 | 22922.0,1668723665798,16.34375
159 | 23068.0,1668723668717,16.65234375
160 | 23214.0,1668723671714,16.462890625
161 | 23360.0,1668723674653,16.751953125
162 | 23506.0,1668723677650,16.7421875
163 | 23652.0,1668723680603,17.119140625
164 | 23798.0,1668723683547,16.59765625
165 | 23944.0,1668723686475,16.423828125
166 | 24090.0,1668723689408,16.884765625
167 | 24236.0,1668723692326,16.9453125
168 | 24382.0,1668723695270,16.6953125
169 | 24528.0,1668723698284,17.109375
170 | 24674.0,1668723701255,16.861328125
171 | 24820.0,1668723704207,16.98828125
172 | 24966.0,1668723707193,16.962890625
173 | 25112.0,1668723710116,16.658203125
174 | 25258.0,1668723713026,16.90234375
175 | 25404.0,1668723715902,16.59765625
176 | 25550.0,1668723718935,16.9375
177 | 25696.0,1668723721917,16.916015625
178 | 25842.0,1668723724835,16.775390625
179 | 25988.0,1668723727867,16.81640625
180 | 26134.0,1668723730811,16.62890625
181 | 26280.0,1668723733942,16.947265625
182 | 26426.0,1668723736984,16.720703125
183 | 26572.0,1668723740048,16.615234375
184 | 26718.0,1668723742952,16.669921875
185 | 26864.0,1668723745880,16.7421875
186 | 27010.0,1668723748878,17.107421875
187 | 27156.0,1668723751871,16.92578125
188 | 27302.0,1668723754827,17.1171875
189 | 27448.0,1668723757777,16.955078125
190 | 27594.0,1668723760671,16.591796875
191 | 27740.0,1668723763502,16.888671875
192 | 27886.0,1668723766617,16.90625
193 | 28032.0,1668723769501,16.6328125
194 | 28178.0,1668723772364,16.623046875
195 | 28324.0,1668723775217,16.91015625
196 | 28470.0,1668723778204,17.29296875
197 | 28616.0,1668723781057,16.5859375
198 | 28762.0,1668723784024,16.6484375
199 | 28908.0,1668723786984,17.357421875
200 | 29054.0,1668723789946,16.8984375
201 | 29200.0,1668723792895,17.1953125
202 |
--------------------------------------------------------------------------------
/snake/data/appendix/return/return_bootstrap_no_outer_no_norm.csv:
--------------------------------------------------------------------------------
1 | 3.0,1668722619197,0.1328125
2 | 146.0,1668722640028,0.16796875
3 | 292.0,1668722645854,0.21875
4 | 438.0,1668722651650,1.458984375
5 | 584.0,1668722657527,3.625
6 | 730.0,1668722663141,3.990234375
7 | 876.0,1668722668885,4.234375
8 | 1022.0,1668722674617,4.57421875
9 | 1168.0,1668722680358,4.61328125
10 | 1314.0,1668722686093,4.646484375
11 | 1460.0,1668722691785,5.3125
12 | 1606.0,1668722697437,5.103515625
13 | 1752.0,1668722703235,5.53515625
14 | 1898.0,1668722708898,5.640625
15 | 2044.0,1668722714739,5.671875
16 | 2190.0,1668722720660,5.6640625
17 | 2336.0,1668722726364,6.224609375
18 | 2482.0,1668722732005,6.1640625
19 | 2628.0,1668722737803,6.025390625
20 | 2774.0,1668722743800,6.353515625
21 | 2920.0,1668722749562,6.39453125
22 | 3066.0,1668722755350,6.638671875
23 | 3212.0,1668722761124,6.541015625
24 | 3358.0,1668722766927,7.033203125
25 | 3504.0,1668722772705,6.435546875
26 | 3650.0,1668722778522,6.80859375
27 | 3796.0,1668722784271,6.65234375
28 | 3942.0,1668722790034,7.447265625
29 | 4088.0,1668722795789,7.05859375
30 | 4234.0,1668722801481,7.560546875
31 | 4380.0,1668722807333,7.66015625
32 | 4526.0,1668722813076,7.87109375
33 | 4672.0,1668722818834,8.384765625
34 | 4818.0,1668722824472,8.107421875
35 | 4964.0,1668722830204,8.70703125
36 | 5110.0,1668722836025,8.7578125
37 | 5256.0,1668722841726,8.2734375
38 | 5402.0,1668722847339,9.134765625
39 | 5548.0,1668722852929,9.25
40 | 5694.0,1668722858607,9.380859375
41 | 5840.0,1668722864225,9.373046875
42 | 5986.0,1668722869908,9.5
43 | 6132.0,1668722875683,9.91015625
44 | 6278.0,1668722881367,10.482421875
45 | 6424.0,1668722887141,10.7890625
46 | 6570.0,1668722892839,10.759765625
47 | 6716.0,1668722898701,11.0
48 | 6862.0,1668722904312,10.791015625
49 | 7008.0,1668722910087,11.10546875
50 | 7154.0,1668722915758,11.486328125
51 | 7300.0,1668722921420,11.8828125
52 | 7446.0,1668722927185,12.400390625
53 | 7592.0,1668722932867,12.267578125
54 | 7738.0,1668722938582,12.279296875
55 | 7884.0,1668722944416,12.427734375
56 | 8030.0,1668722950303,12.724609375
57 | 8176.0,1668722956098,12.482421875
58 | 8322.0,1668722961755,12.580078125
59 | 8468.0,1668722967472,13.337890625
60 | 8614.0,1668722973222,13.10546875
61 | 8760.0,1668722979002,13.607421875
62 | 8906.0,1668722984751,13.703125
63 | 9052.0,1668722990496,13.484375
64 | 9198.0,1668722996216,13.451171875
65 | 9344.0,1668723001906,14.0859375
66 | 9490.0,1668723007633,14.357421875
67 | 9636.0,1668723013453,14.505859375
68 | 9782.0,1668723019300,14.357421875
69 | 9928.0,1668723025025,14.5
70 | 10074.0,1668723030696,14.58984375
71 | 10220.0,1668723036496,14.451171875
72 | 10366.0,1668723042147,14.431640625
73 | 10512.0,1668723047937,14.759765625
74 | 10658.0,1668723053729,14.748046875
75 | 10804.0,1668723059509,14.59375
76 | 10950.0,1668723065218,14.919921875
77 | 11096.0,1668723071050,14.951171875
78 | 11242.0,1668723076833,15.271484375
79 | 11388.0,1668723082607,15.20703125
80 | 11534.0,1668723088315,14.990234375
81 | 11680.0,1668723094012,15.1640625
82 | 11826.0,1668723099722,15.47265625
83 | 11972.0,1668723105532,15.541015625
84 | 12118.0,1668723111235,15.7265625
85 | 12264.0,1668723116936,15.697265625
86 | 12410.0,1668723122630,15.26953125
87 | 12556.0,1668723128438,15.72265625
88 | 12702.0,1668723134292,16.083984375
89 | 12848.0,1668723140026,15.923828125
90 | 12994.0,1668723145687,16.046875
91 | 13140.0,1668723151384,16.3203125
92 | 13286.0,1668723157110,16.23046875
93 | 13432.0,1668723162716,16.359375
94 | 13578.0,1668723168363,16.75
95 | 13724.0,1668723174022,16.55078125
96 | 13870.0,1668723179751,16.431640625
97 | 14016.0,1668723185409,16.75
98 | 14162.0,1668723191036,16.263671875
99 | 14308.0,1668723196819,16.861328125
100 | 14454.0,1668723202442,16.787109375
101 | 14600.0,1668723208221,17.412109375
102 | 14746.0,1668723213865,17.044921875
103 | 14892.0,1668723219625,17.08203125
104 | 15038.0,1668723225379,17.294921875
105 | 15184.0,1668723231053,16.6171875
106 | 15330.0,1668723236781,17.359375
107 | 15476.0,1668723242545,17.640625
108 | 15622.0,1668723248259,16.830078125
109 | 15768.0,1668723254016,17.14453125
110 | 15914.0,1668723259906,17.533203125
111 | 16060.0,1668723265740,17.5703125
112 | 16206.0,1668723271470,17.556640625
113 | 16352.0,1668723277358,17.6015625
114 | 16498.0,1668723283080,17.60546875
115 | 16644.0,1668723288988,17.693359375
116 | 16790.0,1668723294686,17.771484375
117 | 16936.0,1668723300475,17.123046875
118 | 17082.0,1668723306407,17.744140625
119 | 17228.0,1668723312107,18.01953125
120 | 17374.0,1668723317927,17.8984375
121 | 17520.0,1668723323617,17.455078125
122 | 17666.0,1668723329307,17.86328125
123 | 17812.0,1668723334958,17.154296875
124 | 17958.0,1668723340804,18.560546875
125 | 18104.0,1668723346540,18.130859375
126 | 18250.0,1668723352150,18.03515625
127 | 18396.0,1668723357881,18.5
128 | 18542.0,1668723363667,17.99609375
129 | 18688.0,1668723369414,17.68359375
130 | 18834.0,1668723375124,18.001953125
131 | 18980.0,1668723380993,18.318359375
132 | 19126.0,1668723386615,18.185546875
133 | 19272.0,1668723392258,18.5
134 | 19418.0,1668723397953,18.427734375
135 | 19564.0,1668723403594,18.51171875
136 | 19710.0,1668723409217,18.392578125
137 | 19856.0,1668723415006,18.486328125
138 | 20002.0,1668723420733,17.962890625
139 | 20148.0,1668723426617,18.3828125
140 | 20294.0,1668723432389,18.623046875
141 | 20440.0,1668723438213,18.5859375
142 | 20586.0,1668723443997,18.44921875
143 | 20732.0,1668723449726,18.4609375
144 | 20878.0,1668723455467,18.51953125
145 | 21024.0,1668723461158,18.59765625
146 | 21170.0,1668723466870,19.298828125
147 | 21316.0,1668723472513,18.619140625
148 | 21462.0,1668723478178,19.0703125
149 | 21608.0,1668723483915,18.5234375
150 | 21754.0,1668723489649,19.423828125
151 | 21900.0,1668723495350,18.9375
152 | 22046.0,1668723501107,19.021484375
153 | 22192.0,1668723506790,18.94140625
154 | 22338.0,1668723512511,19.171875
155 | 22484.0,1668723518326,19.357421875
156 | 22630.0,1668723524410,19.34765625
157 | 22776.0,1668723530157,19.197265625
158 | 22922.0,1668723535915,18.9453125
159 | 23068.0,1668723541714,19.548828125
160 | 23214.0,1668723547515,20.001953125
161 | 23360.0,1668723553261,19.470703125
162 | 23506.0,1668723559000,19.6796875
163 | 23652.0,1668723564710,19.603515625
164 | 23798.0,1668723570466,19.767578125
165 | 23944.0,1668723576104,19.845703125
166 | 24090.0,1668723581868,20.306640625
167 | 24236.0,1668723587643,20.60546875
168 | 24382.0,1668723593429,21.0703125
169 | 24528.0,1668723599213,20.18359375
170 | 24674.0,1668723604930,21.396484375
171 | 24820.0,1668723610673,21.44921875
172 | 24966.0,1668723616481,21.0859375
173 | 25112.0,1668723622140,20.41796875
174 | 25258.0,1668723627902,21.591796875
175 | 25404.0,1668723633692,21.63671875
176 | 25550.0,1668723639407,22.3828125
177 | 25696.0,1668723645060,21.3046875
178 | 25842.0,1668723650840,21.978515625
179 | 25988.0,1668723656534,23.203125
180 | 26134.0,1668723662184,22.7109375
181 | 26280.0,1668723667968,23.30078125
182 | 26426.0,1668723673616,23.57421875
183 | 26572.0,1668723679336,23.759765625
184 | 26718.0,1668723685088,24.6953125
185 | 26864.0,1668723690820,25.328125
186 | 27010.0,1668723696535,25.3671875
187 | 27156.0,1668723702282,26.216796875
188 | 27302.0,1668723708087,26.984375
189 | 27448.0,1668723713670,28.15625
190 | 27594.0,1668723719435,27.2421875
191 | 27740.0,1668723725158,28.361328125
192 | 27886.0,1668723731004,29.01953125
193 | 28032.0,1668723736800,28.388671875
194 | 28178.0,1668723742499,29.0625
195 | 28324.0,1668723748207,29.712890625
196 | 28470.0,1668723753972,30.822265625
197 | 28616.0,1668723759729,30.857421875
198 | 28762.0,1668723765471,30.544921875
199 | 28908.0,1668723771156,31.25
200 | 29054.0,1668723776872,31.35546875
201 | 29200.0,1668723782544,31.185546875
202 |
--------------------------------------------------------------------------------
/snake/data/appendix/return/return_bootstrap_no_outer_norm.csv:
--------------------------------------------------------------------------------
1 | 3.0,1668726176304,0.1328125
2 | 146.0,1668726195787,0.482421875
3 | 292.0,1668726201637,3.060546875
4 | 438.0,1668726207614,3.86328125
5 | 584.0,1668726213336,4.142578125
6 | 730.0,1668726219094,4.36328125
7 | 876.0,1668726224913,5.1015625
8 | 1022.0,1668726230744,4.978515625
9 | 1168.0,1668726236583,5.68359375
10 | 1314.0,1668726242399,4.96875
11 | 1460.0,1668726248214,5.515625
12 | 1606.0,1668726254036,6.123046875
13 | 1752.0,1668726259933,6.259765625
14 | 1898.0,1668726265912,6.41796875
15 | 2044.0,1668726271703,6.2578125
16 | 2190.0,1668726277593,6.05859375
17 | 2336.0,1668726283479,6.634765625
18 | 2482.0,1668726289219,6.552734375
19 | 2628.0,1668726294913,6.28125
20 | 2774.0,1668726300709,6.505859375
21 | 2920.0,1668726306543,6.306640625
22 | 3066.0,1668726312335,6.5078125
23 | 3212.0,1668726318156,6.783203125
24 | 3358.0,1668726324064,6.83984375
25 | 3504.0,1668726329855,6.373046875
26 | 3650.0,1668726335682,6.615234375
27 | 3796.0,1668726341555,6.44140625
28 | 3942.0,1668726347521,6.853515625
29 | 4088.0,1668726353462,6.611328125
30 | 4234.0,1668726359238,6.69921875
31 | 4380.0,1668726365187,6.83203125
32 | 4526.0,1668726371142,6.853515625
33 | 4672.0,1668726377172,7.193359375
34 | 4818.0,1668726383132,6.818359375
35 | 4964.0,1668726388941,7.1171875
36 | 5110.0,1668726394825,6.78125
37 | 5256.0,1668726400580,7.18359375
38 | 5402.0,1668726406397,7.25
39 | 5548.0,1668726412196,7.076171875
40 | 5694.0,1668726418025,7.18359375
41 | 5840.0,1668726423794,7.138671875
42 | 5986.0,1668726429613,7.2109375
43 | 6132.0,1668726435531,7.40625
44 | 6278.0,1668726441367,7.447265625
45 | 6424.0,1668726447299,7.509765625
46 | 6570.0,1668726453290,7.52734375
47 | 6716.0,1668726459307,7.509765625
48 | 6862.0,1668726465173,7.306640625
49 | 7008.0,1668726471197,8.1796875
50 | 7154.0,1668726477248,7.4140625
51 | 7300.0,1668726483092,7.791015625
52 | 7446.0,1668726488945,8.02734375
53 | 7592.0,1668726494638,8.18359375
54 | 7738.0,1668726500284,8.3203125
55 | 7884.0,1668726506030,8.447265625
56 | 8030.0,1668726511687,8.5703125
57 | 8176.0,1668726517465,8.595703125
58 | 8322.0,1668726523288,8.896484375
59 | 8468.0,1668726529121,8.84765625
60 | 8614.0,1668726534863,9.078125
61 | 8760.0,1668726540646,9.056640625
62 | 8906.0,1668726546472,9.525390625
63 | 9052.0,1668726552178,9.419921875
64 | 9198.0,1668726557821,9.375
65 | 9344.0,1668726563522,9.64453125
66 | 9490.0,1668726569288,9.91015625
67 | 9636.0,1668726575047,10.111328125
68 | 9782.0,1668726580800,9.943359375
69 | 9928.0,1668726586523,10.10546875
70 | 10074.0,1668726592249,10.072265625
71 | 10220.0,1668726597902,10.650390625
72 | 10366.0,1668726603786,10.28125
73 | 10512.0,1668726609549,10.529296875
74 | 10658.0,1668726615355,10.76171875
75 | 10804.0,1668726621214,10.798828125
76 | 10950.0,1668726626992,11.3828125
77 | 11096.0,1668726632753,10.67578125
78 | 11242.0,1668726638599,11.576171875
79 | 11388.0,1668726644345,11.404296875
80 | 11534.0,1668726650056,11.255859375
81 | 11680.0,1668726655842,11.716796875
82 | 11826.0,1668726661702,11.58984375
83 | 11972.0,1668726667535,11.759765625
84 | 12118.0,1668726673359,12.0
85 | 12264.0,1668726679202,11.986328125
86 | 12410.0,1668726684918,11.677734375
87 | 12556.0,1668726690644,12.447265625
88 | 12702.0,1668726696372,12.2109375
89 | 12848.0,1668726702172,12.3203125
90 | 12994.0,1668726708026,12.375
91 | 13140.0,1668726713764,12.568359375
92 | 13286.0,1668726719548,12.490234375
93 | 13432.0,1668726725285,12.505859375
94 | 13578.0,1668726731045,13.189453125
95 | 13724.0,1668726736955,12.740234375
96 | 13870.0,1668726742690,12.779296875
97 | 14016.0,1668726748456,13.095703125
98 | 14162.0,1668726754285,12.990234375
99 | 14308.0,1668726760149,13.087890625
100 | 14454.0,1668726765949,13.177734375
101 | 14600.0,1668726771767,12.640625
102 | 14746.0,1668726777571,13.255859375
103 | 14892.0,1668726783327,13.28515625
104 | 15038.0,1668726789094,13.068359375
105 | 15184.0,1668726795074,13.361328125
106 | 15330.0,1668726800939,13.34375
107 | 15476.0,1668726806671,12.939453125
108 | 15622.0,1668726812432,12.92578125
109 | 15768.0,1668726818166,12.94140625
110 | 15914.0,1668726823930,12.779296875
111 | 16060.0,1668726829650,13.359375
112 | 16206.0,1668726835530,12.884765625
113 | 16352.0,1668726841307,13.21875
114 | 16498.0,1668726847147,12.97265625
115 | 16644.0,1668726852981,13.734375
116 | 16790.0,1668726858855,13.306640625
117 | 16936.0,1668726864543,13.080078125
118 | 17082.0,1668726870264,13.173828125
119 | 17228.0,1668726876001,13.330078125
120 | 17374.0,1668726881723,13.421875
121 | 17520.0,1668726887686,13.35546875
122 | 17666.0,1668726893504,13.3515625
123 | 17812.0,1668726899332,13.2578125
124 | 17958.0,1668726905220,13.3359375
125 | 18104.0,1668726911033,13.640625
126 | 18250.0,1668726916856,13.44140625
127 | 18396.0,1668726922579,13.447265625
128 | 18542.0,1668726928392,13.01171875
129 | 18688.0,1668726934203,13.388671875
130 | 18834.0,1668726939947,13.01953125
131 | 18980.0,1668726945715,13.322265625
132 | 19126.0,1668726951473,13.37109375
133 | 19272.0,1668726957272,13.46875
134 | 19418.0,1668726963106,13.2734375
135 | 19564.0,1668726968919,12.9375
136 | 19710.0,1668726974705,13.44921875
137 | 19856.0,1668726980479,13.197265625
138 | 20002.0,1668726986303,13.205078125
139 | 20148.0,1668726992186,13.12109375
140 | 20294.0,1668726998108,12.708984375
141 | 20440.0,1668727003907,13.298828125
142 | 20586.0,1668727009881,12.134765625
143 | 20732.0,1668727015758,12.857421875
144 | 20878.0,1668727021683,12.44921875
145 | 21024.0,1668727027577,12.931640625
146 | 21170.0,1668727033537,13.15234375
147 | 21316.0,1668727039497,12.505859375
148 | 21462.0,1668727045431,12.1484375
149 | 21608.0,1668727051179,13.0078125
150 | 21754.0,1668727057083,12.876953125
151 | 21900.0,1668727062937,12.806640625
152 | 22046.0,1668727068822,12.92578125
153 | 22192.0,1668727074650,13.0390625
154 | 22338.0,1668727080481,12.98828125
155 | 22484.0,1668727086427,13.26171875
156 | 22630.0,1668727092241,13.197265625
157 | 22776.0,1668727098143,12.86328125
158 | 22922.0,1668727103974,12.908203125
159 | 23068.0,1668727109757,13.291015625
160 | 23214.0,1668727115639,12.90625
161 | 23360.0,1668727121612,13.12890625
162 | 23506.0,1668727127593,13.26171875
163 | 23652.0,1668727133444,13.259765625
164 | 23798.0,1668727139309,13.515625
165 | 23944.0,1668727145383,13.4921875
166 | 24090.0,1668727151326,13.51171875
167 | 24236.0,1668727157297,13.52734375
168 | 24382.0,1668727163195,13.865234375
169 | 24528.0,1668727169206,13.69140625
170 | 24674.0,1668727175068,13.32421875
171 | 24820.0,1668727180935,13.33203125
172 | 24966.0,1668727186800,13.4140625
173 | 25112.0,1668727192658,13.23828125
174 | 25258.0,1668727198503,13.634765625
175 | 25404.0,1668727204475,12.8984375
176 | 25550.0,1668727210328,13.3125
177 | 25696.0,1668727216147,13.70703125
178 | 25842.0,1668727221931,13.796875
179 | 25988.0,1668727227675,13.443359375
180 | 26134.0,1668727233400,13.623046875
181 | 26280.0,1668727239274,13.486328125
182 | 26426.0,1668727245186,13.78125
183 | 26572.0,1668727250984,13.88671875
184 | 26718.0,1668727256776,13.939453125
185 | 26864.0,1668727262592,13.734375
186 | 27010.0,1668727268466,13.75390625
187 | 27156.0,1668727274248,13.5546875
188 | 27302.0,1668727280098,13.26171875
189 | 27448.0,1668727285757,13.546875
190 | 27594.0,1668727291492,13.529296875
191 | 27740.0,1668727297352,13.740234375
192 | 27886.0,1668727303050,13.951171875
193 | 28032.0,1668727308907,13.984375
194 | 28178.0,1668727314686,13.212890625
195 | 28324.0,1668727320550,13.724609375
196 | 28470.0,1668727326383,13.73828125
197 | 28616.0,1668727332393,13.666015625
198 | 28762.0,1668727338229,13.390625
199 | 28908.0,1668727344047,14.0546875
200 | 29054.0,1668727349833,13.919921875
201 | 29200.0,1668727355880,14.05078125
202 |
--------------------------------------------------------------------------------
/snake/data/bootstrap_outer_critic_return/MET-2195__eval_episode_reward_stochastic_policy.csv:
--------------------------------------------------------------------------------
1 | 3.0,1667994813080,0.1484375
2 | 146.0,1667994833793,0.126953125
3 | 292.0,1667994839668,0.212890625
4 | 438.0,1667994845709,1.52734375
5 | 584.0,1667994851582,3.611328125
6 | 730.0,1667994857340,3.849609375
7 | 876.0,1667994863289,4.201171875
8 | 1022.0,1667994869143,4.740234375
9 | 1168.0,1667994875019,5.052734375
10 | 1314.0,1667994880867,5.59375
11 | 1460.0,1667994886513,5.349609375
12 | 1606.0,1667994892262,5.4453125
13 | 1752.0,1667994898097,5.880859375
14 | 1898.0,1667994904052,5.958984375
15 | 2044.0,1667994909875,6.05859375
16 | 2190.0,1667994915733,6.005859375
17 | 2336.0,1667994921625,6.037109375
18 | 2482.0,1667994927434,6.42578125
19 | 2628.0,1667994933147,6.380859375
20 | 2774.0,1667994938926,6.396484375
21 | 2920.0,1667994944691,6.423828125
22 | 3066.0,1667994950507,6.54296875
23 | 3212.0,1667994956285,6.501953125
24 | 3358.0,1667994962115,6.21484375
25 | 3504.0,1667994967933,6.328125
26 | 3650.0,1667994973731,6.3828125
27 | 3796.0,1667994979591,6.353515625
28 | 3942.0,1667994985608,6.58203125
29 | 4088.0,1667994991503,6.4921875
30 | 4234.0,1667994997258,6.4375
31 | 4380.0,1667995003073,6.998046875
32 | 4526.0,1667995008857,6.40625
33 | 4672.0,1667995014681,7.193359375
34 | 4818.0,1667995020445,7.173828125
35 | 4964.0,1667995026258,7.15234375
36 | 5110.0,1667995032111,7.119140625
37 | 5256.0,1667995037843,7.732421875
38 | 5402.0,1667995043840,7.37109375
39 | 5548.0,1667995049690,7.76953125
40 | 5694.0,1667995055531,7.44140625
41 | 5840.0,1667995061415,7.83984375
42 | 5986.0,1667995067242,7.98046875
43 | 6132.0,1667995072900,8.375
44 | 6278.0,1667995078562,8.810546875
45 | 6424.0,1667995084444,8.556640625
46 | 6570.0,1667995090169,8.87109375
47 | 6716.0,1667995095998,9.251953125
48 | 6862.0,1667995101806,9.734375
49 | 7008.0,1667995107636,9.765625
50 | 7154.0,1667995113440,9.94921875
51 | 7300.0,1667995119170,10.49609375
52 | 7446.0,1667995125049,10.69140625
53 | 7592.0,1667995130983,10.990234375
54 | 7738.0,1667995136954,11.427734375
55 | 7884.0,1667995142796,11.916015625
56 | 8030.0,1667995148612,11.9296875
57 | 8176.0,1667995154334,12.25
58 | 8322.0,1667995160132,12.615234375
59 | 8468.0,1667995165970,12.7578125
60 | 8614.0,1667995171794,12.8828125
61 | 8760.0,1667995177629,13.70703125
62 | 8906.0,1667995183430,13.87890625
63 | 9052.0,1667995189389,14.150390625
64 | 9198.0,1667995195211,14.55859375
65 | 9344.0,1667995201166,15.099609375
66 | 9490.0,1667995207160,14.92578125
67 | 9636.0,1667995213240,15.380859375
68 | 9782.0,1667995219066,15.79296875
69 | 9928.0,1667995225622,15.953125
70 | 10074.0,1667995231599,16.103515625
71 | 10220.0,1667995237550,16.302734375
72 | 10366.0,1667995243444,16.3203125
73 | 10512.0,1667995249140,16.234375
74 | 10658.0,1667995255036,16.16015625
75 | 10804.0,1667995260930,16.265625
76 | 10950.0,1667995266944,16.134765625
77 | 11096.0,1667995272956,16.94921875
78 | 11242.0,1667995278848,17.03515625
79 | 11388.0,1667995284901,17.591796875
80 | 11534.0,1667995290930,16.916015625
81 | 11680.0,1667995296926,17.4453125
82 | 11826.0,1667995302850,16.880859375
83 | 11972.0,1667995308762,17.30078125
84 | 12118.0,1667995314730,17.732421875
85 | 12264.0,1667995320661,17.912109375
86 | 12410.0,1667995326638,17.587890625
87 | 12556.0,1667995332542,18.1171875
88 | 12702.0,1667995338459,17.90234375
89 | 12848.0,1667995344430,17.927734375
90 | 12994.0,1667995350351,17.6328125
91 | 13140.0,1667995356317,17.9453125
92 | 13286.0,1667995362287,18.291015625
93 | 13432.0,1667995368168,17.65234375
94 | 13578.0,1667995374071,18.21875
95 | 13724.0,1667995380123,18.28125
96 | 13870.0,1667995386023,18.1875
97 | 14016.0,1667995391939,18.130859375
98 | 14162.0,1667995397974,18.974609375
99 | 14308.0,1667995403885,18.6171875
100 | 14454.0,1667995409772,18.75
101 | 14600.0,1667995415753,18.625
102 | 14746.0,1667995421752,18.876953125
103 | 14892.0,1667995427663,19.267578125
104 | 15038.0,1667995433605,19.34765625
105 | 15184.0,1667995439535,19.41796875
106 | 15330.0,1667995445407,19.83984375
107 | 15476.0,1667995451416,19.79296875
108 | 15622.0,1667995457167,20.18359375
109 | 15768.0,1667995462998,19.564453125
110 | 15914.0,1667995468767,20.078125
111 | 16060.0,1667995474635,20.568359375
112 | 16206.0,1667995480424,20.16015625
113 | 16352.0,1667995486275,20.755859375
114 | 16498.0,1667995492313,20.703125
115 | 16644.0,1667995498120,20.38671875
116 | 16790.0,1667995504181,21.5703125
117 | 16936.0,1667995510054,20.720703125
118 | 17082.0,1667995515917,21.12890625
119 | 17228.0,1667995521742,21.681640625
120 | 17374.0,1667995527611,22.0
121 | 17520.0,1667995533541,22.8828125
122 | 17666.0,1667995539381,22.4921875
123 | 17812.0,1667995545272,22.68359375
124 | 17958.0,1667995551131,23.044921875
125 | 18104.0,1667995556973,23.62890625
126 | 18250.0,1667995562927,24.1796875
127 | 18396.0,1667995568679,24.24609375
128 | 18542.0,1667995574573,24.10546875
129 | 18688.0,1667995580510,24.322265625
130 | 18834.0,1667995586366,25.462890625
131 | 18980.0,1667995592325,25.7265625
132 | 19126.0,1667995598196,27.662109375
133 | 19272.0,1667995604231,27.09765625
134 | 19418.0,1667995610194,28.6640625
135 | 19564.0,1667995616033,29.3046875
136 | 19710.0,1667995621869,30.00390625
137 | 19856.0,1667995627806,29.5546875
138 | 20002.0,1667995633722,29.63671875
139 | 20148.0,1667995639557,30.845703125
140 | 20294.0,1667995645636,30.8671875
141 | 20440.0,1667995651476,30.66015625
142 | 20586.0,1667995657363,31.708984375
143 | 20732.0,1667995663219,32.171875
144 | 20878.0,1667995668993,32.44140625
145 | 21024.0,1667995674930,31.77734375
146 | 21170.0,1667995680800,33.005859375
147 | 21316.0,1667995686902,32.884765625
148 | 21462.0,1667995692776,33.498046875
149 | 21608.0,1667995698724,33.087890625
150 | 21754.0,1667995704785,33.486328125
151 | 21900.0,1667995710755,33.15234375
152 | 22046.0,1667995716719,33.376953125
153 | 22192.0,1667995722735,33.4609375
154 | 22338.0,1667995728700,33.8828125
155 | 22484.0,1667995734668,33.349609375
156 | 22630.0,1667995740594,33.75390625
157 | 22776.0,1667995746579,33.111328125
158 | 22922.0,1667995752950,32.490234375
159 | 23068.0,1667995759828,33.837890625
160 | 23214.0,1667995766824,33.9296875
161 | 23360.0,1667995773799,33.63671875
162 | 23506.0,1667995780744,33.52734375
163 | 23652.0,1667995787983,33.5546875
164 | 23798.0,1667995794921,33.697265625
165 | 23944.0,1667995801292,34.21875
166 | 24090.0,1667995807280,33.951171875
167 | 24236.0,1667995813292,33.677734375
168 | 24382.0,1667995819288,33.859375
169 | 24528.0,1667995825341,33.853515625
170 | 24674.0,1667995831300,33.537109375
171 | 24820.0,1667995837385,34.171875
172 | 24966.0,1667995843389,34.580078125
173 | 25112.0,1667995849396,34.0234375
174 | 25258.0,1667995855246,33.9765625
175 | 25404.0,1667995861101,33.689453125
176 | 25550.0,1667995867022,33.69921875
177 | 25696.0,1667995872920,34.1171875
178 | 25842.0,1667995878835,33.640625
179 | 25988.0,1667995884899,32.494140625
180 | 26134.0,1667995890730,33.984375
181 | 26280.0,1667995896584,34.189453125
182 | 26426.0,1667995902459,33.87109375
183 | 26572.0,1667995908499,33.392578125
184 | 26718.0,1667995914492,33.392578125
185 | 26864.0,1667995920445,33.568359375
186 | 27010.0,1667995926292,34.341796875
187 | 27156.0,1667995932242,33.759765625
188 | 27302.0,1667995938174,34.267578125
189 | 27448.0,1667995944157,34.447265625
190 | 27594.0,1667995949918,33.857421875
191 | 27740.0,1667995955812,33.857421875
192 | 27886.0,1667995961815,34.212890625
193 | 28032.0,1667995967779,33.83203125
194 | 28178.0,1667995973678,33.677734375
195 | 28324.0,1667995979587,34.337890625
196 | 28470.0,1667995985513,33.505859375
197 | 28616.0,1667995991466,34.1875
198 | 28762.0,1667995997360,33.986328125
199 | 28908.0,1667996003288,34.248046875
200 | 29054.0,1667996009218,34.5
201 | 29200.0,1667996015216,34.625
202 |
--------------------------------------------------------------------------------
/snake/data/bootstrap_return/MET-2166__eval_episode_reward_stochastic_policy.csv:
--------------------------------------------------------------------------------
1 | 3.0,1667988472883,0.158203125
2 | 146.0,1667988491997,0.150390625
3 | 292.0,1667988497527,0.228515625
4 | 438.0,1667988503102,1.978515625
5 | 584.0,1667988508577,3.443359375
6 | 730.0,1667988514213,3.65234375
7 | 876.0,1667988519766,4.08203125
8 | 1022.0,1667988525253,4.197265625
9 | 1168.0,1667988530896,4.51953125
10 | 1314.0,1667988536374,4.685546875
11 | 1460.0,1667988542058,4.74609375
12 | 1606.0,1667988547740,5.15625
13 | 1752.0,1667988553696,4.541015625
14 | 1898.0,1667988559209,4.966796875
15 | 2044.0,1667988564737,4.95703125
16 | 2190.0,1667988570263,5.501953125
17 | 2336.0,1667988575738,5.712890625
18 | 2482.0,1667988581225,5.7421875
19 | 2628.0,1667988586661,5.6640625
20 | 2774.0,1667988592131,6.125
21 | 2920.0,1667988597730,6.025390625
22 | 3066.0,1667988603340,6.017578125
23 | 3212.0,1667988608949,6.205078125
24 | 3358.0,1667988614494,6.4375
25 | 3504.0,1667988619953,6.5625
26 | 3650.0,1667988625421,6.8125
27 | 3796.0,1667988630924,6.568359375
28 | 3942.0,1667988636336,6.466796875
29 | 4088.0,1667988641750,6.498046875
30 | 4234.0,1667988647158,6.625
31 | 4380.0,1667988652654,6.708984375
32 | 4526.0,1667988658121,7.296875
33 | 4672.0,1667988663742,7.40234375
34 | 4818.0,1667988669165,7.515625
35 | 4964.0,1667988674675,7.49609375
36 | 5110.0,1667988680164,7.59375
37 | 5256.0,1667988685636,7.58203125
38 | 5402.0,1667988691076,7.779296875
39 | 5548.0,1667988696603,8.201171875
40 | 5694.0,1667988702133,8.20703125
41 | 5840.0,1667988707728,8.18359375
42 | 5986.0,1667988713172,8.697265625
43 | 6132.0,1667988718647,8.4609375
44 | 6278.0,1667988724117,9.16015625
45 | 6424.0,1667988729721,8.8828125
46 | 6570.0,1667988735243,9.30078125
47 | 6716.0,1667988740826,9.080078125
48 | 6862.0,1667988746237,9.6171875
49 | 7008.0,1667988751634,9.609375
50 | 7154.0,1667988757093,9.990234375
51 | 7300.0,1667988762770,9.978515625
52 | 7446.0,1667988768278,10.06640625
53 | 7592.0,1667988773813,10.177734375
54 | 7738.0,1667988779365,9.74609375
55 | 7884.0,1667988784885,9.9765625
56 | 8030.0,1667988790435,9.99609375
57 | 8176.0,1667988796066,9.94140625
58 | 8322.0,1667988801546,10.048828125
59 | 8468.0,1667988807276,9.818359375
60 | 8614.0,1667988812786,9.896484375
61 | 8760.0,1667988818223,10.158203125
62 | 8906.0,1667988823737,10.36328125
63 | 9052.0,1667988829212,9.873046875
64 | 9198.0,1667988834723,9.87890625
65 | 9344.0,1667988840216,10.15625
66 | 9490.0,1667988845697,9.58203125
67 | 9636.0,1667988851268,9.853515625
68 | 9782.0,1667988856677,9.7890625
69 | 9928.0,1667988862281,9.94140625
70 | 10074.0,1667988867854,10.06640625
71 | 10220.0,1667988873353,9.93359375
72 | 10366.0,1667988878806,9.9296875
73 | 10512.0,1667988884313,9.3515625
74 | 10658.0,1667988889787,9.3671875
75 | 10804.0,1667988895266,9.541015625
76 | 10950.0,1667988900847,9.173828125
77 | 11096.0,1667988906491,9.400390625
78 | 11242.0,1667988911988,9.33984375
79 | 11388.0,1667988917483,9.18359375
80 | 11534.0,1667988922955,9.1953125
81 | 11680.0,1667988928562,9.0
82 | 11826.0,1667988934090,9.076171875
83 | 11972.0,1667988939646,8.666015625
84 | 12118.0,1667988945037,9.283203125
85 | 12264.0,1667988950624,8.849609375
86 | 12410.0,1667988956106,8.712890625
87 | 12556.0,1667988961631,8.818359375
88 | 12702.0,1667988967223,9.005859375
89 | 12848.0,1667988973075,8.93359375
90 | 12994.0,1667988978559,8.890625
91 | 13140.0,1667988984243,9.1015625
92 | 13286.0,1667988989794,8.763671875
93 | 13432.0,1667988995424,8.68359375
94 | 13578.0,1667989001056,8.765625
95 | 13724.0,1667989006550,8.580078125
96 | 13870.0,1667989012109,8.181640625
97 | 14016.0,1667989017569,7.755859375
98 | 14162.0,1667989023090,8.154296875
99 | 14308.0,1667989028609,7.970703125
100 | 14454.0,1667989034207,7.896484375
101 | 14600.0,1667989039737,7.865234375
102 | 14746.0,1667989045202,7.986328125
103 | 14892.0,1667989050731,7.5234375
104 | 15038.0,1667989056168,7.685546875
105 | 15184.0,1667989061724,7.140625
106 | 15330.0,1667989067337,7.57421875
107 | 15476.0,1667989072853,7.130859375
108 | 15622.0,1667989078338,7.033203125
109 | 15768.0,1667989084104,7.19140625
110 | 15914.0,1667989089604,7.482421875
111 | 16060.0,1667989095122,6.5703125
112 | 16206.0,1667989100730,7.103515625
113 | 16352.0,1667989106342,7.1015625
114 | 16498.0,1667989111813,6.470703125
115 | 16644.0,1667989117288,6.80078125
116 | 16790.0,1667989122829,6.41015625
117 | 16936.0,1667989128349,6.74609375
118 | 17082.0,1667989133816,6.134765625
119 | 17228.0,1667989139393,6.515625
120 | 17374.0,1667989144943,6.34765625
121 | 17520.0,1667989150451,6.1328125
122 | 17666.0,1667989155881,5.939453125
123 | 17812.0,1667989161596,5.95703125
124 | 17958.0,1667989167227,5.798828125
125 | 18104.0,1667989172709,6.5390625
126 | 18250.0,1667989178376,6.373046875
127 | 18396.0,1667989183921,6.25390625
128 | 18542.0,1667989189437,5.814453125
129 | 18688.0,1667989195024,5.7890625
130 | 18834.0,1667989200615,5.76953125
131 | 18980.0,1667989206254,6.01171875
132 | 19126.0,1667989211877,5.70703125
133 | 19272.0,1667989217439,5.865234375
134 | 19418.0,1667989222961,6.25390625
135 | 19564.0,1667989228552,5.8125
136 | 19710.0,1667989233963,5.8828125
137 | 19856.0,1667989239593,6.177734375
138 | 20002.0,1667989245103,6.26953125
139 | 20148.0,1667989250840,5.724609375
140 | 20294.0,1667989256420,5.44921875
141 | 20440.0,1667989262128,6.06640625
142 | 20586.0,1667989267722,6.20703125
143 | 20732.0,1667989273464,6.068359375
144 | 20878.0,1667989279148,6.0
145 | 21024.0,1667989284818,5.76953125
146 | 21170.0,1667989290358,5.859375
147 | 21316.0,1667989296065,5.580078125
148 | 21462.0,1667989301627,5.5859375
149 | 21608.0,1667989307144,5.900390625
150 | 21754.0,1667989312782,5.625
151 | 21900.0,1667989318230,6.21484375
152 | 22046.0,1667989323876,6.861328125
153 | 22192.0,1667989329537,6.830078125
154 | 22338.0,1667989335045,7.0390625
155 | 22484.0,1667989340555,7.083984375
156 | 22630.0,1667989346356,7.009765625
157 | 22776.0,1667989351924,6.953125
158 | 22922.0,1667989357409,7.044921875
159 | 23068.0,1667989363104,6.93359375
160 | 23214.0,1667989368709,6.919921875
161 | 23360.0,1667989374276,6.32421875
162 | 23506.0,1667989379790,6.232421875
163 | 23652.0,1667989385458,6.30078125
164 | 23798.0,1667989391103,6.37890625
165 | 23944.0,1667989396680,6.1640625
166 | 24090.0,1667989402418,6.001953125
167 | 24236.0,1667989407991,5.697265625
168 | 24382.0,1667989413623,5.623046875
169 | 24528.0,1667989419143,5.5859375
170 | 24674.0,1667989424840,5.666015625
171 | 24820.0,1667989430399,5.22265625
172 | 24966.0,1667989435963,5.513671875
173 | 25112.0,1667989441769,5.283203125
174 | 25258.0,1667989447356,5.384765625
175 | 25404.0,1667989452914,5.4921875
176 | 25550.0,1667989458435,5.556640625
177 | 25696.0,1667989463979,6.072265625
178 | 25842.0,1667989469541,5.740234375
179 | 25988.0,1667989474971,6.935546875
180 | 26134.0,1667989480532,5.990234375
181 | 26280.0,1667989486146,5.98046875
182 | 26426.0,1667989491672,6.02734375
183 | 26572.0,1667989497182,5.798828125
184 | 26718.0,1667989502845,5.71875
185 | 26864.0,1667989508428,5.279296875
186 | 27010.0,1667989514035,5.85546875
187 | 27156.0,1667989519472,5.970703125
188 | 27302.0,1667989525010,5.8671875
189 | 27448.0,1667989530454,5.63671875
190 | 27594.0,1667989535756,6.10546875
191 | 27740.0,1667989541487,6.11328125
192 | 27886.0,1667989546947,5.62890625
193 | 28032.0,1667989552507,5.939453125
194 | 28178.0,1667989558083,6.1796875
195 | 28324.0,1667989563702,5.787109375
196 | 28470.0,1667989569260,5.765625
197 | 28616.0,1667989574740,5.84765625
198 | 28762.0,1667989580270,5.916015625
199 | 28908.0,1667989585951,5.83984375
200 | 29054.0,1667989591436,5.615234375
201 | 29200.0,1667989596990,5.74609375
202 |
--------------------------------------------------------------------------------
/snake/data/bootstrap_return/MET-2172__eval_episode_reward_stochastic_policy.csv:
--------------------------------------------------------------------------------
1 | 3.0,1667989606621,0.166015625
2 | 146.0,1667989625138,0.158203125
3 | 292.0,1667989630579,0.24609375
4 | 438.0,1667989635914,1.88671875
5 | 584.0,1667989641271,3.87890625
6 | 730.0,1667989646730,4.14453125
7 | 876.0,1667989652025,4.419921875
8 | 1022.0,1667989657482,4.892578125
9 | 1168.0,1667989662785,5.154296875
10 | 1314.0,1667989668110,5.3125
11 | 1460.0,1667989673459,5.236328125
12 | 1606.0,1667989678906,5.49609375
13 | 1752.0,1667989684313,5.83203125
14 | 1898.0,1667989689652,5.89453125
15 | 2044.0,1667989694912,6.21484375
16 | 2190.0,1667989700263,6.6875
17 | 2336.0,1667989705752,6.501953125
18 | 2482.0,1667989711098,6.703125
19 | 2628.0,1667989716563,7.486328125
20 | 2774.0,1667989721923,8.080078125
21 | 2920.0,1667989727393,8.16015625
22 | 3066.0,1667989732774,8.640625
23 | 3212.0,1667989738178,9.16015625
24 | 3358.0,1667989743588,9.669921875
25 | 3504.0,1667989748982,9.61328125
26 | 3650.0,1667989754313,10.1484375
27 | 3796.0,1667989759652,10.625
28 | 3942.0,1667989765050,9.859375
29 | 4088.0,1667989770346,11.04296875
30 | 4234.0,1667989775754,11.564453125
31 | 4380.0,1667989781054,11.720703125
32 | 4526.0,1667989786426,11.671875
33 | 4672.0,1667989791720,12.06640625
34 | 4818.0,1667989797083,12.37109375
35 | 4964.0,1667989802727,12.7421875
36 | 5110.0,1667989808253,12.75
37 | 5256.0,1667989813739,12.67578125
38 | 5402.0,1667989819118,12.6171875
39 | 5548.0,1667989824623,13.6484375
40 | 5694.0,1667989830006,12.931640625
41 | 5840.0,1667989835450,13.296875
42 | 5986.0,1667989840833,13.810546875
43 | 6132.0,1667989846274,13.958984375
44 | 6278.0,1667989851593,13.72265625
45 | 6424.0,1667989857027,14.03515625
46 | 6570.0,1667989862412,14.58203125
47 | 6716.0,1667989867866,14.72265625
48 | 6862.0,1667989873262,14.265625
49 | 7008.0,1667989878649,14.60546875
50 | 7154.0,1667989884117,14.341796875
51 | 7300.0,1667989889449,14.833984375
52 | 7446.0,1667989894976,14.98828125
53 | 7592.0,1667989900818,15.12109375
54 | 7738.0,1667989906292,15.158203125
55 | 7884.0,1667989911704,15.8828125
56 | 8030.0,1667989917173,15.306640625
57 | 8176.0,1667989922585,15.513671875
58 | 8322.0,1667989928011,15.248046875
59 | 8468.0,1667989933387,15.78125
60 | 8614.0,1667989938798,15.453125
61 | 8760.0,1667989944248,15.92578125
62 | 8906.0,1667989949684,15.60546875
63 | 9052.0,1667989955094,16.140625
64 | 9198.0,1667989960539,16.1171875
65 | 9344.0,1667989965889,15.857421875
66 | 9490.0,1667989971366,16.23046875
67 | 9636.0,1667989976720,16.150390625
68 | 9782.0,1667989982204,16.203125
69 | 9928.0,1667989987803,16.169921875
70 | 10074.0,1667989993088,16.033203125
71 | 10220.0,1667989998467,16.390625
72 | 10366.0,1667990003907,16.013671875
73 | 10512.0,1667990009300,16.060546875
74 | 10658.0,1667990014707,16.3828125
75 | 10804.0,1667990020076,16.646484375
76 | 10950.0,1667990025489,16.203125
77 | 11096.0,1667990030793,16.787109375
78 | 11242.0,1667990036211,16.95703125
79 | 11388.0,1667990041631,16.41015625
80 | 11534.0,1667990047186,16.92578125
81 | 11680.0,1667990052685,16.951171875
82 | 11826.0,1667990058039,16.4765625
83 | 11972.0,1667990063356,16.916015625
84 | 12118.0,1667990068752,16.923828125
85 | 12264.0,1667990074197,16.873046875
86 | 12410.0,1667990079543,17.0390625
87 | 12556.0,1667990084850,17.0546875
88 | 12702.0,1667990090240,17.25
89 | 12848.0,1667990095578,16.765625
90 | 12994.0,1667990101108,17.25390625
91 | 13140.0,1667990106533,17.310546875
92 | 13286.0,1667990111870,17.41015625
93 | 13432.0,1667990117274,17.42578125
94 | 13578.0,1667990122597,17.07421875
95 | 13724.0,1667990127980,17.373046875
96 | 13870.0,1667990133290,17.4765625
97 | 14016.0,1667990138604,17.44921875
98 | 14162.0,1667990144034,17.310546875
99 | 14308.0,1667990149408,17.2734375
100 | 14454.0,1667990154736,17.916015625
101 | 14600.0,1667990160076,17.416015625
102 | 14746.0,1667990165533,17.705078125
103 | 14892.0,1667990170829,17.7421875
104 | 15038.0,1667990176217,17.505859375
105 | 15184.0,1667990181580,17.5234375
106 | 15330.0,1667990186930,17.0703125
107 | 15476.0,1667990192323,18.234375
108 | 15622.0,1667990197788,17.72265625
109 | 15768.0,1667990203161,17.671875
110 | 15914.0,1667990208603,17.828125
111 | 16060.0,1667990213950,17.927734375
112 | 16206.0,1667990219250,17.9765625
113 | 16352.0,1667990224932,18.119140625
114 | 16498.0,1667990230243,18.474609375
115 | 16644.0,1667990235568,18.05078125
116 | 16790.0,1667990240902,18.7421875
117 | 16936.0,1667990246305,18.50390625
118 | 17082.0,1667990251584,18.51953125
119 | 17228.0,1667990257138,18.576171875
120 | 17374.0,1667990262554,18.330078125
121 | 17520.0,1667990267910,18.796875
122 | 17666.0,1667990273304,18.43359375
123 | 17812.0,1667990278692,18.5625
124 | 17958.0,1667990284065,18.7109375
125 | 18104.0,1667990289432,18.677734375
126 | 18250.0,1667990294816,18.4140625
127 | 18396.0,1667990300234,18.1484375
128 | 18542.0,1667990305654,17.92578125
129 | 18688.0,1667990311055,18.50390625
130 | 18834.0,1667990316605,17.794921875
131 | 18980.0,1667990321929,18.5
132 | 19126.0,1667990327356,18.7109375
133 | 19272.0,1667990332764,17.62109375
134 | 19418.0,1667990338212,18.712890625
135 | 19564.0,1667990343759,18.46875
136 | 19710.0,1667990349210,18.11328125
137 | 19856.0,1667990354619,18.712890625
138 | 20002.0,1667990359979,18.791015625
139 | 20148.0,1667990365439,18.859375
140 | 20294.0,1667990370813,18.552734375
141 | 20440.0,1667990376304,18.34765625
142 | 20586.0,1667990381686,18.08203125
143 | 20732.0,1667990387051,17.908203125
144 | 20878.0,1667990392411,18.294921875
145 | 21024.0,1667990397789,18.166015625
146 | 21170.0,1667990403297,18.359375
147 | 21316.0,1667990408685,18.369140625
148 | 21462.0,1667990414183,18.53515625
149 | 21608.0,1667990419459,18.322265625
150 | 21754.0,1667990424928,17.890625
151 | 21900.0,1667990430289,18.515625
152 | 22046.0,1667990435639,17.994140625
153 | 22192.0,1667990441001,18.34765625
154 | 22338.0,1667990446412,18.138671875
155 | 22484.0,1667990451711,18.443359375
156 | 22630.0,1667990457136,18.4453125
157 | 22776.0,1667990462520,18.904296875
158 | 22922.0,1667990467963,18.912109375
159 | 23068.0,1667990473357,18.771484375
160 | 23214.0,1667990478959,18.89453125
161 | 23360.0,1667990484434,18.359375
162 | 23506.0,1667990489837,17.98828125
163 | 23652.0,1667990495244,18.228515625
164 | 23798.0,1667990500609,18.583984375
165 | 23944.0,1667990505954,18.564453125
166 | 24090.0,1667990511297,18.83203125
167 | 24236.0,1667990516630,18.681640625
168 | 24382.0,1667990522000,18.509765625
169 | 24528.0,1667990527382,18.39453125
170 | 24674.0,1667990532768,17.955078125
171 | 24820.0,1667990538141,18.390625
172 | 24966.0,1667990543530,18.7109375
173 | 25112.0,1667990548883,18.572265625
174 | 25258.0,1667990554369,18.26953125
175 | 25404.0,1667990559688,18.263671875
176 | 25550.0,1667990565010,18.841796875
177 | 25696.0,1667990570410,18.47265625
178 | 25842.0,1667990575764,18.6015625
179 | 25988.0,1667990581064,18.552734375
180 | 26134.0,1667990586521,18.556640625
181 | 26280.0,1667990591880,18.966796875
182 | 26426.0,1667990597346,18.970703125
183 | 26572.0,1667990602702,19.12109375
184 | 26718.0,1667990608107,19.291015625
185 | 26864.0,1667990613428,19.15234375
186 | 27010.0,1667990618771,19.134765625
187 | 27156.0,1667990624143,19.447265625
188 | 27302.0,1667990629471,19.724609375
189 | 27448.0,1667990634988,19.34765625
190 | 27594.0,1667990640307,19.568359375
191 | 27740.0,1667990645789,19.40234375
192 | 27886.0,1667990651073,19.40625
193 | 28032.0,1667990656348,19.400390625
194 | 28178.0,1667990661691,19.5078125
195 | 28324.0,1667990667109,19.77734375
196 | 28470.0,1667990672484,20.2421875
197 | 28616.0,1667990677866,20.83984375
198 | 28762.0,1667990683181,20.037109375
199 | 28908.0,1667990688545,20.05078125
200 | 29054.0,1667990693925,20.044921875
201 | 29200.0,1667990699295,20.578125
202 |
--------------------------------------------------------------------------------
/snake/data/bootstrap_return/MET-2178__eval_episode_reward_stochastic_policy.csv:
--------------------------------------------------------------------------------
1 | 3.0,1667990709114,0.134765625
2 | 146.0,1667990728218,0.15625
3 | 292.0,1667990733655,0.212890625
4 | 438.0,1667990739155,2.224609375
5 | 584.0,1667990744637,3.521484375
6 | 730.0,1667990750032,3.59375
7 | 876.0,1667990755436,4.0546875
8 | 1022.0,1667990760943,4.29296875
9 | 1168.0,1667990766410,4.701171875
10 | 1314.0,1667990771839,4.751953125
11 | 1460.0,1667990777430,5.1953125
12 | 1606.0,1667990782974,5.21484375
13 | 1752.0,1667990788579,5.82421875
14 | 1898.0,1667990793918,5.951171875
15 | 2044.0,1667990799510,5.734375
16 | 2190.0,1667990805018,5.76953125
17 | 2336.0,1667990810468,5.833984375
18 | 2482.0,1667990815847,6.12890625
19 | 2628.0,1667990821318,6.1328125
20 | 2774.0,1667990826836,6.345703125
21 | 2920.0,1667990832242,6.48828125
22 | 3066.0,1667990837686,6.419921875
23 | 3212.0,1667990843047,5.818359375
24 | 3358.0,1667990848501,6.296875
25 | 3504.0,1667990853924,6.544921875
26 | 3650.0,1667990859443,6.27734375
27 | 3796.0,1667990864846,6.6171875
28 | 3942.0,1667990870252,6.568359375
29 | 4088.0,1667990875790,6.771484375
30 | 4234.0,1667990881343,6.48046875
31 | 4380.0,1667990886799,6.73828125
32 | 4526.0,1667990892201,6.810546875
33 | 4672.0,1667990897621,6.271484375
34 | 4818.0,1667990903101,6.693359375
35 | 4964.0,1667990908580,7.09375
36 | 5110.0,1667990914054,7.0390625
37 | 5256.0,1667990919710,7.00390625
38 | 5402.0,1667990925180,7.4296875
39 | 5548.0,1667990930729,6.97265625
40 | 5694.0,1667990936280,7.548828125
41 | 5840.0,1667990941859,7.072265625
42 | 5986.0,1667990947416,7.173828125
43 | 6132.0,1667990952860,7.724609375
44 | 6278.0,1667990958277,7.91796875
45 | 6424.0,1667990963751,8.001953125
46 | 6570.0,1667990969279,8.59375
47 | 6716.0,1667990974829,8.65234375
48 | 6862.0,1667990980398,8.609375
49 | 7008.0,1667990985796,9.376953125
50 | 7154.0,1667990991411,9.478515625
51 | 7300.0,1667990996917,9.7265625
52 | 7446.0,1667991002575,10.205078125
53 | 7592.0,1667991008252,10.490234375
54 | 7738.0,1667991013757,10.724609375
55 | 7884.0,1667991019217,10.728515625
56 | 8030.0,1667991024852,11.078125
57 | 8176.0,1667991030410,11.43359375
58 | 8322.0,1667991036027,11.591796875
59 | 8468.0,1667991041592,12.046875
60 | 8614.0,1667991047137,11.982421875
61 | 8760.0,1667991052676,12.59765625
62 | 8906.0,1667991058115,13.041015625
63 | 9052.0,1667991063690,13.23046875
64 | 9198.0,1667991069295,13.60546875
65 | 9344.0,1667991074715,13.361328125
66 | 9490.0,1667991080245,14.01171875
67 | 9636.0,1667991085823,14.09375
68 | 9782.0,1667991091332,14.275390625
69 | 9928.0,1667991096912,14.923828125
70 | 10074.0,1667991102434,15.0703125
71 | 10220.0,1667991107992,14.986328125
72 | 10366.0,1667991113373,15.771484375
73 | 10512.0,1667991118839,15.19140625
74 | 10658.0,1667991124317,15.9921875
75 | 10804.0,1667991129912,15.80859375
76 | 10950.0,1667991135585,15.724609375
77 | 11096.0,1667991141285,15.759765625
78 | 11242.0,1667991146762,16.119140625
79 | 11388.0,1667991152253,15.916015625
80 | 11534.0,1667991157783,16.732421875
81 | 11680.0,1667991163267,16.498046875
82 | 11826.0,1667991168697,16.77734375
83 | 11972.0,1667991174091,16.265625
84 | 12118.0,1667991179539,16.978515625
85 | 12264.0,1667991184980,16.423828125
86 | 12410.0,1667991190511,16.82421875
87 | 12556.0,1667991196093,16.6015625
88 | 12702.0,1667991201558,17.12109375
89 | 12848.0,1667991207103,17.15234375
90 | 12994.0,1667991212643,16.828125
91 | 13140.0,1667991218271,16.716796875
92 | 13286.0,1667991223695,17.44140625
93 | 13432.0,1667991229269,17.564453125
94 | 13578.0,1667991234762,17.58203125
95 | 13724.0,1667991240403,18.15625
96 | 13870.0,1667991245965,17.853515625
97 | 14016.0,1667991251604,17.216796875
98 | 14162.0,1667991257177,17.978515625
99 | 14308.0,1667991262719,17.72265625
100 | 14454.0,1667991268211,17.783203125
101 | 14600.0,1667991273684,17.9765625
102 | 14746.0,1667991279325,17.55859375
103 | 14892.0,1667991284851,18.138671875
104 | 15038.0,1667991290372,18.337890625
105 | 15184.0,1667991296126,18.07421875
106 | 15330.0,1667991301771,18.169921875
107 | 15476.0,1667991307475,18.259765625
108 | 15622.0,1667991313059,18.646484375
109 | 15768.0,1667991318667,18.333984375
110 | 15914.0,1667991324142,18.640625
111 | 16060.0,1667991329767,18.84765625
112 | 16206.0,1667991335359,18.3671875
113 | 16352.0,1667991341032,19.107421875
114 | 16498.0,1667991346604,19.548828125
115 | 16644.0,1667991352162,19.611328125
116 | 16790.0,1667991357648,19.376953125
117 | 16936.0,1667991363082,18.0859375
118 | 17082.0,1667991368620,19.458984375
119 | 17228.0,1667991374171,19.623046875
120 | 17374.0,1667991379602,19.439453125
121 | 17520.0,1667991385127,19.6171875
122 | 17666.0,1667991390663,19.78515625
123 | 17812.0,1667991396204,19.798828125
124 | 17958.0,1667991401745,20.029296875
125 | 18104.0,1667991407312,20.541015625
126 | 18250.0,1667991412797,20.7578125
127 | 18396.0,1667991418374,20.80078125
128 | 18542.0,1667991423977,20.82421875
129 | 18688.0,1667991429571,21.083984375
130 | 18834.0,1667991435099,20.439453125
131 | 18980.0,1667991440712,20.69921875
132 | 19126.0,1667991446769,21.390625
133 | 19272.0,1667991452318,21.98828125
134 | 19418.0,1667991457895,22.34765625
135 | 19564.0,1667991463328,21.798828125
136 | 19710.0,1667991468848,23.21484375
137 | 19856.0,1667991474295,23.150390625
138 | 20002.0,1667991479883,24.0390625
139 | 20148.0,1667991485393,23.26953125
140 | 20294.0,1667991490904,23.0625
141 | 20440.0,1667991496376,24.2890625
142 | 20586.0,1667991501872,25.779296875
143 | 20732.0,1667991507603,26.154296875
144 | 20878.0,1667991513032,27.578125
145 | 21024.0,1667991518584,26.671875
146 | 21170.0,1667991524019,27.03125
147 | 21316.0,1667991529689,28.181640625
148 | 21462.0,1667991535161,28.251953125
149 | 21608.0,1667991540741,29.208984375
150 | 21754.0,1667991546338,29.025390625
151 | 21900.0,1667991551845,29.63671875
152 | 22046.0,1667991557403,28.42578125
153 | 22192.0,1667991563021,30.96875
154 | 22338.0,1667991568662,30.7890625
155 | 22484.0,1667991574127,30.58984375
156 | 22630.0,1667991579822,32.3046875
157 | 22776.0,1667991585233,30.9609375
158 | 22922.0,1667991590737,31.330078125
159 | 23068.0,1667991596370,31.138671875
160 | 23214.0,1667991602144,31.279296875
161 | 23360.0,1667991608013,33.0703125
162 | 23506.0,1667991613533,32.494140625
163 | 23652.0,1667991619103,32.30859375
164 | 23798.0,1667991624785,33.203125
165 | 23944.0,1667991630381,32.912109375
166 | 24090.0,1667991635977,32.734375
167 | 24236.0,1667991641615,32.07421875
168 | 24382.0,1667991647176,33.0703125
169 | 24528.0,1667991652850,33.375
170 | 24674.0,1667991658481,34.138671875
171 | 24820.0,1667991664195,33.275390625
172 | 24966.0,1667991669944,33.818359375
173 | 25112.0,1667991675499,33.5390625
174 | 25258.0,1667991681230,34.0703125
175 | 25404.0,1667991686761,34.05078125
176 | 25550.0,1667991692367,33.75390625
177 | 25696.0,1667991698058,34.15234375
178 | 25842.0,1667991703681,34.3671875
179 | 25988.0,1667991709379,32.48828125
180 | 26134.0,1667991714817,34.2734375
181 | 26280.0,1667991720625,33.94921875
182 | 26426.0,1667991726310,34.23828125
183 | 26572.0,1667991731858,34.27734375
184 | 26718.0,1667991737348,34.228515625
185 | 26864.0,1667991742833,33.93359375
186 | 27010.0,1667991748282,34.19921875
187 | 27156.0,1667991753795,34.3359375
188 | 27302.0,1667991759494,34.28515625
189 | 27448.0,1667991765049,34.314453125
190 | 27594.0,1667991770668,34.0859375
191 | 27740.0,1667991776053,33.642578125
192 | 27886.0,1667991781701,34.37890625
193 | 28032.0,1667991787293,34.2265625
194 | 28178.0,1667991792781,34.490234375
195 | 28324.0,1667991798294,34.244140625
196 | 28470.0,1667991803821,33.763671875
197 | 28616.0,1667991809269,34.328125
198 | 28762.0,1667991814799,34.66015625
199 | 28908.0,1667991820279,34.59375
200 | 29054.0,1667991825799,34.3125
201 | 29200.0,1667991831400,34.734375
202 |
--------------------------------------------------------------------------------
/snake/data/bootstrap_return/MET-2188__eval_episode_reward_stochastic_policy.csv:
--------------------------------------------------------------------------------
1 | 3.0,1667992944929,0.1640625
2 | 146.0,1667992962921,0.177734375
3 | 292.0,1667992968277,0.212890625
4 | 438.0,1667992973585,0.853515625
5 | 584.0,1667992978985,3.140625
6 | 730.0,1667992984488,3.951171875
7 | 876.0,1667992989743,3.943359375
8 | 1022.0,1667992995036,3.9921875
9 | 1168.0,1667993000415,4.107421875
10 | 1314.0,1667993005795,4.10546875
11 | 1460.0,1667993011041,4.486328125
12 | 1606.0,1667993016447,4.568359375
13 | 1752.0,1667993021790,4.794921875
14 | 1898.0,1667993027068,4.6953125
15 | 2044.0,1667993032371,5.21875
16 | 2190.0,1667993037661,5.421875
17 | 2336.0,1667993043028,5.44140625
18 | 2482.0,1667993048329,5.345703125
19 | 2628.0,1667993053729,5.482421875
20 | 2774.0,1667993059113,5.658203125
21 | 2920.0,1667993064490,5.724609375
22 | 3066.0,1667993069782,5.28125
23 | 3212.0,1667993075172,5.521484375
24 | 3358.0,1667993080573,5.5078125
25 | 3504.0,1667993085962,5.314453125
26 | 3650.0,1667993091364,6.2421875
27 | 3796.0,1667993096699,5.384765625
28 | 3942.0,1667993102169,5.392578125
29 | 4088.0,1667993107501,5.728515625
30 | 4234.0,1667993112871,6.080078125
31 | 4380.0,1667993118240,5.908203125
32 | 4526.0,1667993123578,5.869140625
33 | 4672.0,1667993128847,6.044921875
34 | 4818.0,1667993134243,6.52734375
35 | 4964.0,1667993139541,6.26953125
36 | 5110.0,1667993144899,6.23828125
37 | 5256.0,1667993150286,5.7734375
38 | 5402.0,1667993155648,6.55078125
39 | 5548.0,1667993161015,6.326171875
40 | 5694.0,1667993166460,6.69921875
41 | 5840.0,1667993171809,6.3828125
42 | 5986.0,1667993177159,6.11328125
43 | 6132.0,1667993182520,6.4765625
44 | 6278.0,1667993187944,6.36328125
45 | 6424.0,1667993193190,6.73046875
46 | 6570.0,1667993198483,6.939453125
47 | 6716.0,1667993203844,7.220703125
48 | 6862.0,1667993209126,7.55078125
49 | 7008.0,1667993214451,8.09765625
50 | 7154.0,1667993219682,8.044921875
51 | 7300.0,1667993225026,7.97265625
52 | 7446.0,1667993230323,8.447265625
53 | 7592.0,1667993235696,8.751953125
54 | 7738.0,1667993241028,8.794921875
55 | 7884.0,1667993246457,8.826171875
56 | 8030.0,1667993251876,8.890625
57 | 8176.0,1667993257367,8.9609375
58 | 8322.0,1667993262694,9.357421875
59 | 8468.0,1667993268023,9.44140625
60 | 8614.0,1667993273422,9.224609375
61 | 8760.0,1667993278755,9.33203125
62 | 8906.0,1667993284077,9.830078125
63 | 9052.0,1667993289417,9.744140625
64 | 9198.0,1667993294751,9.90625
65 | 9344.0,1667993300066,10.263671875
66 | 9490.0,1667993305513,10.57421875
67 | 9636.0,1667993310878,10.037109375
68 | 9782.0,1667993316297,9.91796875
69 | 9928.0,1667993321681,10.5625
70 | 10074.0,1667993327058,10.298828125
71 | 10220.0,1667993332494,10.716796875
72 | 10366.0,1667993337724,10.75390625
73 | 10512.0,1667993343072,10.81640625
74 | 10658.0,1667993348386,10.6484375
75 | 10804.0,1667993353650,10.6796875
76 | 10950.0,1667993359050,10.708984375
77 | 11096.0,1667993364413,10.564453125
78 | 11242.0,1667993369750,10.697265625
79 | 11388.0,1667993375085,10.982421875
80 | 11534.0,1667993380337,10.734375
81 | 11680.0,1667993385690,11.009765625
82 | 11826.0,1667993391044,10.962890625
83 | 11972.0,1667993396480,10.896484375
84 | 12118.0,1667993402009,10.923828125
85 | 12264.0,1667993407433,11.080078125
86 | 12410.0,1667993412838,10.974609375
87 | 12556.0,1667993418273,11.05078125
88 | 12702.0,1667993423722,10.646484375
89 | 12848.0,1667993429193,10.27734375
90 | 12994.0,1667993434686,11.18359375
91 | 13140.0,1667993440044,10.78125
92 | 13286.0,1667993445460,11.111328125
93 | 13432.0,1667993450818,10.841796875
94 | 13578.0,1667993456167,10.755859375
95 | 13724.0,1667993461796,11.203125
96 | 13870.0,1667993467377,11.111328125
97 | 14016.0,1667993472828,10.904296875
98 | 14162.0,1667993478249,11.033203125
99 | 14308.0,1667993483731,11.26171875
100 | 14454.0,1667993489211,11.384765625
101 | 14600.0,1667993494683,11.451171875
102 | 14746.0,1667993500149,10.943359375
103 | 14892.0,1667993505561,11.23046875
104 | 15038.0,1667993511139,10.822265625
105 | 15184.0,1667993516599,11.181640625
106 | 15330.0,1667993522226,11.333984375
107 | 15476.0,1667993527698,11.076171875
108 | 15622.0,1667993533221,11.33984375
109 | 15768.0,1667993538682,11.478515625
110 | 15914.0,1667993544034,11.041015625
111 | 16060.0,1667993549457,11.22265625
112 | 16206.0,1667993554894,11.337890625
113 | 16352.0,1667993560238,10.8359375
114 | 16498.0,1667993565735,11.275390625
115 | 16644.0,1667993571126,11.572265625
116 | 16790.0,1667993576658,11.4140625
117 | 16936.0,1667993582096,11.4921875
118 | 17082.0,1667993587535,11.2109375
119 | 17228.0,1667993592993,11.609375
120 | 17374.0,1667993598436,11.59375
121 | 17520.0,1667993603915,11.8984375
122 | 17666.0,1667993609231,11.546875
123 | 17812.0,1667993614642,11.701171875
124 | 17958.0,1667993620104,11.66015625
125 | 18104.0,1667993625680,11.46875
126 | 18250.0,1667993631090,12.087890625
127 | 18396.0,1667993636587,11.896484375
128 | 18542.0,1667993642000,12.373046875
129 | 18688.0,1667993647433,12.275390625
130 | 18834.0,1667993652810,12.060546875
131 | 18980.0,1667993658306,12.51171875
132 | 19126.0,1667993663829,12.810546875
133 | 19272.0,1667993669122,12.560546875
134 | 19418.0,1667993674428,12.689453125
135 | 19564.0,1667993679783,12.671875
136 | 19710.0,1667993685065,12.7890625
137 | 19856.0,1667993690489,13.072265625
138 | 20002.0,1667993695813,13.0
139 | 20148.0,1667993701227,12.771484375
140 | 20294.0,1667993706658,12.83203125
141 | 20440.0,1667993712211,12.8359375
142 | 20586.0,1667993717521,12.857421875
143 | 20732.0,1667993722963,12.69140625
144 | 20878.0,1667993728240,12.765625
145 | 21024.0,1667993733872,12.943359375
146 | 21170.0,1667993739152,12.71875
147 | 21316.0,1667993744589,12.37890625
148 | 21462.0,1667993749888,12.3828125
149 | 21608.0,1667993755275,12.353515625
150 | 21754.0,1667993760613,12.30859375
151 | 21900.0,1667993765967,12.07421875
152 | 22046.0,1667993771296,12.4296875
153 | 22192.0,1667993776996,12.27734375
154 | 22338.0,1667993782314,12.20703125
155 | 22484.0,1667993787753,12.349609375
156 | 22630.0,1667993793108,12.30859375
157 | 22776.0,1667993798620,12.5625
158 | 22922.0,1667993803942,12.556640625
159 | 23068.0,1667993809237,12.40625
160 | 23214.0,1667993814703,12.611328125
161 | 23360.0,1667993820120,13.2109375
162 | 23506.0,1667993825535,12.83203125
163 | 23652.0,1667993830966,12.75390625
164 | 23798.0,1667993836509,12.896484375
165 | 23944.0,1667993841875,12.9375
166 | 24090.0,1667993847241,12.458984375
167 | 24236.0,1667993852647,12.376953125
168 | 24382.0,1667993858043,12.705078125
169 | 24528.0,1667993863400,12.7890625
170 | 24674.0,1667993868751,12.810546875
171 | 24820.0,1667993874245,12.802734375
172 | 24966.0,1667993879710,12.607421875
173 | 25112.0,1667993885182,12.9140625
174 | 25258.0,1667993890630,12.751953125
175 | 25404.0,1667993896244,12.83203125
176 | 25550.0,1667993901670,12.787109375
177 | 25696.0,1667993907167,12.603515625
178 | 25842.0,1667993912678,12.755859375
179 | 25988.0,1667993918109,12.884765625
180 | 26134.0,1667993923525,13.087890625
181 | 26280.0,1667993928895,13.072265625
182 | 26426.0,1667993934480,13.078125
183 | 26572.0,1667993939813,13.06640625
184 | 26718.0,1667993945204,12.994140625
185 | 26864.0,1667993950495,13.4765625
186 | 27010.0,1667993955897,13.47265625
187 | 27156.0,1667993961229,13.84375
188 | 27302.0,1667993966706,13.833984375
189 | 27448.0,1667993972189,13.55078125
190 | 27594.0,1667993977524,13.822265625
191 | 27740.0,1667993982856,13.75
192 | 27886.0,1667993988220,13.759765625
193 | 28032.0,1667993993651,13.42578125
194 | 28178.0,1667993999080,13.966796875
195 | 28324.0,1667994004560,14.015625
196 | 28470.0,1667994009970,14.388671875
197 | 28616.0,1667994015385,14.3046875
198 | 28762.0,1667994020780,13.765625
199 | 28908.0,1667994026252,14.001953125
200 | 29054.0,1667994031763,14.1484375
201 | 29200.0,1667994037287,13.931640625
202 |
--------------------------------------------------------------------------------
/snake/data/bootstrap_return/MET-2196__eval_episode_reward_stochastic_policy.csv:
--------------------------------------------------------------------------------
1 | 3.0,1667995177742,0.154296875
2 | 146.0,1667995196189,0.171875
3 | 292.0,1667995201910,0.259765625
4 | 438.0,1667995207664,2.080078125
5 | 584.0,1667995213249,3.546875
6 | 730.0,1667995218727,3.908203125
7 | 876.0,1667995224273,4.0625
8 | 1022.0,1667995229766,4.765625
9 | 1168.0,1667995235306,4.904296875
10 | 1314.0,1667995240820,5.01953125
11 | 1460.0,1667995246473,5.666015625
12 | 1606.0,1667995251894,5.173828125
13 | 1752.0,1667995257430,5.857421875
14 | 1898.0,1667995262848,5.7265625
15 | 2044.0,1667995268300,6.267578125
16 | 2190.0,1667995273762,6.478515625
17 | 2336.0,1667995279295,5.875
18 | 2482.0,1667995284885,6.15234375
19 | 2628.0,1667995290417,5.96484375
20 | 2774.0,1667995296031,6.1796875
21 | 2920.0,1667995301523,5.88671875
22 | 3066.0,1667995307082,5.966796875
23 | 3212.0,1667995312629,6.408203125
24 | 3358.0,1667995318165,6.078125
25 | 3504.0,1667995323714,6.4375
26 | 3650.0,1667995329275,6.34765625
27 | 3796.0,1667995334939,6.609375
28 | 3942.0,1667995340534,6.572265625
29 | 4088.0,1667995346065,6.77734375
30 | 4234.0,1667995351548,5.92578125
31 | 4380.0,1667995357096,6.53515625
32 | 4526.0,1667995362476,6.640625
33 | 4672.0,1667995367953,7.029296875
34 | 4818.0,1667995373428,6.525390625
35 | 4964.0,1667995378960,6.98046875
36 | 5110.0,1667995384416,6.30078125
37 | 5256.0,1667995389941,6.55859375
38 | 5402.0,1667995395435,6.31640625
39 | 5548.0,1667995400874,7.025390625
40 | 5694.0,1667995406493,7.2109375
41 | 5840.0,1667995411890,7.283203125
42 | 5986.0,1667995417430,7.20703125
43 | 6132.0,1667995422939,7.78125
44 | 6278.0,1667995428384,7.9765625
45 | 6424.0,1667995433832,7.662109375
46 | 6570.0,1667995439283,8.1875
47 | 6716.0,1667995444804,8.064453125
48 | 6862.0,1667995450269,7.93359375
49 | 7008.0,1667995455735,8.212890625
50 | 7154.0,1667995461203,8.4296875
51 | 7300.0,1667995466776,8.49609375
52 | 7446.0,1667995472212,8.27734375
53 | 7592.0,1667995477682,8.4296875
54 | 7738.0,1667995483287,8.302734375
55 | 7884.0,1667995488828,8.720703125
56 | 8030.0,1667995494248,8.7421875
57 | 8176.0,1667995499596,8.40625
58 | 8322.0,1667995505230,9.0703125
59 | 8468.0,1667995510762,8.966796875
60 | 8614.0,1667995516288,8.794921875
61 | 8760.0,1667995521808,9.015625
62 | 8906.0,1667995527362,9.443359375
63 | 9052.0,1667995532917,9.294921875
64 | 9198.0,1667995538490,9.2421875
65 | 9344.0,1667995544078,9.3359375
66 | 9490.0,1667995549658,9.60546875
67 | 9636.0,1667995555161,9.466796875
68 | 9782.0,1667995560627,9.615234375
69 | 9928.0,1667995566178,9.537109375
70 | 10074.0,1667995571648,9.689453125
71 | 10220.0,1667995577215,9.703125
72 | 10366.0,1667995582681,9.533203125
73 | 10512.0,1667995588325,9.828125
74 | 10658.0,1667995593809,9.86328125
75 | 10804.0,1667995599709,9.7109375
76 | 10950.0,1667995605371,10.015625
77 | 11096.0,1667995610980,10.138671875
78 | 11242.0,1667995616555,10.005859375
79 | 11388.0,1667995622032,10.361328125
80 | 11534.0,1667995627674,10.29296875
81 | 11680.0,1667995633149,10.443359375
82 | 11826.0,1667995638650,10.6796875
83 | 11972.0,1667995644253,10.515625
84 | 12118.0,1667995649786,10.765625
85 | 12264.0,1667995655316,10.822265625
86 | 12410.0,1667995660805,10.888671875
87 | 12556.0,1667995666438,10.5546875
88 | 12702.0,1667995672003,10.861328125
89 | 12848.0,1667995677525,10.47265625
90 | 12994.0,1667995683035,11.04296875
91 | 13140.0,1667995688539,10.591796875
92 | 13286.0,1667995694057,10.70703125
93 | 13432.0,1667995699542,10.759765625
94 | 13578.0,1667995705068,10.66015625
95 | 13724.0,1667995710623,10.76171875
96 | 13870.0,1667995716154,11.18359375
97 | 14016.0,1667995721642,10.689453125
98 | 14162.0,1667995727200,10.88671875
99 | 14308.0,1667995732750,10.642578125
100 | 14454.0,1667995738272,11.025390625
101 | 14600.0,1667995743765,10.9453125
102 | 14746.0,1667995749387,10.78515625
103 | 14892.0,1667995755349,10.8828125
104 | 15038.0,1667995762059,11.2265625
105 | 15184.0,1667995768892,10.95703125
106 | 15330.0,1667995775030,11.150390625
107 | 15476.0,1667995781023,11.51953125
108 | 15622.0,1667995787787,11.03125
109 | 15768.0,1667995794518,11.521484375
110 | 15914.0,1667995800316,11.36328125
111 | 16060.0,1667995805929,11.65625
112 | 16206.0,1667995811430,11.642578125
113 | 16352.0,1667995817099,11.64453125
114 | 16498.0,1667995822591,11.59765625
115 | 16644.0,1667995828121,11.8046875
116 | 16790.0,1667995833648,11.712890625
117 | 16936.0,1667995839121,12.109375
118 | 17082.0,1667995844789,11.974609375
119 | 17228.0,1667995850227,12.1875
120 | 17374.0,1667995855815,12.26953125
121 | 17520.0,1667995861373,12.283203125
122 | 17666.0,1667995866957,12.12109375
123 | 17812.0,1667995872446,12.083984375
124 | 17958.0,1667995878051,12.208984375
125 | 18104.0,1667995883533,12.45703125
126 | 18250.0,1667995889097,12.111328125
127 | 18396.0,1667995894568,12.3984375
128 | 18542.0,1667995900122,12.20703125
129 | 18688.0,1667995905579,12.66015625
130 | 18834.0,1667995911110,12.25390625
131 | 18980.0,1667995916617,12.62890625
132 | 19126.0,1667995922151,12.7421875
133 | 19272.0,1667995927696,12.345703125
134 | 19418.0,1667995933239,12.685546875
135 | 19564.0,1667995938829,12.490234375
136 | 19710.0,1667995944444,12.583984375
137 | 19856.0,1667995950103,12.5390625
138 | 20002.0,1667995955757,12.259765625
139 | 20148.0,1667995961372,12.462890625
140 | 20294.0,1667995967070,12.337890625
141 | 20440.0,1667995972610,12.203125
142 | 20586.0,1667995978219,12.53125
143 | 20732.0,1667995983772,12.736328125
144 | 20878.0,1667995989442,12.09765625
145 | 21024.0,1667995995065,12.421875
146 | 21170.0,1667996000631,12.4765625
147 | 21316.0,1667996006272,12.68359375
148 | 21462.0,1667996011856,12.58984375
149 | 21608.0,1667996017369,11.994140625
150 | 21754.0,1667996022918,12.40234375
151 | 21900.0,1667996028588,12.865234375
152 | 22046.0,1667996034130,12.318359375
153 | 22192.0,1667996039722,12.189453125
154 | 22338.0,1667996045386,12.107421875
155 | 22484.0,1667996050939,12.802734375
156 | 22630.0,1667996056606,12.326171875
157 | 22776.0,1667996062196,12.384765625
158 | 22922.0,1667996067921,12.69921875
159 | 23068.0,1667996073514,12.541015625
160 | 23214.0,1667996079102,12.6484375
161 | 23360.0,1667996084753,12.453125
162 | 23506.0,1667996090303,12.126953125
163 | 23652.0,1667996095988,12.716796875
164 | 23798.0,1667996101594,12.84375
165 | 23944.0,1667996107263,12.59765625
166 | 24090.0,1667996112881,12.873046875
167 | 24236.0,1667996118538,12.626953125
168 | 24382.0,1667996124187,12.83984375
169 | 24528.0,1667996129852,12.912109375
170 | 24674.0,1667996135519,12.591796875
171 | 24820.0,1667996141137,12.55078125
172 | 24966.0,1667996146911,13.37109375
173 | 25112.0,1667996152521,12.583984375
174 | 25258.0,1667996158310,12.947265625
175 | 25404.0,1667996163991,13.447265625
176 | 25550.0,1667996169709,13.375
177 | 25696.0,1667996175396,13.376953125
178 | 25842.0,1667996181153,13.333984375
179 | 25988.0,1667996186938,13.013671875
180 | 26134.0,1667996192637,13.166015625
181 | 26280.0,1667996198388,13.16015625
182 | 26426.0,1667996204060,12.791015625
183 | 26572.0,1667996209672,13.283203125
184 | 26718.0,1667996215382,13.20703125
185 | 26864.0,1667996221046,13.208984375
186 | 27010.0,1667996226638,13.23046875
187 | 27156.0,1667996232242,13.169921875
188 | 27302.0,1667996237908,13.146484375
189 | 27448.0,1667996243515,13.056640625
190 | 27594.0,1667996249132,12.8828125
191 | 27740.0,1667996254767,12.837890625
192 | 27886.0,1667996260363,13.283203125
193 | 28032.0,1667996266032,13.228515625
194 | 28178.0,1667996271686,12.69140625
195 | 28324.0,1667996277311,13.38671875
196 | 28470.0,1667996282935,12.416015625
197 | 28616.0,1667996288513,12.861328125
198 | 28762.0,1667996294050,12.55859375
199 | 28908.0,1667996299727,12.611328125
200 | 29054.0,1667996305475,12.71484375
201 | 29200.0,1667996311015,12.556640625
202 |
--------------------------------------------------------------------------------
/snake/data/bootstrap_return/MET-2199__eval_episode_reward_stochastic_policy.csv:
--------------------------------------------------------------------------------
1 | 3.0,1667996320975,0.14453125
2 | 146.0,1667996339763,0.150390625
3 | 292.0,1667996345348,0.271484375
4 | 438.0,1667996350930,2.166015625
5 | 584.0,1667996356459,3.521484375
6 | 730.0,1667996361918,4.017578125
7 | 876.0,1667996367450,4.380859375
8 | 1022.0,1667996373017,4.8125
9 | 1168.0,1667996378612,5.072265625
10 | 1314.0,1667996384064,5.275390625
11 | 1460.0,1667996389582,5.787109375
12 | 1606.0,1667996395028,5.82421875
13 | 1752.0,1667996400554,5.822265625
14 | 1898.0,1667996406193,6.40625
15 | 2044.0,1667996411765,6.095703125
16 | 2190.0,1667996417488,6.41796875
17 | 2336.0,1667996423097,6.166015625
18 | 2482.0,1667996428667,6.474609375
19 | 2628.0,1667996434164,6.630859375
20 | 2774.0,1667996439601,6.529296875
21 | 2920.0,1667996445130,6.703125
22 | 3066.0,1667996450646,6.634765625
23 | 3212.0,1667996456199,6.7734375
24 | 3358.0,1667996461826,6.623046875
25 | 3504.0,1667996467384,6.6171875
26 | 3650.0,1667996472841,7.103515625
27 | 3796.0,1667996478566,6.966796875
28 | 3942.0,1667996484102,7.072265625
29 | 4088.0,1667996489624,7.53515625
30 | 4234.0,1667996495201,7.095703125
31 | 4380.0,1667996500787,7.384765625
32 | 4526.0,1667996506286,7.53125
33 | 4672.0,1667996511822,7.98046875
34 | 4818.0,1667996517389,8.115234375
35 | 4964.0,1667996523057,8.19140625
36 | 5110.0,1667996528634,8.19921875
37 | 5256.0,1667996534139,8.740234375
38 | 5402.0,1667996539660,8.875
39 | 5548.0,1667996545315,8.72265625
40 | 5694.0,1667996550844,8.982421875
41 | 5840.0,1667996556397,9.12890625
42 | 5986.0,1667996561941,9.080078125
43 | 6132.0,1667996567503,8.994140625
44 | 6278.0,1667996573006,9.509765625
45 | 6424.0,1667996578522,9.345703125
46 | 6570.0,1667996584093,9.8203125
47 | 6716.0,1667996589597,10.267578125
48 | 6862.0,1667996595246,9.681640625
49 | 7008.0,1667996600806,9.80859375
50 | 7154.0,1667996606398,9.890625
51 | 7300.0,1667996611924,10.69140625
52 | 7446.0,1667996617556,10.486328125
53 | 7592.0,1667996623095,10.794921875
54 | 7738.0,1667996628576,10.841796875
55 | 7884.0,1667996634213,10.99609375
56 | 8030.0,1667996639836,11.15625
57 | 8176.0,1667996645354,11.306640625
58 | 8322.0,1667996650809,11.205078125
59 | 8468.0,1667996656300,11.681640625
60 | 8614.0,1667996661951,11.66796875
61 | 8760.0,1667996667391,12.078125
62 | 8906.0,1667996672918,12.22265625
63 | 9052.0,1667996678604,12.37890625
64 | 9198.0,1667996684185,12.609375
65 | 9344.0,1667996689865,13.037109375
66 | 9490.0,1667996695395,12.68359375
67 | 9636.0,1667996701056,13.23046875
68 | 9782.0,1667996706683,12.82421875
69 | 9928.0,1667996712195,13.3125
70 | 10074.0,1667996717750,13.451171875
71 | 10220.0,1667996723353,13.099609375
72 | 10366.0,1667996729021,13.732421875
73 | 10512.0,1667996734603,14.037109375
74 | 10658.0,1667996740221,13.470703125
75 | 10804.0,1667996745771,13.50390625
76 | 10950.0,1667996751432,14.1953125
77 | 11096.0,1667996756942,13.90625
78 | 11242.0,1667996762540,13.951171875
79 | 11388.0,1667996768165,14.26171875
80 | 11534.0,1667996773707,14.015625
81 | 11680.0,1667996779269,13.759765625
82 | 11826.0,1667996784821,14.03515625
83 | 11972.0,1667996790398,14.3125
84 | 12118.0,1667996795850,14.10546875
85 | 12264.0,1667996801483,14.9453125
86 | 12410.0,1667996807086,14.58984375
87 | 12556.0,1667996812759,15.47265625
88 | 12702.0,1667996818301,15.0390625
89 | 12848.0,1667996823849,15.3984375
90 | 12994.0,1667996829514,15.5078125
91 | 13140.0,1667996835227,15.9609375
92 | 13286.0,1667996840854,15.515625
93 | 13432.0,1667996846668,15.796875
94 | 13578.0,1667996852480,15.783203125
95 | 13724.0,1667996858303,16.138671875
96 | 13870.0,1667996864131,15.83984375
97 | 14016.0,1667996870354,15.48828125
98 | 14162.0,1667996876167,16.0390625
99 | 14308.0,1667996882447,16.376953125
100 | 14454.0,1667996888244,16.482421875
101 | 14600.0,1667996893862,16.25390625
102 | 14746.0,1667996899449,16.708984375
103 | 14892.0,1667996905031,16.591796875
104 | 15038.0,1667996910581,16.44921875
105 | 15184.0,1667996916193,16.78515625
106 | 15330.0,1667996921865,17.05859375
107 | 15476.0,1667996927563,16.96875
108 | 15622.0,1667996933282,16.984375
109 | 15768.0,1667996938949,16.224609375
110 | 15914.0,1667996944516,16.859375
111 | 16060.0,1667996950255,16.841796875
112 | 16206.0,1667996955807,17.119140625
113 | 16352.0,1667996961359,17.068359375
114 | 16498.0,1667996966957,16.826171875
115 | 16644.0,1667996972611,16.978515625
116 | 16790.0,1667996978353,17.33203125
117 | 16936.0,1667996984076,17.064453125
118 | 17082.0,1667996990314,16.931640625
119 | 17228.0,1667996996427,17.2109375
120 | 17374.0,1667997002739,17.267578125
121 | 17520.0,1667997008933,17.521484375
122 | 17666.0,1667997014742,17.4375
123 | 17812.0,1667997020399,17.341796875
124 | 17958.0,1667997026027,17.142578125
125 | 18104.0,1667997031680,17.357421875
126 | 18250.0,1667997037368,17.455078125
127 | 18396.0,1667997043321,17.33984375
128 | 18542.0,1667997048944,17.396484375
129 | 18688.0,1667997054564,17.52734375
130 | 18834.0,1667997060116,17.3984375
131 | 18980.0,1667997065668,17.43359375
132 | 19126.0,1667997071360,17.34765625
133 | 19272.0,1667997077026,18.087890625
134 | 19418.0,1667997082633,18.14453125
135 | 19564.0,1667997088183,17.9609375
136 | 19710.0,1667997093703,18.212890625
137 | 19856.0,1667997099385,18.287109375
138 | 20002.0,1667997104944,18.501953125
139 | 20148.0,1667997110569,18.490234375
140 | 20294.0,1667997116066,18.490234375
141 | 20440.0,1667997121854,18.11328125
142 | 20586.0,1667997127482,19.0
143 | 20732.0,1667997133184,18.9140625
144 | 20878.0,1667997138901,18.90625
145 | 21024.0,1667997144549,18.751953125
146 | 21170.0,1667997150128,19.375
147 | 21316.0,1667997155724,19.419921875
148 | 21462.0,1667997161319,19.12890625
149 | 21608.0,1667997166991,19.7265625
150 | 21754.0,1667997172671,19.8046875
151 | 21900.0,1667997178352,20.01953125
152 | 22046.0,1667997184020,20.080078125
153 | 22192.0,1667997189634,19.525390625
154 | 22338.0,1667997195271,20.103515625
155 | 22484.0,1667997200972,20.361328125
156 | 22630.0,1667997206621,20.5078125
157 | 22776.0,1667997212201,20.525390625
158 | 22922.0,1667997217811,20.990234375
159 | 23068.0,1667997223375,20.2109375
160 | 23214.0,1667997228978,20.791015625
161 | 23360.0,1667997234756,20.80078125
162 | 23506.0,1667997240303,20.80859375
163 | 23652.0,1667997245870,21.244140625
164 | 23798.0,1667997251593,21.505859375
165 | 23944.0,1667997257342,21.033203125
166 | 24090.0,1667997262995,22.150390625
167 | 24236.0,1667997268741,22.046875
168 | 24382.0,1667997274387,21.609375
169 | 24528.0,1667997280162,21.73828125
170 | 24674.0,1667997285678,21.90234375
171 | 24820.0,1667997291393,21.92578125
172 | 24966.0,1667997296968,22.265625
173 | 25112.0,1667997302741,22.4375
174 | 25258.0,1667997308508,22.384765625
175 | 25404.0,1667997314099,22.306640625
176 | 25550.0,1667997319683,22.802734375
177 | 25696.0,1667997325431,22.740234375
178 | 25842.0,1667997331029,23.19921875
179 | 25988.0,1667997336681,23.111328125
180 | 26134.0,1667997342383,24.279296875
181 | 26280.0,1667997348005,24.12109375
182 | 26426.0,1667997353521,24.28515625
183 | 26572.0,1667997359275,25.828125
184 | 26718.0,1667997364839,26.453125
185 | 26864.0,1667997370640,27.3203125
186 | 27010.0,1667997376386,29.033203125
187 | 27156.0,1667997382016,29.91796875
188 | 27302.0,1667997387597,30.78125
189 | 27448.0,1667997393223,31.267578125
190 | 27594.0,1667997398863,32.009765625
191 | 27740.0,1667997404452,30.61328125
192 | 27886.0,1667997410178,32.748046875
193 | 28032.0,1667997415869,32.1484375
194 | 28178.0,1667997421659,32.271484375
195 | 28324.0,1667997427224,32.677734375
196 | 28470.0,1667997432872,32.66796875
197 | 28616.0,1667997438558,32.48828125
198 | 28762.0,1667997444117,33.21484375
199 | 28908.0,1667997449795,32.892578125
200 | 29054.0,1667997455441,33.40625
201 | 29200.0,1667997461124,33.4921875
202 |
--------------------------------------------------------------------------------
/snake/data/mgrl_return/MET-2157__eval_episode_reward_stochastic_policy.csv:
--------------------------------------------------------------------------------
1 | 3.0,1667987108012,0.140625
2 | 146.0,1667987120064,0.185546875
3 | 292.0,1667987124772,0.3125
4 | 438.0,1667987129323,2.08984375
5 | 584.0,1667987133938,3.650390625
6 | 730.0,1667987138527,3.86328125
7 | 876.0,1667987143127,4.072265625
8 | 1022.0,1667987147765,4.703125
9 | 1168.0,1667987152404,4.888671875
10 | 1314.0,1667987157001,5.099609375
11 | 1460.0,1667987161673,5.595703125
12 | 1606.0,1667987166268,5.345703125
13 | 1752.0,1667987170802,5.81640625
14 | 1898.0,1667987175408,5.958984375
15 | 2044.0,1667987180042,5.98828125
16 | 2190.0,1667987184611,6.02734375
17 | 2336.0,1667987189125,5.720703125
18 | 2482.0,1667987193708,5.484375
19 | 2628.0,1667987198419,5.93359375
20 | 2774.0,1667987202946,6.28515625
21 | 2920.0,1667987207607,6.33203125
22 | 3066.0,1667987212231,6.5078125
23 | 3212.0,1667987216907,6.48046875
24 | 3358.0,1667987221578,6.201171875
25 | 3504.0,1667987226224,6.2734375
26 | 3650.0,1667987230906,6.744140625
27 | 3796.0,1667987235586,6.6640625
28 | 3942.0,1667987240163,6.951171875
29 | 4088.0,1667987244849,6.82421875
30 | 4234.0,1667987249371,6.73046875
31 | 4380.0,1667987253956,6.84765625
32 | 4526.0,1667987258524,6.94140625
33 | 4672.0,1667987263124,7.107421875
34 | 4818.0,1667987267668,6.8671875
35 | 4964.0,1667987272331,7.025390625
36 | 5110.0,1667987276955,7.4296875
37 | 5256.0,1667987281580,7.43359375
38 | 5402.0,1667987286271,7.685546875
39 | 5548.0,1667987290845,7.55078125
40 | 5694.0,1667987295444,7.607421875
41 | 5840.0,1667987300092,7.517578125
42 | 5986.0,1667987304703,7.962890625
43 | 6132.0,1667987309277,7.875
44 | 6278.0,1667987313897,8.220703125
45 | 6424.0,1667987318542,8.2734375
46 | 6570.0,1667987323153,9.078125
47 | 6716.0,1667987327816,8.87890625
48 | 6862.0,1667987332430,9.361328125
49 | 7008.0,1667987337070,9.193359375
50 | 7154.0,1667987341634,9.9453125
51 | 7300.0,1667987346310,9.998046875
52 | 7446.0,1667987350948,10.365234375
53 | 7592.0,1667987355576,10.615234375
54 | 7738.0,1667987360129,10.828125
55 | 7884.0,1667987364807,11.09765625
56 | 8030.0,1667987369398,11.923828125
57 | 8176.0,1667987373990,11.61328125
58 | 8322.0,1667987378669,12.142578125
59 | 8468.0,1667987383251,12.181640625
60 | 8614.0,1667987387851,12.80078125
61 | 8760.0,1667987392464,12.642578125
62 | 8906.0,1667987397090,12.68359375
63 | 9052.0,1667987401866,13.1171875
64 | 9198.0,1667987406589,13.140625
65 | 9344.0,1667987411224,13.64453125
66 | 9490.0,1667987415898,13.53125
67 | 9636.0,1667987420488,13.9140625
68 | 9782.0,1667987425182,13.943359375
69 | 9928.0,1667987429966,14.123046875
70 | 10074.0,1667987434623,13.9765625
71 | 10220.0,1667987439212,14.3984375
72 | 10366.0,1667987443831,14.58203125
73 | 10512.0,1667987448516,14.59375
74 | 10658.0,1667987453239,14.751953125
75 | 10804.0,1667987457836,15.32421875
76 | 10950.0,1667987462499,15.5703125
77 | 11096.0,1667987467187,15.294921875
78 | 11242.0,1667987471923,15.65234375
79 | 11388.0,1667987476543,15.439453125
80 | 11534.0,1667987481160,14.95703125
81 | 11680.0,1667987485785,15.392578125
82 | 11826.0,1667987490388,15.720703125
83 | 11972.0,1667987494961,15.73828125
84 | 12118.0,1667987499539,15.919921875
85 | 12264.0,1667987504118,15.99609375
86 | 12410.0,1667987508760,15.75390625
87 | 12556.0,1667987513306,15.9921875
88 | 12702.0,1667987517943,16.439453125
89 | 12848.0,1667987522627,16.080078125
90 | 12994.0,1667987527335,15.875
91 | 13140.0,1667987531963,15.8046875
92 | 13286.0,1667987536632,15.84375
93 | 13432.0,1667987541173,16.21484375
94 | 13578.0,1667987546179,15.78515625
95 | 13724.0,1667987550727,16.41015625
96 | 13870.0,1667987555364,15.896484375
97 | 14016.0,1667987560010,16.16796875
98 | 14162.0,1667987564686,16.49609375
99 | 14308.0,1667987569339,16.009765625
100 | 14454.0,1667987573974,16.615234375
101 | 14600.0,1667987578595,16.62890625
102 | 14746.0,1667987583284,16.498046875
103 | 14892.0,1667987587935,16.515625
104 | 15038.0,1667987592539,16.458984375
105 | 15184.0,1667987597298,17.0546875
106 | 15330.0,1667987601849,16.736328125
107 | 15476.0,1667987606573,17.232421875
108 | 15622.0,1667987611179,17.44140625
109 | 15768.0,1667987615902,17.19921875
110 | 15914.0,1667987620555,16.966796875
111 | 16060.0,1667987625183,17.140625
112 | 16206.0,1667987629803,16.78515625
113 | 16352.0,1667987634461,16.646484375
114 | 16498.0,1667987639139,17.078125
115 | 16644.0,1667987643786,17.111328125
116 | 16790.0,1667987648494,17.2890625
117 | 16936.0,1667987653106,17.05078125
118 | 17082.0,1667987657773,17.244140625
119 | 17228.0,1667987662376,17.5078125
120 | 17374.0,1667987666989,17.458984375
121 | 17520.0,1667987671551,17.67578125
122 | 17666.0,1667987676268,17.466796875
123 | 17812.0,1667987680837,17.265625
124 | 17958.0,1667987685424,16.97265625
125 | 18104.0,1667987690016,17.392578125
126 | 18250.0,1667987694737,17.693359375
127 | 18396.0,1667987699419,17.52734375
128 | 18542.0,1667987704134,17.650390625
129 | 18688.0,1667987708752,17.892578125
130 | 18834.0,1667987713280,17.294921875
131 | 18980.0,1667987718004,17.6328125
132 | 19126.0,1667987722681,17.43359375
133 | 19272.0,1667987727429,17.88671875
134 | 19418.0,1667987732055,17.66015625
135 | 19564.0,1667987736700,17.79296875
136 | 19710.0,1667987741313,17.73046875
137 | 19856.0,1667987746039,17.853515625
138 | 20002.0,1667987750707,18.048828125
139 | 20148.0,1667987755305,17.537109375
140 | 20294.0,1667987759969,17.50390625
141 | 20440.0,1667987764553,17.390625
142 | 20586.0,1667987769184,17.984375
143 | 20732.0,1667987773792,17.7265625
144 | 20878.0,1667987778468,17.412109375
145 | 21024.0,1667987783135,17.3359375
146 | 21170.0,1667987787761,18.080078125
147 | 21316.0,1667987792508,18.18359375
148 | 21462.0,1667987797190,18.029296875
149 | 21608.0,1667987801859,18.017578125
150 | 21754.0,1667987806527,17.916015625
151 | 21900.0,1667987811110,17.650390625
152 | 22046.0,1667987815788,18.01171875
153 | 22192.0,1667987820369,17.8515625
154 | 22338.0,1667987825008,18.16796875
155 | 22484.0,1667987829661,17.78125
156 | 22630.0,1667987834278,18.0
157 | 22776.0,1667987838959,18.07421875
158 | 22922.0,1667987843678,18.24609375
159 | 23068.0,1667987848347,17.986328125
160 | 23214.0,1667987852920,17.796875
161 | 23360.0,1667987857475,17.712890625
162 | 23506.0,1667987862136,18.28125
163 | 23652.0,1667987866840,17.931640625
164 | 23798.0,1667987871456,18.375
165 | 23944.0,1667987876013,18.81640625
166 | 24090.0,1667987880626,18.2109375
167 | 24236.0,1667987885316,17.98046875
168 | 24382.0,1667987889948,18.0078125
169 | 24528.0,1667987894520,18.46484375
170 | 24674.0,1667987899133,18.263671875
171 | 24820.0,1667987903797,18.06640625
172 | 24966.0,1667987908413,18.412109375
173 | 25112.0,1667987913109,18.02734375
174 | 25258.0,1667987917860,18.6796875
175 | 25404.0,1667987922499,19.0
176 | 25550.0,1667987927232,18.302734375
177 | 25696.0,1667987931845,18.24609375
178 | 25842.0,1667987936430,18.484375
179 | 25988.0,1667987941029,18.923828125
180 | 26134.0,1667987945734,18.609375
181 | 26280.0,1667987950294,18.8359375
182 | 26426.0,1667987954857,18.625
183 | 26572.0,1667987959594,18.626953125
184 | 26718.0,1667987964182,18.763671875
185 | 26864.0,1667987968747,18.884765625
186 | 27010.0,1667987973393,18.833984375
187 | 27156.0,1667987978091,19.306640625
188 | 27302.0,1667987982658,19.16015625
189 | 27448.0,1667987987309,18.974609375
190 | 27594.0,1667987991945,19.337890625
191 | 27740.0,1667987996592,19.16015625
192 | 27886.0,1667988001536,19.478515625
193 | 28032.0,1667988006429,19.591796875
194 | 28178.0,1667988011140,19.4140625
195 | 28324.0,1667988015794,19.6015625
196 | 28470.0,1667988020489,19.478515625
197 | 28616.0,1667988025271,19.244140625
198 | 28762.0,1667988029880,19.578125
199 | 28908.0,1667988034607,20.52734375
200 | 29054.0,1667988039362,20.111328125
201 | 29200.0,1667988044047,19.568359375
202 |
--------------------------------------------------------------------------------
/snake/data/mgrl_return/MET-2179__eval_episode_reward_stochastic_policy.csv:
--------------------------------------------------------------------------------
1 | 3.0,1667990937414,0.1640625
2 | 146.0,1667990949593,0.150390625
3 | 292.0,1667990954319,0.201171875
4 | 438.0,1667990958988,2.099609375
5 | 584.0,1667990963797,3.716796875
6 | 730.0,1667990968479,4.158203125
7 | 876.0,1667990973189,4.669921875
8 | 1022.0,1667990977883,5.017578125
9 | 1168.0,1667990982511,5.109375
10 | 1314.0,1667990987194,5.46484375
11 | 1460.0,1667990991982,5.666015625
12 | 1606.0,1667990996737,5.923828125
13 | 1752.0,1667991001599,6.361328125
14 | 1898.0,1667991006212,6.271484375
15 | 2044.0,1667991010965,6.537109375
16 | 2190.0,1667991015780,6.369140625
17 | 2336.0,1667991020556,6.42578125
18 | 2482.0,1667991025266,6.423828125
19 | 2628.0,1667991029884,6.501953125
20 | 2774.0,1667991034668,6.990234375
21 | 2920.0,1667991039341,6.5703125
22 | 3066.0,1667991044152,6.99609375
23 | 3212.0,1667991048785,6.8515625
24 | 3358.0,1667991053333,7.271484375
25 | 3504.0,1667991058045,6.9765625
26 | 3650.0,1667991062715,7.658203125
27 | 3796.0,1667991067373,7.18359375
28 | 3942.0,1667991072033,7.66796875
29 | 4088.0,1667991076771,7.533203125
30 | 4234.0,1667991081420,8.12109375
31 | 4380.0,1667991086149,7.935546875
32 | 4526.0,1667991090892,8.572265625
33 | 4672.0,1667991095723,8.48046875
34 | 4818.0,1667991100498,8.603515625
35 | 4964.0,1667991105240,9.189453125
36 | 5110.0,1667991109971,9.40234375
37 | 5256.0,1667991114783,9.806640625
38 | 5402.0,1667991119575,10.375
39 | 5548.0,1667991124433,9.970703125
40 | 5694.0,1667991129171,10.189453125
41 | 5840.0,1667991133833,9.828125
42 | 5986.0,1667991138583,10.44140625
43 | 6132.0,1667991143252,10.896484375
44 | 6278.0,1667991148014,10.52734375
45 | 6424.0,1667991152701,10.849609375
46 | 6570.0,1667991157413,11.193359375
47 | 6716.0,1667991162146,11.25
48 | 6862.0,1667991166847,11.5625
49 | 7008.0,1667991171562,11.326171875
50 | 7154.0,1667991176224,11.890625
51 | 7300.0,1667991180981,11.90625
52 | 7446.0,1667991185732,11.916015625
53 | 7592.0,1667991190399,12.271484375
54 | 7738.0,1667991195114,11.97265625
55 | 7884.0,1667991199826,12.419921875
56 | 8030.0,1667991204493,12.728515625
57 | 8176.0,1667991209205,13.078125
58 | 8322.0,1667991214038,12.76171875
59 | 8468.0,1667991218773,13.279296875
60 | 8614.0,1667991223385,13.30859375
61 | 8760.0,1667991228138,13.70703125
62 | 8906.0,1667991232843,13.744140625
63 | 9052.0,1667991237629,14.0625
64 | 9198.0,1667991242342,13.998046875
65 | 9344.0,1667991247069,13.9296875
66 | 9490.0,1667991251826,14.623046875
67 | 9636.0,1667991256593,14.494140625
68 | 9782.0,1667991261376,14.705078125
69 | 9928.0,1667991266131,14.1328125
70 | 10074.0,1667991270872,14.677734375
71 | 10220.0,1667991275662,14.765625
72 | 10366.0,1667991280410,15.052734375
73 | 10512.0,1667991285244,15.21875
74 | 10658.0,1667991290031,15.06640625
75 | 10804.0,1667991294825,15.044921875
76 | 10950.0,1667991299609,15.283203125
77 | 11096.0,1667991304553,15.49609375
78 | 11242.0,1667991309331,15.607421875
79 | 11388.0,1667991314096,15.947265625
80 | 11534.0,1667991318770,15.580078125
81 | 11680.0,1667991323407,15.923828125
82 | 11826.0,1667991328133,15.4453125
83 | 11972.0,1667991332868,15.7109375
84 | 12118.0,1667991337699,15.78125
85 | 12264.0,1667991342398,16.42578125
86 | 12410.0,1667991347339,16.203125
87 | 12556.0,1667991352083,15.92578125
88 | 12702.0,1667991356872,15.810546875
89 | 12848.0,1667991361622,16.322265625
90 | 12994.0,1667991366434,16.486328125
91 | 13140.0,1667991371214,16.478515625
92 | 13286.0,1667991375937,15.986328125
93 | 13432.0,1667991380720,16.865234375
94 | 13578.0,1667991385572,16.5546875
95 | 13724.0,1667991390283,16.390625
96 | 13870.0,1667991395102,16.92578125
97 | 14016.0,1667991399841,16.8203125
98 | 14162.0,1667991404631,16.826171875
99 | 14308.0,1667991409331,16.689453125
100 | 14454.0,1667991414071,17.181640625
101 | 14600.0,1667991418798,16.630859375
102 | 14746.0,1667991423585,16.93359375
103 | 14892.0,1667991428374,17.185546875
104 | 15038.0,1667991433264,17.234375
105 | 15184.0,1667991438044,17.3359375
106 | 15330.0,1667991442932,17.123046875
107 | 15476.0,1667991447693,17.09765625
108 | 15622.0,1667991452475,17.470703125
109 | 15768.0,1667991457262,17.697265625
110 | 15914.0,1667991461967,16.7265625
111 | 16060.0,1667991466679,16.845703125
112 | 16206.0,1667991471428,17.474609375
113 | 16352.0,1667991476122,17.671875
114 | 16498.0,1667991480853,17.515625
115 | 16644.0,1667991485580,17.759765625
116 | 16790.0,1667991490366,17.287109375
117 | 16936.0,1667991495134,17.759765625
118 | 17082.0,1667991499978,18.0390625
119 | 17228.0,1667991504833,17.359375
120 | 17374.0,1667991509654,17.771484375
121 | 17520.0,1667991514421,18.08203125
122 | 17666.0,1667991519116,17.939453125
123 | 17812.0,1667991523916,17.93359375
124 | 17958.0,1667991528698,17.90625
125 | 18104.0,1667991533480,17.890625
126 | 18250.0,1667991538217,18.158203125
127 | 18396.0,1667991542963,18.001953125
128 | 18542.0,1667991547667,18.44140625
129 | 18688.0,1667991552512,18.70703125
130 | 18834.0,1667991557345,18.220703125
131 | 18980.0,1667991562149,18.453125
132 | 19126.0,1667991566940,18.51953125
133 | 19272.0,1667991571817,18.578125
134 | 19418.0,1667991576569,18.4375
135 | 19564.0,1667991581447,18.845703125
136 | 19710.0,1667991586289,18.8046875
137 | 19856.0,1667991591062,18.31640625
138 | 20002.0,1667991595848,18.169921875
139 | 20148.0,1667991600880,18.64453125
140 | 20294.0,1667991605803,18.5
141 | 20440.0,1667991610744,18.392578125
142 | 20586.0,1667991615568,18.87109375
143 | 20732.0,1667991620384,18.869140625
144 | 20878.0,1667991625163,18.3046875
145 | 21024.0,1667991630036,18.70703125
146 | 21170.0,1667991634834,18.951171875
147 | 21316.0,1667991639600,18.755859375
148 | 21462.0,1667991644588,18.5390625
149 | 21608.0,1667991649342,18.91015625
150 | 21754.0,1667991654198,18.625
151 | 21900.0,1667991659022,19.29296875
152 | 22046.0,1667991663911,18.865234375
153 | 22192.0,1667991668695,18.947265625
154 | 22338.0,1667991673446,18.875
155 | 22484.0,1667991678281,18.80859375
156 | 22630.0,1667991683035,18.8203125
157 | 22776.0,1667991687862,18.505859375
158 | 22922.0,1667991692626,18.64453125
159 | 23068.0,1667991697456,18.912109375
160 | 23214.0,1667991702198,19.240234375
161 | 23360.0,1667991706974,18.98828125
162 | 23506.0,1667991711699,18.845703125
163 | 23652.0,1667991716493,18.904296875
164 | 23798.0,1667991721318,19.40234375
165 | 23944.0,1667991726112,19.21875
166 | 24090.0,1667991730873,19.494140625
167 | 24236.0,1667991735641,18.853515625
168 | 24382.0,1667991740418,19.58203125
169 | 24528.0,1667991745359,19.25390625
170 | 24674.0,1667991750079,19.5546875
171 | 24820.0,1667991754934,19.34765625
172 | 24966.0,1667991759776,19.4296875
173 | 25112.0,1667991764663,19.888671875
174 | 25258.0,1667991769463,19.423828125
175 | 25404.0,1667991774231,19.443359375
176 | 25550.0,1667991778976,19.234375
177 | 25696.0,1667991783714,19.61328125
178 | 25842.0,1667991788490,19.671875
179 | 25988.0,1667991793263,19.51171875
180 | 26134.0,1667991798020,19.921875
181 | 26280.0,1667991802799,19.724609375
182 | 26426.0,1667991807516,19.607421875
183 | 26572.0,1667991812315,19.85546875
184 | 26718.0,1667991817102,19.48046875
185 | 26864.0,1667991821846,19.583984375
186 | 27010.0,1667991826575,20.162109375
187 | 27156.0,1667991831361,19.720703125
188 | 27302.0,1667991836128,20.16015625
189 | 27448.0,1667991840871,20.2109375
190 | 27594.0,1667991845610,20.15625
191 | 27740.0,1667991850416,20.48828125
192 | 27886.0,1667991855193,20.53125
193 | 28032.0,1667991859974,20.6640625
194 | 28178.0,1667991864791,20.453125
195 | 28324.0,1667991869531,20.693359375
196 | 28470.0,1667991874321,20.392578125
197 | 28616.0,1667991879030,20.24609375
198 | 28762.0,1667991883687,21.11328125
199 | 28908.0,1667991888340,20.8515625
200 | 29054.0,1667991893075,20.951171875
201 | 29200.0,1667991897900,20.712890625
202 |
--------------------------------------------------------------------------------
/snake/data/mgrl_return/MET-2194__eval_episode_reward_stochastic_policy.csv:
--------------------------------------------------------------------------------
1 | 3.0,1667994808583,0.14453125
2 | 146.0,1667994820732,0.150390625
3 | 292.0,1667994825389,0.302734375
4 | 438.0,1667994830210,2.18359375
5 | 584.0,1667994834899,3.470703125
6 | 730.0,1667994839442,3.947265625
7 | 876.0,1667994844002,4.34765625
8 | 1022.0,1667994848501,4.857421875
9 | 1168.0,1667994853026,5.02734375
10 | 1314.0,1667994857578,5.376953125
11 | 1460.0,1667994862142,5.427734375
12 | 1606.0,1667994866708,5.953125
13 | 1752.0,1667994871247,5.830078125
14 | 1898.0,1667994875855,6.421875
15 | 2044.0,1667994880478,6.12890625
16 | 2190.0,1667994884998,6.6171875
17 | 2336.0,1667994889598,6.470703125
18 | 2482.0,1667994894185,6.49609375
19 | 2628.0,1667994898789,6.759765625
20 | 2774.0,1667994903484,6.6953125
21 | 2920.0,1667994908027,6.931640625
22 | 3066.0,1667994912699,6.7734375
23 | 3212.0,1667994917287,6.951171875
24 | 3358.0,1667994921867,6.73046875
25 | 3504.0,1667994926409,7.064453125
26 | 3650.0,1667994931051,7.1484375
27 | 3796.0,1667994935605,7.65625
28 | 3942.0,1667994940155,7.44921875
29 | 4088.0,1667994944649,7.28125
30 | 4234.0,1667994949259,7.166015625
31 | 4380.0,1667994953758,7.685546875
32 | 4526.0,1667994958333,7.607421875
33 | 4672.0,1667994962874,8.34765625
34 | 4818.0,1667994967449,8.4453125
35 | 4964.0,1667994972072,8.509765625
36 | 5110.0,1667994976735,8.7890625
37 | 5256.0,1667994981287,8.990234375
38 | 5402.0,1667994986004,9.361328125
39 | 5548.0,1667994990537,9.1875
40 | 5694.0,1667994995321,9.34765625
41 | 5840.0,1667994999783,9.8828125
42 | 5986.0,1667995004410,9.734375
43 | 6132.0,1667995008975,10.173828125
44 | 6278.0,1667995013582,10.51953125
45 | 6424.0,1667995018223,10.58203125
46 | 6570.0,1667995022808,10.67578125
47 | 6716.0,1667995027392,10.90234375
48 | 6862.0,1667995032014,11.21484375
49 | 7008.0,1667995036534,11.455078125
50 | 7154.0,1667995041153,11.66796875
51 | 7300.0,1667995046018,11.7421875
52 | 7446.0,1667995050607,11.865234375
53 | 7592.0,1667995055266,11.87109375
54 | 7738.0,1667995059875,12.90625
55 | 7884.0,1667995064422,12.857421875
56 | 8030.0,1667995069057,12.43359375
57 | 8176.0,1667995073595,13.041015625
58 | 8322.0,1667995078140,12.9609375
59 | 8468.0,1667995082727,13.52734375
60 | 8614.0,1667995087354,13.013671875
61 | 8760.0,1667995091935,13.8671875
62 | 8906.0,1667995096597,14.09765625
63 | 9052.0,1667995101178,14.234375
64 | 9198.0,1667995105894,14.79296875
65 | 9344.0,1667995110459,14.681640625
66 | 9490.0,1667995115098,14.140625
67 | 9636.0,1667995119742,15.1484375
68 | 9782.0,1667995124307,15.029296875
69 | 9928.0,1667995128966,15.4609375
70 | 10074.0,1667995133552,15.04296875
71 | 10220.0,1667995138192,15.287109375
72 | 10366.0,1667995142819,15.2890625
73 | 10512.0,1667995147453,15.56640625
74 | 10658.0,1667995151992,15.841796875
75 | 10804.0,1667995156603,15.7265625
76 | 10950.0,1667995161173,15.96875
77 | 11096.0,1667995165774,15.666015625
78 | 11242.0,1667995170411,16.15234375
79 | 11388.0,1667995175034,16.220703125
80 | 11534.0,1667995179664,15.693359375
81 | 11680.0,1667995184208,15.755859375
82 | 11826.0,1667995188836,15.849609375
83 | 11972.0,1667995193411,16.2890625
84 | 12118.0,1667995197964,15.87109375
85 | 12264.0,1667995202965,16.388671875
86 | 12410.0,1667995207637,16.205078125
87 | 12556.0,1667995212352,16.326171875
88 | 12702.0,1667995216930,16.37890625
89 | 12848.0,1667995221510,16.822265625
90 | 12994.0,1667995226120,16.48828125
91 | 13140.0,1667995230762,15.78125
92 | 13286.0,1667995235455,16.77734375
93 | 13432.0,1667995240079,17.0
94 | 13578.0,1667995244599,16.650390625
95 | 13724.0,1667995249332,16.953125
96 | 13870.0,1667995253854,17.123046875
97 | 14016.0,1667995258470,16.673828125
98 | 14162.0,1667995263117,17.328125
99 | 14308.0,1667995267839,17.111328125
100 | 14454.0,1667995272407,17.916015625
101 | 14600.0,1667995277022,17.333984375
102 | 14746.0,1667995281712,17.625
103 | 14892.0,1667995286490,17.6953125
104 | 15038.0,1667995291249,17.47265625
105 | 15184.0,1667995296017,18.083984375
106 | 15330.0,1667995300639,17.287109375
107 | 15476.0,1667995305343,17.9296875
108 | 15622.0,1667995309914,17.29296875
109 | 15768.0,1667995314820,17.646484375
110 | 15914.0,1667995319458,17.8203125
111 | 16060.0,1667995324154,17.931640625
112 | 16206.0,1667995328786,17.521484375
113 | 16352.0,1667995333387,18.01953125
114 | 16498.0,1667995338089,18.126953125
115 | 16644.0,1667995342742,17.509765625
116 | 16790.0,1667995347353,17.64453125
117 | 16936.0,1667995351964,18.166015625
118 | 17082.0,1667995356618,17.48828125
119 | 17228.0,1667995361183,17.9453125
120 | 17374.0,1667995365758,18.189453125
121 | 17520.0,1667995370432,17.77734375
122 | 17666.0,1667995374936,18.22265625
123 | 17812.0,1667995379710,17.84765625
124 | 17958.0,1667995384314,18.435546875
125 | 18104.0,1667995389019,17.84765625
126 | 18250.0,1667995393656,18.71484375
127 | 18396.0,1667995398405,18.6875
128 | 18542.0,1667995402898,18.646484375
129 | 18688.0,1667995407610,18.90625
130 | 18834.0,1667995412158,18.720703125
131 | 18980.0,1667995416868,19.25
132 | 19126.0,1667995421547,18.716796875
133 | 19272.0,1667995426235,19.15625
134 | 19418.0,1667995430822,19.349609375
135 | 19564.0,1667995435537,18.59765625
136 | 19710.0,1667995440203,19.478515625
137 | 19856.0,1667995444915,19.84765625
138 | 20002.0,1667995449530,19.03515625
139 | 20148.0,1667995454159,19.533203125
140 | 20294.0,1667995458744,19.044921875
141 | 20440.0,1667995463423,19.830078125
142 | 20586.0,1667995468000,19.46875
143 | 20732.0,1667995472602,20.056640625
144 | 20878.0,1667995477213,19.70703125
145 | 21024.0,1667995481849,19.2265625
146 | 21170.0,1667995486567,19.734375
147 | 21316.0,1667995491295,19.56640625
148 | 21462.0,1667995495967,19.7734375
149 | 21608.0,1667995500636,19.705078125
150 | 21754.0,1667995505261,19.97265625
151 | 21900.0,1667995509963,20.078125
152 | 22046.0,1667995514559,19.912109375
153 | 22192.0,1667995519341,20.06640625
154 | 22338.0,1667995524086,19.703125
155 | 22484.0,1667995528676,19.763671875
156 | 22630.0,1667995533240,20.02734375
157 | 22776.0,1667995537932,20.322265625
158 | 22922.0,1667995542657,21.021484375
159 | 23068.0,1667995547380,20.265625
160 | 23214.0,1667995552062,20.8828125
161 | 23360.0,1667995556843,20.6640625
162 | 23506.0,1667995561419,20.44140625
163 | 23652.0,1667995566063,20.302734375
164 | 23798.0,1667995570685,20.529296875
165 | 23944.0,1667995575210,20.388671875
166 | 24090.0,1667995579931,20.1484375
167 | 24236.0,1667995584488,20.498046875
168 | 24382.0,1667995589118,20.29296875
169 | 24528.0,1667995593779,20.76171875
170 | 24674.0,1667995598459,20.83203125
171 | 24820.0,1667995603161,20.748046875
172 | 24966.0,1667995607887,20.955078125
173 | 25112.0,1667995612613,21.248046875
174 | 25258.0,1667995617384,20.912109375
175 | 25404.0,1667995622055,21.388671875
176 | 25550.0,1667995626741,21.650390625
177 | 25696.0,1667995631324,21.18359375
178 | 25842.0,1667995635950,21.33984375
179 | 25988.0,1667995640567,21.564453125
180 | 26134.0,1667995645224,21.91015625
181 | 26280.0,1667995649868,21.7578125
182 | 26426.0,1667995654537,21.501953125
183 | 26572.0,1667995659204,22.232421875
184 | 26718.0,1667995663877,22.193359375
185 | 26864.0,1667995668570,21.900390625
186 | 27010.0,1667995673218,22.49609375
187 | 27156.0,1667995677887,22.32421875
188 | 27302.0,1667995682651,22.583984375
189 | 27448.0,1667995687232,22.490234375
190 | 27594.0,1667995691810,22.90625
191 | 27740.0,1667995696507,22.841796875
192 | 27886.0,1667995701109,22.845703125
193 | 28032.0,1667995705835,21.994140625
194 | 28178.0,1667995710488,22.865234375
195 | 28324.0,1667995715128,23.298828125
196 | 28470.0,1667995719694,23.322265625
197 | 28616.0,1667995724292,23.078125
198 | 28762.0,1667995728951,23.833984375
199 | 28908.0,1667995733568,23.427734375
200 | 29054.0,1667995738181,23.6015625
201 | 29200.0,1667995742754,23.5625
202 |
--------------------------------------------------------------------------------
/snake/env/__init__.py:
--------------------------------------------------------------------------------
1 | from snake.env.snake import make_snake_env
2 |
--------------------------------------------------------------------------------
/snake/env/snake.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import jumanji
4 | import jumanji.wrappers
5 |
6 |
7 | def make_snake_env() -> Tuple[jumanji.Environment, jumanji.Environment]:
8 | snake = jumanji.make("Snake-6x6-v0")
9 | eval_snake = snake
10 | snake = jumanji.wrappers.AutoResetWrapper(snake)
11 | snake = jumanji.wrappers.VmapWrapper(snake)
12 | return snake, eval_snake
13 |
--------------------------------------------------------------------------------
/snake/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from snake.networks.snake import make_actor_critic_networks_snake
2 |
--------------------------------------------------------------------------------
/snake/networks/actor_critic.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, NamedTuple, Optional
2 |
3 | import chex
4 | import haiku as hk
5 | from jax import numpy as jnp
6 |
7 | from snake.networks.distribution import ParametricDistribution
8 |
9 |
10 | class FeedForwardNetwork(NamedTuple):
11 | init: Callable[
12 | [chex.PRNGKey, chex.Array, Optional[jnp.float_]],
13 | hk.Params,
14 | ]
15 | apply: Callable[
16 | [hk.Params, chex.Array, Optional[jnp.float_]],
17 | chex.Array,
18 | ]
19 |
20 |
21 | class ActorCriticNetworks(NamedTuple):
22 | """Defines the actor-critic networks, which outputs the logits of a policy, and a value given
23 | an observation.
24 | """
25 |
26 | policy_network: FeedForwardNetwork
27 | value_network: FeedForwardNetwork
28 | outer_value_network: Optional[FeedForwardNetwork]
29 | parametric_action_distribution: ParametricDistribution
30 |
--------------------------------------------------------------------------------
/snake/networks/cnn.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Sequence
2 |
3 | import chex
4 | import haiku as hk
5 | import jax
6 | from jax import numpy as jnp
7 |
8 | from snake.networks.actor_critic import ActorCriticNetworks, FeedForwardNetwork
9 | from snake.networks.distribution import CategoricalParametricDistribution
10 |
11 |
12 | def make_actor_critic_networks_cnn(
13 | num_actions: int,
14 | num_channels: int,
15 | policy_layers: Sequence[int],
16 | value_layers: Sequence[int],
17 | outer_critic: bool,
18 | embedding_size_actor: Optional[int],
19 | embedding_size_critic: Optional[int],
20 | ) -> ActorCriticNetworks:
21 | """Make actor-critic networks for Snake."""
22 |
23 | parametric_action_distribution = CategoricalParametricDistribution(
24 | num_actions=num_actions
25 | )
26 | policy_network = make_network_cnn(
27 | num_outputs=num_actions,
28 | mlp_units=policy_layers,
29 | conv_n_channels=num_channels,
30 | embedding_size=embedding_size_actor,
31 | )
32 | value_network = make_network_cnn(
33 | num_outputs=1,
34 | mlp_units=value_layers,
35 | conv_n_channels=num_channels,
36 | embedding_size=embedding_size_critic,
37 | )
38 | if outer_critic:
39 | outer_value_network: Optional[FeedForwardNetwork] = make_network_cnn(
40 | num_outputs=1,
41 | mlp_units=value_layers,
42 | conv_n_channels=num_channels,
43 | embedding_size=embedding_size_critic,
44 | )
45 | else:
46 | outer_value_network = None
47 | return ActorCriticNetworks(
48 | policy_network=policy_network,
49 | value_network=value_network,
50 | outer_value_network=outer_value_network,
51 | parametric_action_distribution=parametric_action_distribution,
52 | )
53 |
54 |
55 | def make_network_cnn(
56 | num_outputs: int,
57 | mlp_units: Sequence[int],
58 | conv_n_channels: int,
59 | embedding_size: Optional[int],
60 | ) -> FeedForwardNetwork:
61 | def network_fn(
62 | observation: chex.Array,
63 | discount_factor: Optional[jnp.float_],
64 | ) -> chex.Array:
65 | torso = hk.Sequential(
66 | [
67 | hk.Conv2D(conv_n_channels, (2, 2), 2),
68 | jax.nn.relu,
69 | hk.Conv2D(conv_n_channels, (2, 2), 1),
70 | jax.nn.relu,
71 | hk.Flatten(),
72 | ]
73 | )
74 | if observation.ndim == 5:
75 | torso = jax.vmap(torso)
76 | x = torso(observation)
77 | if embedding_size:
78 | assert discount_factor is not None
79 | discount_factor = jnp.broadcast_to(discount_factor, (*x.shape[:-1], 1))
80 | embedding = hk.Linear(embedding_size)(discount_factor)
81 | x = jnp.concatenate([x, embedding], axis=-1)
82 | head = hk.nets.MLP((*mlp_units, num_outputs), activate_final=False)
83 | if num_outputs == 1:
84 | return jnp.squeeze(head(x), axis=-1)
85 | else:
86 | return head(x)
87 |
88 | init, apply = hk.without_apply_rng(hk.transform(network_fn))
89 | return FeedForwardNetwork(init=init, apply=apply)
90 |
--------------------------------------------------------------------------------
/snake/networks/distribution.py:
--------------------------------------------------------------------------------
1 | """Copied from Brax and adapted with typing."""
2 | import abc
3 | from typing import Any
4 |
5 | import chex
6 | import jax
7 | from jax import numpy as jnp
8 |
9 |
10 | class Postprocessor(abc.ABC):
11 | def forward(self, x: chex.Array) -> chex.Array:
12 | raise NotImplementedError
13 |
14 | def inverse(self, y: chex.Array) -> chex.Array:
15 | raise NotImplementedError
16 |
17 | def forward_log_det_jacobian(self, x: chex.Array) -> chex.Array:
18 | raise NotImplementedError
19 |
20 |
21 | class ParametricDistribution(abc.ABC):
22 | """Abstract class for parametric (action) distribution."""
23 |
24 | def __init__(
25 | self,
26 | param_size: int,
27 | postprocessor: Postprocessor,
28 | event_ndims: int,
29 | reparametrizable: bool,
30 | ):
31 | """Abstract class for parametric (action) distribution.
32 | Specifies how to transform distribution parameters (i.e. actor output)
33 | into a distribution over actions.
34 | Args:
35 | param_size: size of the parameters for the distribution
36 | postprocessor: bijector which is applied after sampling (in practice, it's
37 | tanh or identity)
38 | event_ndims: rank of the distribution sample (i.e. action)
39 | reparametrizable: is the distribution reparametrizable
40 | """
41 | self._param_size = param_size
42 | self._postprocessor = postprocessor
43 | self._event_ndims = event_ndims # rank of events
44 | self._reparametrizable = reparametrizable
45 | assert event_ndims in [0, 1]
46 |
47 | @abc.abstractmethod
48 | def create_dist(self, parameters: chex.Array) -> Any:
49 | """Creates distribution from parameters."""
50 |
51 | @property
52 | def param_size(self) -> int:
53 | return self._param_size
54 |
55 | @property
56 | def reparametrizable(self) -> bool:
57 | return self._reparametrizable
58 |
59 | def postprocess(self, event: chex.Array) -> chex.Array:
60 | return self._postprocessor.forward(event)
61 |
62 | def inverse_postprocess(self, event: chex.Array) -> chex.Array:
63 | return self._postprocessor.inverse(event)
64 |
65 | def sample_no_postprocessing(
66 | self, parameters: chex.Array, seed: chex.PRNGKey
67 | ) -> Any:
68 | return self.create_dist(parameters).sample(seed=seed)
69 |
70 | def sample(self, parameters: chex.Array, seed: chex.PRNGKey) -> chex.Array:
71 | """Returns a sample from the postprocessed distribution."""
72 | return self.postprocess(self.sample_no_postprocessing(parameters, seed))
73 |
74 | def mode(self, parameters: chex.Array) -> chex.Array:
75 | """Returns the mode of the postprocessed distribution."""
76 | return self.postprocess(self.create_dist(parameters).mode())
77 |
78 | def log_prob(self, parameters: chex.Array, raw_actions: chex.Array) -> chex.Array:
79 | """Compute the log probability of actions."""
80 | dist = self.create_dist(parameters)
81 | log_probs = dist.log_prob(raw_actions)
82 | log_probs -= self._postprocessor.forward_log_det_jacobian(raw_actions)
83 | if self._event_ndims == 1:
84 | log_probs = jnp.sum(log_probs, axis=-1) # sum over action dimension
85 | return log_probs
86 |
87 | def entropy(self, parameters: chex.Array, seed: chex.PRNGKey) -> chex.Array:
88 | """Return the entropy of the given distribution."""
89 | dist = self.create_dist(parameters)
90 | entropy = dist.entropy()
91 | entropy += self._postprocessor.forward_log_det_jacobian(dist.sample(seed=seed))
92 | if self._event_ndims == 1:
93 | entropy = jnp.sum(entropy, axis=-1)
94 | return entropy
95 |
96 | def kl_divergence(
97 | self, parameters: chex.Array, other_parameters: chex.Array
98 | ) -> chex.Array:
99 | """KL divergence is invariant with respect to transformation by the same bijector."""
100 | dist = self.create_dist(parameters)
101 | other_dist = self.create_dist(other_parameters)
102 | return dist.kl_divergence(other_dist)
103 |
104 |
105 | class IdentityBijector(Postprocessor):
106 | """Identity Bijector."""
107 |
108 | def forward(self, x: chex.Array) -> chex.Array:
109 | return x
110 |
111 | def inverse(self, y: chex.Array) -> chex.Array:
112 | return y
113 |
114 | def forward_log_det_jacobian(self, x: chex.Array) -> chex.Array:
115 | return jnp.zeros_like(x, x.dtype)
116 |
117 |
118 | class CategoricalDistribution:
119 | """Categorical distribution."""
120 |
121 | def __init__(self, logits: chex.Array):
122 | self.logits = logits
123 | self.num_actions = jnp.shape(logits)[-1]
124 |
125 | def sample(self, seed: chex.PRNGKey) -> chex.Array:
126 | return jax.random.categorical(seed, self.logits)
127 |
128 | def mode(self) -> chex.Array:
129 | return jnp.argmax(self.logits, axis=-1)
130 |
131 | def log_prob(self, x: chex.Array) -> chex.Array:
132 | value_one_hot = jax.nn.one_hot(x, self.num_actions)
133 | mask_outside_domain = jnp.logical_or(x < 0, x > self.num_actions - 1)
134 | safe_log_probs = jnp.where(
135 | value_one_hot == 0,
136 | jnp.zeros((), dtype=self.logits.dtype),
137 | jax.nn.log_softmax(self.logits) * value_one_hot,
138 | )
139 | return jnp.where(
140 | mask_outside_domain,
141 | -jnp.inf,
142 | jnp.sum(safe_log_probs, axis=-1),
143 | )
144 |
145 | def entropy(self) -> chex.Array:
146 | log_probs = jax.nn.log_softmax(self.logits)
147 | probs = jnp.exp(log_probs)
148 | return -jnp.sum(jnp.where(probs == 0, 0.0, probs * log_probs), axis=-1)
149 |
150 | def kl_divergence(self, other: "CategoricalDistribution") -> chex.Array:
151 | log_probs = jax.nn.log_softmax(self.logits)
152 | probs = jnp.exp(log_probs)
153 | log_probs_other = jax.nn.log_softmax(other.logits)
154 | return jnp.sum(
155 | jnp.where(probs == 0, 0.0, probs * (log_probs - log_probs_other)), axis=-1
156 | )
157 |
158 |
159 | class CategoricalParametricDistribution(ParametricDistribution):
160 | """Categorical distribution for discrete action spaces."""
161 |
162 | def __init__(self, num_actions: int):
163 | """Initialize the distribution.
164 | Args:
165 | num_actions: the number of actions.
166 | """
167 | postprocessor: Postprocessor
168 | postprocessor = IdentityBijector()
169 | super().__init__(
170 | param_size=num_actions,
171 | postprocessor=postprocessor,
172 | event_ndims=0,
173 | reparametrizable=True,
174 | )
175 |
176 | def create_dist(self, parameters: chex.Array) -> CategoricalDistribution:
177 | return CategoricalDistribution(logits=parameters)
178 |
--------------------------------------------------------------------------------
/snake/networks/snake.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Sequence
2 |
3 | from snake.networks import cnn
4 | from snake.networks.actor_critic import ActorCriticNetworks
5 |
6 |
7 | def make_actor_critic_networks_snake(
8 | num_channels: int,
9 | policy_layers: Sequence[int],
10 | value_layers: Sequence[int],
11 | outer_critic: bool,
12 | embedding_size_actor: Optional[int],
13 | embedding_size_critic: Optional[int],
14 | ) -> ActorCriticNetworks:
15 | """Make actor-critic networks for Snake."""
16 |
17 | return cnn.make_actor_critic_networks_cnn(
18 | num_actions=4,
19 | num_channels=num_channels,
20 | policy_layers=policy_layers,
21 | value_layers=value_layers,
22 | outer_critic=outer_critic,
23 | embedding_size_actor=embedding_size_actor,
24 | embedding_size_critic=embedding_size_critic,
25 | )
26 |
--------------------------------------------------------------------------------
/snake/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/instadeepai/outer-value-function-meta-rl/eda5463f51daa1adeb6cc8e8621d6b5bb4da26d3/snake/training/__init__.py
--------------------------------------------------------------------------------
/snake/training/config.py:
--------------------------------------------------------------------------------
1 | from typing import NamedTuple, Optional, Sequence
2 |
3 | import omegaconf
4 |
5 |
6 | class Config(NamedTuple):
7 | name: str
8 | agent: str
9 | meta_objective: Optional[str]
10 | outer_critic: bool
11 | num_timesteps: int
12 | num_eval_points: int
13 | total_batch_size: int
14 | total_num_envs: int
15 | n_steps: int
16 | total_num_eval: int
17 | deterministic_eval: bool
18 | network_type: str
19 | num_channels: int
20 | policy_layers: Sequence[int]
21 | value_layers: Sequence[int]
22 | embedding_size_actor: Optional[int]
23 | embedding_size_critic: Optional[int]
24 | optimizer: str
25 | learning_rate: float
26 | gradient_clip_norm: Optional[float]
27 | bootstrap_l_optimizer: Optional[str]
28 | bootstrap_l_learning_rate: Optional[float]
29 | meta_optimizer: Optional[str]
30 | meta_learning_rate: Optional[float]
31 | meta_gradient_clip_norm: Optional[float]
32 | normalize_advantage: bool
33 | normalize_outer_advantage: Optional[bool]
34 | reward_scaling: float
35 | gamma_init: float
36 | gamma_outer: Optional[float]
37 | gamma_metal: Optional[bool]
38 | lambda_init: float
39 | lambda_outer: Optional[float]
40 | lambda_metal: Optional[bool]
41 | l_pg_init: float
42 | l_pg_outer: Optional[float]
43 | l_pg_metal: Optional[bool]
44 | l_td_init: float
45 | l_td_outer: Optional[float]
46 | l_td_metal: Optional[bool]
47 | l_en_init: float
48 | l_en_outer: Optional[float]
49 | l_en_metal: Optional[bool]
50 | l_kl_outer: Optional[float]
51 | bootstrap_l: Optional[int]
52 | bootstrap_pm: Optional[float]
53 | bootstrap_vm: Optional[float]
54 | seed: int
55 |
56 |
57 | def convert_config(cfg: omegaconf.DictConfig) -> Config:
58 | return Config(
59 | name=cfg.training.name,
60 | agent=cfg.agent.agent,
61 | meta_objective=cfg.agent.meta_objective,
62 | num_timesteps=cfg.training.num_timesteps,
63 | num_eval_points=cfg.training.num_eval_points,
64 | reward_scaling=cfg.training.reward_scaling,
65 | n_steps=cfg.training.n_steps,
66 | total_batch_size=cfg.training.total_batch_size,
67 | total_num_envs=cfg.training.total_num_envs,
68 | optimizer=cfg.training.optimizer,
69 | learning_rate=float(cfg.training.learning_rate),
70 | bootstrap_l_optimizer=cfg.agent.bootstrap_l_optimizer,
71 | bootstrap_l_learning_rate=cfg.agent.bootstrap_l_learning_rate,
72 | meta_optimizer=cfg.agent.meta_optimizer,
73 | meta_learning_rate=cfg.agent.meta_learning_rate,
74 | gradient_clip_norm=cfg.training.gradient_clip_norm,
75 | meta_gradient_clip_norm=cfg.agent.meta_gradient_clip_norm,
76 | gamma_init=float(cfg.training.gamma_init),
77 | gamma_outer=float(cfg.agent.gamma_outer)
78 | if cfg.agent.gamma_outer is not None
79 | else None,
80 | gamma_metal=cfg.agent.gamma_metal,
81 | lambda_init=float(cfg.training.lambda_init),
82 | lambda_outer=float(cfg.agent.lambda_outer)
83 | if cfg.agent.lambda_outer is not None
84 | else None,
85 | lambda_metal=cfg.agent.lambda_metal,
86 | l_pg_init=float(cfg.training.l_pg_init),
87 | l_pg_outer=float(cfg.agent.l_pg_outer)
88 | if cfg.agent.l_pg_outer is not None
89 | else None,
90 | l_pg_metal=cfg.agent.l_pg_metal,
91 | l_td_init=float(cfg.training.l_td_init),
92 | l_td_outer=float(cfg.agent.l_td_outer)
93 | if cfg.agent.l_td_outer is not None
94 | else None,
95 | l_td_metal=cfg.agent.l_td_metal,
96 | l_en_init=float(cfg.training.l_en_init),
97 | l_en_outer=float(cfg.agent.l_en_outer)
98 | if cfg.agent.l_en_outer is not None
99 | else None,
100 | l_en_metal=cfg.agent.l_en_metal,
101 | l_kl_outer=float(cfg.agent.l_kl_outer)
102 | if cfg.agent.l_kl_outer is not None
103 | else None,
104 | bootstrap_l=cfg.agent.bootstrap_l,
105 | bootstrap_pm=cfg.agent.bootstrap_pm,
106 | bootstrap_vm=cfg.agent.bootstrap_vm,
107 | normalize_advantage=cfg.agent.normalize_advantage,
108 | normalize_outer_advantage=cfg.agent.normalize_outer_advantage,
109 | total_num_eval=cfg.agent.total_num_eval,
110 | deterministic_eval=cfg.agent.deterministic_eval,
111 | policy_layers=cfg.network.policy_layers,
112 | value_layers=cfg.network.value_layers,
113 | outer_critic=cfg.agent.outer_critic,
114 | seed=cfg.training.seed,
115 | embedding_size_actor=cfg.network.embedding_size_actor,
116 | embedding_size_critic=cfg.network.embedding_size_critic,
117 | network_type=cfg.network.type,
118 | num_channels=cfg.network.num_channels,
119 | )
120 |
--------------------------------------------------------------------------------
/snake/training/evaluator.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from typing import Callable, Dict, Optional, Tuple
3 |
4 | import chex
5 | import haiku as hk
6 | import jax
7 | import jumanji.types
8 | from jax import numpy as jnp
9 |
10 | from snake.agent import ActorCriticAgent
11 | from snake.training.types import TrainingState
12 |
13 |
14 | class Evaluator:
15 | """Class to run evaluations."""
16 |
17 | def __init__(
18 | self,
19 | eval_env: jumanji.Environment,
20 | actor_critic_agent: ActorCriticAgent,
21 | total_num_eval: int,
22 | key: chex.PRNGKey,
23 | deterministic: bool,
24 | ):
25 | self._eval_env = eval_env
26 | self._actor_critic_agent = actor_critic_agent
27 |
28 | num_devices = jax.local_device_count()
29 | self._num_devices = num_devices
30 | assert total_num_eval % num_devices == 0
31 | self._total_num_eval = total_num_eval
32 | self._num_eval_per_device = total_num_eval // num_devices
33 |
34 | self._key = key
35 | self._deterministic = deterministic
36 | self.generate_eval_unroll = jax.pmap(
37 | self._generate_eval_unroll, axis_name="devices"
38 | )
39 |
40 | def _unroll_one_step(
41 | self,
42 | policy: Callable[[chex.Array, chex.PRNGKey], Tuple[chex.Array, Dict]],
43 | carry: Tuple[jumanji.env.State, jumanji.types.TimeStep],
44 | key: chex.PRNGKey,
45 | ) -> Tuple[Tuple[jumanji.env.State, jumanji.types.TimeStep], None]:
46 | state, timestep = carry
47 | observation = jax.tree_util.tree_map(
48 | lambda x: x[None, ...], timestep.observation
49 | )
50 | action, _ = policy(observation, key)
51 | next_state, next_timestep = self._eval_env.step(state, jnp.squeeze(action))
52 | return (next_state, next_timestep), None
53 |
54 | def _generate_eval_one_episode(
55 | self,
56 | policy_params: hk.Params,
57 | discount_factor: Optional[jnp.float_],
58 | key: chex.PRNGKey,
59 | ) -> Tuple[Dict, Dict]:
60 | stochastic_policy = self._actor_critic_agent.make_policy(
61 | policy_params,
62 | discount_factor,
63 | deterministic=False,
64 | )
65 | determinist_policy = self._actor_critic_agent.make_policy(
66 | policy_params,
67 | discount_factor,
68 | deterministic=True,
69 | )
70 |
71 | def cond_fun(
72 | carry: Tuple[
73 | jumanji.env.State,
74 | jumanji.types.TimeStep,
75 | chex.PRNGKey,
76 | jnp.float32,
77 | jnp.int32,
78 | ]
79 | ) -> jnp.bool_:
80 | _, timestep, *_ = carry
81 | return ~timestep.last()
82 |
83 | def body_fun(
84 | policy: Callable[[chex.Array, chex.PRNGKey], Tuple[chex.Array, Dict]],
85 | carry: Tuple[
86 | jumanji.env.State,
87 | jumanji.types.TimeStep,
88 | chex.PRNGKey,
89 | jnp.float32,
90 | jnp.int32,
91 | ],
92 | ) -> Tuple[
93 | jumanji.env.State,
94 | jumanji.types.TimeStep,
95 | chex.PRNGKey,
96 | jnp.float32,
97 | jnp.int32,
98 | ]:
99 | state, timestep, key, return_, count = carry
100 | key, step_key = jax.random.split(key)
101 | (state, timestep), _ = self._unroll_one_step(
102 | policy, (state, timestep), step_key
103 | )
104 | return_ += timestep.reward
105 | count += 1
106 | return state, timestep, key, return_, count
107 |
108 | (
109 | reset_key_stochastic,
110 | reset_key_determinist,
111 | init_key_stochastic,
112 | init_key_determinist,
113 | ) = jax.random.split(key, 4)
114 | state_stochastic, timestep_stochastic = self._eval_env.reset(
115 | reset_key_stochastic
116 | )
117 | _, _, _, return_stochastic, count_stochastic = jax.lax.while_loop(
118 | cond_fun,
119 | functools.partial(body_fun, stochastic_policy),
120 | (
121 | state_stochastic,
122 | timestep_stochastic,
123 | init_key_stochastic,
124 | jnp.float32(0),
125 | jnp.int32(0),
126 | ),
127 | )
128 | eval_metrics_stochastic = {
129 | "episode_reward": return_stochastic,
130 | "episode_length": count_stochastic,
131 | }
132 | state_determinist, timestep_determinist = self._eval_env.reset(
133 | reset_key_stochastic
134 | )
135 | _, _, _, return_deterministic, count_deterministic = jax.lax.while_loop(
136 | cond_fun,
137 | functools.partial(body_fun, determinist_policy),
138 | (
139 | state_determinist,
140 | timestep_determinist,
141 | init_key_determinist,
142 | jnp.float32(0),
143 | jnp.int32(0),
144 | ),
145 | )
146 | eval_metrics_deterministic = {
147 | "episode_reward": return_deterministic,
148 | "episode_length": count_deterministic,
149 | }
150 | return eval_metrics_stochastic, eval_metrics_deterministic
151 |
152 | def _generate_eval_unroll(
153 | self,
154 | policy_params: hk.Params,
155 | discount_factor: Optional[jnp.float_],
156 | key: chex.PRNGKey,
157 | ) -> Dict:
158 |
159 | keys = jax.random.split(key, self._num_eval_per_device)
160 | eval_metrics = jax.vmap(
161 | self._generate_eval_one_episode, in_axes=(None, None, 0)
162 | )(
163 | policy_params,
164 | discount_factor,
165 | keys,
166 | )
167 | eval_metrics: Dict = jax.lax.pmean(
168 | jax.tree_util.tree_map(jnp.mean, eval_metrics),
169 | axis_name="devices",
170 | )
171 |
172 | return eval_metrics
173 |
174 | def run_evaluation(self, training_state: TrainingState) -> Dict:
175 | """Run one epoch of evaluation."""
176 | self._key, unroll_key = jax.random.split(self._key)
177 |
178 | unroll_keys = jax.random.split(unroll_key, self._num_devices)
179 | if training_state.meta_params is not None:
180 | discount_factor = jax.nn.sigmoid(training_state.meta_params.gamma)
181 | else:
182 | discount_factor = None
183 | eval_metrics_stochastic, eval_metrics_determinist = self.generate_eval_unroll(
184 | training_state.params.actor,
185 | discount_factor,
186 | unroll_keys,
187 | )
188 | stochastic_metrics = eval_metrics_stochastic
189 | determinist_metrics = eval_metrics_determinist
190 | metrics = {
191 | **{
192 | key + "_stochastic_policy": value
193 | for key, value in stochastic_metrics.items()
194 | },
195 | **{
196 | key + "_determinist_policy": value
197 | for key, value in determinist_metrics.items()
198 | },
199 | }
200 | return metrics
201 |
--------------------------------------------------------------------------------
/snake/training/logger.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import Any, Callable, Dict, Mapping, Optional
3 |
4 | import chex
5 | import numpy as np
6 | from neptune import new as neptune
7 |
8 |
9 | class Logger(abc.ABC):
10 | # copied from Acme
11 | """A logger has a `write` method."""
12 |
13 | @abc.abstractmethod
14 | def write(self, data: Mapping[str, Any], *args: Any, **kwargs: Any) -> None:
15 | """Writes `data` to destination (file, terminal, database, etc)."""
16 |
17 | @abc.abstractmethod
18 | def save_checkpoint(self, file_name: str) -> None:
19 | """Saves a checkpoint."""
20 |
21 | @abc.abstractmethod
22 | def close(self) -> None:
23 | """Closes the logger, not expecting any further write."""
24 |
25 | def __enter__(self) -> "Logger":
26 | return self
27 |
28 | def __exit__(
29 | self, exc_type: Exception, exc_val: Exception, exc_tb: Exception
30 | ) -> None:
31 | self.close()
32 |
33 |
34 | class NeptuneLogger(Logger):
35 | def __init__(
36 | self,
37 | config: Dict[str, Any],
38 | aggregation_fn: Callable[[chex.Array], chex.Array],
39 | seed: Optional[int] = None,
40 | **kwargs: Any,
41 | ):
42 | self.run = neptune.init(
43 | project="", # Change the project name to your Neptune project
44 | name=config["name"] + (f"_seed_{seed}" if seed is not None else ""),
45 | **kwargs,
46 | )
47 | self.run["config"] = config
48 | self._t: int = 0
49 | del aggregation_fn # Moved to logging mean and last sample instead.
50 | self.mean_fn = np.mean
51 |
52 | def downsample(x: np.ndarray) -> np.ndarray:
53 | return x[-1] # type: ignore
54 |
55 | self.downsample_fn = downsample
56 |
57 | def write( # noqa: CCR001
58 | self,
59 | data: Mapping[str, Any],
60 | label: str = "",
61 | timestep: Optional[int] = None,
62 | *args: Any,
63 | **kwargs: Any,
64 | ) -> None:
65 | self._t = timestep or self._t + 1
66 | prefix = label and f"{label}/"
67 | for key, metric in data.items():
68 | if np.ndim(metric) == 0:
69 | if not np.isnan(metric):
70 | self.run[f"{prefix}/{key}"].log(
71 | float(metric),
72 | step=self._t,
73 | wait=True,
74 | )
75 | elif np.ndim(metric) == 1:
76 | metric_value_mean = self.mean_fn(metric)
77 | if not np.isnan(metric_value_mean):
78 | self.run[f"{prefix}mean/{key}"].log(
79 | metric_value_mean.item(),
80 | step=self._t,
81 | wait=True,
82 | )
83 | metric_value_sample = self.downsample_fn(metric)
84 | if not np.isnan(metric_value_sample):
85 | self.run[f"{prefix}sample/{key}"].log(
86 | metric_value_sample.item(),
87 | step=self._t,
88 | wait=True,
89 | )
90 | else:
91 | raise ValueError(
92 | f"Expected metric to be 0 or 1 dimension, got {metric}."
93 | )
94 |
95 | def save_checkpoint(self, file_name: str) -> None:
96 | self.run[f"checkpoints/{file_name}"].upload(file_name)
97 |
98 | def close(self) -> None:
99 | self.run.stop()
100 |
101 |
102 | class TerminalLogger(Logger):
103 | def __init__(
104 | self, aggregation_fn: Callable[[chex.Array], chex.Array], **kwargs: Any
105 | ):
106 | self.aggregation_fn = aggregation_fn
107 | print(">>> Terminal Logger")
108 |
109 | def write(
110 | self,
111 | data: Mapping[str, Any],
112 | label: str = "",
113 | timestep: Optional[int] = None,
114 | *args: Any,
115 | **kwargs: Any,
116 | ) -> None:
117 | gamma = data.get("gamma", None)
118 | return_ = data.get("episode_reward_stochastic_policy", None)
119 | if timestep is not None:
120 | print_str = f"\nTimestep {timestep:.2e} >>> "
121 | else:
122 | print_str = "\n"
123 | if return_ is not None:
124 | print_str += f"mean_return: {self.aggregation_fn(return_):.2f} "
125 | if gamma is not None:
126 | print_str += f"discount_factor: {self.aggregation_fn(gamma):.5f}"
127 | print(print_str)
128 |
129 | def save_checkpoint(self, file_name: str) -> None:
130 | pass
131 |
132 | def close(self) -> None:
133 | pass
134 |
135 |
136 | def make_logger_factory(
137 | logger: str,
138 | config_dict: Dict[str, Any],
139 | aggregation_behaviour: str = "mean",
140 | seed: Optional[int] = None,
141 | ) -> Callable[[], Logger]:
142 | if aggregation_behaviour == "mean":
143 | aggregation_fn = np.mean
144 | else:
145 | raise ValueError(
146 | f"aggregation_behaviour is expected to be 'mean', got {aggregation_behaviour} instead."
147 | )
148 |
149 | def make_logger() -> Logger:
150 | if logger == "neptune":
151 | return NeptuneLogger(config_dict, aggregation_fn, seed)
152 | elif logger == "terminal":
153 | return TerminalLogger(aggregation_fn)
154 | else:
155 | raise ValueError(
156 | f"expected logger in ['neptune', 'terminal'], got {logger}."
157 | )
158 |
159 | return make_logger
160 |
--------------------------------------------------------------------------------
/snake/training/setup_run.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import jax
4 | import jumanji
5 | import optax
6 |
7 | from snake.agent import A2C, ActorCriticAgent, MetaA2C
8 | from snake.env import make_snake_env
9 | from snake.networks import make_actor_critic_networks_snake
10 | from snake.networks.actor_critic import ActorCriticNetworks
11 | from snake.training.config import Config
12 | from snake.training.evaluator import Evaluator
13 | from snake.training.types import HyperParams, Metal
14 |
15 |
16 | def setup(config: Config) -> Tuple[ActorCriticAgent, Evaluator]:
17 | env, eval_env = make_snake_env()
18 |
19 | agent = make_agent(config, env)
20 | evaluator = Evaluator(
21 | eval_env=eval_env,
22 | actor_critic_agent=agent,
23 | total_num_eval=config.total_num_eval,
24 | key=jax.random.PRNGKey(config.seed),
25 | deterministic=config.deterministic_eval,
26 | )
27 | return agent, evaluator
28 |
29 |
30 | def make_agent( # noqa: CCR001
31 | config: Config, env: jumanji.Environment
32 | ) -> ActorCriticAgent:
33 |
34 | actor_critic_networks: ActorCriticNetworks
35 | if config.network_type == "snake":
36 | actor_critic_networks = make_actor_critic_networks_snake(
37 | num_channels=config.num_channels,
38 | policy_layers=config.policy_layers,
39 | value_layers=config.value_layers,
40 | outer_critic=config.outer_critic,
41 | embedding_size_actor=config.embedding_size_actor,
42 | embedding_size_critic=config.embedding_size_critic,
43 | )
44 | else:
45 | raise ValueError
46 |
47 | if config.optimizer == "sgd":
48 | optimizer = optax.sgd(config.learning_rate)
49 | elif config.optimizer == "rmsprop":
50 | optimizer = optax.rmsprop(config.learning_rate, decay=0.9)
51 | elif config.optimizer == "adam":
52 | optimizer = optax.adam(config.learning_rate, eps_root=1e-8)
53 | else:
54 | raise ValueError(
55 | f"Expected optimizer to be in ['sgd', 'rmsprop', 'adam'], got "
56 | f"{config.optimizer} instead."
57 | )
58 |
59 | if config.gradient_clip_norm is not None:
60 | optimizer = optax.chain(
61 | optax.clip_by_global_norm(config.gradient_clip_norm),
62 | optimizer,
63 | )
64 | agent: ActorCriticAgent
65 | if config.agent == "a2c":
66 | agent = A2C(
67 | n_steps=config.n_steps,
68 | total_batch_size=config.total_batch_size,
69 | total_num_envs=config.total_num_envs,
70 | env=env,
71 | actor_critic_networks=actor_critic_networks,
72 | optimizer=optimizer,
73 | normalize_advantage=config.normalize_advantage,
74 | reward_scaling=config.reward_scaling,
75 | hyper_params=HyperParams(
76 | gamma=config.gamma_init,
77 | lambda_=config.lambda_init,
78 | l_pg=config.l_pg_init,
79 | l_td=config.l_td_init,
80 | l_en=config.l_en_init,
81 | ),
82 | env_type=config.network_type,
83 | )
84 | elif config.agent == "meta_a2c":
85 | assert config.meta_learning_rate is not None
86 | if config.meta_optimizer == "sgd":
87 | meta_optimizer = optax.sgd(config.meta_learning_rate)
88 | elif config.meta_optimizer == "rmsprop":
89 | meta_optimizer = optax.rmsprop(config.meta_learning_rate, decay=0.9)
90 | elif config.meta_optimizer == "adam":
91 | meta_optimizer = optax.adam(config.meta_learning_rate, eps_root=1e-8)
92 | else:
93 | raise ValueError(
94 | f"Expected meta_optimizer to be in ['sgd', 'rmsprop', 'adam'], got "
95 | f"{config.meta_optimizer} instead."
96 | )
97 | if config.meta_gradient_clip_norm is not None:
98 | meta_optimizer = optax.chain(
99 | optax.clip_by_global_norm(config.meta_gradient_clip_norm),
100 | meta_optimizer,
101 | )
102 |
103 | if config.bootstrap_l_optimizer is not None:
104 | if config.bootstrap_l_optimizer == "sgd":
105 | bootstrap_l_optimizer = optax.sgd(config.bootstrap_l_learning_rate)
106 | elif config.bootstrap_l_optimizer == "rmsprop":
107 | bootstrap_l_optimizer = optax.rmsprop(
108 | config.bootstrap_l_learning_rate, decay=0.9
109 | )
110 | elif config.bootstrap_l_optimizer == "adam":
111 | bootstrap_l_optimizer = optax.adam(
112 | config.bootstrap_l_learning_rate, eps_root=1e-8
113 | )
114 | else:
115 | raise ValueError(
116 | f"Expected bootstrap_l_optimizer to be in ['sgd', 'rmsprop', 'adam'], got "
117 | f"{config.bootstrap_l_optimizer} instead."
118 | )
119 | else:
120 | bootstrap_l_optimizer = None
121 |
122 | hyper_params_init = HyperParams(
123 | gamma=config.gamma_init,
124 | lambda_=config.lambda_init,
125 | l_pg=config.l_pg_init,
126 | l_td=config.l_td_init,
127 | l_en=config.l_en_init,
128 | )
129 | outer_hyper_params = HyperParams(
130 | gamma=config.gamma_outer,
131 | lambda_=config.lambda_outer,
132 | l_pg=config.l_pg_outer,
133 | l_td=config.l_td_outer,
134 | l_en=config.l_en_outer,
135 | )
136 | assert config.normalize_outer_advantage is not None
137 | metal = Metal(
138 | gamma=(lambda _: _) if config.gamma_metal else jax.lax.stop_gradient,
139 | lambda_=(lambda _: _) if config.lambda_metal else jax.lax.stop_gradient,
140 | l_pg=(lambda _: _) if config.l_pg_metal else jax.lax.stop_gradient,
141 | l_td=(lambda _: _) if config.l_td_metal else jax.lax.stop_gradient,
142 | l_en=(lambda _: _) if config.l_en_metal else jax.lax.stop_gradient,
143 | )
144 | agent = MetaA2C(
145 | n_steps=config.n_steps,
146 | total_batch_size=config.total_batch_size,
147 | total_num_envs=config.total_num_envs,
148 | env=env,
149 | actor_critic_networks=actor_critic_networks,
150 | optimizer=optimizer,
151 | bootstrap_l_optimizer=bootstrap_l_optimizer,
152 | meta_optimizer=meta_optimizer,
153 | normalize_advantage=config.normalize_advantage,
154 | normalize_outer_advantage=config.normalize_outer_advantage,
155 | reward_scaling=config.reward_scaling,
156 | hyper_params_init=hyper_params_init,
157 | outer_hyper_params=outer_hyper_params,
158 | meta_objective=config.meta_objective,
159 | bootstrap_l=config.bootstrap_l,
160 | bootstrap_pm=config.bootstrap_pm,
161 | bootstrap_vm=config.bootstrap_vm,
162 | metal=metal,
163 | l_kl_outer=config.l_kl_outer,
164 | env_type=config.network_type,
165 | )
166 | else:
167 | raise ValueError(
168 | f"Expected agent in ['a2c', 'meta_a2c'], got {config.agent} instead."
169 | )
170 | return agent
171 |
--------------------------------------------------------------------------------
/snake/training/types.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Dict, NamedTuple, Optional
2 |
3 | import chex
4 | import haiku as hk
5 | import jumanji.types
6 | import optax
7 | from jax import numpy as jnp
8 |
9 |
10 | class HyperParams(NamedTuple):
11 | gamma: jnp.float_
12 | lambda_: jnp.float_
13 | l_pg: jnp.float_ # policy gradient loss cost
14 | l_td: jnp.float_ # temporal difference loss cost
15 | l_en: jnp.float_ # entropy loss cost
16 |
17 |
18 | MetaParams = HyperParams
19 |
20 |
21 | class Metal(NamedTuple):
22 | gamma: Callable[[chex.Array], chex.Array]
23 | lambda_: Callable[[chex.Array], chex.Array]
24 | l_pg: Callable[[chex.Array], chex.Array]
25 | l_td: Callable[[chex.Array], chex.Array]
26 | l_en: Callable[[chex.Array], chex.Array]
27 |
28 |
29 | class ActorCriticParams(NamedTuple):
30 | actor: hk.Params
31 | critic: hk.Params
32 | outer_critic: Optional[hk.Params]
33 |
34 |
35 | class TrainingState(NamedTuple):
36 | """Contains training state for the learner."""
37 |
38 | params: ActorCriticParams
39 | meta_params: Optional[MetaParams]
40 | optimizer_state: optax.OptState
41 | meta_optimizer_state: Optional[optax.OptState]
42 | env_steps: jnp.int32
43 |
44 |
45 | class ActingState(NamedTuple):
46 | """Container for data used during the acting in the environment."""
47 |
48 | env_state: jumanji.env.State
49 | timestep: jumanji.types.TimeStep
50 | acting_key: chex.PRNGKey
51 |
52 |
53 | class Transition(NamedTuple):
54 | """Container for a transition."""
55 |
56 | observation: chex.Array
57 | action: chex.Array
58 | reward: jnp.float_
59 | discount: jnp.float_
60 | truncation: jnp.bool_
61 | next_observation: chex.Array
62 | extras: Dict
63 |
64 |
65 | class State(NamedTuple):
66 | """Container for TrainingState and ActingState."""
67 |
68 | training_state: TrainingState
69 | acting_state: ActingState
70 |
--------------------------------------------------------------------------------
/snake/training/utils.py:
--------------------------------------------------------------------------------
1 | from typing import TypeVar
2 |
3 | import jax
4 |
5 | T = TypeVar("T")
6 |
7 |
8 | def first_from_device(tree: T) -> T:
9 | return jax.tree_util.tree_map(lambda x: x[0], tree) # type: ignore
10 |
--------------------------------------------------------------------------------
/snake_train.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import logging
3 | import pickle
4 | from typing import Dict, Tuple
5 |
6 | import hydra
7 | import jax
8 | import omegaconf
9 | from tqdm import trange
10 |
11 | from snake.agent import A2C, meta_params_to_hyper_params
12 | from snake.training import utils
13 | from snake.training.config import convert_config
14 | from snake.training.logger import make_logger_factory
15 | from snake.training.setup_run import setup
16 | from snake.training.types import State
17 |
18 |
19 | @hydra.main(config_path="snake/configs", config_name="config.yaml")
20 | def run(cfg: omegaconf.DictConfig) -> None:
21 | print(omegaconf.OmegaConf.to_yaml(cfg))
22 | logging.getLogger().setLevel(logging.INFO)
23 | logging.info({"devices": jax.local_devices()})
24 |
25 | config = convert_config(cfg)
26 |
27 | agent, evaluator = setup(config)
28 | # Can be either 'neptune' or 'terminal'.
29 | make_logger = make_logger_factory(
30 | "terminal", config._asdict(), aggregation_behaviour="mean"
31 | )
32 |
33 | num_timesteps_per_learner_inner_update_step = (
34 | config.total_batch_size * config.n_steps
35 | )
36 | num_timesteps_per_epoch = config.num_timesteps // config.num_eval_points
37 | num_learner_steps_per_epoch = (
38 | num_timesteps_per_epoch // num_timesteps_per_learner_inner_update_step
39 | )
40 |
41 | state = agent.init(jax.random.PRNGKey(config.seed))
42 |
43 | @functools.partial(jax.pmap, axis_name="devices")
44 | def epoch_fn(state: State) -> Tuple[State, Dict]:
45 | state, metrics = jax.lax.scan(
46 | lambda s, _: agent.update(s), state, None, num_learner_steps_per_epoch
47 | )
48 | return state, metrics
49 |
50 | with make_logger() as logger, trange(
51 | config.num_eval_points + 1
52 | ) as epochs, jax.log_compiles():
53 | num_learner_updates = 0
54 |
55 | # Log first hyper_params
56 | if state.training_state.meta_params is not None:
57 | initial_hyper_params = meta_params_to_hyper_params(
58 | utils.first_from_device(state.training_state.meta_params)
59 | )
60 | else:
61 | assert isinstance(agent, A2C)
62 | initial_hyper_params = agent.hyper_params
63 | logger.write(initial_hyper_params._asdict(), "train/mean/", num_learner_updates)
64 | logger.write(
65 | initial_hyper_params._asdict(), "train/sample/", num_learner_updates
66 | )
67 |
68 | for _ in epochs:
69 |
70 | # Checkpoint
71 | with open(f"state_{num_learner_updates:.2e}.pickle", "wb") as file_:
72 | pickle.dump(utils.first_from_device(state), file_)
73 | logger.save_checkpoint(file_.name)
74 |
75 | # Validation
76 | metrics = evaluator.run_evaluation(state.training_state)
77 | logger.write(utils.first_from_device(metrics), "eval", num_learner_updates)
78 |
79 | # Training steps
80 | state, metrics = epoch_fn(state)
81 | num_learner_updates += num_learner_steps_per_epoch
82 | logger.write(utils.first_from_device(metrics), "train", num_learner_updates)
83 |
84 |
85 | if __name__ == "__main__":
86 | run()
87 |
--------------------------------------------------------------------------------