├── tests ├── __init__.py └── test_env.py ├── nfq ├── __init__.py ├── networks.py └── agents.py ├── paper.pdf ├── environments ├── __init__.py └── cartpole.py ├── cartpole.conf ├── .gitmodules ├── .travis.yml ├── requirements-dev.txt ├── setup.cfg ├── Makefile ├── requirements.txt ├── .pre-commit-config.yaml ├── LICENSE ├── .gitignore ├── README.md └── train_eval.py /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | -------------------------------------------------------------------------------- /nfq/__init__.py: -------------------------------------------------------------------------------- 1 | """Neural Fitted Q-Iteration.""" 2 | -------------------------------------------------------------------------------- /paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seungjaeryanlee/implementations-nfq/HEAD/paper.pdf -------------------------------------------------------------------------------- /environments/__init__.py: -------------------------------------------------------------------------------- 1 | """Environments specified in NFQ paper.""" 2 | # flake8: noqa 3 | from .cartpole import CartPoleRegulatorEnv 4 | -------------------------------------------------------------------------------- /cartpole.conf: -------------------------------------------------------------------------------- 1 | EPOCH = 2000 2 | TRAIN_ENV_MAX_STEPS = 100 3 | EVAL_ENV_MAX_STEPS = 3000 4 | DISCOUNT = 0.95 5 | INIT_EXPERIENCE = 0 6 | 7 | INCREMENT_EXPERIENCE 8 | HINT_TO_GOAL 9 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "tests/utils"] 2 | path = tests/utils 3 | url = https://github.com/seungjaeryanlee/implementations-utils-tests.git 4 | [submodule "utils"] 5 | path = utils 6 | url = https://github.com/seungjaeryanlee/implementations-utils.git 7 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.6" 4 | install: 5 | - pip install -r requirements.txt --quiet 6 | - pip install -r requirements-dev.txt --quiet 7 | script: 8 | - black --check . 9 | - flake8 10 | - isort **/*.py -c -vb 11 | - pytest 12 | cache: pip 13 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | black>=19.3b0 2 | flake8>=3.7.8 3 | flake8-bugbear>=19.3.0 4 | flake8-docstrings>=1.3.0 5 | isort>=4.3.21 6 | pytest>=5.0.1 7 | seed-isort-config>=1.9.2 8 | pre-commit==1.17.0 9 | 10 | # https://github.com/pytest-dev/pytest-cov/issues/252 11 | pytest-remotedata>=0.3.1 12 | # https://github.com/PyCQA/pydocstyle/issues/375 13 | pydocstyle<4.0.0 14 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | select = C,E,F,W,B,B950,D 4 | ignore = 5 | E203, 6 | E501, 7 | W503, 8 | D101, # Missing docstring in public class 9 | D105, 10 | D107, 11 | D202, # No blank lines allowed after function docstring 12 | exclude = 13 | .git, 14 | __pycache__, 15 | .ipynb_checkpoints, 16 | 17 | [isort] 18 | known_third_party=configargparse,gym,numpy,pytest,torch 19 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Install dependencies 2 | dep: 3 | pip install -r requirements.txt 4 | 5 | # Install developer dependencies 6 | dev: 7 | pip install -r requirements.txt 8 | pip install -r requirements-dev.txt 9 | pre-commit install 10 | 11 | # Format code with black and isort 12 | format: 13 | black . 14 | seed-isort-config 15 | isort -y 16 | 17 | # Test code with black, flake8, isort, mypy, and pytest. 18 | test: 19 | pytest -v 20 | black --check . 21 | isort **/*.py -c 22 | flake8 23 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # ConfigArgParse allows using config files with argparse. 2 | ConfigArgParse==0.14.0 3 | 4 | # OpenAI Gym is de-facto standard for RL agent-environment interface. 5 | gym==0.14.0 6 | 7 | # PyTorch v1.1+ supports TensorBoard. 8 | torch==1.1.0 9 | 10 | # TensorBoard v1.14+ supports PyTorch. 11 | tensorboard==1.14.0 12 | 13 | # Weights & Biases allows visualizing data online. 14 | wandb==0.8.5 15 | 16 | # coloredlogs provide colors to Python's logger. 17 | coloredlogs==10.0 18 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.1.0 # This is pre-commit-hooks version, NOT flake8 version! 4 | hooks: 5 | - id: flake8 6 | - repo: https://github.com/ambv/black 7 | rev: stable 8 | hooks: 9 | - id: black 10 | - repo: https://github.com/asottile/seed-isort-config 11 | rev: v1.5.0 12 | hooks: 13 | - id: seed-isort-config 14 | - repo: https://github.com/pre-commit/mirrors-isort 15 | rev: 'v4.3.21' 16 | hooks: 17 | - id: isort 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Seungjae Ryan Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /nfq/networks.py: -------------------------------------------------------------------------------- 1 | """Networks for NFQ.""" 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class NFQNetwork(nn.Module): 7 | def __init__(self): 8 | """Networks for NFQ.""" 9 | super().__init__() 10 | self.layers = nn.Sequential( 11 | nn.Linear(5, 5), 12 | nn.Sigmoid(), 13 | nn.Linear(5, 5), 14 | nn.Sigmoid(), 15 | nn.Linear(5, 1), 16 | nn.Sigmoid(), 17 | ) 18 | 19 | # Initialize weights to [-0.5, 0.5] 20 | def init_weights(m): 21 | if type(m) == nn.Linear: 22 | torch.nn.init.uniform_(m.weight, -0.5, 0.5) 23 | # TODO(seungjaeryanlee): What about bias? 24 | 25 | self.layers.apply(init_weights) 26 | 27 | def forward(self, x: torch.Tensor) -> torch.Tensor: 28 | """ 29 | Forward propagation. 30 | 31 | Parameters 32 | ---------- 33 | x : torch.Tensor 34 | Input tensor of observation and action concatenated. 35 | 36 | Returns 37 | ------- 38 | y : torch.Tensor 39 | Forward-propagated observation predicting Q-value. 40 | 41 | """ 42 | return self.layers(x) 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## Custom 2 | tensorboard_logs/ 3 | wandb/ 4 | saves/ 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | -------------------------------------------------------------------------------- /tests/test_env.py: -------------------------------------------------------------------------------- 1 | """Stub unit tests.""" 2 | import numpy as np 3 | import pytest 4 | 5 | from environments import CartPoleRegulatorEnv as Env 6 | 7 | 8 | class TestCartPoleRegulatorEnv: 9 | def test_train_mode_reset(self): 10 | """Test reset() in train mode.""" 11 | train_env = Env(mode="train") 12 | x, x_, theta, theta_ = train_env.reset() 13 | 14 | assert abs(x) <= 2.3 15 | assert x_ == 0 16 | assert abs(theta) <= 0.3 17 | assert theta_ == 0 18 | 19 | def test_eval_mode_reset(self): 20 | """Test reset() in eval mode.""" 21 | eval_env = Env(mode="eval") 22 | x, x_, theta, theta_ = eval_env.reset() 23 | 24 | assert abs(x) <= 1.0 25 | assert x_ == 0 26 | assert abs(theta) <= 0.3 27 | assert theta_ == 0 28 | 29 | @pytest.mark.parametrize("env", [Env(mode="train"), Env(mode="eval")]) 30 | def test_get_goal_pattern_set(self, env): 31 | """Test get_goal_pattern_set().""" 32 | goal_state_action_b, goal_target_q_values = env.get_goal_pattern_set() 33 | 34 | for x, _, theta, _, action in goal_state_action_b: 35 | assert abs(x) <= env.x_success_range 36 | assert abs(theta) <= env.theta_success_range 37 | assert action in [0, 1] 38 | for target in goal_target_q_values: 39 | assert target == 0 40 | 41 | @pytest.mark.parametrize("env", [Env(mode="train"), Env(mode="eval")]) 42 | @pytest.mark.parametrize("get_best_action", [None, lambda x: 0]) 43 | def test_generate_rollout_next_obs(self, env, get_best_action): 44 | """Test generate_rollout() generates continued observation.""" 45 | env = Env(mode="train") 46 | rollout, episode_cost = env.generate_rollout(get_best_action=None) 47 | 48 | prev_next_obs = rollout[0][3] 49 | for obs, _, _, next_obs, _ in rollout[1:]: 50 | assert np.array_equal(prev_next_obs, obs) 51 | prev_next_obs = next_obs 52 | 53 | @pytest.mark.parametrize("env", [Env(mode="train"), Env(mode="eval")]) 54 | @pytest.mark.parametrize("get_best_action", [None, lambda x: 0]) 55 | def test_generate_rollout_cost_threshold(self, env, get_best_action): 56 | """Test generate_rollout() does not have a cost over 1.""" 57 | env = Env(mode="train") 58 | rollout, episode_cost = env.generate_rollout(get_best_action=None) 59 | 60 | for (_, _, cost, _, _) in rollout: 61 | assert 0 <= cost <= 1 62 | 63 | @pytest.mark.parametrize("env", [Env(mode="train"), Env(mode="eval")]) 64 | @pytest.mark.parametrize("get_best_action", [None, lambda x: 0]) 65 | def test_generate_rollout_episode_cost(self, env, get_best_action): 66 | """Test generate_rollout()'s second return value episode_cost.""" 67 | env = Env(mode="train") 68 | rollout, episode_cost = env.generate_rollout(get_best_action=None) 69 | 70 | total_cost = 0 71 | for _, _, cost, _, _ in rollout: 72 | total_cost += cost 73 | assert episode_cost == total_cost 74 | 75 | @pytest.mark.parametrize("env", [Env(mode="train"), Env(mode="eval")]) 76 | @pytest.mark.parametrize("get_best_action", [None, lambda x: 0]) 77 | def test_generate_rollout_with_random_action_done_value(self, env, get_best_action): 78 | """Test done values of generate_rollout()e.""" 79 | env = Env(mode="train") 80 | rollout, episode_cost = env.generate_rollout(get_best_action) 81 | 82 | for i, (_, _, _, _, done) in enumerate(rollout): 83 | if i + 1 < len(rollout): 84 | assert not done 85 | else: 86 | assert done or len(rollout) == env.max_steps 87 | 88 | @pytest.mark.parametrize("env", [Env(mode="train"), Env(mode="eval")]) 89 | def test_generate_rollout_gest_best_action(self, env): 90 | """Test generate_rollout() uses get_best_action correctly.""" 91 | env = Env(mode="train") 92 | rollout, _ = env.generate_rollout(get_best_action=lambda x: 0) 93 | 94 | for _, action, _, _, _ in rollout: 95 | assert action == 0 96 | -------------------------------------------------------------------------------- /nfq/agents.py: -------------------------------------------------------------------------------- 1 | """Reinforcement learning agents.""" 2 | from typing import List, Tuple 3 | 4 | import gym 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | 11 | 12 | class NFQAgent: 13 | def __init__(self, nfq_net: nn.Module, optimizer: optim.Optimizer): 14 | """ 15 | Neural Fitted Q-Iteration agent. 16 | 17 | Parameters 18 | ---------- 19 | nfq_net : nn.Module 20 | The Q-Network that returns estimated cost given observation and action. 21 | optimizer : optim.Optimzer 22 | Optimizer for training the NFQ network. 23 | 24 | """ 25 | self._nfq_net = nfq_net 26 | self._optimizer = optimizer 27 | 28 | def get_best_action(self, obs: np.array) -> int: 29 | """ 30 | Return best action for given observation according to the neural network. 31 | 32 | Parameters 33 | ---------- 34 | obs : np.array 35 | An observation to find the best action for. 36 | 37 | Returns 38 | ------- 39 | action : int 40 | The action chosen by greedy selection. 41 | 42 | """ 43 | q_left = self._nfq_net( 44 | torch.cat([torch.FloatTensor(obs), torch.FloatTensor([0])], dim=0) 45 | ) 46 | q_right = self._nfq_net( 47 | torch.cat([torch.FloatTensor(obs), torch.FloatTensor([1])], dim=0) 48 | ) 49 | 50 | # Best action has lower "Q" value since it estimates cumulative cost. 51 | return 1 if q_left >= q_right else 0 52 | 53 | def generate_pattern_set( 54 | self, 55 | rollouts: List[Tuple[np.array, int, int, np.array, bool]], 56 | gamma: float = 0.95, 57 | ): 58 | """Generate pattern set. 59 | 60 | Parameters 61 | ---------- 62 | rollouts : list of tuple 63 | Generated rollouts, which is a tuple of state, action, cost, next state, and done. 64 | gamma : float 65 | Discount factor. Defaults to 0.95. 66 | 67 | Returns 68 | ------- 69 | pattern_set : tuple of torch.Tensor 70 | Pattern set to train the NFQ network. 71 | 72 | """ 73 | # _b denotes batch 74 | state_b, action_b, cost_b, next_state_b, done_b = zip(*rollouts) 75 | state_b = torch.FloatTensor(state_b) 76 | action_b = torch.FloatTensor(action_b) 77 | cost_b = torch.FloatTensor(cost_b) 78 | next_state_b = torch.FloatTensor(next_state_b) 79 | done_b = torch.FloatTensor(done_b) 80 | 81 | state_action_b = torch.cat([state_b, action_b.unsqueeze(1)], 1) 82 | assert state_action_b.shape == (len(rollouts), state_b.shape[1] + 1) 83 | 84 | # Compute min_a Q(s', a) 85 | q_next_state_left_b = self._nfq_net( 86 | torch.cat([next_state_b, torch.zeros(len(rollouts), 1)], 1) 87 | ).squeeze() 88 | q_next_state_right_b = self._nfq_net( 89 | torch.cat([next_state_b, torch.ones(len(rollouts), 1)], 1) 90 | ).squeeze() 91 | q_next_state_b = torch.min(q_next_state_left_b, q_next_state_right_b) 92 | 93 | # If goal state (S+): target = 0 + gamma * min Q 94 | # If forbidden state (S-): target = 1 95 | # If neither: target = c_trans + gamma * min Q 96 | # NOTE(seungjaeryanlee): done is True only when the episode terminated 97 | # due to entering forbidden state. It is not 98 | # True if it terminated due to maximum timestep. 99 | with torch.no_grad(): 100 | target_q_values = cost_b + gamma * q_next_state_b * (1 - done_b) 101 | 102 | return state_action_b, target_q_values 103 | 104 | def train(self, pattern_set: Tuple[torch.Tensor, torch.Tensor]) -> float: 105 | """Train neural network with a given pattern set. 106 | 107 | Parameters 108 | ---------- 109 | pattern_set : tuple of torch.Tensor 110 | Pattern set to train the NFQ network. 111 | 112 | Returns 113 | ------- 114 | loss : float 115 | Training loss. 116 | 117 | """ 118 | state_action_b, target_q_values = pattern_set 119 | predicted_q_values = self._nfq_net(state_action_b).squeeze() 120 | loss = F.mse_loss(predicted_q_values, target_q_values) 121 | 122 | self._optimizer.zero_grad() 123 | loss.backward() 124 | self._optimizer.step() 125 | 126 | return loss.item() 127 | 128 | def evaluate(self, eval_env: gym.Env, render: bool) -> Tuple[int, str, float]: 129 | """Evaluate NFQ agent on evaluation environment. 130 | 131 | Parameters 132 | ---------- 133 | eval_env : gym.Env 134 | Environment to evaluate the agent. 135 | render: bool 136 | If true, render environment. 137 | 138 | Returns 139 | ------- 140 | episode_length : int 141 | Number of steps the agent took. 142 | success : bool 143 | True if the agent was terminated due to max timestep. 144 | episode_cost : float 145 | Total cost accumulated from the evaluation episode. 146 | 147 | """ 148 | episode_length = 0 149 | obs = eval_env.reset() 150 | done = False 151 | info = {"time_limit": False} 152 | episode_cost = 0 153 | while not done and not info["time_limit"]: 154 | action = self.get_best_action(obs) 155 | obs, cost, done, info = eval_env.step(action) 156 | episode_cost += cost 157 | episode_length += 1 158 | 159 | if render: 160 | eval_env.render() 161 | 162 | success = ( 163 | episode_length == eval_env.max_steps 164 | and abs(obs[0]) <= eval_env.x_success_range 165 | ) 166 | 167 | return episode_length, success, episode_cost 168 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Fitted Q Iteration - First Experiences with a Data Efficient Neural Reinforcement Learning Method 2 | 3 | [![black Build Status](https://img.shields.io/travis/com/seungjaeryanlee/implementations-nfq.svg?label=black)](https://black.readthedocs.io/en/stable/) 4 | [![flake8 Build Status](https://img.shields.io/travis/com/seungjaeryanlee/implementations-nfq.svg?label=flake8)](http://flake8.pycqa.org/en/latest/) 5 | [![isort Build Status](https://img.shields.io/travis/com/seungjaeryanlee/implementations-nfq.svg?label=isort)](https://pypi.org/project/isort/) 6 | [![pytest Build Status](https://img.shields.io/travis/com/seungjaeryanlee/implementations-nfq.svg?label=pytest)](https://docs.pytest.org/en/latest/) 7 | 8 | [![numpydoc Docstring Style](https://img.shields.io/badge/docstring-numpydoc-blue.svg)](https://numpydoc.readthedocs.io/en/latest/format.html#docstring-standard) 9 | [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-blue.svg)](https://pre-commit.com/) 10 | 11 | This repository is an implementation of the paper [Neural Fitted Q Iteration - First Experiences with a Data Efficient Neural Reinforcement Learning Method (Riedmiller, 2005)](/paper.pdf). 12 | 13 | **Please ⭐ this repository if you found it useful!** 14 | 15 | 16 | --- 17 | 18 | ### Table of Contents 📜 19 | 20 | - [Summary](#summary-) 21 | - [Installation](#installation-) 22 | - [Running](#running-) 23 | - [Results](#results-) 24 | - [Differences from the Paper](#differences-from-the-paper-) 25 | - [Reproducibility](#reproducibility-) 26 | 27 | For implementations of other deep learning papers, check the **[implementations](https://github.com/seungjaeryanlee/implementations) repository**! 28 | 29 | --- 30 | 31 | ### Summary 📝 32 | 33 | Neural Fitted Q-Iteration used a deep neural network for a Q-network, with its input being observation (s) and action (a) and its output being its action value (Q(s, a)). Instead of online Q-learning, the paper proposes **batch offline updates** by collecting experience throughout the episode and updating with that batch. The paper also suggests **hint-to-goal** method, where the neural network is trained explicitly in goal regions so that it can correctly estimate the value of the goal region. 34 | 35 | ### Installation 🧱 36 | 37 | First, clone this repository from GitHub. Since this repository contains submodules, you should use the `--recursive` flag. 38 | 39 | ```bash 40 | git clone --recursive https://github.com/seungjaeryanlee/implementations-nfq.git 41 | ``` 42 | 43 | If you already cloned the repository without the flag, you can download the submodules separately with the `git submodules` command: 44 | 45 | ```bash 46 | git clone https://github.com/seungjaeryanlee/implementations-nfq.git 47 | git submodule update --init --recursive 48 | ``` 49 | 50 | After cloing the repository, use the [requirements.txt](/requirements.txt) for simple installation of PyPI packages. 51 | 52 | ```bash 53 | pip install -r requirements.txt 54 | ``` 55 | 56 | You can read more about each package in the comments of the [requirements.txt](/requirements.txt) file! 57 | 58 | ### Running 🏃 59 | 60 | You can train the NFQ agent on Cartpole Regulator using the given configuration file with the below command: 61 | ``` 62 | python train_eval.py -c cartpole.conf 63 | ``` 64 | 65 | For a reproducible run, use the `--RANDOM_SEED` flag. 66 | ``` 67 | python train_eval.py -c cartpole.conf --RANDOM_SEED=1 68 | ``` 69 | 70 | To save a trained agent, use the `--SAVE_PATH` flag. 71 | ``` 72 | python train_eval.py -c cartpole.conf --SAVE_PATH=saves/cartpole.pth 73 | ``` 74 | 75 | To load a trained agent, use the `--LOAD_PATH` flag. 76 | ``` 77 | python train_eval.py -c cartpole.conf --LOAD_PATH=saves/cartpole.pth 78 | ``` 79 | 80 | To enable logging to TensorBoard or W&B, use appropriate flags. 81 | ``` 82 | python train_eval.py -c cartpole.conf --USE_TENSORBOARD --USE_WANDB 83 | ``` 84 | 85 | ### Results 📊 86 | 87 | This repository uses **TensorBoard** for offline logging and **Weights & Biases** for online logging. You can see the all the metrics in [my summary report at Weights & Biases](https://app.wandb.ai/seungjaeryanlee/implementations-nfq/reports?view=seungjaeryanlee%2FSummary)! 88 | 89 |

