├── LICENSE
├── README.md
├── data
├── mher_all.png
├── mher_all_step.png
└── mher_sac.png
├── mher
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── config.cpython-36.pyc
│ ├── default_cfg.cpython-36.pyc
│ ├── play.cpython-36.pyc
│ ├── run.cpython-36.pyc
│ └── train.cpython-36.pyc
├── algos
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── actor_critic.cpython-36.pyc
│ │ ├── ddpg.cpython-36.pyc
│ │ ├── normalizer.cpython-36.pyc
│ │ ├── rollout.cpython-36.pyc
│ │ └── util.cpython-36.pyc
│ ├── actor_critic.py
│ ├── algorithm.py
│ ├── ddpg.py
│ ├── dynamics.py
│ ├── sac.py
│ ├── sac_utils.py
│ └── util.py
├── buffers
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── replay_buffer.cpython-36.pyc
│ │ └── samplers.cpython-36.pyc
│ ├── prioritized_buffer.py
│ └── replay_buffer.py
├── common
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── atari_wrappers.cpython-36.pyc
│ │ ├── cmd_util.cpython-36.pyc
│ │ ├── console_util.cpython-36.pyc
│ │ ├── dataset.cpython-36.pyc
│ │ ├── import_util.cpython-36.pyc
│ │ ├── init_utils.cpython-36.pyc
│ │ ├── logger.cpython-36.pyc
│ │ ├── math_util.cpython-36.pyc
│ │ ├── misc_util.cpython-36.pyc
│ │ ├── monitor.cpython-36.pyc
│ │ ├── mpi_adam.cpython-36.pyc
│ │ ├── mpi_moments.cpython-36.pyc
│ │ ├── retro_wrappers.cpython-36.pyc
│ │ ├── tf_util.cpython-36.pyc
│ │ ├── tile_images.cpython-36.pyc
│ │ └── wrappers.cpython-36.pyc
│ ├── atari_wrappers.py
│ ├── cg.py
│ ├── cmd_util.py
│ ├── console_util.py
│ ├── dataset.py
│ ├── distributions.py
│ ├── import_util.py
│ ├── init_utils.py
│ ├── input.py
│ ├── logger.py
│ ├── math_util.py
│ ├── misc_util.py
│ ├── models.py
│ ├── monitor.py
│ ├── mpi_adam.py
│ ├── mpi_adam_optimizer.py
│ ├── mpi_fork.py
│ ├── mpi_moments.py
│ ├── mpi_running_mean_std.py
│ ├── mpi_util.py
│ ├── normalizer.py
│ ├── plot
│ │ ├── plot.py
│ │ └── results_plotter.py
│ ├── plot_util.py
│ ├── policies.py
│ ├── retro_wrappers.py
│ ├── runners.py
│ ├── running_mean_std.py
│ ├── schedules.py
│ ├── segment_tree.py
│ ├── test_mpi_util.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── envs
│ │ │ ├── __init__.py
│ │ │ ├── fixed_sequence_env.py
│ │ │ ├── identity_env.py
│ │ │ ├── identity_env_test.py
│ │ │ └── mnist_env.py
│ │ ├── test_cartpole.py
│ │ ├── test_doc_examples.py
│ │ ├── test_env_after_learn.py
│ │ ├── test_fetchreach.py
│ │ ├── test_fixed_sequence.py
│ │ ├── test_identity.py
│ │ ├── test_mnist.py
│ │ ├── test_plot_util.py
│ │ ├── test_schedules.py
│ │ ├── test_segment_tree.py
│ │ ├── test_serialization.py
│ │ ├── test_tf_util.py
│ │ ├── test_with_mpi.py
│ │ └── util.py
│ ├── tf_util.py
│ ├── tile_images.py
│ ├── vec_env
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── dummy_vec_env.cpython-36.pyc
│ │ │ ├── shmem_vec_env.cpython-36.pyc
│ │ │ ├── subproc_vec_env.cpython-36.pyc
│ │ │ ├── util.cpython-36.pyc
│ │ │ ├── vec_env.cpython-36.pyc
│ │ │ ├── vec_frame_stack.cpython-36.pyc
│ │ │ ├── vec_monitor.cpython-36.pyc
│ │ │ ├── vec_normalize.cpython-36.pyc
│ │ │ ├── vec_remove_dict_obs.cpython-36.pyc
│ │ │ └── vec_video_recorder.cpython-36.pyc
│ │ ├── dummy_vec_env.py
│ │ ├── shmem_vec_env.py
│ │ ├── subproc_vec_env.py
│ │ ├── test_vec_env.py
│ │ ├── test_video_recorder.py
│ │ ├── util.py
│ │ ├── vec_env.py
│ │ ├── vec_frame_stack.py
│ │ ├── vec_monitor.py
│ │ ├── vec_normalize.py
│ │ ├── vec_remove_dict_obs.py
│ │ └── vec_video_recorder.py
│ └── wrappers.py
├── config.py
├── default_cfg.py
├── envs
│ ├── __pycache__
│ │ ├── env_utils.cpython-36.pyc
│ │ └── make_env_utils.cpython-36.pyc
│ ├── env_utils.py
│ ├── make_env_utils.py
│ └── wrappers
│ │ ├── __pycache__
│ │ └── wrapper_utils.cpython-36.pyc
│ │ ├── multi_world_wrapper.py
│ │ └── wrapper_utils.py
├── play.py
├── plot.py
├── rollouts
│ ├── __init__.py
│ └── rollout.py
├── run.py
├── samplers
│ ├── __init__.py
│ ├── her_sampler.py
│ ├── nstep_sampler.py
│ ├── prioritized_sampler.py
│ └── sampler.py
└── train.py
└── setup.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Rui
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Modular-HER
2 |   
3 |
4 | Modular-HER is revised from OpenAI baselines and supports many improvements for Hindsight Experience Replay (HER) as modules. We aim to provide a more **modular**, **readable** and **concise** package for Multi-goal Reinforcement Learning.
5 |
6 | Welcome everyone to contribute suggestions or code !
7 |
8 |
9 | ## Functions
10 | - [x] DDPG (https://arxiv.org/abs/1509.02971);
11 | - [x] HER (future, episode, final, random) (https://arxiv.org/abs/1707.01495);
12 | - [x] Cut HER (incrementally increase the future sample length);
13 | - [x] SHER (https://arxiv.org/abs/2002.02089);
14 | - [x] Prioritized HER (same as PHER in https://arxiv.org/abs/1905.08786);
15 | - [ ] Energe-based Prioritized HER(https://www.researchgate.net/publication/341776498_Energy-Based_Hindsight_Experience_Prioritization);
16 | - [ ] Curriculum-guided Hindsight Experience Replay (http://papers.nips.cc/paper/9425-curriculum-guided-hindsight-experience-replay);
17 | - [x] nstep DDPG and nstep HER;
18 | - [ ] more to be continued...
19 |
20 |
21 | ## Prerequisites
22 | Require python3 (>=3.5), tensorflow (>=1.4,<=1.14) and system packages CMake, OpenMPI and zlib. Those can be installed as follows
23 |
24 | #### Ubuntu :
25 | ```bash
26 | sudo apt-get update && sudo apt-get install cmake libopenmpi-dev python3-dev zlib1g-dev
27 | ```
28 |
29 | #### Mac OS X :
30 | With [Homebrew](https://brew.sh) installed, run the following:
31 | ```bash
32 | brew install cmake openmpi
33 | ```
34 |
35 | ## Installation
36 | ```bash
37 | git clone https://github.com/YangRui2015/Modular_HER.git
38 | cd Modular_HER
39 | pip install -e .
40 | ```
41 |
42 |
43 | ## Usage
44 | Trainging DDPG and save logs and models.
45 | ```bash
46 | python -m mher.run --env=FetchReach-v1 --num_epoch 30 --num_env 1 --sampler random --play_episodes 5 --log_path=~/logs/fetchreach/ --save_path=~/logs/models/fetchreach_ddpg/
47 | ```
48 |
49 | Trainging HER + DDPG with different sampler ('her_future', 'her_random', 'her_last', 'her_episode' are supported).
50 | ```bash
51 | python -m mher.run --env=FetchReach-v1 --num_epoch 30 --num_env 1 --sampler her_future --play_episodes 5 --log_path=~/logs/fetchreach/ --save_path=~/logs/models/fetchreach_herfuture/
52 | ```
53 |
54 | Training SAC + HER.
55 | ```bash
56 | python -m mher.run --env=FetchReach-v1 --num_epoch 50 --algo sac --sac_alpha 0.05 --sampler her_episode
57 | ```
58 |
59 | All support sampler flags.
60 | | Group | Samplers |
61 | | ------ | ------ |
62 | | Random sampler | random |
63 | | HER | her_future, her_episode, her_last, her_random |
64 | | Nstep| nstep, nstep_her_future, nstep_her_epsisode, nstep_her_last, nstep_her_random|
65 | | Priority| priority, priority_her_future, priority_her_episode, priority_her_random, priority_her_last|
66 |
67 |
68 | ## Results
69 |
70 | We use a group of test parameters in DEFAULT_ENV_PARAMS for performance comparison in FetchReach-v1 environment.
71 |
72 | 1. Performance of HER of different goal sample methods (future, random, episode, last).
73 |
74 |

75 |
76 | 2. Performance of Nstep HER and Nstep DDPG.
77 |
78 | 
79 |
80 | 3. Performance of SHER (Not good enough in FetchReach environment, I will test more envs to report).
81 |
82 | 
83 |
84 |
85 | ## Update
86 |
87 | * 9.27 V0.0: update readme;
88 | * 10.3 V0.5: revised code framework hugely, supported DDPG and HER(future, last, final, random);
89 | * 10.4 V0.6: update code framework, add rollouts and samplers packages;
90 | * 10.6 add nstep sampler and nstep her sampler;
91 | * 10.7 fix bug of nstep her sampler;
92 | * 10.16 add priority experience replay and cut her;
93 | * 10.31 V1.0: add SHER support;
94 |
--------------------------------------------------------------------------------
/data/mher_all.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/data/mher_all.png
--------------------------------------------------------------------------------
/data/mher_all_step.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/data/mher_all_step.png
--------------------------------------------------------------------------------
/data/mher_sac.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/data/mher_sac.png
--------------------------------------------------------------------------------
/mher/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/__init__.py
--------------------------------------------------------------------------------
/mher/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/__pycache__/config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/__pycache__/config.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/__pycache__/default_cfg.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/__pycache__/default_cfg.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/__pycache__/play.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/__pycache__/play.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/__pycache__/run.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/__pycache__/run.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/__pycache__/train.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/__pycache__/train.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/algos/__init__.py:
--------------------------------------------------------------------------------
1 | from mher.algos.ddpg import DDPG
2 | from mher.algos.sac import SAC
--------------------------------------------------------------------------------
/mher/algos/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/algos/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/algos/__pycache__/actor_critic.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/algos/__pycache__/actor_critic.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/algos/__pycache__/ddpg.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/algos/__pycache__/ddpg.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/algos/__pycache__/normalizer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/algos/__pycache__/normalizer.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/algos/__pycache__/rollout.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/algos/__pycache__/rollout.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/algos/__pycache__/util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/algos/__pycache__/util.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/algos/actor_critic.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from mher.algos.sac_utils import apply_squashing_func, mlp_gaussian_policy
4 | from mher.algos.util import nn, store_args
5 |
6 |
7 | class ActorCritic:
8 | @store_args
9 | def __init__(self, inputs_tf, dimo, dimg, dimu, max_u, o_stats, g_stats, hidden, layers, sess):
10 | """The actor-critic network and related training code.
11 | Args:
12 | inputs_tf (dict of tensors): all necessary inputs for the network: the
13 | observation (o), the goal (g), and the action (u)
14 | dimo (int): the dimension of the observations
15 | dimg (int): the dimension of the goals
16 | dimu (int): the dimension of the actions
17 | max_u (float): the maximum magnitude of actions; action outputs will be scaled accordingly
18 | o_stats (mher.algos.Normalizer): normalizer for observations
19 | g_stats (mher.algos.Normalizer): normalizer for goals
20 | hidden (int): number of hidden units that should be used in hidden layers
21 | layers (int): number of hidden layers
22 | """
23 | self.o_tf = inputs_tf['o']
24 | self.g_tf = inputs_tf['g']
25 | self.u_tf = inputs_tf['u']
26 |
27 | # Prepare inputs for actor and critic.
28 | o = self.o_stats.normalize(self.o_tf)
29 | g = self.g_stats.normalize(self.g_tf)
30 | input_pi = tf.concat(axis=1, values=[o, g]) # for actor
31 | self._network(input_pi, o, g)
32 |
33 |
34 | def _network(self, input_pi, o, g):
35 | # Networks.
36 | with tf.variable_scope('pi'):
37 | self.pi_tf = self.max_u * tf.tanh(nn(input_pi, [self.hidden] * self.layers + [self.dimu]))
38 |
39 | with tf.variable_scope('Q'):
40 | # for policy training
41 | input_Q = tf.concat(axis=1, values=[o, g, self.pi_tf / self.max_u])
42 | self.Q_pi_tf = nn(input_Q, [self.hidden] * self.layers + [1])
43 | # for critic training
44 | input_Q = tf.concat(axis=1, values=[o, g, self.u_tf / self.max_u])
45 | self._input_Q = input_Q # exposed for tests
46 | self.Q_tf = nn(input_Q, [self.hidden] * self.layers + [1], reuse=True)
47 |
48 | def get_Q(self, o, g, u):
49 | feed = {
50 | self.o_tf: o.reshape(-1, self.dimo),
51 | self.g_tf: g.reshape(-1, self.dimg),
52 | self.u_tf: u.reshape(-1, self.dimu)
53 | }
54 | return self.sess.run(self.Q_tf, feed_dict=feed)
55 |
56 | def get_Q_pi(self, o, g):
57 | feed = {
58 | self.o_tf: o.reshape(-1, self.dimo),
59 | self.g_tf:g.reshape(-1, self.dimg)
60 | }
61 | return self.sess.run(self.Q_pi_tf, feed_dict=feed)
62 |
63 |
64 | class SAC_ActorCritic(ActorCritic):
65 | @store_args
66 | def __init__(self, inputs_tf, dimo, dimg, dimu, max_u, o_stats, g_stats, hidden, layers, sess):
67 | super(SAC_ActorCritic, self).__init__(**self.__dict__)
68 |
69 |
70 | def _network(self, input_pi, o, g):
71 | with tf.variable_scope('pi'):
72 | self.mu_tf, self.pi_tf, self.logp_pi_tf, self.log_std = mlp_gaussian_policy(input_pi, self.dimu,
73 | hidden_sizes=[self.hidden] * self.layers,
74 | activation=tf.nn.relu,
75 | output_activation=None)
76 | self.mu_tf, self.pi_tf, self.logp_pi_tf = apply_squashing_func(self.mu_tf, self.pi_tf, self.logp_pi_tf)
77 |
78 | with tf.variable_scope('q1'):
79 | self.q1_pi_tf = nn(tf.concat(axis=1, values=[o, g, self.pi_tf]),
80 | layers_sizes=[self.hidden] * self.layers + [1])
81 | self.q1_tf = nn(tf.concat(axis=1, values=[o, g, self.u_tf]),
82 | layers_sizes=[self.hidden] * self.layers + [1], reuse=True)
83 | with tf.variable_scope('q2'):
84 | self.q2_pi_tf = nn(tf.concat(axis=1, values=[o, g, self.pi_tf]),
85 | layers_sizes=[self.hidden] * self.layers + [1])
86 | self.q2_tf = nn(tf.concat(axis=1, values=[o, g, self.u_tf]),
87 | layers_sizes=[self.hidden] * self.layers + [1], reuse=True)
88 | with tf.variable_scope('min'):
89 | self.min_q_pi_tf = tf.minimum(self.q1_pi_tf, self.q2_pi_tf)
90 | self.min_q_tf = tf.minimum(self.q1_tf, self.q2_tf)
91 | with tf.variable_scope('v'):
92 | self.v_tf = nn(input_pi,layers_sizes=[self.hidden] * self.layers + [1])
93 |
94 | def get_Q(self, o, g, u):
95 | feed = {
96 | self.o_tf: o.reshape(-1, self.dimo),
97 | self.g_tf: g.reshape(-1, self.dimg),
98 | self.u_tf: u.reshape(-1, self.dimu)
99 | }
100 | return self.sess.run(self.min_q_tf, feed_dict=feed)
101 |
102 | def get_Q_pi(self, o, g):
103 | feed = {
104 | self.o_tf: o.reshape(-1, self.dimo),
105 | self.g_tf: g.reshape(-1, self.dimg)
106 | }
107 | return self.sess.run(self.min_q_pi_tf, feed_dict=feed)
108 |
109 | def get_V(self, o, g):
110 | feed = {
111 | self.o_tf: o.reshape(-1, self.dimo),
112 | self.g_tf: g.reshape(-1, self.dimg)
113 | }
114 | return self.sess.run(self.v_tf, feed_dict=feed)
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
--------------------------------------------------------------------------------
/mher/algos/ddpg.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from mher.algos.actor_critic import ActorCritic
4 | from mher.algos.algorithm import Algorithm
5 | from mher.algos.util import flatten_grads, get_var, store_args
6 | from mher.common import logger, tf_util
7 | from mher.common.mpi_adam import MpiAdam
8 |
9 |
10 | class DDPG(Algorithm):
11 | @store_args
12 | def __init__(self, buffer, input_dims, hidden, layers, polyak, Q_lr, pi_lr,
13 | norm_eps, norm_clip, max_u, action_l2, clip_obs, scope, subtract_goals,
14 | relative_goals, clip_pos_returns, clip_return, gamma, vloss_type='normal',
15 | priority=False, reuse=False, **kwargs):
16 | """
17 | see algorithm
18 | """
19 | super(DDPG, self).__init__(**self.__dict__)
20 |
21 | def _create_network(self, reuse=False):
22 | logger.info("Creating a DDPG agent with action space %d x %s..." % (self.dimu, self.max_u))
23 | self.sess = tf_util.get_session()
24 | # normalizer for input
25 | self._create_normalizer(reuse)
26 | batch_tf = self._get_batch_tf()
27 |
28 | # networks
29 | self._create_target_main(ActorCritic, reuse, batch_tf)
30 |
31 | # loss functions
32 | target_Q_pi_tf = self.target.Q_pi_tf
33 | clip_range = (-self.clip_return, 0. if self.clip_pos_returns else np.inf)
34 | target_tf = self._clip_target(batch_tf, clip_range, target_Q_pi_tf)
35 |
36 | self.abs_td_error_tf = tf.abs(tf.stop_gradient(target_tf) - self.main.Q_tf)
37 | self.Q_loss = tf.square(self.abs_td_error_tf)
38 | if self.priority:
39 | self.Q_loss_tf = tf.reduce_mean(batch_tf['w'] * self.Q_loss)
40 | else:
41 | self.Q_loss_tf = tf.reduce_mean(self.Q_loss)
42 | self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
43 | self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
44 |
45 | # varibles
46 | self.main_Q_var = get_var(self.scope + '/main/Q')
47 | self.main_pi_var = get_var(self.scope + '/main/pi')
48 | self.target_Q_var = get_var(self.scope + '/target/Q')
49 | self.target_pi_var = get_var(self.scope + '/target/pi')
50 |
51 | Q_grads_tf = tf.gradients(self.Q_loss_tf, self.main_Q_var)
52 | pi_grads_tf = tf.gradients(self.pi_loss_tf, self.main_pi_var)
53 | assert len(self.main_Q_var) == len(Q_grads_tf)
54 | assert len(self.main_pi_var) == len(pi_grads_tf)
55 | self.Q_grads_vars_tf = zip(Q_grads_tf, self.main_Q_var)
56 | self.pi_grads_vars_tf = zip(pi_grads_tf, self.main_pi_var)
57 | self.Q_grad_tf = flatten_grads(grads=Q_grads_tf, var_list=self.main_Q_var)
58 | self.pi_grad_tf = flatten_grads(grads=pi_grads_tf, var_list=self.main_pi_var)
59 |
60 | # optimizers
61 | self.Q_adam = MpiAdam(self.main_Q_var, scale_grad_by_procs=False)
62 | self.pi_adam = MpiAdam(self.main_pi_var, scale_grad_by_procs=False)
63 | self.main_vars = self.main_Q_var + self.main_pi_var
64 | self.target_vars = self.target_Q_var+ self.target_pi_var
65 | self.init_target_net_op = list(map(lambda v: v[0].assign(v[1]), zip(self.target_vars, self.main_vars)))
66 | self.update_target_net_op = list(map(lambda v: v[0].assign(self.polyak * v[0] + (1. - self.polyak) * v[1]),
67 | zip(self.target_vars, self.main_vars)))
68 |
69 | # initialize all variables
70 | self.global_vars = get_var(self.scope, key='global')
71 | tf.variables_initializer(self.global_vars).run()
72 | self._sync_optimizers()
73 | self._init_target_net()
74 |
75 | def _sync_optimizers(self):
76 | self.Q_adam.sync()
77 | self.pi_adam.sync()
78 |
79 | def _grads(self): # Avoid feed_dict here for performance!
80 | critic_loss, actor_loss, Q_grad, pi_grad, abs_td_error = self.sess.run([
81 | self.Q_loss_tf,
82 | self.main.Q_pi_tf,
83 | self.Q_grad_tf,
84 | self.pi_grad_tf,
85 | self.abs_td_error_tf
86 | ])
87 | return critic_loss, actor_loss, Q_grad, pi_grad, abs_td_error
88 |
89 | def _update(self, Q_grad, pi_grad):
90 | self.Q_adam.update(Q_grad, self.Q_lr)
91 | self.pi_adam.update(pi_grad, self.pi_lr)
92 |
93 |
94 |
--------------------------------------------------------------------------------
/mher/algos/sac.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from mher.algos.actor_critic import SAC_ActorCritic
4 | from mher.algos.algorithm import Algorithm
5 | from mher.algos.util import flatten_grads, get_var, store_args
6 | from mher.common import logger, tf_util
7 | from mher.common.mpi_adam import MpiAdam
8 | from mher.common import logger
9 |
10 |
11 | class SAC(Algorithm):
12 | @store_args
13 | def __init__(self, buffer, input_dims, hidden, layers, polyak, Q_lr, pi_lr,
14 | norm_eps, norm_clip, max_u, action_l2, clip_obs, scope, subtract_goals,
15 | relative_goals, clip_pos_returns, clip_return, gamma, vloss_type='normal',
16 | priority=False, sac_alpha=0.03, reuse=False, **kwargs):
17 | """Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).
18 | Args:
19 | sac_alpha: hyperparameter in SAC
20 | """
21 | super(SAC, self).__init__(**self.__dict__)
22 |
23 | def _name_variable(self, name, main=True):
24 | if main:
25 | return self.scope + '/main/' + name
26 | else:
27 | return self.scope + '/target/' + name
28 |
29 | def _create_network(self, reuse=False):
30 | logger.info("Creating a SAC agent with action space %d x %s..." % (self.dimu, self.max_u))
31 | self.sess = tf_util.get_session()
32 | self._create_normalizer(reuse)
33 | batch_tf = self._get_batch_tf()
34 |
35 | # networks
36 | self._create_target_main(SAC_ActorCritic, reuse, batch_tf)
37 |
38 | # loss functions
39 | clip_range = (-self.clip_return, 0. if self.clip_pos_returns else np.inf)
40 | target_tf = self._clip_target(batch_tf, clip_range, self.target.v_tf)
41 | q_backup_tf = tf.stop_gradient(target_tf)
42 | v_backup_tf = tf.stop_gradient(self.main.min_q_pi_tf - self.sac_alpha * self.main.logp_pi_tf)
43 |
44 | q1_loss_tf = 0.5 * tf.reduce_mean((q_backup_tf - self.main.q1_tf) ** 2)
45 | q2_loss_tf = 0.5 * tf.reduce_mean((q_backup_tf - self.main.q2_tf) ** 2)
46 | v_loss_tf = 0.5 * tf.reduce_mean((v_backup_tf - self.main.v_tf) ** 2)
47 | self.abs_tf_error_tf = tf.reduce_mean(tf.abs(q_backup_tf - self.main.q1_tf) + tf.abs(q_backup_tf -self.main.q2_tf))
48 |
49 | self.value_loss_tf = q1_loss_tf + q2_loss_tf + v_loss_tf
50 | self.pi_loss_tf = tf.reduce_mean(self.sac_alpha * self.main.logp_pi_tf - self.main.q1_pi_tf)
51 |
52 | # virables
53 | value_params = get_var(self._name_variable('q')) + get_var(self._name_variable('v'))
54 | pi_params = get_var(self._name_variable('pi'))
55 | # gradients
56 | V_grads_tf = tf.gradients(self.value_loss_tf, value_params)
57 | pi_grads_tf = tf.gradients(self.pi_loss_tf, pi_params)
58 | self.V_grad_tf = flatten_grads(grads=V_grads_tf, var_list=value_params)
59 | self.pi_grad_tf = flatten_grads(grads=pi_grads_tf, var_list=pi_params)
60 |
61 | # optimizers
62 | self.V_adam = MpiAdam(value_params, scale_grad_by_procs=False)
63 | self.pi_adam = MpiAdam(pi_params, scale_grad_by_procs=False)
64 |
65 | # polyak averaging
66 | self.main_vars = get_var(self._name_variable('pi')) + get_var(self._name_variable('q1')) + get_var(self._name_variable('q2')) + get_var(self._name_variable('v'))
67 | self.target_vars = get_var(self._name_variable('pi', main=False)) + get_var(self._name_variable('q1', main=False)) + get_var(self._name_variable('q2', main=False)) + get_var(self._name_variable('v', main=False))
68 |
69 | self.init_target_net_op = list(map(lambda v: v[0].assign(v[1]), zip(self.target_vars, self.main_vars)))
70 | self.update_target_net_op = list(map(lambda v: v[0].assign(self.polyak * v[0] + (1. - self.polyak) * v[1]), \
71 | zip(self.target_vars, self.main_vars)))
72 |
73 | # initialize all variables
74 | self.global_vars = get_var(self.scope, key='global')
75 | tf.variables_initializer(self.global_vars).run()
76 | self._sync_optimizers()
77 | self._init_target_net()
78 |
79 |
80 | def _sync_optimizers(self):
81 | self.V_adam.sync()
82 | self.pi_adam.sync()
83 |
84 | def _grads(self):
85 | critic_loss, actor_loss, V_grad, pi_grad, abs_td_error = self.sess.run([
86 | self.value_loss_tf,
87 | self.pi_loss_tf,
88 | self.V_grad_tf,
89 | self.pi_grad_tf,
90 | self.abs_tf_error_tf
91 | ])
92 | return critic_loss, actor_loss, V_grad, pi_grad, abs_td_error
93 |
94 | def _update(self, V_grad, pi_grad):
95 | self.V_adam.update(V_grad, self.Q_lr)
96 | self.pi_adam.update(pi_grad, self.pi_lr)
97 |
98 | # sac doesn't need noise
99 | def get_actions(self, o, ag, g, noise_eps=0., random_eps=0., use_target_net=False, compute_Q=False):
100 | o, g = self._preprocess_og(o=o, g=g, ag=ag)
101 | if not noise_eps and not random_eps:
102 | u = self.simple_get_action(o, g, use_target_net, deterministic=True)
103 | else:
104 | u = self.simple_get_action(o, g, use_target_net, deterministic=False)
105 |
106 | if compute_Q:
107 | Q_pi = self.get_Q_fun(o, g)
108 |
109 | u = np.clip(u, -self.max_u, self.max_u)
110 | if u.shape[0] == 1:
111 | u = u[0]
112 |
113 | if compute_Q:
114 | return [u, Q_pi]
115 | else:
116 | return u
117 |
118 | def simple_get_action(self, o, g, use_target_net=False, deterministic=False):
119 | o,g = self._preprocess_og(o=o,g=g)
120 | policy = self.target if use_target_net else self.main # in n-step self.target performs better
121 | act_tf = policy.mu_tf if deterministic else policy.pi_tf
122 | action, logp_pi, min_q_pi, q1_pi, q2_pi,log_std = self.sess.run( \
123 | [act_tf, policy.logp_pi_tf, policy.min_q_pi_tf, policy.q1_pi_tf, policy.q2_pi_tf, policy.log_std], \
124 | feed_dict={
125 | policy.o_tf: o.reshape(-1, self.dimo),
126 | policy.g_tf: g.reshape(-1, self.dimg)
127 | })
128 | return action
129 |
--------------------------------------------------------------------------------
/mher/algos/sac_utils.py:
--------------------------------------------------------------------------------
1 | from mher.algos.util import nn
2 | import tensorflow as tf
3 | import numpy as np
4 |
5 | EPS = 1e-6
6 | LOG_STD_MAX = 2
7 | LOG_STD_MIN = -20
8 |
9 | def gaussian_likelihood(x, mu, log_std):
10 | pre_sum = -0.5 * (((x-mu)/(tf.exp(log_std)+EPS))**2 + 2*log_std + np.log(2*np.pi))
11 | return tf.reduce_sum(pre_sum, axis=1)
12 |
13 | def clip_but_pass_gradient(x, l=-1., u=1.):
14 | clip_up = tf.cast(x > u, tf.float32)
15 | clip_low = tf.cast(x < l, tf.float32)
16 | return x + tf.stop_gradient((u - x)*clip_up + (l - x)*clip_low)
17 |
18 | def nn_gaussian_policy(x, a, dimu, layers_sizes,output_activation):
19 | act_dim = dimu
20 | net = nn(x, layers_sizes)
21 | mu = tf.layers.dense(net, act_dim, activation=output_activation)
22 |
23 | log_std = tf.layers.dense(net, act_dim, activation=None)
24 | log_std = tf.clip_by_value(log_std, LOG_STD_MIN, LOG_STD_MAX)
25 |
26 | std = tf.exp(log_std)
27 | pi = mu + tf.random_normal(tf.shape(mu)) * std
28 | logp_pi = gaussian_likelihood(pi, mu, log_std)
29 | return mu, pi, logp_pi
30 |
31 | def apply_squashing_func(mu, pi, logp_pi):
32 | mu = tf.tanh(mu)
33 | pi = tf.tanh(pi)
34 | # To avoid evil machine precision error, strictly clip 1-pi**2 to [0,1] range.
35 | logp_pi -= tf.reduce_sum(tf.log(clip_but_pass_gradient(1 - pi**2, l=0, u=1) + 1e-6), axis=1)
36 | return mu, pi, logp_pi
37 |
38 | def mlp_gaussian_policy(x, act_dim, hidden_sizes, activation, output_activation):
39 | net = nn(x, hidden_sizes)
40 | mu = tf.layers.dense(net, act_dim, activation=output_activation)
41 |
42 | log_std = tf.layers.dense(net, act_dim, activation=tf.tanh)
43 | log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)
44 |
45 | std = tf.exp(log_std)
46 | pi = mu + tf.random_normal(tf.shape(mu)) * std
47 | logp_pi = gaussian_likelihood(pi, mu, log_std)
48 | return mu, pi, logp_pi, log_std
49 |
50 |
--------------------------------------------------------------------------------
/mher/algos/util.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import importlib
3 | import inspect
4 | import os
5 | import subprocess
6 | import sys
7 |
8 | import numpy as np
9 | import tensorflow as tf
10 | from mher.common import tf_util as U
11 |
12 |
13 | def dims_to_shapes(input_dims):
14 | return {key: tuple([val]) if val > 0 else tuple() for key, val in input_dims.items()}
15 |
16 | def store_args(method):
17 | """Stores provided method args as instance attributes.
18 | """
19 | argspec = inspect.getfullargspec(method)
20 | defaults = {}
21 | if argspec.defaults is not None:
22 | defaults = dict(
23 | zip(argspec.args[-len(argspec.defaults):], argspec.defaults))
24 | if argspec.kwonlydefaults is not None:
25 | defaults.update(argspec.kwonlydefaults)
26 | arg_names = argspec.args[1:]
27 |
28 | @functools.wraps(method)
29 | def wrapper(*positional_args, **keyword_args):
30 | self = positional_args[0]
31 | # Get default arg values
32 | args = defaults.copy()
33 | # Add provided arg values
34 | for name, value in zip(arg_names, positional_args[1:]):
35 | args[name] = value
36 | args.update(keyword_args)
37 | self.__dict__.update(args)
38 | return method(*positional_args, **keyword_args)
39 |
40 | return wrapper
41 |
42 |
43 | def import_function(spec):
44 | """Import a function identified by a string like "pkg.module:fn_name".
45 | """
46 | mod_name, fn_name = spec.split(':')
47 | module = importlib.import_module(mod_name)
48 | fn = getattr(module, fn_name)
49 | return fn
50 |
51 |
52 | def flatten_grads(var_list, grads):
53 | """Flattens a variables and their gradients.
54 | """
55 | return tf.concat([tf.reshape(grad, [U.numel(v)])
56 | for (v, grad) in zip(var_list, grads)], 0)
57 |
58 |
59 | def nn(input, layers_sizes, reuse=None, flatten=False, name="", trainable='True', init='xavier', init_range=0.01):
60 | """Creates a simple neural network
61 | """
62 | if init == 'xavier':
63 | initializer = tf.contrib.layers.xavier_initializer()
64 | elif init == 'random':
65 | initializer = tf.random_uniform_initializer(minval=-init_range, maxval=init_range)
66 | else:
67 | raise NotImplementedError
68 |
69 | for i, size in enumerate(layers_sizes):
70 | activation = tf.nn.relu if i < len(layers_sizes) - 1 else None
71 | input = tf.layers.dense(inputs=input,
72 | units=size,
73 | kernel_initializer=initializer,
74 | reuse=reuse,
75 | name=name + '_' + str(i),
76 | trainable=trainable)
77 | if activation:
78 | input = activation(input)
79 | if flatten:
80 | assert layers_sizes[-1] == 1
81 | input = tf.reshape(input, [-1])
82 | return input
83 |
84 |
85 | def install_mpi_excepthook():
86 | import sys
87 |
88 | from mpi4py import MPI
89 | old_hook = sys.excepthook
90 |
91 | def new_hook(a, b, c):
92 | old_hook(a, b, c)
93 | sys.stdout.flush()
94 | sys.stderr.flush()
95 | MPI.COMM_WORLD.Abort()
96 | sys.excepthook = new_hook
97 |
98 |
99 | def mpi_fork(n, extra_mpi_args=[]):
100 | """Re-launches the current script with workers
101 | Returns "parent" for original parent, "child" for MPI children
102 | """
103 | if n <= 1:
104 | return "child"
105 | if os.getenv("IN_MPI") is None:
106 | env = os.environ.copy()
107 | env.update(
108 | MKL_NUM_THREADS="1",
109 | OMP_NUM_THREADS="1",
110 | IN_MPI="1"
111 | )
112 | # "-bind-to core" is crucial for good performance
113 | args = ["mpirun", "-np", str(n)] + \
114 | extra_mpi_args + \
115 | [sys.executable]
116 |
117 | args += sys.argv
118 | subprocess.check_call(args, env=env)
119 | return "parent"
120 | else:
121 | install_mpi_excepthook()
122 | return "child"
123 |
124 |
125 | def convert_episode_to_batch_major(episode):
126 | """Converts an episode to have the batch dimension in the major (first)
127 | dimension.
128 | """
129 | episode_batch = {}
130 | for key in episode.keys():
131 | val = np.array(episode[key]).copy()
132 | # make inputs batch-major instead of time-major
133 | episode_batch[key] = val.swapaxes(0, 1)
134 |
135 | return episode_batch
136 |
137 |
138 | def transitions_in_episode_batch(episode_batch):
139 | """Number of transitions in a given episode batch.
140 | """
141 | shape = episode_batch['u'].shape
142 | return shape[0] * shape[1]
143 |
144 |
145 | def reshape_for_broadcasting(source, target):
146 | """Reshapes a tensor (source) to have the correct shape and dtype of the target
147 | before broadcasting it with MPI.
148 | """
149 | dim = len(target.get_shape())
150 | shape = ([1] * (dim - 1)) + [-1]
151 | return tf.reshape(tf.cast(source, target.dtype), shape)
152 |
153 | def get_var(scope, key='trainable'):
154 | if key == 'trainable':
155 | tf_key = tf.GraphKeys.TRAINABLE_VARIABLES
156 | elif key == 'global':
157 | tf_key = tf.GraphKeys.GLOBAL_VARIABLES
158 | else:
159 | print('No such key {} for tensorflow'.format(key))
160 | raise NotImplementedError
161 | res = tf.get_collection(tf_key, scope=scope)
162 | return res
163 |
164 |
--------------------------------------------------------------------------------
/mher/buffers/__init__.py:
--------------------------------------------------------------------------------
1 | from mher.buffers.replay_buffer import ReplayBuffer
2 | from mher.buffers.prioritized_buffer import PrioritizedReplayBuffer
--------------------------------------------------------------------------------
/mher/buffers/__pycache__/replay_buffer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/buffers/__pycache__/replay_buffer.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/buffers/__pycache__/samplers.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/buffers/__pycache__/samplers.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/buffers/prioritized_buffer.py:
--------------------------------------------------------------------------------
1 | import threading
2 |
3 | import numpy as np
4 | from mher.buffers.replay_buffer import ReplayBuffer
5 | from mher.common.segment_tree import MinSegmentTree, SumSegmentTree
6 |
7 |
8 | class PrioritizedReplayBuffer(ReplayBuffer):
9 | def __init__(self, buffer_shapes, size_in_transitions, T, sampler):
10 | """Create Prioritized Replay buffer"""
11 | super(PrioritizedReplayBuffer, self).__init__(buffer_shapes, size_in_transitions, T, sampler)
12 |
13 | def store_episode(self, episode_batch):
14 | """episode_batch: array(batch_size x (T or T+1) x dim_key)"""
15 | episode_idxs = super().store_episode(episode_batch)
16 | # save priority
17 | if not hasattr(episode_idxs, '__len__'):
18 | episode_idxs = np.array([episode_idxs])
19 | self.sampler.update_new_priorities(episode_idxs)
20 |
21 | def update_priorities(self, idxs, priorities):
22 | self.sampler.update_priorities(idxs, priorities)
23 |
24 |
--------------------------------------------------------------------------------
/mher/buffers/replay_buffer.py:
--------------------------------------------------------------------------------
1 | import threading
2 |
3 | import numpy as np
4 |
5 |
6 | class ReplayBuffer:
7 | def __init__(self, buffer_shapes, size_in_transitions, T, sampler):
8 | """Creates a replay buffer.
9 | Args:
10 | buffer_shapes (dict of ints): the shape for all buffers that are used in buffer
11 | size_in_transitions (int): the size of the buffer, measured in transitions
12 | T (int): the time horizon for episodes
13 | sampler (class): sampler class used to sample from buffer
14 | """
15 | self.buffer_shapes = buffer_shapes
16 | self.size = size_in_transitions // T # size in episodes
17 | self.T = T
18 | self.sampler = sampler
19 | # self.buffers is {key: array(size_in_episodes x T or T+1 x dim_key)}
20 | self.buffers = {key: np.empty([self.size, *shape]) for key, shape in buffer_shapes.items()}
21 | # memory management
22 | self.point = 0
23 | self.current_size = 0
24 | self.n_transitions_stored = 0
25 | self.lock = threading.Lock()
26 |
27 | @property
28 | def full(self):
29 | with self.lock:
30 | return self.current_size == self.size
31 |
32 | def sample(self):
33 | """Returns a dict {key: array(batch_size x shapes[key])}
34 | """
35 | buffers = {}
36 | with self.lock:
37 | assert self.current_size > 0
38 | for key in self.buffers.keys():
39 | buffers[key] = self.buffers[key][:self.current_size]
40 | # make o_2 and ag_2
41 | if 'o_2' not in buffers and 'ag_2' not in buffers:
42 | buffers['o_2'] = buffers['o'][:, 1:, :]
43 | buffers['ag_2'] = buffers['ag'][:, 1:, :]
44 | transitions = self.sampler.sample(buffers)
45 | return transitions
46 |
47 | def store_episode(self, episode_batch):
48 | """episode_batch: array(rollout_batch_size x (T or T+1) x dim_key)"""
49 | buffer_sizes = [len(episode_batch[key]) for key in episode_batch.keys()]
50 | assert np.all(np.array(buffer_sizes) == buffer_sizes[0])
51 | buffer_size = buffer_sizes[0]
52 | with self.lock:
53 | idxs = self._get_storage_idx(buffer_size) #use ordered idx get lower performance
54 | # load inputs into buffers
55 | for key in episode_batch.keys():
56 | if key in self.buffers:
57 | self.buffers[key][idxs] = episode_batch[key]
58 | self.n_transitions_stored += buffer_size * self.T
59 | return idxs
60 |
61 | def get_current_episode_size(self):
62 | with self.lock:
63 | return self.current_size
64 |
65 | def get_current_size(self):
66 | with self.lock:
67 | return self.current_size * self.T
68 |
69 | def get_transitions_stored(self):
70 | with self.lock:
71 | return self.n_transitions_stored
72 |
73 | def clear_buffer(self):
74 | with self.lock:
75 | self.current_size = 0
76 |
77 | # if full, insert randomly
78 | def _get_storage_idx(self, inc=None):
79 | inc = inc or 1 # size increment
80 | assert inc <= self.size, "Batch committed to replay is too large!"
81 | # go consecutively until you hit the end, and then go randomly.
82 | if self.current_size+inc <= self.size:
83 | idx = np.arange(self.current_size, self.current_size+inc)
84 | elif self.current_size < self.size:
85 | overflow = inc - (self.size - self.current_size)
86 | idx_a = np.arange(self.current_size, self.size)
87 | idx_b = np.random.randint(0, self.current_size, overflow)
88 | idx = np.concatenate([idx_a, idx_b])
89 | else:
90 | idx = np.random.randint(0, self.size, inc)
91 | # update replay size
92 | self.current_size = min(self.size, self.current_size+inc)
93 |
94 | if inc == 1:
95 | idx = idx[0]
96 | return idx
97 |
98 | # if full, insert in order
99 | def _get_ordered_storage_idx(self, inc=None):
100 | inc = inc or 1 # size increment
101 | assert inc <= self.size, "Batch committed to replay is too large!"
102 |
103 | if self.point+inc <= self.size - 1:
104 | idx = np.arange(self.point, self.point + inc)
105 | else:
106 | overflow = inc - (self.size - self.point)
107 | idx_a = np.arange(self.point, self.size)
108 | idx_b = np.arange(0, overflow)
109 | idx = np.concatenate([idx_a, idx_b])
110 |
111 | self.point = (self.point + inc) % self.size
112 |
113 | # update replay size, don't add when it already surpass self.size
114 | if self.current_size < self.size:
115 | self.current_size = min(self.size, self.current_size+inc)
116 |
117 | if inc == 1:
118 | idx = idx[0]
119 | return idx
120 |
--------------------------------------------------------------------------------
/mher/common/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa F403
2 | from mher.common.console_util import *
3 | from mher.common.dataset import Dataset
4 | from mher.common.math_util import *
5 | from mher.common.misc_util import *
6 |
--------------------------------------------------------------------------------
/mher/common/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/__pycache__/atari_wrappers.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/__pycache__/atari_wrappers.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/__pycache__/cmd_util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/__pycache__/cmd_util.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/__pycache__/console_util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/__pycache__/console_util.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/__pycache__/dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/__pycache__/dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/__pycache__/import_util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/__pycache__/import_util.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/__pycache__/init_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/__pycache__/init_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/__pycache__/logger.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/__pycache__/logger.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/__pycache__/math_util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/__pycache__/math_util.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/__pycache__/misc_util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/__pycache__/misc_util.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/__pycache__/monitor.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/__pycache__/monitor.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/__pycache__/mpi_adam.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/__pycache__/mpi_adam.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/__pycache__/mpi_moments.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/__pycache__/mpi_moments.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/__pycache__/retro_wrappers.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/__pycache__/retro_wrappers.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/__pycache__/tf_util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/__pycache__/tf_util.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/__pycache__/tile_images.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/__pycache__/tile_images.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/__pycache__/wrappers.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/__pycache__/wrappers.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/cg.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | def cg(f_Ax, b, cg_iters=10, callback=None, verbose=False, residual_tol=1e-10):
3 | """
4 | Demmel p 312
5 | """
6 | p = b.copy()
7 | r = b.copy()
8 | x = np.zeros_like(b)
9 | rdotr = r.dot(r)
10 |
11 | fmtstr = "%10i %10.3g %10.3g"
12 | titlestr = "%10s %10s %10s"
13 | if verbose: print(titlestr % ("iter", "residual norm", "soln norm"))
14 |
15 | for i in range(cg_iters):
16 | if callback is not None:
17 | callback(x)
18 | if verbose: print(fmtstr % (i, rdotr, np.linalg.norm(x)))
19 | z = f_Ax(p)
20 | v = rdotr / p.dot(z)
21 | x += v*p
22 | r -= v*z
23 | newrdotr = r.dot(r)
24 | mu = newrdotr/rdotr
25 | p = r + mu*p
26 |
27 | rdotr = newrdotr
28 | if rdotr < residual_tol:
29 | break
30 |
31 | if callback is not None:
32 | callback(x)
33 | if verbose: print(fmtstr % (i+1, rdotr, np.linalg.norm(x))) # pylint: disable=W0631
34 | return x
35 |
--------------------------------------------------------------------------------
/mher/common/cmd_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for command line
3 | """
4 | import os
5 | import gym
6 | import argparse
7 | from mher.common import logger
8 |
9 | def common_arg_parser():
10 | """
11 | Create common used argparses for training
12 | """
13 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
14 | parser.add_argument('--env', help='environment ID', type=str, default='FetchReach-v1')
15 | parser.add_argument('--seed', help='set seed', type=int, default=None)
16 | parser.add_argument('--alg', help='Algorithm', type=str, default='her')
17 | parser.add_argument('--random_init', help='Random init epochs before training',default=0, type=int)
18 | parser.add_argument('--num_epoch', type=int, default=100)
19 | parser.add_argument('--num_timesteps', type=float, default=1e6)
20 | parser.add_argument('--network', help='network type (mlp, cnn, lstm, cnn_lstm, conv_only)', default='mlp', type=str)
21 | parser.add_argument('--num_env', help='Number of environment being run in parallel. Default set to 1', default=1, type=int)
22 | parser.add_argument('--save_path', help='Path to save trained model to', default=None, type=str)
23 | parser.add_argument('--policy_save_interval', default=10, type=int)
24 | parser.add_argument('--load_path', help='Path to load trained model to', default=None, type=str)
25 | parser.add_argument('--log_path', help='Directory to save learning curve data.', default=None, type=str)
26 | parser.add_argument('--play_episodes', help='Number of episodes to play after training', default=1, type=int)
27 | parser.add_argument('--play_no_training', default=False, action='store_true')
28 | return parser
29 |
30 | def parse_unknown_args(args):
31 | """
32 | Parse arguments not consumed by arg parser into a dictionary
33 | """
34 | retval = {}
35 | preceded_by_key = False
36 | for arg in args:
37 | if arg.startswith('--'):
38 | if '=' in arg:
39 | key = arg.split('=')[0][2:]
40 | value = arg.split('=')[1]
41 | retval[key] = value
42 | else:
43 | key = arg[2:]
44 | preceded_by_key = True
45 | elif preceded_by_key:
46 | retval[key] = arg
47 | preceded_by_key = False
48 |
49 | return retval
50 |
51 |
52 | def parse_cmdline_kwargs(args):
53 | '''
54 | convert a list of '='-spaced command-line arguments to a dictionary, evaluating python objects when possible
55 | '''
56 | def parse(v):
57 | assert isinstance(v, str)
58 | try:
59 | return eval(v)
60 | except (NameError, SyntaxError):
61 | return v
62 | return {k: parse(v) for k,v in parse_unknown_args(args).items()}
63 |
64 | def preprocess_kwargs(args):
65 | arg_parser = common_arg_parser()
66 | args, unknown_args = arg_parser.parse_known_args(args)
67 | extra_args = parse_cmdline_kwargs(unknown_args)
68 | return args, extra_args
--------------------------------------------------------------------------------
/mher/common/console_util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from contextlib import contextmanager
3 | import numpy as np
4 | import time
5 | import shlex
6 | import subprocess
7 |
8 | # ================================================================
9 | # Misc
10 | # ================================================================
11 |
12 | def fmt_row(width, row, header=False):
13 | out = " | ".join(fmt_item(x, width) for x in row)
14 | if header: out = out + "\n" + "-"*len(out)
15 | return out
16 |
17 | def fmt_item(x, l):
18 | if isinstance(x, np.ndarray):
19 | assert x.ndim==0
20 | x = x.item()
21 | if isinstance(x, (float, np.float32, np.float64)):
22 | v = abs(x)
23 | if (v < 1e-4 or v > 1e+4) and v > 0:
24 | rep = "%7.2e" % x
25 | else:
26 | rep = "%7.5f" % x
27 | else: rep = str(x)
28 | return " "*(l - len(rep)) + rep
29 |
30 | color2num = dict(
31 | gray=30,
32 | red=31,
33 | green=32,
34 | yellow=33,
35 | blue=34,
36 | magenta=35,
37 | cyan=36,
38 | white=37,
39 | crimson=38
40 | )
41 |
42 | def colorize(string, color='green', bold=False, highlight=False):
43 | attr = []
44 | num = color2num[color]
45 | if highlight: num += 10
46 | attr.append(str(num))
47 | if bold: attr.append('1')
48 | return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string)
49 |
50 | def print_cmd(cmd, dry=False):
51 | if isinstance(cmd, str): # for shell=True
52 | pass
53 | else:
54 | cmd = ' '.join(shlex.quote(arg) for arg in cmd)
55 | print(colorize(('CMD: ' if not dry else 'DRY: ') + cmd))
56 |
57 |
58 | def get_git_commit(cwd=None):
59 | return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], cwd=cwd).decode('utf8')
60 |
61 | def get_git_commit_message(cwd=None):
62 | return subprocess.check_output(['git', 'show', '-s', '--format=%B', 'HEAD'], cwd=cwd).decode('utf8')
63 |
64 | def ccap(cmd, dry=False, env=None, **kwargs):
65 | print_cmd(cmd, dry)
66 | if not dry:
67 | subprocess.check_call(cmd, env=env, **kwargs)
68 |
69 |
70 | MESSAGE_DEPTH = 0
71 |
72 | @contextmanager
73 | def timed(msg):
74 | global MESSAGE_DEPTH #pylint: disable=W0603
75 | print(colorize('\t'*MESSAGE_DEPTH + '=: ' + msg, color='magenta'))
76 | tstart = time.time()
77 | MESSAGE_DEPTH += 1
78 | yield
79 | MESSAGE_DEPTH -= 1
80 | print(colorize('\t'*MESSAGE_DEPTH + "done in %.3f seconds"%(time.time() - tstart), color='magenta'))
81 |
--------------------------------------------------------------------------------
/mher/common/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class Dataset(object):
4 | def __init__(self, data_map, deterministic=False, shuffle=True):
5 | self.data_map = data_map
6 | self.deterministic = deterministic
7 | self.enable_shuffle = shuffle
8 | self.n = next(iter(data_map.values())).shape[0]
9 | self._next_id = 0
10 | self.shuffle()
11 |
12 | def shuffle(self):
13 | if self.deterministic:
14 | return
15 | perm = np.arange(self.n)
16 | np.random.shuffle(perm)
17 |
18 | for key in self.data_map:
19 | self.data_map[key] = self.data_map[key][perm]
20 |
21 | self._next_id = 0
22 |
23 | def next_batch(self, batch_size):
24 | if self._next_id >= self.n and self.enable_shuffle:
25 | self.shuffle()
26 |
27 | cur_id = self._next_id
28 | cur_batch_size = min(batch_size, self.n - self._next_id)
29 | self._next_id += cur_batch_size
30 |
31 | data_map = dict()
32 | for key in self.data_map:
33 | data_map[key] = self.data_map[key][cur_id:cur_id+cur_batch_size]
34 | return data_map
35 |
36 | def iterate_once(self, batch_size):
37 | if self.enable_shuffle: self.shuffle()
38 |
39 | while self._next_id <= self.n - batch_size:
40 | yield self.next_batch(batch_size)
41 | self._next_id = 0
42 |
43 | def subset(self, num_elements, deterministic=True):
44 | data_map = dict()
45 | for key in self.data_map:
46 | data_map[key] = self.data_map[key][:num_elements]
47 | return Dataset(data_map, deterministic)
48 |
49 |
50 | def iterbatches(arrays, *, num_batches=None, batch_size=None, shuffle=True, include_final_partial_batch=True):
51 | assert (num_batches is None) != (batch_size is None), 'Provide num_batches or batch_size, but not both'
52 | arrays = tuple(map(np.asarray, arrays))
53 | n = arrays[0].shape[0]
54 | assert all(a.shape[0] == n for a in arrays[1:])
55 | inds = np.arange(n)
56 | if shuffle: np.random.shuffle(inds)
57 | sections = np.arange(0, n, batch_size)[1:] if num_batches is None else num_batches
58 | for batch_inds in np.array_split(inds, sections):
59 | if include_final_partial_batch or len(batch_inds) == batch_size:
60 | yield tuple(a[batch_inds] for a in arrays)
61 |
--------------------------------------------------------------------------------
/mher/common/import_util.py:
--------------------------------------------------------------------------------
1 | from importlib import import_module
2 |
3 |
4 | def get_alg_module(alg, submodule=None):
5 | submodule = submodule or alg
6 | try:
7 | # first try to import the alg module from mher
8 | alg_module = import_module('.'.join(['mher', alg, submodule]))
9 | except ImportError:
10 | # then from rl_algs
11 | alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule]))
12 |
13 | return alg_module
--------------------------------------------------------------------------------
/mher/common/init_utils.py:
--------------------------------------------------------------------------------
1 | import gym
2 | from collections import defaultdict
3 |
4 | def init_mpi_import():
5 | '''
6 | import mpi used for multi-process training
7 | '''
8 | try:
9 | from mpi4py import MPI
10 | except ImportError:
11 | MPI = None
12 | return MPI
13 |
14 |
15 | def init_environment_import():
16 | '''
17 | import required environment code base
18 | '''
19 | # try:
20 | # import pybullet_envs
21 | # except ImportError:
22 | # pybullet_envs = None
23 |
24 | # try:
25 | # import roboschool
26 | # except ImportError:
27 | # roboschool = None
28 |
29 | # support mulitworld
30 | # try:
31 | # import multiworld
32 | # multiworld.register_all_envs()
33 | # except ImportError:
34 | # multiworld = None
35 |
36 | _game_envs = defaultdict(set)
37 | for env in gym.envs.registry.all():
38 | # TODO: solve this with regexes
39 | try:
40 | env_type = env.entry_point.split(':')[0].split('.')[-1]
41 | _game_envs[env_type].add(env.id)
42 | except:
43 | pass
44 | return _game_envs
--------------------------------------------------------------------------------
/mher/common/input.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from gym.spaces import Discrete, Box, MultiDiscrete
4 |
5 | def observation_placeholder(ob_space, batch_size=None, name='Ob'):
6 | '''
7 | Create placeholder to feed observations into of the size appropriate to the observation space
8 |
9 | Parameters:
10 | ----------
11 |
12 | ob_space: gym.Space observation space
13 |
14 | batch_size: int size of the batch to be fed into input. Can be left None in most cases.
15 |
16 | name: str name of the placeholder
17 |
18 | Returns:
19 | -------
20 |
21 | tensorflow placeholder tensor
22 | '''
23 |
24 | assert isinstance(ob_space, Discrete) or isinstance(ob_space, Box) or isinstance(ob_space, MultiDiscrete), \
25 | 'Can only deal with Discrete and Box observation spaces for now'
26 |
27 | dtype = ob_space.dtype
28 | if dtype == np.int8:
29 | dtype = np.uint8
30 |
31 | return tf.placeholder(shape=(batch_size,) + ob_space.shape, dtype=dtype, name=name)
32 |
33 |
34 | def observation_input(ob_space, batch_size=None, name='Ob'):
35 | '''
36 | Create placeholder to feed observations into of the size appropriate to the observation space, and add input
37 | encoder of the appropriate type.
38 | '''
39 |
40 | placeholder = observation_placeholder(ob_space, batch_size, name)
41 | return placeholder, encode_observation(ob_space, placeholder)
42 |
43 | def encode_observation(ob_space, placeholder):
44 | '''
45 | Encode input in the way that is appropriate to the observation space
46 |
47 | Parameters:
48 | ----------
49 |
50 | ob_space: gym.Space observation space
51 |
52 | placeholder: tf.placeholder observation input placeholder
53 | '''
54 | if isinstance(ob_space, Discrete):
55 | return tf.to_float(tf.one_hot(placeholder, ob_space.n))
56 | elif isinstance(ob_space, Box):
57 | return tf.to_float(placeholder)
58 | elif isinstance(ob_space, MultiDiscrete):
59 | placeholder = tf.cast(placeholder, tf.int32)
60 | one_hots = [tf.to_float(tf.one_hot(placeholder[..., i], ob_space.nvec[i])) for i in range(placeholder.shape[-1])]
61 | return tf.concat(one_hots, axis=-1)
62 | else:
63 | raise NotImplementedError
64 |
65 |
--------------------------------------------------------------------------------
/mher/common/math_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.signal
3 |
4 |
5 | def discount(x, gamma):
6 | """
7 | computes discounted sums along 0th dimension of x.
8 |
9 | inputs
10 | ------
11 | x: ndarray
12 | gamma: float
13 |
14 | outputs
15 | -------
16 | y: ndarray with same shape as x, satisfying
17 |
18 | y[t] = x[t] + gamma*x[t+1] + gamma^2*x[t+2] + ... + gamma^k x[t+k],
19 | where k = len(x) - t - 1
20 |
21 | """
22 | assert x.ndim >= 1
23 | return scipy.signal.lfilter([1],[1,-gamma],x[::-1], axis=0)[::-1]
24 |
25 | def explained_variance(ypred,y):
26 | """
27 | Computes fraction of variance that ypred explains about y.
28 | Returns 1 - Var[y-ypred] / Var[y]
29 |
30 | interpretation:
31 | ev=0 => might as well have predicted zero
32 | ev=1 => perfect prediction
33 | ev<0 => worse than just predicting zero
34 |
35 | """
36 | assert y.ndim == 1 and ypred.ndim == 1
37 | vary = np.var(y)
38 | return np.nan if vary==0 else 1 - np.var(y-ypred)/vary
39 |
40 | def explained_variance_2d(ypred, y):
41 | assert y.ndim == 2 and ypred.ndim == 2
42 | vary = np.var(y, axis=0)
43 | out = 1 - np.var(y-ypred)/vary
44 | out[vary < 1e-10] = 0
45 | return out
46 |
47 | def ncc(ypred, y):
48 | return np.corrcoef(ypred, y)[1,0]
49 |
50 | def flatten_arrays(arrs):
51 | return np.concatenate([arr.flat for arr in arrs])
52 |
53 | def unflatten_vector(vec, shapes):
54 | i=0
55 | arrs = []
56 | for shape in shapes:
57 | size = np.prod(shape)
58 | arr = vec[i:i+size].reshape(shape)
59 | arrs.append(arr)
60 | i += size
61 | return arrs
62 |
63 | def discount_with_boundaries(X, New, gamma):
64 | """
65 | X: 2d array of floats, time x features
66 | New: 2d array of bools, indicating when a new episode has started
67 | """
68 | Y = np.zeros_like(X)
69 | T = X.shape[0]
70 | Y[T-1] = X[T-1]
71 | for t in range(T-2, -1, -1):
72 | Y[t] = X[t] + gamma * Y[t+1] * (1 - New[t+1])
73 | return Y
74 |
75 | def test_discount_with_boundaries():
76 | gamma=0.9
77 | x = np.array([1.0, 2.0, 3.0, 4.0], 'float32')
78 | starts = [1.0, 0.0, 0.0, 1.0]
79 | y = discount_with_boundaries(x, starts, gamma)
80 | assert np.allclose(y, [
81 | 1 + gamma * 2 + gamma**2 * 3,
82 | 2 + gamma * 3,
83 | 3,
84 | 4
85 | ])
86 |
--------------------------------------------------------------------------------
/mher/common/mpi_adam.py:
--------------------------------------------------------------------------------
1 | import mher.common.tf_util as U
2 | import tensorflow as tf
3 | import numpy as np
4 | try:
5 | from mpi4py import MPI
6 | except ImportError:
7 | MPI = None
8 |
9 |
10 | class MpiAdam(object):
11 | def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None):
12 | self.var_list = var_list
13 | self.beta1 = beta1
14 | self.beta2 = beta2
15 | self.epsilon = epsilon
16 | self.scale_grad_by_procs = scale_grad_by_procs
17 | size = sum(U.numel(v) for v in var_list)
18 | self.m = np.zeros(size, 'float32')
19 | self.v = np.zeros(size, 'float32')
20 | self.t = 0
21 | self.setfromflat = U.SetFromFlat(var_list)
22 | self.getflat = U.GetFlat(var_list)
23 | self.comm = MPI.COMM_WORLD if comm is None and MPI is not None else comm
24 |
25 | def update(self, localg, stepsize):
26 | if self.t % 100 == 0:
27 | self.check_synced()
28 | localg = localg.astype('float32')
29 | if self.comm is not None:
30 | globalg = np.zeros_like(localg)
31 | self.comm.Allreduce(localg, globalg, op=MPI.SUM)
32 | if self.scale_grad_by_procs:
33 | globalg /= self.comm.Get_size()
34 | else:
35 | globalg = np.copy(localg)
36 |
37 | self.t += 1
38 | a = stepsize * np.sqrt(1 - self.beta2**self.t)/(1 - self.beta1**self.t)
39 | self.m = self.beta1 * self.m + (1 - self.beta1) * globalg
40 | self.v = self.beta2 * self.v + (1 - self.beta2) * (globalg * globalg)
41 | step = (- a) * self.m / (np.sqrt(self.v) + self.epsilon)
42 | self.setfromflat(self.getflat() + step)
43 |
44 | def sync(self):
45 | if self.comm is None:
46 | return
47 | theta = self.getflat()
48 | self.comm.Bcast(theta, root=0)
49 | self.setfromflat(theta)
50 |
51 | def check_synced(self):
52 | if self.comm is None:
53 | return
54 | if self.comm.Get_rank() == 0: # this is root
55 | theta = self.getflat()
56 | self.comm.Bcast(theta, root=0)
57 | else:
58 | thetalocal = self.getflat()
59 | thetaroot = np.empty_like(thetalocal)
60 | self.comm.Bcast(thetaroot, root=0)
61 | assert (thetaroot == thetalocal).all(), (thetaroot, thetalocal)
62 |
63 | @U.in_session
64 | def test_MpiAdam():
65 | np.random.seed(0)
66 | tf.set_random_seed(0)
67 |
68 | a = tf.Variable(np.random.randn(3).astype('float32'))
69 | b = tf.Variable(np.random.randn(2,5).astype('float32'))
70 | loss = tf.reduce_sum(tf.square(a)) + tf.reduce_sum(tf.sin(b))
71 |
72 | stepsize = 1e-2
73 | update_op = tf.train.AdamOptimizer(stepsize).minimize(loss)
74 | do_update = U.function([], loss, updates=[update_op])
75 |
76 | tf.get_default_session().run(tf.global_variables_initializer())
77 | losslist_ref = []
78 | for i in range(10):
79 | l = do_update()
80 | print(i, l)
81 | losslist_ref.append(l)
82 |
83 |
84 |
85 | tf.set_random_seed(0)
86 | tf.get_default_session().run(tf.global_variables_initializer())
87 |
88 | var_list = [a,b]
89 | lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)])
90 | adam = MpiAdam(var_list)
91 |
92 | losslist_test = []
93 | for i in range(10):
94 | l,g = lossandgrad()
95 | adam.update(g, stepsize)
96 | print(i,l)
97 | losslist_test.append(l)
98 |
99 | np.testing.assert_allclose(np.array(losslist_ref), np.array(losslist_test), atol=1e-4)
100 |
101 |
102 | if __name__ == '__main__':
103 | test_MpiAdam()
104 |
--------------------------------------------------------------------------------
/mher/common/mpi_adam_optimizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from mher.common import tf_util as U
4 | from mher.common.tests.test_with_mpi import with_mpi
5 | from mher import logger
6 | try:
7 | from mpi4py import MPI
8 | except ImportError:
9 | MPI = None
10 |
11 | class MpiAdamOptimizer(tf.train.AdamOptimizer):
12 | """Adam optimizer that averages gradients across mpi processes."""
13 | def __init__(self, comm, grad_clip=None, mpi_rank_weight=1, **kwargs):
14 | self.comm = comm
15 | self.grad_clip = grad_clip
16 | self.mpi_rank_weight = mpi_rank_weight
17 | tf.train.AdamOptimizer.__init__(self, **kwargs)
18 | def compute_gradients(self, loss, var_list, **kwargs):
19 | grads_and_vars = tf.train.AdamOptimizer.compute_gradients(self, loss, var_list, **kwargs)
20 | grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None]
21 | flat_grad = tf.concat([tf.reshape(g, (-1,)) for g, v in grads_and_vars], axis=0) * self.mpi_rank_weight
22 | shapes = [v.shape.as_list() for g, v in grads_and_vars]
23 | sizes = [int(np.prod(s)) for s in shapes]
24 |
25 | total_weight = np.zeros(1, np.float32)
26 | self.comm.Allreduce(np.array([self.mpi_rank_weight], dtype=np.float32), total_weight, op=MPI.SUM)
27 | total_weight = total_weight[0]
28 |
29 | buf = np.zeros(sum(sizes), np.float32)
30 | countholder = [0] # Counts how many times _collect_grads has been called
31 | stat = tf.reduce_sum(grads_and_vars[0][1]) # sum of first variable
32 | def _collect_grads(flat_grad, np_stat):
33 | if self.grad_clip is not None:
34 | gradnorm = np.linalg.norm(flat_grad)
35 | if gradnorm > 1:
36 | flat_grad /= gradnorm
37 | logger.logkv_mean('gradnorm', gradnorm)
38 | logger.logkv_mean('gradclipfrac', float(gradnorm > 1))
39 | self.comm.Allreduce(flat_grad, buf, op=MPI.SUM)
40 | np.divide(buf, float(total_weight), out=buf)
41 | if countholder[0] % 100 == 0:
42 | check_synced(np_stat, self.comm)
43 | countholder[0] += 1
44 | return buf
45 |
46 | avg_flat_grad = tf.py_func(_collect_grads, [flat_grad, stat], tf.float32)
47 | avg_flat_grad.set_shape(flat_grad.shape)
48 | avg_grads = tf.split(avg_flat_grad, sizes, axis=0)
49 | avg_grads_and_vars = [(tf.reshape(g, v.shape), v)
50 | for g, (_, v) in zip(avg_grads, grads_and_vars)]
51 | return avg_grads_and_vars
52 |
53 | def check_synced(localval, comm=None):
54 | """
55 | It's common to forget to initialize your variables to the same values, or
56 | (less commonly) if you update them in some other way than adam, to get them out of sync.
57 | This function checks that variables on all MPI workers are the same, and raises
58 | an AssertionError otherwise
59 |
60 | Arguments:
61 | comm: MPI communicator
62 | localval: list of local variables (list of variables on current worker to be compared with the other workers)
63 | """
64 | comm = comm or MPI.COMM_WORLD
65 | vals = comm.gather(localval)
66 | if comm.rank == 0:
67 | assert all(val==vals[0] for val in vals[1:]),\
68 | 'MpiAdamOptimizer detected that different workers have different weights: {}'.format(vals)
69 |
70 | @with_mpi(timeout=5)
71 | def test_nonfreeze():
72 | np.random.seed(0)
73 | tf.set_random_seed(0)
74 |
75 | a = tf.Variable(np.random.randn(3).astype('float32'))
76 | b = tf.Variable(np.random.randn(2,5).astype('float32'))
77 | loss = tf.reduce_sum(tf.square(a)) + tf.reduce_sum(tf.sin(b))
78 |
79 | stepsize = 1e-2
80 | # for some reason the session config with inter_op_parallelism_threads was causing
81 | # nested sess.run calls to freeze
82 | config = tf.ConfigProto(inter_op_parallelism_threads=1)
83 | sess = U.get_session(config=config)
84 | update_op = MpiAdamOptimizer(comm=MPI.COMM_WORLD, learning_rate=stepsize).minimize(loss)
85 | sess.run(tf.global_variables_initializer())
86 | losslist_ref = []
87 | for i in range(100):
88 | l,_ = sess.run([loss, update_op])
89 | print(i, l)
90 | losslist_ref.append(l)
91 |
--------------------------------------------------------------------------------
/mher/common/mpi_fork.py:
--------------------------------------------------------------------------------
1 | import os, subprocess, sys
2 |
3 | def mpi_fork(n, bind_to_core=False):
4 | """Re-launches the current script with workers
5 | Returns "parent" for original parent, "child" for MPI children
6 | """
7 | if n<=1:
8 | return "child"
9 | if os.getenv("IN_MPI") is None:
10 | env = os.environ.copy()
11 | env.update(
12 | MKL_NUM_THREADS="1",
13 | OMP_NUM_THREADS="1",
14 | IN_MPI="1"
15 | )
16 | args = ["mpirun", "-np", str(n)]
17 | if bind_to_core:
18 | args += ["-bind-to", "core"]
19 | args += [sys.executable] + sys.argv
20 | subprocess.check_call(args, env=env)
21 | return "parent"
22 | else:
23 | return "child"
24 |
--------------------------------------------------------------------------------
/mher/common/mpi_moments.py:
--------------------------------------------------------------------------------
1 | from mpi4py import MPI
2 | import numpy as np
3 | from mher.common import zipsame
4 |
5 |
6 | def mpi_mean(x, axis=0, comm=None, keepdims=False):
7 | x = np.asarray(x)
8 | assert x.ndim > 0
9 | if comm is None: comm = MPI.COMM_WORLD
10 | xsum = x.sum(axis=axis, keepdims=keepdims)
11 | n = xsum.size
12 | localsum = np.zeros(n+1, x.dtype)
13 | localsum[:n] = xsum.ravel()
14 | localsum[n] = x.shape[axis]
15 | # globalsum = np.zeros_like(localsum)
16 | # comm.Allreduce(localsum, globalsum, op=MPI.SUM)
17 | globalsum = comm.allreduce(localsum, op=MPI.SUM)
18 | return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n]
19 |
20 | def mpi_moments(x, axis=0, comm=None, keepdims=False):
21 | x = np.asarray(x)
22 | assert x.ndim > 0
23 | mean, count = mpi_mean(x, axis=axis, comm=comm, keepdims=True)
24 | sqdiffs = np.square(x - mean)
25 | meansqdiff, count1 = mpi_mean(sqdiffs, axis=axis, comm=comm, keepdims=True)
26 | assert count1 == count
27 | std = np.sqrt(meansqdiff)
28 | if not keepdims:
29 | newshape = mean.shape[:axis] + mean.shape[axis+1:]
30 | mean = mean.reshape(newshape)
31 | std = std.reshape(newshape)
32 | return mean, std, count
33 |
34 |
35 | def test_runningmeanstd():
36 | import subprocess
37 | subprocess.check_call(['mpirun', '-np', '3',
38 | 'python','-c',
39 | 'from mher.common.mpi_moments import _helper_runningmeanstd; _helper_runningmeanstd()'])
40 |
41 | def _helper_runningmeanstd():
42 | comm = MPI.COMM_WORLD
43 | np.random.seed(0)
44 | for (triple,axis) in [
45 | ((np.random.randn(3), np.random.randn(4), np.random.randn(5)),0),
46 | ((np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),0),
47 | ((np.random.randn(2,3), np.random.randn(2,4), np.random.randn(2,4)),1),
48 | ]:
49 |
50 |
51 | x = np.concatenate(triple, axis=axis)
52 | ms1 = [x.mean(axis=axis), x.std(axis=axis), x.shape[axis]]
53 |
54 |
55 | ms2 = mpi_moments(triple[comm.Get_rank()],axis=axis)
56 |
57 | for (a1,a2) in zipsame(ms1, ms2):
58 | print(a1, a2)
59 | assert np.allclose(a1, a2)
60 | print("ok!")
61 |
62 |
--------------------------------------------------------------------------------
/mher/common/mpi_running_mean_std.py:
--------------------------------------------------------------------------------
1 | try:
2 | from mpi4py import MPI
3 | except ImportError:
4 | MPI = None
5 |
6 | import tensorflow as tf, mher.common.tf_util as U, numpy as np
7 |
8 | class RunningMeanStd(object):
9 | # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
10 | def __init__(self, epsilon=1e-2, shape=()):
11 |
12 | self._sum = tf.get_variable(
13 | dtype=tf.float64,
14 | shape=shape,
15 | initializer=tf.constant_initializer(0.0),
16 | name="runningsum", trainable=False)
17 | self._sumsq = tf.get_variable(
18 | dtype=tf.float64,
19 | shape=shape,
20 | initializer=tf.constant_initializer(epsilon),
21 | name="runningsumsq", trainable=False)
22 | self._count = tf.get_variable(
23 | dtype=tf.float64,
24 | shape=(),
25 | initializer=tf.constant_initializer(epsilon),
26 | name="count", trainable=False)
27 | self.shape = shape
28 |
29 | self.mean = tf.to_float(self._sum / self._count)
30 | self.std = tf.sqrt( tf.maximum( tf.to_float(self._sumsq / self._count) - tf.square(self.mean) , 1e-2 ))
31 |
32 | newsum = tf.placeholder(shape=self.shape, dtype=tf.float64, name='sum')
33 | newsumsq = tf.placeholder(shape=self.shape, dtype=tf.float64, name='var')
34 | newcount = tf.placeholder(shape=[], dtype=tf.float64, name='count')
35 | self.incfiltparams = U.function([newsum, newsumsq, newcount], [],
36 | updates=[tf.assign_add(self._sum, newsum),
37 | tf.assign_add(self._sumsq, newsumsq),
38 | tf.assign_add(self._count, newcount)])
39 |
40 |
41 | def update(self, x):
42 | x = x.astype('float64')
43 | n = int(np.prod(self.shape))
44 | totalvec = np.zeros(n*2+1, 'float64')
45 | addvec = np.concatenate([x.sum(axis=0).ravel(), np.square(x).sum(axis=0).ravel(), np.array([len(x)],dtype='float64')])
46 | if MPI is not None:
47 | MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
48 | self.incfiltparams(totalvec[0:n].reshape(self.shape), totalvec[n:2*n].reshape(self.shape), totalvec[2*n])
49 |
50 | @U.in_session
51 | def test_runningmeanstd():
52 | for (x1, x2, x3) in [
53 | (np.random.randn(3), np.random.randn(4), np.random.randn(5)),
54 | (np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),
55 | ]:
56 |
57 | rms = RunningMeanStd(epsilon=0.0, shape=x1.shape[1:])
58 | U.initialize()
59 |
60 | x = np.concatenate([x1, x2, x3], axis=0)
61 | ms1 = [x.mean(axis=0), x.std(axis=0)]
62 | rms.update(x1)
63 | rms.update(x2)
64 | rms.update(x3)
65 | ms2 = [rms.mean.eval(), rms.std.eval()]
66 |
67 | assert np.allclose(ms1, ms2)
68 |
69 | @U.in_session
70 | def test_dist():
71 | np.random.seed(0)
72 | p1,p2,p3=(np.random.randn(3,1), np.random.randn(4,1), np.random.randn(5,1))
73 | q1,q2,q3=(np.random.randn(6,1), np.random.randn(7,1), np.random.randn(8,1))
74 |
75 | # p1,p2,p3=(np.random.randn(3), np.random.randn(4), np.random.randn(5))
76 | # q1,q2,q3=(np.random.randn(6), np.random.randn(7), np.random.randn(8))
77 |
78 | comm = MPI.COMM_WORLD
79 | assert comm.Get_size()==2
80 | if comm.Get_rank()==0:
81 | x1,x2,x3 = p1,p2,p3
82 | elif comm.Get_rank()==1:
83 | x1,x2,x3 = q1,q2,q3
84 | else:
85 | assert False
86 |
87 | rms = RunningMeanStd(epsilon=0.0, shape=(1,))
88 | U.initialize()
89 |
90 | rms.update(x1)
91 | rms.update(x2)
92 | rms.update(x3)
93 |
94 | bigvec = np.concatenate([p1,p2,p3,q1,q2,q3])
95 |
96 | def checkallclose(x,y):
97 | print(x,y)
98 | return np.allclose(x,y)
99 |
100 | assert checkallclose(
101 | bigvec.mean(axis=0),
102 | rms.mean.eval(),
103 | )
104 | assert checkallclose(
105 | bigvec.std(axis=0),
106 | rms.std.eval(),
107 | )
108 |
109 |
110 | if __name__ == "__main__":
111 | # Run with mpirun -np 2 python
112 | test_dist()
113 |
--------------------------------------------------------------------------------
/mher/common/mpi_util.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import os, numpy as np
3 | import platform
4 | import shutil
5 | import subprocess
6 | import warnings
7 | import sys
8 |
9 | try:
10 | from mpi4py import MPI
11 | except ImportError:
12 | MPI = None
13 |
14 |
15 | def sync_from_root(sess, variables, comm=None):
16 | """
17 | Send the root node's parameters to every worker.
18 | Arguments:
19 | sess: the TensorFlow session.
20 | variables: all parameter variables including optimizer's
21 | """
22 | if comm is None: comm = MPI.COMM_WORLD
23 | import tensorflow as tf
24 | values = comm.bcast(sess.run(variables))
25 | sess.run([tf.assign(var, val)
26 | for (var, val) in zip(variables, values)])
27 |
28 | def gpu_count():
29 | """
30 | Count the GPUs on this machine.
31 | """
32 | if shutil.which('nvidia-smi') is None:
33 | return 0
34 | output = subprocess.check_output(['nvidia-smi', '--query-gpu=gpu_name', '--format=csv'])
35 | return max(0, len(output.split(b'\n')) - 2)
36 |
37 | def setup_mpi_gpus():
38 | """
39 | Set CUDA_VISIBLE_DEVICES to MPI rank if not already set
40 | """
41 | if 'CUDA_VISIBLE_DEVICES' not in os.environ:
42 | if sys.platform == 'darwin': # This Assumes if you're on OSX you're just
43 | ids = [] # doing a smoke test and don't want GPUs
44 | else:
45 | lrank, _lsize = get_local_rank_size(MPI.COMM_WORLD)
46 | ids = [lrank]
47 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, ids))
48 |
49 | def get_local_rank_size(comm):
50 | """
51 | Returns the rank of each process on its machine
52 | The processes on a given machine will be assigned ranks
53 | 0, 1, 2, ..., N-1,
54 | where N is the number of processes on this machine.
55 |
56 | Useful if you want to assign one gpu per machine
57 | """
58 | this_node = platform.node()
59 | ranks_nodes = comm.allgather((comm.Get_rank(), this_node))
60 | node2rankssofar = defaultdict(int)
61 | local_rank = None
62 | for (rank, node) in ranks_nodes:
63 | if rank == comm.Get_rank():
64 | local_rank = node2rankssofar[node]
65 | node2rankssofar[node] += 1
66 | assert local_rank is not None
67 | return local_rank, node2rankssofar[this_node]
68 |
69 | def share_file(comm, path):
70 | """
71 | Copies the file from rank 0 to all other ranks
72 | Puts it in the same place on all machines
73 | """
74 | localrank, _ = get_local_rank_size(comm)
75 | if comm.Get_rank() == 0:
76 | with open(path, 'rb') as fh:
77 | data = fh.read()
78 | comm.bcast(data)
79 | else:
80 | data = comm.bcast(None)
81 | if localrank == 0:
82 | os.makedirs(os.path.dirname(path), exist_ok=True)
83 | with open(path, 'wb') as fh:
84 | fh.write(data)
85 | comm.Barrier()
86 |
87 | def dict_gather(comm, d, op='mean', assert_all_have_data=True):
88 | """
89 | Perform a reduction operation over dicts
90 | """
91 | if comm is None: return d
92 | alldicts = comm.allgather(d)
93 | size = comm.size
94 | k2li = defaultdict(list)
95 | for d in alldicts:
96 | for (k,v) in d.items():
97 | k2li[k].append(v)
98 | result = {}
99 | for (k,li) in k2li.items():
100 | if assert_all_have_data:
101 | assert len(li)==size, "only %i out of %i MPI workers have sent '%s'" % (len(li), size, k)
102 | if op=='mean':
103 | result[k] = np.mean(li, axis=0)
104 | elif op=='sum':
105 | result[k] = np.sum(li, axis=0)
106 | else:
107 | assert 0, op
108 | return result
109 |
110 | def mpi_weighted_mean(comm, local_name2valcount):
111 | """
112 | Perform a weighted average over dicts that are each on a different node
113 | Input: local_name2valcount: dict mapping key -> (value, count)
114 | Returns: key -> mean
115 | """
116 | all_name2valcount = comm.gather(local_name2valcount)
117 | if comm.rank == 0:
118 | name2sum = defaultdict(float)
119 | name2count = defaultdict(float)
120 | for n2vc in all_name2valcount:
121 | for (name, (val, count)) in n2vc.items():
122 | try:
123 | val = float(val)
124 | except ValueError:
125 | if comm.rank == 0:
126 | warnings.warn('WARNING: tried to compute mean on non-float {}={}'.format(name, val))
127 | else:
128 | name2sum[name] += val * count
129 | name2count[name] += count
130 | return {name : name2sum[name] / name2count[name] for name in name2sum}
131 | else:
132 | return {}
133 |
134 |
--------------------------------------------------------------------------------
/mher/common/plot/plot.py:
--------------------------------------------------------------------------------
1 | # DEPRECATED, use mher.common.plot_util instead
2 |
3 | import os
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import json
7 | import seaborn as sns; sns.set()
8 | import glob2
9 | import argparse
10 |
11 |
12 | def smooth_reward_curve(x, y):
13 | halfwidth = int(np.ceil(len(x) / 60)) # Halfwidth of our smoothing convolution
14 | k = halfwidth
15 | xsmoo = x
16 | ysmoo = np.convolve(y, np.ones(2 * k + 1), mode='same') / np.convolve(np.ones_like(y), np.ones(2 * k + 1),
17 | mode='same')
18 | return xsmoo, ysmoo
19 |
20 |
21 | def load_results(file):
22 | if not os.path.exists(file):
23 | return None
24 | with open(file, 'r') as f:
25 | lines = [line for line in f]
26 | if len(lines) < 2:
27 | return None
28 | keys = [name.strip() for name in lines[0].split(',')]
29 | data = np.genfromtxt(file, delimiter=',', skip_header=1, filling_values=0.)
30 | if data.ndim == 1:
31 | data = data.reshape(1, -1)
32 | assert data.ndim == 2
33 | assert data.shape[-1] == len(keys)
34 | result = {}
35 | for idx, key in enumerate(keys):
36 | result[key] = data[:, idx]
37 | return result
38 |
39 |
40 | def pad(xs, value=np.nan):
41 | maxlen = np.max([len(x) for x in xs])
42 |
43 | padded_xs = []
44 | for x in xs:
45 | if x.shape[0] >= maxlen:
46 | padded_xs.append(x)
47 |
48 | padding = np.ones((maxlen - x.shape[0],) + x.shape[1:]) * value
49 | x_padded = np.concatenate([x, padding], axis=0)
50 | assert x_padded.shape[1:] == x.shape[1:]
51 | assert x_padded.shape[0] == maxlen
52 | padded_xs.append(x_padded)
53 | return np.array(padded_xs)
54 |
55 |
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument('dir', type=str)
58 | parser.add_argument('--smooth', type=int, default=1)
59 | args = parser.parse_args()
60 |
61 | # Load all data.
62 | data = {}
63 | paths = [os.path.abspath(os.path.join(path, '..')) for path in glob2.glob(os.path.join(args.dir, '**', 'progress.csv'))]
64 | for curr_path in paths:
65 | if not os.path.isdir(curr_path):
66 | continue
67 | results = load_results(os.path.join(curr_path, 'progress.csv'))
68 | if not results:
69 | print('skipping {}'.format(curr_path))
70 | continue
71 | print('loading {} ({})'.format(curr_path, len(results['epoch'])))
72 | with open(os.path.join(curr_path, 'params.json'), 'r') as f:
73 | params = json.load(f)
74 |
75 | success_rate = np.array(results['test/success_rate'])
76 | epoch = np.array(results['epoch']) + 1
77 | env_id = params['env_name']
78 | replay_strategy = params['replay_strategy']
79 |
80 | if replay_strategy == 'future':
81 | config = 'her'
82 | else:
83 | config = 'ddpg'
84 | if 'Dense' in env_id:
85 | config += '-dense'
86 | else:
87 | config += '-sparse'
88 | env_id = env_id.replace('Dense', '')
89 |
90 | # Process and smooth data.
91 | assert success_rate.shape == epoch.shape
92 | x = epoch
93 | y = success_rate
94 | if args.smooth:
95 | x, y = smooth_reward_curve(epoch, success_rate)
96 | assert x.shape == y.shape
97 |
98 | if env_id not in data:
99 | data[env_id] = {}
100 | if config not in data[env_id]:
101 | data[env_id][config] = []
102 | data[env_id][config].append((x, y))
103 |
104 | # Plot data.
105 | for env_id in sorted(data.keys()):
106 | print('exporting {}'.format(env_id))
107 | plt.clf()
108 |
109 | for config in sorted(data[env_id].keys()):
110 | xs, ys = zip(*data[env_id][config])
111 | xs, ys = pad(xs), pad(ys)
112 | assert xs.shape == ys.shape
113 |
114 | plt.plot(xs[0], np.nanmedian(ys, axis=0), label=config)
115 | plt.fill_between(xs[0], np.nanpercentile(ys, 25, axis=0), np.nanpercentile(ys, 75, axis=0), alpha=0.25)
116 | plt.title(env_id)
117 | plt.xlabel('Epoch')
118 | plt.ylabel('Median Success Rate')
119 | plt.legend()
120 | plt.savefig(os.path.join(args.dir, 'fig_{}.png'.format(env_id)))
121 |
--------------------------------------------------------------------------------
/mher/common/plot/results_plotter.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib
3 | matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode
4 |
5 | import matplotlib.pyplot as plt
6 | plt.rcParams['svg.fonttype'] = 'none'
7 |
8 | from mher.common import plot_util
9 |
10 | X_TIMESTEPS = 'timesteps'
11 | X_EPISODES = 'episodes'
12 | X_WALLTIME = 'walltime_hrs'
13 | Y_REWARD = 'reward'
14 | Y_TIMESTEPS = 'timesteps'
15 | POSSIBLE_X_AXES = [X_TIMESTEPS, X_EPISODES, X_WALLTIME]
16 | EPISODES_WINDOW = 100
17 | COLORS = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow', 'black', 'purple', 'pink',
18 | 'brown', 'orange', 'teal', 'coral', 'lightblue', 'lime', 'lavender', 'turquoise',
19 | 'darkgreen', 'tan', 'salmon', 'gold', 'darkred', 'darkblue']
20 |
21 | def rolling_window(a, window):
22 | shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
23 | strides = a.strides + (a.strides[-1],)
24 | return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
25 |
26 | def window_func(x, y, window, func):
27 | yw = rolling_window(y, window)
28 | yw_func = func(yw, axis=-1)
29 | return x[window-1:], yw_func
30 |
31 | def ts2xy(ts, xaxis, yaxis):
32 | if xaxis == X_TIMESTEPS:
33 | x = np.cumsum(ts.l.values)
34 | elif xaxis == X_EPISODES:
35 | x = np.arange(len(ts))
36 | elif xaxis == X_WALLTIME:
37 | x = ts.t.values / 3600.
38 | else:
39 | raise NotImplementedError
40 | if yaxis == Y_REWARD:
41 | y = ts.r.values
42 | elif yaxis == Y_TIMESTEPS:
43 | y = ts.l.values
44 | else:
45 | raise NotImplementedError
46 | return x, y
47 |
48 | def plot_curves(xy_list, xaxis, yaxis, title):
49 | fig = plt.figure(figsize=(8,2))
50 | maxx = max(xy[0][-1] for xy in xy_list)
51 | minx = 0
52 | for (i, (x, y)) in enumerate(xy_list):
53 | color = COLORS[i % len(COLORS)]
54 | plt.scatter(x, y, s=2)
55 | x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean) #So returns average of last EPISODE_WINDOW episodes
56 | plt.plot(x, y_mean, color=color)
57 | plt.xlim(minx, maxx)
58 | plt.title(title)
59 | plt.xlabel(xaxis)
60 | plt.ylabel(yaxis)
61 | plt.tight_layout()
62 | fig.canvas.mpl_connect('resize_event', lambda event: plt.tight_layout())
63 | plt.grid(True)
64 |
65 |
66 | def split_by_task(taskpath):
67 | if type(taskpath) == dict:
68 | return taskpath['dirname'].split('/')[-1].split('-')[0]
69 | else:
70 | return taskpath.dirname.split('/')[-1].split('-')[0]
71 |
72 | def plot_results(dirs, num_timesteps=10e6, xaxis=X_TIMESTEPS, yaxis=Y_REWARD, title='', split_fn=split_by_task, average_group=False):
73 | results = plot_util.load_results(dirs)
74 | plot_util.plot_results(results, xy_fn=lambda r: ts2xy(r.monitor, xaxis, yaxis), split_fn=split_fn, resample=0 ,average_group=average_group )
75 |
76 | # ['monitor'] resample=int(1e6)
77 |
78 | # Example usage in jupyter-notebook
79 | # from mher.results_plotter import plot_results
80 | # %matplotlib inline
81 | # plot_results("./log")
82 | # Here ./log is a directory containing the monitor.csv files
83 |
84 | def main():
85 | import argparse
86 | import os
87 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
88 | parser.add_argument('--dirs', help='List of log directories', nargs = '*', default=['./log'])
89 | parser.add_argument('--num_timesteps', type=int, default=int(10e6))
90 | parser.add_argument('--xaxis', help = 'Varible on X-axis', default = X_TIMESTEPS)
91 | parser.add_argument('--yaxis', help = 'Varible on Y-axis', default = Y_REWARD)
92 | parser.add_argument('--task_name', help = 'Title of plot', default = 'Breakout')
93 | parser.add_argument('--average_group', help = 'group of point on the X-axis', type=bool, default = False)
94 | args = parser.parse_args()
95 | args.dirs = [os.path.abspath(dir) for dir in args.dirs]
96 | plot_results(args.dirs, args.num_timesteps, args.xaxis, args.yaxis, args.task_name, average_group=args.average_group)
97 | plt.show()
98 |
99 | if __name__ == '__main__':
100 | main()
101 |
--------------------------------------------------------------------------------
/mher/common/runners.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from abc import ABC, abstractmethod
3 |
4 | class AbstractEnvRunner(ABC):
5 | def __init__(self, *, env, model, nsteps):
6 | self.env = env
7 | self.model = model
8 | self.nenv = nenv = env.num_envs if hasattr(env, 'num_envs') else 1
9 | self.batch_ob_shape = (nenv*nsteps,) + env.observation_space.shape
10 | self.obs = np.zeros((nenv,) + env.observation_space.shape, dtype=env.observation_space.dtype.name)
11 | self.obs[:] = env.reset()
12 | self.nsteps = nsteps
13 | self.states = model.initial_state
14 | self.dones = [False for _ in range(nenv)]
15 |
16 | @abstractmethod
17 | def run(self):
18 | raise NotImplementedError
19 |
20 |
--------------------------------------------------------------------------------
/mher/common/schedules.py:
--------------------------------------------------------------------------------
1 | """This file is used for specifying various schedules that evolve over
2 | time throughout the execution of the algorithm, such as:
3 | - learning rate for the optimizer
4 | - exploration epsilon for the epsilon greedy exploration strategy
5 | - beta parameter for beta parameter in prioritized replay
6 |
7 | Each schedule has a function `value(t)` which returns the current value
8 | of the parameter given the timestep t of the optimization procedure.
9 | """
10 |
11 |
12 | class Schedule(object):
13 | def value(self, t):
14 | """Value of the schedule at time t"""
15 | raise NotImplementedError()
16 |
17 |
18 | class ConstantSchedule(object):
19 | def __init__(self, value):
20 | """Value remains constant over time.
21 |
22 | Parameters
23 | ----------
24 | value: float
25 | Constant value of the schedule
26 | """
27 | self._v = value
28 |
29 | def value(self, t):
30 | """See Schedule.value"""
31 | return self._v
32 |
33 |
34 | def linear_interpolation(l, r, alpha):
35 | return l + alpha * (r - l)
36 |
37 |
38 | class PiecewiseSchedule(object):
39 | def __init__(self, endpoints, interpolation=linear_interpolation, outside_value=None):
40 | """Piecewise schedule.
41 |
42 | endpoints: [(int, int)]
43 | list of pairs `(time, value)` meanining that schedule should output
44 | `value` when `t==time`. All the values for time must be sorted in
45 | an increasing order. When t is between two times, e.g. `(time_a, value_a)`
46 | and `(time_b, value_b)`, such that `time_a <= t < time_b` then value outputs
47 | `interpolation(value_a, value_b, alpha)` where alpha is a fraction of
48 | time passed between `time_a` and `time_b` for time `t`.
49 | interpolation: lambda float, float, float: float
50 | a function that takes value to the left and to the right of t according
51 | to the `endpoints`. Alpha is the fraction of distance from left endpoint to
52 | right endpoint that t has covered. See linear_interpolation for example.
53 | outside_value: float
54 | if the value is requested outside of all the intervals sepecified in
55 | `endpoints` this value is returned. If None then AssertionError is
56 | raised when outside value is requested.
57 | """
58 | idxes = [e[0] for e in endpoints]
59 | assert idxes == sorted(idxes)
60 | self._interpolation = interpolation
61 | self._outside_value = outside_value
62 | self._endpoints = endpoints
63 |
64 | def value(self, t):
65 | """See Schedule.value"""
66 | for (l_t, l), (r_t, r) in zip(self._endpoints[:-1], self._endpoints[1:]):
67 | if l_t <= t and t < r_t:
68 | alpha = float(t - l_t) / (r_t - l_t)
69 | return self._interpolation(l, r, alpha)
70 |
71 | # t does not belong to any of the pieces, so doom.
72 | assert self._outside_value is not None
73 | return self._outside_value
74 |
75 |
76 | class LinearSchedule(object):
77 | def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
78 | """Linear interpolation between initial_p and final_p over
79 | schedule_timesteps. After this many timesteps pass final_p is
80 | returned.
81 |
82 | Parameters
83 | ----------
84 | schedule_timesteps: int
85 | Number of timesteps for which to linearly anneal initial_p
86 | to final_p
87 | initial_p: float
88 | initial output value
89 | final_p: float
90 | final output value
91 | """
92 | self.schedule_timesteps = schedule_timesteps
93 | self.final_p = final_p
94 | self.initial_p = initial_p
95 |
96 | def value(self, t):
97 | """See Schedule.value"""
98 | fraction = min(float(t) / self.schedule_timesteps, 1.0)
99 | return self.initial_p + fraction * (self.final_p - self.initial_p)
100 |
--------------------------------------------------------------------------------
/mher/common/test_mpi_util.py:
--------------------------------------------------------------------------------
1 | from mher.common import mpi_util
2 | from mher import logger
3 | from mher.common.tests.test_with_mpi import with_mpi
4 | try:
5 | from mpi4py import MPI
6 | except ImportError:
7 | MPI = None
8 |
9 | @with_mpi()
10 | def test_mpi_weighted_mean():
11 | comm = MPI.COMM_WORLD
12 | with logger.scoped_configure(comm=comm):
13 | if comm.rank == 0:
14 | name2valcount = {'a' : (10, 2), 'b' : (20,3)}
15 | elif comm.rank == 1:
16 | name2valcount = {'a' : (19, 1), 'c' : (42,3)}
17 | else:
18 | raise NotImplementedError
19 | d = mpi_util.mpi_weighted_mean(comm, name2valcount)
20 | correctval = {'a' : (10 * 2 + 19) / 3.0, 'b' : 20, 'c' : 42}
21 | if comm.rank == 0:
22 | assert d == correctval, '{} != {}'.format(d, correctval)
23 |
24 | for name, (val, count) in name2valcount.items():
25 | for _ in range(count):
26 | logger.logkv_mean(name, val)
27 | d2 = logger.dumpkvs()
28 | if comm.rank == 0:
29 | assert d2 == correctval
30 |
--------------------------------------------------------------------------------
/mher/common/tests/__init__.py:
--------------------------------------------------------------------------------
1 | import os, pytest
2 | mark_slow = pytest.mark.skipif(not os.getenv('RUNSLOW'), reason='slow')
--------------------------------------------------------------------------------
/mher/common/tests/envs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/tests/envs/__init__.py
--------------------------------------------------------------------------------
/mher/common/tests/envs/fixed_sequence_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gym import Env
3 | from gym.spaces import Discrete
4 |
5 |
6 | class FixedSequenceEnv(Env):
7 | def __init__(
8 | self,
9 | n_actions=10,
10 | episode_len=100
11 | ):
12 | self.action_space = Discrete(n_actions)
13 | self.observation_space = Discrete(1)
14 | self.np_random = np.random.RandomState(0)
15 | self.episode_len = episode_len
16 | self.sequence = [self.np_random.randint(0, self.action_space.n)
17 | for _ in range(self.episode_len)]
18 | self.time = 0
19 |
20 |
21 | def reset(self):
22 | self.time = 0
23 | return 0
24 |
25 | def step(self, actions):
26 | rew = self._get_reward(actions)
27 | self._choose_next_state()
28 | done = False
29 | if self.episode_len and self.time >= self.episode_len:
30 | done = True
31 |
32 | return 0, rew, done, {}
33 |
34 | def seed(self, seed=None):
35 | self.np_random.seed(seed)
36 |
37 | def _choose_next_state(self):
38 | self.time += 1
39 |
40 | def _get_reward(self, actions):
41 | return 1 if actions == self.sequence[self.time] else 0
42 |
43 |
44 |
--------------------------------------------------------------------------------
/mher/common/tests/envs/identity_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from abc import abstractmethod
3 | from gym import Env
4 | from gym.spaces import MultiDiscrete, Discrete, Box
5 | from collections import deque
6 |
7 | class IdentityEnv(Env):
8 | def __init__(
9 | self,
10 | episode_len=None,
11 | delay=0,
12 | zero_first_rewards=True
13 | ):
14 |
15 | self.observation_space = self.action_space
16 | self.episode_len = episode_len
17 | self.time = 0
18 | self.delay = delay
19 | self.zero_first_rewards = zero_first_rewards
20 | self.q = deque(maxlen=delay+1)
21 |
22 | def reset(self):
23 | self.q.clear()
24 | for _ in range(self.delay + 1):
25 | self.q.append(self.action_space.sample())
26 | self.time = 0
27 |
28 | return self.q[-1]
29 |
30 | def step(self, actions):
31 | rew = self._get_reward(self.q.popleft(), actions)
32 | if self.zero_first_rewards and self.time < self.delay:
33 | rew = 0
34 | self.q.append(self.action_space.sample())
35 | self.time += 1
36 | done = self.episode_len is not None and self.time >= self.episode_len
37 | return self.q[-1], rew, done, {}
38 |
39 | def seed(self, seed=None):
40 | self.action_space.seed(seed)
41 |
42 | @abstractmethod
43 | def _get_reward(self, state, actions):
44 | raise NotImplementedError
45 |
46 |
47 | class DiscreteIdentityEnv(IdentityEnv):
48 | def __init__(
49 | self,
50 | dim,
51 | episode_len=None,
52 | delay=0,
53 | zero_first_rewards=True
54 | ):
55 |
56 | self.action_space = Discrete(dim)
57 | super().__init__(episode_len=episode_len, delay=delay, zero_first_rewards=zero_first_rewards)
58 |
59 | def _get_reward(self, state, actions):
60 | return 1 if state == actions else 0
61 |
62 | class MultiDiscreteIdentityEnv(IdentityEnv):
63 | def __init__(
64 | self,
65 | dims,
66 | episode_len=None,
67 | delay=0,
68 | ):
69 |
70 | self.action_space = MultiDiscrete(dims)
71 | super().__init__(episode_len=episode_len, delay=delay)
72 |
73 | def _get_reward(self, state, actions):
74 | return 1 if all(state == actions) else 0
75 |
76 |
77 | class BoxIdentityEnv(IdentityEnv):
78 | def __init__(
79 | self,
80 | shape,
81 | episode_len=None,
82 | ):
83 |
84 | self.action_space = Box(low=-1.0, high=1.0, shape=shape, dtype=np.float32)
85 | super().__init__(episode_len=episode_len)
86 |
87 | def _get_reward(self, state, actions):
88 | diff = actions - state
89 | diff = diff[:]
90 | return -0.5 * np.dot(diff, diff)
91 |
--------------------------------------------------------------------------------
/mher/common/tests/envs/identity_env_test.py:
--------------------------------------------------------------------------------
1 | from mher.common.tests.envs.identity_env import DiscreteIdentityEnv
2 |
3 |
4 | def test_discrete_nodelay():
5 | nsteps = 100
6 | eplen = 50
7 | env = DiscreteIdentityEnv(10, episode_len=eplen)
8 | ob = env.reset()
9 | for t in range(nsteps):
10 | action = env.action_space.sample()
11 | next_ob, rew, done, info = env.step(action)
12 | assert rew == (1 if action == ob else 0)
13 | if (t + 1) % eplen == 0:
14 | assert done
15 | next_ob = env.reset()
16 | else:
17 | assert not done
18 | ob = next_ob
19 |
20 | def test_discrete_delay1():
21 | eplen = 50
22 | env = DiscreteIdentityEnv(10, episode_len=eplen, delay=1)
23 | ob = env.reset()
24 | prev_ob = None
25 | for t in range(eplen):
26 | action = env.action_space.sample()
27 | next_ob, rew, done, info = env.step(action)
28 | if t > 0:
29 | assert rew == (1 if action == prev_ob else 0)
30 | else:
31 | assert rew == 0
32 | prev_ob = ob
33 | ob = next_ob
34 | if t < eplen - 1:
35 | assert not done
36 | assert done
37 |
--------------------------------------------------------------------------------
/mher/common/tests/envs/mnist_env.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import numpy as np
3 | import tempfile
4 | from gym import Env
5 | from gym.spaces import Discrete, Box
6 |
7 |
8 |
9 | class MnistEnv(Env):
10 | def __init__(
11 | self,
12 | episode_len=None,
13 | no_images=None
14 | ):
15 | import filelock
16 | from tensorflow.examples.tutorials.mnist import input_data
17 | # we could use temporary directory for this with a context manager and
18 | # TemporaryDirecotry, but then each test that uses mnist would re-download the data
19 | # this way the data is not cleaned up, but we only download it once per machine
20 | mnist_path = osp.join(tempfile.gettempdir(), 'MNIST_data')
21 | with filelock.FileLock(mnist_path + '.lock'):
22 | self.mnist = input_data.read_data_sets(mnist_path)
23 |
24 | self.np_random = np.random.RandomState()
25 |
26 | self.observation_space = Box(low=0.0, high=1.0, shape=(28,28,1))
27 | self.action_space = Discrete(10)
28 | self.episode_len = episode_len
29 | self.time = 0
30 | self.no_images = no_images
31 |
32 | self.train_mode()
33 | self.reset()
34 |
35 | def reset(self):
36 | self._choose_next_state()
37 | self.time = 0
38 |
39 | return self.state[0]
40 |
41 | def step(self, actions):
42 | rew = self._get_reward(actions)
43 | self._choose_next_state()
44 | done = False
45 | if self.episode_len and self.time >= self.episode_len:
46 | rew = 0
47 | done = True
48 |
49 | return self.state[0], rew, done, {}
50 |
51 | def seed(self, seed=None):
52 | self.np_random.seed(seed)
53 |
54 | def train_mode(self):
55 | self.dataset = self.mnist.train
56 |
57 | def test_mode(self):
58 | self.dataset = self.mnist.test
59 |
60 | def _choose_next_state(self):
61 | max_index = (self.no_images if self.no_images is not None else self.dataset.num_examples) - 1
62 | index = self.np_random.randint(0, max_index)
63 | image = self.dataset.images[index].reshape(28,28,1)*255
64 | label = self.dataset.labels[index]
65 | self.state = (image, label)
66 | self.time += 1
67 |
68 | def _get_reward(self, actions):
69 | return 1 if self.state[1] == actions else 0
70 |
71 |
72 |
--------------------------------------------------------------------------------
/mher/common/tests/test_cartpole.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import gym
3 |
4 | from mher.run import get_learn_function
5 | from mher.common.tests.util import reward_per_episode_test
6 | from mher.common.tests import mark_slow
7 |
8 | common_kwargs = dict(
9 | total_timesteps=30000,
10 | network='mlp',
11 | gamma=1.0,
12 | seed=0,
13 | )
14 |
15 | learn_kwargs = {
16 | 'a2c' : dict(nsteps=32, value_network='copy', lr=0.05),
17 | 'acer': dict(value_network='copy'),
18 | 'acktr': dict(nsteps=32, value_network='copy', is_async=False),
19 | 'deepq': dict(total_timesteps=20000),
20 | 'ppo2': dict(value_network='copy'),
21 | 'trpo_mpi': {}
22 | }
23 |
24 | @mark_slow
25 | @pytest.mark.parametrize("alg", learn_kwargs.keys())
26 | def test_cartpole(alg):
27 | '''
28 | Test if the algorithm (with an mlp policy)
29 | can learn to balance the cartpole
30 | '''
31 |
32 | kwargs = common_kwargs.copy()
33 | kwargs.update(learn_kwargs[alg])
34 |
35 | learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)
36 | def env_fn():
37 |
38 | env = gym.make('CartPole-v0')
39 | env.seed(0)
40 | return env
41 |
42 | reward_per_episode_test(env_fn, learn_fn, 100)
43 |
44 | if __name__ == '__main__':
45 | test_cartpole('acer')
46 |
--------------------------------------------------------------------------------
/mher/common/tests/test_doc_examples.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | try:
3 | import mujoco_py
4 | _mujoco_present = True
5 | except BaseException:
6 | mujoco_py = None
7 | _mujoco_present = False
8 |
9 |
10 | @pytest.mark.skipif(
11 | not _mujoco_present,
12 | reason='error loading mujoco - either mujoco / mujoco key not present, or LD_LIBRARY_PATH is not pointing to mujoco library'
13 | )
14 | def test_lstm_example():
15 | import tensorflow as tf
16 | from mher.common import policies, models, cmd_util
17 | from mher.common.vec_env.dummy_vec_env import DummyVecEnv
18 |
19 | # create vectorized environment
20 | venv = DummyVecEnv([lambda: cmd_util.make_mujoco_env('Reacher-v2', seed=0)])
21 |
22 | with tf.Session() as sess:
23 | # build policy based on lstm network with 128 units
24 | policy = policies.build_policy(venv, models.lstm(128))(nbatch=1, nsteps=1)
25 |
26 | # initialize tensorflow variables
27 | sess.run(tf.global_variables_initializer())
28 |
29 | # prepare environment variables
30 | ob = venv.reset()
31 | state = policy.initial_state
32 | done = [False]
33 | step_counter = 0
34 |
35 | # run a single episode until the end (i.e. until done)
36 | while True:
37 | action, _, state, _ = policy.step(ob, S=state, M=done)
38 | ob, reward, done, _ = venv.step(action)
39 | step_counter += 1
40 | if done:
41 | break
42 |
43 |
44 | assert step_counter > 5
45 |
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/mher/common/tests/test_env_after_learn.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import gym
3 | import tensorflow as tf
4 |
5 | from mher.common.vec_env.subproc_vec_env import SubprocVecEnv
6 | from mher.run import get_learn_function
7 | from mher.common.tf_util import make_session
8 |
9 | algos = ['a2c', 'acer', 'acktr', 'deepq', 'ppo2', 'trpo_mpi']
10 |
11 | @pytest.mark.parametrize('algo', algos)
12 | def test_env_after_learn(algo):
13 | def make_env():
14 | # acktr requires too much RAM, fails on travis
15 | env = gym.make('CartPole-v1' if algo == 'acktr' else 'PongNoFrameskip-v4')
16 | return env
17 |
18 | make_session(make_default=True, graph=tf.Graph())
19 | env = SubprocVecEnv([make_env])
20 |
21 | learn = get_learn_function(algo)
22 |
23 | # Commenting out the following line resolves the issue, though crash happens at env.reset().
24 | learn(network='mlp', env=env, total_timesteps=0, load_path=None, seed=None)
25 |
26 | env.reset()
27 | env.close()
28 |
--------------------------------------------------------------------------------
/mher/common/tests/test_fetchreach.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import gym
3 |
4 | from mher.run import get_learn_function
5 | from mher.common.tests.util import reward_per_episode_test
6 | from mher.common.tests import mark_slow
7 |
8 | pytest.importorskip('mujoco_py')
9 |
10 | common_kwargs = dict(
11 | network='mlp',
12 | seed=0,
13 | )
14 |
15 | learn_kwargs = {
16 | 'her': dict(total_timesteps=2000)
17 | }
18 |
19 | @mark_slow
20 | @pytest.mark.parametrize("alg", learn_kwargs.keys())
21 | def test_fetchreach(alg):
22 | '''
23 | Test if the algorithm (with an mlp policy)
24 | can learn the FetchReach task
25 | '''
26 |
27 | kwargs = common_kwargs.copy()
28 | kwargs.update(learn_kwargs[alg])
29 |
30 | learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)
31 | def env_fn():
32 |
33 | env = gym.make('FetchReach-v1')
34 | env.seed(0)
35 | return env
36 |
37 | reward_per_episode_test(env_fn, learn_fn, -15)
38 |
39 | if __name__ == '__main__':
40 | test_fetchreach('her')
41 |
--------------------------------------------------------------------------------
/mher/common/tests/test_fixed_sequence.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from mher.common.tests.envs.fixed_sequence_env import FixedSequenceEnv
3 |
4 | from mher.common.tests.util import simple_test
5 | from mher.run import get_learn_function
6 | from mher.common.tests import mark_slow
7 |
8 |
9 | common_kwargs = dict(
10 | seed=0,
11 | total_timesteps=50000,
12 | )
13 |
14 | learn_kwargs = {
15 | 'a2c': {},
16 | 'ppo2': dict(nsteps=10, ent_coef=0.0, nminibatches=1),
17 | # TODO enable sequential models for trpo_mpi (proper handling of nbatch and nsteps)
18 | # github issue: https://github.com/openai/baselines/issues/188
19 | # 'trpo_mpi': lambda e, p: trpo_mpi.learn(policy_fn=p(env=e), env=e, max_timesteps=30000, timesteps_per_batch=100, cg_iters=10, gamma=0.9, lam=1.0, max_kl=0.001)
20 | }
21 |
22 |
23 | alg_list = learn_kwargs.keys()
24 | rnn_list = ['lstm']
25 |
26 | @mark_slow
27 | @pytest.mark.parametrize("alg", alg_list)
28 | @pytest.mark.parametrize("rnn", rnn_list)
29 | def test_fixed_sequence(alg, rnn):
30 | '''
31 | Test if the algorithm (with a given policy)
32 | can learn an identity transformation (i.e. return observation as an action)
33 | '''
34 |
35 | kwargs = learn_kwargs[alg]
36 | kwargs.update(common_kwargs)
37 |
38 | env_fn = lambda: FixedSequenceEnv(n_actions=10, episode_len=5)
39 | learn = lambda e: get_learn_function(alg)(
40 | env=e,
41 | network=rnn,
42 | **kwargs
43 | )
44 |
45 | simple_test(env_fn, learn, 0.7)
46 |
47 |
48 | if __name__ == '__main__':
49 | test_fixed_sequence('ppo2', 'lstm')
50 |
51 |
52 |
53 |
--------------------------------------------------------------------------------
/mher/common/tests/test_identity.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from mher.common.tests.envs.identity_env import DiscreteIdentityEnv, BoxIdentityEnv, MultiDiscreteIdentityEnv
3 | from mher.run import get_learn_function
4 | from mher.common.tests.util import simple_test
5 | from mher.common.tests import mark_slow
6 |
7 | common_kwargs = dict(
8 | total_timesteps=30000,
9 | network='mlp',
10 | gamma=0.9,
11 | seed=0,
12 | )
13 |
14 | learn_kwargs = {
15 | 'a2c' : {},
16 | 'acktr': {},
17 | 'deepq': {},
18 | 'ddpg': dict(layer_norm=True),
19 | 'ppo2': dict(lr=1e-3, nsteps=64, ent_coef=0.0),
20 | 'trpo_mpi': dict(timesteps_per_batch=100, cg_iters=10, gamma=0.9, lam=1.0, max_kl=0.01)
21 | }
22 |
23 |
24 | algos_disc = ['a2c', 'acktr', 'deepq', 'ppo2', 'trpo_mpi']
25 | algos_multidisc = ['a2c', 'acktr', 'ppo2', 'trpo_mpi']
26 | algos_cont = ['a2c', 'acktr', 'ddpg', 'ppo2', 'trpo_mpi']
27 |
28 | @mark_slow
29 | @pytest.mark.parametrize("alg", algos_disc)
30 | def test_discrete_identity(alg):
31 | '''
32 | Test if the algorithm (with an mlp policy)
33 | can learn an identity transformation (i.e. return observation as an action)
34 | '''
35 |
36 | kwargs = learn_kwargs[alg]
37 | kwargs.update(common_kwargs)
38 |
39 | learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)
40 | env_fn = lambda: DiscreteIdentityEnv(10, episode_len=100)
41 | simple_test(env_fn, learn_fn, 0.9)
42 |
43 | @mark_slow
44 | @pytest.mark.parametrize("alg", algos_multidisc)
45 | def test_multidiscrete_identity(alg):
46 | '''
47 | Test if the algorithm (with an mlp policy)
48 | can learn an identity transformation (i.e. return observation as an action)
49 | '''
50 |
51 | kwargs = learn_kwargs[alg]
52 | kwargs.update(common_kwargs)
53 |
54 | learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)
55 | env_fn = lambda: MultiDiscreteIdentityEnv((3,3), episode_len=100)
56 | simple_test(env_fn, learn_fn, 0.9)
57 |
58 | @mark_slow
59 | @pytest.mark.parametrize("alg", algos_cont)
60 | def test_continuous_identity(alg):
61 | '''
62 | Test if the algorithm (with an mlp policy)
63 | can learn an identity transformation (i.e. return observation as an action)
64 | to a required precision
65 | '''
66 |
67 | kwargs = learn_kwargs[alg]
68 | kwargs.update(common_kwargs)
69 | learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)
70 |
71 | env_fn = lambda: BoxIdentityEnv((1,), episode_len=100)
72 | simple_test(env_fn, learn_fn, -0.1)
73 |
74 | if __name__ == '__main__':
75 | test_multidiscrete_identity('acktr')
76 |
77 |
--------------------------------------------------------------------------------
/mher/common/tests/test_mnist.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | # from mher.acer import acer_simple as acer
4 | from mher.common.tests.envs.mnist_env import MnistEnv
5 | from mher.common.tests.util import simple_test
6 | from mher.run import get_learn_function
7 | from mher.common.tests import mark_slow
8 |
9 | # TODO investigate a2c and ppo2 failures - is it due to bad hyperparameters for this problem?
10 | # GitHub issue https://github.com/openai/baselines/issues/189
11 | common_kwargs = {
12 | 'seed': 0,
13 | 'network':'cnn',
14 | 'gamma':0.9,
15 | 'pad':'SAME'
16 | }
17 |
18 | learn_args = {
19 | 'a2c': dict(total_timesteps=50000),
20 | 'acer': dict(total_timesteps=20000),
21 | 'deepq': dict(total_timesteps=5000),
22 | 'acktr': dict(total_timesteps=30000),
23 | 'ppo2': dict(total_timesteps=50000, lr=1e-3, nsteps=128, ent_coef=0.0),
24 | 'trpo_mpi': dict(total_timesteps=80000, timesteps_per_batch=100, cg_iters=10, lam=1.0, max_kl=0.001)
25 | }
26 |
27 |
28 | #tests pass, but are too slow on travis. Same algorithms are covered
29 | # by other tests with less compute-hungry nn's and by benchmarks
30 | @pytest.mark.skip
31 | @mark_slow
32 | @pytest.mark.parametrize("alg", learn_args.keys())
33 | def test_mnist(alg):
34 | '''
35 | Test if the algorithm can learn to classify MNIST digits.
36 | Uses CNN policy.
37 | '''
38 |
39 | learn_kwargs = learn_args[alg]
40 | learn_kwargs.update(common_kwargs)
41 |
42 | learn = get_learn_function(alg)
43 | learn_fn = lambda e: learn(env=e, **learn_kwargs)
44 | env_fn = lambda: MnistEnv(episode_len=100)
45 |
46 | simple_test(env_fn, learn_fn, 0.6)
47 |
48 | if __name__ == '__main__':
49 | test_mnist('acer')
50 |
--------------------------------------------------------------------------------
/mher/common/tests/test_plot_util.py:
--------------------------------------------------------------------------------
1 | # smoke tests of plot_util
2 | from mher.common import plot_util as pu
3 | from mher.common.tests.util import smoketest
4 |
5 |
6 | def test_plot_util():
7 | nruns = 4
8 | logdirs = [smoketest('--alg=ppo2 --env=CartPole-v0 --num_timesteps=10000') for _ in range(nruns)]
9 | data = pu.load_results(logdirs)
10 | assert len(data) == 4
11 |
12 | _, axes = pu.plot_results(data[:1]); assert len(axes) == 1
13 | _, axes = pu.plot_results(data, tiling='vertical'); assert axes.shape==(4,1)
14 | _, axes = pu.plot_results(data, tiling='horizontal'); assert axes.shape==(1,4)
15 | _, axes = pu.plot_results(data, tiling='symmetric'); assert axes.shape==(2,2)
16 | _, axes = pu.plot_results(data, split_fn=lambda _: ''); assert len(axes) == 1
17 |
18 |
--------------------------------------------------------------------------------
/mher/common/tests/test_schedules.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mher.common.schedules import ConstantSchedule, PiecewiseSchedule
4 |
5 |
6 | def test_piecewise_schedule():
7 | ps = PiecewiseSchedule([(-5, 100), (5, 200), (10, 50), (100, 50), (200, -50)], outside_value=500)
8 |
9 | assert np.isclose(ps.value(-10), 500)
10 | assert np.isclose(ps.value(0), 150)
11 | assert np.isclose(ps.value(5), 200)
12 | assert np.isclose(ps.value(9), 80)
13 | assert np.isclose(ps.value(50), 50)
14 | assert np.isclose(ps.value(80), 50)
15 | assert np.isclose(ps.value(150), 0)
16 | assert np.isclose(ps.value(175), -25)
17 | assert np.isclose(ps.value(201), 500)
18 | assert np.isclose(ps.value(500), 500)
19 |
20 | assert np.isclose(ps.value(200 - 1e-10), -50)
21 |
22 |
23 | def test_constant_schedule():
24 | cs = ConstantSchedule(5)
25 | for i in range(-100, 100):
26 | assert np.isclose(cs.value(i), 5)
27 |
--------------------------------------------------------------------------------
/mher/common/tests/test_segment_tree.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from mher.common.segment_tree import SumSegmentTree, MinSegmentTree
4 |
5 |
6 | def test_tree_set():
7 | tree = SumSegmentTree(4)
8 |
9 | tree[2] = 1.0
10 | tree[3] = 3.0
11 |
12 | assert np.isclose(tree.sum(), 4.0)
13 | assert np.isclose(tree.sum(0, 2), 0.0)
14 | assert np.isclose(tree.sum(0, 3), 1.0)
15 | assert np.isclose(tree.sum(2, 3), 1.0)
16 | assert np.isclose(tree.sum(2, -1), 1.0)
17 | assert np.isclose(tree.sum(2, 4), 4.0)
18 |
19 |
20 | def test_tree_set_overlap():
21 | tree = SumSegmentTree(4)
22 |
23 | tree[2] = 1.0
24 | tree[2] = 3.0
25 |
26 | assert np.isclose(tree.sum(), 3.0)
27 | assert np.isclose(tree.sum(2, 3), 3.0)
28 | assert np.isclose(tree.sum(2, -1), 3.0)
29 | assert np.isclose(tree.sum(2, 4), 3.0)
30 | assert np.isclose(tree.sum(1, 2), 0.0)
31 |
32 |
33 | def test_prefixsum_idx():
34 | tree = SumSegmentTree(4)
35 |
36 | tree[2] = 1.0
37 | tree[3] = 3.0
38 |
39 | assert tree.find_prefixsum_idx(0.0) == 2
40 | assert tree.find_prefixsum_idx(0.5) == 2
41 | assert tree.find_prefixsum_idx(0.99) == 2
42 | assert tree.find_prefixsum_idx(1.01) == 3
43 | assert tree.find_prefixsum_idx(3.00) == 3
44 | assert tree.find_prefixsum_idx(4.00) == 3
45 |
46 |
47 | def test_prefixsum_idx2():
48 | tree = SumSegmentTree(4)
49 |
50 | tree[0] = 0.5
51 | tree[1] = 1.0
52 | tree[2] = 1.0
53 | tree[3] = 3.0
54 |
55 | assert tree.find_prefixsum_idx(0.00) == 0
56 | assert tree.find_prefixsum_idx(0.55) == 1
57 | assert tree.find_prefixsum_idx(0.99) == 1
58 | assert tree.find_prefixsum_idx(1.51) == 2
59 | assert tree.find_prefixsum_idx(3.00) == 3
60 | assert tree.find_prefixsum_idx(5.50) == 3
61 |
62 |
63 | def test_max_interval_tree():
64 | tree = MinSegmentTree(4)
65 |
66 | tree[0] = 1.0
67 | tree[2] = 0.5
68 | tree[3] = 3.0
69 |
70 | assert np.isclose(tree.min(), 0.5)
71 | assert np.isclose(tree.min(0, 2), 1.0)
72 | assert np.isclose(tree.min(0, 3), 0.5)
73 | assert np.isclose(tree.min(0, -1), 0.5)
74 | assert np.isclose(tree.min(2, 4), 0.5)
75 | assert np.isclose(tree.min(3, 4), 3.0)
76 |
77 | tree[2] = 0.7
78 |
79 | assert np.isclose(tree.min(), 0.7)
80 | assert np.isclose(tree.min(0, 2), 1.0)
81 | assert np.isclose(tree.min(0, 3), 0.7)
82 | assert np.isclose(tree.min(0, -1), 0.7)
83 | assert np.isclose(tree.min(2, 4), 0.7)
84 | assert np.isclose(tree.min(3, 4), 3.0)
85 |
86 | tree[2] = 4.0
87 |
88 | assert np.isclose(tree.min(), 1.0)
89 | assert np.isclose(tree.min(0, 2), 1.0)
90 | assert np.isclose(tree.min(0, 3), 1.0)
91 | assert np.isclose(tree.min(0, -1), 1.0)
92 | assert np.isclose(tree.min(2, 4), 3.0)
93 | assert np.isclose(tree.min(2, 3), 4.0)
94 | assert np.isclose(tree.min(2, -1), 4.0)
95 | assert np.isclose(tree.min(3, 4), 3.0)
96 |
97 |
98 | if __name__ == '__main__':
99 | test_tree_set()
100 | test_tree_set_overlap()
101 | test_prefixsum_idx()
102 | test_prefixsum_idx2()
103 | test_max_interval_tree()
104 |
--------------------------------------------------------------------------------
/mher/common/tests/test_serialization.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gym
3 | import tempfile
4 | import pytest
5 | import tensorflow as tf
6 | import numpy as np
7 |
8 | from mher.common.tests.envs.mnist_env import MnistEnv
9 | from mher.common.vec_env.dummy_vec_env import DummyVecEnv
10 | from mher.run import get_learn_function
11 | from mher.common.tf_util import make_session, get_session
12 |
13 | from functools import partial
14 |
15 |
16 | learn_kwargs = {
17 | 'deepq': {},
18 | 'a2c': {},
19 | 'acktr': {},
20 | 'acer': {},
21 | 'ppo2': {'nminibatches': 1, 'nsteps': 10},
22 | 'trpo_mpi': {},
23 | }
24 |
25 | network_kwargs = {
26 | 'mlp': {},
27 | 'cnn': {'pad': 'SAME'},
28 | 'lstm': {},
29 | 'cnn_lnlstm': {'pad': 'SAME'}
30 | }
31 |
32 |
33 | @pytest.mark.parametrize("learn_fn", learn_kwargs.keys())
34 | @pytest.mark.parametrize("network_fn", network_kwargs.keys())
35 | def test_serialization(learn_fn, network_fn):
36 | '''
37 | Test if the trained model can be serialized
38 | '''
39 |
40 |
41 | if network_fn.endswith('lstm') and learn_fn in ['acer', 'acktr', 'trpo_mpi', 'deepq']:
42 | # TODO make acktr work with recurrent policies
43 | # and test
44 | # github issue: https://github.com/openai/baselines/issues/660
45 | return
46 |
47 | def make_env():
48 | env = MnistEnv(episode_len=100)
49 | env.seed(10)
50 | return env
51 |
52 | env = DummyVecEnv([make_env])
53 | ob = env.reset().copy()
54 | learn = get_learn_function(learn_fn)
55 |
56 | kwargs = {}
57 | kwargs.update(network_kwargs[network_fn])
58 | kwargs.update(learn_kwargs[learn_fn])
59 |
60 |
61 | learn = partial(learn, env=env, network=network_fn, seed=0, **kwargs)
62 |
63 | with tempfile.TemporaryDirectory() as td:
64 | model_path = os.path.join(td, 'serialization_test_model')
65 |
66 | with tf.Graph().as_default(), make_session().as_default():
67 | model = learn(total_timesteps=100)
68 | model.save(model_path)
69 | mean1, std1 = _get_action_stats(model, ob)
70 | variables_dict1 = _serialize_variables()
71 |
72 | with tf.Graph().as_default(), make_session().as_default():
73 | model = learn(total_timesteps=0, load_path=model_path)
74 | mean2, std2 = _get_action_stats(model, ob)
75 | variables_dict2 = _serialize_variables()
76 |
77 | for k, v in variables_dict1.items():
78 | np.testing.assert_allclose(v, variables_dict2[k], atol=0.01,
79 | err_msg='saved and loaded variable {} value mismatch'.format(k))
80 |
81 | np.testing.assert_allclose(mean1, mean2, atol=0.5)
82 | np.testing.assert_allclose(std1, std2, atol=0.5)
83 |
84 |
85 | @pytest.mark.parametrize("learn_fn", learn_kwargs.keys())
86 | @pytest.mark.parametrize("network_fn", ['mlp'])
87 | def test_coexistence(learn_fn, network_fn):
88 | '''
89 | Test if more than one model can exist at a time
90 | '''
91 |
92 | if learn_fn == 'deepq':
93 | # TODO enable multiple DQN models to be useable at the same time
94 | # github issue https://github.com/openai/baselines/issues/656
95 | return
96 |
97 | if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
98 | # TODO make acktr work with recurrent policies
99 | # and test
100 | # github issue: https://github.com/openai/baselines/issues/660
101 | return
102 |
103 | env = DummyVecEnv([lambda: gym.make('CartPole-v0')])
104 | learn = get_learn_function(learn_fn)
105 |
106 | kwargs = {}
107 | kwargs.update(network_kwargs[network_fn])
108 | kwargs.update(learn_kwargs[learn_fn])
109 |
110 | learn = partial(learn, env=env, network=network_fn, total_timesteps=0, **kwargs)
111 | make_session(make_default=True, graph=tf.Graph())
112 | model1 = learn(seed=1)
113 | make_session(make_default=True, graph=tf.Graph())
114 | model2 = learn(seed=2)
115 |
116 | model1.step(env.observation_space.sample())
117 | model2.step(env.observation_space.sample())
118 |
119 |
120 |
121 | def _serialize_variables():
122 | sess = get_session()
123 | variables = tf.trainable_variables()
124 | values = sess.run(variables)
125 | return {var.name: value for var, value in zip(variables, values)}
126 |
127 |
128 | def _get_action_stats(model, ob):
129 | ntrials = 1000
130 | if model.initial_state is None or model.initial_state == []:
131 | actions = np.array([model.step(ob)[0] for _ in range(ntrials)])
132 | else:
133 | actions = np.array([model.step(ob, S=model.initial_state, M=[False])[0] for _ in range(ntrials)])
134 |
135 | mean = np.mean(actions, axis=0)
136 | std = np.std(actions, axis=0)
137 |
138 | return mean, std
139 |
140 |
--------------------------------------------------------------------------------
/mher/common/tests/test_tf_util.py:
--------------------------------------------------------------------------------
1 | # tests for tf_util
2 | import tensorflow as tf
3 | from mher.common.tf_util import (
4 | function,
5 | initialize,
6 | single_threaded_session
7 | )
8 |
9 |
10 | def test_function():
11 | with tf.Graph().as_default():
12 | x = tf.placeholder(tf.int32, (), name="x")
13 | y = tf.placeholder(tf.int32, (), name="y")
14 | z = 3 * x + 2 * y
15 | lin = function([x, y], z, givens={y: 0})
16 |
17 | with single_threaded_session():
18 | initialize()
19 |
20 | assert lin(2) == 6
21 | assert lin(x=3) == 9
22 | assert lin(2, 2) == 10
23 | assert lin(x=2, y=3) == 12
24 |
25 |
26 | def test_multikwargs():
27 | with tf.Graph().as_default():
28 | x = tf.placeholder(tf.int32, (), name="x")
29 | with tf.variable_scope("other"):
30 | x2 = tf.placeholder(tf.int32, (), name="x")
31 | z = 3 * x + 2 * x2
32 |
33 | lin = function([x, x2], z, givens={x2: 0})
34 | with single_threaded_session():
35 | initialize()
36 | assert lin(2) == 6
37 | assert lin(2, 2) == 10
38 |
39 |
40 | if __name__ == '__main__':
41 | test_function()
42 | test_multikwargs()
43 |
--------------------------------------------------------------------------------
/mher/common/tests/test_with_mpi.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import subprocess
4 | import cloudpickle
5 | import base64
6 | import pytest
7 | from functools import wraps
8 |
9 | try:
10 | from mpi4py import MPI
11 | except ImportError:
12 | MPI = None
13 |
14 | def with_mpi(nproc=2, timeout=30, skip_if_no_mpi=True):
15 | def outer_thunk(fn):
16 | @wraps(fn)
17 | def thunk(*args, **kwargs):
18 | serialized_fn = base64.b64encode(cloudpickle.dumps(lambda: fn(*args, **kwargs)))
19 | subprocess.check_call([
20 | 'mpiexec','-n', str(nproc),
21 | sys.executable,
22 | '-m', 'mher.common.tests.test_with_mpi',
23 | serialized_fn
24 | ], env=os.environ, timeout=timeout)
25 |
26 | if skip_if_no_mpi:
27 | return pytest.mark.skipif(MPI is None, reason="MPI not present")(thunk)
28 | else:
29 | return thunk
30 |
31 | return outer_thunk
32 |
33 |
34 | if __name__ == '__main__':
35 | if len(sys.argv) > 1:
36 | fn = cloudpickle.loads(base64.b64decode(sys.argv[1]))
37 | assert callable(fn)
38 | fn()
39 |
--------------------------------------------------------------------------------
/mher/common/tests/util.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from mher.common.vec_env.dummy_vec_env import DummyVecEnv
4 |
5 | N_TRIALS = 10000
6 | N_EPISODES = 100
7 |
8 | _sess_config = tf.ConfigProto(
9 | allow_soft_placement=True,
10 | intra_op_parallelism_threads=1,
11 | inter_op_parallelism_threads=1
12 | )
13 |
14 | def simple_test(env_fn, learn_fn, min_reward_fraction, n_trials=N_TRIALS):
15 | def seeded_env_fn():
16 | env = env_fn()
17 | env.seed(0)
18 | return env
19 |
20 | np.random.seed(0)
21 | env = DummyVecEnv([seeded_env_fn])
22 | with tf.Graph().as_default(), tf.Session(config=_sess_config).as_default():
23 | tf.set_random_seed(0)
24 | model = learn_fn(env)
25 | sum_rew = 0
26 | done = True
27 | for i in range(n_trials):
28 | if done:
29 | obs = env.reset()
30 | state = model.initial_state
31 | if state is not None:
32 | a, v, state, _ = model.step(obs, S=state, M=[False])
33 | else:
34 | a, v, _, _ = model.step(obs)
35 | obs, rew, done, _ = env.step(a)
36 | sum_rew += float(rew)
37 | print("Reward in {} trials is {}".format(n_trials, sum_rew))
38 | assert sum_rew > min_reward_fraction * n_trials, \
39 | 'sum of rewards {} is less than {} of the total number of trials {}'.format(sum_rew, min_reward_fraction, n_trials)
40 |
41 | def reward_per_episode_test(env_fn, learn_fn, min_avg_reward, n_trials=N_EPISODES):
42 | env = DummyVecEnv([env_fn])
43 | with tf.Graph().as_default(), tf.Session(config=_sess_config).as_default():
44 | model = learn_fn(env)
45 | N_TRIALS = 100
46 | observations, actions, rewards = rollout(env, model, N_TRIALS)
47 | rewards = [sum(r) for r in rewards]
48 | avg_rew = sum(rewards) / N_TRIALS
49 | print("Average reward in {} episodes is {}".format(n_trials, avg_rew))
50 | assert avg_rew > min_avg_reward, \
51 | 'average reward in {} episodes ({}) is less than {}'.format(n_trials, avg_rew, min_avg_reward)
52 |
53 | def rollout(env, model, n_trials):
54 | rewards = []
55 | actions = []
56 | observations = []
57 | for i in range(n_trials):
58 | obs = env.reset()
59 | state = model.initial_state if hasattr(model, 'initial_state') else None
60 | episode_rew = []
61 | episode_actions = []
62 | episode_obs = []
63 | while True:
64 | if state is not None:
65 | a, v, state, _ = model.step(obs, S=state, M=[False])
66 | else:
67 | a,v, _, _ = model.step(obs)
68 |
69 | obs, rew, done, _ = env.step(a)
70 | episode_rew.append(rew)
71 | episode_actions.append(a)
72 | episode_obs.append(obs)
73 | if done:
74 | break
75 | rewards.append(episode_rew)
76 | actions.append(episode_actions)
77 | observations.append(episode_obs)
78 | return observations, actions, rewards
79 |
80 |
81 | def smoketest(argstr, **kwargs):
82 | import tempfile
83 | import subprocess
84 | import os
85 | argstr = 'python -m mher.run ' + argstr
86 | for key, value in kwargs:
87 | argstr += ' --{}={}'.format(key, value)
88 | tempdir = tempfile.mkdtemp()
89 | env = os.environ.copy()
90 | env['OPENAI_LOGDIR'] = tempdir
91 | subprocess.run(argstr.split(' '), env=env)
92 | return tempdir
93 |
--------------------------------------------------------------------------------
/mher/common/tile_images.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def tile_images(img_nhwc):
4 | """
5 | Tile N images into one big PxQ image
6 | (P,Q) are chosen to be as close as possible, and if N
7 | is square, then P=Q.
8 |
9 | input: img_nhwc, list or array of images, ndim=4 once turned into array
10 | n = batch index, h = height, w = width, c = channel
11 | returns:
12 | bigim_HWc, ndarray with ndim=3
13 | """
14 | img_nhwc = np.asarray(img_nhwc)
15 | N, h, w, c = img_nhwc.shape
16 | H = int(np.ceil(np.sqrt(N)))
17 | W = int(np.ceil(float(N)/H))
18 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)])
19 | img_HWhwc = img_nhwc.reshape(H, W, h, w, c)
20 | img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)
21 | img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c)
22 | return img_Hh_Ww_c
23 |
24 |
--------------------------------------------------------------------------------
/mher/common/vec_env/__init__.py:
--------------------------------------------------------------------------------
1 | from .vec_env import AlreadySteppingError, NotSteppingError, VecEnv, VecEnvWrapper, VecEnvObservationWrapper, CloudpickleWrapper
2 | from .dummy_vec_env import DummyVecEnv
3 | from .shmem_vec_env import ShmemVecEnv
4 | from .subproc_vec_env import SubprocVecEnv
5 | from .vec_frame_stack import VecFrameStack
6 | from .vec_monitor import VecMonitor
7 | from .vec_normalize import VecNormalize
8 | from .vec_remove_dict_obs import VecExtractDictObs
9 |
10 | __all__ = ['AlreadySteppingError', 'NotSteppingError', 'VecEnv', 'VecEnvWrapper', 'VecEnvObservationWrapper', 'CloudpickleWrapper', 'DummyVecEnv', 'ShmemVecEnv', 'SubprocVecEnv', 'VecFrameStack', 'VecMonitor', 'VecNormalize', 'VecExtractDictObs']
11 |
--------------------------------------------------------------------------------
/mher/common/vec_env/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/vec_env/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/vec_env/__pycache__/dummy_vec_env.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/vec_env/__pycache__/dummy_vec_env.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/vec_env/__pycache__/shmem_vec_env.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/vec_env/__pycache__/shmem_vec_env.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/vec_env/__pycache__/subproc_vec_env.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/vec_env/__pycache__/subproc_vec_env.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/vec_env/__pycache__/util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/vec_env/__pycache__/util.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/vec_env/__pycache__/vec_env.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/vec_env/__pycache__/vec_env.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/vec_env/__pycache__/vec_frame_stack.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/vec_env/__pycache__/vec_frame_stack.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/vec_env/__pycache__/vec_monitor.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/vec_env/__pycache__/vec_monitor.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/vec_env/__pycache__/vec_normalize.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/vec_env/__pycache__/vec_normalize.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/vec_env/__pycache__/vec_remove_dict_obs.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/vec_env/__pycache__/vec_remove_dict_obs.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/vec_env/__pycache__/vec_video_recorder.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/common/vec_env/__pycache__/vec_video_recorder.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/common/vec_env/dummy_vec_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .vec_env import VecEnv
3 | from .util import copy_obs_dict, dict_to_obs, obs_space_info
4 |
5 | class DummyVecEnv(VecEnv):
6 | """
7 | VecEnv that does runs multiple environments sequentially, that is,
8 | the step and reset commands are send to one environment at a time.
9 | Useful when debugging and when num_env == 1 (in the latter case,
10 | avoids communication overhead)
11 | """
12 | def __init__(self, env_fns):
13 | """
14 | Arguments:
15 |
16 | env_fns: iterable of callables functions that build environments
17 | """
18 | self.envs = [fn() for fn in env_fns]
19 | env = self.envs[0]
20 | VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
21 | obs_space = env.observation_space
22 | self.keys, shapes, dtypes = obs_space_info(obs_space)
23 |
24 | self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys }
25 | self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
26 | self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
27 | self.buf_infos = [{} for _ in range(self.num_envs)]
28 | self.actions = None
29 | self.spec = self.envs[0].spec
30 |
31 | def step_async(self, actions):
32 | listify = True
33 | try:
34 | if len(actions) == self.num_envs:
35 | listify = False
36 | except TypeError:
37 | pass
38 |
39 | if not listify:
40 | self.actions = actions
41 | else:
42 | assert self.num_envs == 1, "actions {} is either not a list or has a wrong size - cannot match to {} environments".format(actions, self.num_envs)
43 | self.actions = [actions]
44 |
45 | def step_wait(self):
46 | for e in range(self.num_envs):
47 | action = self.actions[e]
48 | obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = self.envs[e].step(action)
49 | # if self.buf_dones[e]: # here we don't need to reset because we reset in the main code file
50 | # obs = self.envs[e].reset()
51 | self._save_obs(e, obs)
52 | return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones),
53 | self.buf_infos.copy())
54 |
55 | def reset(self):
56 | for e in range(self.num_envs):
57 | obs = self.envs[e].reset()
58 | self._save_obs(e, obs)
59 | return self._obs_from_buf()
60 |
61 | def _save_obs(self, e, obs):
62 | for k in self.keys:
63 | if k is None:
64 | self.buf_obs[k][e] = obs
65 | else:
66 | self.buf_obs[k][e] = obs[k]
67 |
68 | def _obs_from_buf(self):
69 | return dict_to_obs(copy_obs_dict(self.buf_obs))
70 |
71 | def get_images(self):
72 | return [env.render(mode='rgb_array') for env in self.envs]
73 |
74 | def render(self, mode='human'):
75 | if self.num_envs == 1:
76 | return self.envs[0].render(mode=mode)
77 | else:
78 | return super().render(mode=mode)
79 |
--------------------------------------------------------------------------------
/mher/common/vec_env/shmem_vec_env.py:
--------------------------------------------------------------------------------
1 | """
2 | An interface for asynchronous vectorized environments.
3 | """
4 |
5 | import multiprocessing as mp
6 | import numpy as np
7 | from .vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars
8 | import ctypes
9 | from mher.common import logger
10 |
11 | from .util import dict_to_obs, obs_space_info, obs_to_dict
12 |
13 | _NP_TO_CT = {np.float32: ctypes.c_float,
14 | np.int32: ctypes.c_int32,
15 | np.int8: ctypes.c_int8,
16 | np.uint8: ctypes.c_char,
17 | np.bool: ctypes.c_bool}
18 |
19 |
20 | class ShmemVecEnv(VecEnv):
21 | """
22 | Optimized version of SubprocVecEnv that uses shared variables to communicate observations.
23 | """
24 |
25 | def __init__(self, env_fns, spaces=None, context='spawn'):
26 | """
27 | If you don't specify observation_space, we'll have to create a dummy
28 | environment to get it.
29 | """
30 | ctx = mp.get_context(context)
31 | if spaces:
32 | observation_space, action_space = spaces
33 | else:
34 | logger.log('Creating dummy env object to get spaces')
35 | with logger.scoped_configure(format_strs=[]):
36 | dummy = env_fns[0]()
37 | observation_space, action_space = dummy.observation_space, dummy.action_space
38 | dummy.close()
39 | del dummy
40 | VecEnv.__init__(self, len(env_fns), observation_space, action_space)
41 | self.obs_keys, self.obs_shapes, self.obs_dtypes = obs_space_info(observation_space)
42 | self.obs_bufs = [
43 | {k: ctx.Array(_NP_TO_CT[self.obs_dtypes[k].type], int(np.prod(self.obs_shapes[k]))) for k in self.obs_keys}
44 | for _ in env_fns]
45 | self.parent_pipes = []
46 | self.procs = []
47 | with clear_mpi_env_vars():
48 | for env_fn, obs_buf in zip(env_fns, self.obs_bufs):
49 | wrapped_fn = CloudpickleWrapper(env_fn)
50 | parent_pipe, child_pipe = ctx.Pipe()
51 | proc = ctx.Process(target=_subproc_worker,
52 | args=(child_pipe, parent_pipe, wrapped_fn, obs_buf, self.obs_shapes, self.obs_dtypes, self.obs_keys))
53 | proc.daemon = True
54 | self.procs.append(proc)
55 | self.parent_pipes.append(parent_pipe)
56 | proc.start()
57 | child_pipe.close()
58 | self.waiting_step = False
59 | self.viewer = None
60 |
61 | def reset(self):
62 | if self.waiting_step:
63 | logger.warn('Called reset() while waiting for the step to complete')
64 | self.step_wait()
65 | for pipe in self.parent_pipes:
66 | pipe.send(('reset', None))
67 | return self._decode_obses([pipe.recv() for pipe in self.parent_pipes])
68 |
69 | def step_async(self, actions):
70 | assert len(actions) == len(self.parent_pipes)
71 | for pipe, act in zip(self.parent_pipes, actions):
72 | pipe.send(('step', act))
73 | self.waiting_step = True
74 |
75 | def step_wait(self):
76 | outs = [pipe.recv() for pipe in self.parent_pipes]
77 | self.waiting_step = False
78 | obs, rews, dones, infos = zip(*outs)
79 | return self._decode_obses(obs), np.array(rews), np.array(dones), infos
80 |
81 | def close_extras(self):
82 | if self.waiting_step:
83 | self.step_wait()
84 | for pipe in self.parent_pipes:
85 | pipe.send(('close', None))
86 | for pipe in self.parent_pipes:
87 | pipe.recv()
88 | pipe.close()
89 | for proc in self.procs:
90 | proc.join()
91 |
92 | def get_images(self, mode='human'):
93 | for pipe in self.parent_pipes:
94 | pipe.send(('render', None))
95 | return [pipe.recv() for pipe in self.parent_pipes]
96 |
97 | def _decode_obses(self, obs):
98 | result = {}
99 | for k in self.obs_keys:
100 |
101 | bufs = [b[k] for b in self.obs_bufs]
102 | o = [np.frombuffer(b.get_obj(), dtype=self.obs_dtypes[k]).reshape(self.obs_shapes[k]) for b in bufs]
103 | result[k] = np.array(o)
104 | return dict_to_obs(result)
105 |
106 |
107 | def _subproc_worker(pipe, parent_pipe, env_fn_wrapper, obs_bufs, obs_shapes, obs_dtypes, keys):
108 | """
109 | Control a single environment instance using IPC and
110 | shared memory.
111 | """
112 | def _write_obs(maybe_dict_obs):
113 | flatdict = obs_to_dict(maybe_dict_obs)
114 | for k in keys:
115 | dst = obs_bufs[k].get_obj()
116 | dst_np = np.frombuffer(dst, dtype=obs_dtypes[k]).reshape(obs_shapes[k]) # pylint: disable=W0212
117 | np.copyto(dst_np, flatdict[k])
118 |
119 | env = env_fn_wrapper.x()
120 | parent_pipe.close()
121 | try:
122 | while True:
123 | cmd, data = pipe.recv()
124 | if cmd == 'reset':
125 | pipe.send(_write_obs(env.reset()))
126 | elif cmd == 'step':
127 | obs, reward, done, info = env.step(data)
128 | if done:
129 | obs = env.reset()
130 | pipe.send((_write_obs(obs), reward, done, info))
131 | elif cmd == 'render':
132 | pipe.send(env.render(mode='rgb_array'))
133 | elif cmd == 'close':
134 | pipe.send(None)
135 | break
136 | else:
137 | raise RuntimeError('Got unrecognized cmd %s' % cmd)
138 | except KeyboardInterrupt:
139 | print('ShmemVecEnv worker: got KeyboardInterrupt')
140 | finally:
141 | env.close()
142 |
--------------------------------------------------------------------------------
/mher/common/vec_env/subproc_vec_env.py:
--------------------------------------------------------------------------------
1 | import multiprocessing as mp
2 |
3 | import numpy as np
4 | from .vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars
5 |
6 |
7 | def worker(remote, parent_remote, env_fn_wrappers):
8 | def step_env(env, action):
9 | ob, reward, done, info = env.step(action)
10 | if done:
11 | ob = env.reset()
12 | return ob, reward, done, info
13 |
14 | parent_remote.close()
15 | envs = [env_fn_wrapper() for env_fn_wrapper in env_fn_wrappers.x]
16 | try:
17 | while True:
18 | cmd, data = remote.recv()
19 | if cmd == 'step':
20 | remote.send([step_env(env, action) for env, action in zip(envs, data)])
21 | elif cmd == 'reset':
22 | remote.send([env.reset() for env in envs])
23 | elif cmd == 'render':
24 | remote.send([env.render(mode='rgb_array') for env in envs])
25 | elif cmd == 'close':
26 | remote.close()
27 | break
28 | elif cmd == 'get_spaces_spec':
29 | remote.send(CloudpickleWrapper((envs[0].observation_space, envs[0].action_space, envs[0].spec)))
30 | else:
31 | raise NotImplementedError
32 | except KeyboardInterrupt:
33 | print('SubprocVecEnv worker: got KeyboardInterrupt')
34 | finally:
35 | for env in envs:
36 | env.close()
37 |
38 |
39 | class SubprocVecEnv(VecEnv):
40 | """
41 | VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes.
42 | Recommended to use when num_envs > 1 and step() can be a bottleneck.
43 | """
44 | def __init__(self, env_fns, spaces=None, context='spawn', in_series=1):
45 | """
46 | Arguments:
47 |
48 | env_fns: iterable of callables - functions that create environments to run in subprocesses. Need to be cloud-pickleable
49 | in_series: number of environments to run in series in a single process
50 | (e.g. when len(env_fns) == 12 and in_series == 3, it will run 4 processes, each running 3 envs in series)
51 | """
52 | self.waiting = False
53 | self.closed = False
54 | self.in_series = in_series
55 | nenvs = len(env_fns)
56 | assert nenvs % in_series == 0, "Number of envs must be divisible by number of envs to run in series"
57 | self.nremotes = nenvs // in_series
58 | env_fns = np.array_split(env_fns, self.nremotes)
59 | ctx = mp.get_context(context)
60 | self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(self.nremotes)])
61 | self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
62 | for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
63 | for p in self.ps:
64 | p.daemon = True # if the main process crashes, we should not cause things to hang
65 | with clear_mpi_env_vars():
66 | p.start()
67 | for remote in self.work_remotes:
68 | remote.close()
69 |
70 | self.remotes[0].send(('get_spaces_spec', None))
71 | observation_space, action_space, self.spec = self.remotes[0].recv().x
72 | self.viewer = None
73 | VecEnv.__init__(self, nenvs, observation_space, action_space)
74 |
75 | def step_async(self, actions):
76 | self._assert_not_closed()
77 | actions = np.array_split(actions, self.nremotes)
78 | for remote, action in zip(self.remotes, actions):
79 | remote.send(('step', action))
80 | self.waiting = True
81 |
82 | def step_wait(self):
83 | self._assert_not_closed()
84 | results = [remote.recv() for remote in self.remotes]
85 | results = _flatten_list(results)
86 | self.waiting = False
87 | obs, rews, dones, infos = zip(*results)
88 | return _flatten_obs(obs), np.stack(rews), np.stack(dones), infos
89 |
90 | def reset(self):
91 | self._assert_not_closed()
92 | for remote in self.remotes:
93 | remote.send(('reset', None))
94 | obs = [remote.recv() for remote in self.remotes]
95 | obs = _flatten_list(obs)
96 | return _flatten_obs(obs)
97 |
98 | def close_extras(self):
99 | self.closed = True
100 | if self.waiting:
101 | for remote in self.remotes:
102 | remote.recv()
103 | for remote in self.remotes:
104 | remote.send(('close', None))
105 | for p in self.ps:
106 | p.join()
107 |
108 | def get_images(self):
109 | self._assert_not_closed()
110 | for pipe in self.remotes:
111 | pipe.send(('render', None))
112 | imgs = [pipe.recv() for pipe in self.remotes]
113 | imgs = _flatten_list(imgs)
114 | return imgs
115 |
116 | def _assert_not_closed(self):
117 | assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()"
118 |
119 | def __del__(self):
120 | if not self.closed:
121 | self.close()
122 |
123 | def _flatten_obs(obs):
124 | assert isinstance(obs, (list, tuple))
125 | assert len(obs) > 0
126 |
127 | if isinstance(obs[0], dict):
128 | keys = obs[0].keys()
129 | return {k: np.stack([o[k] for o in obs]) for k in keys}
130 | else:
131 | return np.stack(obs)
132 |
133 | def _flatten_list(l):
134 | assert isinstance(l, (list, tuple))
135 | assert len(l) > 0
136 | assert all([len(l_) > 0 for l_ in l])
137 |
138 | return [l__ for l_ in l for l__ in l_]
139 |
--------------------------------------------------------------------------------
/mher/common/vec_env/test_vec_env.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for asynchronous vectorized environments.
3 | """
4 |
5 | import gym
6 | import numpy as np
7 | import pytest
8 | from .dummy_vec_env import DummyVecEnv
9 | from .shmem_vec_env import ShmemVecEnv
10 | from .subproc_vec_env import SubprocVecEnv
11 | from mher.common.tests.test_with_mpi import with_mpi
12 |
13 |
14 | def assert_venvs_equal(venv1, venv2, num_steps):
15 | """
16 | Compare two environments over num_steps steps and make sure
17 | that the observations produced by each are the same when given
18 | the same actions.
19 | """
20 | assert venv1.num_envs == venv2.num_envs
21 | assert venv1.observation_space.shape == venv2.observation_space.shape
22 | assert venv1.observation_space.dtype == venv2.observation_space.dtype
23 | assert venv1.action_space.shape == venv2.action_space.shape
24 | assert venv1.action_space.dtype == venv2.action_space.dtype
25 |
26 | try:
27 | obs1, obs2 = venv1.reset(), venv2.reset()
28 | assert np.array(obs1).shape == np.array(obs2).shape
29 | assert np.array(obs1).shape == (venv1.num_envs,) + venv1.observation_space.shape
30 | assert np.allclose(obs1, obs2)
31 | venv1.action_space.seed(1337)
32 | for _ in range(num_steps):
33 | actions = np.array([venv1.action_space.sample() for _ in range(venv1.num_envs)])
34 | for venv in [venv1, venv2]:
35 | venv.step_async(actions)
36 | outs1 = venv1.step_wait()
37 | outs2 = venv2.step_wait()
38 | for out1, out2 in zip(outs1[:3], outs2[:3]):
39 | assert np.array(out1).shape == np.array(out2).shape
40 | assert np.allclose(out1, out2)
41 | assert list(outs1[3]) == list(outs2[3])
42 | finally:
43 | venv1.close()
44 | venv2.close()
45 |
46 |
47 | @pytest.mark.parametrize('klass', (ShmemVecEnv, SubprocVecEnv))
48 | @pytest.mark.parametrize('dtype', ('uint8', 'float32'))
49 | def test_vec_env(klass, dtype): # pylint: disable=R0914
50 | """
51 | Test that a vectorized environment is equivalent to
52 | DummyVecEnv, since DummyVecEnv is less likely to be
53 | error prone.
54 | """
55 | num_envs = 3
56 | num_steps = 100
57 | shape = (3, 8)
58 |
59 | def make_fn(seed):
60 | """
61 | Get an environment constructor with a seed.
62 | """
63 | return lambda: SimpleEnv(seed, shape, dtype)
64 | fns = [make_fn(i) for i in range(num_envs)]
65 | env1 = DummyVecEnv(fns)
66 | env2 = klass(fns)
67 | assert_venvs_equal(env1, env2, num_steps=num_steps)
68 |
69 |
70 | @pytest.mark.parametrize('dtype', ('uint8', 'float32'))
71 | @pytest.mark.parametrize('num_envs_in_series', (3, 4, 6))
72 | def test_sync_sampling(dtype, num_envs_in_series):
73 | """
74 | Test that a SubprocVecEnv running with envs in series
75 | outputs the same as DummyVecEnv.
76 | """
77 | num_envs = 12
78 | num_steps = 100
79 | shape = (3, 8)
80 |
81 | def make_fn(seed):
82 | """
83 | Get an environment constructor with a seed.
84 | """
85 | return lambda: SimpleEnv(seed, shape, dtype)
86 | fns = [make_fn(i) for i in range(num_envs)]
87 | env1 = DummyVecEnv(fns)
88 | env2 = SubprocVecEnv(fns, in_series=num_envs_in_series)
89 | assert_venvs_equal(env1, env2, num_steps=num_steps)
90 |
91 |
92 | @pytest.mark.parametrize('dtype', ('uint8', 'float32'))
93 | @pytest.mark.parametrize('num_envs_in_series', (3, 4, 6))
94 | def test_sync_sampling_sanity(dtype, num_envs_in_series):
95 | """
96 | Test that a SubprocVecEnv running with envs in series
97 | outputs the same as SubprocVecEnv without running in series.
98 | """
99 | num_envs = 12
100 | num_steps = 100
101 | shape = (3, 8)
102 |
103 | def make_fn(seed):
104 | """
105 | Get an environment constructor with a seed.
106 | """
107 | return lambda: SimpleEnv(seed, shape, dtype)
108 | fns = [make_fn(i) for i in range(num_envs)]
109 | env1 = SubprocVecEnv(fns)
110 | env2 = SubprocVecEnv(fns, in_series=num_envs_in_series)
111 | assert_venvs_equal(env1, env2, num_steps=num_steps)
112 |
113 |
114 | class SimpleEnv(gym.Env):
115 | """
116 | An environment with a pre-determined observation space
117 | and RNG seed.
118 | """
119 |
120 | def __init__(self, seed, shape, dtype):
121 | np.random.seed(seed)
122 | self._dtype = dtype
123 | self._start_obs = np.array(np.random.randint(0, 0x100, size=shape),
124 | dtype=dtype)
125 | self._max_steps = seed + 1
126 | self._cur_obs = None
127 | self._cur_step = 0
128 | # this is 0xFF instead of 0x100 because the Box space includes
129 | # the high end, while randint does not
130 | self.action_space = gym.spaces.Box(low=0, high=0xFF, shape=shape, dtype=dtype)
131 | self.observation_space = self.action_space
132 |
133 | def step(self, action):
134 | self._cur_obs += np.array(action, dtype=self._dtype)
135 | self._cur_step += 1
136 | done = self._cur_step >= self._max_steps
137 | reward = self._cur_step / self._max_steps
138 | return self._cur_obs, reward, done, {'foo': 'bar' + str(reward)}
139 |
140 | def reset(self):
141 | self._cur_obs = self._start_obs
142 | self._cur_step = 0
143 | return self._cur_obs
144 |
145 | def render(self, mode=None):
146 | raise NotImplementedError
147 |
148 |
149 |
150 | @with_mpi()
151 | def test_mpi_with_subprocvecenv():
152 | shape = (2,3,4)
153 | nenv = 1
154 | venv = SubprocVecEnv([lambda: SimpleEnv(0, shape, 'float32')] * nenv)
155 | ob = venv.reset()
156 | venv.close()
157 | assert ob.shape == (nenv,) + shape
158 |
159 |
--------------------------------------------------------------------------------
/mher/common/vec_env/test_video_recorder.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for asynchronous vectorized environments.
3 | """
4 |
5 | import gym
6 | import pytest
7 | import os
8 | import glob
9 | import tempfile
10 |
11 | from .dummy_vec_env import DummyVecEnv
12 | from .shmem_vec_env import ShmemVecEnv
13 | from .subproc_vec_env import SubprocVecEnv
14 | from .vec_video_recorder import VecVideoRecorder
15 |
16 | @pytest.mark.parametrize('klass', (DummyVecEnv, ShmemVecEnv, SubprocVecEnv))
17 | @pytest.mark.parametrize('num_envs', (1, 4))
18 | @pytest.mark.parametrize('video_length', (10, 100))
19 | @pytest.mark.parametrize('video_interval', (1, 50))
20 | def test_video_recorder(klass, num_envs, video_length, video_interval):
21 | """
22 | Wrap an existing VecEnv with VevVideoRecorder,
23 | Make (video_interval + video_length + 1) steps,
24 | then check that the file is present
25 | """
26 |
27 | def make_fn():
28 | env = gym.make('PongNoFrameskip-v4')
29 | return env
30 | fns = [make_fn for _ in range(num_envs)]
31 | env = klass(fns)
32 |
33 | with tempfile.TemporaryDirectory() as video_path:
34 | env = VecVideoRecorder(env, video_path, record_video_trigger=lambda x: x % video_interval == 0, video_length=video_length)
35 |
36 | env.reset()
37 | for _ in range(video_interval + video_length + 1):
38 | env.step([0] * num_envs)
39 | env.close()
40 |
41 |
42 | recorded_video = glob.glob(os.path.join(video_path, "*.mp4"))
43 |
44 | # first and second step
45 | assert len(recorded_video) == 2
46 | # Files are not empty
47 | assert all(os.stat(p).st_size != 0 for p in recorded_video)
48 |
49 |
50 |
--------------------------------------------------------------------------------
/mher/common/vec_env/util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for dealing with vectorized environments.
3 | """
4 |
5 | from collections import OrderedDict
6 |
7 | import gym
8 | import numpy as np
9 |
10 |
11 | def copy_obs_dict(obs):
12 | """
13 | Deep-copy an observation dict.
14 | """
15 | return {k: np.copy(v) for k, v in obs.items()}
16 |
17 |
18 | def dict_to_obs(obs_dict):
19 | """
20 | Convert an observation dict into a raw array if the
21 | original observation space was not a Dict space.
22 | """
23 | if set(obs_dict.keys()) == {None}:
24 | return obs_dict[None]
25 | return obs_dict
26 |
27 |
28 | def obs_space_info(obs_space):
29 | """
30 | Get dict-structured information about a gym.Space.
31 |
32 | Returns:
33 | A tuple (keys, shapes, dtypes):
34 | keys: a list of dict keys.
35 | shapes: a dict mapping keys to shapes.
36 | dtypes: a dict mapping keys to dtypes.
37 | """
38 | if isinstance(obs_space, gym.spaces.Dict):
39 | assert isinstance(obs_space.spaces, OrderedDict)
40 | subspaces = obs_space.spaces
41 | elif isinstance(obs_space, gym.spaces.Tuple):
42 | assert isinstance(obs_space.spaces, tuple)
43 | subspaces = {i: obs_space.spaces[i] for i in range(len(obs_space.spaces))}
44 | else:
45 | subspaces = {None: obs_space}
46 | keys = []
47 | shapes = {}
48 | dtypes = {}
49 | for key, box in subspaces.items():
50 | keys.append(key)
51 | shapes[key] = box.shape
52 | dtypes[key] = box.dtype
53 | return keys, shapes, dtypes
54 |
55 |
56 | def obs_to_dict(obs):
57 | """
58 | Convert an observation into a dict.
59 | """
60 | if isinstance(obs, dict):
61 | return obs
62 | return {None: obs}
63 |
--------------------------------------------------------------------------------
/mher/common/vec_env/vec_frame_stack.py:
--------------------------------------------------------------------------------
1 | from .vec_env import VecEnvWrapper
2 | import numpy as np
3 | from gym import spaces
4 |
5 |
6 | class VecFrameStack(VecEnvWrapper):
7 | def __init__(self, venv, nstack):
8 | self.venv = venv
9 | self.nstack = nstack
10 | wos = venv.observation_space # wrapped ob space
11 | low = np.repeat(wos.low, self.nstack, axis=-1)
12 | high = np.repeat(wos.high, self.nstack, axis=-1)
13 | self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype)
14 | observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)
15 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
16 |
17 | def step_wait(self):
18 | obs, rews, news, infos = self.venv.step_wait()
19 | self.stackedobs = np.roll(self.stackedobs, shift=-1, axis=-1)
20 | for (i, new) in enumerate(news):
21 | if new:
22 | self.stackedobs[i] = 0
23 | self.stackedobs[..., -obs.shape[-1]:] = obs
24 | return self.stackedobs, rews, news, infos
25 |
26 | def reset(self):
27 | obs = self.venv.reset()
28 | self.stackedobs[...] = 0
29 | self.stackedobs[..., -obs.shape[-1]:] = obs
30 | return self.stackedobs
31 |
--------------------------------------------------------------------------------
/mher/common/vec_env/vec_monitor.py:
--------------------------------------------------------------------------------
1 | from . import VecEnvWrapper
2 | from mher.common.monitor import ResultsWriter
3 | import numpy as np
4 | import time
5 | from collections import deque
6 |
7 | class VecMonitor(VecEnvWrapper):
8 | def __init__(self, venv, filename=None, keep_buf=0, info_keywords=()):
9 | VecEnvWrapper.__init__(self, venv)
10 | self.eprets = None
11 | self.eplens = None
12 | self.epcount = 0
13 | self.tstart = time.time()
14 | if filename:
15 | self.results_writer = ResultsWriter(filename, header={'t_start': self.tstart},
16 | extra_keys=info_keywords)
17 | else:
18 | self.results_writer = None
19 | self.info_keywords = info_keywords
20 | self.keep_buf = keep_buf
21 | if self.keep_buf:
22 | self.epret_buf = deque([], maxlen=keep_buf)
23 | self.eplen_buf = deque([], maxlen=keep_buf)
24 |
25 | def reset(self):
26 | obs = self.venv.reset()
27 | self.eprets = np.zeros(self.num_envs, 'f')
28 | self.eplens = np.zeros(self.num_envs, 'i')
29 | return obs
30 |
31 | def step_wait(self):
32 | obs, rews, dones, infos = self.venv.step_wait()
33 | self.eprets += rews
34 | self.eplens += 1
35 |
36 | newinfos = list(infos[:])
37 | for i in range(len(dones)):
38 | if dones[i]:
39 | info = infos[i].copy()
40 | ret = self.eprets[i]
41 | eplen = self.eplens[i]
42 | epinfo = {'r': ret, 'l': eplen, 't': round(time.time() - self.tstart, 6)}
43 | for k in self.info_keywords:
44 | epinfo[k] = info[k]
45 | info['episode'] = epinfo
46 | if self.keep_buf:
47 | self.epret_buf.append(ret)
48 | self.eplen_buf.append(eplen)
49 | self.epcount += 1
50 | self.eprets[i] = 0
51 | self.eplens[i] = 0
52 | if self.results_writer:
53 | self.results_writer.write_row(epinfo)
54 | newinfos[i] = info
55 | return obs, rews, dones, newinfos
56 |
--------------------------------------------------------------------------------
/mher/common/vec_env/vec_normalize.py:
--------------------------------------------------------------------------------
1 | from . import VecEnvWrapper
2 | import numpy as np
3 |
4 | class VecNormalize(VecEnvWrapper):
5 | """
6 | A vectorized wrapper that normalizes the observations
7 | and returns from an environment.
8 | """
9 |
10 | def __init__(self, venv, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8, use_tf=False):
11 | VecEnvWrapper.__init__(self, venv)
12 | if use_tf:
13 | from mher.common.running_mean_std import TfRunningMeanStd
14 | self.ob_rms = TfRunningMeanStd(shape=self.observation_space.shape, scope='ob_rms') if ob else None
15 | self.ret_rms = TfRunningMeanStd(shape=(), scope='ret_rms') if ret else None
16 | else:
17 | from mher.common.running_mean_std import RunningMeanStd
18 | self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
19 | self.ret_rms = RunningMeanStd(shape=()) if ret else None
20 | self.clipob = clipob
21 | self.cliprew = cliprew
22 | self.ret = np.zeros(self.num_envs)
23 | self.gamma = gamma
24 | self.epsilon = epsilon
25 |
26 | def step_wait(self):
27 | obs, rews, news, infos = self.venv.step_wait()
28 | self.ret = self.ret * self.gamma + rews
29 | obs = self._obfilt(obs)
30 | if self.ret_rms:
31 | self.ret_rms.update(self.ret)
32 | rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
33 | self.ret[news] = 0.
34 | return obs, rews, news, infos
35 |
36 | def _obfilt(self, obs):
37 | if self.ob_rms:
38 | self.ob_rms.update(obs)
39 | obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
40 | return obs
41 | else:
42 | return obs
43 |
44 | def reset(self):
45 | self.ret = np.zeros(self.num_envs)
46 | obs = self.venv.reset()
47 | return self._obfilt(obs)
48 |
--------------------------------------------------------------------------------
/mher/common/vec_env/vec_remove_dict_obs.py:
--------------------------------------------------------------------------------
1 | from .vec_env import VecEnvObservationWrapper
2 |
3 | class VecExtractDictObs(VecEnvObservationWrapper):
4 | def __init__(self, venv, key):
5 | self.key = key
6 | super().__init__(venv=venv,
7 | observation_space=venv.observation_space.spaces[self.key])
8 |
9 | def process(self, obs):
10 | return obs[self.key]
11 |
--------------------------------------------------------------------------------
/mher/common/vec_env/vec_video_recorder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from mher import logger
3 | from mher.common.vec_env import VecEnvWrapper
4 | from gym.wrappers.monitoring import video_recorder
5 |
6 |
7 | class VecVideoRecorder(VecEnvWrapper):
8 | """
9 | Wrap VecEnv to record rendered image as mp4 video.
10 | """
11 |
12 | def __init__(self, venv, directory, record_video_trigger, video_length=200):
13 | """
14 | # Arguments
15 | venv: VecEnv to wrap
16 | directory: Where to save videos
17 | record_video_trigger:
18 | Function that defines when to start recording.
19 | The function takes the current number of step,
20 | and returns whether we should start recording or not.
21 | video_length: Length of recorded video
22 | """
23 |
24 | VecEnvWrapper.__init__(self, venv)
25 | self.record_video_trigger = record_video_trigger
26 | self.video_recorder = None
27 |
28 | self.directory = os.path.abspath(directory)
29 | if not os.path.exists(self.directory): os.mkdir(self.directory)
30 |
31 | self.file_prefix = "vecenv"
32 | self.file_infix = '{}'.format(os.getpid())
33 | self.step_id = 0
34 | self.video_length = video_length
35 |
36 | self.recording = False
37 | self.recorded_frames = 0
38 |
39 | def reset(self):
40 | obs = self.venv.reset()
41 |
42 | self.start_video_recorder()
43 |
44 | return obs
45 |
46 | def start_video_recorder(self):
47 | self.close_video_recorder()
48 |
49 | base_path = os.path.join(self.directory, '{}.video.{}.video{:06}'.format(self.file_prefix, self.file_infix, self.step_id))
50 | self.video_recorder = video_recorder.VideoRecorder(
51 | env=self.venv,
52 | base_path=base_path,
53 | metadata={'step_id': self.step_id}
54 | )
55 |
56 | self.video_recorder.capture_frame()
57 | self.recorded_frames = 1
58 | self.recording = True
59 |
60 | def _video_enabled(self):
61 | return self.record_video_trigger(self.step_id)
62 |
63 | def step_wait(self):
64 | obs, rews, dones, infos = self.venv.step_wait()
65 |
66 | self.step_id += 1
67 | if self.recording:
68 | self.video_recorder.capture_frame()
69 | self.recorded_frames += 1
70 | if self.recorded_frames > self.video_length:
71 | logger.info("Saving video to ", self.video_recorder.path)
72 | self.close_video_recorder()
73 | elif self._video_enabled():
74 | self.start_video_recorder()
75 |
76 | return obs, rews, dones, infos
77 |
78 | def close_video_recorder(self):
79 | if self.recording:
80 | self.video_recorder.close()
81 | self.recording = False
82 | self.recorded_frames = 0
83 |
84 | def close(self):
85 | VecEnvWrapper.close(self)
86 | self.close_video_recorder()
87 |
88 | def __del__(self):
89 | self.close()
90 |
--------------------------------------------------------------------------------
/mher/common/wrappers.py:
--------------------------------------------------------------------------------
1 | import gym
2 |
3 | class TimeLimit(gym.Wrapper):
4 | def __init__(self, env, max_episode_steps=None):
5 | super(TimeLimit, self).__init__(env)
6 | self._max_episode_steps = max_episode_steps
7 | self._elapsed_steps = 0
8 |
9 | def step(self, ac):
10 | observation, reward, done, info = self.env.step(ac)
11 | self._elapsed_steps += 1
12 | if self._elapsed_steps >= self._max_episode_steps:
13 | done = True
14 | info['TimeLimit.truncated'] = True
15 | return observation, reward, done, info
16 |
17 | def reset(self, **kwargs):
18 | self._elapsed_steps = 0
19 | return self.env.reset(**kwargs)
20 |
21 | class ClipActionsWrapper(gym.Wrapper):
22 | def step(self, action):
23 | import numpy as np
24 | action = np.nan_to_num(action)
25 | action = np.clip(action, self.action_space.low, self.action_space.high)
26 | return self.env.step(action)
27 |
28 | def reset(self, **kwargs):
29 | return self.env.reset(**kwargs)
30 |
--------------------------------------------------------------------------------
/mher/default_cfg.py:
--------------------------------------------------------------------------------
1 | DEFAULT_ENV_PARAMS = {
2 | 'SawyerPush-v0':{
3 | 'n_cycles':10,
4 | 'n_batches':5,
5 | 'n_test_rollouts':50,
6 | 'batch_size':64,
7 | 'rollout_batch_size':1
8 | },
9 | 'SawyerReachXYZEnv-v1':{
10 | 'n_cycles':5,
11 | 'n_batches':2,
12 | 'n_test_rollouts':50,
13 | 'batch_size':64
14 | },
15 | 'FetchReach-v1': {
16 | 'n_cycles': 10,
17 | 'n_test_rollouts': 20,
18 | 'n_batches': 2,
19 | 'batch_size': 64,
20 | },
21 | # 'FetchPush-v1': {
22 | # 'n_cycles': 10,
23 | # 'n_test_rollouts': 20,
24 | # 'n_batches': 10,
25 | # 'batch_size': 256,
26 | # },
27 | }
28 |
29 |
30 | DEFAULT_PARAMS = {
31 | # algorithm
32 | 'algo':'ddpg',
33 | # env
34 | 'max_u': 1., # max absolute value of actions on different coordinates
35 | # ddpg
36 | 'layers': 3, # number of layers in the critic/actor networks
37 | 'hidden': 256, # number of neurons in each hidden layers
38 | 'Q_lr': 0.001, # critic learning rate
39 | 'pi_lr': 0.001, # actor learning rate
40 | 'polyak': 0.95, # polyak averaging coefficient
41 | 'action_l2': 1.0, # quadratic penalty on actions (before rescaling by max_u)
42 | 'clip_obs': 200.,
43 | 'relative_goals': False,
44 | 'clip_pos_returns': True,
45 | 'clip_return': True,
46 |
47 | # sac
48 | 'sac_alpha':0.03,
49 |
50 | # buffer
51 | 'buffer_size': int(1E6), # for experience replay
52 | 'sampler': 'random',
53 |
54 | # training
55 | 'n_cycles': 50, # per epoch
56 | 'rollout_batch_size': 2, # per mpi thread
57 | 'n_batches': 40, # training batches per cycle
58 | 'batch_size': 1024, #258 per mpi thread, measured in transitions and reduced to even multiple of chunk_length.
59 | 'n_test_rollouts': 10, # number of test rollouts per epoch, each consists of rollout_batch_size rollouts
60 | 'test_with_polyak': False, # run test episodes with the target network
61 | # playing
62 | 'play_episodes':1, # number of running test episodes
63 | # saving
64 | 'policy_save_interval': 10,
65 | # exploration
66 | 'random_eps': 0.3, # percentage of time a random action is taken
67 | 'noise_eps': 0.2, # std of gaussian noise added to not-completely-random actions as a percentage of max_u
68 | # HER
69 | 'replay_strategy': 'future', # supported modes: future, none
70 | 'relabel_p': 0.8, # relabeling probability
71 | # normalization
72 | 'norm_eps': 1e-4, # epsilon used for observation normalization
73 | 'norm_clip': 5, # normalized observations are cropped to this values
74 |
75 | # random init episode
76 | 'random_init':100, # for dynamic n-step, this should be bigger
77 |
78 | # prioritized experience replay
79 | 'alpah': 0.6,
80 | 'beta': 0.4,
81 | 'eps': 1e-5,
82 |
83 | # n step hindsight experience
84 | 'nstep':3,
85 | 'use_nstep':False,
86 |
87 | # lambda n-step
88 | 'use_lambda_nstep':False,
89 | 'lamb':0.7,
90 |
91 | # dynamic n-step
92 | 'use_dynamic_nstep':False,
93 | 'alpha':0.5,
94 | 'dynamic_batchsize':512, # warm up the dynamic model
95 | 'dynamic_init':500,
96 |
97 | # if do not use her
98 | 'no_her':False # no her, will be used for DDPG and n-step DDPG
99 | }
--------------------------------------------------------------------------------
/mher/envs/__pycache__/env_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/envs/__pycache__/env_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/envs/__pycache__/make_env_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/envs/__pycache__/make_env_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/envs/env_utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | Util tools for environments
3 | '''
4 | import re
5 |
6 | import gym
7 |
8 |
9 | def simple_goal_subtract(a, b):
10 | assert a.shape == b.shape
11 | return a - b
12 |
13 | def g_to_ag(o, env_id):
14 | if env_id == 'FetchReach':
15 | ag = o[:,0:3]
16 | elif env_id in ['FetchPush','FetchSlide', 'FetchPickAndPlace']:
17 | ag = o[:,3:6]
18 | else:
19 | raise NotImplementedError
20 | return ag
21 |
22 | CACHED_ENVS = {}
23 | def cached_make_env(make_env):
24 | """
25 | Only creates a new environment from the provided function if one has not yet already been
26 | created. This is useful here because we need to infer certain properties of the env, e.g.
27 | its observation and action spaces, without any intend of actually using it.
28 | """
29 | if make_env not in CACHED_ENVS:
30 | env = make_env()
31 | CACHED_ENVS[make_env] = env
32 | return CACHED_ENVS[make_env]
33 |
34 | def get_rewardfun(params, tmp_env):
35 | tmp_env.reset()
36 | def reward_fun(ag_2, g, info): # vectorized
37 | return tmp_env.compute_reward(achieved_goal=ag_2, desired_goal=g, info=info)
38 | return reward_fun
39 |
40 | def get_env_type(args, _game_envs):
41 | env_id = args.env
42 | # Re-parse the gym registry, since we could have new envs since last time.
43 | for env in gym.envs.registry.all():
44 | try:
45 | env_type = env.entry_point.split(':')[0].split('.')[-1]
46 | _game_envs[env_type].add(env.id) # This is a set so add is idempotent
47 | except:
48 | pass
49 |
50 | if env_id in _game_envs.keys():
51 | env_type = env_id
52 | env_id = [g for g in _game_envs[env_type]][0]
53 | else:
54 | env_type = None
55 | for g, e in _game_envs.items():
56 | if env_id in e:
57 | env_type = g
58 | break
59 | if ':' in env_id:
60 | env_type = re.sub(r':.*', '', env_id)
61 | assert env_type is not None, 'env_id {} is not recognized in env types {}'.format(env_id, _game_envs.keys())
62 |
63 | return env_type, env_id
64 |
65 | def obs_to_goal_fun(env):
66 | # only support Fetchenv and Handenv now
67 | from gym.envs.robotics import FetchEnv, hand_env
68 | from multiworld.envs.mujoco.sawyer_xyz import (sawyer_push_nips,
69 | sawyer_reach)
70 | from multiworld.envs.pygame import point2d
71 |
72 | if isinstance(env.env, FetchEnv):
73 | obs_dim = env.observation_space['observation'].shape[0]
74 | goal_dim = env.observation_space['desired_goal'].shape[0]
75 | temp_dim = env.sim.data.get_site_xpos('robot0:grip').shape[0]
76 | def obs_to_goal(observation):
77 | observation = observation.reshape(-1, obs_dim)
78 | if env.has_object:
79 | goal = observation[:, temp_dim:temp_dim + goal_dim]
80 | else:
81 | goal = observation[:, :goal_dim]
82 | return goal.copy()
83 | elif isinstance(env.env, hand_env.HandEnv):
84 | goal_dim = env.observation_space['desired_goal'].shape[0]
85 | def obs_to_goal(observation):
86 | goal = observation[:, -goal_dim:]
87 | return goal.copy()
88 | elif isinstance(env.env.env, point2d.Point2DEnv):
89 | def obs_to_goal(observation):
90 | return observation.copy()
91 | elif isinstance(env.env.env, sawyer_push_nips.SawyerPushAndReachXYEnv):
92 | assert env.env.env.observation_space['observation'].shape == env.env.env.observation_space['achieved_goal'].shape, \
93 | "This environment's observation space doesn't equal goal space"
94 | def obs_to_goal(observation):
95 | return observation
96 | elif isinstance(env.env.env, sawyer_reach.SawyerReachXYZEnv):
97 | def obs_to_goal(observation):
98 | return observation
99 | else:
100 | import pdb; pdb.set_trace()
101 | raise NotImplementedError('Do not support such type {}'.format(env))
102 |
103 | return obs_to_goal
104 |
--------------------------------------------------------------------------------
/mher/envs/make_env_utils.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import os
3 | import sys
4 |
5 | import gym
6 | import tensorflow as tf
7 | from gym.wrappers import FilterObservation, FlattenObservation
8 | from mher.common import logger, retro_wrappers, set_global_seeds
9 | from mher.common.init_utils import init_mpi_import
10 | from mher.common.monitor import Monitor
11 | from mher.common.tf_util import get_session
12 | from mher.common.vec_env import VecEnv, VecFrameStack, VecNormalize
13 | from mher.common.vec_env.dummy_vec_env import DummyVecEnv
14 | from mher.common.vec_env.subproc_vec_env import SubprocVecEnv
15 | from mher.common.wrappers import ClipActionsWrapper
16 | from mher.envs.env_utils import get_env_type
17 |
18 | MPI = init_mpi_import()
19 |
20 | def build_env(args, _game_envs):
21 | ncpu = multiprocessing.cpu_count()
22 | if sys.platform == 'darwin': ncpu //= 2
23 | alg = args.alg
24 | seed = args.seed
25 |
26 | env_type, env_id = get_env_type(args, _game_envs)
27 | config = tf.ConfigProto(allow_soft_placement=True,
28 | intra_op_parallelism_threads=1,
29 | inter_op_parallelism_threads=1)
30 | config.gpu_options.allow_growth = True
31 | get_session(config=config)
32 |
33 | reward_scale = args.reward_scale if hasattr(args, 'reward_scale') else 1
34 | flatten_dict_observations = alg not in {'her'}
35 | env = make_vec_env(env_id, env_type, args.num_env or 1, seed,
36 | reward_scale=reward_scale,
37 | flatten_dict_observations=flatten_dict_observations)
38 |
39 | if env_type == 'mujoco':
40 | env = VecNormalize(env, use_tf=True)
41 | # build one simple env without vector wrapper
42 | tmp_env = make_env(env_id, env_type, seed=seed,
43 | reward_scale=reward_scale,
44 | flatten_dict_observations=flatten_dict_observations,
45 | logger_dir=logger.get_dir())
46 |
47 | return env, tmp_env
48 |
49 | def make_vec_env(env_id, env_type, num_env, seed,
50 | wrapper_kwargs=None,
51 | env_kwargs=None,
52 | start_index=0,
53 | reward_scale=1.0,
54 | flatten_dict_observations=True,
55 | gamestate=None,
56 | initializer=None,
57 | force_dummy=False):
58 | """
59 | Create a wrapped, monitored SubprocVecEnv for Atari and MuJoCo.
60 | """
61 | wrapper_kwargs = wrapper_kwargs or {}
62 | env_kwargs = env_kwargs or {}
63 | mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
64 | seed = seed + 10000 * mpi_rank if seed is not None else None
65 | logger_dir = logger.get_dir()
66 | def make_thunk(rank, initializer=None):
67 | return lambda: make_env(
68 | env_id=env_id,
69 | env_type=env_type,
70 | mpi_rank=mpi_rank,
71 | subrank=rank,
72 | seed=seed,
73 | reward_scale=reward_scale,
74 | gamestate=gamestate,
75 | flatten_dict_observations=flatten_dict_observations,
76 | wrapper_kwargs=wrapper_kwargs,
77 | env_kwargs=env_kwargs,
78 | logger_dir=logger_dir,
79 | initializer=initializer
80 | )
81 | set_global_seeds(seed)
82 | if not force_dummy and num_env > 1:
83 | return SubprocVecEnv([make_thunk(i + start_index, initializer=initializer) for i in range(num_env)])
84 | else:
85 | return DummyVecEnv([make_thunk(i + start_index, initializer=None) for i in range(num_env)])
86 |
87 |
88 | def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None,
89 | flatten_dict_observations=True, wrapper_kwargs=None, env_kwargs=None, logger_dir=None, initializer=None):
90 | if initializer is not None:
91 | initializer(mpi_rank=mpi_rank, subrank=subrank)
92 |
93 | wrapper_kwargs = wrapper_kwargs or {}
94 | env_kwargs = env_kwargs or {}
95 | if ':' in env_id:
96 | import importlib
97 | import re
98 | module_name = re.sub(':.*','',env_id)
99 | env_id = re.sub('.*:', '', env_id)
100 | importlib.import_module(module_name)
101 |
102 | env = gym.make(env_id, **env_kwargs)
103 | # if env_id.startswith('Sawyer'):
104 | # from mher.algos.multi_world_wrapper import SawyerGoalWrapper
105 | # env = SawyerGoalWrapper(env)
106 | # if (env_id.startswith('Sawyer') or env_id.startswith('Point2D')) and not hasattr(env, '_max_episode_steps'):
107 | # env = gym.wrappers.TimeLimit(env, max_episode_steps=100)
108 |
109 | if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
110 | env = FlattenObservation(env)
111 |
112 | env.seed(seed + subrank if seed is not None else None)
113 | env = Monitor(env,
114 | logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)),
115 | allow_early_resets=True)
116 |
117 | if isinstance(env.action_space, gym.spaces.Box):
118 | env = ClipActionsWrapper(env)
119 |
120 | if reward_scale != 1:
121 | env = retro_wrappers.RewardScaler(env, reward_scale)
122 | return env
123 |
124 | def make_mujoco_env(env_id, seed, reward_scale=1.0):
125 | """
126 | Create a wrapped, monitored gym.Env for MuJoCo.
127 | """
128 | rank = MPI.COMM_WORLD.Get_rank()
129 | myseed = seed + 1000 * rank if seed is not None else None
130 | set_global_seeds(myseed)
131 | env = gym.make(env_id)
132 | logger_path = None if logger.get_dir() is None else os.path.join(logger.get_dir(), str(rank))
133 | env = Monitor(env, logger_path, allow_early_resets=True)
134 | env.seed(seed)
135 | if reward_scale != 1.0:
136 | from mher.common.retro_wrappers import RewardScaler
137 | env = RewardScaler(env, reward_scale)
138 | return env
139 |
140 | def make_robotics_env(env_id, seed, rank=0):
141 | """
142 | Create a wrapped, monitored gym.Env for MuJoCo.
143 | """
144 | set_global_seeds(seed)
145 | env = gym.make(env_id)
146 | env = FlattenObservation(FilterObservation(env, ['observation', 'desired_goal']))
147 | env = Monitor(
148 | env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
149 | info_keywords=('is_success',))
150 | env.seed(seed)
151 | return env
152 |
--------------------------------------------------------------------------------
/mher/envs/wrappers/__pycache__/wrapper_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/envs/wrappers/__pycache__/wrapper_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/mher/envs/wrappers/multi_world_wrapper.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import gym
4 | import multiworld
5 | import numpy as np
6 | from gym.core import Wrapper
7 |
8 |
9 | # for point env
10 | class PointGoalWrapper(Wrapper):
11 | def __init__(self, env):
12 | Wrapper.__init__(self, env=env)
13 | self.env = env
14 | self.action_space = env.action_space
15 | self.observation_space = env.observation_space
16 |
17 | def reset(self):
18 | return self.env.reset()
19 |
20 | def step(self, action):
21 | import pdb; pdb.set_trace()
22 | obs_dict, reward, done, info = self.env.step(action)
23 | obs = {
24 | 'observation':obs_dict['observation'],
25 | 'desired_goal':obs_dict['desired_goal'],
26 | 'achieved_goal':obs_dict['achieved_goal']
27 | }
28 | return obs, reward, done, info
29 |
30 | def render(self, mode='human'):
31 | return self.env.render()
32 |
33 | def compute_reward(self, achieved_goal, desired_goal, info=None):
34 | obs = {
35 | 'state_achieved_goal': achieved_goal,
36 | 'state_desired_goal':desired_goal
37 | }
38 | action = np.array([])
39 | return self.env.compute_reward(action, obs)
40 |
41 | def sample_goal(self):
42 | goal_dict = self.env.sample_goal()
43 | return goal_dict['desired_goal']
44 |
45 | # for sawyer env
46 | class SawyerGoalWrapper(Wrapper):
47 | reward_type_dict = {
48 | 'dense':'hand_distance',
49 | 'sparse':'hand_success'
50 | }
51 | observation_keys = ['observation', 'desired_goal', 'achieved_goal']
52 |
53 | def __init__(self, env, reward_type='sparse'):
54 | Wrapper.__init__(self, env=env)
55 | self.env = env
56 | self.action_space = env.action_space
57 | # observation
58 | for key in list(env.observation_space.spaces.keys()):
59 | if key not in self.observation_keys:
60 | del env.observation_space.spaces[key]
61 |
62 | self.observation_space = env.observation_space
63 | self.reward_type = reward_type
64 | self.env.reward_type = self.reward_type_dict[self.reward_type]
65 | # self.env.indicator_threshold = 0.03
66 |
67 | def reset(self):
68 | return self.env.reset()
69 |
70 | def step(self, action):
71 | obs_dict, reward, done, info = self.env.step(action)
72 | obs = {
73 | 'observation':obs_dict['observation'],
74 | 'desired_goal':obs_dict['desired_goal'],
75 | 'achieved_goal':obs_dict['achieved_goal']
76 | }
77 | if 'hand_success' in info.keys():
78 | info['is_success'] = info['hand_success']
79 | if 'success' in info.keys():
80 | info['is_success'] = info['success']
81 | import pdb; pdb.set_trace()
82 | return obs, reward, done, info
83 |
84 | def render(self, mode='human'):
85 | return self.env.render()
86 |
87 | def compute_reward(self, achieved_goal, desired_goal, info):
88 | obs = {
89 | 'state_achieved_goal': achieved_goal,
90 | 'state_desired_goal':desired_goal
91 | }
92 | action = np.array([])
93 | return self.env.compute_rewards(action, obs)
94 |
95 | def sample_goal(self):
96 | goal_dict = self.env.sample_goal()
97 | return goal_dict['desired_goal']
98 |
--------------------------------------------------------------------------------
/mher/envs/wrappers/wrapper_utils.py:
--------------------------------------------------------------------------------
1 |
2 | def recurse_attribute(obj, attr, max_depth=3):
3 | '''find env's attribution'''
4 | tmp_obj = obj
5 | depth = 0
6 | while depth < max_depth and not hasattr(tmp_obj, attr):
7 | tmp_obj = tmp_obj.env
8 | depth += 1
9 | if hasattr(tmp_obj, attr):
10 | return getattr(tmp_obj, attr)
11 | else:
12 | return None
13 |
14 |
--------------------------------------------------------------------------------
/mher/play.py:
--------------------------------------------------------------------------------
1 | # DEPRECATED, use --play flag to mher.run instead
2 | import pickle
3 |
4 | import click
5 | import numpy as np
6 |
7 | import mher.config as config
8 | from mher.rollouts.rollout import RolloutWorker
9 | from mher.common import logger, set_global_seeds
10 | from mher.common.vec_env import VecEnv
11 |
12 |
13 | @click.command()
14 | @click.argument('policy_file', type=str)
15 | @click.option('--seed', type=int, default=0)
16 | @click.option('--n_test_rollouts', type=int, default=10)
17 | @click.option('--render', type=int, default=1)
18 |
19 | def main(policy_file, seed, n_test_rollouts, render):
20 | set_global_seeds(seed)
21 |
22 | # Load policy.
23 | with open(policy_file, 'rb') as f:
24 | policy = pickle.load(f)
25 | env_name = policy.info['env_name']
26 |
27 | # Prepare params.
28 | params = config.DEFAULT_PARAMS
29 | if env_name in config.DEFAULT_ENV_PARAMS:
30 | params.update(config.DEFAULT_ENV_PARAMS[env_name]) # merge env-specific parameters in
31 | params['env_name'] = env_name
32 | params = config.prepare_params(params)
33 | config.log_params(params, logger=logger)
34 | dims = config.configure_dims(params)
35 |
36 | eval_params = {
37 | 'exploit': True,
38 | 'use_target_net': params['test_with_polyak'],
39 | 'compute_Q': True,
40 | 'rollout_batch_size': 1,
41 | 'render': bool(render),
42 | }
43 | for name in ['T', 'gamma', 'noise_eps', 'random_eps']:
44 | eval_params[name] = params[name]
45 |
46 | evaluator = RolloutWorker(params['make_env'], policy, dims, logger, **eval_params)
47 | evaluator.seed(seed)
48 |
49 | # Run evaluation.
50 | evaluator.clear_history()
51 | for _ in range(n_test_rollouts):
52 | evaluator.generate_rollouts()
53 |
54 | # record logs
55 | for key, val in evaluator.logs('test'):
56 | logger.record_tabular(key, np.mean(val))
57 | logger.dump_tabular()
58 |
59 |
60 | # playing with a model and an environment
61 | def play(model, env, episodes=1):
62 | logger.log("Running trained model")
63 | obs = env.reset()
64 | state = model.initial_state if hasattr(model, 'initial_state') else None
65 | dones = np.zeros((1,))
66 |
67 | episode_rew = np.zeros((episodes, env.num_envs)) if isinstance(env, VecEnv) else np.zeros((episodes, 1))
68 | ep_num = 0
69 | while ep_num < episodes:
70 | actions, _, _, _ = model.step(obs)
71 |
72 | obs, rew, done, _ = env.step(actions)
73 | episode_rew[ep_num] += rew
74 | env.render()
75 | done_any = done.any() if isinstance(done, np.ndarray) else done
76 | if done_any:
77 | logger.log('episode_rew={}'.format(episode_rew[ep_num]))
78 | ep_num += 1
79 | obs = env.reset()
80 | average_reward = np.mean(episode_rew)
81 | logger.log('Total average test reward:{}'.format(average_reward))
82 | return average_reward
83 |
84 |
85 | if __name__ == '__main__':
86 | main()
87 |
--------------------------------------------------------------------------------
/mher/plot.py:
--------------------------------------------------------------------------------
1 | import os
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import json
5 | import math
6 | from numpy.core.fromnumeric import size
7 | from numpy.lib.function_base import i0
8 | from numpy.lib.npyio import save
9 | from numpy.ma.core import right_shift
10 | import seaborn as sns; sns.set()
11 | import glob2
12 | import argparse
13 | plt.rcParams['pdf.fonttype'] = 42
14 | plt.rcParams['ps.fonttype'] = 42
15 |
16 |
17 | smooth = True
18 |
19 | def smooth_reward_curve(x, y):
20 | halfwidth = 2
21 | k = halfwidth
22 | xsmoo = x
23 | ysmoo = np.convolve(y, np.ones(2 * k + 1), mode='same') / np.convolve(np.ones_like(y), np.ones(2 * k + 1),
24 | mode='same')
25 | return xsmoo, ysmoo
26 |
27 |
28 | def load_results(file):
29 | if not os.path.exists(file):
30 | return None
31 | with open(file, 'r') as f:
32 | lines = [line for line in f]
33 | if len(lines) < 2:
34 | return None
35 | keys = [name.strip() for name in lines[0].split(',')]
36 | try:
37 | data = np.genfromtxt(file, delimiter=',', skip_header=1, filling_values=0.)
38 | except:
39 | import pdb; pdb.set_trace()
40 | if data.ndim == 1:
41 | data = data.reshape(1, -1)
42 | assert data.ndim == 2
43 | assert data.shape[-1] == len(keys)
44 | result = {}
45 | for idx, key in enumerate(keys):
46 | result[key] = data[:, idx]
47 | return result
48 |
49 |
50 | def pad(xs, value=np.nan):
51 | maxlen = np.max([len(x) for x in xs])
52 |
53 | padded_xs = []
54 | for x in xs:
55 | if x.shape[0] >= maxlen:
56 | padded_xs.append(x)
57 |
58 | padding = np.ones((maxlen - x.shape[0],) + x.shape[1:]) * value
59 | x_padded = np.concatenate([x, padding], axis=0)
60 | assert x_padded.shape[1:] == x.shape[1:]
61 | assert x_padded.shape[0] == maxlen
62 | padded_xs.append(x_padded)
63 | return np.array(padded_xs)
64 |
65 |
66 | # Load all data.
67 | def load_data(dir, key='test/success_rate', filename='progress.csv'):
68 | data = []
69 | # find all */progress.csv under dir
70 | paths = [os.path.abspath(os.path.join(path, '..')) for path in glob2.glob(os.path.join(dir, '**', filename))]
71 | for curr_path in paths:
72 | if not os.path.isdir(curr_path):
73 | continue
74 | results = load_results(os.path.join(curr_path, filename))
75 | if not results:
76 | print('skipping {}'.format(curr_path))
77 | continue
78 | print('loading {} ({})'.format(curr_path, len(results['epoch'])))
79 |
80 | success_rate = np.array(results[key])[:50]
81 | epoch = np.array(results['epoch'])[:50] + 1
82 |
83 | # Process and smooth data.
84 | assert success_rate.shape == epoch.shape
85 | x = epoch
86 | y = success_rate
87 | if smooth:
88 | x, y = smooth_reward_curve(epoch, success_rate)
89 | assert x.shape == y.shape
90 | data.append((x, y))
91 | return data
92 |
93 | def load_datas(dirs, key='test/success_rate', filename='progress.csv'):
94 | datas = []
95 | for dir in dirs:
96 | data = load_data(dir, key, filename)
97 | datas.append(data)
98 | return datas
99 |
100 | # Plot datas
101 | def plot_datas(datas, labels, info, fontsize=15, i=0, j=0):
102 | title, xlabel, ylabel = info
103 | for data, label in zip(datas, labels):
104 | try:
105 | xs, ys = zip(*data)
106 | except:
107 | import pdb; pdb.set_trace()
108 | xs, ys = pad(xs), pad(ys)
109 | assert xs.shape == ys.shape
110 |
111 | plt.plot(xs[0], np.nanmedian(ys, axis=0), label=label)
112 | plt.fill_between(xs[0], np.nanpercentile(ys, 25, axis=0), np.nanpercentile(ys, 75, axis=0), alpha=0.25)
113 | plt.title(title, fontsize=fontsize)
114 | plt.xlabel(xlabel, fontsize=fontsize)
115 | plt.ylabel(ylabel, fontsize=fontsize)
116 | plt.legend(fontsize=fontsize-3, loc=4, bbox_to_anchor=(0.5, 0.06, 0.5, 0.5))
117 | plt.xticks(fontsize=fontsize-3)
118 | plt.yticks(fontsize=fontsize-4)
119 |
120 | def plot_main(dirs, labels, info, key='test/success_rate', filename='progress.csv', save_dir='./test.png'):
121 | plt.figure(dpi=300, figsize=(5,4))
122 | datas = load_datas(dirs, key, filename)
123 |
124 | plot_datas(datas, labels, info)
125 | plt.subplots_adjust(left=0.14, right=0.98, bottom=0.15, top=0.92, hspace=0.3, wspace=0.15)
126 | plt.savefig(save_dir)
127 |
128 |
129 | if __name__ == '__main__':
130 | data_dirs = ['', '']
131 | save_dir = ''
132 | legend = ['HER', 'CHER']
133 | infos = ['title', 'Epoch', 'Median success rate']
134 | plot_main(data_dirs, legend, infos, key='test/mean_Q', save_dir=save_dir)
--------------------------------------------------------------------------------
/mher/rollouts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YangRui2015/Modular_HER/77acca83d6849d140ab893ec1b472b71e1da08d4/mher/rollouts/__init__.py
--------------------------------------------------------------------------------
/mher/run.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import multiprocessing
3 | import os
4 | import os.path as osp
5 | import re
6 | import sys
7 |
8 | import gym
9 | import numpy as np
10 | import tensorflow as tf
11 |
12 | from mher import config
13 | from mher.rollouts.rollout import RolloutWorker
14 | from mher.common import logger, set_global_seeds, tf_util
15 | from mher.common.cmd_util import preprocess_kwargs
16 | from mher.common.import_util import get_alg_module
17 | from mher.common.init_utils import init_environment_import, init_mpi_import
18 | from mher.common.logger import configure_logger
19 | from mher.envs.make_env_utils import build_env
20 | from mher.play import play
21 | from mher.train import train
22 |
23 | MPI = init_mpi_import()
24 | _game_envs = init_environment_import()
25 |
26 | def prepare(args):
27 | ## make save dir
28 | if args.save_path:
29 | os.makedirs(os.path.expanduser(args.save_path), exist_ok=True)
30 | # configure logger, disable logging in child MPI processes (with rank > 0)
31 | if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
32 | configure_logger(args.log_path)
33 | else:
34 | configure_logger(args.log_path, format_strs=[])
35 | # Seed everything.
36 | rank = MPI.COMM_WORLD.Get_rank()
37 | rank_seed = args.seed + 1000000 * rank if args.seed is not None else None
38 | set_global_seeds(rank_seed)
39 | return rank
40 |
41 | def main(args):
42 | # process argprase and parameters
43 | args, extra_args = preprocess_kwargs(args)
44 | rank = prepare(args)
45 | env, tmp_env = build_env(args, _game_envs)
46 | params = config.process_params(env, tmp_env, rank, args, extra_args)
47 | dims = config.configure_dims(tmp_env, params)
48 |
49 | # define objects
50 | sampler = config.configure_sampler(dims, params)
51 | buffer = config.configure_buffer(dims, params, sampler)
52 | policy = config.configure_algorithm(dims=dims, params=params, buffer=buffer)
53 | rollout_params, eval_params = config.configure_rollout(params)
54 |
55 | if args.load_path is not None:
56 | tf_util.load_variables(args.load_path)
57 |
58 | rollout_worker = RolloutWorker(env, policy, dims, logger, monitor=True, **rollout_params)
59 | evaluator = RolloutWorker(env, policy, dims, logger, **eval_params)
60 |
61 | n_epochs = config.configure_epoch(args.num_epoch, params)
62 | policy = train(
63 | policy=policy,
64 | rollout_worker=rollout_worker,
65 | save_path=args.save_path,
66 | evaluator=evaluator,
67 | n_epochs=n_epochs,
68 | n_test_rollouts=params['n_test_rollouts'],
69 | n_cycles=params['n_cycles'],
70 | n_batches=params['n_batches'],
71 | policy_save_interval=params['policy_save_interval'],
72 | random_init=params['random_init']
73 | )
74 |
75 | if args.play_episodes or args.play_no_training:
76 | play(policy, env, episodes=args.play_episodes)
77 | env.close()
78 |
79 |
80 | if __name__ == '__main__':
81 | main(sys.argv)
82 |
--------------------------------------------------------------------------------
/mher/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | from mher.samplers.sampler import RandomSampler
2 | from mher.samplers.her_sampler import HER_Sampler
3 | from mher.samplers.nstep_sampler import Nstep_Sampler, Nstep_HER_Sampler
4 | from mher.samplers.prioritized_sampler import PrioritizedSampler, PrioritizedHERSampler
5 |
6 |
7 |
--------------------------------------------------------------------------------
/mher/samplers/her_sampler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from mher.samplers.sampler import RelabelSampler
3 |
4 |
5 | class HER_Sampler(RelabelSampler):
6 | valid_strategy = ['future', 'last', 'random', 'episode', 'cut']
7 | def __init__(self, T, reward_fun, batch_size, relabel_p, strategy, *args):
8 | super(HER_Sampler, self).__init__(T, reward_fun, batch_size, relabel_p)
9 | self.strategy = strategy
10 | self.cur_L = 1
11 | self.inc_L = T / 500
12 |
13 | def _get_relabel_ag(self, episode_batch, episode_idxs, t_samples, num_episodes):
14 | relabel_indexes = self._relabel_idxs()
15 | if self.strategy == 'future' or self.strategy not in self.valid_strategy:
16 | future_offset = (np.random.uniform(size=self.batch_size) * (self.T - t_samples)).astype(int)
17 | future_t = (t_samples + 1 + future_offset)[relabel_indexes]
18 | future_ag = episode_batch['ag'][episode_idxs[relabel_indexes], future_t]
19 | elif self.strategy == 'last':
20 | future_ag = episode_batch['ag'][episode_idxs[relabel_indexes], -1]
21 | elif self.strategy == 'episode':
22 | random_t_samples = np.random.randint(self.T, size=self.batch_size)[relabel_indexes]
23 | future_ag = episode_batch['ag'][episode_idxs[relabel_indexes], random_t_samples]
24 | elif self.strategy == 'cut':
25 | print(int(self.cur_L))
26 | future_offset = (np.random.uniform(size=self.batch_size) * np.minimum(int(self.cur_L), (self.T - t_samples))).astype(int)
27 | future_t = (t_samples + 1 + future_offset)[relabel_indexes]
28 | future_ag = episode_batch['ag'][episode_idxs[relabel_indexes], future_t]
29 | self.cur_L += self.inc_L
30 | else: # self.strategy == 'random'
31 | random_episode_idxs = np.random.randint(0, num_episodes, self.batch_size)[relabel_indexes]
32 | random_t_samples = np.random.randint(self.T, size=self.batch_size)[relabel_indexes]
33 | future_ag = episode_batch['ag'][random_episode_idxs, random_t_samples]
34 | return future_ag, relabel_indexes
35 |
36 | def sample(self, episode_batch):
37 | transitions, info = self._sample_transitions(episode_batch)
38 | relabel_ag, relabel_indexes = self._get_relabel_ag(episode_batch, info['episode_idxs'], info['t_samples'], info['num_episodes'])
39 | transitions = self.relabel_transition(transitions, relabel_indexes, relabel_ag)
40 | transitions = self.reshape_transitions(transitions)
41 | return transitions
42 |
43 | class ClipHER_Sampler(HER_Sampler):
44 | def __init__(self, T, reward_fun, batch_size, relabel_p, num_epoch=200, *args):
45 | super(ClipHER_Sampler, self).__init__(T, reward_fun, batch_size, relabel_p, 'future', *args)
46 | self.cur_L = 1
47 | self.inc_L = T / num_epoch
48 |
49 | def _get_relabel_ag(self, episode_batch, episode_idxs, t_samples, num_episodes):
50 | relabel_indexes = self._relabel_idxs()
51 | future_offset = (np.random.uniform(size=self.batch_size) * np.minimum(int(self.cur_L), (self.T - t_samples))).astype(int)
52 | future_t = (t_samples + 1 + future_offset)[relabel_indexes]
53 | future_ag = episode_batch['ag'][episode_idxs[relabel_indexes], future_t]
54 | return future_ag, relabel_indexes
55 |
56 | def sample(self, episode_batch):
57 | transitions, info = self._sample_transitions(episode_batch)
58 | relabel_ag, relabel_indexes = self._get_relabel_ag(episode_batch, info['episode_idxs'], info['t_samples'], info['num_episodes'])
59 | transitions = self.relabel_transition(transitions, relabel_indexes, relabel_ag)
60 | transitions = self.reshape_transitions(transitions)
61 | self.cur_L += self.inc_L
62 | return transitions
63 |
64 |
--------------------------------------------------------------------------------
/mher/samplers/nstep_sampler.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import numpy as np
3 |
4 | from mher.samplers.sampler import RelabelSampler
5 | from mher.samplers.her_sampler import HER_Sampler
6 |
7 |
8 | class Nstep_Sampler(RelabelSampler):
9 | def __init__(self, T, reward_fun, batch_size, replay_p, nstep, gamma, *args):
10 | super(Nstep_Sampler, self).__init__(T, reward_fun, batch_size, replay_p, *args)
11 | self.nstep = nstep
12 | self.gamma = gamma
13 |
14 | def _sample_nstep_transitions(self, episode_batch):
15 | transitions, info = self._sample_transitions(episode_batch)
16 | episode_idxs, t_samples = info['episode_idxs'], info['t_samples']
17 | transitions['r'] = self.recompute_reward(transitions)
18 | transition_lis = [transitions]
19 | nstep_masks = [np.ones(self.batch_size)]
20 | for i in range(1, self.nstep):
21 | t_samples_i = t_samples + i
22 | out_range_idxs = np.where(t_samples_i > self.T-1)
23 | t_samples_i[out_range_idxs] = self.T - 1
24 | transitions = self._get_transitions(episode_batch, episode_idxs, t_samples_i)
25 | transition_lis.append(transitions)
26 | mask = np.ones(self.batch_size) * pow(self.gamma, i)
27 | mask[out_range_idxs] = 0
28 | nstep_masks.append(mask)
29 | return transition_lis, nstep_masks, info
30 |
31 | def _recompute_nstep_reward(self, transition_lis):
32 | for i in range(len(transition_lis)):
33 | transition_lis[i]['r'] = self.recompute_reward(transition_lis[i])
34 | return transition_lis
35 |
36 | # process to get final transitions
37 | def _get_out_transitions(self, transition_lis, nstep_masks):
38 | out_transitions = copy.deepcopy(transition_lis[0])
39 | final_gamma = np.ones(self.batch_size) * pow(self.gamma, self.nstep) # gamma
40 | for i in range(1, self.nstep):
41 | out_transitions['r'] += nstep_masks[i] * transition_lis[i]['r']
42 | final_gamma[np.where((nstep_masks[i] == 0) & (final_gamma == pow(self.gamma, self.nstep)))] = pow(self.gamma, i)
43 | out_transitions['o_2'] = transition_lis[-1]['o_2'].copy()
44 | out_transitions['gamma'] = final_gamma.copy()
45 | return out_transitions
46 |
47 | def sample(self, episode_batch):
48 | transition_lis, nstep_masks, _ = self._sample_nstep_transitions(episode_batch)
49 | transition_lis = self._recompute_nstep_reward(transition_lis)
50 | out_transitions = self._get_out_transitions(transition_lis, nstep_masks)
51 | self.reshape_transitions(out_transitions)
52 | return out_transitions
53 |
54 |
55 | class Nstep_HER_Sampler(Nstep_Sampler, HER_Sampler):
56 | def __init__(self, T, reward_fun, batch_size, relabel_p, nstep, gamma, strategy):
57 | super().__init__(T, reward_fun, batch_size, relabel_p, nstep, gamma, strategy)
58 |
59 | def relabel_nstep_transitions(self, episode_batch, transition_lis, info):
60 | relabel_ag, relabel_indexes = self._get_relabel_ag(episode_batch, info['episode_idxs'], info['t_samples'], info['num_episodes'])
61 | for i in range(len(transition_lis)):
62 | transitions = transition_lis[i]
63 | transitions = self.relabel_transition(transitions, relabel_indexes, relabel_ag)
64 | transition_lis[i] = transitions
65 | return transition_lis
66 |
67 | def sample(self, episode_batch):
68 | transition_lis, nstep_masks, info = self._sample_nstep_transitions(episode_batch)
69 | transition_lis = self.relabel_nstep_transitions(episode_batch, transition_lis, info)
70 | out_transitions = self._get_out_transitions(transition_lis, nstep_masks)
71 | out_transitions = self.reshape_transitions(out_transitions)
72 | return out_transitions
73 |
--------------------------------------------------------------------------------
/mher/samplers/prioritized_sampler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from mher.samplers.sampler import Sampler
3 | from mher.samplers.her_sampler import HER_Sampler
4 | from mher.common.segment_tree import SumSegmentTree, MinSegmentTree
5 |
6 |
7 | class PrioritizedSampler(Sampler):
8 | def __init__(self, T, reward_fun, batch_size, size_in_transitions, alpha, beta, eps, *args):
9 | '''beta: float To what degree to use importance weights
10 | (0 - no corrections, 1 - full correction)'''
11 | super(PrioritizedSampler, self).__init__(T, reward_fun, batch_size, *args)
12 | assert alpha >= 0 and beta >= 0
13 | self.alpha = alpha
14 | self.beta = beta
15 | self.eps = eps
16 |
17 | capacity = 1
18 | while capacity < size_in_transitions:
19 | capacity *= 2
20 | self.sum_tree = SumSegmentTree(capacity)
21 | self.min_tree = MinSegmentTree(capacity)
22 | self.capacity = size_in_transitions
23 | self._max_priority = 1.0
24 | self.n_transitions_stored = 0
25 |
26 | def update_new_priorities(self, episode_idxs):
27 | N = len(episode_idxs) * self.T
28 | priority_array = np.zeros(N) + self._max_priority
29 | episode_idxs_repeat = (episode_idxs * self.T).repeat(self.T) + np.arange(self.T)
30 | self.update_priorities(episode_idxs_repeat, priority_array)
31 | self.n_transitions_stored += len(episode_idxs) * self.T
32 | self.n_transitions_stored = min(self.n_transitions_stored, self.capacity)
33 |
34 | def update_priorities(self, idxes, priorities):
35 | """Update priorities of sampled transitions"""
36 | assert len(idxes) == len(priorities) and np.all(priorities >= 0)
37 | priorities += self.eps # avoid zero
38 | new_priority = np.power(priorities.flatten(), self.alpha)
39 | self.sum_tree.set_items(idxes, new_priority)
40 | self.min_tree.set_items(idxes, new_priority)
41 | self._max_priority = max(np.max(priorities), self._max_priority)
42 |
43 | def _sample_idxes(self):
44 | culm_sums = np.random.random(size=self.batch_size) * self.sum_tree.sum()
45 | idxes = np.zeros(self.batch_size)
46 | for i in range(self.batch_size):
47 | idxes[i] = self.sum_tree.find_prefixsum_idx(culm_sums[i])
48 | episode_idxs = idxes // self.T
49 | t_samples = idxes % self.T
50 | return episode_idxs.astype(np.int), t_samples.astype(np.int), idxes.astype(np.int)
51 |
52 | def priority_sample(self, episode_batch):
53 | episode_idxs, t_samples, idxes = self._sample_idxes()
54 | p_min = self.min_tree.min() / self.sum_tree.sum()
55 | transitions = self._get_transitions(episode_batch, episode_idxs, t_samples)
56 | p_samples = self.sum_tree.get_items(idxes) / self.sum_tree.sum()
57 | weights = np.power(p_samples / p_min, - self.beta)
58 | transitions['w'] = weights
59 | info = {
60 | 'episode_idxs': episode_idxs,
61 | 't_samples': t_samples,
62 | 'idxes': idxes,
63 | 'num_episodes': episode_batch['u'].shape[0]
64 | }
65 | return transitions, info
66 |
67 | def sample(self, episode_batch):
68 | transitions, info = self.priority_sample(episode_batch)
69 | transitions['r'] = self.recompute_reward(transitions)
70 | transitions = self.reshape_transitions(transitions)
71 | return (transitions, info['idxes'])
72 |
73 |
74 | class PrioritizedHERSampler(PrioritizedSampler, HER_Sampler):
75 | '''not good with relabeling after prioritized sampling'''
76 | def __init__(self, T, reward_fun, batch_size, size_in_transitions, alpha, beta, eps, relabel_p, strategy):
77 | super().__init__(T, reward_fun, batch_size, size_in_transitions, alpha, beta, eps, relabel_p, strategy)
78 |
79 | def sample(self, episode_batch):
80 | transitions, info = self.priority_sample(episode_batch)
81 | relabel_ag, relabel_indexes = self._get_relabel_ag(episode_batch, info['episode_idxs'], info['t_samples'], info['num_episodes'])
82 | transitions = self.relabel_transition(transitions, relabel_indexes, relabel_ag)
83 | transitions = self.reshape_transitions(transitions)
84 | return (transitions, info['idxes'])
85 |
86 |
87 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/mher/samplers/sampler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Sampler:
5 | def __init__(self, T, reward_fun, batch_size):
6 | self.T = T
7 | self.reward_fun = reward_fun
8 | self.batch_size = batch_size
9 |
10 | def _get_transitions(self, episode_batch, episode_idxs, t_samples):
11 | return {key: episode_batch[key][episode_idxs, t_samples].copy()
12 | for key in episode_batch.keys()}
13 |
14 | def _sample_transitions(self, episode_batch):
15 | num_episodes = episode_batch['u'].shape[0]
16 | episode_idxs = np.random.randint(0, num_episodes, self.batch_size)
17 | t_samples = np.random.randint(self.T, size=self.batch_size)
18 | transitions = self._get_transitions(episode_batch, episode_idxs, t_samples)
19 | info = {
20 | 'num_episodes': num_episodes,
21 | 'episode_idxs':episode_idxs,
22 | 't_samples':t_samples
23 | }
24 | return transitions, info
25 |
26 | def recompute_reward(self, transitions):
27 | # Reconstruct info dictionary for reward computation.
28 | info = {}
29 | for key, value in transitions.items():
30 | if key.startswith('info_'):
31 | info[key.replace('info_', '')] = value
32 | # Re-compute reward since we may have substituted the goal.
33 | reward_params = {k: transitions[k] for k in ['ag_2', 'g']}
34 | reward_params['info'] = info
35 | return self.reward_fun(**reward_params)
36 |
37 | def reshape_transitions(self, transitions):
38 | transitions = {k: transitions[k].reshape(self.batch_size, *transitions[k].shape[1:])
39 | for k in transitions.keys()}
40 | assert(transitions['u'].shape[0] == self.batch_size)
41 | return transitions
42 |
43 | def sample(self, episode_batch):
44 | pass
45 |
46 | class RandomSampler(Sampler):
47 | def sample(self, episode_batch):
48 | transitions, _ = self._sample_transitions(episode_batch)
49 | transitions['r'] = self.recompute_reward(transitions)
50 | transitions = self.reshape_transitions(transitions)
51 | return transitions
52 |
53 | class RelabelSampler(Sampler):
54 | def __init__(self, T, reward_fun, batch_size, relabel_p):
55 | '''relabel_p defines the probability for relabeling'''
56 | super(RelabelSampler, self).__init__(T, reward_fun, batch_size)
57 | self.relabel_p = relabel_p
58 |
59 | def _relabel_idxs(self):
60 | return (np.random.uniform(size=self.batch_size) < self.relabel_p)
61 |
62 | def relabel_transition(self, transitions, relabel_indexes, relabel_ag):
63 | assert relabel_indexes.sum() == len(relabel_ag)
64 | transitions['g'][relabel_indexes] = relabel_ag
65 | transitions['r'] = self.recompute_reward(transitions)
66 | return transitions
67 |
--------------------------------------------------------------------------------
/mher/train.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import time
4 |
5 | import click
6 | import numpy as np
7 | from mpi4py import MPI
8 |
9 | import mher.config as config
10 | from mher.common import logger
11 | from mher.common.mpi_moments import mpi_moments
12 | from mher.rollouts.rollout import RolloutWorker
13 |
14 |
15 | def mpi_average(value):
16 | if not isinstance(value, list):
17 | value = [value]
18 | if not any(value):
19 | value = [0.]
20 | return mpi_moments(np.array(value))[0]
21 |
22 |
23 | def train(*, policy, rollout_worker, evaluator, n_epochs, n_test_rollouts, n_cycles,
24 | n_batches, policy_save_interval, save_path, random_init, **kwargs):
25 | rank = MPI.COMM_WORLD.Get_rank()
26 | if save_path:
27 | latest_policy_path = os.path.join(save_path, 'policy_latest.pkl')
28 | best_policy_path = os.path.join(save_path, 'policy_best.pkl')
29 | periodic_policy_path = os.path.join(save_path, 'policy_{}.pkl')
30 |
31 | # random_init buffer and o/g/u stat
32 | if random_init:
33 | logger.info('Random initializing ...')
34 | rollout_worker.clear_history()
35 | for epi in range(int(random_init) // rollout_worker.rollout_batch_size):
36 | episode = rollout_worker.generate_rollouts(random_ac=True)
37 | policy.store_episode(episode)
38 | if policy.use_dynamic_nstep and policy.n_step > 1:
39 | policy.update_dynamic_model(init=True)
40 |
41 | best_success_rate = -1
42 | logger.info('Start training...')
43 | # num_timesteps = n_epochs * n_cycles * rollout_length * number of rollout workers
44 | for epoch in range(n_epochs):
45 | time_start = time.time()
46 | # train
47 | rollout_worker.clear_history()
48 | for i in range(n_cycles):
49 | policy.dynamic_batch = False
50 | episode = rollout_worker.generate_rollouts()
51 | policy.store_episode(episode)
52 | for j in range(n_batches):
53 | policy.train()
54 | policy.update_target_net()
55 |
56 | # test
57 | evaluator.clear_history()
58 | for _ in range(n_test_rollouts):
59 | evaluator.generate_rollouts()
60 |
61 | # record logs
62 | time_end = time.time()
63 | logger.record_tabular('epoch', epoch)
64 | logger.record_tabular('epoch time(min)', (time_end - time_start)/60)
65 | for key, val in evaluator.logs('test'):
66 | logger.record_tabular(key, mpi_average(val))
67 | for key, val in rollout_worker.logs('train'):
68 | logger.record_tabular(key, mpi_average(val))
69 | for key, val in policy.logs_stats():
70 | logger.record_tabular(key, mpi_average(val))
71 |
72 | if rank == 0:
73 | logger.dump_tabular()
74 |
75 | # save the policy if it's better than the previous ones
76 | success_rate = mpi_average(evaluator.current_success_rate())
77 | if rank == 0 and success_rate > best_success_rate and save_path:
78 | best_success_rate = success_rate
79 | logger.info('New best success rate: {}. Saving policy to {} ...'.format(best_success_rate, best_policy_path))
80 | policy.save(best_policy_path)
81 | policy.save(latest_policy_path)
82 | if rank == 0 and policy_save_interval > 0 and epoch % policy_save_interval == 0 and save_path:
83 | policy_path = periodic_policy_path.format(epoch)
84 | logger.info('Saving periodic policy to {} ...'.format(policy_path))
85 | policy.save(policy_path)
86 |
87 | # make sure that different threads have different seeds
88 | local_uniform = np.random.uniform(size=(1,))
89 | root_uniform = local_uniform.copy()
90 | MPI.COMM_WORLD.Bcast(root_uniform, root=0)
91 | if rank != 0:
92 | assert local_uniform[0] != root_uniform[0]
93 |
94 | if rank == 0 and save_path:
95 | policy_path = periodic_policy_path.format(epoch)
96 | logger.info('Saving final policy to {} ...'.format(policy_path))
97 | policy.save(policy_path)
98 |
99 | return policy
100 |
101 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import re
2 | from setuptools import setup, find_packages
3 | import sys
4 |
5 | if sys.version_info.major != 3:
6 | print('This Python is only compatible with Python 3, but you are running '
7 | 'Python {}. The installation will likely fail.'.format(sys.version_info.major))
8 |
9 |
10 | extras = {
11 | 'test': [
12 | 'filelock',
13 | 'pytest',
14 | 'pytest-forked',
15 | 'atari-py',
16 | 'matplotlib',
17 | 'pandas'
18 | ],
19 | 'mpi': [
20 | 'mpi4py'
21 | ]
22 | }
23 |
24 | all_deps = []
25 | for group_name in extras:
26 | all_deps += extras[group_name]
27 |
28 | extras['all'] = all_deps
29 |
30 | setup(name='mher',
31 | # packages=[package for package in find_packages()
32 | # if package.startswith('baselines')],
33 |
34 | packages = find_packages(),
35 | install_requires=[
36 | 'gym>=0.15.4, <0.16.0',
37 | 'scipy',
38 | 'tqdm',
39 | 'joblib',
40 | 'cloudpickle',
41 | 'click',
42 | 'opencv-python'
43 | ],
44 | extras_require=extras,
45 | description='Modular HER: based on OpenAI baselines',
46 | author='RuiYang',
47 | url='https://github.com/YangRui2015/Modular_HER',
48 | author_email='yangrui19@mails.tsinghua.edu.cn',
49 | version='1.0')
50 |
51 |
52 | # ensure there is some tensorflow build with version above 1.4
53 | import pkg_resources
54 | tf_pkg = None
55 | for tf_pkg_name in ['tensorflow', 'tensorflow-gpu', 'tf-nightly', 'tf-nightly-gpu']:
56 | try:
57 | tf_pkg = pkg_resources.get_distribution(tf_pkg_name)
58 | except pkg_resources.DistributionNotFound:
59 | pass
60 | assert tf_pkg is not None, 'TensorFlow needed, of version above 1.4'
61 | from distutils.version import LooseVersion
62 | assert LooseVersion(re.sub(r'-?rc\d+$', '', tf_pkg.version)) >= LooseVersion('1.4.0')
63 |
--------------------------------------------------------------------------------