├── CONTRIBUTING.md ├── README.md ├── trfl ├── clipping_ops.py ├── periodic_ops_test.py ├── __init__.py ├── clipping_ops_test.py ├── policy_ops.py ├── dpg_ops.py ├── periodic_ops.py ├── policy_ops_test.py ├── base_ops.py ├── indexing_ops.py ├── continuous_retrace_ops_test.py ├── dpg_ops_test.py ├── target_update_ops.py ├── target_update_ops_test.py ├── distribution_ops_test.py ├── distribution_ops.py ├── pixel_control_ops_test.py ├── continuous_retrace_ops.py ├── sequence_ops.py ├── pixel_control_ops.py ├── value_ops_test.py ├── retrace_ops_test.py ├── vtrace_ops_test.py ├── value_ops.py ├── dist_value_ops.py └── vtrace_ops.py ├── setup.py ├── docs ├── multistep_forward_view.md └── index.md └── LICENSE /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TRFL 2 | 3 | TRFL (pronounced "truffle") is a library built on top of TensorFlow that exposes 4 | several useful building blocks for implementing Reinforcement Learning agents. 5 | 6 | 7 | ## Installation 8 | 9 | TRFL can be installed from pip with the following command: 10 | `pip install trfl` 11 | 12 | TRFL will work with both the CPU and GPU version of tensorflow, but to allow 13 | for that it does not list Tensorflow as a requirement, so you need to install 14 | Tensorflow and Tensorflow-probability separately if you haven't already done so. 15 | 16 | ## Usage Example 17 | 18 | ```python 19 | import tensorflow as tf 20 | import trfl 21 | 22 | # Q-values for the previous and next timesteps, shape [batch_size, num_actions]. 23 | q_tm1 = tf.get_variable( 24 | "q_tm1", initializer=[[1., 1., 0.], [1., 2., 0.]], dtype=tf.float32) 25 | q_t = tf.get_variable( 26 | "q_t", initializer=[[0., 1., 0.], [1., 2., 0.]], dtype=tf.float32) 27 | 28 | # Action indices, discounts and rewards, shape [batch_size]. 29 | a_tm1 = tf.constant([0, 1], dtype=tf.int32) 30 | r_t = tf.constant([1, 1], dtype=tf.float32) 31 | pcont_t = tf.constant([0, 1], dtype=tf.float32) # the discount factor 32 | 33 | # Q-learning loss, and auxiliary data. 34 | loss, q_learning = trfl.qlearning(q_tm1, a_tm1, r_t, pcont_t, q_t) 35 | ``` 36 | 37 | `loss` is the tensor representing the loss. For Q-learning, it is half the 38 | squared difference between the predicted Q-values and the TD targets, shape 39 | `[batch_size]`. Extra information is in the `q_learning` namedtuple, including 40 | `q_learning.td_error` and `q_learning.target`. 41 | 42 | The `loss` tensor can be differentiated to derive the corresponding RL update. 43 | 44 | ```python 45 | reduced_loss = tf.reduce_mean(loss) 46 | optimizer = tf.train.AdamOptimizer(learning_rate=0.1) 47 | train_op = optimizer.minimize(reduced_loss) 48 | ``` 49 | 50 | All loss functions in the package return both a loss tensor and a namedtuple 51 | with extra information, using the above convention, but different functions 52 | may have different `extra` fields. Check the documentation of each function 53 | below for more information. 54 | 55 | # Documentation 56 | 57 | Check out the full documentation page 58 | [here](https://github.com/deepmind/trfl/blob/master/docs/index.md). 59 | -------------------------------------------------------------------------------- /trfl/clipping_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Ops to implement gradient clipping.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow.compat.v1 as tf 22 | 23 | 24 | def huber_loss(input_tensor, quadratic_linear_boundary, name=None): 25 | """Calculates huber loss of `input_tensor`. 26 | 27 | For each value x in `input_tensor`, the following is calculated: 28 | 29 | ``` 30 | 0.5 * x^2 if |x| <= d 31 | 0.5 * d^2 + d * (|x| - d) if |x| > d 32 | ``` 33 | 34 | where d is `quadratic_linear_boundary`. 35 | 36 | When `input_tensor` is a loss this results in a form of gradient clipping. 37 | This is, for instance, how gradients are clipped in DQN and its variants. 38 | 39 | Args: 40 | input_tensor: `Tensor`, input values to calculate the huber loss on. 41 | quadratic_linear_boundary: `float`, the point where the huber loss function 42 | changes from a quadratic to linear. 43 | name: `string`, name for the operation (optional). 44 | 45 | Returns: 46 | `Tensor` of the same shape as `input_tensor`, containing values calculated 47 | in the manner described above. 48 | 49 | Raises: 50 | ValueError: if quadratic_linear_boundary <= 0. 51 | """ 52 | if quadratic_linear_boundary < 0: 53 | raise ValueError("quadratic_linear_boundary must be > 0.") 54 | 55 | with tf.name_scope( 56 | name, default_name="huber_loss", 57 | values=[input_tensor, quadratic_linear_boundary]): 58 | abs_x = tf.abs(input_tensor) 59 | delta = quadratic_linear_boundary 60 | quad = tf.minimum(abs_x, delta) 61 | # The following expression is the same in value as 62 | # tf.maximum(abs_x - delta, 0), but importantly the gradient for the 63 | # expression when abs_x == delta is 0 (for tf.maximum it would be 1). This 64 | # is necessary to avoid doubling the gradient, since there is already a 65 | # non-zero contribution to the gradient from the quadratic term. 66 | lin = (abs_x - quad) 67 | return 0.5 * quad**2 + delta * lin 68 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Setup for pip package.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from setuptools import find_packages 22 | from setuptools import setup 23 | 24 | REQUIRED_PACKAGES = ['six', 'absl-py', 'numpy', 'wrapt', 'dm-tree'] 25 | EXTRA_PACKAGES = { 26 | 'tensorflow': [ 27 | 'tensorflow>=1.15', 'tensorflow-probability>=0.8' 28 | ], 29 | 'tensorflow with gpu': [ 30 | 'tensorflow-gpu>=1.15', 'tensorflow-probability>=0.8' 31 | ], 32 | } 33 | 34 | setup( 35 | name='trfl', 36 | version='1.2.0', 37 | description=('trfl is a library of building blocks for ' 38 | 'reinforcement learning algorithms.'), 39 | long_description='', 40 | url='http://www.github.com/deepmind/trfl/', 41 | author='DeepMind', 42 | author_email='trfl-steering@google.com', 43 | # Contained modules and scripts. 44 | packages=find_packages(), 45 | install_requires=REQUIRED_PACKAGES, 46 | extras_require=EXTRA_PACKAGES, 47 | # Add in any packaged data. 48 | include_package_data=True, 49 | zip_safe=False, 50 | # PyPI package information. 51 | classifiers=[ 52 | 'Development Status :: 4 - Beta', 53 | 'Intended Audience :: Developers', 54 | 'Intended Audience :: Education', 55 | 'Intended Audience :: Science/Research', 56 | 'License :: OSI Approved :: Apache Software License', 57 | 'Operating System :: MacOS :: MacOS X', 58 | 'Operating System :: POSIX', 59 | 'Operating System :: Unix', 60 | 'Programming Language :: Python :: 3.4', 61 | 'Programming Language :: Python :: 3.5', 62 | 'Programming Language :: Python :: 3.6', 63 | 'Programming Language :: Python :: 3.7', 64 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 65 | 'Topic :: Software Development :: Libraries', 66 | ], 67 | license='Apache 2.0', 68 | keywords='trfl truffle tensorflow tensor machine reinforcement learning', 69 | test_suite='nose.collector', 70 | tests_require=['nose'], 71 | ) 72 | -------------------------------------------------------------------------------- /trfl/periodic_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for periodic_ops.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # Dependency imports 22 | import tensorflow.compat.v1 as tf 23 | from trfl import periodic_ops 24 | 25 | 26 | class PeriodicallyTest(tf.test.TestCase): 27 | """Tests function periodically.""" 28 | 29 | def testPeriodically(self): 30 | """Tests that a function is called exactly every `period` steps.""" 31 | target = tf.Variable(0) 32 | period = 3 33 | 34 | periodic_update = periodic_ops.periodically( 35 | body=lambda: target.assign_add(1).op, period=period) 36 | 37 | with self.test_session() as sess: 38 | sess.run(tf.global_variables_initializer()) 39 | desired_values = [1, 1, 1, 2, 2, 2, 3, 3, 3, 4] 40 | for desired_value in desired_values: 41 | sess.run(periodic_update) 42 | result = sess.run(target) 43 | self.assertEqual(desired_value, result) 44 | 45 | def testPeriodOne(self): 46 | """Tests that the function is called every time if period == 1.""" 47 | target = tf.Variable(0) 48 | 49 | periodic_update = periodic_ops.periodically( 50 | body=lambda: target.assign_add(1).op, period=1) 51 | 52 | with self.test_session() as sess: 53 | sess.run(tf.global_variables_initializer()) 54 | for desired_value in range(1, 11): 55 | _, result = sess.run([periodic_update, target]) 56 | self.assertEqual(desired_value, result) 57 | 58 | def testPeriodNone(self): 59 | """Tests that the function is never called if period == None.""" 60 | target = tf.Variable(0) 61 | 62 | periodic_update = periodic_ops.periodically( 63 | body=lambda: target.assign_add(1).op, period=None) 64 | 65 | with self.test_session() as sess: 66 | sess.run(tf.global_variables_initializer()) 67 | desired_value = 0 68 | for _ in range(1, 11): 69 | _, result = sess.run([periodic_update, target]) 70 | self.assertEqual(desired_value, result) 71 | 72 | def testFunctionNotCallable(self): 73 | """Tests value error when argument fn is not a callable.""" 74 | self.assertRaises( 75 | TypeError, periodic_ops.periodically, body=1, period=2) 76 | 77 | 78 | if __name__ == '__main__': 79 | tf.test.main() 80 | -------------------------------------------------------------------------------- /docs/multistep_forward_view.md: -------------------------------------------------------------------------------- 1 | ### Multistep forward view 2 | 3 | The 4 | [multistep_forward_view](https://github.com/deepmind/trfl/blob/master/trfl/sequence_ops.py?q=multistep_forward_view) 5 | function computes mixed multistep returns in terms of the instantaneous 6 | `rewards`, discount factors `pcontinues`, `state_value` estimates, and mixing 7 | weights `lambda_`. In the math that follows we will replace these by 8 | $$r_{0:T-1}$$, $$\gamma_{0:T-1}$$, $$V_{1:T}$$, and $$\lambda_{0:T-1}$$, 9 | respectively. Note that in the implementation, the `state_values` array is 10 | offset in time by 1 relative to the other arrays. 11 | 12 | The mixed returns $$M_t$$ are computed by the following recurrence, using a 13 | backwards scan: 14 | 15 | $$ 16 | M_t = r_t + \gamma_t (\lambda_t M_{t+1} + (1-\lambda_t) V_{t+1}) \label{eq:recurrence} \tag{1}\\ 17 | M_{T-1} = r_{T-1} + \gamma_{T-1} V_T 18 | $$ 19 | 20 | Here we can see why $$M_t$$ is a valid estimate of the return at time $$t$$. The 21 | recurrence is applying the Bellman Equation, using the $$\lambda_t$$-weighted 22 | mixture of $$M_{t+1}$$ and $$V_{t+1}$$, both of which are valid estimates of the 23 | expected return at time $$t+1$$. 24 | 25 | What's left is to show that we are computing the right mixture of multistep 26 | returns. Let $$R(t, k)$$ be the $$\gamma_{t:k}$$-discounted return from time 27 | $$t$$ to $$k$$: 28 | 29 | $$ 30 | R(t, k) = r_t + \gamma_t r_{t+1} + \cdots + (\gamma_t \gamma_{t+1} \cdots \gamma_{k-1}) r_k + (\gamma_t \gamma_{t+1} \cdots \gamma_k) V_{k+1}\\ 31 | = \sum_{i=t}^k \left(\prod_{j=t}^{i-1} \gamma_j \right) r_i + \left(\prod_{i=t}^k \gamma_i \right) V_{k+1} 32 | $$ 33 | 34 | We should mention that $$R(t, k)$$ would correspond to the $$k-t$$ step return 35 | $$G_t^{k-t}$$ in Sutton and Barto's notation. Note that $$R$$ satisfies a 36 | Bellman-style recurrence: 37 | 38 | $$ 39 | R(t-1, k) = r_{t-1} + \gamma_{t-1} R(t, k)\\ 40 | R(t, t) = r_t + \gamma_t V_{t+1} 41 | $$ 42 | 43 | The desired[^1] $$\lambda$$-weighted mixture is given by: 44 | 45 | $$ 46 | L(t, k) = (1-\lambda_t) R(t, t) + (1-\lambda_{t+1}) \lambda_t R(t, t+1) + \cdots + ((1-\lambda_k) \lambda_{k-1} \cdots \lambda_t) R(t, k) + \lambda_k\cdots\lambda_t R(t, k)\\ 47 | = \sum_{i=t}^k \left((1-\lambda_i) \prod_{j=t}^{i-1} \lambda_j \right) R(t, i) + \left(\prod_{i=t}^k \lambda_i \right) R(t, k) \\ 48 | $$ 49 | 50 | We have that 51 | 52 | $$ L(t-1, k) = \sum_{i=t-1}^k \left((1-\lambda_i) \prod_{j=t-1}^{i-1} \lambda_j 53 | \right) R(t-1, i) + \left(\prod_{i=t-1}^k \lambda_i \right) R(t-1, k) \\ 54 | = (1-\lambda_{t-1})R(t-1, t-1) + \lambda_{t-1}\sum_{i=t}^k \left((1-\lambda_i) \prod_{j=t}^{i-1} \lambda_j \right) (r_{t-1} + \gamma_{t-1} R(t, i)) + \lambda_{t-1}\left(\prod_{i=t}^k \lambda_i \right) \left(r_{t-1} + \gamma_{t-1} R(t, k)\right) \\ 55 | = r_{t-1} + (1-\lambda_{t-1})\gamma_{t-1} V_t + \lambda_{t-1}\gamma_{t-1} L(t, k) 56 | $$ 57 | 58 | Thus, $$L(t, k)$$ also satisfies recurrence $$\eqref{eq:recurrence}$$, as 59 | desired. 60 | 61 | [^1]: [Sutton and Barto](http://incompleteideas.net/book/ebook/node74.html), 62 | equation 7.3 63 | -------------------------------------------------------------------------------- /trfl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Flattened namespace for trfl.""" 16 | 17 | from trfl.action_value_ops import double_qlearning 18 | from trfl.action_value_ops import persistent_qlearning 19 | from trfl.action_value_ops import qlambda 20 | from trfl.action_value_ops import qlearning 21 | from trfl.action_value_ops import qv_learning 22 | from trfl.action_value_ops import sarsa 23 | from trfl.action_value_ops import sarsa_lambda 24 | from trfl.action_value_ops import sarse 25 | from trfl.base_ops import assert_rank_and_shape_compatibility 26 | from trfl.base_ops import best_effort_shape 27 | from trfl.clipping_ops import huber_loss 28 | from trfl.continuous_retrace_ops import retrace_from_action_log_probs 29 | from trfl.continuous_retrace_ops import retrace_from_importance_weights 30 | from trfl.discrete_policy_gradient_ops import discrete_policy_entropy_loss 31 | from trfl.discrete_policy_gradient_ops import discrete_policy_gradient 32 | from trfl.discrete_policy_gradient_ops import discrete_policy_gradient_loss 33 | from trfl.discrete_policy_gradient_ops import sequence_advantage_actor_critic_loss 34 | from trfl.dist_value_ops import categorical_dist_double_qlearning 35 | from trfl.dist_value_ops import categorical_dist_qlearning 36 | from trfl.dist_value_ops import categorical_dist_td_learning 37 | from trfl.dpg_ops import dpg 38 | from trfl.indexing_ops import batched_index 39 | from trfl.periodic_ops import periodically 40 | from trfl.pixel_control_ops import pixel_control_loss 41 | from trfl.pixel_control_ops import pixel_control_rewards 42 | from trfl.policy_gradient_ops import policy_entropy_loss 43 | from trfl.policy_gradient_ops import policy_gradient 44 | from trfl.policy_gradient_ops import policy_gradient_loss 45 | from trfl.policy_gradient_ops import sequence_a2c_loss 46 | from trfl.policy_ops import epsilon_greedy 47 | from trfl.retrace_ops import retrace 48 | from trfl.retrace_ops import retrace_core 49 | from trfl.sequence_ops import multistep_forward_view 50 | from trfl.sequence_ops import scan_discounted_sum 51 | from trfl.target_update_ops import periodic_target_update 52 | from trfl.target_update_ops import update_target_variables 53 | from trfl.value_ops import generalized_lambda_returns 54 | from trfl.value_ops import qv_max 55 | from trfl.value_ops import td_lambda 56 | from trfl.value_ops import td_learning 57 | from trfl.vtrace_ops import vtrace_from_importance_weights 58 | from trfl.vtrace_ops import vtrace_from_logits 59 | 60 | __version__ = '1.2.0' 61 | -------------------------------------------------------------------------------- /trfl/clipping_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for clipping_ops.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # Dependency imports 22 | import numpy as np 23 | import tensorflow.compat.v1 as tf 24 | from trfl import clipping_ops 25 | 26 | 27 | class HuberLossTest(tf.test.TestCase): 28 | 29 | def testValue(self): 30 | with self.test_session(): 31 | quadratic_linear_boundary = 2 32 | xs = np.array( 33 | [-3.5, -2.1, 2, -1.9 - 1, -0.5, 0, 0.5, 1, 1.9, 2, 2.1, 3.5]) 34 | ys = clipping_ops.huber_loss(xs, quadratic_linear_boundary).eval() 35 | 36 | d = quadratic_linear_boundary 37 | 38 | # Check values for x <= -2 39 | ys_lo = ys[xs <= -d] 40 | xs_lo = xs[xs <= -d] 41 | expected_ys_lo = [0.5 * d**2 + d * (-x - d) for x in xs_lo] 42 | self.assertAllClose(ys_lo, expected_ys_lo) 43 | 44 | # Check values for x >= 2 45 | ys_hi = ys[xs >= d] 46 | xs_hi = xs[xs >= d] 47 | expected_ys_hi = [0.5 * d**2 + d * (x - d) for x in xs_hi] 48 | self.assertAllClose(ys_hi, expected_ys_hi) 49 | 50 | # Check values for x in (-2, 2) 51 | ys_mid = ys[np.abs(xs) < d] 52 | xs_mid = xs[np.abs(xs) < d] 53 | expected_ys_mid = [0.5 * x**2 for x in xs_mid] 54 | self.assertAllClose(ys_mid, expected_ys_mid) 55 | 56 | def testGradient(self): 57 | with self.test_session() as sess: 58 | x = tf.placeholder(tf.float64) 59 | quadratic_linear_boundary = 3 60 | loss = clipping_ops.huber_loss(x, quadratic_linear_boundary) 61 | xs = np.array([-5, -4, -3.1, -3, -2.9, 2, -1, 0, 1, 2, 2.9, 3, 3.1, 4, 5]) 62 | grads = sess.run(tf.gradients([loss], [x]), feed_dict={x: xs})[0] 63 | 64 | self.assertTrue(np.all(np.abs(grads) <= quadratic_linear_boundary)) 65 | 66 | # Everything <= -3 should have gradient -3. 67 | grads_lo = grads[xs <= -quadratic_linear_boundary] 68 | self.assertAllEqual(grads_lo, 69 | [-quadratic_linear_boundary] * grads_lo.shape[0]) 70 | 71 | # Everything >= 3 should have gradient 3. 72 | grads_hi = grads[xs >= quadratic_linear_boundary] 73 | self.assertAllEqual(grads_hi, 74 | [quadratic_linear_boundary] * grads_hi.shape[0]) 75 | 76 | # x in (-3, 3) should have gradient x. 77 | grads_mid = grads[np.abs(xs) <= quadratic_linear_boundary] 78 | xs_mid = xs[np.abs(xs) <= quadratic_linear_boundary] 79 | self.assertAllEqual(grads_mid, xs_mid) 80 | 81 | 82 | if __name__ == "__main__": 83 | tf.test.main() 84 | -------------------------------------------------------------------------------- /trfl/policy_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """TensorFlow ops for expressing common types of RL policies.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # Dependency imports 22 | 23 | import tensorflow.compat.v1 as tf 24 | import tensorflow_probability as tfp 25 | 26 | 27 | def epsilon_greedy(action_values, epsilon, legal_actions_mask=None): 28 | """Computes an epsilon-greedy distribution over actions. 29 | 30 | This returns a categorical distribution over a discrete action space. It is 31 | assumed that the trailing dimension of `action_values` is of length A, i.e. 32 | the number of actions. It is also assumed that actions are 0-indexed. 33 | 34 | This policy does the following: 35 | 36 | - With probability 1 - epsilon, take the action corresponding to the highest 37 | action value, breaking ties uniformly at random. 38 | - With probability epsilon, take an action uniformly at random. 39 | 40 | Args: 41 | action_values: A Tensor of action values with any rank >= 1 and dtype float. 42 | Shape can be flat ([A]), batched ([B, A]), a batch of sequences 43 | ([T, B, A]), and so on. 44 | epsilon: A scalar Tensor (or Python float) with value between 0 and 1. 45 | legal_actions_mask: An optional one-hot tensor having the shame shape and 46 | dtypes as `action_values`, defining the legal actions: 47 | legal_actions_mask[..., a] = 1 if a is legal, 0 otherwise. 48 | If not provided, all actions will be considered legal and 49 | `tf.ones_like(action_values)`. 50 | 51 | Returns: 52 | policy: tfp.distributions.Categorical distribution representing the policy. 53 | """ 54 | with tf.name_scope("epsilon_greedy", values=[action_values, epsilon]): 55 | 56 | # Convert inputs to Tensors if they aren't already. 57 | action_values = tf.convert_to_tensor(action_values) 58 | epsilon = tf.convert_to_tensor(epsilon, dtype=action_values.dtype) 59 | 60 | # We compute the action space dynamically. 61 | num_actions = tf.cast(tf.shape(action_values)[-1], action_values.dtype) 62 | 63 | # Dithering action distribution. 64 | if legal_actions_mask is None: 65 | dither_probs = 1 / num_actions * tf.ones_like(action_values) 66 | else: 67 | dither_probs = 1 / tf.reduce_sum( 68 | legal_actions_mask, axis=-1, keepdims=True) * legal_actions_mask 69 | 70 | # Greedy action distribution, breaking ties uniformly at random. 71 | max_value = tf.reduce_max(action_values, axis=-1, keepdims=True) 72 | greedy_probs = tf.cast(tf.equal(action_values, max_value), 73 | action_values.dtype) 74 | greedy_probs /= tf.reduce_sum(greedy_probs, axis=-1, keepdims=True) 75 | 76 | # Epsilon-greedy action distribution. 77 | probs = epsilon * dither_probs + (1 - epsilon) * greedy_probs 78 | 79 | # Make the policy object. 80 | policy = tfp.distributions.Categorical(probs=probs) 81 | 82 | return policy 83 | -------------------------------------------------------------------------------- /trfl/dpg_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Deterministic Policy Gradient (DPG) ops. 16 | 17 | These ops support training a value based agent on control problems with 18 | continuous action spaces. The agent's actions are assumed to be continuous 19 | vectors of size `action_dimension`. 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import collections 27 | 28 | # Dependency imports 29 | import tensorflow.compat.v1 as tf 30 | from trfl import base_ops 31 | 32 | DPGExtra = collections.namedtuple("dpg_extra", ["q_max", "a_max", "dqda"]) 33 | 34 | 35 | def dpg(q_max, a_max, dqda_clipping=None, clip_norm=False, name="DpgLearning"): 36 | """Implements the Deterministic Policy Gradient (DPG) loss as a TensorFlow Op. 37 | 38 | This op implements the loss for the `actor`, the `critic` can instead be 39 | updated by minimizing the `value_ops.td_learning` loss. 40 | 41 | See "Deterministic Policy Gradient Algorithms" by Silver, Lever, Heess, 42 | Degris, Wierstra, Riedmiller (http://proceedings.mlr.press/v32/silver14.pdf). 43 | 44 | Args: 45 | q_max: Tensor holding Q-values generated by Q network with the input of 46 | (state, a_max) pair, shape `[B]`. 47 | a_max: Tensor holding the optimal action, shape `[B, action_dimension]`. 48 | dqda_clipping: `int` or `float`, clips the gradient dqda element-wise 49 | between `[-dqda_clipping, dqda_clipping]`. 50 | clip_norm: Whether to perform dqda clipping on the vector norm of the last 51 | dimension, or component wise (default). 52 | name: name to prefix ops created within this op. 53 | 54 | Returns: 55 | A namedtuple with fields: 56 | 57 | * `loss`: a tensor containing the batch of losses, shape `[B]`. 58 | * `extra`: a namedtuple with fields: 59 | * `q_max`: Tensor holding the optimal Q values, `[B]`. 60 | * `a_max`: Tensor holding the optimal action, `[B, action_dimension]`. 61 | * `dqda`: Tensor holding the derivative dq/da, `[B, action_dimension]`. 62 | 63 | Raises: 64 | ValueError: If `q_max` doesn't depend on `a_max` or if `dqda_clipping <= 0`. 65 | """ 66 | 67 | # DPG op. 68 | with tf.name_scope(name, values=[q_max, a_max]): 69 | 70 | # Calculate the gradient dq/da. 71 | dqda = tf.gradients([q_max], [a_max])[0] 72 | 73 | # Check that `q_max` depends on `a_max`. 74 | if dqda is None: 75 | raise ValueError("q_max needs to be a function of a_max") 76 | 77 | # Clipping the gradient dq/da. 78 | if dqda_clipping is not None: 79 | if dqda_clipping <= 0: 80 | raise ValueError("dqda_clipping should be bigger than 0, {} found" 81 | .format(dqda_clipping)) 82 | if clip_norm: 83 | dqda = tf.clip_by_norm(dqda, dqda_clipping, axes=-1) 84 | else: 85 | dqda = tf.clip_by_value(dqda, -1. * dqda_clipping, dqda_clipping) 86 | 87 | # Target_a ensures correct gradient calculated during backprop. 88 | target_a = dqda + a_max 89 | # Stop the gradient going through Q network when backprop. 90 | target_a = tf.stop_gradient(target_a) 91 | # Gradient only go through actor network. 92 | loss = 0.5 * tf.reduce_sum(tf.square(target_a - a_max), axis=-1) 93 | return base_ops.LossOutput( 94 | loss, DPGExtra(q_max=q_max, a_max=a_max, dqda=dqda)) 95 | -------------------------------------------------------------------------------- /trfl/periodic_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Periodic execution ops. 16 | 17 | It is very common in Reinforcement Learning for certain ops to only need to be 18 | executed periodically, for example: once every N agent steps. The ops below 19 | support this common use-case by wrapping a subgraph as a periodic op that only 20 | actually executes the underlying computation once every N evaluations of the op, 21 | behaving as a no-op in all other calls. 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | 28 | # Dependency imports 29 | import numpy as np 30 | import tensorflow.compat.v1 as tf 31 | 32 | 33 | def periodically(body, period, counter=None, name="periodically"): 34 | """Periodically performs a tensorflow op. 35 | 36 | The body tensorflow op will be executed every `period` times the periodically 37 | op is executed. More specifically, with `n` the number of times the op has 38 | been executed, the body will be executed when `n` is a non zero positive 39 | multiple of `period` (i.e. there exist an integer `k > 0` such that 40 | `k * period == n`). 41 | 42 | If `period` is 0 or `None`, it would not perform any op and would return a 43 | `tf.no_op()`. 44 | 45 | Args: 46 | body: callable that returns the tensorflow op to be performed every time 47 | an internal counter is divisible by the period. The op must have no 48 | output (for example, a tf.group()). 49 | period: inverse frequency with which to perform the op. 50 | counter: an optional tensorflow variable to use as a counter relative to the 51 | period. It will be incremented per call and reset to 1 in every update. In 52 | order to ensure that `body` is run in the first count, initialize the 53 | counter at a value bigger than `period`. If not given, an internal counter 54 | will be created in the graph. (not that this is incompatible with 55 | Tensorflow 2 behavior) 56 | name: name of the variable_scope. 57 | 58 | Raises: 59 | TypeError: if body is not a callable. 60 | ValueError: if period is negative. 61 | 62 | Returns: 63 | An op that periodically performs the specified op. 64 | """ 65 | if not callable(body): 66 | raise TypeError("body must be callable.") 67 | 68 | if period is None: 69 | return tf.no_op() 70 | 71 | elif isinstance(period, (int, float)): 72 | if period == 0: 73 | return tf.no_op() 74 | 75 | if period < 0: 76 | raise ValueError("period cannot be less than 0.") 77 | 78 | if period == 1: 79 | return body() 80 | 81 | if counter is None: 82 | with tf.variable_scope(None, default_name=name): 83 | counter = tf.get_variable( 84 | "counter", 85 | shape=[], 86 | dtype=tf.int64, 87 | trainable=False, 88 | initializer=tf.constant_initializer( 89 | np.iinfo(np.int64).max, dtype=tf.int64)) 90 | 91 | def _wrapped_body(): 92 | with tf.control_dependencies([body()]): 93 | # Done the deed, resets the counter. 94 | return counter.assign(1) 95 | 96 | update = tf.cond( 97 | tf.math.greater_equal(counter, tf.to_int64(period)), 98 | _wrapped_body, lambda: counter.assign_add(1)) 99 | 100 | return update 101 | -------------------------------------------------------------------------------- /trfl/policy_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for policy_ops.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # Dependency imports 22 | import tensorflow.compat.v1 as tf 23 | from trfl import policy_ops 24 | 25 | 26 | class EpsilonGreedyTest(tf.test.TestCase): 27 | 28 | def testTieBreaking(self): 29 | num_actions = 4 30 | # Given some action values that are all equal: 31 | action_values = [1.1] * num_actions 32 | epsilon = 0. 33 | 34 | # We expect the policy to be a uniform distribution. 35 | expected = [1 / num_actions] * num_actions 36 | 37 | result = policy_ops.epsilon_greedy(action_values, epsilon).probs 38 | with self.test_session() as sess: 39 | self.assertAllClose(sess.run(result), expected) 40 | 41 | def testGreedy(self): 42 | # Given some action values with one largest value: 43 | action_values = [0.5, 0.99, 0.9, 1., 0.1, -0.1, -100.] 44 | 45 | # And zero epsilon: 46 | epsilon = 0. 47 | 48 | # We expect a deterministic greedy policy that chooses one action. 49 | expected = [0., 0., 0., 1., 0., 0., 0.] 50 | 51 | result = policy_ops.epsilon_greedy(action_values, epsilon).probs 52 | with self.test_session() as sess: 53 | self.assertAllClose(sess.run(result), expected) 54 | 55 | def testDistribution(self): 56 | # Given some action values and non-zero epsilon: 57 | action_values = [0.9, 1., 0.9, 0.1, -0.6] 58 | epsilon = 0.1 59 | 60 | # We expect a distribution that concentrates the right probabilities. 61 | expected = [0.02, 0.92, 0.02, 0.02, 0.02] 62 | 63 | result = policy_ops.epsilon_greedy(action_values, epsilon).probs 64 | with self.test_session() as sess: 65 | self.assertAllClose(sess.run(result), expected) 66 | 67 | def testBatched(self): 68 | # Given batched action values: 69 | action_values = [[1., 2., 3.], 70 | [4., 5., 6.], 71 | [6., 5., 4.], 72 | [3., 2., 1.]] 73 | epsilon = 0. 74 | 75 | # We expect batched probabilities. 76 | expected = [[0., 0., 1.], 77 | [0., 0., 1.], 78 | [1., 0., 0.], 79 | [1., 0., 0.]] 80 | 81 | result = policy_ops.epsilon_greedy(action_values, epsilon).probs 82 | with self.test_session() as sess: 83 | self.assertAllClose(sess.run(result), expected) 84 | 85 | def testFloat64(self): 86 | # Given action values that are float 64: 87 | action_values = tf.convert_to_tensor([1., 2., 4., 3.], dtype=tf.float64) 88 | epsilon = 0.1 89 | 90 | expected = [0.025, 0.025, 0.925, 0.025] 91 | 92 | result = policy_ops.epsilon_greedy(action_values, epsilon).probs 93 | with self.test_session() as sess: 94 | self.assertAllClose(sess.run(result), expected) 95 | 96 | def testLegalActionsMask(self): 97 | action_values = [0.9, 1., 0.9, 0.1, -0.6] 98 | legal_actions_mask = [0., 1., 1., 1., 1.] 99 | epsilon = 0.1 100 | 101 | expected = [0.00, 0.925, 0.025, 0.025, 0.025] 102 | 103 | result = policy_ops.epsilon_greedy(action_values, epsilon, 104 | legal_actions_mask).probs 105 | with self.test_session() as sess: 106 | self.assertAllClose(sess.run(result), expected) 107 | 108 | 109 | if __name__ == "__main__": 110 | tf.test.main() 111 | -------------------------------------------------------------------------------- /trfl/base_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Utilities for Reinforcement Learning ops.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | # Dependency imports 23 | 24 | from six.moves import zip 25 | import tensorflow.compat.v1 as tf 26 | 27 | LossOutput = collections.namedtuple("loss_output", ["loss", "extra"]) 28 | 29 | 30 | def best_effort_shape(tensor, with_rank=None): 31 | """Extract as much static shape information from a tensor as possible. 32 | 33 | Args: 34 | tensor: A `Tensor`. If `with_rank` is None, must have statically-known 35 | number of dimensions. 36 | with_rank: Optional, an integer number of dimensions to force the shape to 37 | be. Useful for tensors with no static shape information that must be 38 | of a particular rank. Default is None (number of dimensions must be 39 | statically known). 40 | 41 | Returns: 42 | An iterable with length equal to the number of dimensions in `tensor`, 43 | containing integers for the dimensions with statically-known size, and 44 | scalar `Tensor`s for dimensions with size only known at run-time. 45 | 46 | Raises: 47 | ValueError: If `with_rank` is None and `tensor` does not have 48 | statically-known number of dimensions. 49 | """ 50 | tensor_shape = tf.TensorShape(tensor.shape) 51 | if with_rank: 52 | tensor_shape = tensor_shape.with_rank(with_rank) 53 | if tensor_shape.ndims is None: 54 | raise ValueError( 55 | "`tensor` does not have statically-known number of dimensions.") 56 | shape_list = tensor_shape.as_list() 57 | for idx, dim in enumerate(shape_list): 58 | if not dim: 59 | shape_list[idx] = tf.shape(tensor)[idx] 60 | return shape_list 61 | 62 | 63 | def assert_rank_and_shape_compatibility(tensors, rank): 64 | """Asserts that the tensors have the correct rank and compatible shapes. 65 | 66 | Shapes (of equal rank) are compatible if corresponding dimensions are all 67 | equal or unspecified. E.g. `[2, 3]` is compatible with all of `[2, 3]`, 68 | `[None, 3]`, `[2, None]` and `[None, None]`. 69 | 70 | Args: 71 | tensors: List of tensors. 72 | rank: A scalar specifying the rank that the tensors passed need to have. 73 | 74 | Raises: 75 | ValueError: If the list of tensors is empty or fail the rank and mutual 76 | compatibility asserts. 77 | """ 78 | if not tensors: 79 | raise ValueError("List of tensors should be non-empty.") 80 | 81 | union_of_shapes = tf.TensorShape(None) 82 | for tensor in tensors: 83 | tensor_shape = tf.TensorShape(tensor.shape) 84 | tensor_shape.assert_has_rank(rank) 85 | union_of_shapes = union_of_shapes.merge_with(tensor_shape) 86 | 87 | 88 | def wrap_rank_shape_assert(tensors_list, expected_ranks, op_name): 89 | try: 90 | for tensors, rank in zip(tensors_list, expected_ranks): 91 | assert_rank_and_shape_compatibility(tensors, rank) 92 | except ValueError as e: 93 | error_message = ("{}: Error in rank and/or " 94 | "compatibility check, {}".format(op_name, e)) 95 | tf.logging.error(error_message) 96 | raise ValueError(error_message) 97 | 98 | 99 | def assert_arg_bounded(value, min_value, max_value, op_name, arg_name): 100 | if not min_value <= value <= max_value: 101 | raise ValueError( 102 | (op_name + ": " + arg_name + " has to lie in " + 103 | "[" + str(min_value) + ", " + str(max_value) + "].")) 104 | -------------------------------------------------------------------------------- /trfl/indexing_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Indexing ops. 16 | 17 | These ops support indexing the 2D tensors representing batches of values 18 | (shape: [B, dim]) or 3D tensors representing batches of sequences 19 | of values (shape: [T, B, dim]. `T` is the length of the rollout, `B` is the 20 | batch size, and `dim` the size of the dimension that must be indexed. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | # Dependency imports 28 | import tensorflow.compat.v1 as tf 29 | 30 | 31 | def assert_compatible_shapes(value_shape, index_shape): 32 | """Check shapes of the indices and the tensor to be indexed. 33 | 34 | If all input shapes are known statically, obtain shapes of arguments and 35 | perform compatibility checks. Otherwise, print a warning. The only check 36 | we cannot perform statically (and do not attempt elsewhere) is making 37 | sure that each action index in actions is in [0, num_actions). 38 | 39 | Args: 40 | value_shape: static shape of the values. 41 | index_shape: static shape of the indices. 42 | """ 43 | # note: rank-0 "[]" TensorShape is still True. 44 | if value_shape and index_shape: 45 | try: 46 | msg = ("Shapes of \"values\" and \"indices\" do not correspond to " 47 | "minibatch (2-D) or sequence-minibatch (3-D) indexing") 48 | assert (value_shape.ndims, index_shape.ndims) in [(2, 1), (3, 2)], msg 49 | msg = ("\"values\" and \"indices\" have incompatible shapes of {} " 50 | "and {}, respectively").format(value_shape, index_shape) 51 | assert value_shape[:-1].is_compatible_with(index_shape), msg 52 | except AssertionError as e: 53 | raise ValueError(e) # Convert AssertionError to ValueError. 54 | 55 | else: # No shape information is known ahead of time. 56 | tf.logging.warning( 57 | "indexing function cannot get shapes for tensors \"values\" and " 58 | "\"indices\" at construction time, and so can't check that their " 59 | "shapes are valid or compatible. Incorrect indexing may occur at " 60 | "runtime without error!") 61 | 62 | 63 | def batched_index(values, indices, keepdims=None): 64 | """Equivalent to `values[:, indices]`. 65 | 66 | Performs indexing on batches and sequence-batches by reducing over 67 | zero-masked values. Compared to indexing with `tf.gather` this approach is 68 | more general and TPU-friendly, but may be less efficient if `num_values` 69 | is large. It works with tensors whose shapes are unspecified or 70 | partially-specified, but this op will only do shape checking on shape 71 | information available at graph construction time. When complete shape 72 | information is absent, certain shape incompatibilities may not be detected at 73 | runtime! See `indexing_ops_test` for detailed examples. 74 | 75 | Args: 76 | values: tensor of shape `[B, num_values]` or `[T, B, num_values]` 77 | indices: tensor of shape `[B]` or `[T, B]` containing indices. 78 | keepdims: If `True`, the returned tensor will have an added 1 dimension at 79 | the end (e.g. `[B, 1]` or `[T, B, 1]`). 80 | 81 | Returns: 82 | Tensor of shape `[B]` or `[T, B]` containing values for the given indices. 83 | 84 | Raises: ValueError if values and indices have sizes that are known 85 | statically (i.e. during graph construction), and those sizes are not 86 | compatible (see shape descriptions in Args list above). 87 | """ 88 | with tf.name_scope("batch_indexing", values=[values, indices]): 89 | values = tf.convert_to_tensor(values) 90 | indices = tf.convert_to_tensor(indices) 91 | assert_compatible_shapes(values.shape, indices.shape) 92 | 93 | one_hot_indices = tf.one_hot( 94 | indices, tf.shape(values)[-1], dtype=values.dtype) 95 | return tf.reduce_sum(values * one_hot_indices, axis=-1, keepdims=keepdims) 96 | -------------------------------------------------------------------------------- /trfl/continuous_retrace_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for continuous_retrace_ops.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # Dependency imports 22 | 23 | import numpy as np 24 | import six 25 | from six.moves import range 26 | import tensorflow.compat.v1 as tf 27 | 28 | from trfl import continuous_retrace_ops 29 | 30 | 31 | def _shaped_arange(*shape): 32 | """Runs np.arange, converts to float and reshapes.""" 33 | return np.arange(np.prod(shape), dtype=np.float32).reshape(*shape) 34 | 35 | 36 | def _ground_truth_calculation(discounts, log_rhos, rewards, q_values, values, 37 | bootstrap_value, lambda_): 38 | """Calculates the ground truth for Retrace in python/numpy.""" 39 | qs = [] 40 | seq_len = len(discounts) 41 | rhos = np.exp(log_rhos) 42 | cs = np.minimum(rhos, 1.0) 43 | cs *= lambda_ 44 | # This is a very inefficient way to calculate the Retrace ground truth. 45 | values_t_plus_1 = np.concatenate([values, bootstrap_value[None, :]], axis=0) 46 | for s in range(seq_len): 47 | q_s = np.copy(q_values[s]) # Very important copy... 48 | delta = rewards[s] + discounts[s] * values_t_plus_1[s + 1] - q_values[s] 49 | q_s += delta 50 | for t in range(s + 1, seq_len): 51 | q_s += ( 52 | np.prod(discounts[s:t], axis=0) * np.prod(cs[s + 1:t + 1], axis=0) * 53 | (rewards[t] + discounts[t] * values_t_plus_1[t + 1] - q_values[t])) 54 | qs.append(q_s) 55 | qs = np.stack(qs, axis=0) 56 | return qs 57 | 58 | 59 | class ContinuousRetraceTest(tf.test.TestCase): 60 | 61 | def testSingleElem(self): 62 | """Tests Retrace with a single element batch and lambda set to 1.0.""" 63 | batch_size = 1 64 | lambda_ = 1.0 65 | self._main_test(batch_size, lambda_) 66 | 67 | def testLargerBatch(self): 68 | """Tests Retrace with a larger batch.""" 69 | batch_size = 2 70 | lambda_ = 1.0 71 | self._main_test(batch_size, lambda_) 72 | 73 | def testLowerLambda(self): 74 | """Tests Retrace with a lower lambda.""" 75 | batch_size = 2 76 | lambda_ = 0.5 77 | self._main_test(batch_size, lambda_) 78 | 79 | def _main_test(self, batch_size, lambda_): 80 | """Tests Retrace against ground truth data calculated in python.""" 81 | seq_len = 5 82 | # Create log_rhos such that rho will span from near-zero to above the 83 | # clipping thresholds. In particular, calculate log_rhos in [-2.5, 2.5), 84 | # so that rho is in approx [0.08, 12.2). 85 | log_rhos = _shaped_arange(seq_len, batch_size) / (batch_size * seq_len) 86 | log_rhos = 5 * (log_rhos - 0.5) # [0.0, 1.0) -> [-2.5, 2.5). 87 | values = { 88 | "discounts": 89 | np.array( # T, B where B_i: [0.9 / (i+1)] * T 90 | [[0.9 / (b + 1) 91 | for b in range(batch_size)] 92 | for _ in range(seq_len)]), 93 | "rewards": 94 | _shaped_arange(seq_len, batch_size), 95 | "q_values": 96 | _shaped_arange(seq_len, batch_size) / batch_size, 97 | "values": 98 | _shaped_arange(seq_len, batch_size) / batch_size, 99 | "bootstrap_value": 100 | _shaped_arange(batch_size) + 1.0, # B 101 | "log_rhos": 102 | log_rhos 103 | } 104 | placeholders = { 105 | key: tf.placeholder(tf.float32, shape=val.shape) 106 | for key, val in six.iteritems(values) 107 | } 108 | placeholders = { 109 | k: tf.placeholder(dtype=p.dtype, shape=[None] * len(p.shape)) 110 | for k, p in placeholders.items() 111 | } 112 | 113 | retrace_returns = continuous_retrace_ops.retrace_from_importance_weights( 114 | lambda_=lambda_, **placeholders) 115 | 116 | feed_dict = {placeholders[k]: v for k, v in values.items()} 117 | with self.test_session() as sess: 118 | retrace_outputvalues = sess.run(retrace_returns, feed_dict=feed_dict) 119 | 120 | ground_truth_data = _ground_truth_calculation(lambda_=lambda_, **values) 121 | 122 | self.assertAllClose(ground_truth_data, retrace_outputvalues.qs) 123 | 124 | 125 | if __name__ == "__main__": 126 | tf.test.main() 127 | -------------------------------------------------------------------------------- /trfl/dpg_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for dpg_ops.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # Dependency imports 22 | import numpy as np 23 | import tensorflow.compat.v1 as tf 24 | from trfl import dpg_ops 25 | 26 | 27 | class DpgTest(tf.test.TestCase): 28 | """Tests for DpgLearning. 29 | """ 30 | 31 | def setUp(self): 32 | """Sets up test scenario. 33 | 34 | a_tm1_max = s_tm1 * w_s + b_s 35 | q_tm1_max = a_tm1_max * w + b 36 | """ 37 | super(DpgTest, self).setUp() 38 | self.s_tm1 = tf.constant([[0, 1, 0], [1, 1, 2]], dtype=tf.float32) 39 | self.w_s = tf.Variable(tf.random_normal([3, 2]), dtype=tf.float32) 40 | self.b_s = tf.Variable(tf.zeros([2]), dtype=tf.float32) 41 | self.a_tm1_max = tf.matmul(self.s_tm1, self.w_s) + self.b_s 42 | self.w = tf.Variable(tf.random_normal([2, 1]), dtype=tf.float32) 43 | self.b = tf.Variable(tf.zeros([1]), dtype=tf.float32) 44 | self.q_tm1_max = tf.matmul(self.a_tm1_max, self.w) + self.b 45 | self.loss, self.dpg_extra = dpg_ops.dpg(self.q_tm1_max, self.a_tm1_max) 46 | self.batch_size = self.a_tm1_max.get_shape()[0] 47 | 48 | def testDpgNoGradient(self): 49 | """Test case: q_tm1_max does not depend on a_tm1_max => exception raised. 50 | """ 51 | with self.test_session(): 52 | a_tm1_max = tf.constant([[0, 1, 0], [1, 1, 2]]) 53 | q_tm1_max = tf.constant([[1], [0]]) 54 | self.assertRaises(ValueError, dpg_ops.dpg, q_tm1_max, a_tm1_max) 55 | 56 | def testDpgDqda(self): 57 | """Tests the gradient qd/qa produced by the DPGLearner is correct.""" 58 | with self.test_session() as sess: 59 | sess.run(tf.global_variables_initializer()) 60 | value_grad = np.transpose(self.w.eval())[0] 61 | for i in range(int(self.batch_size)): 62 | self.assertAllClose(self.dpg_extra.dqda.eval()[i], value_grad) 63 | 64 | def testDpgGradient(self): 65 | """Gradient of loss w.r.t. actor network parameter w_s is correct.""" 66 | with self.test_session() as sess: 67 | weight_gradient = tf.gradients(self.loss, self.w_s) 68 | sess.run(tf.global_variables_initializer()) 69 | value_dpg_gradient, value_s_tm1, value_w = sess.run( 70 | [weight_gradient[0], self.s_tm1, self.w]) 71 | true_grad = self.calculateTrueGradient(value_w, value_s_tm1) 72 | self.assertAllClose(value_dpg_gradient, true_grad) 73 | 74 | def testDpgNoOtherGradients(self): 75 | """No gradient of loss w.r.t. parameters other than that of actor network. 76 | """ 77 | with self.test_session(): 78 | gradients = tf.gradients([self.loss], [self.q_tm1_max, self.w, self.b]) 79 | self.assertListEqual(gradients, [None] * len(gradients)) 80 | 81 | def testDpgDqdaClippingError(self): 82 | self.assertRaises( 83 | ValueError, dpg_ops.dpg, 84 | self.q_tm1_max, self.a_tm1_max, dqda_clipping=-10) 85 | 86 | def testDpgGradientClipping(self): 87 | """Tests the gradient qd/qa are clipped.""" 88 | _, dpg_extra = dpg_ops.dpg( 89 | self.q_tm1_max, self.a_tm1_max, dqda_clipping=0.01) 90 | with self.test_session() as sess: 91 | sess.run(tf.global_variables_initializer()) 92 | value_grad = np.transpose(self.w.eval())[0] 93 | for i in range(int(self.batch_size)): 94 | self.assertAllClose(dpg_extra.dqda.eval()[i], 95 | np.clip(value_grad, -0.01, 0.01)) 96 | self.assertTrue(np.greater(np.absolute(value_grad), 0.01).any()) 97 | 98 | def testDpgGradientNormClipping(self): 99 | """Tests the gradient qd/qa are clipped using norm clipping.""" 100 | _, dpg_extra = dpg_ops.dpg( 101 | self.q_tm1_max, self.a_tm1_max, dqda_clipping=0.01, clip_norm=True) 102 | with self.test_session() as sess: 103 | sess.run(tf.global_variables_initializer()) 104 | for i in range(int(self.batch_size)): 105 | self.assertAllClose(np.linalg.norm(dpg_extra.dqda.eval()[i]), 0.01) 106 | 107 | def testLossShape(self): 108 | self.assertEqual(self.loss.shape.as_list(), [self.batch_size]) 109 | 110 | def calculateTrueGradient(self, value_w, value_s_tm1): 111 | """Calculates the true gradient over the batch. 112 | 113 | sum_k dq/dw_s = sum_k dq/da * da/dw_s 114 | = w * sum_k da/dw_s 115 | 116 | Args: 117 | value_w: numpy.ndarray containing weights of the linear layer. 118 | value_s_tm1: state representation. 119 | 120 | Returns: 121 | The true_gradient of the test case. 122 | """ 123 | dadws = np.zeros((value_w.shape[0], 124 | np.product(self.w_s.get_shape().as_list()))) 125 | for i in range(self.batch_size): 126 | dadws += np.vstack((np.hstack((value_s_tm1[i], np.zeros(3))), 127 | np.hstack((np.zeros(3), value_s_tm1[i])))) 128 | true_grad = np.dot(np.transpose(value_w), dadws) 129 | true_grad = -np.transpose(np.reshape( 130 | true_grad, self.w_s.get_shape().as_list()[::-1])) 131 | return true_grad 132 | 133 | 134 | if __name__ == '__main__': 135 | tf.test.main() 136 | -------------------------------------------------------------------------------- /trfl/target_update_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tensorflow ops for updating target networks. 16 | 17 | Tensorflow ops that are used to update a target network from a source network. 18 | This is used in agents such as DQN or DPG, which use a target network that 19 | changes more slowly than the online network, in order to improve stability. 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | # Dependency imports 27 | import tensorflow.compat.v1 as tf 28 | from trfl import periodic_ops 29 | 30 | 31 | def update_target_variables(target_variables, 32 | source_variables, 33 | tau=1.0, 34 | use_locking=False, 35 | name="update_target_variables"): 36 | """Returns an op to update a list of target variables from source variables. 37 | 38 | The update rule is: 39 | `target_variable = (1 - tau) * target_variable + tau * source_variable`. 40 | 41 | Args: 42 | target_variables: a list of the variables to be updated. 43 | source_variables: a list of the variables used for the update. 44 | tau: weight used to gate the update. The permitted range is 0 < tau <= 1, 45 | with small tau representing an incremental update, and tau == 1 46 | representing a full update (that is, a straight copy). 47 | use_locking: use `tf.Variable.assign`'s locking option when assigning 48 | source variable values to target variables. 49 | name: sets the `name_scope` for this op. 50 | 51 | Raises: 52 | TypeError: when tau is not a Python float 53 | ValueError: when tau is out of range, or the source and target variables 54 | have different numbers or shapes. 55 | 56 | Returns: 57 | An op that executes all the variable updates. 58 | """ 59 | if not isinstance(tau, float) and not tf.is_tensor(tau): 60 | raise TypeError("Tau has wrong type (should be float) {}".format(tau)) 61 | if not tf.is_tensor(tau) and not 0.0 < tau <= 1.0: 62 | raise ValueError("Invalid parameter tau {}".format(tau)) 63 | if len(target_variables) != len(source_variables): 64 | raise ValueError("Number of target variables {} is not the same as " 65 | "number of source variables {}".format( 66 | len(target_variables), len(source_variables))) 67 | 68 | same_shape = all(trg.get_shape() == src.get_shape() 69 | for trg, src in zip(target_variables, source_variables)) 70 | if not same_shape: 71 | raise ValueError("Target variables don't have the same shape as source " 72 | "variables.") 73 | 74 | def update_op(target_variable, source_variable, tau): 75 | if tau == 1.0: 76 | return target_variable.assign(source_variable, use_locking) 77 | else: 78 | return target_variable.assign( 79 | tau * source_variable + (1.0 - tau) * target_variable, use_locking) 80 | 81 | with tf.name_scope(name, values=target_variables + source_variables): 82 | update_ops = [update_op(target_var, source_var, tau) 83 | for target_var, source_var 84 | in zip(target_variables, source_variables)] 85 | return tf.group(name="update_all_variables", *update_ops) 86 | 87 | 88 | def periodic_target_update(target_variables, 89 | source_variables, 90 | update_period, 91 | tau=1.0, 92 | use_locking=False, 93 | counter=None, 94 | name="periodic_target_update"): 95 | """Returns an op to periodically update a list of target variables. 96 | 97 | The `update_target_variables` op is executed every `update_period` 98 | executions of the `periodic_target_update` op. 99 | 100 | The update rule is: 101 | `target_variable = (1 - tau) * target_variable + tau * source_variable`. 102 | 103 | Args: 104 | target_variables: a list of the variables to be updated. 105 | source_variables: a list of the variables used for the update. 106 | update_period: inverse frequency with which to apply the update. 107 | tau: weight used to gate the update. The permitted range is 0 < tau <= 1, 108 | with small tau representing an incremental update, and tau == 1 109 | representing a full update (that is, a straight copy). 110 | use_locking: use `tf.variable.Assign`'s locking option when assigning 111 | source variable values to target variables. 112 | counter: an optional tensorflow variable to use as a counter relative to 113 | `update_period`, which be passed to `periodic_ops.periodically`. See 114 | description in `periodic_ops.periodically` for details. 115 | name: sets the `name_scope` for this op. 116 | 117 | Returns: 118 | An op that periodically updates `target_variables` with `source_variables`. 119 | """ 120 | 121 | def update_op(): 122 | return update_target_variables( 123 | target_variables, source_variables, tau, use_locking) 124 | 125 | with tf.name_scope(name, values=target_variables + source_variables): 126 | return periodic_ops.periodically(update_op, update_period, counter=counter) 127 | -------------------------------------------------------------------------------- /trfl/target_update_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for target_update_ops.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # Dependency imports 22 | from absl.testing import parameterized 23 | import tensorflow.compat.v1 as tf 24 | from trfl import target_update_ops 25 | 26 | 27 | class UpdateTargetVariablesTest(tf.test.TestCase, parameterized.TestCase): 28 | """Tests function update_target_variables.""" 29 | 30 | @parameterized.parameters({'use_locking': True}, {'use_locking': False}) 31 | def testFullUpdate(self, use_locking): 32 | """Tests full update of the target variables from the source variables.""" 33 | target_variables = [ 34 | tf.Variable(tf.random_normal(shape=[1, 2])), 35 | tf.Variable(tf.random_normal(shape=[3, 4])), 36 | ] 37 | source_variables = [ 38 | tf.Variable(tf.random_normal(shape=[1, 2])), 39 | tf.Variable(tf.random_normal(shape=[3, 4])), 40 | ] 41 | updated = target_update_ops.update_target_variables( 42 | target_variables, source_variables, use_locking=use_locking) 43 | 44 | # Collect all the tensors and ops we want to evaluate in the session. 45 | vars_ops = target_variables + source_variables 46 | 47 | with self.test_session() as sess: 48 | sess.run(tf.global_variables_initializer()) 49 | sess.run(updated) 50 | results = sess.run(vars_ops) 51 | # First target variable is updated with first source variable. 52 | self.assertAllClose(results[0], results[2]) 53 | # Second target variable is updated with second source variable. 54 | self.assertAllClose(results[1], results[3]) 55 | 56 | @parameterized.parameters({'use_locking': True}, {'use_locking': False}) 57 | def testIncrementalUpdate(self, use_locking): 58 | """Tests incremental update of the target variables.""" 59 | target_variables = [tf.Variable(tf.random_normal(shape=[1, 2]))] 60 | source_variables = [tf.Variable(tf.random_normal(shape=[1, 2]))] 61 | updated = target_update_ops.update_target_variables( 62 | target_variables, source_variables, tau=0.1, use_locking=use_locking) 63 | 64 | with self.test_session() as sess: 65 | sess.run(tf.global_variables_initializer()) 66 | before_assign = sess.run(target_variables[0]) 67 | sess.run(updated) 68 | results = sess.run([target_variables[0], source_variables[0]]) 69 | self.assertAllClose(results[0], 0.1 * results[1] + 0.9 * before_assign) 70 | 71 | def testIncompatibleLength(self): 72 | """Tests error when variable lists have unequal length.""" 73 | with self.test_session(): 74 | target_variables = [tf.Variable(tf.random_normal(shape=[1, 2]))] 75 | source_variables = [ 76 | tf.Variable(tf.random_normal(shape=[1, 2])), 77 | tf.Variable(tf.random_normal(shape=[3, 4])), 78 | ] 79 | self.assertRaises(ValueError, target_update_ops.update_target_variables, 80 | target_variables, source_variables) 81 | 82 | def testIncompatibleShape(self): 83 | """Tests error when variable lists have unequal shapes.""" 84 | with self.test_session(): 85 | target_variables = [ 86 | tf.Variable(tf.random_normal(shape=[1, 2])), 87 | tf.Variable(tf.random_normal(shape=[1, 2])), 88 | ] 89 | source_variables = [ 90 | tf.Variable(tf.random_normal(shape=[1, 2])), 91 | tf.Variable(tf.random_normal(shape=[3, 4])), 92 | ] 93 | self.assertRaises(ValueError, target_update_ops.update_target_variables, 94 | target_variables, source_variables) 95 | 96 | def testInvalidTypeTau(self): 97 | """Tests error when tau has wrong type.""" 98 | target_variables = [tf.Variable(tf.random_normal(shape=[1, 2]))] 99 | source_variables = [tf.Variable(tf.random_normal(shape=[1, 2]))] 100 | self.assertRaises(TypeError, target_update_ops.update_target_variables, 101 | target_variables, source_variables, 1) 102 | 103 | def testInvalidRangeTau(self): 104 | """Tests error when tau is outside permitted range.""" 105 | target_variables = [tf.Variable(tf.random_normal(shape=[1, 2]))] 106 | source_variables = [tf.Variable(tf.random_normal(shape=[1, 2]))] 107 | self.assertRaises(ValueError, target_update_ops.update_target_variables, 108 | target_variables, source_variables, -0.1) 109 | self.assertRaises(ValueError, target_update_ops.update_target_variables, 110 | target_variables, source_variables, 1.1) 111 | 112 | 113 | class PeriodicTargetUpdateTest(tf.test.TestCase, parameterized.TestCase): 114 | """Tests function period_target_update.""" 115 | 116 | @parameterized.parameters( 117 | {'use_locking': True, 'update_period': 1}, 118 | {'use_locking': False, 'update_period': 1}, 119 | {'use_locking': True, 'update_period': 3}, 120 | {'use_locking': False, 'update_period': 3} 121 | ) 122 | def testPeriodicTargetUpdate(self, use_locking, update_period): 123 | """Tests that the simple success case works as expected. 124 | 125 | This is an integration test. The periodically and update parts are 126 | unit-tested in the preceding. 127 | 128 | Args: 129 | use_locking: value for `periodic_target_update`'s `use_locking` argument. 130 | update_period: how often an update should happen. 131 | """ 132 | target_variables = [tf.Variable(tf.zeros([1, 2]))] 133 | source_variables = [tf.Variable(tf.random_normal([1, 2]))] 134 | increment = tf.ones([1, 2]) 135 | 136 | update_source_op = tf.assign_add(source_variables[0], increment) 137 | updated = target_update_ops.periodic_target_update( 138 | target_variables, 139 | source_variables, 140 | update_period=update_period, 141 | use_locking=use_locking) 142 | 143 | with self.test_session() as sess: 144 | sess.run(tf.global_variables_initializer()) 145 | 146 | for step in range(3 * update_period): 147 | sess.run(update_source_op) 148 | sess.run(updated) 149 | targets, sources = sess.run([target_variables, source_variables]) 150 | 151 | if step % update_period == 0: 152 | self.assertAllClose(targets, sources) 153 | else: 154 | self.assertNotAllClose(targets, sources) 155 | 156 | 157 | if __name__ == '__main__': 158 | tf.test.main() 159 | 160 | -------------------------------------------------------------------------------- /trfl/distribution_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for distribution_ops.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import itertools 21 | # Dependency imports 22 | from absl.testing import parameterized 23 | import tensorflow.compat.v1 as tf 24 | import tensorflow_probability as tfp 25 | from trfl import distribution_ops 26 | 27 | 28 | l2_project = distribution_ops.l2_project 29 | _MULTIVARIATE_GAUSSIAN_TYPES = [ 30 | tfp.distributions.MultivariateNormalDiagPlusLowRank, 31 | tfp.distributions.MultivariateNormalDiag, 32 | tfp.distributions.MultivariateNormalTriL, 33 | tfp.distributions.MultivariateNormalFullCovariance 34 | ] 35 | 36 | 37 | class FactorisedKLGaussianTest(tf.test.TestCase, parameterized.TestCase): 38 | 39 | def _create_gaussian(self, gaussian_type): 40 | mu = tf.random_normal([3]) 41 | if gaussian_type == tfp.distributions.MultivariateNormalDiag: 42 | scale_diag = tf.random_normal([3]) 43 | dist = tfp.distributions.MultivariateNormalDiag(mu, scale_diag) 44 | if gaussian_type == tfp.distributions.MultivariateNormalDiagPlusLowRank: 45 | scale_diag = tf.random_normal([3]) 46 | perturb_factor = tf.random_normal([3, 2]) 47 | scale_perturb_diag = tf.random_normal([2]) 48 | dist = tfp.distributions.MultivariateNormalDiagPlusLowRank( 49 | mu, 50 | scale_diag, 51 | scale_perturb_factor=perturb_factor, 52 | scale_perturb_diag=scale_perturb_diag) 53 | if gaussian_type == tfp.distributions.MultivariateNormalTriL: 54 | cov = tf.random_uniform([3, 3], minval=0, maxval=1.0) 55 | # Create a PSD matrix. 56 | cov = 0.5 * (cov + tf.transpose(cov)) + 3 * tf.eye(3) 57 | scale = tf.cholesky(cov) 58 | dist = tfp.distributions.MultivariateNormalTriL(mu, scale) 59 | if gaussian_type == tfp.distributions.MultivariateNormalFullCovariance: 60 | cov = tf.random_uniform([3, 3], minval=0, maxval=1.0) 61 | # Create a PSD matrix. 62 | cov = 0.5 * (cov + tf.transpose(cov)) + 3 * tf.eye(3) 63 | dist = tfp.distributions.MultivariateNormalFullCovariance(mu, cov) 64 | return (dist, mu, dist.covariance()) 65 | 66 | @parameterized.parameters( 67 | itertools.product(_MULTIVARIATE_GAUSSIAN_TYPES, 68 | _MULTIVARIATE_GAUSSIAN_TYPES)) 69 | def testFactorisedKLGaussian(self, dist1_type, dist2_type): 70 | """Tests that the factorised KL terms sum up to the true KL.""" 71 | dist1, dist1_mean, dist1_cov = self._create_gaussian(dist1_type) 72 | dist2, dist2_mean, dist2_cov = self._create_gaussian(dist2_type) 73 | both_diagonal = _is_diagonal(dist1.scale) and _is_diagonal(dist2.scale) 74 | if both_diagonal: 75 | dist1_cov = dist1.parameters['scale_diag'] 76 | dist2_cov = dist2.parameters['scale_diag'] 77 | kl = tfp.distributions.kl_divergence(dist1, dist2) 78 | kl_mean, kl_cov = distribution_ops.factorised_kl_gaussian( 79 | dist1_mean, 80 | dist1_cov, 81 | dist2_mean, 82 | dist2_cov, 83 | both_diagonal=both_diagonal) 84 | with self.test_session() as sess: 85 | sess.run(tf.global_variables_initializer()) 86 | actual_kl, kl_mean_np, kl_cov_np = sess.run([kl, kl_mean, kl_cov]) 87 | self.assertAllClose(actual_kl, kl_mean_np + kl_cov_np, rtol=1e-4) 88 | 89 | def testShapeAssertion(self): 90 | dist_type = tfp.distributions.MultivariateNormalDiag 91 | _, dist1_mean, dist1_cov = self._create_gaussian(dist_type) 92 | _, dist2_mean, dist2_cov = self._create_gaussian(dist_type) 93 | shape_error_regexp = 'Shape (.*) must have rank [0-9]+' 94 | with self.assertRaisesRegexp(ValueError, shape_error_regexp): 95 | distribution_ops.factorised_kl_gaussian( 96 | dist1_mean, dist1_cov, dist2_mean, dist2_cov, both_diagonal=True) 97 | 98 | def testConsistentGradientsBothDiagonal(self): 99 | dist_type = tfp.distributions.MultivariateNormalDiag 100 | dist1, dist1_mean, _ = self._create_gaussian(dist_type) 101 | dist2, dist2_mean, _ = self._create_gaussian(dist_type) 102 | 103 | kl = tfp.distributions.kl_divergence(dist1, dist2) 104 | dist1_scale = dist1.parameters['scale_diag'] 105 | dist2_scale = dist2.parameters['scale_diag'] 106 | kl_mean, kl_cov = distribution_ops.factorised_kl_gaussian( 107 | dist1_mean, dist1_scale, dist2_mean, dist2_scale, both_diagonal=True) 108 | 109 | dist_params = [dist1_mean, dist2_mean, dist1_scale, dist2_scale] 110 | actual_kl_gradients = tf.gradients(kl, dist_params) 111 | factorised_kl_gradients = tf.gradients(kl_mean + kl_cov, dist_params) 112 | 113 | # Check that no gradients flow into the mean terms from `kl_cov` and 114 | # vice-versa. 115 | gradients = tf.gradients(kl_mean, [dist1_scale]) 116 | self.assertListEqual(gradients, [None]) 117 | gradients = tf.gradients(kl_cov, [dist1_mean, dist2_mean]) 118 | self.assertListEqual(gradients, [None, None]) 119 | 120 | with self.test_session() as sess: 121 | np_actual_kl, np_factorised_kl = sess.run( 122 | [actual_kl_gradients, factorised_kl_gradients]) 123 | self.assertAllClose(np_actual_kl, np_factorised_kl) 124 | 125 | def testConsistentGradientsFullCovariance(self): 126 | dist_type = tfp.distributions.MultivariateNormalFullCovariance 127 | dist1, dist1_mean, dist1_cov = self._create_gaussian(dist_type) 128 | dist2, dist2_mean, dist2_cov = self._create_gaussian(dist_type) 129 | 130 | kl = tfp.distributions.kl_divergence(dist1, dist2) 131 | kl_mean, kl_cov = distribution_ops.factorised_kl_gaussian( 132 | dist1_mean, dist1_cov, dist2_mean, dist2_cov, both_diagonal=False) 133 | 134 | dist1_cov = dist1.parameters['covariance_matrix'] 135 | dist2_cov = dist2.parameters['covariance_matrix'] 136 | dist_params = [ 137 | dist1_mean, 138 | dist2_mean, 139 | dist1_cov, 140 | dist2_cov, 141 | ] 142 | actual_kl_gradients = tf.gradients(kl, dist_params) 143 | factorised_kl_gradients = tf.gradients(kl_mean + kl_cov, dist_params) 144 | 145 | # Check that no gradients flow into the mean terms from `kl_cov` and 146 | # vice-versa. 147 | gradients = tf.gradients(kl_mean, [dist1_cov]) 148 | self.assertListEqual(gradients, [None]) 149 | gradients = tf.gradients(kl_cov, [dist1_mean, dist2_mean]) 150 | self.assertListEqual(gradients, [None, None]) 151 | 152 | with self.test_session() as sess: 153 | np_actual_kl, np_factorised_kl = sess.run( 154 | [actual_kl_gradients, factorised_kl_gradients]) 155 | self.assertAllClose(np_actual_kl, np_factorised_kl) 156 | 157 | 158 | # Check for diagonal Gaussian distributions. Based on the definition in 159 | # tensorflow_probability/python/distributions/mvn_linear_operator.py 160 | def _is_diagonal(x): 161 | """Helper to identify if `LinearOperator` has only a diagonal component.""" 162 | return (isinstance(x, tf.linalg.LinearOperatorIdentity) or 163 | isinstance(x, tf.linalg.LinearOperatorScaledIdentity) or 164 | isinstance(x, tf.linalg.LinearOperatorDiag)) 165 | 166 | 167 | if __name__ == '__main__': 168 | tf.test.main() 169 | -------------------------------------------------------------------------------- /trfl/distribution_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """TensorFlow ops for various distribution projection operations. 16 | 17 | All ops support multidimensional tensors. All dimensions except for the last 18 | one can be considered as batch dimensions. They are processed in parallel 19 | and are fully independent. The last dimension represents the number of bins. 20 | 21 | The op supports broadcasting across all dimensions except for the last one. 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | 28 | import tensorflow.compat.v1 as tf 29 | import tensorflow_probability as tfp 30 | 31 | 32 | def l2_project(z_p, p, z_q): 33 | """Projects distribution (z_p, p) onto support z_q under L2-metric over CDFs. 34 | 35 | The supports z_p and z_q are specified as tensors of distinct atoms (given 36 | in ascending order). 37 | 38 | Let Kq be len(z_q) and Kp be len(z_p). This projection works for any 39 | support z_q, in particular Kq need not be equal to Kp. 40 | 41 | Args: 42 | z_p: Tensor holding support of distribution p, shape `[batch_size, Kp]`. 43 | p: Tensor holding probability values p(z_p[i]), shape `[batch_size, Kp]`. 44 | z_q: Tensor holding support to project onto, shape `[Kq]`. 45 | 46 | Returns: 47 | Projection of (z_p, p) onto support z_q under Cramer distance. 48 | """ 49 | # Broadcasting of tensors is used extensively in the code below. To avoid 50 | # accidental broadcasting along unintended dimensions, tensors are defensively 51 | # reshaped to have equal number of dimensions (3) throughout and intended 52 | # shapes are indicated alongside tensor definitions. To reduce verbosity, 53 | # extra dimensions of size 1 are inserted by indexing with `None` instead of 54 | # `tf.expand_dims()` (e.g., `x[:, None, :]` reshapes a tensor of shape 55 | # `[k, l]' to one of shape `[k, 1, l]`). 56 | 57 | # Extract vmin and vmax and construct helper tensors from z_q 58 | vmin, vmax = z_q[0], z_q[-1] 59 | d_pos = tf.concat([z_q, vmin[None]], 0)[1:] # 1 x Kq x 1 60 | d_neg = tf.concat([vmax[None], z_q], 0)[:-1] # 1 x Kq x 1 61 | # Clip z_p to be in new support range (vmin, vmax). 62 | z_p = tf.clip_by_value(z_p, vmin, vmax)[:, None, :] # B x 1 x Kp 63 | 64 | # Get the distance between atom values in support. 65 | d_pos = (d_pos - z_q)[None, :, None] # z_q[i+1] - z_q[i]. 1 x B x 1 66 | d_neg = (z_q - d_neg)[None, :, None] # z_q[i] - z_q[i-1]. 1 x B x 1 67 | z_q = z_q[None, :, None] # 1 x Kq x 1 68 | 69 | # Ensure that we do not divide by zero, in case of atoms of identical value. 70 | d_neg = tf.where(d_neg > 0, 1./d_neg, tf.zeros_like(d_neg)) # 1 x Kq x 1 71 | d_pos = tf.where(d_pos > 0, 1./d_pos, tf.zeros_like(d_pos)) # 1 x Kq x 1 72 | 73 | delta_qp = z_p - z_q # clip(z_p)[j] - z_q[i]. B x Kq x Kp 74 | d_sign = tf.cast(delta_qp >= 0., dtype=p.dtype) # B x Kq x Kp 75 | 76 | # Matrix of entries sgn(a_ij) * |a_ij|, with a_ij = clip(z_p)[j] - z_q[i]. 77 | # Shape B x Kq x Kp. 78 | delta_hat = (d_sign * delta_qp * d_pos) - ((1. - d_sign) * delta_qp * d_neg) 79 | p = p[:, None, :] # B x 1 x Kp. 80 | return tf.reduce_sum(tf.clip_by_value(1. - delta_hat, 0., 1.) * p, 2) 81 | 82 | 83 | def factorised_kl_gaussian(dist1_mean, 84 | dist1_covariance_or_scale, 85 | dist2_mean, 86 | dist2_covariance_or_scale, 87 | both_diagonal=False): 88 | """Compute the KL divergence KL(dist1, dist2) between two Gaussians. 89 | 90 | The KL is factorised into two terms - `kl_mean` and `kl_cov`. This 91 | factorisation is specific to multivariate gaussian distributions and arises 92 | from its analytic form. 93 | Specifically, if we assume two multivariate Gaussian distributions with rank 94 | k and means, M1 and M2 and variance S1 and S2, the analytic KL can be written 95 | out as: 96 | 97 | D_KL(N0 || N1) = 0.5 * (tr(inv(S1) * S0) + ln(det(S1)/det(S0)) - k + 98 | (M1 - M0).T * inv(S1) * (M1 - M0)) 99 | 100 | The terms on the first row correspond to the covariance factor and the terms 101 | on the second row correspond to the mean factor in the factorized KL. 102 | These terms can thus be used to independently control how much the mean and 103 | covariance between the two gaussians can vary. 104 | 105 | This implementation ensures that gradient flow is equivalent to calling 106 | `tfp.distributions.kl_divergence` once. 107 | 108 | More details on the equation can be found here: 109 | https://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians 110 | 111 | 112 | Args: 113 | dist1_mean: The mean of the first Multivariate Gaussian distribution. 114 | dist1_covariance_or_scale: The covariance or scale of the first Multivariate 115 | Gaussian distribution. In cases where *both* distributions are Gaussians 116 | with diagonal covariance matrices (for instance, if both are instances of 117 | `tfp.distributions.MultivariateNormalDiag`), then the `scale` can be 118 | passed in instead and the `both_diagonal` flag must be set to `True`. 119 | A more efficient sparse computation path is used in this case. For all 120 | other cases, the full covariance matrix must be passed in. 121 | dist2_mean: The mean of the second Multivariate Gaussian distribution. 122 | dist2_covariance_or_scale: The covariance or scale tensor of the second 123 | Multivariate Gaussian distribution, as for `dist1_covariance_or_scale`. 124 | both_diagonal: A `bool` indicating that both dist1 and dist2 are diagonal 125 | matrices. A more efficient sparse computation is used in this case. 126 | 127 | Returns: 128 | A tuple consisting of (`kl_mean`, `kl_cov`) which correspond to the mean and 129 | the covariance factorisation of the KL. 130 | """ 131 | if both_diagonal: 132 | dist1_mean_rank = dist1_mean.get_shape().ndims 133 | dist1_covariance_or_scale.get_shape().assert_has_rank(dist1_mean_rank) 134 | dist2_mean_rank = dist2_mean.get_shape().ndims 135 | dist2_covariance_or_scale.get_shape().assert_has_rank(dist2_mean_rank) 136 | 137 | dist_type = tfp.distributions.MultivariateNormalDiag 138 | else: 139 | dist_type = tfp.distributions.MultivariateNormalFullCovariance 140 | 141 | # Recreate the distributions but with stop gradients on the mean and cov. 142 | dist1_stop_grad_mean = dist_type( 143 | tf.stop_gradient(dist1_mean), dist1_covariance_or_scale) 144 | dist2 = dist_type(dist2_mean, dist2_covariance_or_scale) 145 | 146 | # Now create a third distribution with the mean of dist1 and the variance of 147 | # dist2 and appropriate stop_gradients. 148 | dist3 = dist_type(dist1_mean, dist2_covariance_or_scale) 149 | dist3_stop_grad_mean = dist_type( 150 | tf.stop_gradient(dist1_mean), dist2_covariance_or_scale) 151 | 152 | # Finally get the two components of the KL between dist1 and dist2 153 | # using dist3 154 | kl_mean = tfp.distributions.kl_divergence(dist3, dist2) 155 | kl_cov = tfp.distributions.kl_divergence(dist1_stop_grad_mean, 156 | dist3_stop_grad_mean) 157 | return kl_mean, kl_cov 158 | -------------------------------------------------------------------------------- /trfl/pixel_control_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for pixel_control_ops.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # Dependency imports 22 | import numpy as np 23 | import tensorflow.compat.v1 as tf 24 | from trfl import pixel_control_ops 25 | 26 | 27 | class PixelControlRewardsTest(tf.test.TestCase): 28 | """Test the `pixel_control_rewards` op.""" 29 | 30 | def setUp(self): 31 | """Defines example data and expected result for the op.""" 32 | super(PixelControlRewardsTest, self).setUp() 33 | 34 | # Configure. 35 | self._cell = 2 36 | obs_size = (5, 2, 4, 4, 3, 2) 37 | y = obs_size[2] // self._cell 38 | x = obs_size[3] // self._cell 39 | channels = np.prod(obs_size[4:]) 40 | rew_size = (obs_size[0]-1, obs_size[1], x, y) 41 | 42 | # Input data. 43 | self._obs_np = np.random.uniform(size=obs_size) 44 | self._obs_tf = tf.placeholder(tf.float32, obs_size) 45 | 46 | # Expected pseudo-rewards. 47 | abs_diff = np.absolute(self._obs_np[1:] - self._obs_np[:-1]) 48 | abs_diff = abs_diff.reshape((-1,) + obs_size[2:4] + (channels,)) 49 | abs_diff = abs_diff.reshape((-1, y, self._cell, x, self._cell, channels)) 50 | avg_abs_diff = abs_diff.mean(axis=(2, 4, 5)) 51 | self._expected_pseudo_rewards = avg_abs_diff.reshape(rew_size) 52 | 53 | def testPixelControlRewards(self): 54 | """Compute pseudo rewards from observations.""" 55 | pseudo_rewards_tf = pixel_control_ops.pixel_control_rewards( 56 | self._obs_tf, self._cell) 57 | 58 | with self.test_session() as sess: 59 | self.assertAllClose( 60 | sess.run(pseudo_rewards_tf, feed_dict={self._obs_tf: self._obs_np}), 61 | self._expected_pseudo_rewards) 62 | 63 | 64 | class PixelControlLossTest(tf.test.TestCase): 65 | """Test the `pixel_control_loss` op.""" 66 | 67 | def setUp(self): 68 | """Defines example data and expected result for the op.""" 69 | super(PixelControlLossTest, self).setUp() 70 | 71 | # Observation shape is (2,2,3) (i.e., height 2, width 2, and 3 channels). 72 | # We will use no cropping, and a cell size of 1. We have num_actions = 3, 73 | # meaning our Q values should be (2,2,3). We will set the Q value equal to 74 | # the observation. 75 | self.seq_length = 3 76 | self.batch_size = 1 77 | num_actions = 3 78 | obs_shape = (2, 2, num_actions) 79 | self.discount = 0.9 80 | self.cell_size = 1 81 | self.scale = 1.0 82 | 83 | # Create ops to feed actions and rewards. 84 | self.observations_ph = tf.placeholder( 85 | shape=(self.seq_length+1, self.batch_size)+obs_shape, dtype=tf.float32) 86 | self.action_values_ph = tf.placeholder( 87 | shape=(self.seq_length+1, self.batch_size)+obs_shape, dtype=tf.float32) 88 | self.actions_ph = tf.placeholder( 89 | shape=(self.seq_length, self.batch_size), dtype=tf.int32) 90 | 91 | # Observations. 92 | obs1 = np.array([[[1, 2, 3], [3, 4, 5]], [[5, 6, 7], [7, 8, 9]]]) 93 | obs2 = np.array([[[7, 8, 9], [1, 2, 3]], [[3, 4, 5], [5, 6, 7]]]) 94 | obs3 = np.array([[[5, 6, 7], [7, 8, 9]], [[1, 2, 3], [3, 4, 5]]]) 95 | obs4 = np.array([[[3, 4, 5], [5, 6, 7]], [[7, 8, 9], [1, 2, 3]]]) 96 | 97 | # Actions. 98 | action1 = 0 99 | action2 = 1 100 | action3 = 2 101 | 102 | # Compute loss for constant discount. 103 | qa_tm1 = obs3[:, :, action3] 104 | reward3 = np.mean(np.abs(obs4 - obs3), axis=2) 105 | qmax_t = np.amax(obs4, axis=2) 106 | target = reward3 + self.discount * qmax_t 107 | error3 = target - qa_tm1 108 | 109 | qa_tm1 = obs2[:, :, action2] 110 | reward2 = np.mean(np.abs(obs3 - obs2), axis=2) 111 | target = reward2 + self.discount * target 112 | error2 = target - qa_tm1 113 | 114 | qa_tm1 = obs1[:, :, action1] 115 | reward1 = np.mean(np.abs(obs2 - obs1), axis=2) 116 | target = reward1 + self.discount * target 117 | error1 = target - qa_tm1 118 | 119 | # Compute loss for episode termination with discount 0. 120 | qa_tm1 = obs1[:, :, action1] 121 | reward1 = np.mean(np.abs(obs2 - obs1), axis=2) 122 | target = reward1 + 0. * target 123 | error1_term = target - qa_tm1 124 | 125 | self.error = np.sum( 126 | np.square(error1) + np.square(error2) + np.square(error3)) * 0.5 127 | self.error_term = np.sum( 128 | np.square(error1_term) + np.square(error2) + np.square(error3)) * 0.5 129 | 130 | # Placeholder data. 131 | self.observations = np.expand_dims( 132 | np.stack([obs1, obs2, obs3, obs4], axis=0), axis=1) 133 | self.action_values = self.observations 134 | self.actions = np.stack( 135 | [np.array([action1]), np.array([action2]), np.array([action3])], axis=0) 136 | 137 | def testPixelControlLossScalarDiscount(self): 138 | """Compute loss for given observations, actions, values, scalar discount.""" 139 | 140 | loss, _ = pixel_control_ops.pixel_control_loss( 141 | self.observations_ph, self.actions_ph, self.action_values_ph, 142 | self.cell_size, self.discount, self.scale) 143 | init = tf.global_variables_initializer() 144 | 145 | with self.test_session() as sess: 146 | sess.run(init) 147 | feed_dict = { 148 | self.observations_ph: self.observations, 149 | self.action_values_ph: self.action_values, 150 | self.actions_ph: self.actions} 151 | loss_np = sess.run(loss, feed_dict=feed_dict) 152 | self.assertNear(loss_np, self.error, 1e-3) 153 | 154 | def testPixelControlLossTensorDiscount(self): 155 | """Compute loss for given observations, actions, values, tensor discount.""" 156 | 157 | zero_discount = tf.zeros((1, self.batch_size)) 158 | non_zero_discount = tf.tile( 159 | tf.reshape(self.discount, [1, 1]), 160 | [self.seq_length - 1, self.batch_size]) 161 | tensor_discount = tf.concat([zero_discount, non_zero_discount], axis=0) 162 | loss, _ = pixel_control_ops.pixel_control_loss( 163 | self.observations_ph, self.actions_ph, self.action_values_ph, 164 | self.cell_size, tensor_discount, self.scale) 165 | init = tf.global_variables_initializer() 166 | 167 | with self.test_session() as sess: 168 | sess.run(init) 169 | feed_dict = { 170 | self.observations_ph: self.observations, 171 | self.action_values_ph: self.action_values, 172 | self.actions_ph: self.actions} 173 | loss_np = sess.run(loss, feed_dict=feed_dict) 174 | self.assertNear(loss_np, self.error_term, 1e-3) 175 | 176 | def testPixelControlLossShapes(self): 177 | with self.assertRaisesRegexp( 178 | ValueError, "Pixel Control values are not compatible"): 179 | pixel_control_ops.pixel_control_loss( 180 | self.observations_ph, self.actions_ph, 181 | self.action_values_ph[:, :, :-1], self.cell_size, self.discount, 182 | self.scale) 183 | 184 | def testTensorDiscountShape(self): 185 | with self.assertRaisesRegexp( 186 | ValueError, "discount_factor must be a scalar or a tensor of rank 2"): 187 | tensor_discount = tf.tile( 188 | tf.reshape(self.discount, [1, 1, 1]), 189 | [self.seq_length, self.batch_size, 1]) 190 | pixel_control_ops.pixel_control_loss( 191 | self.observations_ph, self.actions_ph, 192 | self.action_values_ph, self.cell_size, tensor_discount, 193 | self.scale) 194 | 195 | 196 | if __name__ == "__main__": 197 | tf.test.main() 198 | -------------------------------------------------------------------------------- /trfl/continuous_retrace_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """TensorFlow ops for the Retrace algorithm and continuous actions. 16 | 17 | Safe and Efficient Off-Policy Reinforcement Learning 18 | R. Munos, T. Stepleton, A. Harutyunyan, M. G. Bellemare 19 | https://arxiv.org/abs/1606.02647 20 | 21 | This variant is commonly used to update the Q function in RS0, which 22 | additionally uses SVG or a SVG variant to update the policy. 23 | 24 | Learning by Playing - Solving Sparse Reward Tasks from Scratch 25 | M. Riedmiller, R. Hafner, T. Lampe, M. Neunert, J. Degrave, T. Van de Wiele, 26 | V. Mnih, N. Heess, J. T. Springenberg 27 | https://arxiv.org/abs/1802.10567 28 | 29 | Learning Continuous Control Policies by Stochastic Value Gradients 30 | N. Heess, G. Wayne, D. Silver, T. Lillicrap, Y. Tassa, T. Erez 31 | https://arxiv.org/abs/1510.09142 32 | 33 | """ 34 | 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function 38 | 39 | import collections 40 | 41 | import tensorflow.compat.v1 as tf 42 | 43 | 44 | QTraceReturns = collections.namedtuple("QTraceReturns", [ 45 | "qs", "importance_weights", "log_importance_weights", 46 | "truncated_importance_weights", "deltas", "vs_minus_q_xs" 47 | ]) 48 | 49 | 50 | def retrace_from_action_log_probs( 51 | behaviour_action_log_probs, 52 | target_action_log_probs, 53 | discounts, 54 | rewards, 55 | q_values, 56 | values, 57 | bootstrap_value, 58 | lambda_=1., 59 | name="retrace_from_action_log_probs"): 60 | """Constructs Q/Retrace ops. 61 | 62 | This is an implementation of Retrace. In the description of the arguments 63 | the notation is as follows: `T` refers to the sequence size over which 64 | the return is calculated, finally `B` denotes the batch size. 65 | 66 | Args: 67 | behaviour_action_log_probs: Log-probabilities. Shape [T, B]. 68 | target_action_log_probs: Log-probabilities for target policy. Shape [T, B] 69 | discounts: Also called pcontinues. Discount encountered when following 70 | the behaviour policy. Shape [T, B]. 71 | rewards: A tensor containing rewards generated by following the behaviour 72 | policy. Shape [T, B]. 73 | q_values: Q-function estimates wrt. the target policy. Shape [T, B]. 74 | values: Value function estimates wrt. the target policy. Shape [T, B]. 75 | bootstrap_value: Value function estimate at time `T`. Shape [B]. 76 | lambda_: Mix between 1-step (lambda_=0) and n-step (lambda_=1). 77 | name: The name scope that all qtrace ops will be created in. 78 | 79 | Returns: 80 | A `QTraceReturns` namedtuple containing: 81 | 82 | * qs: The Retrace regression/policy gradient targets. 83 | Can be used to calculate estimates of the advantage for policy 84 | gradients or as regression target for Q-value functions. Shape [T, B]. 85 | * importance_weights: Importance sampling weights. Shape [T, B]. 86 | * log_importance_weights: Importance sampling weights. Shape [T, B]. 87 | * truncated_importance_weights: Called c_t in the paper. Shape [T, B]. 88 | * deltas: Shape [T, B] 89 | * vs_minus_q_xs: Q-Retrace targets - Q(x_s, u_s). Shape [T, B]. 90 | """ 91 | # Turn arguments to tensors. 92 | behaviour_action_log_probs = tf.convert_to_tensor( 93 | behaviour_action_log_probs, dtype=tf.float32) 94 | target_action_log_probs = tf.convert_to_tensor( 95 | target_action_log_probs, dtype=tf.float32) 96 | values = tf.convert_to_tensor(values, dtype=tf.float32) 97 | q_values = tf.convert_to_tensor(q_values, dtype=tf.float32) 98 | bootstrap_value = tf.convert_to_tensor(bootstrap_value, dtype=tf.float32) 99 | discounts = tf.convert_to_tensor(discounts, dtype=tf.float32) 100 | rewards = tf.convert_to_tensor(rewards, dtype=tf.float32) 101 | 102 | # Make sure tensor ranks are as expected. 103 | behaviour_action_log_probs.get_shape().assert_has_rank(2) 104 | target_action_log_probs.get_shape().assert_has_rank(2) 105 | values.get_shape().assert_has_rank(2) 106 | q_values.get_shape().assert_has_rank(2) 107 | bootstrap_value.get_shape().assert_has_rank(1) 108 | discounts.get_shape().assert_has_rank(2) 109 | rewards.get_shape().assert_has_rank(2) 110 | 111 | with tf.name_scope( 112 | name, 113 | values=[ 114 | behaviour_action_log_probs, target_action_log_probs, discounts, 115 | rewards, q_values, values, bootstrap_value 116 | ]): 117 | log_rhos = target_action_log_probs - behaviour_action_log_probs 118 | return retrace_from_importance_weights( 119 | log_rhos=log_rhos, 120 | discounts=discounts, 121 | rewards=rewards, 122 | q_values=q_values, 123 | values=values, 124 | bootstrap_value=bootstrap_value, 125 | lambda_=lambda_) 126 | 127 | 128 | def retrace_from_importance_weights(log_rhos, 129 | discounts, 130 | rewards, 131 | q_values, 132 | values, 133 | bootstrap_value, 134 | lambda_=1.0, 135 | name="retrace_from_importance_weights"): 136 | """Constructs Q/Retrace ops. 137 | 138 | This is an implementation of Retrace. In the description of the arguments 139 | the notation is as follows: `T` refers to the sequence size over which 140 | the return is calculated, finally `B` denotes the batch size. 141 | 142 | Args: 143 | log_rhos: Log-probabilities for target policy. Shape [T, B] 144 | discounts: Also called pcontinues. Discount encountered when following 145 | the behaviour policy. Shape [T, B]. 146 | rewards: A tensor containing rewards generated by following the behaviour 147 | policy. Shape [T, B]. 148 | q_values: Q-function estimates wrt. the target policy. Shape [T, B]. 149 | values: Value function estimates wrt. the target policy. Shape [T, B]. 150 | bootstrap_value: Value function estimate at time `T`. Shape [B]. 151 | lambda_: Mix between 1-step (lambda_=0) and n-step (lambda_=1). 152 | name: The name scope that all qtrace ops will be created in. 153 | 154 | Returns: 155 | A `QTraceReturns` namedtuple containing: 156 | 157 | * qs: The Retrace regression/policy gradient targets. 158 | Can be used to calculate estimates of the advantage for policy 159 | gradients or as regression target for Q-value functions. Shape [T, B]. 160 | * importance_weights: Importance sampling weights. Shape [T, B]. 161 | * log_importance_weights: Importance sampling weights. Shape [T, B]. 162 | * truncated_importance_weights: Called c_t in the paper. Shape [T, B]. 163 | * deltas: Shape [T, B] 164 | * vs_minus_q_xs: Q-Retrace targets - Q(x_s, u_s). Shape [T, B]. 165 | 166 | Raises: 167 | ValueError: If compiled=True, but log_rhos has rank other than 2. 168 | """ 169 | # Make sure tensor ranks are consistent. 170 | rho_rank = log_rhos.get_shape().ndims # Usually 2. 171 | q_values.get_shape().assert_has_rank(rho_rank) 172 | values.get_shape().assert_has_rank(rho_rank) 173 | bootstrap_value.get_shape().assert_has_rank(rho_rank - 1) 174 | discounts.get_shape().assert_has_rank(rho_rank) 175 | rewards.get_shape().assert_has_rank(rho_rank) 176 | 177 | lambda_ = tf.convert_to_tensor(lambda_, dtype=tf.float32) 178 | 179 | with tf.name_scope( 180 | name, values=[log_rhos, discounts, rewards, values, bootstrap_value]): 181 | rhos = tf.exp(log_rhos) 182 | 183 | cs = tf.minimum(1.0, rhos, name="cs") 184 | 185 | # Set the last c to 1. 186 | cs = tf.concat([cs[1:], tf.ones_like(cs[-1:])], axis=0) 187 | cs *= lambda_ 188 | 189 | # Append bootstrapped value to get [v1, ..., v_t+1] 190 | values_t_plus_1 = tf.concat( 191 | [values[1:], tf.expand_dims(bootstrap_value, 0)], axis=0) 192 | 193 | # delta_t = (r_t + discount * V(x_{t+1}) - Q(x_t, a_t)) 194 | deltas = (rewards + discounts * values_t_plus_1 - q_values) 195 | 196 | # Note that all sequences are reversed, computation starts from the back. 197 | sequences = ( 198 | tf.reverse(discounts, axis=[0]), 199 | tf.reverse(cs, axis=[0]), 200 | tf.reverse(deltas, axis=[0]), 201 | ) 202 | 203 | # Re-trace vs are calculated through a scan from the back to the beginning 204 | # of the given trajectory. 205 | def scanfunc(acc, sequence_item): 206 | discount_t, c_t, delta_t = sequence_item 207 | return delta_t + discount_t * c_t * acc 208 | 209 | initial_values = tf.zeros_like(bootstrap_value) 210 | vs_minus_q_xs = tf.scan( 211 | fn=scanfunc, 212 | elems=sequences, 213 | initializer=initial_values, 214 | parallel_iterations=1, 215 | back_prop=False, 216 | name="scan") 217 | # Reverse the results back to original order. 218 | vs_minus_q_xs = tf.reverse(vs_minus_q_xs, [0], name="vs_minus_q_xs") 219 | 220 | # Add V(x_s) to get q targets. 221 | qs = tf.add(vs_minus_q_xs, q_values, name="s") 222 | 223 | result = QTraceReturns( 224 | qs=tf.stop_gradient(qs), 225 | importance_weights=tf.stop_gradient(rhos), 226 | log_importance_weights=tf.stop_gradient(log_rhos), 227 | truncated_importance_weights=tf.stop_gradient(cs), 228 | deltas=tf.stop_gradient(deltas), 229 | vs_minus_q_xs=tf.stop_gradient(vs_minus_q_xs)) 230 | return result 231 | -------------------------------------------------------------------------------- /trfl/sequence_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tensorflow ops for multistep return evaluation.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # Dependency imports 22 | import tensorflow.compat.v1 as tf 23 | 24 | 25 | def _reverse_seq(sequence, sequence_lengths=None): 26 | """Reverse sequence along dim 0. 27 | 28 | Args: 29 | sequence: Tensor of shape [T, B, ...]. 30 | sequence_lengths: (optional) tensor of shape [B]. If `None`, only reverse 31 | along dim 0. 32 | 33 | Returns: 34 | Tensor of same shape as sequence with dim 0 reversed up to sequence_lengths. 35 | """ 36 | if sequence_lengths is None: 37 | return tf.reverse(sequence, [0]) 38 | 39 | sequence_lengths = tf.convert_to_tensor(sequence_lengths) 40 | with tf.control_dependencies( 41 | [tf.assert_equal(sequence.shape[1], sequence_lengths.shape[0])]): 42 | return tf.reverse_sequence( 43 | sequence, sequence_lengths, seq_axis=0, batch_axis=1) 44 | 45 | 46 | def scan_discounted_sum(sequence, decay, initial_value, reverse=False, 47 | sequence_lengths=None, back_prop=True, 48 | name="scan_discounted_sum"): 49 | """Evaluates a cumulative discounted sum along dimension 0. 50 | 51 | ```python 52 | if reverse = False: 53 | result[1] = sequence[1] + decay[1] * initial_value 54 | result[k] = sequence[k] + decay[k] * result[k - 1] 55 | if reverse = True: 56 | result[last] = sequence[last] + decay[last] * initial_value 57 | result[k] = sequence[k] + decay[k] * result[k + 1] 58 | ``` 59 | 60 | Respective dimensions T, B and ... have to be the same for all input tensors. 61 | T: temporal dimension of the sequence; B: batch dimension of the sequence. 62 | 63 | if sequence_lengths is set then x1 and x2 below are equivalent: 64 | ```python 65 | x1 = zero_pad_to_length( 66 | scan_discounted_sum( 67 | sequence[:length], decays[:length], **kwargs), length=T) 68 | x2 = scan_discounted_sum(sequence, decays, 69 | sequence_lengths=[length], **kwargs) 70 | ``` 71 | 72 | Args: 73 | sequence: Tensor of shape `[T, B, ...]` containing values to be summed. 74 | decay: Tensor of shape `[T, B, ...]` containing decays/discounts. 75 | initial_value: Tensor of shape `[B, ...]` containing initial value. 76 | reverse: Whether to process the sum in a reverse order. 77 | sequence_lengths: Tensor of shape `[B]` containing sequence lengths to be 78 | (reversed and then) summed. 79 | back_prop: Whether to backpropagate. 80 | name: Sets the name_scope for this op. 81 | 82 | Returns: 83 | Cumulative sum with discount. Same shape and type as `sequence`. 84 | """ 85 | # Note this can be implemented in terms of cumprod and cumsum, 86 | # approximately as (ignoring boundary issues and initial_value): 87 | # 88 | # cumsum(decay_prods * sequence) / decay_prods 89 | # where decay_prods = reverse_cumprod(decay) 90 | # 91 | # One reason this hasn't been done is that multiplying then dividing again by 92 | # products of decays isn't ideal numerically, in particular if any of the 93 | # decays are zero it results in NaNs. 94 | with tf.name_scope(name, values=[sequence, decay, initial_value]): 95 | if sequence_lengths is not None: 96 | # Zero out sequence and decay beyond sequence_lengths. 97 | with tf.control_dependencies( 98 | [tf.assert_equal(sequence.shape[0], decay.shape[0])]): 99 | mask = tf.sequence_mask(sequence_lengths, maxlen=sequence.shape[0], 100 | dtype=sequence.dtype) 101 | mask = tf.transpose(mask) 102 | 103 | # Adding trailing dimensions to mask to allow for broadcasting. 104 | to_seq = mask.shape.dims + [1] * (sequence.shape.ndims - mask.shape.ndims) 105 | sequence *= tf.reshape(mask, to_seq) 106 | to_decay = mask.shape.dims + [1] * (decay.shape.ndims - mask.shape.ndims) 107 | decay *= tf.reshape(mask, to_decay) 108 | 109 | sequences = [sequence, decay] 110 | if reverse: 111 | sequences = [_reverse_seq(s, sequence_lengths) for s in sequences] 112 | 113 | summed = tf.scan(lambda a, x: x[0] + x[1] * a, 114 | sequences, 115 | initializer=tf.convert_to_tensor(initial_value), 116 | parallel_iterations=1, 117 | back_prop=back_prop) 118 | 119 | if not back_prop: 120 | summed = tf.stop_gradient(summed) 121 | if reverse: 122 | summed = _reverse_seq(summed, sequence_lengths) 123 | return summed 124 | 125 | 126 | def multistep_forward_view(rewards, pcontinues, state_values, lambda_, 127 | back_prop=True, sequence_lengths=None, 128 | name="multistep_forward_view_op"): 129 | """Evaluates complex backups (forward view of eligibility traces). 130 | 131 | ```python 132 | result[t] = rewards[t] + 133 | pcontinues[t]*(lambda_[t]*result[t+1] + (1-lambda_[t])*state_values[t]) 134 | result[last] = rewards[last] + pcontinues[last]*state_values[last] 135 | ``` 136 | 137 | This operation evaluates multistep returns where lambda_ parameter controls 138 | mixing between full returns and boostrapping. It is users responsibility 139 | to provide state_values. Depending on how state_values are evaluated this 140 | function can evaluate targets for Q(lambda), Sarsa(lambda) or some other 141 | multistep boostrapping algorithm. 142 | 143 | More information about a forward view is given here: 144 | http://incompleteideas.net/sutton/book/ebook/node74.html 145 | 146 | Please note that instead of evaluating traces and then explicitly summing 147 | them we instead evaluate mixed returns in the reverse temporal order 148 | by using the recurrent relationship given above. 149 | 150 | The parameter lambda_ can either be a constant value (e.g for Peng's 151 | Q(lambda) and Sarsa(_lambda)) or alternatively it can be a tensor containing 152 | arbitrary values (Watkins' Q(lambda), Munos' Retrace, etc). 153 | 154 | The result of evaluating this recurrence relation is a weighted sum of 155 | n-step returns, as depicted in the diagram below. One strategy to prove this 156 | equivalence notes that many of the terms in adjacent n-step returns 157 | "telescope", or cancel out, when the returns are summed. 158 | 159 | Below L3 is lambda at time step 3 (important: this diagram is 1-indexed, not 160 | 0-indexed like Python). If lambda is scalar then L1=L2=...=Ln. 161 | g1,...,gn are discounts. 162 | 163 | ``` 164 | Weights: (1-L1) (1-L2)*l1 (1-L3)*l1*l2 ... L1*L2*...*L{n-1} 165 | Returns: |r1*(g1)+ |r1*(g1)+ |r1*(g1)+ |r1*(g1)+ 166 | v1*(g1) |r2*(g1*g2)+ |r2*(g1*g2)+ |r2*(g1*g2)+ 167 | v2*(g1*g2) |r3*(g1*g2*g3)+ |r3*(g1*g2*g3)+ 168 | v3*(g1*g2*g3) ... 169 | |rn*(g1*...*gn)+ 170 | vn*(g1*...*gn) 171 | ``` 172 | 173 | Args: 174 | rewards: Tensor of shape `[T, B]` containing rewards. 175 | pcontinues: Tensor of shape `[T, B]` containing discounts. 176 | state_values: Tensor of shape `[T, B]` containing state values. 177 | lambda_: Mixing parameter lambda. 178 | The parameter can either be a scalar or a Tensor of shape `[T, B]` 179 | if mixing is a function of state. 180 | back_prop: Whether to backpropagate. 181 | sequence_lengths: Tensor of shape `[B]` containing sequence lengths to be 182 | (reversed and then) summed, same as in `scan_discounted_sum`. 183 | name: Sets the name_scope for this op. 184 | 185 | Returns: 186 | Tensor of shape `[T, B]` containing multistep returns. 187 | """ 188 | with tf.name_scope(name, values=[rewards, pcontinues, state_values]): 189 | # Regroup: 190 | # result[t] = (rewards[t] + pcontinues[t]*(1-lambda_)*state_values[t]) + 191 | # pcontinues[t]*lambda_*result[t + 1] 192 | # Define: 193 | # sequence[t] = rewards[t] + pcontinues[t]*(1-lambda_)*state_values[t] 194 | # discount[t] = pcontinues[t]*lambda_ 195 | # Substitute: 196 | # result[t] = sequence[t] + discount[t]*result[t + 1] 197 | # Boundary condition: 198 | # result[last] = rewards[last] + pcontinues[last]*state_values[last] 199 | # Add and subtract the same quantity at BC: 200 | # state_values[last] = 201 | # lambda_*state_values[last] + (1-lambda_)*state_values[last] 202 | # This makes: 203 | # result[last] = 204 | # (rewards[last] + pcontinues[last]*(1-lambda_)*state_values[last]) + 205 | # pcontinues[last]*lambda_*state_values[last] 206 | # Substitute in definitions for sequence and discount: 207 | # result[last] = sequence[last] + discount[last]*state_values[last] 208 | # Define: 209 | # initial_value=state_values[last] 210 | # We get the following recurrent relationship: 211 | # result[last] = sequence[last] + decay[last]*initial_value 212 | # result[k] = sequence[k] + decay[k] * result[k + 1] 213 | # This matches the form of scan_discounted_sum: 214 | # result = scan_sum_with_discount(sequence, discount, 215 | # initial_value = state_values[last]) 216 | sequence = rewards + pcontinues * state_values * (1 - lambda_) 217 | discount = pcontinues * lambda_ 218 | return scan_discounted_sum(sequence, discount, state_values[-1], 219 | reverse=True, sequence_lengths=sequence_lengths, 220 | back_prop=back_prop) 221 | -------------------------------------------------------------------------------- /trfl/pixel_control_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """TensorFlow ops for implementing Pixel Control. 16 | 17 | Pixel Control is an auxiliary task introduced in the UNREAL agent. 18 | In Pixel Control an additional agent head is trained off-policy to predict 19 | action-value functions for a host of pseudo rewards derived from the stream of 20 | observations. This leads to better state representations and therefore improved 21 | performance, both in terms of data efficiency and final performance. 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | 28 | import collections 29 | 30 | # Dependency imports 31 | import tensorflow.compat.v1 as tf 32 | from trfl import action_value_ops 33 | from trfl import base_ops 34 | 35 | 36 | PixelControlExtra = collections.namedtuple( 37 | "pixel_control_extra", ["spatial_loss", "pseudo_rewards"]) 38 | 39 | 40 | def pixel_control_rewards(observations, cell_size): 41 | """Calculates pixel control task rewards from observation sequence. 42 | 43 | The observations are first split in a grid of KxK cells. For each cell a 44 | distinct pseudo reward is computed as the average absolute change in pixel 45 | intensity for all pixels in the cell. The change in intensity is averaged 46 | across both pixels and channels (e.g. RGB). 47 | 48 | The `observations` provided to this function should be cropped suitably, to 49 | ensure that the observations' height and width are a multiple of `cell_size`. 50 | The values of the `observations` tensor should be rescaled to [0, 1]. In the 51 | UNREAL agent observations are cropped to 80x80, and each cell is 4x4 in size. 52 | 53 | See "Reinforcement Learning with Unsupervised Auxiliary Tasks" by Jaderberg, 54 | Mnih, Czarnecki et al. (https://arxiv.org/abs/1611.05397). 55 | 56 | Args: 57 | observations: A tensor of shape `[T+1,B,H,W,C...]`, where 58 | * `T` is the sequence length, `B` is the batch size. 59 | * `H` is height, `W` is width. 60 | * `C...` is at least one channel dimension (e.g., colour, stack). 61 | * `T` and `B` can be statically unknown. 62 | cell_size: The size of each cell. 63 | 64 | Returns: 65 | A tensor of pixel control rewards calculated from the observation. The 66 | shape is `[T,B,H',W']`, where `H'` and `W'` are determined by the 67 | `cell_size`. If evenly-divisible, `H' = H/cell_size`, and similar for `W`. 68 | """ 69 | # Calculate the absolute differences across the sequence. 70 | abs_diff = tf.abs(observations[1:] - observations[:-1]) 71 | # Average over cells. `abs_diff` has shape [T,B,H,W,C...], e.g., 72 | # [T,B,H,W,C] if we have a colour channel. We want to use the TF avg_pool3d 73 | # op, but it expects 5D inputs so we collapse all channel dimensions. 74 | # Merge remaining dimensions after W: [T,B,H,W,C']. 75 | full_shape = tf.shape(abs_diff) 76 | preserved_shape = full_shape[:4] 77 | trailing_shape = (tf.reduce_prod(full_shape[4:]),) 78 | shape = tf.concat([preserved_shape, trailing_shape], 0) 79 | abs_diff = tf.reshape(abs_diff, shape) 80 | # Apply the averaging using average pooling and reducing over channel. 81 | avg_abs_diff = tf.nn.avg_pool3d( 82 | abs_diff, 83 | ksize=[1, 1, cell_size, cell_size, 1], 84 | strides=[1, 1, cell_size, cell_size, 1], 85 | padding="VALID") # [T,B,H',W',C']. 86 | pseudo_rewards = tf.reduce_mean( 87 | avg_abs_diff, axis=[4], name="pseudo_rewards") # [T,B,H',W']. 88 | sequence_batch = abs_diff.get_shape()[:2] 89 | new_height_width = avg_abs_diff.get_shape()[2:4] 90 | pseudo_rewards.set_shape(sequence_batch.concatenate(new_height_width)) 91 | return pseudo_rewards 92 | 93 | 94 | def pixel_control_loss( 95 | observations, actions, action_values, cell_size, discount_factor, 96 | scale, crop_height_dim=(None, None), crop_width_dim=(None, None)): 97 | """Calculate n-step Q-learning loss for pixel control auxiliary task. 98 | 99 | For each pixel-based pseudo reward signal, the corresponding action-value 100 | function is trained off-policy, using Q(lambda). A discount of 0.9 is 101 | commonly used for learning the value functions. 102 | 103 | Note that, since pseudo rewards have a spatial structure, with neighbouring 104 | cells exhibiting strong correlations, it is convenient to predict the action 105 | values for all the cells through a deconvolutional head. 106 | 107 | See "Reinforcement Learning with Unsupervised Auxiliary Tasks" by Jaderberg, 108 | Mnih, Czarnecki et al. (https://arxiv.org/abs/1611.05397). 109 | 110 | Args: 111 | observations: A tensor of shape `[T+1,B, ...]`; `...` is the observation 112 | shape, `T` the sequence length, and `B` the batch size. `T` and `B` can 113 | be statically unknown for `observations`, `actions` and `action_values`. 114 | actions: A tensor, shape `[T,B]`, of the actions across each sequence. 115 | action_values: A tensor, shape `[T+1,B,H,W,N]` of pixel control action 116 | values, where `H`, `W` are the number of pixel control cells/tasks, and 117 | `N` is the number of actions. 118 | cell_size: size of the cells used to derive the pixel based pseudo-rewards. 119 | discount_factor: discount used for learning the value function associated 120 | to the pseudo rewards; must be a scalar or a Tensor of shape [T,B]. 121 | scale: scale factor for pixels in `observations`. 122 | crop_height_dim: tuple (min_height, max_height) specifying how 123 | to crop the input observations before computing the pseudo-rewards. 124 | crop_width_dim: tuple (min_width, max_width) specifying how 125 | to crop the input observations before computing the pseudo-rewards. 126 | 127 | Returns: 128 | A namedtuple with fields: 129 | 130 | * `loss`: a tensor containing the batch of losses, shape [B]. 131 | * `extra`: a namedtuple with fields: 132 | * `target`: batch of target values for `q_tm1[a_tm1]`, shape [B]. 133 | * `td_error`: batch of temporal difference errors, shape [B]. 134 | 135 | Raises: 136 | ValueError: if the shape of `action_values` is not compatible with that of 137 | the pseudo-rewards derived from the observations. 138 | """ 139 | # Useful shapes. 140 | sequence_length, batch_size = base_ops.best_effort_shape(actions) 141 | num_actions = action_values.get_shape().as_list()[-1] 142 | height_width_q = action_values.get_shape().as_list()[2:-1] 143 | # Calculate rewards using the observations. Crop observations if appropriate. 144 | if crop_height_dim[0] is not None: 145 | h_low, h_high = crop_height_dim 146 | observations = observations[:, :, h_low:h_high, :] 147 | if crop_width_dim[0] is not None: 148 | w_low, w_high = crop_width_dim 149 | observations = observations[:, :, :, w_low:w_high] 150 | # Rescale observations by a constant factor. 151 | observations *= tf.constant(scale) 152 | # Compute pseudo-rewards and get their shape. 153 | pseudo_rewards = pixel_control_rewards(observations, cell_size) 154 | height_width = pseudo_rewards.get_shape().as_list()[2:] 155 | # Check that pseudo-rewards and Q-values are compatible in shape. 156 | if height_width != height_width_q: 157 | raise ValueError( 158 | "Pixel Control values are not compatible with the shape of the" 159 | "pseudo-rewards derived from the observation. Pseudo-rewards have shape" 160 | "{}, while Pixel Control values have shape {}".format( 161 | height_width, height_width_q)) 162 | # We now have Q(s,a) and rewards, so can calculate the n-step loss. The 163 | # QLambda loss op expects inputs of shape [T,B,N] and [T,B], but our tensors 164 | # are in a variety of incompatible shapes. The state-action values have 165 | # shape [T,B,H,W,N] and rewards have shape [T,B,H,W]. We can think of the 166 | # [H,W] dimensions as extra batch dimensions for the purposes of the loss 167 | # calculation, so we first collapse [B,H,W] into a single dimension. 168 | q_tm1 = tf.reshape( 169 | action_values[:-1], # [T,B,H,W,N]. 170 | [sequence_length, -1, num_actions], 171 | name="q_tm1") # [T,BHW,N]. 172 | r_t = tf.reshape( 173 | pseudo_rewards, # [T,B,H,W]. 174 | [sequence_length, -1], 175 | name="r_t") # [T,BHW]. 176 | q_t = tf.reshape( 177 | action_values[1:], # [T,B,H,W,N]. 178 | [sequence_length, -1, num_actions], 179 | name="q_t") # [T,BHW,N]. 180 | # The actions tensor is of shape [T,B], and is the same for each H and W. 181 | # We thus expand it to be same shape as the reward tensor, [T,BHW]. 182 | expanded_actions = tf.expand_dims(tf.expand_dims(actions, -1), -1) 183 | a_tm1 = tf.tile( 184 | expanded_actions, multiples=[1, 1] + height_width) # [T,B,H,W]. 185 | a_tm1 = tf.reshape(a_tm1, [sequence_length, -1]) # [T,BHW]. 186 | # We similarly expand-and-tile the discount to [T,BHW]. 187 | discount_factor = tf.convert_to_tensor(discount_factor) 188 | if discount_factor.shape.ndims == 0: 189 | pcont_t = tf.reshape(discount_factor, [1, 1]) # [1,1]. 190 | pcont_t = tf.tile(pcont_t, tf.shape(a_tm1)) # [T,BHW]. 191 | elif discount_factor.shape.ndims == 2: 192 | tiled_pcont = tf.tile( 193 | tf.expand_dims(tf.expand_dims(discount_factor, -1), -1), 194 | [1, 1] + height_width) 195 | pcont_t = tf.reshape(tiled_pcont, [sequence_length, -1]) 196 | else: 197 | raise ValueError( 198 | "The discount_factor must be a scalar or a tensor of rank 2." 199 | "instead is a tensor of shape {}".format( 200 | discount_factor.shape.as_list())) 201 | # Compute a QLambda loss of shape [T,BHW] 202 | loss, _ = action_value_ops.qlambda(q_tm1, a_tm1, r_t, pcont_t, q_t, lambda_=1) 203 | # Take sum over sequence, sum over cells. 204 | expanded_shape = [sequence_length, batch_size] + height_width 205 | spatial_loss = tf.reshape(loss, expanded_shape) # [T,B,H,W]. 206 | # Return. 207 | extra = PixelControlExtra( 208 | spatial_loss=spatial_loss, pseudo_rewards=pseudo_rewards) 209 | return base_ops.LossOutput( 210 | tf.reduce_sum(spatial_loss, axis=[0, 2, 3]), extra) # [B] 211 | -------------------------------------------------------------------------------- /trfl/value_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for value_ops.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # Dependency imports 22 | from absl.testing import parameterized 23 | import tensorflow.compat.v1 as tf 24 | import tree as nest 25 | from trfl import value_ops 26 | 27 | 28 | class TDLearningTest(tf.test.TestCase): 29 | """Tests for ValueLearning.""" 30 | 31 | def setUp(self): 32 | super(TDLearningTest, self).setUp() 33 | self.v_tm1 = tf.constant([1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=tf.float32) 34 | self.v_t = tf.constant([0, 0, 0, 1, 1, 1, 2, 2, 2], dtype=tf.float32) 35 | self.pcont_t = tf.constant( 36 | [0, 0.5, 1, 0, 0.5, 1, 0, 0.5, 1], dtype=tf.float32) 37 | self.r_t = tf.constant( 38 | [-1, -1, -1, -1, -1, -1, -1, -1, -1], dtype=tf.float32) 39 | self.value_learning = value_ops.td_learning( 40 | self.v_tm1, self.r_t, self.pcont_t, self.v_t) 41 | 42 | def testRankCheck(self): 43 | v_tm1 = tf.placeholder(tf.float32, [None, None]) 44 | with self.assertRaisesRegexp( 45 | ValueError, 'TDLearning: Error in rank and/or compatibility check'): 46 | self.value_learning = value_ops.td_learning( 47 | v_tm1, self.r_t, self.pcont_t, self.v_t) 48 | 49 | def testCompatibilityCheck(self): 50 | pcont_t = tf.placeholder(tf.float32, [8]) 51 | with self.assertRaisesRegexp( 52 | ValueError, 'TDLearning: Error in rank and/or compatibility check'): 53 | self.value_learning = value_ops.td_learning( 54 | self.v_tm1, self.r_t, pcont_t, self.v_t) 55 | 56 | def testTarget(self): 57 | """Tests that target value == r_t + pcont_t * v_t.""" 58 | with self.test_session() as sess: 59 | self.assertAllClose( 60 | sess.run(self.value_learning.extra.target), 61 | [-1, -1, -1, -1, -0.5, 0, -1, 0, 1]) 62 | 63 | def testTDError(self): 64 | """Tests that td_error == target_value - v_tm1.""" 65 | with self.test_session() as sess: 66 | self.assertAllClose( 67 | sess.run(self.value_learning.extra.td_error), 68 | [-2, -2, -2, -2, -1.5, -1, -2, -1, 0]) 69 | 70 | def testLoss(self): 71 | """Tests that loss == 0.5 * td_error^2.""" 72 | with self.test_session() as sess: 73 | # Loss is 0.5 * td_error^2 74 | self.assertAllClose( 75 | sess.run(self.value_learning.loss), 76 | [2, 2, 2, 2, 1.125, 0.5, 2, 0.5, 0]) 77 | 78 | def testGradVtm1(self): 79 | """Tests that the gradients of negative loss are equal to the td_error.""" 80 | with self.test_session() as sess: 81 | # Take gradients of the negative loss, so that the tests here check the 82 | # values propagated during gradient _descent_, rather than _ascent_. 83 | gradients = tf.gradients([-self.value_learning.loss], [self.v_tm1]) 84 | grad_v_tm1 = sess.run(gradients[0]) 85 | self.assertAllClose(grad_v_tm1, [-2, -2, -2, -2, -1.5, -1, -2, -1, 0]) 86 | 87 | def testNoOtherGradients(self): 88 | """Tests no gradient propagates through things other than v_tm1.""" 89 | # Gradients are only defined for v_tm1, not any other input. 90 | gradients = tf.gradients([self.value_learning.loss], 91 | [self.v_t, self.r_t, self.pcont_t]) 92 | self.assertEqual(gradients, [None] * len(gradients)) 93 | 94 | 95 | class TDLambdaTest(parameterized.TestCase, tf.test.TestCase): 96 | 97 | def _setUp_td_loss(self, gae_lambda=1, sequence_length=4, batch_size=2): 98 | t, b = sequence_length, batch_size 99 | self._state_values = tf.placeholder(tf.float32, shape=(t, b)) 100 | self._rewards = tf.placeholder(tf.float32, shape=(t, b)) 101 | self._pcontinues = tf.placeholder(tf.float32, shape=(t, b)) 102 | self._bootstrap_value = tf.placeholder(tf.float32, shape=(b,)) 103 | loss, (td, discounted_returns) = value_ops.td_lambda( 104 | state_values=self._state_values, 105 | rewards=self._rewards, 106 | pcontinues=self._pcontinues, 107 | bootstrap_value=self._bootstrap_value, 108 | lambda_=gae_lambda) 109 | self._loss = loss 110 | self._temporal_differences = td 111 | self._discounted_returns = discounted_returns 112 | 113 | @parameterized.parameters( 114 | (1,), 115 | (0.9,),) 116 | def testShapeInference(self, gae_lambda): 117 | sequence_length = 4 118 | batch_size = 2 119 | self._setUp_td_loss( 120 | gae_lambda, sequence_length=sequence_length, batch_size=batch_size) 121 | sequence_batch_shape = tf.TensorShape([sequence_length, batch_size]) 122 | batch_shape = tf.TensorShape(batch_size) 123 | self.assertEqual(self._discounted_returns.get_shape(), sequence_batch_shape) 124 | self.assertEqual(self._temporal_differences.get_shape(), 125 | sequence_batch_shape) 126 | self.assertEqual(self._loss.get_shape(), batch_shape) 127 | 128 | @parameterized.named_parameters( 129 | ('Length', None, 4), 130 | ('Batch', 5, None), 131 | ('BatchAndLength', None, None),) 132 | def testShapeInferenceDynamic(self, sequence_length, batch_size): 133 | self._setUp_td_loss( 134 | sequence_length=sequence_length, batch_size=batch_size, gae_lambda=1.) 135 | t, b = sequence_length, batch_size 136 | 137 | self.assertEqual(self._discounted_returns.get_shape().as_list(), [t, b]) 138 | self.assertEqual(self._temporal_differences.get_shape().as_list(), [t, b]) 139 | self.assertEqual(self._loss.get_shape().as_list(), [b]) 140 | 141 | @parameterized.parameters( 142 | (1,), 143 | (0.9,),) 144 | def testInvalidGradients(self, gae_lambda): 145 | self._setUp_td_loss(gae_lambda=gae_lambda) 146 | ins = nest.flatten([self._rewards, self._pcontinues, self._bootstrap_value]) 147 | outs = [None] * len(ins) 148 | 149 | self.assertAllEqual(tf.gradients(self._loss, ins), outs) 150 | 151 | def testGradientsLoss(self): 152 | self._setUp_td_loss() 153 | gradient = tf.gradients(self._loss, self._state_values)[0] 154 | self.assertEqual(gradient.get_shape(), self._state_values.get_shape()) 155 | 156 | 157 | class GeneralizedLambdaReturnsTest(parameterized.TestCase, tf.test.TestCase): 158 | 159 | @parameterized.parameters(0.25, 0.5, 1) 160 | def testGeneralizedLambdaReturns(self, lambda_): 161 | """Tests the module-level function generalized_lambda_returns.""" 162 | 163 | # Sequence length 2, batch size 1. 164 | state_values = tf.constant([[0.2], [0.3]], dtype=tf.float32) 165 | rewards = tf.constant([[0.4], [0.5]], dtype=tf.float32) 166 | pcontinues = tf.constant([[0.9], [0.8]], dtype=tf.float32) 167 | bootstrap_value = tf.constant([0.1], dtype=tf.float32) 168 | 169 | discounted_returns = value_ops.generalized_lambda_returns( 170 | rewards, pcontinues, state_values, bootstrap_value, lambda_) 171 | 172 | # Manually calculate the discounted returns. 173 | return1 = 0.5 + 0.8 * 0.1 174 | return0 = 0.4 + 0.9 * (lambda_ * return1 + (1 - lambda_) * 0.3) 175 | 176 | with self.test_session() as sess: 177 | self.assertAllClose(sess.run(discounted_returns), [[return0], [return1]]) 178 | 179 | 180 | class QVMAXTest(tf.test.TestCase): 181 | """Tests for the QVMAX loss.""" 182 | 183 | def setUp(self): 184 | super(QVMAXTest, self).setUp() 185 | self.v_tm1 = tf.constant([1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=tf.float32) 186 | self.pcont_t = tf.constant( 187 | [0, 0.5, 1, 0, 0.5, 1, 0, 0.5, 1], dtype=tf.float32) 188 | self.r_t = tf.constant( 189 | [-1, -1, -1, -1, -1, -1, -1, -1, -1], dtype=tf.float32) 190 | self.q_t = tf.constant( 191 | [[0, -1], [-2, 0], [0, -3], [1, 0], [1, 1], 192 | [0, 1], [1, 2], [2, -2], [2, 2]], dtype=tf.float32) 193 | self.loss_op, self.extra_ops = value_ops.qv_max( 194 | self.v_tm1, self.r_t, self.pcont_t, self.q_t) 195 | 196 | def testRankCheck(self): 197 | v_tm1 = tf.placeholder(tf.float32, [None, None]) 198 | with self.assertRaisesRegexp( 199 | ValueError, 'QVMAX: Error in rank and/or compatibility check'): 200 | value_ops.qv_max(v_tm1, self.r_t, self.pcont_t, self.q_t) 201 | 202 | def testCompatibilityCheck(self): 203 | pcont_t = tf.placeholder(tf.float32, [8]) 204 | with self.assertRaisesRegexp( 205 | ValueError, 'QVMAX: Error in rank and/or compatibility check'): 206 | value_ops.qv_max(self.v_tm1, self.r_t, pcont_t, self.q_t) 207 | 208 | def testTarget(self): 209 | """Tests that target value == r_t + pcont_t * max q_t.""" 210 | with self.test_session() as sess: 211 | self.assertAllClose( 212 | sess.run(self.extra_ops.target), 213 | [-1, -1, -1, -1, -0.5, 0, -1, 0, 1]) 214 | 215 | def testTDError(self): 216 | """Tests that td_error == target_value - v_tm1.""" 217 | with self.test_session() as sess: 218 | self.assertAllClose( 219 | sess.run(self.extra_ops.td_error), 220 | [-2, -2, -2, -2, -1.5, -1, -2, -1, 0]) 221 | 222 | def testLoss(self): 223 | """Tests that loss == 0.5 * td_error^2.""" 224 | with self.test_session() as sess: 225 | # Loss is 0.5 * td_error^2 226 | self.assertAllClose( 227 | sess.run(self.loss_op), 228 | [2, 2, 2, 2, 1.125, 0.5, 2, 0.5, 0]) 229 | 230 | def testGradVtm1(self): 231 | """Tests that the gradients of negative loss are equal to the td_error.""" 232 | with self.test_session() as sess: 233 | # Take gradients of the negative loss, so that the tests here check the 234 | # values propagated during gradient _descent_, rather than _ascent_. 235 | gradients = tf.gradients([-self.loss_op], [self.v_tm1]) 236 | grad_v_tm1 = sess.run(gradients[0]) 237 | self.assertAllClose(grad_v_tm1, [-2, -2, -2, -2, -1.5, -1, -2, -1, 0]) 238 | 239 | def testNoOtherGradients(self): 240 | """Tests no gradient propagates through things other than v_tm1.""" 241 | # Gradients are only defined for v_tm1, not any other input. 242 | gradients = tf.gradients([self.loss_op], 243 | [self.q_t, self.r_t, self.pcont_t]) 244 | self.assertEqual(gradients, [None] * len(gradients)) 245 | 246 | 247 | if __name__ == '__main__': 248 | tf.test.main() 249 | -------------------------------------------------------------------------------- /trfl/retrace_ops_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | # Copyright 2018 The trfl Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================ 16 | """Tests for retrace_ops.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # Dependency imports 23 | import numpy as np 24 | from six.moves import xrange # pylint: disable=redefined-builtin 25 | import tensorflow.compat.v1 as tf 26 | from trfl import retrace_ops 27 | 28 | 29 | class RetraceOpsTest(tf.test.TestCase): 30 | """Tests for `Retrace` ops.""" 31 | 32 | def setUp(self): 33 | """Defines example data for, and an expected result of, Retrace operations. 34 | 35 | The example data comprises a minibatch of two sequences of four 36 | consecutive timesteps, allowing the data to be interpreted by Retrace 37 | as three successive transitions. 38 | """ 39 | super(RetraceOpsTest, self).setUp() 40 | 41 | ### Example input data: 42 | 43 | self.lambda_ = 0.9 44 | self.qs = [ 45 | [[2.2, 3.2, 4.2], 46 | [5.2, 6.2, 7.2]], 47 | 48 | [[7.2, 6.2, 5.2], 49 | [4.2, 3.2, 2.2]], 50 | 51 | [[3.2, 5.2, 7.2], 52 | [4.2, 6.2, 9.2]], 53 | 54 | [[2.2, 8.2, 4.2], 55 | [9.2, 1.2, 8.2]]] 56 | self.targnet_qs = [ 57 | [[2., 3., 4.], 58 | [5., 6., 7.]], 59 | 60 | [[7., 6., 5.], 61 | [4., 3., 2.]], 62 | 63 | [[3., 5., 7.], 64 | [4., 6., 9.]], 65 | 66 | [[2., 8., 4.], 67 | [9., 1., 8.]]] 68 | self.actions = [ 69 | [2, 70 | 0], [1, 71 | 2], [0, 72 | 1], [2, 73 | 0]] 74 | self.rewards = [ 75 | [1.9, 76 | 2.9], [3.9, 77 | 4.9], [5.9, 78 | 6.9], [np.nan, # nan marks entries we should never use. 79 | np.nan]] 80 | self.pcontinues = [ 81 | [0.8, 82 | 0.9], [0.7, 83 | 0.8], [0.6, 84 | 0.5], [np.nan, 85 | np.nan]] 86 | self.target_policy_probs = [ 87 | [[np.nan] * 3, 88 | [np.nan] * 3], 89 | 90 | [[0.41, 0.28, 0.31], 91 | [0.19, 0.77, 0.04]], 92 | 93 | [[0.22, 0.44, 0.34], 94 | [0.14, 0.25, 0.61]], 95 | 96 | [[0.16, 0.72, 0.12], 97 | [0.33, 0.30, 0.37]]] 98 | self.behaviour_policy_probs = [ 99 | [np.nan, 100 | np.nan], [0.85, 101 | 0.86], [0.87, 102 | 0.88], [0.89, 103 | 0.84]] 104 | 105 | ### Expected results of Retrace as applied to the above: 106 | 107 | # NOTE: To keep the test code compact, we don't use the example data when 108 | # manually computing our expected results, but instead duplicate their 109 | # values explictly in those calculations. Some patterns in the values can 110 | # help you track who's who: for example, note that target network Q values 111 | # are integers, whilst learning network Q values all end in 0.2. 112 | 113 | # In a purely theoretical setting, we would compute the quantity we call 114 | # the "trace" using this recurrence relation: 115 | # 116 | # ΔQ_tm1 = δ_tm1 + λγπ(a_t | s_t)/μ(a_t | s_t) ⋅ ΔQ_t 117 | # δ_tm1 = r_t + γ𝔼_π[Q(s_t, .)] - Q(s_tm1, a_tm1) 118 | # 119 | # In a target network setting, you might rewrite ΔQ_t as ΔQ'_t, indicating 120 | # that this value is the next-timestep trace as computed when all 121 | # Q(s_tm1, a_tm1) terms (in δ_t, δ_t+1, ...) come from the target network, 122 | # not the learning network. 123 | # 124 | # To generate our collection of expected outputs, we'll first compute 125 | # "ΔQ'_tm1" (the "target network trace") at all timesteps. 126 | # 127 | # We start at the end of the sequence and work backward, like the 128 | # implementation does. 129 | targ_trace = np.zeros((3, 2)) 130 | targ_trace[2, 0] = (5.9 + 0.6*(0.16*2 + 0.72*8 + 0.12*4) - 3) # δ_tm1[2,0] 131 | targ_trace[2, 1] = (6.9 + 0.5*(0.33*9 + 0.30*1 + 0.37*8) - 6) # δ_tm1[2,1] 132 | 133 | targ_trace[1, 0] = (3.9 + 0.7*(0.22*3 + 0.44*5 + 0.34*7) - 6 + # δ_tm1[1,0] 134 | 0.9*0.7*0.22/0.87 * targ_trace[2, 0]) 135 | targ_trace[1, 1] = (4.9 + 0.8*(0.14*4 + 0.25*6 + 0.61*9) - 2 + # δ_tm1[1,1] 136 | 0.9*0.8*0.25/0.88 * targ_trace[2, 1]) 137 | 138 | targ_trace[0, 0] = (1.9 + 0.8*(0.41*7 + 0.28*6 + 0.31*5) - 4 + # δ_tm1[0,0] 139 | 0.9*0.8*0.28/0.85 * targ_trace[1, 0]) 140 | targ_trace[0, 1] = (2.9 + 0.9*(0.19*4 + 0.77*3 + 0.04*2) - 5 + # δ_tm1[0,1] 141 | 0.9*0.9*0.04/0.86 * targ_trace[1, 1]) 142 | 143 | # We can evaluate target Q values by adding targ_trace to single step 144 | # returns. 145 | target_q = np.zeros((3, 2)) 146 | target_q[2, 0] = (5.9 + 0.6*(0.16*2 + 0.72*8 + 0.12*4)) 147 | target_q[2, 1] = (6.9 + 0.5*(0.33*9 + 0.30*1 + 0.37*8)) 148 | 149 | target_q[1, 0] = (3.9 + 0.7*(0.22*3 + 0.44*5 + 0.34*7) + 150 | 0.9*0.7*0.22/0.87 * targ_trace[2, 0]) 151 | target_q[1, 1] = (4.9 + 0.8*(0.14*4 + 0.25*6 + 0.61*9) + 152 | 0.9*0.8*0.25/0.88 * targ_trace[2, 1]) 153 | 154 | target_q[0, 0] = (1.9 + 0.8*(0.41*7 + 0.28*6 + 0.31*5) + 155 | 0.9*0.8*0.28/0.85 * targ_trace[1, 0]) 156 | target_q[0, 1] = (2.9 + 0.9*(0.19*4 + 0.77*3 + 0.04*2) + 157 | 0.9*0.9*0.04/0.86 * targ_trace[1, 1]) 158 | 159 | # Now we can compute the "official" trace (ΔQ_tm1), which involves the 160 | # learning network. The only difference from the "target network trace" 161 | # calculations is the Q(s_tm1, a_tm1) terms we use: 162 | trace = np.zeros((3, 2)) # ↓ Q(s_tm1, a_tm1) 163 | trace[2, 0] = target_q[2, 0] - 3.2 # δ_tm1[2,0] 164 | trace[2, 1] = target_q[2, 1] - 6.2 # δ_tm1[2,1] 165 | 166 | trace[1, 0] = target_q[1, 0] - 6.2 # δ_tm1[1,0] 167 | trace[1, 1] = target_q[1, 1] - 2.2 # δ_tm1[1,1] 168 | 169 | trace[0, 0] = target_q[0, 0] - 4.2 # δ_tm1[0,0] 170 | trace[0, 1] = target_q[0, 1] - 5.2 # δ_tm1[0,0] 171 | 172 | self.expected_result = 0.5 * np.square(trace) 173 | self.target_q = target_q 174 | 175 | def testRetraceThreeTimeSteps(self): 176 | """Subject Retrace to a two-sequence, three-timestep minibatch.""" 177 | retrace = retrace_ops.retrace( 178 | self.lambda_, self.qs, self.targnet_qs, self.actions, self.rewards, 179 | self.pcontinues, self.target_policy_probs, self.behaviour_policy_probs) 180 | 181 | with self.test_session() as sess: 182 | self.assertAllClose(sess.run(retrace.loss), self.expected_result) 183 | 184 | def _get_retrace_core(self): 185 | """Constructs a tf subgraph from `retrace_core` op. 186 | 187 | A retrace core namedtuple is built from a two-sequence, three-timestep 188 | input minibatch. 189 | 190 | Returns: 191 | Tuple of size 3 containing non-differentiable inputs, differentiable 192 | inputs and retrace_core namedtuple. 193 | """ 194 | # Here we essentially replicate the preprocessing that `retrace` does 195 | # as it constructs the inputs to `retrace_core`. 196 | # These ops must be Tensors so that we can use them in the 197 | # `testNoOtherGradients` unit test. TensorFlow can only compute gradients 198 | # with respect to other parts of the graph 199 | lambda_ = tf.constant(self.lambda_) 200 | q_tm1 = tf.constant(self.qs[:3]) 201 | a_tm1 = tf.constant(self.actions[:3]) 202 | r_t = tf.constant(self.rewards[:3]) 203 | pcont_t = tf.constant(self.pcontinues[:3]) 204 | target_policy_t = tf.constant(self.target_policy_probs[1:4]) 205 | behaviour_policy_t = tf.constant(self.behaviour_policy_probs[1:4]) 206 | targnet_q_t = tf.constant(self.targnet_qs[1:4]) 207 | a_t = tf.constant(self.actions[1:4]) 208 | static_args = [lambda_, a_tm1, r_t, pcont_t, target_policy_t, 209 | behaviour_policy_t, targnet_q_t, a_t] 210 | diff_args = [q_tm1] 211 | return (static_args, diff_args, 212 | retrace_ops.retrace_core(lambda_, q_tm1, a_tm1, r_t, pcont_t, 213 | target_policy_t, behaviour_policy_t, 214 | targnet_q_t, a_t)) 215 | 216 | def testRetraceCoreTargetQThreeTimeSteps(self): 217 | """Tests whether retrace_core evaluates correct targets for regression.""" 218 | _, _, retrace = self._get_retrace_core() 219 | with self.test_session() as sess: 220 | self.assertAllClose(sess.run(retrace.extra.target), self.target_q) 221 | 222 | def testRetraceCoreLossThreeTimeSteps(self): 223 | """Tests whether retrace_core evaluates correct losses.""" 224 | _, _, retrace = self._get_retrace_core() 225 | with self.test_session() as sess: 226 | self.assertAllClose(sess.run(retrace.loss), self.expected_result) 227 | 228 | def testNoOtherGradients(self): 229 | """Tests no gradient propagates through things other than q_tm1.""" 230 | static_args, _, retrace = self._get_retrace_core() 231 | gradients = tf.gradients([retrace.loss], static_args) 232 | self.assertEqual(gradients, [None] * len(gradients)) 233 | 234 | def testMovingNetworkGradientIsEvaluated(self): 235 | """Tests that gradients are evaluated w.r.t. q_tm1.""" 236 | _, diff_args, retrace = self._get_retrace_core() 237 | gradients = tf.gradients([retrace.loss], diff_args) 238 | for gradient in gradients: 239 | self.assertNotEqual(gradient, None) 240 | 241 | def testRetraceHatesBadlyRankedInputs(self): 242 | """Ensure Retrace notices inputs with the wrong rank.""" 243 | # No problems if we create a Retrace using correctly-ranked arguments. 244 | proper_args = [self.lambda_, self.qs, self.targnet_qs, self.actions, 245 | self.rewards, self.pcontinues, self.target_policy_probs, 246 | self.behaviour_policy_probs] 247 | retrace_ops.retrace(*proper_args) 248 | 249 | # Now make a local copy of the args and try modifying each element to have 250 | # an inappropriate rank. We should get an error each time. 251 | for i in xrange(len(proper_args)): 252 | bad_args = list(proper_args) 253 | bad_args[i] = [bad_args[i]] 254 | with self.assertRaises(ValueError): 255 | retrace_ops.retrace(*bad_args) 256 | 257 | 258 | if __name__ == '__main__': 259 | tf.test.main() 260 | -------------------------------------------------------------------------------- /trfl/vtrace_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for vtrace_ops.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # Dependency imports 22 | from absl.testing import parameterized 23 | import numpy as np 24 | import tensorflow.compat.v1 as tf 25 | from trfl import vtrace_ops 26 | 27 | 28 | def _shaped_arange(*shape): 29 | """Runs np.arange, converts to float and reshapes.""" 30 | return np.arange(np.prod(shape), dtype=np.float32).reshape(*shape) 31 | 32 | 33 | def _softmax(logits): 34 | """Applies softmax non-linearity on inputs.""" 35 | return np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True) 36 | 37 | 38 | def _ground_truth_calculation(discounts, log_rhos, rewards, values, 39 | bootstrap_value, clip_rho_threshold, 40 | clip_pg_rho_threshold): 41 | """Calculates the ground truth for V-trace in Python/Numpy.""" 42 | vs = [] 43 | seq_len = len(discounts) 44 | rhos = np.exp(log_rhos) 45 | cs = np.minimum(rhos, 1.0) 46 | clipped_rhos = rhos 47 | if clip_rho_threshold: 48 | clipped_rhos = np.minimum(rhos, clip_rho_threshold) 49 | clipped_pg_rhos = rhos 50 | if clip_pg_rho_threshold: 51 | clipped_pg_rhos = np.minimum(rhos, clip_pg_rho_threshold) 52 | 53 | # This is a very inefficient way to calculate the V-trace ground truth. 54 | # We calculate it this way because it is close to the mathematical notation of 55 | # V-trace. 56 | # v_s = V(x_s) 57 | # + \sum^{T-1}_{t=s} \gamma^{t-s} 58 | # * \prod_{i=s}^{t-1} c_i 59 | # * \rho_t (r_t + \gamma V(x_{t+1}) - V(x_t)) 60 | # Note that when we take the product over c_i, we write `s:t` as the notation 61 | # of the paper is inclusive of the `t-1`, but Python is exclusive. 62 | # Also note that np.prod([]) == 1. 63 | values_t_plus_1 = np.concatenate([values, bootstrap_value[None, :]], axis=0) 64 | for s in range(seq_len): 65 | v_s = np.copy(values[s]) # Very important copy. 66 | for t in range(s, seq_len): 67 | v_s += ( 68 | np.prod(discounts[s:t], axis=0) * np.prod(cs[s:t - 1], 69 | axis=0) * clipped_rhos[t] * 70 | (rewards[t] + discounts[t] * values_t_plus_1[t + 1] - values[t])) 71 | vs.append(v_s) 72 | vs = np.stack(vs, axis=0) 73 | pg_advantages = ( 74 | clipped_pg_rhos * (rewards + discounts * np.concatenate( 75 | [vs[1:], bootstrap_value[None, :]], axis=0) - values)) 76 | 77 | return vtrace_ops.VTraceReturns(vs=vs, pg_advantages=pg_advantages) 78 | 79 | 80 | class LogProbsFromLogitsAndActionsTest(tf.test.TestCase, 81 | parameterized.TestCase): 82 | 83 | @parameterized.named_parameters(('Batch1', 1), ('Batch2', 2)) 84 | def testLogProbsFromLogitsAndActions(self, batch_size): 85 | """Tests log_probs_from_logits_and_actions.""" 86 | seq_len = 7 87 | num_actions = 3 88 | 89 | policy_logits = _shaped_arange(seq_len, batch_size, num_actions) + 10 90 | actions = np.random.randint( 91 | 0, num_actions - 1, size=(seq_len, batch_size), dtype=np.int32) 92 | 93 | action_log_probs_tensor = vtrace_ops.log_probs_from_logits_and_actions( 94 | policy_logits, actions) 95 | 96 | # Ground Truth 97 | # Using broadcasting to create a mask that indexes action logits 98 | action_index_mask = actions[..., None] == np.arange(num_actions) 99 | 100 | def index_with_mask(array, mask): 101 | return array[mask].reshape(*array.shape[:-1]) 102 | 103 | # Note: Normally log(softmax) is not a good idea because it's not 104 | # numerically stable. However, in this test we have well-behaved values. 105 | ground_truth_v = index_with_mask( 106 | np.log(_softmax(policy_logits)), action_index_mask) 107 | 108 | with self.test_session() as session: 109 | self.assertAllClose(ground_truth_v, session.run(action_log_probs_tensor)) 110 | 111 | 112 | class VtraceTest(tf.test.TestCase, parameterized.TestCase): 113 | 114 | @parameterized.named_parameters(('Batch1', 1), ('Batch5', 5)) 115 | def testVTrace(self, batch_size): 116 | """Tests V-trace against ground truth data calculated in python.""" 117 | seq_len = 5 118 | 119 | values = { 120 | # Note that this is only for testing purposes using well-formed inputs. 121 | # In practice we'd be more careful about taking log() of arbitrary 122 | # quantities. 123 | 'log_rhos': 124 | np.log((_shaped_arange(seq_len, batch_size)) / batch_size / 125 | seq_len + 1), 126 | # T, B where B_i: [0.9 / (i+1)] * T 127 | 'discounts': 128 | np.array([[0.9 / (b + 1) 129 | for b in range(batch_size)] 130 | for _ in range(seq_len)]), 131 | 'rewards': 132 | _shaped_arange(seq_len, batch_size), 133 | 'values': 134 | _shaped_arange(seq_len, batch_size) / batch_size, 135 | 'bootstrap_value': 136 | _shaped_arange(batch_size) + 1.0, 137 | 'clip_rho_threshold': 138 | 3.7, 139 | 'clip_pg_rho_threshold': 140 | 2.2, 141 | } 142 | 143 | output = vtrace_ops.vtrace_from_importance_weights(**values) 144 | 145 | with self.test_session() as session: 146 | output_v = session.run(output) 147 | 148 | ground_truth_v = _ground_truth_calculation(**values) 149 | for a, b in zip(ground_truth_v, output_v): 150 | self.assertAllClose(a, b) 151 | 152 | @parameterized.named_parameters(('Batch1', 1), ('Batch2', 2)) 153 | def testVTraceFromLogits(self, batch_size): 154 | """Tests V-trace calculated from logits.""" 155 | seq_len = 5 156 | num_actions = 3 157 | clip_rho_threshold = None # No clipping. 158 | clip_pg_rho_threshold = None # No clipping. 159 | 160 | # Intentionally leaving shapes unspecified to test if V-trace can 161 | # deal with that. 162 | placeholders = { 163 | # T, B, NUM_ACTIONS 164 | 'behaviour_policy_logits': 165 | tf.placeholder(dtype=tf.float32, shape=[None, None, None]), 166 | # T, B, NUM_ACTIONS 167 | 'target_policy_logits': 168 | tf.placeholder(dtype=tf.float32, shape=[None, None, None]), 169 | 'actions': 170 | tf.placeholder(dtype=tf.int32, shape=[None, None]), 171 | 'discounts': 172 | tf.placeholder(dtype=tf.float32, shape=[None, None]), 173 | 'rewards': 174 | tf.placeholder(dtype=tf.float32, shape=[None, None]), 175 | 'values': 176 | tf.placeholder(dtype=tf.float32, shape=[None, None]), 177 | 'bootstrap_value': 178 | tf.placeholder(dtype=tf.float32, shape=[None]), 179 | } 180 | 181 | from_logits_output = vtrace_ops.vtrace_from_logits( 182 | clip_rho_threshold=clip_rho_threshold, 183 | clip_pg_rho_threshold=clip_pg_rho_threshold, 184 | **placeholders) 185 | 186 | target_log_probs = vtrace_ops.log_probs_from_logits_and_actions( 187 | placeholders['target_policy_logits'], placeholders['actions']) 188 | behaviour_log_probs = vtrace_ops.log_probs_from_logits_and_actions( 189 | placeholders['behaviour_policy_logits'], placeholders['actions']) 190 | log_rhos = target_log_probs - behaviour_log_probs 191 | ground_truth = (log_rhos, behaviour_log_probs, target_log_probs) 192 | 193 | values = { 194 | 'behaviour_policy_logits': 195 | _shaped_arange(seq_len, batch_size, num_actions), 196 | 'target_policy_logits': 197 | _shaped_arange(seq_len, batch_size, num_actions), 198 | 'actions': 199 | np.random.randint(0, num_actions - 1, size=(seq_len, batch_size)), 200 | 'discounts': 201 | np.array( # T, B where B_i: [0.9 / (i+1)] * T 202 | [[0.9 / (b + 1) 203 | for b in range(batch_size)] 204 | for _ in range(seq_len)]), 205 | 'rewards': 206 | _shaped_arange(seq_len, batch_size), 207 | 'values': 208 | _shaped_arange(seq_len, batch_size) / batch_size, 209 | 'bootstrap_value': 210 | _shaped_arange(batch_size) + 1.0, # B 211 | } 212 | 213 | feed_dict = {placeholders[k]: v for k, v in values.items()} 214 | with self.test_session() as session: 215 | from_logits_output_v = session.run( 216 | from_logits_output, feed_dict=feed_dict) 217 | (ground_truth_log_rhos, ground_truth_behaviour_action_log_probs, 218 | ground_truth_target_action_log_probs) = session.run( 219 | ground_truth, feed_dict=feed_dict) 220 | 221 | # Calculate V-trace using the ground truth logits. 222 | from_iw = vtrace_ops.vtrace_from_importance_weights( 223 | log_rhos=ground_truth_log_rhos, 224 | discounts=values['discounts'], 225 | rewards=values['rewards'], 226 | values=values['values'], 227 | bootstrap_value=values['bootstrap_value'], 228 | clip_rho_threshold=clip_rho_threshold, 229 | clip_pg_rho_threshold=clip_pg_rho_threshold) 230 | 231 | with self.test_session() as session: 232 | from_iw_v = session.run(from_iw) 233 | 234 | self.assertAllClose(from_iw_v.vs, from_logits_output_v.vs) 235 | self.assertAllClose(from_iw_v.pg_advantages, 236 | from_logits_output_v.pg_advantages) 237 | self.assertAllClose(ground_truth_behaviour_action_log_probs, 238 | from_logits_output_v.behaviour_action_log_probs) 239 | self.assertAllClose(ground_truth_target_action_log_probs, 240 | from_logits_output_v.target_action_log_probs) 241 | self.assertAllClose(ground_truth_log_rhos, from_logits_output_v.log_rhos) 242 | 243 | def testHigherRankInputsForIW(self): 244 | """Checks support for additional dimensions in inputs.""" 245 | placeholders = { 246 | 'log_rhos': tf.placeholder(dtype=tf.float32, shape=[None, None, 1]), 247 | 'discounts': tf.placeholder(dtype=tf.float32, shape=[None, None, 1]), 248 | 'rewards': tf.placeholder(dtype=tf.float32, shape=[None, None, 42]), 249 | 'values': tf.placeholder(dtype=tf.float32, shape=[None, None, 42]), 250 | 'bootstrap_value': tf.placeholder(dtype=tf.float32, shape=[None, 42]) 251 | } 252 | output = vtrace_ops.vtrace_from_importance_weights(**placeholders) 253 | self.assertEqual(output.vs.shape.as_list()[-1], 42) 254 | 255 | def testInconsistentRankInputsForIW(self): 256 | """Test one of many possible errors in shape of inputs.""" 257 | placeholders = { 258 | 'log_rhos': tf.placeholder(dtype=tf.float32, shape=[None, None, 1]), 259 | 'discounts': tf.placeholder(dtype=tf.float32, shape=[None, None, 1]), 260 | 'rewards': tf.placeholder(dtype=tf.float32, shape=[None, None, 42]), 261 | 'values': tf.placeholder(dtype=tf.float32, shape=[None, None, 42]), 262 | # Should be [None, 42]. 263 | 'bootstrap_value': tf.placeholder(dtype=tf.float32, shape=[None]) 264 | } 265 | with self.assertRaisesRegexp(ValueError, 'must have rank 2'): 266 | vtrace_ops.vtrace_from_importance_weights(**placeholders) 267 | 268 | 269 | if __name__ == '__main__': 270 | tf.test.main() 271 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /trfl/value_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """TensorFlow ops for state value learning.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | 23 | # Dependency imports 24 | import tensorflow.compat.v1 as tf 25 | from trfl import base_ops 26 | from trfl import sequence_ops 27 | 28 | 29 | TDExtra = collections.namedtuple("td_extra", ["target", "td_error"]) 30 | TDLambdaExtra = collections.namedtuple( 31 | "td_lambda_extra", ["temporal_differences", "discounted_returns"]) 32 | 33 | 34 | def td_learning(v_tm1, r_t, pcont_t, v_t, name="TDLearning"): 35 | """Implements the TD(0)-learning loss as a TensorFlow op. 36 | 37 | The TD loss is `0.5` times the squared difference between `v_tm1` and 38 | the target `r_t + pcont_t * v_t`. 39 | 40 | See "Learning to Predict by the Methods of Temporal Differences" by Sutton. 41 | (https://link.springer.com/article/10.1023/A:1022633531479). 42 | 43 | Args: 44 | v_tm1: Tensor holding values at previous timestep, shape `[B]`. 45 | r_t: Tensor holding rewards, shape `[B]`. 46 | pcont_t: Tensor holding pcontinue values, shape `[B]`. 47 | v_t: Tensor holding values at current timestep, shape `[B]`. 48 | name: name to prefix ops created by this function. 49 | 50 | Returns: 51 | A namedtuple with fields: 52 | 53 | * `loss`: a tensor containing the batch of losses, shape `[B]`. 54 | * `extra`: a namedtuple with fields: 55 | * `target`: batch of target values for `v_tm1`, shape `[B]`. 56 | * `td_error`: batch of temporal difference errors, shape `[B]`. 57 | """ 58 | # Rank and compatibility checks. 59 | base_ops.wrap_rank_shape_assert([[v_tm1, v_t, r_t, pcont_t]], [1], name) 60 | 61 | # TD(0)-learning op. 62 | with tf.name_scope(name, values=[v_tm1, r_t, pcont_t, v_t]): 63 | 64 | # Build target. 65 | target = tf.stop_gradient(r_t + pcont_t * v_t) 66 | 67 | # Temporal difference error and loss. 68 | # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error. 69 | td_error = target - v_tm1 70 | loss = 0.5 * tf.square(td_error) 71 | return base_ops.LossOutput(loss, TDExtra(target, td_error)) 72 | 73 | 74 | def generalized_lambda_returns(rewards, 75 | pcontinues, 76 | values, 77 | bootstrap_value, 78 | lambda_=1, 79 | name="generalized_lambda_returns"): 80 | """Computes lambda-returns along a batch of (chunks of) trajectories. 81 | 82 | For lambda=1 these will be multistep returns looking ahead from each 83 | state to the end of the chunk, where bootstrap_value is used. If you pass an 84 | entire trajectory and zeros for bootstrap_value, this is just the Monte-Carlo 85 | return / TD(1) target. 86 | 87 | For lambda=0 these are one-step TD(0) targets. 88 | 89 | For inbetween values of lambda these are lambda-returns / TD(lambda) targets, 90 | except that traces are always cut off at the end of the chunk, since we can't 91 | see returns beyond then. If you pass an entire trajectory with zeros for 92 | bootstrap_value though, then they're plain TD(lambda) targets. 93 | 94 | lambda can also be a tensor of values in [0, 1], determining the mix of 95 | bootstrapping vs further accumulation of multistep returns at each timestep. 96 | This can be used to implement Retrace and other algorithms. See 97 | `sequence_ops.multistep_forward_view` for more info on this. Another way to 98 | think about the end-of-chunk cutoff is that lambda is always effectively zero 99 | on the timestep after the end of the chunk, since at the end of the chunk we 100 | rely entirely on bootstrapping and can't accumulate returns looking further 101 | into the future. 102 | 103 | The sequences in the tensors should be aligned such that an agent in a state 104 | with value `V` transitions into another state with value `V'`, receiving 105 | reward `r` and pcontinue `p`. Then `V`, `r` and `p` are all at the same index 106 | `i` in the corresponding tensors. `V'` is at index `i+1`, or in the 107 | `bootstrap_value` tensor if `i == T`. 108 | 109 | Subtracting `values` from these lambda-returns will yield estimates of the 110 | advantage function which can be used for both the policy gradient loss and 111 | the baseline value function loss in A3C / GAE. 112 | 113 | Args: 114 | rewards: 2-D Tensor with shape `[T, B]`. 115 | pcontinues: 2-D Tensor with shape `[T, B]`. 116 | values: 2-D Tensor containing estimates of the state values for timesteps 117 | 0 to `T-1`. Shape `[T, B]`. 118 | bootstrap_value: 1-D Tensor containing an estimate of the value of the 119 | final state at time `T`, used for bootstrapping the target n-step 120 | returns. Shape `[B]`. 121 | lambda_: an optional scalar or 2-D Tensor with shape `[T, B]`. 122 | name: Customises the name_scope for this op. 123 | 124 | Returns: 125 | 2-D Tensor with shape `[T, B]` 126 | """ 127 | values.get_shape().assert_has_rank(2) 128 | rewards.get_shape().assert_has_rank(2) 129 | pcontinues.get_shape().assert_has_rank(2) 130 | bootstrap_value.get_shape().assert_has_rank(1) 131 | scoped_values = [rewards, pcontinues, values, bootstrap_value, lambda_] 132 | with tf.name_scope(name, values=scoped_values): 133 | if lambda_ == 1: 134 | # This is actually equivalent to the branch below, just an optimisation 135 | # to avoid unnecessary work in this case: 136 | return sequence_ops.scan_discounted_sum( 137 | rewards, 138 | pcontinues, 139 | initial_value=bootstrap_value, 140 | reverse=True, 141 | back_prop=False, 142 | name="multistep_returns") 143 | else: 144 | v_tp1 = tf.concat( 145 | axis=0, values=[values[1:, :], 146 | tf.expand_dims(bootstrap_value, 0)]) 147 | # `back_prop=False` prevents gradients flowing into values and 148 | # bootstrap_value, which is what you want when using the bootstrapped 149 | # lambda-returns in an update as targets for values. 150 | return sequence_ops.multistep_forward_view( 151 | rewards, 152 | pcontinues, 153 | v_tp1, 154 | lambda_, 155 | back_prop=False, 156 | name="generalized_lambda_returns") 157 | 158 | 159 | def td_lambda(state_values, 160 | rewards, 161 | pcontinues, 162 | bootstrap_value, 163 | lambda_=1, 164 | name="BaselineLoss"): 165 | """Constructs a TensorFlow graph computing the L2 loss for sequences. 166 | 167 | This loss learns the baseline for advantage actor-critic models. Gradients 168 | for this loss flow through each tensor in `state_values`, but no other 169 | input tensors. The baseline is regressed towards the n-step bootstrapped 170 | returns given by the reward/pcontinue sequence. 171 | 172 | This function is designed for batches of sequences of data. Tensors are 173 | assumed to be time major (i.e. the outermost dimension is time, the second 174 | outermost dimension is the batch dimension). We denote the sequence length 175 | in the shapes of the arguments with the variable `T`, the batch size with 176 | the variable `B`, neither of which needs to be known at construction time. 177 | Index `0` of the time dimension is assumed to be the start of the sequence. 178 | 179 | `rewards` and `pcontinues` are the sequences of data taken directly from the 180 | environment, possibly modulated by a discount. `state_values` are the 181 | sequences of (typically learnt) estimates of the values of the states 182 | visited along a batch of trajectories. 183 | 184 | The sequences in the tensors should be aligned such that an agent in a state 185 | with value `V` that takes an action transitions into another state 186 | with value `V'`, receiving reward `r` and pcontinue `p`. Then `V`, `r` 187 | and `p` are all at the same index `i` in the corresponding tensors. `V'` is 188 | at index `i+1`, or in the `bootstrap_value` tensor if `i == T`. 189 | 190 | See "High-dimensional continuous control using generalized advantage 191 | estimation" by Schulman, Moritz, Levine et al. 192 | (https://arxiv.org/abs/1506.02438). 193 | 194 | Args: 195 | state_values: 2-D Tensor of state-value estimates with shape `[T, B]`. 196 | rewards: 2-D Tensor with shape `[T, B]`. 197 | pcontinues: 2-D Tensor with shape `[T, B]`. 198 | bootstrap_value: 1-D Tensor with shape `[B]`. 199 | lambda_: an optional scalar or 2-D Tensor with shape `[T, B]`. 200 | name: Customises the name_scope for this op. 201 | 202 | Returns: 203 | A namedtuple with fields: 204 | 205 | * `loss`: a tensor containing the batch of losses, shape `[B]`. 206 | * `extra`: a namedtuple with fields: 207 | * temporal_differences, Tensor of shape `[T, B]` 208 | * discounted_returns, Tensor of shape `[T, B]` 209 | """ 210 | scoped_values = [state_values, rewards, pcontinues, bootstrap_value] 211 | with tf.name_scope(name, values=scoped_values): 212 | discounted_returns = generalized_lambda_returns( 213 | rewards, pcontinues, state_values, bootstrap_value, lambda_) 214 | temporal_differences = discounted_returns - state_values 215 | loss = 0.5 * tf.reduce_sum( 216 | tf.square(temporal_differences), axis=0, name="l2_loss") 217 | 218 | return base_ops.LossOutput( 219 | loss, TDLambdaExtra( 220 | temporal_differences=temporal_differences, 221 | discounted_returns=discounted_returns)) 222 | 223 | 224 | def qv_max(v_tm1, r_t, pcont_t, q_t, name="QVMAX"): 225 | """Implements the QVMAX learning loss as a TensorFlow op. 226 | 227 | The QVMAX loss is `0.5` times the squared difference between `v_tm1` and 228 | the target `r_t + pcont_t * max q_t`, where `q_t` is separately learned 229 | through QV learning (c.f. `action_value_ops.qv_learning`). 230 | 231 | See "The QV Family Compared to Other Reinforcement Learning Algorithms" by 232 | Wiering and van Hasselt (2009). 233 | (http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.713.1931) 234 | 235 | Args: 236 | v_tm1: Tensor holding values at previous timestep, shape `[B]`. 237 | r_t: Tensor holding rewards, shape `[B]`. 238 | pcont_t: Tensor holding pcontinue values, shape `[B]`. 239 | q_t: Tensor of action values at current timestep, shape `[B, num_actions]`. 240 | name: name to prefix ops created by this function. 241 | 242 | Returns: 243 | A namedtuple with fields: 244 | 245 | * `loss`: a tensor containing the batch of losses, shape `[B]`. 246 | * `extra`: a namedtuple with fields: 247 | * `target`: batch of target values for `v_tm1`, shape `[B]`. 248 | * `td_error`: batch of temporal difference errors, shape `[B]`. 249 | """ 250 | # Rank and compatibility checks. 251 | base_ops.wrap_rank_shape_assert([[v_tm1, r_t, pcont_t], [q_t]], [1, 2], name) 252 | 253 | # The QVMAX op. 254 | with tf.name_scope(name, values=[v_tm1, r_t, pcont_t, q_t]): 255 | 256 | # Build target. 257 | target = tf.stop_gradient(r_t + pcont_t * tf.reduce_max(q_t, axis=1)) 258 | 259 | # Temporal difference error and loss. 260 | # Loss is MSE scaled by 0.5, so the gradient is equal to the TD error. 261 | td_error = target - v_tm1 262 | loss = 0.5 * tf.square(td_error) 263 | return base_ops.LossOutput(loss, TDExtra(target, td_error)) 264 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # TRFL: Reinforcement Learning Building Blocks 2 | 3 | TRFL (pronounced "truffle") is a library built on top of TensorFlow that exposes 4 | several useful building blocks for implementing Reinforcement Learning agents. 5 | 6 | ## Background 7 | 8 | Common RL algorithms describe a particular update to either a Policy, a Value 9 | function, or an Action-Value (Q) function. In Deep-RL, a policy, value- or Q- 10 | function is typically represented by a neural network (the _model_, not to be 11 | confused with an _environment model_, which is used in _model-based RL_). We 12 | formulate common RL update rules for these neural networks as differentiable 13 | _loss_ functions, as is common in (un-)supervised learning. Under automatic 14 | differentiation, the original update rule is recovered. We find that loss 15 | functions are more modular and composable than traditional RL updates, and more 16 | natural when combining with supervised or unsupervised objectives. 17 | 18 | The loss functions and other operations provided here are implemented in pure 19 | TensorFlow. They are not complete algorithms, but implementations of RL-specific 20 | mathematical operations needed when building fully-functional RL agents. In 21 | particular, the updates are only valid if the input data are sampled in the 22 | correct manner. For example, the [sequence-advantage-actor-critic 23 | loss](trfl.md#sequence_advantage_actor_critic_loss) (i.e. A2C) is only valid if 24 | the input trajectory is an unbiased sample from the current policy; i.e. the 25 | data are _on-policy_. This library cannot check or enforce such constraints. 26 | 27 | ## Installation 28 | 29 | TRFL can be installed from pip directly from github, with the following command: 30 | `pip install git+git://github.com/deepmind/trfl.git` 31 | 32 | TRFL will work with both the CPU and GPU version of tensorflow, but to allow 33 | for that it does not list Tensorflow as a requirement, so you need to install 34 | Tensorflow and Tensorflow-probability separately if you haven't already done so. 35 | 36 | ## Example usage 37 | 38 | Import TensorFlow and TRFL. 39 | 40 | ```python 41 | import tensorflow as tf 42 | import trfl 43 | ``` 44 | 45 | Define the relevant data associated to a `transition` in the environment from 46 | state `s_tm1` to state `s_t`. This typically includes action values (or other 47 | characterization of the agent's policy) in both the `source` and `destination` 48 | states. The action `a_tm1` is the one selected after observing `s_tm1`, and 49 | resulted in observing the immediate reward `r_t` and the subsequent state `s_t`. 50 | `pcont_t` represents a time dependent discount factor, or (equivalently) the 51 | continuation probability from state `s_t`. In most applications, its value will 52 | be equal to a constant factor (e.g., 0.99), except if `s_t` is the final state 53 | in an episode, in which case it is set to zero. 54 | 55 | ```python 56 | # Q-values for the previous and next timesteps, shape [batch_size, num_actions]. 57 | q_tm1 = tf.get_variable( 58 | "q_tm1", initializer=[[1., 1., 0.], [1., 2., 0.]], dtype=tf.float32) 59 | q_t = tf.get_variable( 60 | "q_t", initializer=[[0., 1., 0.], [1., 2., 0.]], dtype=tf.float32) 61 | 62 | # Action indices, discounts and rewards, shape [batch_size]. 63 | a_tm1 = tf.constant([0, 1], dtype=tf.int32) 64 | r_t = tf.constant([1, 1], dtype=tf.float32) 65 | pcont_t = tf.constant([0, 1], dtype=tf.float32) # the discount factor 66 | 67 | # Q-learning loss, and auxiliary data. 68 | loss, q_learning = trfl.qlearning(q_tm1, a_tm1, r_t, pcont_t, q_t) 69 | ``` 70 | 71 | `loss` is the tensor representing the loss. For Q-learning, it is half the 72 | squared difference between the predicted Q-values and the TD targets, shape 73 | `[batch_size]`. Extra information is in the `q_learning` namedtuple, including 74 | `q_learning.td_error` and `q_learning.target`. 75 | 76 | Most of the time, you may only be interested in the loss, in which case you can 77 | use any of the following expressions: 78 | 79 | ```python 80 | loss, _ = trfl.qlearning(q_tm1, a_tm1, r_t, pcont_t, q_t) 81 | loss = trfl.qlearning(q_tm1, a_tm1, r_t, pcont_t, q_t).loss 82 | ``` 83 | 84 | The `loss` tensor can be differentiated to derive the corresponding RL update. 85 | Note that in Q-learning, as in other bootstrapped losses, the TD targets 86 | are wrapped in a `tf.stop_gradient`. Differentiating `loss` therefore 87 | results in gradients with respect to `q_tm1` but not with respect to `q_t`. 88 | 89 | ```python 90 | reduced_loss = tf.reduce_mean(loss) 91 | optimizer = tf.train.AdamOptimizer(learning_rate=0.1) 92 | train_op = optimizer.minimize(reduced_loss) 93 | ``` 94 | 95 | All loss functions in the package return both a loss tensor and a namedtuple 96 | with extra information, using the above convention, but different functions 97 | may have different `extra` fields. Check the documentation of each function 98 | below for more information. 99 | 100 | ## Naming Conventions and Developer Guidelines 101 | 102 | Throughout the package, we use the following conventions: 103 | 104 | * Time indices and variable names: 105 | 106 | * `q_tm1`: the action value in the `source` state of a transition. 107 | * `a_tm1`: the action that was selected in the `source` state. 108 | * `r_t`: the resulting rewards collected in the `destination` state. 109 | * `pcont_t`: the continuation probability / `discount` for a transition. 110 | * `q_t`: the action values in the `destination` state. 111 | 112 | * Tensor shapes: 113 | 114 | * All ops support minibatches only. We use `B` to denote the batch size. 115 | * Batches of rewards, continuation probabilities / discounts have shape [B]. 116 | * Batches of state-values have shape `[B]`. 117 | * Batches of action-values / q-values have shape `[B, num_actions]`. 118 | * All losses have shape [B], i.e. the loss is not reduced over the batch 119 | dimension. This allows the user to easily weight the loss for different 120 | elements of the batch (e.g., by their importance sampling weights). 121 | * For ops that take batches of sequences of data, `T` denotes the sequence 122 | length. Tensors are time-major, and have shape `[T, B, ...]`. Index `0` 123 | of the time dimension is assumed to be the start of the sequence. 124 | 125 | ## Learning updates 126 | 127 | * State Value learning: 128 | 129 | * [td_learning](trfl.md#td_learningv_tm1-r_t-pcont_t-v_t-nametdlearning) 130 | * [generalized_lambda_returns](trfl.md#generalized_lambda_returnsrewards-pcontinues-values-bootstrap_value-lambda_1-namegeneralized_lambda_returns) 131 | * [td_lambda](trfl.md#td_lambdastate_values-rewards-pcontinues-bootstrap_value-lambda_1-namebaselineloss) 132 | * [qv_max](trfl.md#qv_maxv_tm1-r_t-pcont_t-q_t-nameqvmax) 133 | 134 | * Discrete-action Value learning: 135 | 136 | * [qlearning](trfl.md#qlearningq_tm1-a_tm1-r_t-pcont_t-q_t-nameqlearning) 137 | * [double_qlearning](trfl.md#double_qlearningq_tm1-a_tm1-r_t-pcont_t-q_t_value-q_t_selector-namedoubleqlearning) 138 | * [persistent_qlearning](trfl.md#persistent_qlearningq_tm1-a_tm1-r_t-pcont_t-q_t-action_gap_scale05-namepersistentqlearning) 139 | * [sarsa](trfl.md#sarsaq_tm1-a_tm1-r_t-pcont_t-q_t-a_t-namesarsa) 140 | * [sarse](trfl.md#sarseq_tm1-a_tm1-r_t-pcont_t-q_t-probs_a_t-debugfalse-namesarse) 141 | * [qlambda](trfl.md#qlambdaq_tm1-a_tm1-r_t-pcont_t-q_t-lambda_-namegeneralizedqlambda) 142 | * [qv_learning](trfl.md#qv_learningq_tm1-a_tm1-r_t-pcont_t-v_t-nameqvlearning) 143 | 144 | * Distributional Value learning: 145 | 146 | * [categorical_dist_qlearning](trfl.md#categorical_dist_qlearningatoms_tm1-logits_q_tm1-a_tm1-r_t-pcont_t-atoms_t-logits_q_t-namecategoricaldistqlearning) 147 | * [categorical_dist_double_qlearning](trfl.md#categorical_dist_double_qlearningatoms_tm1-logits_q_tm1-a_tm1-r_t-pcont_t-atoms_t-logits_q_t-q_t_selector-namecategoricaldistdoubleqlearning) 148 | * [categorical_dist_td_learning](trfl.md#categorical_dist_td_learningatoms_tm1-logits_v_tm1-r_t-pcont_t-atoms_t-logits_v_t-namecategoricaldisttdlearning) 149 | 150 | * Continuous-action Policy Gradient: 151 | 152 | * [policy_gradient](trfl.md#policy_gradientpolicies-actions-action_values-policy_varsnone-namepolicy_gradient) 153 | * [policy_gradient_loss](trfl.md#policy_gradient_losspolicies-actions-action_values-policy_varsnone-namepolicy_gradient_loss) 154 | * [policy_entropy_loss](trfl.md#policy_entropy_losspolicies-policy_varsnone-scale_opnone-namepolicy_entropy_loss) 155 | * [sequence_a2c_loss](trfl.md#sequence_a2c_losspolicies-baseline_values-actions-rewards-pcontinues-bootstrap_value-policy_varsnone-lambda_1-entropy_costnone-baseline_cost1-entropy_scale_opnone-namesequencea2closs) 156 | 157 | * Deterministic Policy Gradient: 158 | 159 | * [dpg](trfl.md#dpg) 160 | 161 | * Discrete-action Policy Gradient: 162 | 163 | * [discrete_policy_entropy_loss](trfl.md#discrete_policy_entropy_losspolicy_logits-normalisefalse-namediscrete_policy_entropy_loss) 164 | * [sequence_advantage_actor_critic_loss](trfl.md#sequence_advantage_actor_critic_losspolicy_logits-baseline_values-actions-rewards-pcontinues-bootstrap_value-lambda_1-entropy_costnone-baseline_cost1-normalise_entropyfalse-namesequenceadvantageactorcriticloss): 165 | this is the commonly-used A2C/A3C loss function. 166 | * [discrete_policy_gradient](trfl.md#discrete_policy_gradientpolicy_logits-actions-action_values-namediscrete_policy_gradient) 167 | * [discrete_policy_gradient_loss](trfl.md#discrete_policy_gradient_losspolicy_logits-actions-action_values-namediscrete_policy_gradient_loss) 168 | 169 | * Pixel control: 170 | 171 | * [pixel_control_rewards](trfl.md#pixel_control_rewardsobservations-cell_size) 172 | * [pixel_control_loss](trfl.md#pixel_control_lossobservations-actions-action_values-cell_size-discount_factor-scale-crop_height_dimnone-none-crop_width_dimnone-none) 173 | 174 | * Retrace: 175 | 176 | * [retrace](trfl.md#retracelambda_-qs-targnet_qs-actions-rewards-pcontinues-target_policy_probs-behaviour_policy_probs-stop_targnet_gradientstrue-namenone) 177 | * [retrace_core](trfl.md#retrace_corelambda_-q_tm1-a_tm1-r_t-pcont_t-target_policy_t-behaviour_policy_t-targnet_q_t-a_t-stop_targnet_gradientstrue-namenone) 178 | 179 | * Target Network Updating: 180 | 181 | * [update_target_variables](trfl.md#update_target_variablestarget_variables-source_variables-tau10-use_lockingfalse-nameupdate_target_variables) 182 | * [periodic_target_update](trfl.md#periodic_target_updatetarget_variables-source_variables-update_period-tau10,use_lockingfalse-nameperiodic_target_update) 183 | 184 | * V-trace: 185 | 186 | * [vtrace_from_logits](trfl.md#vtrace_from_logitsbehaviour_policy_logits-target_policy_logits-actions-discounts-rewards-values-bootstrap_value-clip_rho_threshold10-clip_pg_rho_threshold10-namevtrace_from_logits) 187 | * [vtrace_from_importance_weights](trfl.md#vtrace_from_importance_weightslog_rhos-discounts-rewards-values-bootstrap_value-clip_rho_threshold10-clip_pg_rho_threshold10-namevtrace_from_importance_weights) 188 | 189 | ## Other 190 | 191 | * Clipping ops 192 | 193 | * [huber_loss](trfl.md#huber_lossinput_tensor-quadratic_linear_boundary-namenone) 194 | 195 | * Distributions 196 | 197 | * [l2_project](trfl.md#l2_projectsupport-weights-new_support) 198 | * [factorised_kl_gaussian](trfl.md#factorised_kl_gaussiandist1_mean-dist1_covariance_or_scale-dist2_mean-dist2_covariance_or_scale-both_diagonalfalse) 199 | 200 | * Indexing ops 201 | 202 | * [batched_index](trfl.md#batched_indexvalues-indices) 203 | 204 | * Periodic execution ops 205 | 206 | * [periodically](trfl.md#periodicallybody-period-nameperiodically) 207 | 208 | * Policy ops 209 | 210 | * [epsilon_greedy](trfl.md#epsilon_greedy) 211 | 212 | * Sequence ops 213 | 214 | * [scan_discounted_sum](trfl.md#scan_discounted_sumsequence-decay-initial_value-reversefalse-sequence_lengthsnone-back_proptrue-namescan_discounted_sum) 215 | * [multistep_forward_view](trfl.md#multistep_forward_viewrewards-pcontinues-state_values-lambda_-back_proptrue-sequence_lengthsnone-namemultistep_forward_view_op) 216 | 217 | ## More information 218 | 219 | * [Multistep Forward View](multistep_forward_view.md) 220 | -------------------------------------------------------------------------------- /trfl/dist_value_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tensorflow ops for common Distributional RL learning rules. 16 | 17 | Distributions are taken to be categorical over a support of 'N' distinct atoms, 18 | which are always specified in ascending order. 19 | 20 | These ops define state/action value distribution learning rules for discrete, 21 | scalar, action spaces. Actions must be represented as indices in the range 22 | `[0, K)` where `K` is the number of distinct actions. 23 | """ 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | 28 | import collections 29 | 30 | # Dependency imports 31 | import tensorflow.compat.v1 as tf 32 | from trfl import base_ops 33 | from trfl import distribution_ops 34 | 35 | Extra = collections.namedtuple("dist_value_extra", ["target"]) 36 | 37 | _l2_project = distribution_ops.l2_project 38 | 39 | 40 | def _slice_with_actions(embeddings, actions): 41 | """Slice a Tensor. 42 | 43 | Take embeddings of the form [batch_size, num_actions, embed_dim] 44 | and actions of the form [batch_size, 1], and return the sliced embeddings 45 | like embeddings[:, actions, :]. 46 | 47 | Args: 48 | embeddings: Tensor of embeddings to index. 49 | actions: int Tensor to use as index into embeddings 50 | 51 | Returns: 52 | Tensor of embeddings indexed by actions 53 | """ 54 | batch_size, num_actions = embeddings.get_shape()[:2] 55 | 56 | # Values are the 'values' in a sparse tensor we will be setting 57 | act_indx = tf.cast(actions, tf.int64)[:, None] 58 | values = tf.reshape(tf.cast(tf.ones(tf.shape(actions)), tf.bool), [-1]) 59 | 60 | # Create a range for each index into the batch 61 | act_range = tf.range(0, batch_size, dtype=tf.int64)[:, None] 62 | # Combine this into coordinates with the action indices 63 | indices = tf.concat([act_range, act_indx], 1) 64 | 65 | actions_mask = tf.SparseTensor(indices, values, [batch_size, num_actions]) 66 | actions_mask = tf.stop_gradient( 67 | tf.sparse_tensor_to_dense(actions_mask, default_value=False)) 68 | sliced_emb = tf.boolean_mask(embeddings, actions_mask) 69 | return sliced_emb 70 | 71 | 72 | def categorical_dist_qlearning(atoms_tm1, 73 | logits_q_tm1, 74 | a_tm1, 75 | r_t, 76 | pcont_t, 77 | atoms_t, 78 | logits_q_t, 79 | name="CategoricalDistQLearning"): 80 | """Implements Distributional Q-learning as TensorFlow ops. 81 | 82 | The function assumes categorical value distributions parameterized by logits. 83 | 84 | See "A Distributional Perspective on Reinforcement Learning" by Bellemare, 85 | Dabney and Munos. (https://arxiv.org/abs/1707.06887). 86 | 87 | Args: 88 | atoms_tm1: 1-D tensor containing atom values for first timestep, 89 | shape `[num_atoms]`. 90 | logits_q_tm1: Tensor holding logits for first timestep in a batch of 91 | transitions, shape `[B, num_actions, num_atoms]`. 92 | a_tm1: Tensor holding action indices, shape `[B]`. 93 | r_t: Tensor holding rewards, shape `[B]`. 94 | pcont_t: Tensor holding pcontinue values, shape `[B]`. 95 | atoms_t: 1-D tensor containing atom values for second timestep, 96 | shape `[num_atoms]`. 97 | logits_q_t: Tensor holding logits for second timestep in a batch of 98 | transitions, shape `[B, num_actions, num_atoms]`. 99 | name: name to prefix ops created by this function. 100 | 101 | Returns: 102 | A namedtuple with fields: 103 | 104 | * `loss`: a tensor containing the batch of losses, shape `[B]`. 105 | * `extra`: a namedtuple with fields: 106 | * `target`: a tensor containing the values that `q_tm1` at actions 107 | `a_tm1` are regressed towards, shape `[B, num_atoms]`. 108 | 109 | Raises: 110 | ValueError: If the tensors do not have the correct rank or compatibility. 111 | """ 112 | # Rank and compatibility checks. 113 | assertion_lists = [[logits_q_tm1, logits_q_t], [a_tm1, r_t, pcont_t], 114 | [atoms_tm1, atoms_t]] 115 | base_ops.wrap_rank_shape_assert(assertion_lists, [3, 1, 1], name) 116 | 117 | # Categorical distributional Q-learning op. 118 | with tf.name_scope( 119 | name, 120 | values=[ 121 | atoms_tm1, logits_q_tm1, a_tm1, r_t, pcont_t, atoms_t, logits_q_t 122 | ]): 123 | 124 | with tf.name_scope("target"): 125 | # Scale and shift time-t distribution atoms by discount and reward. 126 | target_z = r_t[:, None] + pcont_t[:, None] * atoms_t[None, :] 127 | 128 | # Convert logits to distribution, then find greedy action in state s_t. 129 | q_t_probs = tf.nn.softmax(logits_q_t) 130 | q_t_mean = tf.reduce_sum(q_t_probs * atoms_t, 2) 131 | pi_t = tf.argmax(q_t_mean, 1, output_type=tf.int32) 132 | 133 | # Compute distribution for greedy action. 134 | p_target_z = _slice_with_actions(q_t_probs, pi_t) 135 | 136 | # Project using the Cramer distance 137 | target = tf.stop_gradient(_l2_project(target_z, p_target_z, atoms_tm1)) 138 | 139 | logit_qa_tm1 = _slice_with_actions(logits_q_tm1, a_tm1) 140 | 141 | loss = tf.nn.softmax_cross_entropy_with_logits( 142 | logits=logit_qa_tm1, labels=target) 143 | 144 | return base_ops.LossOutput(loss, Extra(target)) 145 | 146 | 147 | def categorical_dist_double_qlearning(atoms_tm1, 148 | logits_q_tm1, 149 | a_tm1, 150 | r_t, 151 | pcont_t, 152 | atoms_t, 153 | logits_q_t, 154 | q_t_selector, 155 | name="CategoricalDistDoubleQLearning"): 156 | """Implements Distributional Double Q-learning as TensorFlow ops. 157 | 158 | The function assumes categorical value distributions parameterized by logits, 159 | and combines distributional RL with double Q-learning. 160 | 161 | See "Rainbow: Combining Improvements in Deep Reinforcement Learning" by 162 | Hessel, Modayil, van Hasselt, Schaul et al. 163 | (https://arxiv.org/abs/1710.02298). 164 | 165 | Args: 166 | atoms_tm1: 1-D tensor containing atom values for first timestep, 167 | shape `[num_atoms]`. 168 | logits_q_tm1: Tensor holding logits for first timestep in a batch of 169 | transitions, shape `[B, num_actions, num_atoms]`. 170 | a_tm1: Tensor holding action indices, shape `[B]`. 171 | r_t: Tensor holding rewards, shape `[B]`. 172 | pcont_t: Tensor holding pcontinue values, shape `[B]`. 173 | atoms_t: 1-D tensor containing atom values for second timestep, 174 | shape `[num_atoms]`. 175 | logits_q_t: Tensor holding logits for second timestep in a batch of 176 | transitions, shape `[B, num_actions, num_atoms]`. 177 | q_t_selector: Tensor holding another set of Q-values for second timestep 178 | in a batch of transitions, shape `[B, num_actions]`. 179 | These values are used for estimating the best action. In Double DQN they 180 | come from the online network. 181 | name: name to prefix ops created by this function. 182 | 183 | Returns: 184 | A namedtuple with fields: 185 | 186 | * `loss`: Tensor containing the batch of losses, shape `[B]`. 187 | * `extra`: a namedtuple with fields: 188 | * `target`: Tensor containing the values that `q_tm1` at actions 189 | `a_tm1` are regressed towards, shape `[B, num_atoms]` . 190 | 191 | Raises: 192 | ValueError: If the tensors do not have the correct rank or compatibility. 193 | """ 194 | # Rank and compatibility checks. 195 | assertion_lists = [[logits_q_tm1, logits_q_t], [a_tm1, r_t, pcont_t], 196 | [atoms_tm1, atoms_t], [q_t_selector]] 197 | base_ops.wrap_rank_shape_assert(assertion_lists, [3, 1, 1, 2], name) 198 | 199 | # Categorical distributional double Q-learning op. 200 | with tf.name_scope( 201 | name, 202 | values=[ 203 | atoms_tm1, logits_q_tm1, a_tm1, r_t, pcont_t, atoms_t, logits_q_t, 204 | q_t_selector 205 | ]): 206 | 207 | with tf.name_scope("target"): 208 | # Scale and shift time-t distribution atoms by discount and reward. 209 | target_z = r_t[:, None] + pcont_t[:, None] * atoms_t[None, :] 210 | 211 | # Convert logits to distribution, then find greedy policy action in 212 | # state s_t. 213 | q_t_probs = tf.nn.softmax(logits_q_t) 214 | pi_t = tf.argmax(q_t_selector, 1, output_type=tf.int32) 215 | # Compute distribution for greedy action. 216 | p_target_z = _slice_with_actions(q_t_probs, pi_t) 217 | 218 | # Project using the Cramer distance 219 | target = tf.stop_gradient(_l2_project(target_z, p_target_z, atoms_tm1)) 220 | 221 | logit_qa_tm1 = _slice_with_actions(logits_q_tm1, a_tm1) 222 | 223 | loss = tf.nn.softmax_cross_entropy_with_logits( 224 | logits=logit_qa_tm1, labels=target) 225 | 226 | return base_ops.LossOutput(loss, Extra(target)) 227 | 228 | 229 | def categorical_dist_td_learning(atoms_tm1, 230 | logits_v_tm1, 231 | r_t, 232 | pcont_t, 233 | atoms_t, 234 | logits_v_t, 235 | name="CategoricalDistTDLearning"): 236 | """Implements Distributional TD-learning as TensorFlow ops. 237 | 238 | The function assumes categorical value distributions parameterized by logits. 239 | 240 | See "A Distributional Perspective on Reinforcement Learning" by Bellemare, 241 | Dabney and Munos. (https://arxiv.org/abs/1707.06887). 242 | 243 | Args: 244 | atoms_tm1: 1-D tensor containing atom values for first timestep, 245 | shape `[num_atoms]`. 246 | logits_v_tm1: Tensor holding logits for first timestep in a batch of 247 | transitions, shape `[B, num_atoms]`. 248 | r_t: Tensor holding rewards, shape `[B]`. 249 | pcont_t: Tensor holding pcontinue values, shape `[B]`. 250 | atoms_t: 1-D tensor containing atom values for second timestep, 251 | shape `[num_atoms]`. 252 | logits_v_t: Tensor holding logits for second timestep in a batch of 253 | transitions, shape `[B, num_atoms]`. 254 | name: name to prefix ops created by this function. 255 | 256 | Returns: 257 | A namedtuple with fields: 258 | 259 | * `loss`: Tensor containing the batch of losses, shape `[B]`. 260 | * `extra`: A namedtuple with fields: 261 | * `target`: Tensor containing the values that `v_tm1` are 262 | regressed towards, shape `[B, num_atoms]`. 263 | 264 | Raises: 265 | ValueError: If the tensors do not have the correct rank or compatibility. 266 | """ 267 | # Rank and compatibility checks. 268 | assertion_lists = [[logits_v_tm1, logits_v_t], [r_t, pcont_t], 269 | [atoms_tm1, atoms_t]] 270 | base_ops.wrap_rank_shape_assert(assertion_lists, [2, 1, 1], name) 271 | 272 | # Categorical distributional TD-learning op. 273 | with tf.name_scope( 274 | name, values=[atoms_tm1, logits_v_tm1, r_t, pcont_t, atoms_t, 275 | logits_v_t]): 276 | 277 | with tf.name_scope("target"): 278 | # Scale and shift time-t distribution atoms by discount and reward. 279 | target_z = r_t[:, None] + pcont_t[:, None] * atoms_t[None, :] 280 | v_t_probs = tf.nn.softmax(logits_v_t) 281 | 282 | # Project using the Cramer distance 283 | target = tf.stop_gradient(_l2_project(target_z, v_t_probs, atoms_tm1)) 284 | 285 | loss = tf.nn.softmax_cross_entropy_with_logits( 286 | logits=logits_v_tm1, labels=target) 287 | 288 | return base_ops.LossOutput(loss, Extra(target)) 289 | -------------------------------------------------------------------------------- /trfl/vtrace_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The trfl Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Ops for computing v-trace learning targets.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | 23 | # Dependency imports 24 | import tensorflow.compat.v1 as tf 25 | 26 | 27 | VTraceFromLogitsReturns = collections.namedtuple( 28 | 'VTraceFromLogitsReturns', 29 | ['vs', 'pg_advantages', 'log_rhos', 30 | 'behaviour_action_log_probs', 'target_action_log_probs']) 31 | VTraceReturns = collections.namedtuple('VTraceReturns', 'vs pg_advantages') 32 | 33 | 34 | def log_probs_from_logits_and_actions(policy_logits, actions): 35 | """Computes action log-probs from policy logits and actions. 36 | 37 | In the notation used throughout documentation and comments, T refers to the 38 | time dimension ranging from 0 to T-1. B refers to the batch size and 39 | NUM_ACTIONS refers to the number of actions. 40 | 41 | Args: 42 | policy_logits: A float32 tensor of shape `[T, B, NUM_ACTIONS]` with 43 | un-normalized log-probabilities parameterizing a softmax policy. 44 | actions: An int32 tensor of shape `[T, B]` with actions. 45 | 46 | Returns: 47 | A float32 tensor of shape `[T, B]` corresponding to the sampling log 48 | probability of the chosen action w.r.t. the policy. 49 | """ 50 | policy_logits = tf.convert_to_tensor(policy_logits, dtype=tf.float32) 51 | actions = tf.convert_to_tensor(actions, dtype=tf.int32) 52 | 53 | policy_logits.shape.assert_has_rank(3) 54 | actions.shape.assert_has_rank(2) 55 | 56 | return -tf.nn.sparse_softmax_cross_entropy_with_logits( 57 | logits=policy_logits, labels=actions) 58 | 59 | 60 | def vtrace_from_logits( 61 | behaviour_policy_logits, target_policy_logits, actions, 62 | discounts, rewards, values, bootstrap_value, 63 | clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0, 64 | name='vtrace_from_logits'): 65 | r"""V-trace for softmax policies. 66 | 67 | Calculates V-trace actor critic targets for softmax polices as described in 68 | 69 | "IMPALA: Scalable Distributed Deep-RL with 70 | Importance Weighted Actor-Learner Architectures" 71 | by Espeholt, Soyer, Munos et al. 72 | 73 | Target policy refers to the policy we are interested in improving and 74 | behaviour policy refers to the policy that generated the given 75 | rewards and actions. 76 | 77 | In the notation used throughout documentation and comments, `T` refers to the 78 | time dimension ranging from `0` to `T-1`. `B` refers to the batch size and 79 | `NUM_ACTIONS` refers to the number of actions. 80 | 81 | Args: 82 | behaviour_policy_logits: A float32 tensor of shape `[T, B, NUM_ACTIONS]` 83 | with un-normalized log-probabilities parametrizing the softmax behaviour 84 | policy. 85 | target_policy_logits: A float32 tensor of shape `[T, B, NUM_ACTIONS]` with 86 | un-normalized log-probabilities parametrizing the softmax target policy. 87 | actions: An int32 tensor of shape `[T, B]` of actions sampled from the 88 | behaviour policy. 89 | discounts: A float32 tensor of shape `[T, B]` with the discount encountered 90 | when following the behaviour policy. 91 | rewards: A float32 tensor of shape `[T, B]` with the rewards generated by 92 | following the behaviour policy. 93 | values: A float32 tensor of shape `[T, B]` with the value function estimates 94 | wrt. the target policy. 95 | bootstrap_value: A float32 of shape `[B]` with the value function estimate 96 | at time T. 97 | clip_rho_threshold: A scalar float32 tensor with the clipping threshold for 98 | importance weights (rho) when calculating the baseline targets (vs). 99 | rho^bar in the paper. 100 | clip_pg_rho_threshold: A scalar float32 tensor with the clipping threshold 101 | on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)). 102 | name: The name scope that all V-trace operations will be created in. 103 | 104 | Returns: 105 | A `VTraceFromLogitsReturns` namedtuple with the following fields: 106 | vs: A float32 tensor of shape `[T, B]`. Can be used as target to train a 107 | baseline (V(x_t) - vs_t)^2. 108 | pg_advantages: A float 32 tensor of shape `[T, B]`. Can be used as an 109 | estimate of the advantage in the calculation of policy gradients. 110 | log_rhos: A float32 tensor of shape `[T, B]` containing the log importance 111 | sampling weights (log rhos). 112 | behaviour_action_log_probs: A float32 tensor of shape `[T, B]` containing 113 | behaviour policy action log probabilities (log \mu(a_t)). 114 | target_action_log_probs: A float32 tensor of shape `[T, B]` containing 115 | target policy action probabilities (log \pi(a_t)). 116 | """ 117 | behaviour_policy_logits = tf.convert_to_tensor( 118 | behaviour_policy_logits, dtype=tf.float32) 119 | target_policy_logits = tf.convert_to_tensor( 120 | target_policy_logits, dtype=tf.float32) 121 | actions = tf.convert_to_tensor(actions, dtype=tf.int32) 122 | 123 | # Make sure tensor ranks are as expected. 124 | # The rest will be checked by from_action_log_probs. 125 | behaviour_policy_logits.shape.assert_has_rank(3) 126 | target_policy_logits.shape.assert_has_rank(3) 127 | actions.shape.assert_has_rank(2) 128 | 129 | with tf.name_scope(name, values=[ 130 | behaviour_policy_logits, target_policy_logits, actions, 131 | discounts, rewards, values, bootstrap_value]): 132 | target_action_log_probs = log_probs_from_logits_and_actions( 133 | target_policy_logits, actions) 134 | behaviour_action_log_probs = log_probs_from_logits_and_actions( 135 | behaviour_policy_logits, actions) 136 | log_rhos = target_action_log_probs - behaviour_action_log_probs 137 | vtrace_returns = vtrace_from_importance_weights( 138 | log_rhos=log_rhos, 139 | discounts=discounts, 140 | rewards=rewards, 141 | values=values, 142 | bootstrap_value=bootstrap_value, 143 | clip_rho_threshold=clip_rho_threshold, 144 | clip_pg_rho_threshold=clip_pg_rho_threshold) 145 | return VTraceFromLogitsReturns( 146 | log_rhos=log_rhos, 147 | behaviour_action_log_probs=behaviour_action_log_probs, 148 | target_action_log_probs=target_action_log_probs, 149 | **vtrace_returns._asdict() 150 | ) 151 | 152 | 153 | def vtrace_from_importance_weights( 154 | log_rhos, discounts, rewards, values, bootstrap_value, 155 | clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0, 156 | name='vtrace_from_importance_weights'): 157 | r"""V-trace from log importance weights. 158 | 159 | Calculates V-trace actor critic targets as described in 160 | 161 | "IMPALA: Scalable Distributed Deep-RL with 162 | Importance Weighted Actor-Learner Architectures" 163 | by Espeholt, Soyer, Munos et al. 164 | 165 | In the notation used throughout documentation and comments, T refers to the 166 | time dimension ranging from 0 to T-1. B refers to the batch size. This code 167 | also supports the case where all tensors have the same number of additional 168 | dimensions, e.g., `rewards` is `[T, B, C]`, `values` is `[T, B, C]`, 169 | `bootstrap_value` is `[B, C]`. 170 | 171 | Args: 172 | log_rhos: A float32 tensor of shape `[T, B]` representing the 173 | log importance sampling weights, i.e. 174 | log(target_policy(a) / behaviour_policy(a)). V-trace performs operations 175 | on rhos in log-space for numerical stability. 176 | discounts: A float32 tensor of shape `[T, B]` with discounts encountered 177 | when following the behaviour policy. 178 | rewards: A float32 tensor of shape `[T, B]` containing rewards generated by 179 | following the behaviour policy. 180 | values: A float32 tensor of shape `[T, B]` with the value function estimates 181 | wrt. the target policy. 182 | bootstrap_value: A float32 of shape `[B]` with the value function estimate 183 | at time T. 184 | clip_rho_threshold: A scalar float32 tensor with the clipping threshold for 185 | importance weights (rho) when calculating the baseline targets (vs). 186 | rho^bar in the paper. If None, no clipping is applied. 187 | clip_pg_rho_threshold: A scalar float32 tensor with the clipping threshold 188 | on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)). If 189 | None, no clipping is applied. 190 | name: The name scope that all V-trace operations will be created in. 191 | 192 | Returns: 193 | A VTraceReturns namedtuple (vs, pg_advantages) where: 194 | vs: A float32 tensor of shape `[T, B]`. Can be used as target to 195 | train a baseline (V(x_t) - vs_t)^2. 196 | pg_advantages: A float32 tensor of shape `[T, B]`. Can be used as the 197 | advantage in the calculation of policy gradients. 198 | """ 199 | log_rhos = tf.convert_to_tensor(log_rhos, dtype=tf.float32) 200 | discounts = tf.convert_to_tensor(discounts, dtype=tf.float32) 201 | rewards = tf.convert_to_tensor(rewards, dtype=tf.float32) 202 | values = tf.convert_to_tensor(values, dtype=tf.float32) 203 | bootstrap_value = tf.convert_to_tensor(bootstrap_value, dtype=tf.float32) 204 | if clip_rho_threshold is not None: 205 | clip_rho_threshold = tf.convert_to_tensor(clip_rho_threshold, 206 | dtype=tf.float32) 207 | if clip_pg_rho_threshold is not None: 208 | clip_pg_rho_threshold = tf.convert_to_tensor(clip_pg_rho_threshold, 209 | dtype=tf.float32) 210 | 211 | # Make sure tensor ranks are consistent. 212 | rho_rank = log_rhos.shape.ndims # Usually 2. 213 | values.shape.assert_has_rank(rho_rank) 214 | bootstrap_value.shape.assert_has_rank(rho_rank - 1) 215 | discounts.shape.assert_has_rank(rho_rank) 216 | rewards.shape.assert_has_rank(rho_rank) 217 | if clip_rho_threshold is not None: 218 | clip_rho_threshold.shape.assert_has_rank(0) 219 | if clip_pg_rho_threshold is not None: 220 | clip_pg_rho_threshold.shape.assert_has_rank(0) 221 | 222 | with tf.name_scope(name, values=[ 223 | log_rhos, discounts, rewards, values, bootstrap_value]): 224 | rhos = tf.exp(log_rhos) 225 | if clip_rho_threshold is not None: 226 | clipped_rhos = tf.minimum(clip_rho_threshold, rhos, name='clipped_rhos') 227 | else: 228 | clipped_rhos = rhos 229 | 230 | cs = tf.minimum(1.0, rhos, name='cs') 231 | # Append bootstrapped value to get [v1, ..., v_t+1] 232 | values_t_plus_1 = tf.concat( 233 | [values[1:], tf.expand_dims(bootstrap_value, 0)], axis=0) 234 | deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values) 235 | 236 | # Note that all sequences are reversed, computation starts from the back. 237 | sequences = ( 238 | tf.reverse(discounts, axis=[0]), 239 | tf.reverse(cs, axis=[0]), 240 | tf.reverse(deltas, axis=[0]), 241 | ) 242 | # V-trace vs are calculated through a scan from the back to the beginning 243 | # of the given trajectory. 244 | def scanfunc(acc, sequence_item): 245 | discount_t, c_t, delta_t = sequence_item 246 | return delta_t + discount_t * c_t * acc 247 | 248 | initial_values = tf.zeros_like(bootstrap_value) 249 | vs_minus_v_xs = tf.scan( 250 | fn=scanfunc, 251 | elems=sequences, 252 | initializer=initial_values, 253 | parallel_iterations=1, 254 | back_prop=False, 255 | name='scan') 256 | # Reverse the results back to original order. 257 | vs_minus_v_xs = tf.reverse(vs_minus_v_xs, [0], name='vs_minus_v_xs') 258 | 259 | # Add V(x_s) to get v_s. 260 | vs = tf.add(vs_minus_v_xs, values, name='vs') 261 | 262 | # Advantage for policy gradient. 263 | vs_t_plus_1 = tf.concat([ 264 | vs[1:], tf.expand_dims(bootstrap_value, 0)], axis=0) 265 | if clip_pg_rho_threshold is not None: 266 | clipped_pg_rhos = tf.minimum(clip_pg_rho_threshold, rhos, 267 | name='clipped_pg_rhos') 268 | else: 269 | clipped_pg_rhos = rhos 270 | pg_advantages = ( 271 | clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values)) 272 | 273 | # Make sure no gradients backpropagated through the returned values. 274 | return VTraceReturns(vs=tf.stop_gradient(vs), 275 | pg_advantages=tf.stop_gradient(pg_advantages)) 276 | --------------------------------------------------------------------------------