90 | Train Episode Length 91 | Evaluation Episode Length 92 |

93 |

94 | Train Episode Cost 95 | Evaluation Episode Cost 96 |

97 |

98 | Total Cycle 99 | Total Cost 100 | Train Loss 101 |

102 | 103 | ### Differences from the Paper 👥 104 | 105 | - From the 3 environments (Pole Balancing, Mountain Car, Cartpole Regulator), only the Cartpole Regulator environment was implemented and tested. It is the most difficult environment. 106 | - For the Cartpole Regulator, the success state is relaxed so that the state is successful whenever the pole angle is at most 24 degrees away from upright position. In the original paper, the cart must also be in the center with 0.05 tolerance. 107 | - Evaluation of the trained policy is only done in 1 evaluation environment, instead of 1000. 108 | 109 | ### Reproducibility 🎯 110 | 111 | Despite having no open-source code, the paper had sufficient details to implement NFQ. However, the results were not fully reproducible: we had to relax the definition of goal states and simplify evaluation. Still, the agent was able to learn to balance a CartPole for 3000 steps while only training from 100-step environment. 112 | 113 | Few nits: 114 | 115 | - There is no specification of pole angle for goal and forbidden states. We set 0~24 degrees from upright position as a requirement for goal state and any state with 90+ degrees forbidden. 116 | - The paper randomly initializes network weights within [−0.5, 0.5], but does not mention bias initialization. 117 | - The goal velocity of the success states is not mentioned. We use a normal distribution to randomly generate velocities for the hint-to-goal variant. 118 | - It is unclear whether to add experience after or before training the agent for each epoch. We assume adding experience before training. 119 | - The learning rate for the Rprop optimizer is not specified. 120 | 121 | -------------------------------------------------------------------------------- /train_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Implement Neural Fitted Q-Iteration. 3 | 4 | http://ml.informatik.uni-freiburg.de/former/_media/publications/rieecml05.pdf 5 | 6 | 7 | Running 8 | ------- 9 | You can train the NFQ agent on CartPole Regulator with the inluded 10 | configuration file with the below command: 11 | ``` 12 | python train_eval.py -c cartpole.conf 13 | ``` 14 | 15 | For a reproducible run, use the RANDOM_SEED flag. 16 | ``` 17 | python train_eval.py -c cartpole.conf --RANDOM_SEED=1 18 | ``` 19 | 20 | To save a trained agent, use the SAVE_PATH flag. 21 | ``` 22 | python train_eval.py -c cartpole.conf --SAVE_PATH=saves/cartpole.pth 23 | ``` 24 | 25 | To load a trained agent, use the LOAD_PATH flag. 26 | ``` 27 | python train_eval.py -c cartpole.conf --LOAD_PATH=saves/cartpole.pth 28 | ``` 29 | 30 | To enable logging to TensorBoard or W&B, use appropriate flags. 31 | ``` 32 | python train_eval.py -c cartpole.conf --USE_TENSORBOARD --USE_WANDB 33 | ``` 34 | 35 | 36 | Logging 37 | ------- 38 | 1. You can view runs online via Weights & Biases (wandb): 39 | https://app.wandb.ai/seungjaeryanlee/implementations-nfq/runs 40 | 41 | 2. You can use TensorBoard to view runs offline: 42 | ``` 43 | tensorboard --logdir=tensorboard_logs --port=2223 44 | ``` 45 | 46 | 47 | Glossary 48 | -------- 49 | env : Environment 50 | obs : Observation 51 | """ 52 | import configargparse 53 | import torch 54 | import torch.optim as optim 55 | 56 | from environments import CartPoleRegulatorEnv 57 | from nfq.agents import NFQAgent 58 | from nfq.networks import NFQNetwork 59 | from utils import get_logger, load_models, make_reproducible, save_models 60 | 61 | 62 | def main(): 63 | """Run NFQ.""" 64 | # Setup hyperparameters 65 | parser = configargparse.ArgParser() 66 | parser.add("-c", "--config", required=True, is_config_file=True) 67 | parser.add("--EPOCH", type=int) 68 | parser.add("--TRAIN_ENV_MAX_STEPS", type=int) 69 | parser.add("--EVAL_ENV_MAX_STEPS", type=int) 70 | parser.add("--DISCOUNT", type=float) 71 | parser.add("--INIT_EXPERIENCE", type=int) 72 | parser.add("--INCREMENT_EXPERIENCE", action="store_true") 73 | parser.add("--HINT_TO_GOAL", action="store_true") 74 | parser.add("--RANDOM_SEED", type=int) 75 | parser.add("--TRAIN_RENDER", action="store_true") 76 | parser.add("--EVAL_RENDER", action="store_true") 77 | parser.add("--SAVE_PATH", type=str, default="") 78 | parser.add("--LOAD_PATH", type=str, default="") 79 | parser.add("--USE_TENSORBOARD", action="store_true") 80 | parser.add("--USE_WANDB", action="store_true") 81 | CONFIG = parser.parse_args() 82 | if not hasattr(CONFIG, "INCREMENT_EXPERIENCE"): 83 | CONFIG.INCREMENT_EXPERIENCE = False 84 | if not hasattr(CONFIG, "HINT_TO_GOAL"): 85 | CONFIG.HINT_TO_GOAL = False 86 | if not hasattr(CONFIG, "TRAIN_RENDER"): 87 | CONFIG.TRAIN_RENDER = False 88 | if not hasattr(CONFIG, "EVAL_RENDER"): 89 | CONFIG.EVAL_RENDER = False 90 | if not hasattr(CONFIG, "USE_TENSORBOARD"): 91 | CONFIG.USE_TENSORBOARD = False 92 | if not hasattr(CONFIG, "USE_WANDB"): 93 | CONFIG.USE_WANDB = False 94 | 95 | print() 96 | print("+--------------------------------+--------------------------------+") 97 | print("| Hyperparameters | Value |") 98 | print("+--------------------------------+--------------------------------+") 99 | for arg in vars(CONFIG): 100 | print( 101 | "| {:30} | {:<30} |".format( 102 | arg, getattr(CONFIG, arg) if getattr(CONFIG, arg) is not None else "" 103 | ) 104 | ) 105 | print("+--------------------------------+--------------------------------+") 106 | print() 107 | 108 | # Log to File, Console, TensorBoard, W&B 109 | logger = get_logger() 110 | 111 | if CONFIG.USE_TENSORBOARD: 112 | from torch.utils.tensorboard import SummaryWriter 113 | 114 | writer = SummaryWriter(log_dir="tensorboard_logs") 115 | if CONFIG.USE_WANDB: 116 | import wandb 117 | 118 | wandb.init(project="implementations-nfq", config=CONFIG) 119 | 120 | # Setup environment 121 | train_env = CartPoleRegulatorEnv(mode="train") 122 | eval_env = CartPoleRegulatorEnv(mode="eval") 123 | 124 | # Fix random seeds 125 | if CONFIG.RANDOM_SEED is not None: 126 | make_reproducible(CONFIG.RANDOM_SEED, use_numpy=True, use_torch=True) 127 | train_env.seed(CONFIG.RANDOM_SEED) 128 | eval_env.seed(CONFIG.RANDOM_SEED) 129 | else: 130 | logger.warning("Running without a random seed: this run is NOT reproducible.") 131 | 132 | # Setup agent 133 | nfq_net = NFQNetwork() 134 | optimizer = optim.Rprop(nfq_net.parameters()) 135 | nfq_agent = NFQAgent(nfq_net, optimizer) 136 | 137 | # Load trained agent 138 | if CONFIG.LOAD_PATH: 139 | load_models(CONFIG.LOAD_PATH, nfq_net=nfq_net, optimizer=optimizer) 140 | 141 | # NFQ Main loop 142 | # A set of transition samples denoted as D 143 | all_rollouts = [] 144 | total_cost = 0 145 | if CONFIG.INIT_EXPERIENCE: 146 | for _ in range(CONFIG.INIT_EXPERIENCE): 147 | rollout, episode_cost = train_env.generate_rollout( 148 | None, render=CONFIG.TRAIN_RENDER 149 | ) 150 | all_rollouts.extend(rollout) 151 | total_cost += episode_cost 152 | for epoch in range(CONFIG.EPOCH + 1): 153 | # Variant 1: Incermentally add transitions (Section 3.4) 154 | # TODO(seungjaeryanlee): Done before or after training? 155 | if CONFIG.INCREMENT_EXPERIENCE: 156 | new_rollout, episode_cost = train_env.generate_rollout( 157 | nfq_agent.get_best_action, render=CONFIG.TRAIN_RENDER 158 | ) 159 | all_rollouts.extend(new_rollout) 160 | total_cost += episode_cost 161 | 162 | state_action_b, target_q_values = nfq_agent.generate_pattern_set(all_rollouts) 163 | 164 | # Variant 2: Clamp function to zero in goal region 165 | # TODO(seungjaeryanlee): Since this is a regulator setting, should it 166 | # not be clamped to zero? 167 | if CONFIG.HINT_TO_GOAL: 168 | goal_state_action_b, goal_target_q_values = train_env.get_goal_pattern_set() 169 | goal_state_action_b = torch.FloatTensor(goal_state_action_b) 170 | goal_target_q_values = torch.FloatTensor(goal_target_q_values) 171 | state_action_b = torch.cat([state_action_b, goal_state_action_b], dim=0) 172 | target_q_values = torch.cat([target_q_values, goal_target_q_values], dim=0) 173 | 174 | loss = nfq_agent.train((state_action_b, target_q_values)) 175 | 176 | # TODO(seungjaeryanlee): Evaluation should be done with 3000 episodes 177 | eval_episode_length, eval_success, eval_episode_cost = nfq_agent.evaluate( 178 | eval_env, CONFIG.EVAL_RENDER 179 | ) 180 | 181 | if CONFIG.INCREMENT_EXPERIENCE: 182 | logger.info( 183 | "Epoch {:4d} | Train {:3d} / {:4.2f} | Eval {:4d} / {:5.2f} | Train Loss {:.4f}".format( # noqa: B950 184 | epoch, 185 | len(new_rollout), 186 | episode_cost, 187 | eval_episode_length, 188 | eval_episode_cost, 189 | loss, 190 | ) 191 | ) 192 | if CONFIG.USE_TENSORBOARD: 193 | writer.add_scalar("train/episode_length", len(new_rollout), epoch) 194 | writer.add_scalar("train/episode_cost", episode_cost, epoch) 195 | writer.add_scalar("train/loss", loss, epoch) 196 | writer.add_scalar("eval/episode_length", eval_episode_length, epoch) 197 | writer.add_scalar("eval/episode_cost", eval_episode_cost, epoch) 198 | if CONFIG.USE_WANDB: 199 | wandb.log({"Train Episode Length": len(new_rollout)}, step=epoch) 200 | wandb.log({"Train Episode Cost": episode_cost}, step=epoch) 201 | wandb.log({"Train Loss": loss}, step=epoch) 202 | wandb.log( 203 | {"Evaluation Episode Length": eval_episode_length}, step=epoch 204 | ) 205 | wandb.log({"Evaluation Episode Cost": eval_episode_cost}, step=epoch) 206 | else: 207 | logger.info( 208 | "Epoch {:4d} | Eval {:4d} / {:5.2f} | Train Loss {:.4f}".format( 209 | epoch, eval_episode_length, eval_episode_cost, loss 210 | ) 211 | ) 212 | if CONFIG.USE_TENSORBOARD: 213 | writer.add_scalar("train/loss", loss, epoch) 214 | writer.add_scalar("eval/episode_length", eval_episode_length, epoch) 215 | writer.add_scalar("eval/episode_cost", eval_episode_cost, epoch) 216 | if CONFIG.USE_WANDB: 217 | wandb.log({"Train Loss": loss}, step=epoch) 218 | wandb.log( 219 | {"Evaluation Episode Length": eval_episode_length}, step=epoch 220 | ) 221 | wandb.log({"Evaluation Episode Cost": eval_episode_cost}, step=epoch) 222 | 223 | if eval_success: 224 | logger.info( 225 | "Epoch {:4d} | Total Cycles {:6d} | Total Cost {:4.2f}".format( 226 | epoch, len(all_rollouts), total_cost 227 | ) 228 | ) 229 | if CONFIG.USE_TENSORBOARD: 230 | writer.add_scalar("summary/total_cycles", len(all_rollouts), epoch) 231 | writer.add_scalar("summary/total_cost", total_cost, epoch) 232 | if CONFIG.USE_WANDB: 233 | wandb.log({"Total Cycles": len(all_rollouts)}, step=epoch) 234 | wandb.log({"Total Cost": total_cost}, step=epoch) 235 | break 236 | 237 | # Save trained agent 238 | if CONFIG.SAVE_PATH: 239 | save_models(CONFIG.SAVE_PATH, nfq_net=nfq_net, optimizer=optimizer) 240 | 241 | train_env.close() 242 | eval_env.close() 243 | 244 | 245 | if __name__ == "__main__": 246 | main() 247 | -------------------------------------------------------------------------------- /environments/cartpole.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified version of classic cart-pole system implemented by Rich Sutton et al. 3 | Copied from http://incompleteideas.net/sutton/book/code/pole.c 4 | permalink: https://perma.cc/C9ZM-652R 5 | """ 6 | # flake8: noqa 7 | import math 8 | from typing import Callable, List, Tuple 9 | 10 | import gym 11 | import numpy as np 12 | from gym import logger, spaces 13 | from gym.utils import seeding 14 | 15 | 16 | class CartPoleRegulatorEnv(gym.Env): 17 | """ 18 | Description: 19 | A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track. The pendulum starts upright, and the goal is to prevent it from falling over by increasing and reducing the cart's velocity. 20 | Source: 21 | This environment corresponds to the version of the cart-pole problem described by Barto, Sutton, and Anderson 22 | Observation: 23 | Type: Box(4) 24 | Num Observation Min Max 25 | 0 Cart Position -4.8 4.8 26 | 1 Cart Velocity -Inf Inf 27 | 2 Pole Angle -24 deg 24 deg 28 | 3 Pole Velocity At Tip -Inf Inf 29 | 30 | Actions: 31 | Type: Discrete(2) 32 | Num Action 33 | 0 Push cart to the left 34 | 1 Push cart to the right 35 | 36 | Note: The amount the velocity that is reduced or increased is not fixed; it depends on the angle the pole is pointing. This is because the center of gravity of the pole increases the amount of energy needed to move the cart underneath it 37 | Reward: 38 | Reward is 1 for every step taken, including the termination step 39 | Starting State: 40 | All observations are assigned a uniform random value in [-0.05..0.05] 41 | Episode Termination: 42 | Pole Angle is more than 12 degrees 43 | Cart Position is more than 2.4 (center of the cart reaches the edge of the display) 44 | Episode length is greater than 200 45 | Solved Requirements 46 | Considered solved when the average reward is greater than or equal to 195.0 over 100 consecutive trials. 47 | """ 48 | 49 | metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 50} 50 | 51 | def __init__(self, mode="train"): 52 | self.gravity = 9.8 53 | self.masscart = 1.0 54 | self.masspole = 0.1 55 | self.total_mass = self.masspole + self.masscart 56 | self.length = 0.5 # actually half the pole's length 57 | self.polemass_length = self.masspole * self.length 58 | self.force_mag = 10.0 59 | self.tau = 0.02 # seconds between state updates 60 | self.kinematics_integrator = "euler" 61 | 62 | assert mode in ["train", "eval"] 63 | self.mode = mode 64 | self.max_steps = 100 if mode == "train" else 3000 65 | 66 | # Success state 67 | # TODO(seungjaeryanlee): Verify pole angle success state 68 | # NOTE(seungjaeryanlee): Relaxed definition of success state 69 | # that deviates from paper 70 | self.x_success_range = 2.4 71 | self.theta_success_range = 12 * 2 * math.pi / 360 72 | 73 | # Failure state description 74 | # TODO(seungjaeryanlee): Verify pole angle threshold 75 | self.x_threshold = 2.4 76 | self.theta_threshold_radians = math.pi / 2 77 | 78 | self.c_trans = 0.01 79 | 80 | # Angle limit set to 2 * theta_threshold_radians so failing observation is still within bounds 81 | high = np.array( 82 | [ 83 | self.x_threshold * 2, 84 | np.finfo(np.float32).max, 85 | self.theta_threshold_radians * 2, 86 | np.finfo(np.float32).max, 87 | ] 88 | ) 89 | 90 | self.action_space = spaces.Discrete(2) 91 | self.observation_space = spaces.Box(-high, high, dtype=np.float32) 92 | 93 | self.seed() 94 | self.viewer = None 95 | self.state = None 96 | 97 | def seed(self, seed=None): 98 | self.np_random, seed = seeding.np_random(seed) 99 | return [seed] 100 | 101 | def _compute_next_state(self, state, action): 102 | x, x_dot, theta, theta_dot = state 103 | force = self.force_mag if action == 1 else -self.force_mag 104 | costheta = math.cos(theta) 105 | sintheta = math.sin(theta) 106 | temp = ( 107 | force + self.polemass_length * theta_dot * theta_dot * sintheta 108 | ) / self.total_mass 109 | thetaacc = (self.gravity * sintheta - costheta * temp) / ( 110 | self.length 111 | * (4.0 / 3.0 - self.masspole * costheta * costheta / self.total_mass) 112 | ) 113 | xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass 114 | if self.kinematics_integrator == "euler": 115 | x = x + self.tau * x_dot 116 | x_dot = x_dot + self.tau * xacc 117 | theta = theta + self.tau * theta_dot 118 | theta_dot = theta_dot + self.tau * thetaacc 119 | else: # semi-implicit euler 120 | x_dot = x_dot + self.tau * xacc 121 | x = x + self.tau * x_dot 122 | theta_dot = theta_dot + self.tau * thetaacc 123 | theta = theta + self.tau * theta_dot 124 | 125 | return x, x_dot, theta, theta_dot 126 | 127 | # NOTE(seungjaeryanlee): done is True only when the episode terminated due 128 | # to entering forbidden state. It is not True if it 129 | # terminated due to maximum timestep. 130 | def step(self, action): 131 | assert self.action_space.contains(action), "%r (%s) invalid" % ( 132 | action, 133 | type(action), 134 | ) 135 | self.state = self._compute_next_state(self.state, action) 136 | x, _, theta, _ = self.state 137 | 138 | self.episode_step += 1 139 | 140 | # Forbidden States (S-) 141 | if ( 142 | x < -self.x_threshold 143 | or x > self.x_threshold 144 | or theta < -self.theta_threshold_radians 145 | or theta > self.theta_threshold_radians 146 | ): 147 | done = True 148 | cost = 1 149 | # Goal States (S+) 150 | elif ( 151 | -self.x_success_range < x < self.x_success_range 152 | and -self.theta_success_range < theta < self.theta_success_range 153 | ): 154 | done = False 155 | cost = 0 156 | else: 157 | done = False 158 | cost = self.c_trans 159 | 160 | # Check for time limit 161 | info = {"time_limit": self.episode_step >= self.max_steps} 162 | 163 | return np.array(self.state), cost, done, info 164 | 165 | def reset(self): 166 | if self.mode == "train": 167 | self.state = self.np_random.uniform( 168 | low=[-2.3, 0, -0.3, 0], high=[2.3, 0, 0.3, 0], size=(4,) 169 | ) 170 | else: 171 | self.state = self.np_random.uniform( 172 | low=[-1, 0, -0.3, 0], high=[1, 0, 0.3, 0], size=(4,) 173 | ) 174 | 175 | self.episode_step = 0 176 | 177 | return np.array(self.state) 178 | 179 | def render(self, mode="human"): 180 | screen_width = 600 181 | screen_height = 400 182 | 183 | world_width = self.x_threshold * 2 184 | scale = screen_width / world_width 185 | carty = 100 # TOP OF CART 186 | polewidth = 10.0 187 | polelen = scale * (2 * self.length) 188 | cartwidth = 50.0 189 | cartheight = 30.0 190 | 191 | if self.viewer is None: 192 | from gym.envs.classic_control import rendering 193 | 194 | self.viewer = rendering.Viewer(screen_width, screen_height) 195 | l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2 196 | axleoffset = cartheight / 4.0 197 | cart = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)]) 198 | self.carttrans = rendering.Transform() 199 | cart.add_attr(self.carttrans) 200 | self.viewer.add_geom(cart) 201 | l, r, t, b = ( 202 | -polewidth / 2, 203 | polewidth / 2, 204 | polelen - polewidth / 2, 205 | -polewidth / 2, 206 | ) 207 | pole = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)]) 208 | pole.set_color(0.8, 0.6, 0.4) 209 | self.poletrans = rendering.Transform(translation=(0, axleoffset)) 210 | pole.add_attr(self.poletrans) 211 | pole.add_attr(self.carttrans) 212 | self.viewer.add_geom(pole) 213 | self.axle = rendering.make_circle(polewidth / 2) 214 | self.axle.add_attr(self.poletrans) 215 | self.axle.add_attr(self.carttrans) 216 | self.axle.set_color(0.5, 0.5, 0.8) 217 | self.viewer.add_geom(self.axle) 218 | self.track = rendering.Line((0, carty), (screen_width, carty)) 219 | self.track.set_color(0, 0, 0) 220 | self.viewer.add_geom(self.track) 221 | 222 | self._pole_geom = pole 223 | 224 | if self.state is None: 225 | return None 226 | 227 | # Edit the pole polygon vertex 228 | pole = self._pole_geom 229 | l, r, t, b = ( 230 | -polewidth / 2, 231 | polewidth / 2, 232 | polelen - polewidth / 2, 233 | -polewidth / 2, 234 | ) 235 | pole.v = [(l, b), (l, t), (r, t), (r, b)] 236 | 237 | x = self.state 238 | cartx = x[0] * scale + screen_width / 2.0 # MIDDLE OF CART 239 | self.carttrans.set_translation(cartx, carty) 240 | self.poletrans.set_rotation(-x[2]) 241 | 242 | return self.viewer.render(return_rgb_array=mode == "rgb_array") 243 | 244 | def close(self): 245 | if self.viewer: 246 | self.viewer.close() 247 | self.viewer = None 248 | 249 | def get_goal_pattern_set(self, size: int = 100): 250 | """Use hint-to-goal heuristic to clamp network output. 251 | 252 | Parameters 253 | ---------- 254 | size : int 255 | The size of the goal pattern set to generate. 256 | 257 | Returns 258 | ------- 259 | pattern_set : tuple of np.ndarray 260 | Pattern set to train the NFQ network. 261 | 262 | """ 263 | goal_state_action_b = [ 264 | np.array( 265 | [ 266 | # NOTE(seungjaeryanlee): The success state in hint-to-goal is not relaxed. 267 | # TODO(seungjaeryanlee): What is goal velocity? 268 | np.random.uniform(-0.05, 0.05), 269 | np.random.normal(), 270 | np.random.uniform( 271 | -self.theta_success_range, self.theta_success_range 272 | ), 273 | np.random.normal(), 274 | np.random.randint(2), 275 | ] 276 | ) 277 | for _ in range(size) 278 | ] 279 | goal_target_q_values = np.zeros(size) 280 | 281 | return goal_state_action_b, goal_target_q_values 282 | 283 | def generate_rollout( 284 | self, get_best_action: Callable = None, render: bool = False 285 | ) -> List[Tuple[np.array, int, int, np.array, bool]]: 286 | """ 287 | Generate rollout using given action selection function. 288 | 289 | If a network is not given, generate random rollout instead. 290 | 291 | Parameters 292 | ---------- 293 | get_best_action : Callable 294 | Greedy policy. 295 | render: bool 296 | If true, render environment. 297 | 298 | Returns 299 | ------- 300 | rollout : List of Tuple 301 | Generated rollout. 302 | episode_cost : float 303 | Cumulative cost throughout the episode. 304 | 305 | """ 306 | rollout = [] 307 | episode_cost = 0 308 | obs = self.reset() 309 | done = False 310 | info = {"time_limit": False} 311 | while not done and not info["time_limit"]: 312 | if get_best_action: 313 | action = get_best_action(obs) 314 | else: 315 | action = self.action_space.sample() 316 | 317 | next_obs, cost, done, info = self.step(action) 318 | rollout.append((obs, action, cost, next_obs, done)) 319 | episode_cost += cost 320 | obs = next_obs 321 | 322 | if render: 323 | self.render() 324 | 325 | return rollout, episode_cost 326 | --------------------------------------------------------------------------------