├── images
├── good.gif
├── random.gif
└── tensorboard.png
├── 06-08-18-20-log
├── 06-08-18-20-100000.h5
├── 06-08-18-20-150000.h5
├── 06-08-18-20-200000.h5
├── 06-08-18-20-50000.h5
└── events.out.tfevents.1528474841.adam-ThinkPad-T520
├── 06-08-18-42-log
├── 06-08-18-42-100000.h5
├── 06-08-18-42-150000.h5
├── 06-08-18-42-200000.h5
├── 06-08-18-42-50000.h5
└── events.out.tfevents.1528476165.adam-ThinkPad-T520
├── 06-08-19-15-log
├── 06-08-19-15-100000.h5
├── 06-08-19-15-150000.h5
├── 06-08-19-15-200000.h5
├── 06-08-19-15-50000.h5
└── events.out.tfevents.1528478105.adam-ThinkPad-T520
├── 06-08-20-23-log
├── 06-08-20-23-100000.h5
├── 06-08-20-23-150000.h5
├── 06-08-20-23-200000.h5
├── 06-08-20-23-50000.h5
└── events.out.tfevents.1528482213.adam-ThinkPad-T520
├── 06-08-20-57-log
├── 06-08-20-57-100000.h5
├── 06-08-20-57-150000.h5
├── 06-08-20-57-200000.h5
├── 06-08-20-57-50000.h5
└── events.out.tfevents.1528484236.adam-ThinkPad-T520
├── loggers.py
├── .travis.yml
├── requirements.txt
├── see.py
├── LICENSE
├── .gitignore
├── README.md
├── replay_buffer.py
└── run.py
/images/good.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/images/good.gif
--------------------------------------------------------------------------------
/images/random.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/images/random.gif
--------------------------------------------------------------------------------
/images/tensorboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/images/tensorboard.png
--------------------------------------------------------------------------------
/06-08-18-20-log/06-08-18-20-100000.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-18-20-log/06-08-18-20-100000.h5
--------------------------------------------------------------------------------
/06-08-18-20-log/06-08-18-20-150000.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-18-20-log/06-08-18-20-150000.h5
--------------------------------------------------------------------------------
/06-08-18-20-log/06-08-18-20-200000.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-18-20-log/06-08-18-20-200000.h5
--------------------------------------------------------------------------------
/06-08-18-20-log/06-08-18-20-50000.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-18-20-log/06-08-18-20-50000.h5
--------------------------------------------------------------------------------
/06-08-18-42-log/06-08-18-42-100000.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-18-42-log/06-08-18-42-100000.h5
--------------------------------------------------------------------------------
/06-08-18-42-log/06-08-18-42-150000.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-18-42-log/06-08-18-42-150000.h5
--------------------------------------------------------------------------------
/06-08-18-42-log/06-08-18-42-200000.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-18-42-log/06-08-18-42-200000.h5
--------------------------------------------------------------------------------
/06-08-18-42-log/06-08-18-42-50000.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-18-42-log/06-08-18-42-50000.h5
--------------------------------------------------------------------------------
/06-08-19-15-log/06-08-19-15-100000.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-19-15-log/06-08-19-15-100000.h5
--------------------------------------------------------------------------------
/06-08-19-15-log/06-08-19-15-150000.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-19-15-log/06-08-19-15-150000.h5
--------------------------------------------------------------------------------
/06-08-19-15-log/06-08-19-15-200000.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-19-15-log/06-08-19-15-200000.h5
--------------------------------------------------------------------------------
/06-08-19-15-log/06-08-19-15-50000.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-19-15-log/06-08-19-15-50000.h5
--------------------------------------------------------------------------------
/06-08-20-23-log/06-08-20-23-100000.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-20-23-log/06-08-20-23-100000.h5
--------------------------------------------------------------------------------
/06-08-20-23-log/06-08-20-23-150000.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-20-23-log/06-08-20-23-150000.h5
--------------------------------------------------------------------------------
/06-08-20-23-log/06-08-20-23-200000.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-20-23-log/06-08-20-23-200000.h5
--------------------------------------------------------------------------------
/06-08-20-23-log/06-08-20-23-50000.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-20-23-log/06-08-20-23-50000.h5
--------------------------------------------------------------------------------
/06-08-20-57-log/06-08-20-57-100000.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-20-57-log/06-08-20-57-100000.h5
--------------------------------------------------------------------------------
/06-08-20-57-log/06-08-20-57-150000.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-20-57-log/06-08-20-57-150000.h5
--------------------------------------------------------------------------------
/06-08-20-57-log/06-08-20-57-200000.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-20-57-log/06-08-20-57-200000.h5
--------------------------------------------------------------------------------
/06-08-20-57-log/06-08-20-57-50000.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-20-57-log/06-08-20-57-50000.h5
--------------------------------------------------------------------------------
/06-08-18-20-log/events.out.tfevents.1528474841.adam-ThinkPad-T520:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-18-20-log/events.out.tfevents.1528474841.adam-ThinkPad-T520
--------------------------------------------------------------------------------
/06-08-18-42-log/events.out.tfevents.1528476165.adam-ThinkPad-T520:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-18-42-log/events.out.tfevents.1528476165.adam-ThinkPad-T520
--------------------------------------------------------------------------------
/06-08-19-15-log/events.out.tfevents.1528478105.adam-ThinkPad-T520:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-19-15-log/events.out.tfevents.1528478105.adam-ThinkPad-T520
--------------------------------------------------------------------------------
/06-08-20-23-log/events.out.tfevents.1528482213.adam-ThinkPad-T520:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-20-23-log/events.out.tfevents.1528482213.adam-ThinkPad-T520
--------------------------------------------------------------------------------
/06-08-20-57-log/events.out.tfevents.1528484236.adam-ThinkPad-T520:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AdamStelmaszczyk/rl-tutorial/HEAD/06-08-20-57-log/events.out.tfevents.1528484236.adam-ThinkPad-T520
--------------------------------------------------------------------------------
/loggers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | class TensorBoardLogger(object):
5 | def __init__(self, log_dir):
6 | self.writer = tf.summary.FileWriter(log_dir)
7 |
8 | def log_scalar(self, name, value, step):
9 | summary = tf.Summary(value=[tf.Summary.Value(tag=name, simple_value=value)])
10 | self.writer.add_summary(summary, step)
11 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 |
3 | install:
4 | - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
5 | - bash Miniconda3-latest-Linux-x86_64.sh -b -p $HOME/miniconda3
6 | - export PATH="$HOME/miniconda3/bin:$PATH"
7 | - conda config --set always_yes yes --set changeps1 no
8 | - conda create -n tutorial python=3.6.5 -y
9 | - source activate tutorial
10 | - git clone https://github.com/AdamStelmaszczyk/rl-tutorial.git
11 | - pip install -r requirements.txt
12 |
13 | script:
14 | - cd rl-tutorial
15 | - python run.py --test
16 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.8.0
2 | astor==0.8.0
3 | certifi==2019.9.11
4 | cloudpickle==1.2.2
5 | future==0.18.1
6 | gast==0.3.2
7 | google-pasta==0.1.7
8 | grpcio==1.16.1
9 | gym==0.15.3
10 | h5py==2.8.0
11 | Keras-Applications==1.0.8
12 | Keras-Preprocessing==1.1.0
13 | Markdown==3.1.1
14 | numpy==1.17.2
15 | opencv-python==4.1.1.26
16 | protobuf==3.9.2
17 | psutil==5.6.3
18 | pyglet==1.3.2
19 | scipy==1.3.1
20 | six==1.12.0
21 | tensorboard==1.14.0
22 | tensorflow==1.14.0
23 | tensorflow-estimator==1.14.0
24 | termcolor==1.1.0
25 | tqdm==4.36.1
26 | Werkzeug==0.16.0
27 | wrapt==1.11.2
28 |
--------------------------------------------------------------------------------
/see.py:
--------------------------------------------------------------------------------
1 | import gym
2 |
3 | env = gym.make('MountainCar-v0')
4 | done = True
5 | episode = 0
6 | episode_return = 0.0
7 | for episode in range(5):
8 | for step in range(200):
9 | if done:
10 | if episode > 0:
11 | print("Episode return: ", episode_return)
12 | obs = env.reset()
13 | episode += 1
14 | episode_return = 0.0
15 | env.render()
16 | else:
17 | obs = next_obs
18 | action = env.action_space.sample()
19 | next_obs, reward, done, _ = env.step(action)
20 | episode_return += reward
21 | env.render()
22 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Adam Stelmaszczyk
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://travis-ci.org/AdamStelmaszczyk/rl-tutorial)
2 |
3 | Solving Mountain Car environment with TensorFlow & Keras implementation of DQN.
4 | For similar code solving some Atari games, look [here](https://github.com/AdamStelmaszczyk/dqn).
5 |
6 | ---
7 |
8 |
9 |
10 | ## Install
11 |
12 | 1. Clone this repo: `git clone https://github.com/AdamStelmaszczyk/rl-tutorial.git`.
13 | 2. [Install `conda`](https://conda.io/docs/user-guide/install/index.html) for dependency management.
14 | 3. Create `tutorial` conda environment: `conda create -n tutorial python=3.6.5 -y`.
15 | 4. Activate `tutorial` conda environment: `source activate tutorial`. All the following commands should be run in the activated `tutorial` environment.
16 | 5. Install basic dependencies: `pip install -r requirements.txt`.
17 |
18 | There is an automatic build on Travis which [does the same](https://github.com/AdamStelmaszczyk/rl-tutorial/blob/master/.travis.yml).
19 |
20 | ## Run
21 |
22 | `python run.py --help`
23 |
24 | ```
25 | usage: run.py [-h] [--eval] [--model MODEL] [--name NAME] [--seed SEED]
26 | [--test] [--view]
27 |
28 | optional arguments:
29 | -h, --help show this help message and exit
30 | --eval run evaluation with log only (default: False)
31 | --images save images during evaluation (default: False)
32 | --model MODEL model filename to load (default: None)
33 | --name NAME name for saved files (default: 06-08-21-53)
34 | --seed SEED pseudo random number generator seed (default: None)
35 | --test run tests (default: False)
36 | --view view the model playing the game (default: False)
37 | ```
38 |
39 | ## Generate GIFs
40 |
41 |
42 |
43 | 1. Generate images: `python run.py --model 06-08-18-42-log/06-08-18-42-200000.h5 --images`.
44 | 2. We will use `convert` tool, which is part of ImageMagick, [here](https://www.imagemagick.org/script/download.php) are the installation instructions.
45 | 3. Convert images from episode 1 to GIF: `convert -layers optimize-frame 1_*.png 1.gif`
46 |
47 |
48 |
49 | ## Uninstall
50 |
51 | 1. Deactivate conda environment: `source deactivate`.
52 | 2. Remove `tutorial` conda environment: `conda env remove -n tutorial -y`.
53 |
54 | ## Links
55 |
56 | - Playing Atari with Deep Reinforcement Learning, https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf
57 | - Human-level control through deep reinforcement learning, https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf
58 | - MountainCar description: https://github.com/openai/gym/wiki/MountainCar-v0
59 | - MountainCar source code: https://github.com/openai/gym/blob/4c460ba6c8959dd8e0a03b13a1ca817da6d4074f/gym/envs/classic_control/mountain_car.py
60 |
--------------------------------------------------------------------------------
/replay_buffer.py:
--------------------------------------------------------------------------------
1 | # This file was based on
2 | # https://github.com/openai/baselines/blob/edb52c22a5e14324304a491edc0f91b6cc07453b/baselines/deepq/replay_buffer.py
3 | # its license:
4 | #
5 | # The MIT License
6 | #
7 | # Copyright (c) 2017 OpenAI (http://openai.com)
8 | #
9 | # Permission is hereby granted, free of charge, to any person obtaining a copy
10 | # of this software and associated documentation files (the "Software"), to deal
11 | # in the Software without restriction, including without limitation the rights
12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 | # copies of the Software, and to permit persons to whom the Software is
14 | # furnished to do so, subject to the following conditions:
15 | #
16 | # The above copyright notice and this permission notice shall be included in
17 | # all copies or substantial portions of the Software.
18 | #
19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
25 | # THE SOFTWARE.
26 |
27 | import random
28 |
29 | import numpy as np
30 |
31 |
32 | class ReplayBuffer(object):
33 | def __init__(self, size):
34 | """
35 | Parameters
36 | ----------
37 | size: int
38 | Max number of transitions to store in the buffer. When the buffer
39 | overflows the old memories are dropped.
40 | """
41 | self._storage = []
42 | self._max_size = size
43 | self._next_idx = 0
44 |
45 | def __len__(self):
46 | return len(self._storage)
47 |
48 | def add(self, observation, action, reward, next_obs, done):
49 | data = (observation, action, reward, next_obs, done)
50 | if self._next_idx >= len(self._storage):
51 | self._storage.append(data)
52 | else:
53 | self._storage[self._next_idx] = data
54 | self._next_idx = (self._next_idx + 1) % self._max_size
55 |
56 | def _encode_sample(self, indices):
57 | goals, observations, actions, rewards, next_observations, dones = [], [], [], [], [], []
58 | for i in indices:
59 | data = self._storage[i]
60 | observation, action, reward, next_obs, done = data
61 | observations.append(np.array(observation, copy=False))
62 | actions.append(np.array(action, copy=False))
63 | rewards.append(reward)
64 | next_observations.append(np.array(next_obs, copy=False))
65 | dones.append(done)
66 | return np.array(observations), np.array(actions), np.array(rewards), np.array(next_observations), np.array(dones)
67 |
68 | def sample(self, batch_size):
69 | """Sample a batch of experiences.
70 |
71 | Parameters
72 | ----------
73 | batch_size: int
74 | How many transitions to sample.
75 |
76 | Returns
77 | -------
78 | observations: np.array
79 | actions: np.array
80 | rewards: np.array
81 | next_observations: np.array
82 | dones: np.array
83 | """
84 | idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
85 | return self._encode_sample(idxes)
86 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import random
2 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
3 | from time import strftime, time
4 |
5 | import cv2
6 | import gym
7 | import numpy as np
8 | import psutil
9 | import tensorflow as tf
10 | import tensorflow.contrib.keras as keras
11 | from tqdm import tqdm
12 |
13 | from loggers import TensorBoardLogger
14 | from replay_buffer import ReplayBuffer
15 |
16 | DISCOUNT_FACTOR_GAMMA = 0.99
17 | LEARNING_RATE = 0.001
18 | BATCH_SIZE = 64
19 | TARGET_UPDATE_EVERY = 1000
20 | TRAIN_START = 2000
21 | REPLAY_BUFFER_SIZE = 50000
22 | MAX_STEPS = 200000
23 | LOG_EVERY = 2000
24 | SNAPSHOT_EVERY = 50000
25 | EVAL_EVERY = 20000
26 | EVAL_STEPS = 10000
27 | EVAL_EPSILON = 0
28 | TRAIN_EPSILON = 0.01
29 | Q_VALIDATION_SIZE = 10000
30 |
31 |
32 | def one_hot_encode(n, action):
33 | one_hot = np.zeros(n)
34 | one_hot[int(action)] = 1
35 | return one_hot
36 |
37 |
38 | def predict(env, model, observations):
39 | action_mask = np.ones((len(observations), env.action_space.n))
40 | return model.predict(x=[observations, action_mask])
41 |
42 |
43 | def fit_batch(env, model, target_model, batch):
44 | observations, actions, rewards, next_observations, dones = batch
45 | # Predict the Q values of the next states. Passing ones as the action mask.
46 | next_q_values = predict(env, target_model, next_observations)
47 | # The Q values of terminal states is 0 by definition.
48 | next_q_values[dones] = 0.0
49 | # The Q values of each start state is the reward + gamma * the max next state Q value
50 | q_values = rewards + DISCOUNT_FACTOR_GAMMA * np.max(next_q_values, axis=1)
51 | one_hot_actions = np.array([one_hot_encode(env.action_space.n, action) for action in actions])
52 | history = model.fit(
53 | x=[observations, one_hot_actions],
54 | y=one_hot_actions * q_values[:, None],
55 | batch_size=BATCH_SIZE,
56 | verbose=0,
57 | )
58 | return history.history['loss'][0]
59 |
60 |
61 | def create_model(env):
62 | n_actions = env.action_space.n
63 | obs_shape = env.observation_space.shape
64 | observations_input = keras.layers.Input(obs_shape, name='observations_input')
65 | action_mask = keras.layers.Input((n_actions,), name='action_mask')
66 | hidden = keras.layers.Dense(32, activation='relu')(observations_input)
67 | hidden_2 = keras.layers.Dense(32, activation='relu')(hidden)
68 | output = keras.layers.Dense(n_actions)(hidden_2)
69 | filtered_output = keras.layers.multiply([output, action_mask])
70 | model = keras.models.Model([observations_input, action_mask], filtered_output)
71 | optimizer = keras.optimizers.Adam(lr=LEARNING_RATE, clipnorm=1.0)
72 | model.compile(optimizer, loss='mean_squared_error')
73 | return model
74 |
75 |
76 | def greedy_action(env, model, observation):
77 | next_q_values = predict(env, model, observations=[observation])
78 | return np.argmax(next_q_values)
79 |
80 |
81 | def epsilon_greedy_action(env, model, observation, epsilon):
82 | if random.random() < epsilon:
83 | action = env.action_space.sample()
84 | else:
85 | action = greedy_action(env, model, observation)
86 | return action
87 |
88 |
89 | def save_model(model, step, logdir, name):
90 | filename = '{}/{}-{}.h5'.format(logdir, name, step)
91 | model.save(filename)
92 | print('Saved {}'.format(filename))
93 | return filename
94 |
95 |
96 | def save_image(env, episode, step):
97 | frame = env.render(mode='rgb_array')
98 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # following cv2.imwrite assumes BGR
99 | filename = "{}_{:06d}.png".format(episode, step)
100 | cv2.imwrite(filename, frame, params=[cv2.IMWRITE_PNG_COMPRESSION, 9])
101 |
102 |
103 | def evaluate(env, model, view=False, images=False):
104 | print("Evaluation")
105 | done = True
106 | episode = 0
107 | episode_return_sum = 0.0
108 | for step in tqdm(range(1, EVAL_STEPS + 1)):
109 | if done:
110 | if episode > 0:
111 | episode_return_sum += episode_return
112 | obs = env.reset()
113 | episode += 1
114 | episode_return = 0.0
115 | episode_steps = 0
116 | if view:
117 | env.render()
118 | if images:
119 | save_image(env, episode, step)
120 | else:
121 | obs = next_obs
122 | action = epsilon_greedy_action(env, model, obs, epsilon=EVAL_EPSILON)
123 | next_obs, reward, done, _ = env.step(action)
124 | episode_return += reward
125 | episode_steps += 1
126 | if view:
127 | env.render()
128 | if images:
129 | save_image(env, episode, step)
130 | assert episode > 0
131 | episode_return_avg = episode_return_sum / episode
132 | return episode_return_avg
133 |
134 |
135 | def train(env, model, max_steps, name, logdir, logger):
136 | target_model = create_model(env)
137 | replay = ReplayBuffer(REPLAY_BUFFER_SIZE)
138 | done = True
139 | episode = 0
140 | steps_after_logging = 0
141 | loss = 0.0
142 | for step in range(1, max_steps + 1):
143 | try:
144 | if step % SNAPSHOT_EVERY == 0:
145 | save_model(model, step, logdir, name)
146 | if done:
147 | if episode > 0:
148 | if steps_after_logging >= LOG_EVERY:
149 | steps_after_logging = 0
150 | episode_end = time()
151 | episode_seconds = episode_end - episode_start
152 | episode_steps = step - episode_start_step
153 | steps_per_second = episode_steps / episode_seconds
154 | memory = psutil.virtual_memory()
155 | to_gb = lambda in_bytes: in_bytes / 1024 / 1024 / 1024
156 | print(
157 | "episode {} "
158 | "steps {}/{} "
159 | "loss {:.7f} "
160 | "return {} "
161 | "in {:.2f}s "
162 | "{:.1f} steps/s "
163 | "{:.1f}/{:.1f} GB RAM".format(
164 | episode,
165 | episode_steps,
166 | step,
167 | loss,
168 | episode_return,
169 | episode_seconds,
170 | steps_per_second,
171 | to_gb(memory.used),
172 | to_gb(memory.total),
173 | ))
174 | logger.log_scalar('episode_return', episode_return, step)
175 | logger.log_scalar('episode_steps', episode_steps, step)
176 | logger.log_scalar('episode_seconds', episode_seconds, step)
177 | logger.log_scalar('steps_per_second', steps_per_second, step)
178 | logger.log_scalar('memory_used', to_gb(memory.used), step)
179 | logger.log_scalar('loss', loss, step)
180 | episode_start = time()
181 | episode_start_step = step
182 | obs = env.reset()
183 | episode += 1
184 | episode_return = 0.0
185 | else:
186 | obs = next_obs
187 |
188 | action = epsilon_greedy_action(env, model, obs, epsilon=TRAIN_EPSILON)
189 | next_obs, reward, done, _ = env.step(action)
190 | episode_return += reward
191 | replay.add(obs, action, reward, next_obs, done)
192 |
193 | if step >= TRAIN_START:
194 | if step % TARGET_UPDATE_EVERY == 0:
195 | target_model.set_weights(model.get_weights())
196 | batch = replay.sample(BATCH_SIZE)
197 | loss = fit_batch(env, model, target_model, batch)
198 | if step == Q_VALIDATION_SIZE:
199 | q_validation_observations, _, _, _, _ = replay.sample(Q_VALIDATION_SIZE)
200 | if step >= TRAIN_START and step % EVAL_EVERY == 0:
201 | episode_return_avg = evaluate(env, model)
202 | q_values = predict(env, model, q_validation_observations)
203 | max_q_values = np.max(q_values, axis=1)
204 | avg_max_q_value = np.mean(max_q_values)
205 | print(
206 | "episode {} "
207 | "step {} "
208 | "episode_return_avg {:.3f} "
209 | "avg_max_q_value {:.3f}".format(
210 | episode,
211 | step,
212 | episode_return_avg,
213 | avg_max_q_value,
214 | ))
215 | logger.log_scalar('episode_return_avg', episode_return_avg, step)
216 | logger.log_scalar('avg_max_q_value', avg_max_q_value, step)
217 | steps_after_logging += 1
218 | except KeyboardInterrupt:
219 | save_model(model, step, logdir, name)
220 | break
221 |
222 |
223 | def load_or_create_model(env, model_filename):
224 | if model_filename:
225 | model = keras.models.load_model(model_filename)
226 | print('Loaded {}'.format(model_filename))
227 | else:
228 | model = create_model(env)
229 | model.summary()
230 | return model
231 |
232 |
233 | def set_seed(env, seed):
234 | random.seed(seed)
235 | np.random.seed(seed)
236 | tf.set_random_seed(seed)
237 | env.seed(seed)
238 |
239 |
240 | def main(args):
241 | assert BATCH_SIZE <= TRAIN_START <= Q_VALIDATION_SIZE <= REPLAY_BUFFER_SIZE
242 | print('args', args)
243 | env = gym.make('MountainCar-v0')
244 | set_seed(env, args.seed)
245 | model = load_or_create_model(env, args.model)
246 | if args.view or args.eval or args.images:
247 | episode_return_avg = evaluate(env, model, args.view, args.images)
248 | print("episode_return_avg {:.3f}".format(episode_return_avg))
249 | else:
250 | max_steps = 100 if args.test else MAX_STEPS
251 | logdir = '{}-log'.format(args.name)
252 | logger = TensorBoardLogger(logdir)
253 | print('Created {}'.format(logdir))
254 | train(env, model, max_steps, args.name, logdir, logger)
255 | if args.test:
256 | filename = save_model(model, EVAL_STEPS, logdir='.', name='test')
257 | load_or_create_model(env, filename)
258 |
259 |
260 | if __name__ == '__main__':
261 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
262 | parser.add_argument('--eval', action='store_true', default=False, help='run evaluation with log only')
263 | parser.add_argument('--images', action='store_true', default=False, help='save images during evaluation')
264 | parser.add_argument('--model', action='store', default=None, help='model filename to load')
265 | parser.add_argument('--name', action='store', default=strftime("%m-%d-%H-%M"), help='name for saved files')
266 | parser.add_argument('--seed', action='store', type=int, help='pseudo random number generator seed')
267 | parser.add_argument('--test', action='store_true', default=False, help='run tests')
268 | parser.add_argument('--view', action='store_true', default=False, help='view the model playing the game')
269 | main(parser.parse_args())
270 |
--------------------------------------------------------------------------------