├── images ├── good.gif ├── random.gif └── tensorboard.png ├── 06-08-18-20-log ├── 06-08-18-20-100000.h5 ├── 06-08-18-20-150000.h5 ├── 06-08-18-20-200000.h5 ├── 06-08-18-20-50000.h5 └── events.out.tfevents.1528474841.adam-ThinkPad-T520 ├── 06-08-18-42-log ├── 06-08-18-42-100000.h5 ├── 06-08-18-42-150000.h5 ├── 06-08-18-42-200000.h5 ├── 06-08-18-42-50000.h5 └── events.out.tfevents.1528476165.adam-ThinkPad-T520 ├── 06-08-19-15-log ├── 06-08-19-15-100000.h5 ├── 06-08-19-15-150000.h5 ├── 06-08-19-15-200000.h5 ├── 06-08-19-15-50000.h5 └── events.out.tfevents.1528478105.adam-ThinkPad-T520 ├── 06-08-20-23-log ├── 06-08-20-23-100000.h5 ├── 06-08-20-23-150000.h5 ├── 06-08-20-23-200000.h5 ├── 06-08-20-23-50000.h5 └── events.out.tfevents.1528482213.adam-ThinkPad-T520 ├── 06-08-20-57-log ├── 06-08-20-57-100000.h5 ├── 06-08-20-57-150000.h5 ├── 06-08-20-57-200000.h5 ├── 06-08-20-57-50000.h5 └── events.out.tfevents.1528484236.adam-ThinkPad-T520 ├── loggers.py ├── .travis.yml ├── requirements.txt ├── see.py ├── LICENSE ├── .gitignore ├── README.md ├── replay_buffer.py └── run.py /images/good.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/images/good.gif -------------------------------------------------------------------------------- /images/random.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/images/random.gif -------------------------------------------------------------------------------- /images/tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/images/tensorboard.png -------------------------------------------------------------------------------- /06-08-18-20-log/06-08-18-20-100000.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-18-20-log/06-08-18-20-100000.h5 -------------------------------------------------------------------------------- /06-08-18-20-log/06-08-18-20-150000.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-18-20-log/06-08-18-20-150000.h5 -------------------------------------------------------------------------------- /06-08-18-20-log/06-08-18-20-200000.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-18-20-log/06-08-18-20-200000.h5 -------------------------------------------------------------------------------- /06-08-18-20-log/06-08-18-20-50000.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-18-20-log/06-08-18-20-50000.h5 -------------------------------------------------------------------------------- /06-08-18-42-log/06-08-18-42-100000.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-18-42-log/06-08-18-42-100000.h5 -------------------------------------------------------------------------------- /06-08-18-42-log/06-08-18-42-150000.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-18-42-log/06-08-18-42-150000.h5 -------------------------------------------------------------------------------- /06-08-18-42-log/06-08-18-42-200000.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-18-42-log/06-08-18-42-200000.h5 -------------------------------------------------------------------------------- /06-08-18-42-log/06-08-18-42-50000.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-18-42-log/06-08-18-42-50000.h5 -------------------------------------------------------------------------------- /06-08-19-15-log/06-08-19-15-100000.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-19-15-log/06-08-19-15-100000.h5 -------------------------------------------------------------------------------- /06-08-19-15-log/06-08-19-15-150000.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-19-15-log/06-08-19-15-150000.h5 -------------------------------------------------------------------------------- /06-08-19-15-log/06-08-19-15-200000.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-19-15-log/06-08-19-15-200000.h5 -------------------------------------------------------------------------------- /06-08-19-15-log/06-08-19-15-50000.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-19-15-log/06-08-19-15-50000.h5 -------------------------------------------------------------------------------- /06-08-20-23-log/06-08-20-23-100000.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-20-23-log/06-08-20-23-100000.h5 -------------------------------------------------------------------------------- /06-08-20-23-log/06-08-20-23-150000.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-20-23-log/06-08-20-23-150000.h5 -------------------------------------------------------------------------------- /06-08-20-23-log/06-08-20-23-200000.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-20-23-log/06-08-20-23-200000.h5 -------------------------------------------------------------------------------- /06-08-20-23-log/06-08-20-23-50000.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-20-23-log/06-08-20-23-50000.h5 -------------------------------------------------------------------------------- /06-08-20-57-log/06-08-20-57-100000.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-20-57-log/06-08-20-57-100000.h5 -------------------------------------------------------------------------------- /06-08-20-57-log/06-08-20-57-150000.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-20-57-log/06-08-20-57-150000.h5 -------------------------------------------------------------------------------- /06-08-20-57-log/06-08-20-57-200000.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-20-57-log/06-08-20-57-200000.h5 -------------------------------------------------------------------------------- /06-08-20-57-log/06-08-20-57-50000.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-20-57-log/06-08-20-57-50000.h5 -------------------------------------------------------------------------------- /06-08-18-20-log/events.out.tfevents.1528474841.adam-ThinkPad-T520: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-18-20-log/events.out.tfevents.1528474841.adam-ThinkPad-T520 -------------------------------------------------------------------------------- /06-08-18-42-log/events.out.tfevents.1528476165.adam-ThinkPad-T520: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-18-42-log/events.out.tfevents.1528476165.adam-ThinkPad-T520 -------------------------------------------------------------------------------- /06-08-19-15-log/events.out.tfevents.1528478105.adam-ThinkPad-T520: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-19-15-log/events.out.tfevents.1528478105.adam-ThinkPad-T520 -------------------------------------------------------------------------------- /06-08-20-23-log/events.out.tfevents.1528482213.adam-ThinkPad-T520: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-20-23-log/events.out.tfevents.1528482213.adam-ThinkPad-T520 -------------------------------------------------------------------------------- /06-08-20-57-log/events.out.tfevents.1528484236.adam-ThinkPad-T520: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-20-57-log/events.out.tfevents.1528484236.adam-ThinkPad-T520 -------------------------------------------------------------------------------- /loggers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class TensorBoardLogger(object): 5 | def __init__(self, log_dir): 6 | self.writer = tf.summary.FileWriter(log_dir) 7 | 8 | def log_scalar(self, name, value, step): 9 | summary = tf.Summary(value=[tf.Summary.Value(tag=name, simple_value=value)]) 10 | self.writer.add_summary(summary, step) 11 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | install: 4 | - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh 5 | - bash Miniconda3-latest-Linux-x86_64.sh -b -p $HOME/miniconda3 6 | - export PATH="$HOME/miniconda3/bin:$PATH" 7 | - conda config --set always_yes yes --set changeps1 no 8 | - conda create -n tutorial python=3.6.5 -y 9 | - source activate tutorial 10 | - git clone https://github.com/AdamStelmaszczyk/rl-tutorial.git 11 | - pip install -r requirements.txt 12 | 13 | script: 14 | - cd rl-tutorial 15 | - python run.py --test 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.8.0 2 | astor==0.8.0 3 | certifi==2019.9.11 4 | cloudpickle==1.2.2 5 | future==0.18.1 6 | gast==0.3.2 7 | google-pasta==0.1.7 8 | grpcio==1.16.1 9 | gym==0.15.3 10 | h5py==2.8.0 11 | Keras-Applications==1.0.8 12 | Keras-Preprocessing==1.1.0 13 | Markdown==3.1.1 14 | numpy==1.17.2 15 | opencv-python==4.1.1.26 16 | protobuf==3.9.2 17 | psutil==5.6.3 18 | pyglet==1.3.2 19 | scipy==1.3.1 20 | six==1.12.0 21 | tensorboard==1.14.0 22 | tensorflow==1.14.0 23 | tensorflow-estimator==1.14.0 24 | termcolor==1.1.0 25 | tqdm==4.36.1 26 | Werkzeug==0.16.0 27 | wrapt==1.11.2 28 | -------------------------------------------------------------------------------- /see.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | env = gym.make('MountainCar-v0') 4 | done = True 5 | episode = 0 6 | episode_return = 0.0 7 | for episode in range(5): 8 | for step in range(200): 9 | if done: 10 | if episode > 0: 11 | print("Episode return: ", episode_return) 12 | obs = env.reset() 13 | episode += 1 14 | episode_return = 0.0 15 | env.render() 16 | else: 17 | obs = next_obs 18 | action = env.action_space.sample() 19 | next_obs, reward, done, _ = env.step(action) 20 | episode_return += reward 21 | env.render() 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Adam Stelmaszczyk 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/AdamStelmaszczyk/rl-tutorial.svg?branch=master)](https://travis-ci.org/AdamStelmaszczyk/rl-tutorial) 2 | 3 | Solving Mountain Car environment with TensorFlow & Keras implementation of DQN. 4 | For similar code solving some Atari games, look [here](https://github.com/AdamStelmaszczyk/dqn). 5 | 6 | --- 7 | 8 | 9 | 10 | ## Install 11 | 12 | 1. Clone this repo: `git clone https://github.com/AdamStelmaszczyk/rl-tutorial.git`. 13 | 2. [Install `conda`](https://conda.io/docs/user-guide/install/index.html) for dependency management. 14 | 3. Create `tutorial` conda environment: `conda create -n tutorial python=3.6.5 -y`. 15 | 4. Activate `tutorial` conda environment: `source activate tutorial`. All the following commands should be run in the activated `tutorial` environment. 16 | 5. Install basic dependencies: `pip install -r requirements.txt`. 17 | 18 | There is an automatic build on Travis which [does the same](https://github.com/AdamStelmaszczyk/rl-tutorial/blob/master/.travis.yml). 19 | 20 | ## Run 21 | 22 | `python run.py --help` 23 | 24 | ``` 25 | usage: run.py [-h] [--eval] [--model MODEL] [--name NAME] [--seed SEED] 26 | [--test] [--view] 27 | 28 | optional arguments: 29 | -h, --help show this help message and exit 30 | --eval run evaluation with log only (default: False) 31 | --images save images during evaluation (default: False) 32 | --model MODEL model filename to load (default: None) 33 | --name NAME name for saved files (default: 06-08-21-53) 34 | --seed SEED pseudo random number generator seed (default: None) 35 | --test run tests (default: False) 36 | --view view the model playing the game (default: False) 37 | ``` 38 | 39 | ## Generate GIFs 40 | 41 | 42 | 43 | 1. Generate images: `python run.py --model 06-08-18-42-log/06-08-18-42-200000.h5 --images`. 44 | 2. We will use `convert` tool, which is part of ImageMagick, [here](https://www.imagemagick.org/script/download.php) are the installation instructions. 45 | 3. Convert images from episode 1 to GIF: `convert -layers optimize-frame 1_*.png 1.gif` 46 | 47 | 48 | 49 | ## Uninstall 50 | 51 | 1. Deactivate conda environment: `source deactivate`. 52 | 2. Remove `tutorial` conda environment: `conda env remove -n tutorial -y`. 53 | 54 | ## Links 55 | 56 | - Playing Atari with Deep Reinforcement Learning, https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf 57 | - Human-level control through deep reinforcement learning, https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf 58 | - MountainCar description: https://github.com/openai/gym/wiki/MountainCar-v0 59 | - MountainCar source code: https://github.com/openai/gym/blob/4c460ba6c8959dd8e0a03b13a1ca817da6d4074f/gym/envs/classic_control/mountain_car.py 60 | -------------------------------------------------------------------------------- /replay_buffer.py: -------------------------------------------------------------------------------- 1 | # This file was based on 2 | # https://github.com/openai/baselines/blob/edb52c22a5e14324304a491edc0f91b6cc07453b/baselines/deepq/replay_buffer.py 3 | # its license: 4 | # 5 | # The MIT License 6 | # 7 | # Copyright (c) 2017 OpenAI (http://openai.com) 8 | # 9 | # Permission is hereby granted, free of charge, to any person obtaining a copy 10 | # of this software and associated documentation files (the "Software"), to deal 11 | # in the Software without restriction, including without limitation the rights 12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | # copies of the Software, and to permit persons to whom the Software is 14 | # furnished to do so, subject to the following conditions: 15 | # 16 | # The above copyright notice and this permission notice shall be included in 17 | # all copies or substantial portions of the Software. 18 | # 19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 25 | # THE SOFTWARE. 26 | 27 | import random 28 | 29 | import numpy as np 30 | 31 | 32 | class ReplayBuffer(object): 33 | def __init__(self, size): 34 | """ 35 | Parameters 36 | ---------- 37 | size: int 38 | Max number of transitions to store in the buffer. When the buffer 39 | overflows the old memories are dropped. 40 | """ 41 | self._storage = [] 42 | self._max_size = size 43 | self._next_idx = 0 44 | 45 | def __len__(self): 46 | return len(self._storage) 47 | 48 | def add(self, observation, action, reward, next_obs, done): 49 | data = (observation, action, reward, next_obs, done) 50 | if self._next_idx >= len(self._storage): 51 | self._storage.append(data) 52 | else: 53 | self._storage[self._next_idx] = data 54 | self._next_idx = (self._next_idx + 1) % self._max_size 55 | 56 | def _encode_sample(self, indices): 57 | goals, observations, actions, rewards, next_observations, dones = [], [], [], [], [], [] 58 | for i in indices: 59 | data = self._storage[i] 60 | observation, action, reward, next_obs, done = data 61 | observations.append(np.array(observation, copy=False)) 62 | actions.append(np.array(action, copy=False)) 63 | rewards.append(reward) 64 | next_observations.append(np.array(next_obs, copy=False)) 65 | dones.append(done) 66 | return np.array(observations), np.array(actions), np.array(rewards), np.array(next_observations), np.array(dones) 67 | 68 | def sample(self, batch_size): 69 | """Sample a batch of experiences. 70 | 71 | Parameters 72 | ---------- 73 | batch_size: int 74 | How many transitions to sample. 75 | 76 | Returns 77 | ------- 78 | observations: np.array 79 | actions: np.array 80 | rewards: np.array 81 | next_observations: np.array 82 | dones: np.array 83 | """ 84 | idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] 85 | return self._encode_sample(idxes) 86 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import random 2 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 3 | from time import strftime, time 4 | 5 | import cv2 6 | import gym 7 | import numpy as np 8 | import psutil 9 | import tensorflow as tf 10 | import tensorflow.contrib.keras as keras 11 | from tqdm import tqdm 12 | 13 | from loggers import TensorBoardLogger 14 | from replay_buffer import ReplayBuffer 15 | 16 | DISCOUNT_FACTOR_GAMMA = 0.99 17 | LEARNING_RATE = 0.001 18 | BATCH_SIZE = 64 19 | TARGET_UPDATE_EVERY = 1000 20 | TRAIN_START = 2000 21 | REPLAY_BUFFER_SIZE = 50000 22 | MAX_STEPS = 200000 23 | LOG_EVERY = 2000 24 | SNAPSHOT_EVERY = 50000 25 | EVAL_EVERY = 20000 26 | EVAL_STEPS = 10000 27 | EVAL_EPSILON = 0 28 | TRAIN_EPSILON = 0.01 29 | Q_VALIDATION_SIZE = 10000 30 | 31 | 32 | def one_hot_encode(n, action): 33 | one_hot = np.zeros(n) 34 | one_hot[int(action)] = 1 35 | return one_hot 36 | 37 | 38 | def predict(env, model, observations): 39 | action_mask = np.ones((len(observations), env.action_space.n)) 40 | return model.predict(x=[observations, action_mask]) 41 | 42 | 43 | def fit_batch(env, model, target_model, batch): 44 | observations, actions, rewards, next_observations, dones = batch 45 | # Predict the Q values of the next states. Passing ones as the action mask. 46 | next_q_values = predict(env, target_model, next_observations) 47 | # The Q values of terminal states is 0 by definition. 48 | next_q_values[dones] = 0.0 49 | # The Q values of each start state is the reward + gamma * the max next state Q value 50 | q_values = rewards + DISCOUNT_FACTOR_GAMMA * np.max(next_q_values, axis=1) 51 | one_hot_actions = np.array([one_hot_encode(env.action_space.n, action) for action in actions]) 52 | history = model.fit( 53 | x=[observations, one_hot_actions], 54 | y=one_hot_actions * q_values[:, None], 55 | batch_size=BATCH_SIZE, 56 | verbose=0, 57 | ) 58 | return history.history['loss'][0] 59 | 60 | 61 | def create_model(env): 62 | n_actions = env.action_space.n 63 | obs_shape = env.observation_space.shape 64 | observations_input = keras.layers.Input(obs_shape, name='observations_input') 65 | action_mask = keras.layers.Input((n_actions,), name='action_mask') 66 | hidden = keras.layers.Dense(32, activation='relu')(observations_input) 67 | hidden_2 = keras.layers.Dense(32, activation='relu')(hidden) 68 | output = keras.layers.Dense(n_actions)(hidden_2) 69 | filtered_output = keras.layers.multiply([output, action_mask]) 70 | model = keras.models.Model([observations_input, action_mask], filtered_output) 71 | optimizer = keras.optimizers.Adam(lr=LEARNING_RATE, clipnorm=1.0) 72 | model.compile(optimizer, loss='mean_squared_error') 73 | return model 74 | 75 | 76 | def greedy_action(env, model, observation): 77 | next_q_values = predict(env, model, observations=[observation]) 78 | return np.argmax(next_q_values) 79 | 80 | 81 | def epsilon_greedy_action(env, model, observation, epsilon): 82 | if random.random() < epsilon: 83 | action = env.action_space.sample() 84 | else: 85 | action = greedy_action(env, model, observation) 86 | return action 87 | 88 | 89 | def save_model(model, step, logdir, name): 90 | filename = '{}/{}-{}.h5'.format(logdir, name, step) 91 | model.save(filename) 92 | print('Saved {}'.format(filename)) 93 | return filename 94 | 95 | 96 | def save_image(env, episode, step): 97 | frame = env.render(mode='rgb_array') 98 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # following cv2.imwrite assumes BGR 99 | filename = "{}_{:06d}.png".format(episode, step) 100 | cv2.imwrite(filename, frame, params=[cv2.IMWRITE_PNG_COMPRESSION, 9]) 101 | 102 | 103 | def evaluate(env, model, view=False, images=False): 104 | print("Evaluation") 105 | done = True 106 | episode = 0 107 | episode_return_sum = 0.0 108 | for step in tqdm(range(1, EVAL_STEPS + 1)): 109 | if done: 110 | if episode > 0: 111 | episode_return_sum += episode_return 112 | obs = env.reset() 113 | episode += 1 114 | episode_return = 0.0 115 | episode_steps = 0 116 | if view: 117 | env.render() 118 | if images: 119 | save_image(env, episode, step) 120 | else: 121 | obs = next_obs 122 | action = epsilon_greedy_action(env, model, obs, epsilon=EVAL_EPSILON) 123 | next_obs, reward, done, _ = env.step(action) 124 | episode_return += reward 125 | episode_steps += 1 126 | if view: 127 | env.render() 128 | if images: 129 | save_image(env, episode, step) 130 | assert episode > 0 131 | episode_return_avg = episode_return_sum / episode 132 | return episode_return_avg 133 | 134 | 135 | def train(env, model, max_steps, name, logdir, logger): 136 | target_model = create_model(env) 137 | replay = ReplayBuffer(REPLAY_BUFFER_SIZE) 138 | done = True 139 | episode = 0 140 | steps_after_logging = 0 141 | loss = 0.0 142 | for step in range(1, max_steps + 1): 143 | try: 144 | if step % SNAPSHOT_EVERY == 0: 145 | save_model(model, step, logdir, name) 146 | if done: 147 | if episode > 0: 148 | if steps_after_logging >= LOG_EVERY: 149 | steps_after_logging = 0 150 | episode_end = time() 151 | episode_seconds = episode_end - episode_start 152 | episode_steps = step - episode_start_step 153 | steps_per_second = episode_steps / episode_seconds 154 | memory = psutil.virtual_memory() 155 | to_gb = lambda in_bytes: in_bytes / 1024 / 1024 / 1024 156 | print( 157 | "episode {} " 158 | "steps {}/{} " 159 | "loss {:.7f} " 160 | "return {} " 161 | "in {:.2f}s " 162 | "{:.1f} steps/s " 163 | "{:.1f}/{:.1f} GB RAM".format( 164 | episode, 165 | episode_steps, 166 | step, 167 | loss, 168 | episode_return, 169 | episode_seconds, 170 | steps_per_second, 171 | to_gb(memory.used), 172 | to_gb(memory.total), 173 | )) 174 | logger.log_scalar('episode_return', episode_return, step) 175 | logger.log_scalar('episode_steps', episode_steps, step) 176 | logger.log_scalar('episode_seconds', episode_seconds, step) 177 | logger.log_scalar('steps_per_second', steps_per_second, step) 178 | logger.log_scalar('memory_used', to_gb(memory.used), step) 179 | logger.log_scalar('loss', loss, step) 180 | episode_start = time() 181 | episode_start_step = step 182 | obs = env.reset() 183 | episode += 1 184 | episode_return = 0.0 185 | else: 186 | obs = next_obs 187 | 188 | action = epsilon_greedy_action(env, model, obs, epsilon=TRAIN_EPSILON) 189 | next_obs, reward, done, _ = env.step(action) 190 | episode_return += reward 191 | replay.add(obs, action, reward, next_obs, done) 192 | 193 | if step >= TRAIN_START: 194 | if step % TARGET_UPDATE_EVERY == 0: 195 | target_model.set_weights(model.get_weights()) 196 | batch = replay.sample(BATCH_SIZE) 197 | loss = fit_batch(env, model, target_model, batch) 198 | if step == Q_VALIDATION_SIZE: 199 | q_validation_observations, _, _, _, _ = replay.sample(Q_VALIDATION_SIZE) 200 | if step >= TRAIN_START and step % EVAL_EVERY == 0: 201 | episode_return_avg = evaluate(env, model) 202 | q_values = predict(env, model, q_validation_observations) 203 | max_q_values = np.max(q_values, axis=1) 204 | avg_max_q_value = np.mean(max_q_values) 205 | print( 206 | "episode {} " 207 | "step {} " 208 | "episode_return_avg {:.3f} " 209 | "avg_max_q_value {:.3f}".format( 210 | episode, 211 | step, 212 | episode_return_avg, 213 | avg_max_q_value, 214 | )) 215 | logger.log_scalar('episode_return_avg', episode_return_avg, step) 216 | logger.log_scalar('avg_max_q_value', avg_max_q_value, step) 217 | steps_after_logging += 1 218 | except KeyboardInterrupt: 219 | save_model(model, step, logdir, name) 220 | break 221 | 222 | 223 | def load_or_create_model(env, model_filename): 224 | if model_filename: 225 | model = keras.models.load_model(model_filename) 226 | print('Loaded {}'.format(model_filename)) 227 | else: 228 | model = create_model(env) 229 | model.summary() 230 | return model 231 | 232 | 233 | def set_seed(env, seed): 234 | random.seed(seed) 235 | np.random.seed(seed) 236 | tf.set_random_seed(seed) 237 | env.seed(seed) 238 | 239 | 240 | def main(args): 241 | assert BATCH_SIZE <= TRAIN_START <= Q_VALIDATION_SIZE <= REPLAY_BUFFER_SIZE 242 | print('args', args) 243 | env = gym.make('MountainCar-v0') 244 | set_seed(env, args.seed) 245 | model = load_or_create_model(env, args.model) 246 | if args.view or args.eval or args.images: 247 | episode_return_avg = evaluate(env, model, args.view, args.images) 248 | print("episode_return_avg {:.3f}".format(episode_return_avg)) 249 | else: 250 | max_steps = 100 if args.test else MAX_STEPS 251 | logdir = '{}-log'.format(args.name) 252 | logger = TensorBoardLogger(logdir) 253 | print('Created {}'.format(logdir)) 254 | train(env, model, max_steps, args.name, logdir, logger) 255 | if args.test: 256 | filename = save_model(model, EVAL_STEPS, logdir='.', name='test') 257 | load_or_create_model(env, filename) 258 | 259 | 260 | if __name__ == '__main__': 261 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 262 | parser.add_argument('--eval', action='store_true', default=False, help='run evaluation with log only') 263 | parser.add_argument('--images', action='store_true', default=False, help='save images during evaluation') 264 | parser.add_argument('--model', action='store', default=None, help='model filename to load') 265 | parser.add_argument('--name', action='store', default=strftime("%m-%d-%H-%M"), help='name for saved files') 266 | parser.add_argument('--seed', action='store', type=int, help='pseudo random number generator seed') 267 | parser.add_argument('--test', action='store_true', default=False, help='run tests') 268 | parser.add_argument('--view', action='store_true', default=False, help='view the model playing the game') 269 | main(parser.parse_args()) 270 | --------------------------------------------------------------------------------