├── .gitignore ├── logo.png ├── 0_quick_start └── 0_logging_device_placement.py ├── 777_workarounds └── 777_1_tf2_cuda10.py ├── 1_keras_api ├── 2_sequential_model.py ├── 1_numbers_classification.ipynb └── 4_text_classification.ipynb ├── LICENSE ├── 2_estimators ├── 2_1_linear_model.ipynb └── .ipynb_checkpoints │ └── 2_1_linear_model-checkpoint.ipynb ├── 19_lingvo └── 19_1_task_config.py ├── README.md └── 20_tf2 ├── 20_2_a2c.py └── 20_1_actor_critic_agent.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0101011/bootstrap-ml/master/logo.png -------------------------------------------------------------------------------- /0_quick_start/0_logging_device_placement.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | # Creates a graph 4 | a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a') 5 | b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2], name='b') 6 | c = tf.matmul(a, b) 7 | 8 | # Creates a session with log_device_placement set to True 9 | sess = tf.Session(config=tf.ConfigProto(log_device_placement=True)) 10 | 11 | # Runs the op 12 | print(sess.run(c)) 13 | -------------------------------------------------------------------------------- /777_workarounds/777_1_tf2_cuda10.py: -------------------------------------------------------------------------------- 1 | !pip install tf-nightly-gpu-2.0-preview 2 | 3 | !wget https://developer.nvidia.com/compute/cuda/10.0/Prod/local_installers/cuda-repo-ubuntu1604-10-0-local-10.0.130-410.48_1.0-1_amd64 -O cuda-repo-ubuntu1604-10-0-local-10.0.130-410.48_1.0-1_amd64.deb 4 | !dpkg -i cuda-repo-ubuntu1604-10-0-local-10.0.130-410.48_1.0-1_amd64.deb 5 | !apt-key add /var/cuda-repo-10-0-local-10.0.130-410.48/7fa2af80.pub 6 | !apt-get update 7 | !apt-get install cuda 8 | !pip install tf-nightly-gpu-2.0-preview 9 | 10 | import tensorflow as tf 11 | 12 | print(tf.__version__) 13 | -------------------------------------------------------------------------------- /1_keras_api/2_sequential_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Running the tiny sample model is faster on the CPU: Batch loading 3 | from RAM to GPU is slower at the start of each operation. Forward/backward 4 | computations are very quick in tiny networks so it's rational to use CPU. 5 | You can also try using model.fit_generator instead of plain fit, so that 6 | CPU thread which loads minibatches works in parallel. 7 | 8 | At the time there is no way I am aware of to preload the whole dataset 9 | on GPU with Keras. 10 | """ 11 | 12 | # Hiding a GPU 13 | import os 14 | os.environ["CUDA_VISIBLE_DEVICES"] = '-1' 15 | 16 | # Importing tf and numpy 17 | import tensorflow as tf 18 | from tensorflow import keras 19 | import numpy as np 20 | 21 | # Defining and running the model 22 | model = keras.Sequential([keras.layers.Dense(units=1, input_shape=[1])]) 23 | model.compile(optimizer='sgd', loss='mean_squared_error') 24 | 25 | xs = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=float) 26 | ys = np.array([-3.0, -1.0, 1.0, 3.0, 5.0, 7.0], dtype=float) 27 | 28 | model.fit(xs, ys, epochs=500) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Andrew Stepin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /2_estimators/2_1_linear_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import tensorflow as tf\n", 10 | "import tensorflow.feature_column as fc \n", 11 | "\n", 12 | "import os\n", 13 | "import sys\n", 14 | "\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "from IPython.display import clear_output" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "tf.enable_eager_execution()" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "! git clone --depth 1 https://github.com/tensorflow/models" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [] 43 | } 44 | ], 45 | "metadata": { 46 | "kernelspec": { 47 | "display_name": "Python 3", 48 | "language": "python", 49 | "name": "python3" 50 | }, 51 | "language_info": { 52 | "codemirror_mode": { 53 | "name": "ipython", 54 | "version": 3 55 | }, 56 | "file_extension": ".py", 57 | "mimetype": "text/x-python", 58 | "name": "python", 59 | "nbconvert_exporter": "python", 60 | "pygments_lexer": "ipython3", 61 | "version": "3.6.8" 62 | } 63 | }, 64 | "nbformat": 4, 65 | "nbformat_minor": 2 66 | } 67 | -------------------------------------------------------------------------------- /2_estimators/.ipynb_checkpoints/2_1_linear_model-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import tensorflow as tf\n", 10 | "import tensorflow.feature_column as fc \n", 11 | "\n", 12 | "import os\n", 13 | "import sys\n", 14 | "\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "from IPython.display import clear_output" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "tf.enable_eager_execution()" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "! git clone --depth 1 https://github.com/tensorflow/models" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [] 43 | } 44 | ], 45 | "metadata": { 46 | "kernelspec": { 47 | "display_name": "Python 3", 48 | "language": "python", 49 | "name": "python3" 50 | }, 51 | "language_info": { 52 | "codemirror_mode": { 53 | "name": "ipython", 54 | "version": 3 55 | }, 56 | "file_extension": ".py", 57 | "mimetype": "text/x-python", 58 | "name": "python", 59 | "nbconvert_exporter": "python", 60 | "pygments_lexer": "ipython3", 61 | "version": "3.6.8" 62 | } 63 | }, 64 | "nbformat": 4, 65 | "nbformat_minor": 2 66 | } 67 | -------------------------------------------------------------------------------- /19_lingvo/19_1_task_config.py: -------------------------------------------------------------------------------- 1 | def Task(cls): 2 | p = model.AsrModel.Params() 3 | p.name = 'librispeech' 4 | 5 | # Initialize encoder params. 6 | ep = p.encoder 7 | 8 | # Data consists 240 dimensional frames (80 x 3 frames), which we 9 | # re-interpret as individual 80 dimensional frames. See also, 10 | # LibrispeechCommonAsrInputParams. 11 | ep.input_shape = [None, None, 80, 1] 12 | ep.lstm_cell_size = 1024 13 | ep.num_lstm_layers = 4 14 | ep.conv_filter_shapes = [(3, 3, 1, 32), (3, 3, 32, 32)] 15 | ep.conv_filter_strides = [(2, 2), (2, 2)] 16 | ep.cnn_tpl.params_init = py_utils.WeightInit.Gaussian(0.001) 17 | 18 | # Disable conv LSTM layers. 19 | ep.num_conv_lstm_layers = 0 20 | 21 | # Initialize decoder params. 22 | dp = p.decoder 23 | dp.rnn_cell_dim = 1024 24 | dp.rnn_layers = 2 25 | dp.source_dim = 2048 26 | # Use functional while based unrolling. 27 | dp.use_while_loop_based_unrolling = False 28 | 29 | tp = p.train 30 | tp.learning_rate = 2.5e-4 31 | tp.lr_schedule = lr_schedule.ContinuousLearningRateSchedule.Params().Set( 32 | start_step=50000, half_life_steps=100000, min=0.01) 33 | 34 | # Setting p.eval.samples_per_summary to a large value ensures that dev, 35 | # devother, test, testother are evaluated completely (since num_samples for 36 | # each of these sets is less than 5000), while train summaries will be 37 | # computed on 5000 examples. 38 | p.eval.samples_per_summary = 5000 39 | p.eval.decoder_samples_per_summary = 0 40 | 41 | # Use variational weight noise to prevent overfitting. 42 | p.vn.global_vn = True 43 | p.train.vn_std = 0.075 44 | p.train.vn_start_step = 20000 45 | 46 | return p 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bootstrap ML 2 | 3 | ![Bootstrap ML Logo](logo.png) 4 | 5 | **Bootstrap ML** is a comprehensive collection of pre-written code for machine learning and deep learning use cases, all in one convenient place. Whether you're a seasoned practitioner or just starting your ML journey, this repository provides a solid foundation to build upon. 6 | 7 | ## What Is It About? 8 | 9 | **Bootstrap ML** aims to accelerate your machine learning and deep learning projects by providing reusable, well-documented code snippets and notebooks. It covers a range of use cases, from quick starts to advanced neural network implementations. 10 | 11 | ### Folder Overview 12 | 13 | - **0_quick_start**: 14 | - `0_logging_device_placement.py`: Logs device placement to help identify performance bottlenecks. 15 | 16 | - **1_keras_api**: 17 | - `1_numbers_classification.ipynb`: Notebook demonstrating number classification using Keras. 18 | - `2_sequential_model.py`: Basic Sequential model example using Keras. 19 | - `3_basic_classification.ipynb`: Notebook for basic classification using Keras. 20 | - `4_text_classification.ipynb`: Notebook for text classification using Keras. 21 | 22 | - **2_estimators**: 23 | - `2_1_linear_model.ipynb`: Notebook demonstrating a linear model implementation using TensorFlow Estimators. 24 | 25 | - **19_lingvo**: 26 | - `19_1_task_config.py`: Task configuration example using the Lingvo framework. 27 | 28 | - **20_tf2**: 29 | - `20_1_actor_critic_agent.ipynb`: Notebook demonstrating an Actor-Critic agent. 30 | - `20_2_a2c.py`: Advantage Actor-Critic (A2C) implementation. 31 | 32 | - **777_workarounds**: 33 | - `777_1_tf2_cuda10.py`: Workaround for TensorFlow 2.x with CUDA 10 compatibility issues. 34 | 35 | ## Benefits 36 | 37 | - **Plug-and-Play**: Pre-written, reusable code that can be easily integrated into your projects. 38 | - **Wide Range of Use Cases**: From data preprocessing to advanced neural network models. 39 | - **Scalable and Efficient**: Optimized for both small-scale experiments and large-scale production workloads. 40 | - **Customizable**: Easily modify and extend the code to suit your specific needs. 41 | 42 | ## TODO List 43 | 44 | - [ ] Add more examples for TensorFlow 2.x. 45 | - [ ] Add the most used deep learning architectures with practical examples. 46 | - [ ] Expand the Lingvo framework examples. 47 | - [ ] PyTorch models and examples. 48 | - [ ] Add enchmarking suite for model comparisons. 49 | 50 | ## Contributing 51 | 52 | I've been working on this repo on my free time contributing on and off as I had free time. Here's how you can get involved: 53 | 54 | 1. Fork the repository. 55 | 2. Create a new branch (`git checkout -b feature-branch`). 56 | 3. Make your changes and commit them (`git commit -m 'Add new feature'`). 57 | 4. Push to your branch (`git push origin feature-branch`). 58 | 5. Create a new Pull Request. 59 | 60 | Feel free to reach out for questions, suggestions, or feedback! 61 | 62 | -- Andrew 63 | -------------------------------------------------------------------------------- /1_keras_api/1_numbers_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import tensorflow as tf\n", 10 | "\n", 11 | "tf.__version__\n", 12 | "\n", 13 | "mnist = tf.keras.datasets.mnist # 28x28 images of hand-written digits 0-9\n", 14 | "\n", 15 | "(x_train, y_train), (x_test, y_test) = mnist.load_data\n", 16 | "\n", 17 | "x_train = tf.keras.utils.normalize(x_train, axis=1)\n", 18 | "x_test = tf.keras.utils.normalize(x_test, axis=1)\n", 19 | "\n", 20 | "model = tf.keras.models.Sequential()\n", 21 | "model.add(tf.keras.layers.Flatten())\n", 22 | "model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))\n", 23 | "model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))\n", 24 | "model.add(tf.keras.layers.Dense(10, activation=tf.nn.softmax))\n", 25 | "\n", 26 | "model.compile(optimizer='adam',\n", 27 | " loss='sparse_categorical_crossentropy',\n", 28 | " metrics=['accuracy'])\n", 29 | "\n", 30 | "# 1. Loss is the degree of error, what you've got wrong\n", 31 | "# 2. Neural networks always try to minimize loss\n", 32 | "# 3. Use adam as the default go-to optimizer in most cases\n", 33 | "\n", 34 | "model.fit(x_train, y_train, epochs=3)\n", 35 | "\n", 36 | "# Validation loss and accuracy calculation\n", 37 | "\n", 38 | "val_loss, val_acc = model.evaluate(x_test, y_test)\n", 39 | "print(val_loss, val_acc)\n", 40 | "\n", 41 | "# 1. Expect loss to be slightly higher, and accuracy to be slightly lower\n", 42 | "# 2. If there's a huge delta then you've probably overfit the model" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "import matplotlib.pyplot as plt\n", 52 | "\n", 53 | "plt.imshow(x_train[0], cmap = plt.cm.binary)\n", 54 | "plt.show()\n", 55 | "\n", 56 | "# print(x_train[0])" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "model.save('num_reader.model')" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "new_model = tf.keras.models.load_model('num_reader.model')" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "predictions = new_model.predict([x_test])\n", 84 | "print(predictions)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "import numpy as np\n", 94 | "\n", 95 | "print(np.argmax(predictions[0]))" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "plt.imshow(x_test[0], cmap = plt.cm.binary)\n", 105 | "plt.show()" 106 | ] 107 | } 108 | ], 109 | "metadata": { 110 | "kernelspec": { 111 | "display_name": "Python 3", 112 | "language": "python", 113 | "name": "python3" 114 | }, 115 | "language_info": { 116 | "codemirror_mode": { 117 | "name": "ipython", 118 | "version": 3 119 | }, 120 | "file_extension": ".py", 121 | "mimetype": "text/x-python", 122 | "name": "python", 123 | "nbconvert_exporter": "python", 124 | "pygments_lexer": "ipython3", 125 | "version": "3.7.1" 126 | } 127 | }, 128 | "nbformat": 4, 129 | "nbformat_minor": 2 130 | } 131 | -------------------------------------------------------------------------------- /20_tf2/20_2_a2c.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import logging 3 | import numpy as np 4 | import tensorflow as tf 5 | import matplotlib.pyplot as plt 6 | import tensorflow.keras.layers as kl 7 | import tensorflow.keras.losses as kls 8 | import tensorflow.keras.optimizers as ko 9 | 10 | 11 | class ProbabilityDistribution(tf.keras.Model): 12 | def call(self, logits): 13 | # sample a random categorical action from given logits 14 | return tf.squeeze(tf.random.categorical(logits, 1), axis=-1) 15 | 16 | 17 | class Model(tf.keras.Model): 18 | def __init__(self, num_actions): 19 | super().__init__('mlp_policy') 20 | # no tf.get_variable(), just simple Keras API 21 | self.hidden1 = kl.Dense(128, activation='relu') 22 | self.hidden2 = kl.Dense(128, activation='relu') 23 | self.value = kl.Dense(1, name='value') 24 | # logits are unnormalized log probabilities 25 | self.logits = kl.Dense(num_actions, name='policy_logits') 26 | self.dist = ProbabilityDistribution() 27 | 28 | def call(self, inputs): 29 | # inputs is a numpy array, convert to Tensor 30 | x = tf.convert_to_tensor(inputs) 31 | # separate hidden layers from the same input tensor 32 | hidden_logs = self.hidden1(x) 33 | hidden_vals = self.hidden2(x) 34 | return self.logits(hidden_logs), self.value(hidden_vals) 35 | 36 | def action_value(self, obs): 37 | # executes call() under the hood 38 | logits, value = self.predict(obs) 39 | action = self.dist.predict(logits) 40 | # a simpler option, will become clear later why we don't use it 41 | # action = tf.random.categorical(logits, 1) 42 | return np.squeeze(action, axis=-1), np.squeeze(value, axis=-1) 43 | 44 | 45 | class A2CAgent: 46 | def __init__(self, model): 47 | # hyperparameters for loss terms, gamma is the discount coefficient 48 | self.params = { 49 | 'gamma': 0.99, 50 | 'value': 0.5, 51 | 'entropy': 0.0001 52 | } 53 | self.model = model 54 | self.model.compile( 55 | optimizer=ko.RMSprop(lr=0.0007), 56 | # define separate losses for policy logits and value estimate 57 | loss=[self._logits_loss, self._value_loss] 58 | ) 59 | 60 | def train(self, env, batch_sz=32, updates=1000): 61 | # storage helpers for a single batch of data 62 | actions = np.empty((batch_sz,), dtype=np.int32) 63 | rewards, dones, values = np.empty((3, batch_sz)) 64 | observations = np.empty((batch_sz,) + env.observation_space.shape) 65 | # training loop: collect samples, send to optimizer, repeat updates times 66 | ep_rews = [0.0] 67 | next_obs = env.reset() 68 | for update in range(updates): 69 | for step in range(batch_sz): 70 | observations[step] = next_obs.copy() 71 | actions[step], values[step] = self.model.action_value(next_obs[None, :]) 72 | next_obs, rewards[step], dones[step], _ = env.step(actions[step]) 73 | 74 | ep_rews[-1] += rewards[step] 75 | if dones[step]: 76 | ep_rews.append(0.0) 77 | next_obs = env.reset() 78 | logging.info("Episode: %03d, Reward: %03d" % (len(ep_rews)-1, ep_rews[-2])) 79 | 80 | _, next_value = self.model.action_value(next_obs[None, :]) 81 | returns, advs = self._returns_advantages(rewards, dones, values, next_value) 82 | # a trick to input actions and advantages through same API 83 | acts_and_advs = np.concatenate([actions[:, None], advs[:, None]], axis=-1) 84 | # performs a full training step on the collected batch 85 | # note: no need to mess around with gradients, Keras API handles it 86 | losses = self.model.train_on_batch(observations, [acts_and_advs, returns]) 87 | logging.debug("[%d/%d] Losses: %s" % (update+1, updates, losses)) 88 | return ep_rews 89 | 90 | def test(self, env, render=False): 91 | obs, done, ep_reward = env.reset(), False, 0 92 | while not done: 93 | action, _ = self.model.action_value(obs[None, :]) 94 | obs, reward, done, _ = env.step(action) 95 | ep_reward += reward 96 | if render: 97 | env.render() 98 | return ep_reward 99 | 100 | def _returns_advantages(self, rewards, dones, values, next_value): 101 | # next_value is the bootstrap value estimate of a future state (the critic) 102 | returns = np.append(np.zeros_like(rewards), next_value, axis=-1) 103 | # returns are calculated as discounted sum of future rewards 104 | for t in reversed(range(rewards.shape[0])): 105 | returns[t] = rewards[t] + self.params['gamma'] * returns[t+1] * (1-dones[t]) 106 | returns = returns[:-1] 107 | # advantages are returns - baseline, value estimates in our case 108 | advantages = returns - values 109 | return returns, advantages 110 | 111 | def _value_loss(self, returns, value): 112 | # value loss is typically MSE between value estimates and returns 113 | return self.params['value']*kls.mean_squared_error(returns, value) 114 | 115 | def _logits_loss(self, acts_and_advs, logits): 116 | # a trick to input actions and advantages through same API 117 | actions, advantages = tf.split(acts_and_advs, 2, axis=-1) 118 | # polymorphic CE loss function that supports sparse and weighted options 119 | # from_logits argument ensures transformation into normalized probabilities 120 | cross_entropy = kls.CategoricalCrossentropy(from_logits=True) 121 | # policy loss is defined by policy gradients, weighted by advantages 122 | # note: we only calculate the loss on the actions we've actually taken 123 | # thus under the hood a sparse version of CE loss will be executed 124 | actions = tf.cast(actions, tf.int32) 125 | policy_loss = cross_entropy(actions, logits, sample_weight=advantages) 126 | # entropy loss can be calculated via CE over itself 127 | entropy_loss = cross_entropy(logits, logits) 128 | # here signs are flipped because optimizer minimizes 129 | return policy_loss - self.params['entropy']*entropy_loss 130 | 131 | 132 | if __name__ == '__main__': 133 | logging.basicConfig(level=logging.INFO) 134 | 135 | env = gym.make('CartPole-v0') 136 | model = Model(num_actions=env.action_space.n) 137 | agent = A2CAgent(model) 138 | 139 | rewards_history = agent.train(env) 140 | print("Finished training.") 141 | print("Total Episode Reward: %d out of 200" % agent.test(env, True)) 142 | 143 | plt.style.use('seaborn') 144 | plt.plot(np.arange(0, len(rewards_history), 25), rewards_history[::25]) 145 | plt.xlabel('Episode') 146 | plt.ylabel('Total Reward') 147 | plt.show() 148 | -------------------------------------------------------------------------------- /1_keras_api/4_text_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "1.12.0\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "import tensorflow as tf\n", 18 | "from tensorflow import keras\n", 19 | "\n", 20 | "import numpy as np\n", 21 | "\n", 22 | "print(tf.__version__)" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "imdb = keras.datasets.imdb\n", 32 | "\n", 33 | "(train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "Training entries: 25000, labels: 25000\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "print(\"Training entries: {}, labels: {}\".format(len(train_data), len(train_labels)))" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 4, 56 | "metadata": {}, 57 | "outputs": [ 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "[1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65, 458, 4468, 66, 3941, 4, 173, 36, 256, 5, 25, 100, 43, 838, 112, 50, 670, 2, 9, 35, 480, 284, 5, 150, 4, 172, 112, 167, 2, 336, 385, 39, 4, 172, 4536, 1111, 17, 546, 38, 13, 447, 4, 192, 50, 16, 6, 147, 2025, 19, 14, 22, 4, 1920, 4613, 469, 4, 22, 71, 87, 12, 16, 43, 530, 38, 76, 15, 13, 1247, 4, 22, 17, 515, 17, 12, 16, 626, 18, 2, 5, 62, 386, 12, 8, 316, 8, 106, 5, 4, 2223, 5244, 16, 480, 66, 3785, 33, 4, 130, 12, 16, 38, 619, 5, 25, 124, 51, 36, 135, 48, 25, 1415, 33, 6, 22, 12, 215, 28, 77, 52, 5, 14, 407, 16, 82, 2, 8, 4, 107, 117, 5952, 15, 256, 4, 2, 7, 3766, 5, 723, 36, 71, 43, 530, 476, 26, 400, 317, 46, 7, 4, 2, 1029, 13, 104, 88, 4, 381, 15, 297, 98, 32, 2071, 56, 26, 141, 6, 194, 7486, 18, 4, 226, 22, 21, 134, 476, 26, 480, 5, 144, 30, 5535, 18, 51, 36, 28, 224, 92, 25, 104, 4, 226, 65, 16, 38, 1334, 88, 12, 16, 283, 5, 16, 4472, 113, 103, 32, 15, 16, 5345, 19, 178, 32]\n" 63 | ] 64 | } 65 | ], 66 | "source": [ 67 | "print(train_data[0])" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 5, 73 | "metadata": {}, 74 | "outputs": [ 75 | { 76 | "data": { 77 | "text/plain": [ 78 | "(218, 189)" 79 | ] 80 | }, 81 | "execution_count": 5, 82 | "metadata": {}, 83 | "output_type": "execute_result" 84 | } 85 | ], 86 | "source": [ 87 | "len(train_data[0]), len(train_data[1])" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 6, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "word_index = imdb.get_word_index()\n", 97 | "\n", 98 | "word_index = {k:(v+3) for k,v in word_index.items()} \n", 99 | "word_index[\"\"] = 0\n", 100 | "word_index[\"\"] = 1\n", 101 | "word_index[\"\"] = 2 \n", 102 | "word_index[\"\"] = 3\n", 103 | "\n", 104 | "word_index = imdb.get_word_index()\n", 105 | "\n", 106 | "word_index = {k:(v+3) for k,v in word_index.items()} \n", 107 | "word_index[\"\"] = 0\n", 108 | "word_index[\"\"] = 1\n", 109 | "word_index[\"\"] = 2 \n", 110 | "word_index[\"\"] = 3\n", 111 | "\n", 112 | "reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])\n", 113 | "\n", 114 | "def decode_review(text):\n", 115 | " return ' '.join([reverse_word_index.get(i, '?') for i in text])" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 8, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "data": { 125 | "text/plain": [ 126 | "\" this film was just brilliant casting location scenery story direction everyone's really suited the part they played and you could just imagine being there robert is an amazing actor and now the same being director father came from the same scottish island as myself so i loved the fact there was a real connection with this film the witty remarks throughout the film were great it was just brilliant so much that i bought the film as soon as it was released for and would recommend it to everyone to watch and the fly fishing was amazing really cried at the end it was so sad and you know what they say if you cry at a film it must have been good and this definitely was also to the two little boy's that played the of norman and paul they were just brilliant children are often left out of the list i think because the stars that play them all grown up are such a big profile for the whole film but these children are amazing and should be praised for what they have done don't you think the whole story was so lovely because it was true and was someone's life after all that was shared with us all\"" 127 | ] 128 | }, 129 | "execution_count": 8, 130 | "metadata": {}, 131 | "output_type": "execute_result" 132 | } 133 | ], 134 | "source": [ 135 | "decode_review(train_data[0])" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 10, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "train_data = keras.preprocessing.sequence.pad_sequences(train_data,\n", 145 | " value=word_index[\"\"],\n", 146 | " padding='post',\n", 147 | " maxlen=256)\n", 148 | "\n", 149 | "test_data = keras.preprocessing.sequence.pad_sequences(test_data,\n", 150 | " value=word_index[\"\"],\n", 151 | " padding='post',\n", 152 | " maxlen=256)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 12, 158 | "metadata": {}, 159 | "outputs": [ 160 | { 161 | "data": { 162 | "text/plain": [ 163 | "(256, 256)" 164 | ] 165 | }, 166 | "execution_count": 12, 167 | "metadata": {}, 168 | "output_type": "execute_result" 169 | } 170 | ], 171 | "source": [ 172 | "len(train_data[0]), len(train_data[1])" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 13, 178 | "metadata": {}, 179 | "outputs": [ 180 | { 181 | "name": "stdout", 182 | "output_type": "stream", 183 | "text": [ 184 | "[ 1 14 22 16 43 530 973 1622 1385 65 458 4468 66 3941\n", 185 | " 4 173 36 256 5 25 100 43 838 112 50 670 2 9\n", 186 | " 35 480 284 5 150 4 172 112 167 2 336 385 39 4\n", 187 | " 172 4536 1111 17 546 38 13 447 4 192 50 16 6 147\n", 188 | " 2025 19 14 22 4 1920 4613 469 4 22 71 87 12 16\n", 189 | " 43 530 38 76 15 13 1247 4 22 17 515 17 12 16\n", 190 | " 626 18 2 5 62 386 12 8 316 8 106 5 4 2223\n", 191 | " 5244 16 480 66 3785 33 4 130 12 16 38 619 5 25\n", 192 | " 124 51 36 135 48 25 1415 33 6 22 12 215 28 77\n", 193 | " 52 5 14 407 16 82 2 8 4 107 117 5952 15 256\n", 194 | " 4 2 7 3766 5 723 36 71 43 530 476 26 400 317\n", 195 | " 46 7 4 2 1029 13 104 88 4 381 15 297 98 32\n", 196 | " 2071 56 26 141 6 194 7486 18 4 226 22 21 134 476\n", 197 | " 26 480 5 144 30 5535 18 51 36 28 224 92 25 104\n", 198 | " 4 226 65 16 38 1334 88 12 16 283 5 16 4472 113\n", 199 | " 103 32 15 16 5345 19 178 32 0 0 0 0 0 0\n", 200 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", 201 | " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", 202 | " 0 0 0 0]\n" 203 | ] 204 | } 205 | ], 206 | "source": [ 207 | "print(train_data[0])" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 14, 213 | "metadata": {}, 214 | "outputs": [ 215 | { 216 | "name": "stdout", 217 | "output_type": "stream", 218 | "text": [ 219 | "_________________________________________________________________\n", 220 | "Layer (type) Output Shape Param # \n", 221 | "=================================================================\n", 222 | "embedding (Embedding) (None, None, 16) 160000 \n", 223 | "_________________________________________________________________\n", 224 | "global_average_pooling1d (Gl (None, 16) 0 \n", 225 | "_________________________________________________________________\n", 226 | "dense (Dense) (None, 16) 272 \n", 227 | "_________________________________________________________________\n", 228 | "dense_1 (Dense) (None, 1) 17 \n", 229 | "=================================================================\n", 230 | "Total params: 160,289\n", 231 | "Trainable params: 160,289\n", 232 | "Non-trainable params: 0\n", 233 | "_________________________________________________________________\n" 234 | ] 235 | } 236 | ], 237 | "source": [ 238 | "vocab_size = 10000\n", 239 | "\n", 240 | "model = keras.Sequential()\n", 241 | "model.add(keras.layers.Embedding(vocab_size, 16))\n", 242 | "model.add(keras.layers.GlobalAveragePooling1D())\n", 243 | "model.add(keras.layers.Dense(16, activation=tf.nn.relu))\n", 244 | "model.add(keras.layers.Dense(1, activation=tf.nn.sigmoid))\n", 245 | "\n", 246 | "model.summary()" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 15, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "model.compile(optimizer=tf.train.AdamOptimizer(),\n", 256 | " loss='binary_crossentropy',\n", 257 | " metrics=['accuracy'])" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 16, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "x_val = train_data[:10000]\n", 267 | "partial_x_train = train_data[10000:]\n", 268 | "\n", 269 | "y_val = train_labels[:10000]\n", 270 | "partial_y_train = train_labels[10000:]" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 17, 276 | "metadata": {}, 277 | "outputs": [ 278 | { 279 | "name": "stdout", 280 | "output_type": "stream", 281 | "text": [ 282 | "Train on 15000 samples, validate on 10000 samples\n", 283 | "Epoch 1/40\n", 284 | "15000/15000 [==============================] - 4s 273us/step - loss: 0.6921 - acc: 0.6027 - val_loss: 0.6901 - val_acc: 0.7397\n", 285 | "Epoch 2/40\n", 286 | "15000/15000 [==============================] - 1s 66us/step - loss: 0.6865 - acc: 0.7416 - val_loss: 0.6826 - val_acc: 0.7406\n", 287 | "Epoch 3/40\n", 288 | "15000/15000 [==============================] - 1s 54us/step - loss: 0.6742 - acc: 0.7606 - val_loss: 0.6663 - val_acc: 0.7586\n", 289 | "Epoch 4/40\n", 290 | "15000/15000 [==============================] - 1s 51us/step - loss: 0.6505 - acc: 0.7718 - val_loss: 0.6397 - val_acc: 0.7692\n", 291 | "Epoch 5/40\n", 292 | "15000/15000 [==============================] - 1s 55us/step - loss: 0.6154 - acc: 0.7965 - val_loss: 0.6021 - val_acc: 0.7884\n", 293 | "Epoch 6/40\n", 294 | "15000/15000 [==============================] - 1s 53us/step - loss: 0.5711 - acc: 0.8144 - val_loss: 0.5594 - val_acc: 0.8023\n", 295 | "Epoch 7/40\n", 296 | "15000/15000 [==============================] - 1s 53us/step - loss: 0.5221 - acc: 0.8331 - val_loss: 0.5147 - val_acc: 0.8219\n", 297 | "Epoch 8/40\n", 298 | "15000/15000 [==============================] - 1s 54us/step - loss: 0.4736 - acc: 0.8485 - val_loss: 0.4727 - val_acc: 0.8355\n", 299 | "Epoch 9/40\n", 300 | "15000/15000 [==============================] - 1s 51us/step - loss: 0.4296 - acc: 0.8613 - val_loss: 0.4359 - val_acc: 0.8450\n", 301 | "Epoch 10/40\n", 302 | "15000/15000 [==============================] - 1s 52us/step - loss: 0.3902 - acc: 0.8759 - val_loss: 0.4050 - val_acc: 0.8526\n", 303 | "Epoch 11/40\n", 304 | "15000/15000 [==============================] - 1s 49us/step - loss: 0.3573 - acc: 0.8833 - val_loss: 0.3821 - val_acc: 0.8568\n", 305 | "Epoch 12/40\n", 306 | "15000/15000 [==============================] - 1s 51us/step - loss: 0.3301 - acc: 0.8903 - val_loss: 0.3600 - val_acc: 0.8655\n", 307 | "Epoch 13/40\n", 308 | "15000/15000 [==============================] - 1s 51us/step - loss: 0.3058 - acc: 0.8975 - val_loss: 0.3449 - val_acc: 0.8690\n", 309 | "Epoch 14/40\n", 310 | "15000/15000 [==============================] - 1s 50us/step - loss: 0.2859 - acc: 0.9031 - val_loss: 0.3316 - val_acc: 0.8731\n", 311 | "Epoch 15/40\n", 312 | "15000/15000 [==============================] - 1s 50us/step - loss: 0.2687 - acc: 0.9076 - val_loss: 0.3215 - val_acc: 0.8757\n", 313 | "Epoch 16/40\n", 314 | "15000/15000 [==============================] - 1s 51us/step - loss: 0.2540 - acc: 0.9116 - val_loss: 0.3133 - val_acc: 0.8786\n", 315 | "Epoch 17/40\n", 316 | "15000/15000 [==============================] - 1s 51us/step - loss: 0.2399 - acc: 0.9171 - val_loss: 0.3066 - val_acc: 0.8787\n", 317 | "Epoch 18/40\n", 318 | "15000/15000 [==============================] - 1s 50us/step - loss: 0.2277 - acc: 0.9219 - val_loss: 0.3011 - val_acc: 0.8809\n", 319 | "Epoch 19/40\n", 320 | "15000/15000 [==============================] - 1s 50us/step - loss: 0.2163 - acc: 0.9258 - val_loss: 0.2968 - val_acc: 0.8823\n", 321 | "Epoch 20/40\n", 322 | "15000/15000 [==============================] - 1s 51us/step - loss: 0.2063 - acc: 0.9291 - val_loss: 0.2932 - val_acc: 0.8826\n", 323 | "Epoch 21/40\n", 324 | "15000/15000 [==============================] - 1s 51us/step - loss: 0.1965 - acc: 0.9328 - val_loss: 0.2904 - val_acc: 0.8836\n", 325 | "Epoch 22/40\n", 326 | "15000/15000 [==============================] - 1s 49us/step - loss: 0.1876 - acc: 0.9367 - val_loss: 0.2886 - val_acc: 0.8841\n", 327 | "Epoch 23/40\n", 328 | "15000/15000 [==============================] - 1s 55us/step - loss: 0.1795 - acc: 0.9405 - val_loss: 0.2874 - val_acc: 0.8845\n", 329 | "Epoch 24/40\n", 330 | "15000/15000 [==============================] - 1s 53us/step - loss: 0.1713 - acc: 0.9445 - val_loss: 0.2858 - val_acc: 0.8842\n", 331 | "Epoch 25/40\n", 332 | "15000/15000 [==============================] - 1s 51us/step - loss: 0.1642 - acc: 0.9481 - val_loss: 0.2852 - val_acc: 0.8852\n", 333 | "Epoch 26/40\n", 334 | "15000/15000 [==============================] - 1s 50us/step - loss: 0.1571 - acc: 0.9499 - val_loss: 0.2854 - val_acc: 0.8854\n", 335 | "Epoch 27/40\n", 336 | "15000/15000 [==============================] - 1s 57us/step - loss: 0.1511 - acc: 0.9529 - val_loss: 0.2860 - val_acc: 0.8851\n", 337 | "Epoch 28/40\n", 338 | "15000/15000 [==============================] - 1s 76us/step - loss: 0.1448 - acc: 0.9558 - val_loss: 0.2859 - val_acc: 0.8869\n", 339 | "Epoch 29/40\n", 340 | "15000/15000 [==============================] - 1s 61us/step - loss: 0.1390 - acc: 0.9567 - val_loss: 0.2865 - val_acc: 0.8862\n", 341 | "Epoch 30/40\n", 342 | "15000/15000 [==============================] - 1s 53us/step - loss: 0.1340 - acc: 0.9599 - val_loss: 0.2880 - val_acc: 0.8867\n", 343 | "Epoch 31/40\n", 344 | "15000/15000 [==============================] - 1s 51us/step - loss: 0.1279 - acc: 0.9619 - val_loss: 0.2896 - val_acc: 0.8863\n", 345 | "Epoch 32/40\n", 346 | "15000/15000 [==============================] - 1s 51us/step - loss: 0.1233 - acc: 0.9641 - val_loss: 0.2915 - val_acc: 0.8859\n", 347 | "Epoch 33/40\n", 348 | "15000/15000 [==============================] - 1s 51us/step - loss: 0.1179 - acc: 0.9662 - val_loss: 0.2936 - val_acc: 0.8852\n", 349 | "Epoch 34/40\n", 350 | "15000/15000 [==============================] - 1s 50us/step - loss: 0.1135 - acc: 0.9678 - val_loss: 0.2961 - val_acc: 0.8851\n", 351 | "Epoch 35/40\n", 352 | "15000/15000 [==============================] - 1s 57us/step - loss: 0.1095 - acc: 0.9693 - val_loss: 0.2980 - val_acc: 0.8854\n", 353 | "Epoch 36/40\n", 354 | "15000/15000 [==============================] - 1s 56us/step - loss: 0.1046 - acc: 0.9710 - val_loss: 0.3011 - val_acc: 0.8846\n", 355 | "Epoch 37/40\n", 356 | "15000/15000 [==============================] - 1s 51us/step - loss: 0.1007 - acc: 0.9727 - val_loss: 0.3042 - val_acc: 0.8840\n", 357 | "Epoch 38/40\n", 358 | "15000/15000 [==============================] - 1s 50us/step - loss: 0.0973 - acc: 0.9730 - val_loss: 0.3073 - val_acc: 0.8833\n", 359 | "Epoch 39/40\n", 360 | "15000/15000 [==============================] - 1s 50us/step - loss: 0.0930 - acc: 0.9751 - val_loss: 0.3099 - val_acc: 0.8834\n", 361 | "Epoch 40/40\n", 362 | "15000/15000 [==============================] - 1s 50us/step - loss: 0.0893 - acc: 0.9775 - val_loss: 0.3134 - val_acc: 0.8825\n" 363 | ] 364 | } 365 | ], 366 | "source": [ 367 | "history = model.fit(partial_x_train,\n", 368 | " partial_y_train,\n", 369 | " epochs=40,\n", 370 | " batch_size=512,\n", 371 | " validation_data=(x_val, y_val),\n", 372 | " verbose=1)" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": 18, 378 | "metadata": {}, 379 | "outputs": [ 380 | { 381 | "name": "stdout", 382 | "output_type": "stream", 383 | "text": [ 384 | "25000/25000 [==============================] - 1s 54us/step\n", 385 | "[0.3345097375965118, 0.87236]\n" 386 | ] 387 | } 388 | ], 389 | "source": [ 390 | "results = model.evaluate(test_data, test_labels)\n", 391 | "\n", 392 | "print(results)" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": 19, 398 | "metadata": {}, 399 | "outputs": [ 400 | { 401 | "data": { 402 | "text/plain": [ 403 | "dict_keys(['val_loss', 'val_acc', 'loss', 'acc'])" 404 | ] 405 | }, 406 | "execution_count": 19, 407 | "metadata": {}, 408 | "output_type": "execute_result" 409 | } 410 | ], 411 | "source": [ 412 | "history_dict = history.history\n", 413 | "history_dict.keys()" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": 22, 419 | "metadata": {}, 420 | "outputs": [ 421 | { 422 | "data": { 423 | "image/png": "\n", 424 | "text/plain": [ 425 | "
" 426 | ] 427 | }, 428 | "metadata": { 429 | "needs_background": "light" 430 | }, 431 | "output_type": "display_data" 432 | } 433 | ], 434 | "source": [ 435 | "import matplotlib.pyplot as plt\n", 436 | "\n", 437 | "acc = history.history['acc']\n", 438 | "val_acc = history.history['val_acc']\n", 439 | "loss = history.history['loss']\n", 440 | "val_loss = history.history['val_loss']\n", 441 | "\n", 442 | "epochs = range(1, len(acc) + 1)\n", 443 | "\n", 444 | "plt.plot(epochs, loss, 'bo', label='Потери обучения')\n", 445 | "plt.plot(epochs, val_loss, 'b', label='Потери проверки')\n", 446 | "plt.title('Потери во время обучения и проверки')\n", 447 | "plt.xlabel('Эпохи')\n", 448 | "plt.ylabel('Потери')\n", 449 | "plt.legend()\n", 450 | "\n", 451 | "plt.show()" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": 23, 457 | "metadata": {}, 458 | "outputs": [ 459 | { 460 | "data": { 461 | "image/png": "\n", 462 | "text/plain": [ 463 | "
" 464 | ] 465 | }, 466 | "metadata": { 467 | "needs_background": "light" 468 | }, 469 | "output_type": "display_data" 470 | } 471 | ], 472 | "source": [ 473 | "plt.clf() # Очистим график\n", 474 | "acc_values = history_dict['acc']\n", 475 | "val_acc_values = history_dict['val_acc']\n", 476 | "\n", 477 | "plt.plot(epochs, acc, 'bo', label='Точность обучения')\n", 478 | "plt.plot(epochs, val_acc, 'b', label='Точность проверки')\n", 479 | "plt.title('Точность во время обучения и проверки')\n", 480 | "plt.xlabel('Эпохи')\n", 481 | "plt.ylabel('Точность')\n", 482 | "plt.legend()\n", 483 | "\n", 484 | "plt.show()" 485 | ] 486 | } 487 | ], 488 | "metadata": { 489 | "kernelspec": { 490 | "display_name": "Python 3", 491 | "language": "python", 492 | "name": "python3" 493 | }, 494 | "language_info": { 495 | "codemirror_mode": { 496 | "name": "ipython", 497 | "version": 3 498 | }, 499 | "file_extension": ".py", 500 | "mimetype": "text/x-python", 501 | "name": "python", 502 | "nbconvert_exporter": "python", 503 | "pygments_lexer": "ipython3", 504 | "version": "3.6.7" 505 | } 506 | }, 507 | "nbformat": 4, 508 | "nbformat_minor": 2 509 | } 510 | -------------------------------------------------------------------------------- /20_tf2/20_1_actor_critic_agent.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Setup" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import gym\n", 17 | "import logging\n", 18 | "import numpy as np\n", 19 | "import tensorflow as tf\n", 20 | "import tensorflow.keras.layers as kl\n", 21 | "import tensorflow.keras.losses as kls\n", 22 | "import tensorflow.keras.optimizers as ko" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "import matplotlib\n", 32 | "import matplotlib.pyplot as plt\n", 33 | "%matplotlib inline " 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "TensorFlow Ver: 1.13.0-dev20190117\n", 46 | "Eager Execution: True\n" 47 | ] 48 | } 49 | ], 50 | "source": [ 51 | "print(\"TensorFlow Ver: \", tf.__version__)\n", 52 | "print(\"Eager Execution:\", tf.executing_eagerly())" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 4, 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "name": "stdout", 62 | "output_type": "stream", 63 | "text": [ 64 | "1 + 2 + 3 + 4 + 5 = tf.Tensor(15, shape=(), dtype=int32)\n" 65 | ] 66 | } 67 | ], 68 | "source": [ 69 | "# eager by default!\n", 70 | "print(\"1 + 2 + 3 + 4 + 5 =\", tf.reduce_sum([1, 2, 3, 4, 5]))" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "# Advantage Actor-Critic with TensorFlow 2.0" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "## Policy & Value Model Class" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 5, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "class ProbabilityDistribution(tf.keras.Model):\n", 94 | " def call(self, logits):\n", 95 | " # sample a random categorical action from given logits\n", 96 | " return tf.squeeze(tf.random.categorical(logits, 1), axis=-1)\n", 97 | "\n", 98 | "class Model(tf.keras.Model):\n", 99 | " def __init__(self, num_actions):\n", 100 | " super().__init__('mlp_policy')\n", 101 | " # no tf.get_variable(), just simple Keras API\n", 102 | " self.hidden1 = kl.Dense(128, activation='relu')\n", 103 | " self.hidden2 = kl.Dense(128, activation='relu')\n", 104 | " self.value = kl.Dense(1, name='value')\n", 105 | " # logits are unnormalized log probabilities\n", 106 | " self.logits = kl.Dense(num_actions, name='policy_logits')\n", 107 | " self.dist = ProbabilityDistribution()\n", 108 | "\n", 109 | " def call(self, inputs):\n", 110 | " # inputs is a numpy array, convert to Tensor\n", 111 | " x = tf.convert_to_tensor(inputs)\n", 112 | " # separate hidden layers from the same input tensor\n", 113 | " hidden_logs = self.hidden1(x)\n", 114 | " hidden_vals = self.hidden2(x)\n", 115 | " return self.logits(hidden_logs), self.value(hidden_vals)\n", 116 | "\n", 117 | " def action_value(self, obs):\n", 118 | " # executes call() under the hood\n", 119 | " logits, value = self.predict(obs)\n", 120 | " action = self.dist.predict(logits)\n", 121 | " # a simpler option, will become clear later why we don't use it\n", 122 | " # action = tf.random.categorical(logits, 1)\n", 123 | " return np.squeeze(action, axis=-1), np.squeeze(value, axis=-1)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "## Advantage Actor-Critic Agent Class" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 6, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "class A2CAgent:\n", 140 | " def __init__(self, model):\n", 141 | " # hyperparameters for loss terms, gamma is the discount coefficient\n", 142 | " self.params = {\n", 143 | " 'gamma': 0.99,\n", 144 | " 'value': 0.5,\n", 145 | " 'entropy': 0.0001\n", 146 | " }\n", 147 | " self.model = model\n", 148 | " self.model.compile(\n", 149 | " optimizer=ko.RMSprop(lr=0.0007),\n", 150 | " # define separate losses for policy logits and value estimate\n", 151 | " loss=[self._logits_loss, self._value_loss]\n", 152 | " )\n", 153 | " \n", 154 | " def train(self, env, batch_sz=32, updates=1000):\n", 155 | " # storage helpers for a single batch of data\n", 156 | " actions = np.empty((batch_sz,), dtype=np.int32)\n", 157 | " rewards, dones, values = np.empty((3, batch_sz))\n", 158 | " observations = np.empty((batch_sz,) + env.observation_space.shape)\n", 159 | " # training loop: collect samples, send to optimizer, repeat updates times\n", 160 | " ep_rews = [0.0]\n", 161 | " next_obs = env.reset()\n", 162 | " for update in range(updates):\n", 163 | " for step in range(batch_sz):\n", 164 | " observations[step] = next_obs.copy()\n", 165 | " actions[step], values[step] = self.model.action_value(next_obs[None, :])\n", 166 | " next_obs, rewards[step], dones[step], _ = env.step(actions[step])\n", 167 | "\n", 168 | " ep_rews[-1] += rewards[step]\n", 169 | " if dones[step]:\n", 170 | " ep_rews.append(0.0)\n", 171 | " next_obs = env.reset()\n", 172 | " logging.info(\"Episode: %03d, Reward: %03d\" % (len(ep_rews)-1, ep_rews[-2]))\n", 173 | "\n", 174 | " _, next_value = self.model.action_value(next_obs[None, :])\n", 175 | " returns, advs = self._returns_advantages(rewards, dones, values, next_value)\n", 176 | " # a trick to input actions and advantages through same API\n", 177 | " acts_and_advs = np.concatenate([actions[:, None], advs[:, None]], axis=-1)\n", 178 | " # performs a full training step on the collected batch\n", 179 | " # note: no need to mess around with gradients, Keras API handles it\n", 180 | " losses = self.model.train_on_batch(observations, [acts_and_advs, returns])\n", 181 | " logging.debug(\"[%d/%d] Losses: %s\" % (update+1, updates, losses))\n", 182 | " return ep_rews\n", 183 | "\n", 184 | " def test(self, env, render=False):\n", 185 | " obs, done, ep_reward = env.reset(), False, 0\n", 186 | " while not done:\n", 187 | " action, _ = self.model.action_value(obs[None, :])\n", 188 | " obs, reward, done, _ = env.step(action)\n", 189 | " ep_reward += reward\n", 190 | " if render:\n", 191 | " env.render()\n", 192 | " return ep_reward\n", 193 | "\n", 194 | " def _returns_advantages(self, rewards, dones, values, next_value):\n", 195 | " # next_value is the bootstrap value estimate of a future state (the critic)\n", 196 | " returns = np.append(np.zeros_like(rewards), next_value, axis=-1)\n", 197 | " # returns are calculated as discounted sum of future rewards\n", 198 | " for t in reversed(range(rewards.shape[0])):\n", 199 | " returns[t] = rewards[t] + self.params['gamma'] * returns[t+1] * (1-dones[t])\n", 200 | " returns = returns[:-1]\n", 201 | " # advantages are returns - baseline, value estimates in our case\n", 202 | " advantages = returns - values\n", 203 | " return returns, advantages\n", 204 | " \n", 205 | " def _value_loss(self, returns, value):\n", 206 | " # value loss is typically MSE between value estimates and returns\n", 207 | " return self.params['value']*kls.mean_squared_error(returns, value)\n", 208 | "\n", 209 | " def _logits_loss(self, acts_and_advs, logits):\n", 210 | " # a trick to input actions and advantages through same API\n", 211 | " actions, advantages = tf.split(acts_and_advs, 2, axis=-1)\n", 212 | " # polymorphic CE loss function that supports sparse and weighted options\n", 213 | " # from_logits argument ensures transformation into normalized probabilities\n", 214 | " cross_entropy = kls.CategoricalCrossentropy(from_logits=True)\n", 215 | " # policy loss is defined by policy gradients, weighted by advantages\n", 216 | " # note: we only calculate the loss on the actions we've actually taken\n", 217 | " # thus under the hood a sparse version of CE loss will be executed\n", 218 | " actions = tf.cast(actions, tf.int32)\n", 219 | " policy_loss = cross_entropy(actions, logits, sample_weight=advantages)\n", 220 | " # entropy loss can be calculated via CE over itself\n", 221 | " entropy_loss = cross_entropy(logits, logits)\n", 222 | " # here signs are flipped because optimizer minimizes\n", 223 | " return policy_loss - self.params['entropy']*entropy_loss" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 7, 229 | "metadata": {}, 230 | "outputs": [ 231 | { 232 | "name": "stdout", 233 | "output_type": "stream", 234 | "text": [ 235 | "WARNING:tensorflow:From /home/inoryy/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:655: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", 236 | "Instructions for updating:\n", 237 | "Colocations handled automatically by placer.\n" 238 | ] 239 | }, 240 | { 241 | "data": { 242 | "text/plain": [ 243 | "(array(1), array([0.00197717], dtype=float32))" 244 | ] 245 | }, 246 | "execution_count": 7, 247 | "metadata": {}, 248 | "output_type": "execute_result" 249 | } 250 | ], 251 | "source": [ 252 | "env = gym.make('CartPole-v0')\n", 253 | "model = Model(num_actions=env.action_space.n)\n", 254 | "model.action_value(env.reset()[None, :])" 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": {}, 260 | "source": [ 261 | "# Training A2C Agent & Results" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 8, 267 | "metadata": {}, 268 | "outputs": [], 269 | "source": [ 270 | "env = gym.make('CartPole-v0')\n", 271 | "model = Model(num_actions=env.action_space.n)\n", 272 | "agent = A2CAgent(model)" 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "metadata": {}, 278 | "source": [ 279 | "## Testing with Random Weights" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 9, 285 | "metadata": {}, 286 | "outputs": [ 287 | { 288 | "name": "stdout", 289 | "output_type": "stream", 290 | "text": [ 291 | "Total Episode Reward: 12 out of 200\n" 292 | ] 293 | } 294 | ], 295 | "source": [ 296 | "rewards_sum = agent.test(env)\n", 297 | "print(\"Total Episode Reward: %d out of 200\" % agent.test(env))" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 10, 303 | "metadata": {}, 304 | "outputs": [ 305 | { 306 | "name": "stderr", 307 | "output_type": "stream", 308 | "text": [ 309 | "INFO:root:Episode: 001, Reward: 019\n", 310 | "INFO:root:Episode: 002, Reward: 023\n", 311 | "INFO:root:Episode: 003, Reward: 027\n", 312 | "INFO:root:Episode: 004, Reward: 016\n", 313 | "INFO:root:Episode: 005, Reward: 035\n", 314 | "INFO:root:Episode: 006, Reward: 021\n", 315 | "INFO:root:Episode: 007, Reward: 012\n", 316 | "INFO:root:Episode: 008, Reward: 023\n", 317 | "INFO:root:Episode: 009, Reward: 021\n", 318 | "INFO:root:Episode: 010, Reward: 026\n", 319 | "INFO:root:Episode: 011, Reward: 059\n", 320 | "INFO:root:Episode: 012, Reward: 021\n", 321 | "INFO:root:Episode: 013, Reward: 012\n", 322 | "INFO:root:Episode: 014, Reward: 018\n", 323 | "INFO:root:Episode: 015, Reward: 016\n", 324 | "INFO:root:Episode: 016, Reward: 027\n", 325 | "INFO:root:Episode: 017, Reward: 032\n", 326 | "INFO:root:Episode: 018, Reward: 013\n", 327 | "INFO:root:Episode: 019, Reward: 017\n", 328 | "INFO:root:Episode: 020, Reward: 041\n", 329 | "INFO:root:Episode: 021, Reward: 015\n", 330 | "INFO:root:Episode: 022, Reward: 015\n", 331 | "INFO:root:Episode: 023, Reward: 045\n", 332 | "INFO:root:Episode: 024, Reward: 014\n", 333 | "INFO:root:Episode: 025, Reward: 018\n", 334 | "INFO:root:Episode: 026, Reward: 037\n", 335 | "INFO:root:Episode: 027, Reward: 017\n", 336 | "INFO:root:Episode: 028, Reward: 025\n", 337 | "INFO:root:Episode: 029, Reward: 044\n", 338 | "INFO:root:Episode: 030, Reward: 010\n", 339 | "INFO:root:Episode: 031, Reward: 014\n", 340 | "INFO:root:Episode: 032, Reward: 013\n", 341 | "INFO:root:Episode: 033, Reward: 017\n", 342 | "INFO:root:Episode: 034, Reward: 022\n", 343 | "INFO:root:Episode: 035, Reward: 021\n", 344 | "INFO:root:Episode: 036, Reward: 039\n", 345 | "INFO:root:Episode: 037, Reward: 013\n", 346 | "INFO:root:Episode: 038, Reward: 041\n", 347 | "INFO:root:Episode: 039, Reward: 036\n", 348 | "INFO:root:Episode: 040, Reward: 020\n", 349 | "INFO:root:Episode: 041, Reward: 041\n", 350 | "INFO:root:Episode: 042, Reward: 020\n", 351 | "INFO:root:Episode: 043, Reward: 028\n", 352 | "INFO:root:Episode: 044, Reward: 023\n", 353 | "INFO:root:Episode: 045, Reward: 077\n", 354 | "INFO:root:Episode: 046, Reward: 010\n", 355 | "INFO:root:Episode: 047, Reward: 021\n", 356 | "INFO:root:Episode: 048, Reward: 012\n", 357 | "INFO:root:Episode: 049, Reward: 031\n", 358 | "INFO:root:Episode: 050, Reward: 049\n", 359 | "INFO:root:Episode: 051, Reward: 034\n", 360 | "INFO:root:Episode: 052, Reward: 016\n", 361 | "INFO:root:Episode: 053, Reward: 034\n", 362 | "INFO:root:Episode: 054, Reward: 027\n", 363 | "INFO:root:Episode: 055, Reward: 031\n", 364 | "INFO:root:Episode: 056, Reward: 015\n", 365 | "INFO:root:Episode: 057, Reward: 012\n", 366 | "INFO:root:Episode: 058, Reward: 024\n", 367 | "INFO:root:Episode: 059, Reward: 082\n", 368 | "INFO:root:Episode: 060, Reward: 038\n", 369 | "INFO:root:Episode: 061, Reward: 026\n", 370 | "INFO:root:Episode: 062, Reward: 012\n", 371 | "INFO:root:Episode: 063, Reward: 018\n", 372 | "INFO:root:Episode: 064, Reward: 011\n", 373 | "INFO:root:Episode: 065, Reward: 048\n", 374 | "INFO:root:Episode: 066, Reward: 044\n", 375 | "INFO:root:Episode: 067, Reward: 018\n", 376 | "INFO:root:Episode: 068, Reward: 016\n", 377 | "INFO:root:Episode: 069, Reward: 012\n", 378 | "INFO:root:Episode: 070, Reward: 023\n", 379 | "INFO:root:Episode: 071, Reward: 013\n", 380 | "INFO:root:Episode: 072, Reward: 021\n", 381 | "INFO:root:Episode: 073, Reward: 014\n", 382 | "INFO:root:Episode: 074, Reward: 032\n", 383 | "INFO:root:Episode: 075, Reward: 016\n", 384 | "INFO:root:Episode: 076, Reward: 033\n", 385 | "INFO:root:Episode: 077, Reward: 022\n", 386 | "INFO:root:Episode: 078, Reward: 019\n", 387 | "INFO:root:Episode: 079, Reward: 022\n", 388 | "INFO:root:Episode: 080, Reward: 082\n", 389 | "INFO:root:Episode: 081, Reward: 016\n", 390 | "INFO:root:Episode: 082, Reward: 017\n", 391 | "INFO:root:Episode: 083, Reward: 049\n", 392 | "INFO:root:Episode: 084, Reward: 020\n", 393 | "INFO:root:Episode: 085, Reward: 023\n", 394 | "INFO:root:Episode: 086, Reward: 032\n", 395 | "INFO:root:Episode: 087, Reward: 029\n", 396 | "INFO:root:Episode: 088, Reward: 030\n", 397 | "INFO:root:Episode: 089, Reward: 029\n", 398 | "INFO:root:Episode: 090, Reward: 030\n", 399 | "INFO:root:Episode: 091, Reward: 038\n", 400 | "INFO:root:Episode: 092, Reward: 070\n", 401 | "INFO:root:Episode: 093, Reward: 018\n", 402 | "INFO:root:Episode: 094, Reward: 051\n", 403 | "INFO:root:Episode: 095, Reward: 052\n", 404 | "INFO:root:Episode: 096, Reward: 058\n", 405 | "INFO:root:Episode: 097, Reward: 020\n", 406 | "INFO:root:Episode: 098, Reward: 043\n", 407 | "INFO:root:Episode: 099, Reward: 038\n", 408 | "INFO:root:Episode: 100, Reward: 023\n", 409 | "INFO:root:Episode: 101, Reward: 025\n", 410 | "INFO:root:Episode: 102, Reward: 038\n", 411 | "INFO:root:Episode: 103, Reward: 050\n", 412 | "INFO:root:Episode: 104, Reward: 034\n", 413 | "INFO:root:Episode: 105, Reward: 022\n", 414 | "INFO:root:Episode: 106, Reward: 020\n", 415 | "INFO:root:Episode: 107, Reward: 022\n", 416 | "INFO:root:Episode: 108, Reward: 033\n", 417 | "INFO:root:Episode: 109, Reward: 021\n", 418 | "INFO:root:Episode: 110, Reward: 038\n", 419 | "INFO:root:Episode: 111, Reward: 042\n", 420 | "INFO:root:Episode: 112, Reward: 014\n", 421 | "INFO:root:Episode: 113, Reward: 081\n", 422 | "INFO:root:Episode: 114, Reward: 029\n", 423 | "INFO:root:Episode: 115, Reward: 025\n", 424 | "INFO:root:Episode: 116, Reward: 029\n", 425 | "INFO:root:Episode: 117, Reward: 022\n", 426 | "INFO:root:Episode: 118, Reward: 109\n", 427 | "INFO:root:Episode: 119, Reward: 048\n", 428 | "INFO:root:Episode: 120, Reward: 022\n", 429 | "INFO:root:Episode: 121, Reward: 024\n", 430 | "INFO:root:Episode: 122, Reward: 029\n", 431 | "INFO:root:Episode: 123, Reward: 023\n", 432 | "INFO:root:Episode: 124, Reward: 042\n", 433 | "INFO:root:Episode: 125, Reward: 023\n", 434 | "INFO:root:Episode: 126, Reward: 013\n", 435 | "INFO:root:Episode: 127, Reward: 034\n", 436 | "INFO:root:Episode: 128, Reward: 033\n", 437 | "INFO:root:Episode: 129, Reward: 034\n", 438 | "INFO:root:Episode: 130, Reward: 063\n", 439 | "INFO:root:Episode: 131, Reward: 060\n", 440 | "INFO:root:Episode: 132, Reward: 018\n", 441 | "INFO:root:Episode: 133, Reward: 039\n", 442 | "INFO:root:Episode: 134, Reward: 015\n", 443 | "INFO:root:Episode: 135, Reward: 035\n", 444 | "INFO:root:Episode: 136, Reward: 132\n", 445 | "INFO:root:Episode: 137, Reward: 035\n", 446 | "INFO:root:Episode: 138, Reward: 033\n", 447 | "INFO:root:Episode: 139, Reward: 028\n", 448 | "INFO:root:Episode: 140, Reward: 015\n", 449 | "INFO:root:Episode: 141, Reward: 013\n", 450 | "INFO:root:Episode: 142, Reward: 101\n", 451 | "INFO:root:Episode: 143, Reward: 028\n", 452 | "INFO:root:Episode: 144, Reward: 066\n", 453 | "INFO:root:Episode: 145, Reward: 200\n", 454 | "INFO:root:Episode: 146, Reward: 059\n", 455 | "INFO:root:Episode: 147, Reward: 077\n", 456 | "INFO:root:Episode: 148, Reward: 021\n", 457 | "INFO:root:Episode: 149, Reward: 030\n", 458 | "INFO:root:Episode: 150, Reward: 053\n", 459 | "INFO:root:Episode: 151, Reward: 019\n", 460 | "INFO:root:Episode: 152, Reward: 035\n", 461 | "INFO:root:Episode: 153, Reward: 035\n", 462 | "INFO:root:Episode: 154, Reward: 069\n", 463 | "INFO:root:Episode: 155, Reward: 108\n", 464 | "INFO:root:Episode: 156, Reward: 079\n", 465 | "INFO:root:Episode: 157, Reward: 021\n", 466 | "INFO:root:Episode: 158, Reward: 026\n", 467 | "INFO:root:Episode: 159, Reward: 045\n", 468 | "INFO:root:Episode: 160, Reward: 025\n", 469 | "INFO:root:Episode: 161, Reward: 069\n", 470 | "INFO:root:Episode: 162, Reward: 016\n", 471 | "INFO:root:Episode: 163, Reward: 036\n", 472 | "INFO:root:Episode: 164, Reward: 063\n", 473 | "INFO:root:Episode: 165, Reward: 039\n", 474 | "INFO:root:Episode: 166, Reward: 075\n", 475 | "INFO:root:Episode: 167, Reward: 035\n", 476 | "INFO:root:Episode: 168, Reward: 059\n", 477 | "INFO:root:Episode: 169, Reward: 025\n", 478 | "INFO:root:Episode: 170, Reward: 069\n", 479 | "INFO:root:Episode: 171, Reward: 063\n", 480 | "INFO:root:Episode: 172, Reward: 024\n", 481 | "INFO:root:Episode: 173, Reward: 023\n", 482 | "INFO:root:Episode: 174, Reward: 082\n", 483 | "INFO:root:Episode: 175, Reward: 048\n", 484 | "INFO:root:Episode: 176, Reward: 049\n", 485 | "INFO:root:Episode: 177, Reward: 076\n", 486 | "INFO:root:Episode: 178, Reward: 024\n", 487 | "INFO:root:Episode: 179, Reward: 067\n", 488 | "INFO:root:Episode: 180, Reward: 045\n", 489 | "INFO:root:Episode: 181, Reward: 035\n", 490 | "INFO:root:Episode: 182, Reward: 044\n", 491 | "INFO:root:Episode: 183, Reward: 044\n", 492 | "INFO:root:Episode: 184, Reward: 026\n", 493 | "INFO:root:Episode: 185, Reward: 068\n", 494 | "INFO:root:Episode: 186, Reward: 020\n", 495 | "INFO:root:Episode: 187, Reward: 047\n", 496 | "INFO:root:Episode: 188, Reward: 028\n", 497 | "INFO:root:Episode: 189, Reward: 053\n", 498 | "INFO:root:Episode: 190, Reward: 089\n", 499 | "INFO:root:Episode: 191, Reward: 042\n", 500 | "INFO:root:Episode: 192, Reward: 023\n", 501 | "INFO:root:Episode: 193, Reward: 079\n", 502 | "INFO:root:Episode: 194, Reward: 051\n", 503 | "INFO:root:Episode: 195, Reward: 038\n", 504 | "INFO:root:Episode: 196, Reward: 116\n", 505 | "INFO:root:Episode: 197, Reward: 067\n", 506 | "INFO:root:Episode: 198, Reward: 082\n", 507 | "INFO:root:Episode: 199, Reward: 122\n", 508 | "INFO:root:Episode: 200, Reward: 113\n", 509 | "INFO:root:Episode: 201, Reward: 035\n", 510 | "INFO:root:Episode: 202, Reward: 061\n", 511 | "INFO:root:Episode: 203, Reward: 132\n", 512 | "INFO:root:Episode: 204, Reward: 033\n", 513 | "INFO:root:Episode: 205, Reward: 093\n", 514 | "INFO:root:Episode: 206, Reward: 125\n", 515 | "INFO:root:Episode: 207, Reward: 040\n", 516 | "INFO:root:Episode: 208, Reward: 044\n", 517 | "INFO:root:Episode: 209, Reward: 034\n", 518 | "INFO:root:Episode: 210, Reward: 059\n", 519 | "INFO:root:Episode: 211, Reward: 063\n", 520 | "INFO:root:Episode: 212, Reward: 116\n", 521 | "INFO:root:Episode: 213, Reward: 061\n", 522 | "INFO:root:Episode: 214, Reward: 086\n", 523 | "INFO:root:Episode: 215, Reward: 065\n", 524 | "INFO:root:Episode: 216, Reward: 031\n", 525 | "INFO:root:Episode: 217, Reward: 064\n", 526 | "INFO:root:Episode: 218, Reward: 153\n", 527 | "INFO:root:Episode: 219, Reward: 200\n", 528 | "INFO:root:Episode: 220, Reward: 088\n", 529 | "INFO:root:Episode: 221, Reward: 035\n", 530 | "INFO:root:Episode: 222, Reward: 113\n", 531 | "INFO:root:Episode: 223, Reward: 080\n", 532 | "INFO:root:Episode: 224, Reward: 048\n", 533 | "INFO:root:Episode: 225, Reward: 044\n", 534 | "INFO:root:Episode: 226, Reward: 061\n", 535 | "INFO:root:Episode: 227, Reward: 077\n", 536 | "INFO:root:Episode: 228, Reward: 025\n" 537 | ] 538 | }, 539 | { 540 | "name": "stderr", 541 | "output_type": "stream", 542 | "text": [ 543 | "INFO:root:Episode: 229, Reward: 026\n", 544 | "INFO:root:Episode: 230, Reward: 054\n", 545 | "INFO:root:Episode: 231, Reward: 120\n", 546 | "INFO:root:Episode: 232, Reward: 074\n", 547 | "INFO:root:Episode: 233, Reward: 122\n", 548 | "INFO:root:Episode: 234, Reward: 098\n", 549 | "INFO:root:Episode: 235, Reward: 034\n", 550 | "INFO:root:Episode: 236, Reward: 086\n", 551 | "INFO:root:Episode: 237, Reward: 126\n", 552 | "INFO:root:Episode: 238, Reward: 200\n", 553 | "INFO:root:Episode: 239, Reward: 175\n", 554 | "INFO:root:Episode: 240, Reward: 059\n", 555 | "INFO:root:Episode: 241, Reward: 045\n", 556 | "INFO:root:Episode: 242, Reward: 029\n", 557 | "INFO:root:Episode: 243, Reward: 027\n", 558 | "INFO:root:Episode: 244, Reward: 128\n", 559 | "INFO:root:Episode: 245, Reward: 104\n", 560 | "INFO:root:Episode: 246, Reward: 133\n", 561 | "INFO:root:Episode: 247, Reward: 101\n", 562 | "INFO:root:Episode: 248, Reward: 043\n", 563 | "INFO:root:Episode: 249, Reward: 053\n", 564 | "INFO:root:Episode: 250, Reward: 065\n", 565 | "INFO:root:Episode: 251, Reward: 072\n", 566 | "INFO:root:Episode: 252, Reward: 093\n", 567 | "INFO:root:Episode: 253, Reward: 200\n", 568 | "INFO:root:Episode: 254, Reward: 156\n", 569 | "INFO:root:Episode: 255, Reward: 053\n", 570 | "INFO:root:Episode: 256, Reward: 057\n", 571 | "INFO:root:Episode: 257, Reward: 121\n", 572 | "INFO:root:Episode: 258, Reward: 051\n", 573 | "INFO:root:Episode: 259, Reward: 095\n", 574 | "INFO:root:Episode: 260, Reward: 096\n", 575 | "INFO:root:Episode: 261, Reward: 053\n", 576 | "INFO:root:Episode: 262, Reward: 193\n", 577 | "INFO:root:Episode: 263, Reward: 083\n", 578 | "INFO:root:Episode: 264, Reward: 060\n", 579 | "INFO:root:Episode: 265, Reward: 100\n", 580 | "INFO:root:Episode: 266, Reward: 113\n", 581 | "INFO:root:Episode: 267, Reward: 120\n", 582 | "INFO:root:Episode: 268, Reward: 038\n", 583 | "INFO:root:Episode: 269, Reward: 084\n", 584 | "INFO:root:Episode: 270, Reward: 049\n", 585 | "INFO:root:Episode: 271, Reward: 066\n", 586 | "INFO:root:Episode: 272, Reward: 166\n", 587 | "INFO:root:Episode: 273, Reward: 144\n", 588 | "INFO:root:Episode: 274, Reward: 053\n", 589 | "INFO:root:Episode: 275, Reward: 057\n", 590 | "INFO:root:Episode: 276, Reward: 092\n", 591 | "INFO:root:Episode: 277, Reward: 122\n", 592 | "INFO:root:Episode: 278, Reward: 153\n", 593 | "INFO:root:Episode: 279, Reward: 131\n", 594 | "INFO:root:Episode: 280, Reward: 200\n", 595 | "INFO:root:Episode: 281, Reward: 074\n", 596 | "INFO:root:Episode: 282, Reward: 147\n", 597 | "INFO:root:Episode: 283, Reward: 079\n", 598 | "INFO:root:Episode: 284, Reward: 120\n", 599 | "INFO:root:Episode: 285, Reward: 136\n", 600 | "INFO:root:Episode: 286, Reward: 133\n", 601 | "INFO:root:Episode: 287, Reward: 133\n", 602 | "INFO:root:Episode: 288, Reward: 088\n", 603 | "INFO:root:Episode: 289, Reward: 057\n", 604 | "INFO:root:Episode: 290, Reward: 185\n", 605 | "INFO:root:Episode: 291, Reward: 087\n", 606 | "INFO:root:Episode: 292, Reward: 154\n", 607 | "INFO:root:Episode: 293, Reward: 200\n", 608 | "INFO:root:Episode: 294, Reward: 114\n", 609 | "INFO:root:Episode: 295, Reward: 118\n", 610 | "INFO:root:Episode: 296, Reward: 089\n", 611 | "INFO:root:Episode: 297, Reward: 069\n", 612 | "INFO:root:Episode: 298, Reward: 155\n", 613 | "INFO:root:Episode: 299, Reward: 109\n", 614 | "INFO:root:Episode: 300, Reward: 095\n", 615 | "INFO:root:Episode: 301, Reward: 200\n", 616 | "INFO:root:Episode: 302, Reward: 200\n", 617 | "INFO:root:Episode: 303, Reward: 139\n", 618 | "INFO:root:Episode: 304, Reward: 200\n", 619 | "INFO:root:Episode: 305, Reward: 099\n", 620 | "INFO:root:Episode: 306, Reward: 133\n", 621 | "INFO:root:Episode: 307, Reward: 152\n", 622 | "INFO:root:Episode: 308, Reward: 177\n", 623 | "INFO:root:Episode: 309, Reward: 140\n", 624 | "INFO:root:Episode: 310, Reward: 167\n", 625 | "INFO:root:Episode: 311, Reward: 134\n", 626 | "INFO:root:Episode: 312, Reward: 200\n", 627 | "INFO:root:Episode: 313, Reward: 200\n", 628 | "INFO:root:Episode: 314, Reward: 154\n", 629 | "INFO:root:Episode: 315, Reward: 200\n", 630 | "INFO:root:Episode: 316, Reward: 141\n", 631 | "INFO:root:Episode: 317, Reward: 200\n", 632 | "INFO:root:Episode: 318, Reward: 072\n", 633 | "INFO:root:Episode: 319, Reward: 128\n", 634 | "INFO:root:Episode: 320, Reward: 190\n", 635 | "INFO:root:Episode: 321, Reward: 200\n", 636 | "INFO:root:Episode: 322, Reward: 108\n", 637 | "INFO:root:Episode: 323, Reward: 038\n", 638 | "INFO:root:Episode: 324, Reward: 200\n", 639 | "INFO:root:Episode: 325, Reward: 102\n", 640 | "INFO:root:Episode: 326, Reward: 200\n", 641 | "INFO:root:Episode: 327, Reward: 200\n", 642 | "INFO:root:Episode: 328, Reward: 200\n", 643 | "INFO:root:Episode: 329, Reward: 151\n", 644 | "INFO:root:Episode: 330, Reward: 200\n", 645 | "INFO:root:Episode: 331, Reward: 129\n", 646 | "INFO:root:Episode: 332, Reward: 086\n", 647 | "INFO:root:Episode: 333, Reward: 174\n", 648 | "INFO:root:Episode: 334, Reward: 157\n", 649 | "INFO:root:Episode: 335, Reward: 200\n", 650 | "INFO:root:Episode: 336, Reward: 060\n", 651 | "INFO:root:Episode: 337, Reward: 200\n", 652 | "INFO:root:Episode: 338, Reward: 200\n", 653 | "INFO:root:Episode: 339, Reward: 036\n", 654 | "INFO:root:Episode: 340, Reward: 111\n", 655 | "INFO:root:Episode: 341, Reward: 200\n", 656 | "INFO:root:Episode: 342, Reward: 200\n", 657 | "INFO:root:Episode: 343, Reward: 200\n", 658 | "INFO:root:Episode: 344, Reward: 193\n", 659 | "INFO:root:Episode: 345, Reward: 200\n", 660 | "INFO:root:Episode: 346, Reward: 174\n", 661 | "INFO:root:Episode: 347, Reward: 200\n", 662 | "INFO:root:Episode: 348, Reward: 146\n", 663 | "INFO:root:Episode: 349, Reward: 150\n", 664 | "INFO:root:Episode: 350, Reward: 146\n", 665 | "INFO:root:Episode: 351, Reward: 148\n", 666 | "INFO:root:Episode: 352, Reward: 144\n", 667 | "INFO:root:Episode: 353, Reward: 162\n", 668 | "INFO:root:Episode: 354, Reward: 200\n", 669 | "INFO:root:Episode: 355, Reward: 200\n", 670 | "INFO:root:Episode: 356, Reward: 133\n", 671 | "INFO:root:Episode: 357, Reward: 152\n", 672 | "INFO:root:Episode: 358, Reward: 096\n", 673 | "INFO:root:Episode: 359, Reward: 069\n", 674 | "INFO:root:Episode: 360, Reward: 039\n", 675 | "INFO:root:Episode: 361, Reward: 115\n", 676 | "INFO:root:Episode: 362, Reward: 130\n", 677 | "INFO:root:Episode: 363, Reward: 077\n", 678 | "INFO:root:Episode: 364, Reward: 128\n", 679 | "INFO:root:Episode: 365, Reward: 098\n", 680 | "INFO:root:Episode: 366, Reward: 129\n", 681 | "INFO:root:Episode: 367, Reward: 033\n", 682 | "INFO:root:Episode: 368, Reward: 200\n", 683 | "INFO:root:Episode: 369, Reward: 140\n", 684 | "INFO:root:Episode: 370, Reward: 155\n", 685 | "INFO:root:Episode: 371, Reward: 130\n", 686 | "INFO:root:Episode: 372, Reward: 167\n", 687 | "INFO:root:Episode: 373, Reward: 170\n", 688 | "INFO:root:Episode: 374, Reward: 180\n", 689 | "INFO:root:Episode: 375, Reward: 147\n", 690 | "INFO:root:Episode: 376, Reward: 114\n", 691 | "INFO:root:Episode: 377, Reward: 054\n", 692 | "INFO:root:Episode: 378, Reward: 200\n", 693 | "INFO:root:Episode: 379, Reward: 072\n", 694 | "INFO:root:Episode: 380, Reward: 200\n", 695 | "INFO:root:Episode: 381, Reward: 200\n", 696 | "INFO:root:Episode: 382, Reward: 200\n", 697 | "INFO:root:Episode: 383, Reward: 182\n", 698 | "INFO:root:Episode: 384, Reward: 200\n", 699 | "INFO:root:Episode: 385, Reward: 200\n", 700 | "INFO:root:Episode: 386, Reward: 193\n", 701 | "INFO:root:Episode: 387, Reward: 200\n", 702 | "INFO:root:Episode: 388, Reward: 095\n", 703 | "INFO:root:Episode: 389, Reward: 200\n", 704 | "INFO:root:Episode: 390, Reward: 125\n", 705 | "INFO:root:Episode: 391, Reward: 158\n", 706 | "INFO:root:Episode: 392, Reward: 148\n", 707 | "INFO:root:Episode: 393, Reward: 083\n", 708 | "INFO:root:Episode: 394, Reward: 200\n", 709 | "INFO:root:Episode: 395, Reward: 200\n", 710 | "INFO:root:Episode: 396, Reward: 200\n", 711 | "INFO:root:Episode: 397, Reward: 156\n", 712 | "INFO:root:Episode: 398, Reward: 068\n" 713 | ] 714 | }, 715 | { 716 | "name": "stdout", 717 | "output_type": "stream", 718 | "text": [ 719 | "Finished training.\n" 720 | ] 721 | } 722 | ], 723 | "source": [ 724 | "# set to logging.WARNING to disable logs or logging.DEBUG to see losses as well\n", 725 | "logging.basicConfig(level=logging.INFO)\n", 726 | "\n", 727 | "rewards_history = agent.train(env)\n", 728 | "print(\"Finished training.\")" 729 | ] 730 | }, 731 | { 732 | "cell_type": "markdown", 733 | "metadata": {}, 734 | "source": [ 735 | "## Testing with Trained Model" 736 | ] 737 | }, 738 | { 739 | "cell_type": "code", 740 | "execution_count": 11, 741 | "metadata": {}, 742 | "outputs": [ 743 | { 744 | "name": "stdout", 745 | "output_type": "stream", 746 | "text": [ 747 | "Total Episode Reward: 200 out of 200\n" 748 | ] 749 | } 750 | ], 751 | "source": [ 752 | "print(\"Total Episode Reward: %d out of 200\" % agent.test(env))" 753 | ] 754 | }, 755 | { 756 | "cell_type": "markdown", 757 | "metadata": {}, 758 | "source": [ 759 | "## Training Rewards History" 760 | ] 761 | }, 762 | { 763 | "cell_type": "code", 764 | "execution_count": 12, 765 | "metadata": {}, 766 | "outputs": [ 767 | { 768 | "data": { 769 | "image/png": "\n", 770 | "text/plain": [ 771 | "
" 772 | ] 773 | }, 774 | "metadata": {}, 775 | "output_type": "display_data" 776 | } 777 | ], 778 | "source": [ 779 | "plt.style.use('seaborn')\n", 780 | "plt.plot(np.arange(0, len(rewards_history), 25), rewards_history[::25])\n", 781 | "plt.xlabel('Episode')\n", 782 | "plt.ylabel('Total Reward')\n", 783 | "plt.show()" 784 | ] 785 | }, 786 | { 787 | "cell_type": "markdown", 788 | "metadata": {}, 789 | "source": [ 790 | "# Static Computational Graph" 791 | ] 792 | }, 793 | { 794 | "cell_type": "code", 795 | "execution_count": 13, 796 | "metadata": {}, 797 | "outputs": [ 798 | { 799 | "name": "stdout", 800 | "output_type": "stream", 801 | "text": [ 802 | "Eager Execution: False\n", 803 | "WARNING:tensorflow:From /home/inoryy/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:1253: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", 804 | "Instructions for updating:\n", 805 | "Call initializer instance with the dtype argument instead of passing it to the constructor\n" 806 | ] 807 | }, 808 | { 809 | "name": "stderr", 810 | "output_type": "stream", 811 | "text": [ 812 | "WARNING:tensorflow:From /home/inoryy/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:1253: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", 813 | "Instructions for updating:\n", 814 | "Call initializer instance with the dtype argument instead of passing it to the constructor\n", 815 | "INFO:root:Episode: 001, Reward: 020\n" 816 | ] 817 | }, 818 | { 819 | "name": "stdout", 820 | "output_type": "stream", 821 | "text": [ 822 | "WARNING:tensorflow:From /home/inoryy/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer_utils.py:123: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", 823 | "Instructions for updating:\n", 824 | "Call initializer instance with the dtype argument instead of passing it to the constructor\n" 825 | ] 826 | }, 827 | { 828 | "name": "stderr", 829 | "output_type": "stream", 830 | "text": [ 831 | "WARNING:tensorflow:From /home/inoryy/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer_utils.py:123: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n", 832 | "Instructions for updating:\n", 833 | "Call initializer instance with the dtype argument instead of passing it to the constructor\n", 834 | "INFO:root:Episode: 002, Reward: 027\n", 835 | "INFO:root:Episode: 003, Reward: 032\n", 836 | "INFO:root:Episode: 004, Reward: 017\n", 837 | "INFO:root:Episode: 005, Reward: 012\n", 838 | "INFO:root:Episode: 006, Reward: 015\n", 839 | "INFO:root:Episode: 007, Reward: 025\n", 840 | "INFO:root:Episode: 008, Reward: 010\n", 841 | "INFO:root:Episode: 009, Reward: 026\n", 842 | "INFO:root:Episode: 010, Reward: 014\n", 843 | "INFO:root:Episode: 011, Reward: 054\n", 844 | "INFO:root:Episode: 012, Reward: 012\n", 845 | "INFO:root:Episode: 013, Reward: 038\n", 846 | "INFO:root:Episode: 014, Reward: 024\n", 847 | "INFO:root:Episode: 015, Reward: 022\n", 848 | "INFO:root:Episode: 016, Reward: 034\n", 849 | "INFO:root:Episode: 017, Reward: 072\n", 850 | "INFO:root:Episode: 018, Reward: 022\n", 851 | "INFO:root:Episode: 019, Reward: 029\n", 852 | "INFO:root:Episode: 020, Reward: 020\n", 853 | "INFO:root:Episode: 021, Reward: 017\n", 854 | "INFO:root:Episode: 022, Reward: 013\n", 855 | "INFO:root:Episode: 023, Reward: 067\n", 856 | "INFO:root:Episode: 024, Reward: 100\n", 857 | "INFO:root:Episode: 025, Reward: 056\n", 858 | "INFO:root:Episode: 026, Reward: 096\n", 859 | "INFO:root:Episode: 027, Reward: 019\n", 860 | "INFO:root:Episode: 028, Reward: 016\n", 861 | "INFO:root:Episode: 029, Reward: 023\n", 862 | "INFO:root:Episode: 030, Reward: 013\n", 863 | "INFO:root:Episode: 031, Reward: 020\n", 864 | "INFO:root:Episode: 032, Reward: 023\n", 865 | "INFO:root:Episode: 033, Reward: 026\n", 866 | "INFO:root:Episode: 034, Reward: 070\n", 867 | "INFO:root:Episode: 035, Reward: 033\n", 868 | "INFO:root:Episode: 036, Reward: 028\n", 869 | "INFO:root:Episode: 037, Reward: 059\n", 870 | "INFO:root:Episode: 038, Reward: 047\n", 871 | "INFO:root:Episode: 039, Reward: 026\n", 872 | "INFO:root:Episode: 040, Reward: 028\n", 873 | "INFO:root:Episode: 041, Reward: 034\n", 874 | "INFO:root:Episode: 042, Reward: 065\n", 875 | "INFO:root:Episode: 043, Reward: 014\n", 876 | "INFO:root:Episode: 044, Reward: 028\n", 877 | "INFO:root:Episode: 045, Reward: 018\n", 878 | "INFO:root:Episode: 046, Reward: 011\n", 879 | "INFO:root:Episode: 047, Reward: 011\n", 880 | "INFO:root:Episode: 048, Reward: 020\n", 881 | "INFO:root:Episode: 049, Reward: 015\n", 882 | "INFO:root:Episode: 050, Reward: 031\n", 883 | "INFO:root:Episode: 051, Reward: 017\n", 884 | "INFO:root:Episode: 052, Reward: 025\n", 885 | "INFO:root:Episode: 053, Reward: 027\n", 886 | "INFO:root:Episode: 054, Reward: 026\n", 887 | "INFO:root:Episode: 055, Reward: 060\n", 888 | "INFO:root:Episode: 056, Reward: 020\n", 889 | "INFO:root:Episode: 057, Reward: 056\n", 890 | "INFO:root:Episode: 058, Reward: 051\n", 891 | "INFO:root:Episode: 059, Reward: 036\n", 892 | "INFO:root:Episode: 060, Reward: 022\n", 893 | "INFO:root:Episode: 061, Reward: 013\n", 894 | "INFO:root:Episode: 062, Reward: 026\n", 895 | "INFO:root:Episode: 063, Reward: 030\n", 896 | "INFO:root:Episode: 064, Reward: 019\n", 897 | "INFO:root:Episode: 065, Reward: 044\n", 898 | "INFO:root:Episode: 066, Reward: 078\n", 899 | "INFO:root:Episode: 067, Reward: 047\n", 900 | "INFO:root:Episode: 068, Reward: 019\n", 901 | "INFO:root:Episode: 069, Reward: 020\n", 902 | "INFO:root:Episode: 070, Reward: 066\n", 903 | "INFO:root:Episode: 071, Reward: 026\n", 904 | "INFO:root:Episode: 072, Reward: 037\n", 905 | "INFO:root:Episode: 073, Reward: 037\n", 906 | "INFO:root:Episode: 074, Reward: 023\n", 907 | "INFO:root:Episode: 075, Reward: 010\n", 908 | "INFO:root:Episode: 076, Reward: 039\n", 909 | "INFO:root:Episode: 077, Reward: 033\n", 910 | "INFO:root:Episode: 078, Reward: 063\n", 911 | "INFO:root:Episode: 079, Reward: 016\n", 912 | "INFO:root:Episode: 080, Reward: 053\n", 913 | "INFO:root:Episode: 081, Reward: 037\n", 914 | "INFO:root:Episode: 082, Reward: 035\n", 915 | "INFO:root:Episode: 083, Reward: 054\n", 916 | "INFO:root:Episode: 084, Reward: 014\n", 917 | "INFO:root:Episode: 085, Reward: 061\n", 918 | "INFO:root:Episode: 086, Reward: 012\n", 919 | "INFO:root:Episode: 087, Reward: 040\n", 920 | "INFO:root:Episode: 088, Reward: 059\n", 921 | "INFO:root:Episode: 089, Reward: 031\n", 922 | "INFO:root:Episode: 090, Reward: 114\n", 923 | "INFO:root:Episode: 091, Reward: 017\n", 924 | "INFO:root:Episode: 092, Reward: 023\n", 925 | "INFO:root:Episode: 093, Reward: 042\n", 926 | "INFO:root:Episode: 094, Reward: 025\n", 927 | "INFO:root:Episode: 095, Reward: 027\n", 928 | "INFO:root:Episode: 096, Reward: 013\n", 929 | "INFO:root:Episode: 097, Reward: 051\n", 930 | "INFO:root:Episode: 098, Reward: 048\n", 931 | "INFO:root:Episode: 099, Reward: 071\n", 932 | "INFO:root:Episode: 100, Reward: 034\n", 933 | "INFO:root:Episode: 101, Reward: 032\n", 934 | "INFO:root:Episode: 102, Reward: 045\n", 935 | "INFO:root:Episode: 103, Reward: 096\n", 936 | "INFO:root:Episode: 104, Reward: 030\n", 937 | "INFO:root:Episode: 105, Reward: 071\n", 938 | "INFO:root:Episode: 106, Reward: 048\n", 939 | "INFO:root:Episode: 107, Reward: 037\n", 940 | "INFO:root:Episode: 108, Reward: 027\n", 941 | "INFO:root:Episode: 109, Reward: 024\n", 942 | "INFO:root:Episode: 110, Reward: 036\n", 943 | "INFO:root:Episode: 111, Reward: 080\n", 944 | "INFO:root:Episode: 112, Reward: 037\n", 945 | "INFO:root:Episode: 113, Reward: 048\n", 946 | "INFO:root:Episode: 114, Reward: 024\n", 947 | "INFO:root:Episode: 115, Reward: 042\n", 948 | "INFO:root:Episode: 116, Reward: 057\n", 949 | "INFO:root:Episode: 117, Reward: 104\n", 950 | "INFO:root:Episode: 118, Reward: 017\n", 951 | "INFO:root:Episode: 119, Reward: 020\n", 952 | "INFO:root:Episode: 120, Reward: 029\n", 953 | "INFO:root:Episode: 121, Reward: 041\n", 954 | "INFO:root:Episode: 122, Reward: 070\n", 955 | "INFO:root:Episode: 123, Reward: 049\n", 956 | "INFO:root:Episode: 124, Reward: 029\n", 957 | "INFO:root:Episode: 125, Reward: 029\n", 958 | "INFO:root:Episode: 126, Reward: 030\n", 959 | "INFO:root:Episode: 127, Reward: 065\n", 960 | "INFO:root:Episode: 128, Reward: 024\n", 961 | "INFO:root:Episode: 129, Reward: 018\n", 962 | "INFO:root:Episode: 130, Reward: 062\n", 963 | "INFO:root:Episode: 131, Reward: 033\n", 964 | "INFO:root:Episode: 132, Reward: 020\n", 965 | "INFO:root:Episode: 133, Reward: 050\n", 966 | "INFO:root:Episode: 134, Reward: 029\n", 967 | "INFO:root:Episode: 135, Reward: 016\n", 968 | "INFO:root:Episode: 136, Reward: 056\n", 969 | "INFO:root:Episode: 137, Reward: 026\n", 970 | "INFO:root:Episode: 138, Reward: 025\n", 971 | "INFO:root:Episode: 139, Reward: 047\n", 972 | "INFO:root:Episode: 140, Reward: 038\n", 973 | "INFO:root:Episode: 141, Reward: 033\n", 974 | "INFO:root:Episode: 142, Reward: 017\n", 975 | "INFO:root:Episode: 143, Reward: 068\n", 976 | "INFO:root:Episode: 144, Reward: 023\n", 977 | "INFO:root:Episode: 145, Reward: 168\n", 978 | "INFO:root:Episode: 146, Reward: 046\n", 979 | "INFO:root:Episode: 147, Reward: 044\n", 980 | "INFO:root:Episode: 148, Reward: 022\n", 981 | "INFO:root:Episode: 149, Reward: 026\n", 982 | "INFO:root:Episode: 150, Reward: 037\n", 983 | "INFO:root:Episode: 151, Reward: 091\n", 984 | "INFO:root:Episode: 152, Reward: 025\n", 985 | "INFO:root:Episode: 153, Reward: 038\n", 986 | "INFO:root:Episode: 154, Reward: 039\n", 987 | "INFO:root:Episode: 155, Reward: 047\n", 988 | "INFO:root:Episode: 156, Reward: 025\n", 989 | "INFO:root:Episode: 157, Reward: 047\n", 990 | "INFO:root:Episode: 158, Reward: 013\n", 991 | "INFO:root:Episode: 159, Reward: 069\n", 992 | "INFO:root:Episode: 160, Reward: 019\n", 993 | "INFO:root:Episode: 161, Reward: 035\n", 994 | "INFO:root:Episode: 162, Reward: 039\n", 995 | "INFO:root:Episode: 163, Reward: 028\n", 996 | "INFO:root:Episode: 164, Reward: 021\n", 997 | "INFO:root:Episode: 165, Reward: 049\n", 998 | "INFO:root:Episode: 166, Reward: 119\n", 999 | "INFO:root:Episode: 167, Reward: 043\n", 1000 | "INFO:root:Episode: 168, Reward: 067\n", 1001 | "INFO:root:Episode: 169, Reward: 124\n", 1002 | "INFO:root:Episode: 170, Reward: 021\n", 1003 | "INFO:root:Episode: 171, Reward: 049\n", 1004 | "INFO:root:Episode: 172, Reward: 051\n", 1005 | "INFO:root:Episode: 173, Reward: 088\n", 1006 | "INFO:root:Episode: 174, Reward: 056\n", 1007 | "INFO:root:Episode: 175, Reward: 144\n", 1008 | "INFO:root:Episode: 176, Reward: 085\n", 1009 | "INFO:root:Episode: 177, Reward: 116\n", 1010 | "INFO:root:Episode: 178, Reward: 090\n", 1011 | "INFO:root:Episode: 179, Reward: 020\n", 1012 | "INFO:root:Episode: 180, Reward: 038\n", 1013 | "INFO:root:Episode: 181, Reward: 127\n", 1014 | "INFO:root:Episode: 182, Reward: 037\n", 1015 | "INFO:root:Episode: 183, Reward: 053\n", 1016 | "INFO:root:Episode: 184, Reward: 059\n", 1017 | "INFO:root:Episode: 185, Reward: 022\n", 1018 | "INFO:root:Episode: 186, Reward: 068\n", 1019 | "INFO:root:Episode: 187, Reward: 033\n", 1020 | "INFO:root:Episode: 188, Reward: 072\n", 1021 | "INFO:root:Episode: 189, Reward: 077\n", 1022 | "INFO:root:Episode: 190, Reward: 041\n", 1023 | "INFO:root:Episode: 191, Reward: 038\n", 1024 | "INFO:root:Episode: 192, Reward: 074\n", 1025 | "INFO:root:Episode: 193, Reward: 028\n", 1026 | "INFO:root:Episode: 194, Reward: 027\n", 1027 | "INFO:root:Episode: 195, Reward: 036\n", 1028 | "INFO:root:Episode: 196, Reward: 040\n", 1029 | "INFO:root:Episode: 197, Reward: 028\n", 1030 | "INFO:root:Episode: 198, Reward: 030\n", 1031 | "INFO:root:Episode: 199, Reward: 034\n", 1032 | "INFO:root:Episode: 200, Reward: 044\n", 1033 | "INFO:root:Episode: 201, Reward: 113\n", 1034 | "INFO:root:Episode: 202, Reward: 089\n", 1035 | "INFO:root:Episode: 203, Reward: 147\n", 1036 | "INFO:root:Episode: 204, Reward: 077\n", 1037 | "INFO:root:Episode: 205, Reward: 056\n", 1038 | "INFO:root:Episode: 206, Reward: 024\n", 1039 | "INFO:root:Episode: 207, Reward: 091\n", 1040 | "INFO:root:Episode: 208, Reward: 033\n", 1041 | "INFO:root:Episode: 209, Reward: 078\n", 1042 | "INFO:root:Episode: 210, Reward: 044\n", 1043 | "INFO:root:Episode: 211, Reward: 110\n", 1044 | "INFO:root:Episode: 212, Reward: 163\n", 1045 | "INFO:root:Episode: 213, Reward: 053\n", 1046 | "INFO:root:Episode: 214, Reward: 102\n", 1047 | "INFO:root:Episode: 215, Reward: 136\n", 1048 | "INFO:root:Episode: 216, Reward: 128\n", 1049 | "INFO:root:Episode: 217, Reward: 066\n", 1050 | "INFO:root:Episode: 218, Reward: 034\n" 1051 | ] 1052 | }, 1053 | { 1054 | "name": "stderr", 1055 | "output_type": "stream", 1056 | "text": [ 1057 | "INFO:root:Episode: 219, Reward: 069\n", 1058 | "INFO:root:Episode: 220, Reward: 164\n", 1059 | "INFO:root:Episode: 221, Reward: 054\n", 1060 | "INFO:root:Episode: 222, Reward: 061\n", 1061 | "INFO:root:Episode: 223, Reward: 090\n", 1062 | "INFO:root:Episode: 224, Reward: 109\n", 1063 | "INFO:root:Episode: 225, Reward: 161\n", 1064 | "INFO:root:Episode: 226, Reward: 200\n", 1065 | "INFO:root:Episode: 227, Reward: 062\n", 1066 | "INFO:root:Episode: 228, Reward: 059\n", 1067 | "INFO:root:Episode: 229, Reward: 102\n", 1068 | "INFO:root:Episode: 230, Reward: 181\n", 1069 | "INFO:root:Episode: 231, Reward: 031\n", 1070 | "INFO:root:Episode: 232, Reward: 107\n", 1071 | "INFO:root:Episode: 233, Reward: 037\n", 1072 | "INFO:root:Episode: 234, Reward: 113\n", 1073 | "INFO:root:Episode: 235, Reward: 102\n", 1074 | "INFO:root:Episode: 236, Reward: 029\n", 1075 | "INFO:root:Episode: 237, Reward: 023\n", 1076 | "INFO:root:Episode: 238, Reward: 145\n", 1077 | "INFO:root:Episode: 239, Reward: 062\n", 1078 | "INFO:root:Episode: 240, Reward: 068\n", 1079 | "INFO:root:Episode: 241, Reward: 157\n", 1080 | "INFO:root:Episode: 242, Reward: 073\n", 1081 | "INFO:root:Episode: 243, Reward: 077\n", 1082 | "INFO:root:Episode: 244, Reward: 146\n", 1083 | "INFO:root:Episode: 245, Reward: 067\n", 1084 | "INFO:root:Episode: 246, Reward: 130\n", 1085 | "INFO:root:Episode: 247, Reward: 080\n", 1086 | "INFO:root:Episode: 248, Reward: 034\n", 1087 | "INFO:root:Episode: 249, Reward: 188\n", 1088 | "INFO:root:Episode: 250, Reward: 142\n", 1089 | "INFO:root:Episode: 251, Reward: 186\n", 1090 | "INFO:root:Episode: 252, Reward: 049\n", 1091 | "INFO:root:Episode: 253, Reward: 048\n", 1092 | "INFO:root:Episode: 254, Reward: 056\n", 1093 | "INFO:root:Episode: 255, Reward: 061\n", 1094 | "INFO:root:Episode: 256, Reward: 138\n", 1095 | "INFO:root:Episode: 257, Reward: 076\n", 1096 | "INFO:root:Episode: 258, Reward: 125\n", 1097 | "INFO:root:Episode: 259, Reward: 161\n", 1098 | "INFO:root:Episode: 260, Reward: 053\n", 1099 | "INFO:root:Episode: 261, Reward: 045\n", 1100 | "INFO:root:Episode: 262, Reward: 141\n", 1101 | "INFO:root:Episode: 263, Reward: 050\n", 1102 | "INFO:root:Episode: 264, Reward: 089\n", 1103 | "INFO:root:Episode: 265, Reward: 123\n", 1104 | "INFO:root:Episode: 266, Reward: 082\n", 1105 | "INFO:root:Episode: 267, Reward: 064\n", 1106 | "INFO:root:Episode: 268, Reward: 088\n", 1107 | "INFO:root:Episode: 269, Reward: 189\n", 1108 | "INFO:root:Episode: 270, Reward: 081\n", 1109 | "INFO:root:Episode: 271, Reward: 041\n", 1110 | "INFO:root:Episode: 272, Reward: 140\n", 1111 | "INFO:root:Episode: 273, Reward: 107\n", 1112 | "INFO:root:Episode: 274, Reward: 105\n", 1113 | "INFO:root:Episode: 275, Reward: 174\n", 1114 | "INFO:root:Episode: 276, Reward: 112\n", 1115 | "INFO:root:Episode: 277, Reward: 080\n", 1116 | "INFO:root:Episode: 278, Reward: 195\n", 1117 | "INFO:root:Episode: 279, Reward: 186\n", 1118 | "INFO:root:Episode: 280, Reward: 036\n", 1119 | "INFO:root:Episode: 281, Reward: 087\n", 1120 | "INFO:root:Episode: 282, Reward: 133\n", 1121 | "INFO:root:Episode: 283, Reward: 037\n", 1122 | "INFO:root:Episode: 284, Reward: 114\n", 1123 | "INFO:root:Episode: 285, Reward: 065\n", 1124 | "INFO:root:Episode: 286, Reward: 031\n", 1125 | "INFO:root:Episode: 287, Reward: 071\n", 1126 | "INFO:root:Episode: 288, Reward: 168\n", 1127 | "INFO:root:Episode: 289, Reward: 121\n", 1128 | "INFO:root:Episode: 290, Reward: 200\n", 1129 | "INFO:root:Episode: 291, Reward: 046\n", 1130 | "INFO:root:Episode: 292, Reward: 048\n", 1131 | "INFO:root:Episode: 293, Reward: 100\n", 1132 | "INFO:root:Episode: 294, Reward: 088\n", 1133 | "INFO:root:Episode: 295, Reward: 158\n", 1134 | "INFO:root:Episode: 296, Reward: 151\n", 1135 | "INFO:root:Episode: 297, Reward: 037\n", 1136 | "INFO:root:Episode: 298, Reward: 136\n", 1137 | "INFO:root:Episode: 299, Reward: 096\n", 1138 | "INFO:root:Episode: 300, Reward: 047\n", 1139 | "INFO:root:Episode: 301, Reward: 121\n", 1140 | "INFO:root:Episode: 302, Reward: 041\n", 1141 | "INFO:root:Episode: 303, Reward: 128\n", 1142 | "INFO:root:Episode: 304, Reward: 163\n", 1143 | "INFO:root:Episode: 305, Reward: 181\n", 1144 | "INFO:root:Episode: 306, Reward: 104\n", 1145 | "INFO:root:Episode: 307, Reward: 121\n", 1146 | "INFO:root:Episode: 308, Reward: 142\n", 1147 | "INFO:root:Episode: 309, Reward: 200\n", 1148 | "INFO:root:Episode: 310, Reward: 200\n", 1149 | "INFO:root:Episode: 311, Reward: 198\n", 1150 | "INFO:root:Episode: 312, Reward: 181\n", 1151 | "INFO:root:Episode: 313, Reward: 062\n", 1152 | "INFO:root:Episode: 314, Reward: 159\n", 1153 | "INFO:root:Episode: 315, Reward: 123\n", 1154 | "INFO:root:Episode: 316, Reward: 097\n", 1155 | "INFO:root:Episode: 317, Reward: 200\n", 1156 | "INFO:root:Episode: 318, Reward: 080\n", 1157 | "INFO:root:Episode: 319, Reward: 070\n", 1158 | "INFO:root:Episode: 320, Reward: 200\n", 1159 | "INFO:root:Episode: 321, Reward: 095\n", 1160 | "INFO:root:Episode: 322, Reward: 142\n", 1161 | "INFO:root:Episode: 323, Reward: 138\n", 1162 | "INFO:root:Episode: 324, Reward: 141\n", 1163 | "INFO:root:Episode: 325, Reward: 175\n", 1164 | "INFO:root:Episode: 326, Reward: 092\n", 1165 | "INFO:root:Episode: 327, Reward: 124\n", 1166 | "INFO:root:Episode: 328, Reward: 200\n", 1167 | "INFO:root:Episode: 329, Reward: 113\n", 1168 | "INFO:root:Episode: 330, Reward: 191\n", 1169 | "INFO:root:Episode: 331, Reward: 177\n", 1170 | "INFO:root:Episode: 332, Reward: 200\n", 1171 | "INFO:root:Episode: 333, Reward: 074\n", 1172 | "INFO:root:Episode: 334, Reward: 034\n", 1173 | "INFO:root:Episode: 335, Reward: 159\n", 1174 | "INFO:root:Episode: 336, Reward: 127\n", 1175 | "INFO:root:Episode: 337, Reward: 200\n", 1176 | "INFO:root:Episode: 338, Reward: 140\n", 1177 | "INFO:root:Episode: 339, Reward: 135\n", 1178 | "INFO:root:Episode: 340, Reward: 200\n", 1179 | "INFO:root:Episode: 341, Reward: 200\n", 1180 | "INFO:root:Episode: 342, Reward: 177\n", 1181 | "INFO:root:Episode: 343, Reward: 187\n", 1182 | "INFO:root:Episode: 344, Reward: 149\n", 1183 | "INFO:root:Episode: 345, Reward: 191\n", 1184 | "INFO:root:Episode: 346, Reward: 155\n", 1185 | "INFO:root:Episode: 347, Reward: 157\n", 1186 | "INFO:root:Episode: 348, Reward: 164\n", 1187 | "INFO:root:Episode: 349, Reward: 158\n", 1188 | "INFO:root:Episode: 350, Reward: 200\n", 1189 | "INFO:root:Episode: 351, Reward: 138\n", 1190 | "INFO:root:Episode: 352, Reward: 144\n", 1191 | "INFO:root:Episode: 353, Reward: 147\n", 1192 | "INFO:root:Episode: 354, Reward: 200\n", 1193 | "INFO:root:Episode: 355, Reward: 145\n", 1194 | "INFO:root:Episode: 356, Reward: 150\n", 1195 | "INFO:root:Episode: 357, Reward: 062\n", 1196 | "INFO:root:Episode: 358, Reward: 149\n", 1197 | "INFO:root:Episode: 359, Reward: 187\n", 1198 | "INFO:root:Episode: 360, Reward: 164\n", 1199 | "INFO:root:Episode: 361, Reward: 144\n", 1200 | "INFO:root:Episode: 362, Reward: 200\n", 1201 | "INFO:root:Episode: 363, Reward: 200\n", 1202 | "INFO:root:Episode: 364, Reward: 104\n", 1203 | "INFO:root:Episode: 365, Reward: 127\n", 1204 | "INFO:root:Episode: 366, Reward: 200\n", 1205 | "INFO:root:Episode: 367, Reward: 111\n", 1206 | "INFO:root:Episode: 368, Reward: 200\n", 1207 | "INFO:root:Episode: 369, Reward: 200\n", 1208 | "INFO:root:Episode: 370, Reward: 124\n", 1209 | "INFO:root:Episode: 371, Reward: 200\n", 1210 | "INFO:root:Episode: 372, Reward: 200\n", 1211 | "INFO:root:Episode: 373, Reward: 178\n", 1212 | "INFO:root:Episode: 374, Reward: 200\n", 1213 | "INFO:root:Episode: 375, Reward: 200\n", 1214 | "INFO:root:Episode: 376, Reward: 057\n", 1215 | "INFO:root:Episode: 377, Reward: 166\n", 1216 | "INFO:root:Episode: 378, Reward: 118\n", 1217 | "INFO:root:Episode: 379, Reward: 200\n", 1218 | "INFO:root:Episode: 380, Reward: 082\n", 1219 | "INFO:root:Episode: 381, Reward: 118\n", 1220 | "INFO:root:Episode: 382, Reward: 058\n", 1221 | "INFO:root:Episode: 383, Reward: 200\n", 1222 | "INFO:root:Episode: 384, Reward: 171\n", 1223 | "INFO:root:Episode: 385, Reward: 113\n", 1224 | "INFO:root:Episode: 386, Reward: 169\n", 1225 | "INFO:root:Episode: 387, Reward: 103\n", 1226 | "INFO:root:Episode: 388, Reward: 141\n", 1227 | "INFO:root:Episode: 389, Reward: 191\n", 1228 | "INFO:root:Episode: 390, Reward: 200\n", 1229 | "INFO:root:Episode: 391, Reward: 171\n", 1230 | "INFO:root:Episode: 392, Reward: 052\n", 1231 | "INFO:root:Episode: 393, Reward: 171\n" 1232 | ] 1233 | }, 1234 | { 1235 | "name": "stdout", 1236 | "output_type": "stream", 1237 | "text": [ 1238 | "Finished training, testing...\n", 1239 | "Total Episode Reward: 200 out of 200\n" 1240 | ] 1241 | } 1242 | ], 1243 | "source": [ 1244 | "with tf.Graph().as_default():\n", 1245 | " print(\"Eager Execution:\", tf.executing_eagerly()) # False\n", 1246 | "\n", 1247 | " model = Model(num_actions=env.action_space.n)\n", 1248 | " agent = A2CAgent(model)\n", 1249 | "\n", 1250 | " rewards_history = agent.train(env)\n", 1251 | " print(\"Finished training, testing...\")\n", 1252 | " print(\"Total Episode Reward: %d out of 200\" % agent.test(env))" 1253 | ] 1254 | }, 1255 | { 1256 | "cell_type": "markdown", 1257 | "metadata": {}, 1258 | "source": [ 1259 | "# Benchmarks" 1260 | ] 1261 | }, 1262 | { 1263 | "cell_type": "code", 1264 | "execution_count": 18, 1265 | "metadata": {}, 1266 | "outputs": [], 1267 | "source": [ 1268 | "# Note: comparing wall time isn't exactly fair due to specifics of how things are executed on multi-core CPU" 1269 | ] 1270 | }, 1271 | { 1272 | "cell_type": "code", 1273 | "execution_count": 14, 1274 | "metadata": {}, 1275 | "outputs": [], 1276 | "source": [ 1277 | "env = gym.make('CartPole-v0')\n", 1278 | "obs = np.repeat(env.reset()[None, :], 100000, axis=0)" 1279 | ] 1280 | }, 1281 | { 1282 | "cell_type": "markdown", 1283 | "metadata": {}, 1284 | "source": [ 1285 | "## Eager Benchmark" 1286 | ] 1287 | }, 1288 | { 1289 | "cell_type": "code", 1290 | "execution_count": 15, 1291 | "metadata": {}, 1292 | "outputs": [ 1293 | { 1294 | "name": "stdout", 1295 | "output_type": "stream", 1296 | "text": [ 1297 | "Eager Execution: True\n", 1298 | "Eager Keras Model: True\n", 1299 | "CPU times: user 639 ms, sys: 736 ms, total: 1.38 s\n", 1300 | "Wall time: 116 ms\n" 1301 | ] 1302 | } 1303 | ], 1304 | "source": [ 1305 | "%%time\n", 1306 | "\n", 1307 | "model = Model(env.action_space.n)\n", 1308 | "model.run_eagerly = True\n", 1309 | "\n", 1310 | "print(\"Eager Execution: \", tf.executing_eagerly())\n", 1311 | "print(\"Eager Keras Model:\", model.run_eagerly)\n", 1312 | "\n", 1313 | "_ = model(obs)\n", 1314 | "# _ = model.predict(obs)" 1315 | ] 1316 | }, 1317 | { 1318 | "cell_type": "markdown", 1319 | "metadata": {}, 1320 | "source": [ 1321 | "## Static Benchmark" 1322 | ] 1323 | }, 1324 | { 1325 | "cell_type": "code", 1326 | "execution_count": 16, 1327 | "metadata": {}, 1328 | "outputs": [ 1329 | { 1330 | "name": "stdout", 1331 | "output_type": "stream", 1332 | "text": [ 1333 | "Eager Execution: False\n", 1334 | "Eager Keras Model: False\n", 1335 | "CPU times: user 793 ms, sys: 79.7 ms, total: 873 ms\n", 1336 | "Wall time: 656 ms\n" 1337 | ] 1338 | } 1339 | ], 1340 | "source": [ 1341 | "%%time\n", 1342 | "\n", 1343 | "with tf.Graph().as_default():\n", 1344 | " model = Model(env.action_space.n)\n", 1345 | "\n", 1346 | " print(\"Eager Execution: \", tf.executing_eagerly())\n", 1347 | " print(\"Eager Keras Model:\", model.run_eagerly)\n", 1348 | "\n", 1349 | " _ = model.predict(obs)" 1350 | ] 1351 | }, 1352 | { 1353 | "cell_type": "markdown", 1354 | "metadata": {}, 1355 | "source": [ 1356 | "## Default Benchmark" 1357 | ] 1358 | }, 1359 | { 1360 | "cell_type": "code", 1361 | "execution_count": 17, 1362 | "metadata": {}, 1363 | "outputs": [ 1364 | { 1365 | "name": "stdout", 1366 | "output_type": "stream", 1367 | "text": [ 1368 | "Eager Execution: True\n", 1369 | "Eager Keras Model: False\n", 1370 | "CPU times: user 994 ms, sys: 23.1 ms, total: 1.02 s\n", 1371 | "Wall time: 769 ms\n" 1372 | ] 1373 | } 1374 | ], 1375 | "source": [ 1376 | "%%time\n", 1377 | "\n", 1378 | "model = Model(env.action_space.n)\n", 1379 | "\n", 1380 | "print(\"Eager Execution: \", tf.executing_eagerly())\n", 1381 | "print(\"Eager Keras Model:\", model.run_eagerly)\n", 1382 | "\n", 1383 | "_ = model.predict(obs)" 1384 | ] 1385 | } 1386 | ], 1387 | "metadata": { 1388 | "kernelspec": { 1389 | "display_name": "Python 3", 1390 | "language": "python", 1391 | "name": "python3" 1392 | }, 1393 | "language_info": { 1394 | "codemirror_mode": { 1395 | "name": "ipython", 1396 | "version": 3 1397 | }, 1398 | "file_extension": ".py", 1399 | "mimetype": "text/x-python", 1400 | "name": "python", 1401 | "nbconvert_exporter": "python", 1402 | "pygments_lexer": "ipython3", 1403 | "version": "3.6.8" 1404 | } 1405 | }, 1406 | "nbformat": 4, 1407 | "nbformat_minor": 2 1408 | } 1409 | --------------------------------------------------------------------------------