├── .gitignore
├── README.md
├── dnc
├── algos
│ ├── batch_polopt.py
│ ├── npo.py
│ └── trpo.py
├── envs
│ ├── __init__.py
│ ├── ant.py
│ ├── assets
│ │ ├── ant.xml
│ │ ├── catch.xml
│ │ ├── jaco_meshes
│ │ │ ├── jaco_link_1.stl
│ │ │ ├── jaco_link_2.stl
│ │ │ ├── jaco_link_3.stl
│ │ │ ├── jaco_link_4.stl
│ │ │ ├── jaco_link_5.stl
│ │ │ ├── jaco_link_base.stl
│ │ │ ├── jaco_link_finger_1.stl
│ │ │ ├── jaco_link_finger_2.stl
│ │ │ ├── jaco_link_finger_3.stl
│ │ │ └── jaco_link_hand.stl
│ │ ├── lob.xml
│ │ └── picker.xml
│ ├── base.py
│ ├── catch.py
│ ├── lob.py
│ └── picker.py
└── sampler
│ └── policy_sampler.py
├── examples
├── dnc_pick.py
├── trpo_pick.py
└── trpo_pick_nonoise.py
├── scripts
└── sim_policy.py
├── setup.py
└── videos
├── catching.gif
└── lobbing.gif
/.gitignore:
--------------------------------------------------------------------------------
1 | data/
2 | *.pkl
3 |
4 | .vscode/
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # pyenv
80 | .python-version
81 |
82 | # celery beat schedule file
83 | celerybeat-schedule
84 |
85 | # SageMath parsed files
86 | *.sage.py
87 |
88 | # Environments
89 | .env
90 | .venv
91 | env/
92 | venv/
93 | ENV/
94 | env.bak/
95 | venv.bak/
96 |
97 | # Spyder project settings
98 | .spyderproject
99 | .spyproject
100 |
101 | # Rope project settings
102 | .ropeproject
103 |
104 | # mkdocs documentation
105 | /site
106 |
107 | # mypy
108 | .mypy_cache/
109 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Divide-and-Conquer Reinforcement Learning
2 |
3 | This repository contains code accompaning the paper, [Divide-and-Conquer Reinforcement Learning (Ghosh et al., ICLR 2018)](https://arxiv.org/abs/1711.09874). It includes code for the DnC algorithm, and the Mujoco environments used for the empirical evaluation. Please see [the project website](http://dibyaghosh.com/dnc/) for videos and further details.
4 |
5 |
6 |
7 | ### Dependencies
8 |
9 | This codebase requires a valid installation of `rllab`. Please refer to the [rllab repository](https://github.com/rll/rllab) for installation instructions.
10 |
11 | The environments are built in Mujoco 1.31: follow the instructions [here](https://github.com/openai/mujoco-py/tree/0.5) to install Mujoco 1.31 if not already done. You are required to have a Mujoco license to run any of the environments.
12 |
13 | ### Usage
14 |
15 | Sample scripts for working with DnC and the provided environments can be found in the [examples](examples/) directory. In particular, a sample scripts for running DnC is located [here](examples/dnc_pick.py).
16 |
17 | ```bash
18 | source activate rllab_env
19 | python examples/dnc_pick.py
20 | ```
21 |
22 | Environments are located in the [dnc/envs/](dnc/envs/) directory, and the DnC implementation can be found at [dnc/algos/](dnc/algos).
23 |
24 | ### Contact
25 | To ask questions or report issues, please open an issue on the [issues tracker](https://github.com/dibyaghosh/dnc/issues).
26 |
27 | ### Citing
28 |
29 | If you use DnC, please cite the following paper:
30 |
31 | - Dibya Ghosh, Avi Singh, Aravind Rajeswaran, Vikash Kumar, Sergey Levine. "[Divide-and-Conquer Reinforcement Learning](https://arxiv.org/abs/1711.09874)". _Proceedings of the International Conference on Learning Representaions (ICLR), 2018._
32 |
--------------------------------------------------------------------------------
/dnc/algos/batch_polopt.py:
--------------------------------------------------------------------------------
1 | from rllab.algos.base import RLAlgorithm
2 | import rllab.misc.logger as logger
3 |
4 | from dnc.sampler.policy_sampler import Sampler
5 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline
6 |
7 | import tensorflow as tf
8 | import numpy as np
9 | import time
10 |
11 |
12 | class BatchPolopt(RLAlgorithm):
13 | """
14 | Base class for batch sampling-based policy optimization methods.
15 | This includes various policy gradient methods like vpg, npg, ppo, trpo, etc
16 | """
17 |
18 | def __init__(
19 | self,
20 | # DnC Options
21 | env,
22 | partitions,
23 | policy_class,
24 | policy_kwargs=dict(),
25 | baseline_class=LinearFeatureBaseline,
26 | baseline_kwargs=dict(),
27 | distillation_period=50,
28 | # Base BatchPolopt Options
29 | scope=None,
30 | n_itr=500,
31 | start_itr=0,
32 | batch_size=5000, # Batch size for each partition
33 | max_path_length=500,
34 | discount=0.99,
35 | gae_lambda=1,
36 | plot=False,
37 | pause_for_plot=False,
38 | center_adv=True,
39 | positive_adv=False,
40 | store_paths=False,
41 | whole_paths=True,
42 | fixed_horizon=False,
43 | force_batch_sampler=False,
44 | **kwargs
45 | ):
46 | """
47 | DnC options
48 |
49 | :param env: Central environment trying to solve
50 | :param partitions: A list of environments to use as partitions for central environment
51 | :param policy_class: The policy class to use for global and local policies (for example GaussianMLPPolicy from sandbox.rocky.tf.policies.gaussian_mlp_policy)
52 | :param policy_kwargs: A dictionary of additional parameters used for policy (beyond name and env_spec)
53 | :param baseline_class: The baseline class used for local policies (for example LinearFeatureBaseline from rllab.baselines.linear_feature_baseline)
54 | :param baseline_kwargs: A dictionary of additional parameters used for baselien (beyond env_spec)
55 | :param distillation_period: How often to distill local policies into global policy, and reset
56 |
57 | Base RLLAB options
58 |
59 | :param scope: Scope for identifying the algorithm. Must be specified if running multiple algorithms simultaneously, each using different environments and policies
60 | :param n_itr: Number of iterations.
61 | :param start_itr: Starting iteration.
62 | :param batch_size: Number of samples per iteration.
63 | :param max_path_length: Maximum length of a single rollout.
64 | :param discount: Discount.
65 | :param gae_lambda: Lambda used for generalized advantage estimation.
66 | :param plot: Plot evaluation run after each iteration.
67 | :param pause_for_plot: Whether to pause before contiuing when plotting.
68 | :param center_adv: Whether to rescale the advantages so that they have mean 0 and standard deviation 1.
69 | :param positive_adv: Whether to shift the advantages so that they are always positive. When used in
70 | conjunction with center_adv the advantages will be standardized before shifting.
71 | :param store_paths: Whether to save all paths data to the snapshot.
72 | :return:
73 | """
74 |
75 | self.env = env
76 | self.env_partitions = partitions
77 | self.n_parts = len(self.env_partitions)
78 |
79 | self.policy = policy_class(name='central_policy', env_spec=env.spec, **policy_kwargs)
80 |
81 | self.local_policies = [
82 | policy_class(name='local_policy_%d' % (n), env_spec=env.spec, **policy_kwargs) for n in range(self.n_parts)
83 | ]
84 |
85 | self.baseline = baseline_class(env_spec=env.spec, **baseline_kwargs)
86 |
87 | self.local_baselines = [
88 | baseline_class(env_spec=env.spec, **baseline_kwargs) for n in range(self.n_parts)
89 | ]
90 |
91 | self.distillation_period = distillation_period
92 |
93 | # Housekeeping
94 |
95 | self.scope = scope
96 | self.n_itr = n_itr
97 | self.start_itr = start_itr
98 | self.batch_size = batch_size
99 | self.max_path_length = max_path_length
100 | self.discount = discount
101 | self.gae_lambda = gae_lambda
102 | self.plot = plot
103 | self.pause_for_plot = pause_for_plot
104 | self.center_adv = center_adv
105 | self.positive_adv = positive_adv
106 | self.store_paths = store_paths
107 | self.whole_paths = whole_paths
108 | self.fixed_horizon = fixed_horizon
109 |
110 | self.local_samplers = [
111 | Sampler(
112 | env=env,
113 | policy=policy,
114 | baseline=baseline,
115 | scope=scope,
116 | batch_size=batch_size,
117 | max_path_length=max_path_length,
118 | discount=discount,
119 | gae_lambda=gae_lambda,
120 | center_adv=center_adv,
121 | positive_adv=positive_adv,
122 | whole_paths=whole_paths,
123 | fixed_horizon=fixed_horizon,
124 | force_batch_sampler=force_batch_sampler
125 | ) for env, policy, baseline in zip(self.env_partitions, self.local_policies, self.local_baselines)
126 | ]
127 |
128 | self.global_sampler = Sampler(
129 | env=self.env,
130 | policy=self.policy,
131 | baseline=self.baseline,
132 | scope=scope,
133 | batch_size=batch_size,
134 | max_path_length=max_path_length,
135 | discount=discount,
136 | gae_lambda=gae_lambda,
137 | center_adv=center_adv,
138 | positive_adv=positive_adv,
139 | whole_paths=whole_paths,
140 | fixed_horizon=fixed_horizon,
141 | force_batch_sampler=force_batch_sampler
142 | )
143 |
144 | self.init_opt()
145 |
146 | def start_worker(self):
147 | for sampler in self.local_samplers:
148 | sampler.start_worker()
149 | self.global_sampler.start_worker()
150 |
151 | def shutdown_worker(self):
152 | for sampler in self.local_samplers:
153 | sampler.shutdown_worker()
154 | self.global_sampler.shutdown_worker()
155 |
156 | def train(self, sess=None):
157 | if sess is None:
158 | config = tf.ConfigProto()
159 | config.gpu_options.allow_growth = True
160 | sess = tf.Session(config=config)
161 | sess.__enter__()
162 | sess.run(tf.initialize_all_variables())
163 | else:
164 | sess.run(tf.initialize_variables(list(tf.get_variable(name) for name in sess.run(tf.report_uninitialized_variables()))))
165 |
166 | self.start_worker()
167 | start_time = time.time()
168 |
169 | for itr in range(self.start_itr, self.n_itr):
170 | itr_start_time = time.time()
171 |
172 | with logger.prefix('itr #%d | ' % itr):
173 | all_paths = []
174 | logger.log("Obtaining samples...")
175 | for sampler in self.local_samplers:
176 | all_paths.append(sampler.obtain_samples(itr))
177 |
178 | logger.log("Processing samples...")
179 | all_samples_data = []
180 | for n, (sampler, paths) in enumerate(zip(self.local_samplers, all_paths)):
181 | with logger.tabular_prefix(str(n)):
182 | all_samples_data.append(sampler.process_samples(itr, paths))
183 |
184 | logger.log("Logging diagnostics...")
185 | self.log_diagnostics(all_paths,)
186 |
187 | logger.log("Optimizing policy...")
188 | self.optimize_policy(itr, all_samples_data)
189 |
190 | logger.log("Saving snapshot...")
191 | params = self.get_itr_snapshot(itr, all_samples_data) # , **kwargs)
192 | logger.save_itr_params(itr, params)
193 |
194 | logger.log("Saved")
195 | logger.record_tabular('Time', time.time() - start_time)
196 | logger.record_tabular('ItrTime', time.time() - itr_start_time)
197 | logger.dump_tabular(with_prefix=False)
198 |
199 | self.shutdown_worker()
200 |
201 | def log_diagnostics(self, all_paths):
202 | for n, (env, policy, baseline, paths) in enumerate(zip(self.env_partitions, self.local_policies, self.local_baselines, all_paths)):
203 | with logger.tabular_prefix(str(n)):
204 | env.log_diagnostics(paths)
205 | policy.log_diagnostics(paths)
206 | baseline.log_diagnostics(paths)
207 |
208 | def init_opt(self):
209 | """
210 | Initialize the optimization procedure. If using tensorflow, this may
211 | include declaring all the variables and compiling functions
212 | """
213 | raise NotImplementedError()
214 |
215 | def get_itr_snapshot(self, itr, samples_data):
216 | """
217 | Returns all the data that should be saved in the snapshot for this
218 | iteration.
219 | """
220 |
221 | d = dict()
222 | for n, (policy, env) in enumerate(zip(self.local_policies, self.env_partitions)):
223 | d['policy%d' % n] = policy
224 | d['env%d' % n] = env
225 |
226 | d['policy'] = self.policy
227 | d['env'] = self.env
228 |
229 | return d
230 |
231 | def optimize_policy(self, itr, samples_data):
232 | """
233 | Runs the optimization procedure
234 | """
235 | raise NotImplementedError()
236 |
--------------------------------------------------------------------------------
/dnc/algos/npo.py:
--------------------------------------------------------------------------------
1 | # Base imports
2 | import numpy as np
3 | import tensorflow as tf
4 | from dnc.algos.batch_polopt import BatchPolopt
5 |
6 | # Optimizers
7 | from sandbox.rocky.tf.optimizers.penalty_lbfgs_optimizer import PenaltyLbfgsOptimizer
8 | from sandbox.rocky.tf.optimizers.first_order_optimizer import FirstOrderOptimizer
9 |
10 | # Utilities
11 | from rllab.misc.ext import sliced_fun
12 | from rllab.misc.overrides import overrides
13 | from sandbox.rocky.tf.misc import tensor_utils
14 | from rllab.misc import ext
15 |
16 | # Logging
17 | import rllab.misc.logger as logger
18 |
19 |
20 | # Convenience Function
21 | def default(variable, defaultValue):
22 | return variable if variable is not None else defaultValue
23 |
24 |
25 | class NPO(BatchPolopt):
26 | """
27 | Natural Policy Optimization.
28 | """
29 |
30 | def __init__(
31 | self,
32 | optimizer_class=None,
33 | optimizer_args=None,
34 | step_size=0.01,
35 | penalty=0.0,
36 | **kwargs
37 | ):
38 |
39 | self.optimizer_class = default(optimizer_class, PenaltyLbfgsOptimizer)
40 | self.optimizer_args = default(optimizer_args, dict())
41 |
42 | self.penalty = penalty
43 | self.constrain_together = penalty > 0
44 |
45 | self.step_size = step_size
46 |
47 | self.metrics = []
48 | super(NPO, self).__init__(**kwargs)
49 |
50 | @overrides
51 | def init_opt(self):
52 |
53 | ###############################
54 | #
55 | # Variable Definitions
56 | #
57 | ###############################
58 |
59 | all_task_dist_info_vars = []
60 | all_obs_vars = []
61 |
62 | for i, policy in enumerate(self.local_policies):
63 |
64 | task_obs_var = self.env_partitions[i].observation_space.new_tensor_variable('obs%d' % i, extra_dims=1)
65 | task_dist_info_vars = []
66 |
67 | for j, other_policy in enumerate(self.local_policies):
68 |
69 | state_info_vars = dict() # Not handling recurrent policies
70 | dist_info_vars = other_policy.dist_info_sym(task_obs_var, state_info_vars)
71 | task_dist_info_vars.append(dist_info_vars)
72 |
73 | all_obs_vars.append(task_obs_var)
74 | all_task_dist_info_vars.append(task_dist_info_vars)
75 |
76 | obs_var = self.env.observation_space.new_tensor_variable('obs', extra_dims=1)
77 | action_var = self.env.action_space.new_tensor_variable('action', extra_dims=1)
78 | advantage_var = tensor_utils.new_tensor('advantage', ndim=1, dtype=tf.float32)
79 |
80 | old_dist_info_vars = {
81 | k: tf.placeholder(tf.float32, shape=[None] + list(shape), name='old_%s' % k)
82 | for k, shape in self.policy.distribution.dist_info_specs
83 | }
84 |
85 | old_dist_info_vars_list = [old_dist_info_vars[k] for k in self.policy.distribution.dist_info_keys]
86 |
87 | input_list = [obs_var, action_var, advantage_var] + old_dist_info_vars_list + all_obs_vars
88 |
89 | ###############################
90 | #
91 | # Local Policy Optimization
92 | #
93 | ###############################
94 |
95 | self.optimizers = []
96 | self.metrics = []
97 |
98 | for n, policy in enumerate(self.local_policies):
99 |
100 | state_info_vars = dict()
101 | dist_info_vars = policy.dist_info_sym(obs_var, state_info_vars)
102 | dist = policy.distribution
103 |
104 | kl = dist.kl_sym(old_dist_info_vars, dist_info_vars)
105 | lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars, dist_info_vars)
106 | surr_loss = - tf.reduce_mean(lr * advantage_var)
107 |
108 | if self.constrain_together:
109 | additional_loss = Metrics.kl_on_others(n, dist, all_task_dist_info_vars)
110 | else:
111 | additional_loss = tf.constant(0.0)
112 |
113 | local_loss = surr_loss + self.penalty * additional_loss
114 |
115 | kl_metric = tensor_utils.compile_function(inputs=input_list, outputs=additional_loss, log_name="KLPenalty%d" % n)
116 | self.metrics.append(kl_metric)
117 |
118 | mean_kl_constraint = tf.reduce_mean(kl)
119 |
120 | optimizer = self.optimizer_class(**self.optimizer_args)
121 | optimizer.update_opt(
122 | loss=local_loss,
123 | target=policy,
124 | leq_constraint=(mean_kl_constraint, self.step_size),
125 | inputs=input_list,
126 | constraint_name="mean_kl_%d" % n,
127 | )
128 | self.optimizers.append(optimizer)
129 |
130 | ###############################
131 | #
132 | # Global Policy Optimization
133 | #
134 | ###############################
135 |
136 | # Behaviour Cloning Loss
137 |
138 | state_info_vars = dict()
139 | center_dist_info_vars = self.policy.dist_info_sym(obs_var, state_info_vars)
140 | behaviour_cloning_loss = tf.losses.mean_squared_error(action_var, center_dist_info_vars['mean'])
141 | self.center_optimizer = FirstOrderOptimizer(max_epochs=1, verbose=True, batch_size=1000)
142 | self.center_optimizer.update_opt(behaviour_cloning_loss, self.policy, [obs_var, action_var])
143 |
144 | # TRPO Loss
145 |
146 | kl = dist.kl_sym(old_dist_info_vars, center_dist_info_vars)
147 | lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars, center_dist_info_vars)
148 | center_trpo_loss = - tf.reduce_mean(lr * advantage_var)
149 | mean_kl_constraint = tf.reduce_mean(kl)
150 |
151 | optimizer = self.optimizer_class(**self.optimizer_args)
152 | optimizer.update_opt(
153 | loss=center_trpo_loss,
154 | target=self.policy,
155 | leq_constraint=(mean_kl_constraint, self.step_size),
156 | inputs=[obs_var, action_var, advantage_var] + old_dist_info_vars_list,
157 | constraint_name="mean_kl_center",
158 | )
159 |
160 | self.center_trpo_optimizer = optimizer
161 |
162 | # Reset Local Policies to Global Policy
163 |
164 | assignment_operations = []
165 |
166 | for policy in self.local_policies:
167 | for param_local, param_center in zip(policy.get_params_internal(), self.policy.get_params_internal()):
168 | if 'std' not in param_local.name:
169 | assignment_operations.append(tf.assign(param_local, param_center))
170 |
171 | self.reset_to_center = tf.group(*assignment_operations)
172 |
173 | return dict()
174 |
175 | def optimize_local_policies(self, itr, all_samples_data):
176 |
177 | dist_info_keys = self.policy.distribution.dist_info_keys
178 | for n, optimizer in enumerate(self.optimizers):
179 |
180 | obs_act_adv_values = tuple(ext.extract(all_samples_data[n], "observations", "actions", "advantages"))
181 | dist_info_list = tuple([all_samples_data[n]["agent_infos"][k] for k in dist_info_keys])
182 | all_task_obs_values = tuple([samples_data["observations"] for samples_data in all_samples_data])
183 |
184 | all_input_values = obs_act_adv_values + dist_info_list + all_task_obs_values
185 | optimizer.optimize(all_input_values)
186 |
187 | kl_penalty = sliced_fun(self.metrics[n], 1)(all_input_values)
188 | logger.record_tabular('KLPenalty%d' % n, kl_penalty)
189 |
190 | def optimize_global_policy(self, itr, all_samples_data):
191 |
192 | all_observations = np.concatenate([samples_data['observations'] for samples_data in all_samples_data])
193 | all_actions = np.concatenate([samples_data['agent_infos']['mean'] for samples_data in all_samples_data])
194 |
195 | num_itrs = 1 if itr % self.distillation_period != 0 else 30
196 |
197 | for _ in range(num_itrs):
198 | self.center_optimizer.optimize([all_observations, all_actions])
199 |
200 | paths = self.global_sampler.obtain_samples(itr)
201 | samples_data = self.global_sampler.process_samples(itr, paths)
202 |
203 | obs_values = tuple(ext.extract(samples_data, "observations", "actions", "advantages"))
204 | dist_info_list = [samples_data["agent_infos"][k] for k in self.policy.distribution.dist_info_keys]
205 |
206 | all_input_values = obs_values + tuple(dist_info_list)
207 |
208 | self.center_trpo_optimizer.optimize(all_input_values)
209 | self.env.log_diagnostics(paths)
210 |
211 | @overrides
212 | def optimize_policy(self, itr, all_samples_data):
213 |
214 | self.optimize_local_policies(itr, all_samples_data)
215 | self.optimize_global_policy(itr, all_samples_data)
216 |
217 | if itr % self.distillation_period == 0:
218 | sess = tf.get_default_session()
219 | sess.run(self.reset_to_center)
220 | logger.log('Reset Local Policies to Global Policies')
221 |
222 | return dict()
223 |
224 | ############################
225 | #
226 | # KL Divergence
227 | #
228 | ############################
229 |
230 |
231 | class Metrics:
232 | @staticmethod
233 | def symmetric_kl(dist, info_vars_1, info_vars_2):
234 | side1 = tf.reduce_mean(dist.kl_sym(info_vars_2, info_vars_1))
235 | side2 = tf.reduce_mean(dist.kl_sym(info_vars_1, info_vars_2))
236 | return (side1 + side2) / 2
237 |
238 | @staticmethod
239 | def kl_on_others(n, dist, dist_info_vars):
240 | # \sum_{j=1} E_{\sim S_j}[D_{kl}(\pi_j || \pi_i)]
241 | if len(dist_info_vars) < 2:
242 | return 0
243 |
244 | kl_with_others = 0
245 | for i in range(len(dist_info_vars)):
246 | if i != n:
247 | kl_with_others += Metrics.symmetric_kl(dist, dist_info_vars[i][i], dist_info_vars[i][n])
248 |
249 | return kl_with_others / (len(dist_info_vars) - 1)
250 |
--------------------------------------------------------------------------------
/dnc/algos/trpo.py:
--------------------------------------------------------------------------------
1 | from dnc.algos.npo import NPO
2 | from sandbox.rocky.tf.optimizers.conjugate_gradient_optimizer import ConjugateGradientOptimizer
3 |
4 |
5 | class TRPO(NPO):
6 | """
7 | Trust Region Policy Optimization
8 |
9 | Please refer to the following classes for parameters:
10 | - dnc.algos.batch_polopt
11 | - dnc.algos.npo
12 |
13 | """
14 |
15 | def __init__(
16 | self,
17 | **kwargs
18 | ):
19 | super(TRPO, self).__init__(
20 | optimizer_class=ConjugateGradientOptimizer,
21 | optimizer_args=dict(),
22 | **kwargs
23 | )
24 |
--------------------------------------------------------------------------------
/dnc/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from dnc.envs.picker import PickerEnv
2 | from dnc.envs.lob import LobberEnv
3 | from dnc.envs.ant import AntEnv
4 | from dnc.envs.catch import CatchEnv
5 |
6 | from dnc.envs.base import create_env_partitions
7 |
8 | import numpy as np
9 |
10 | _envs = {
11 | 'pick': PickerEnv,
12 | 'lob': LobberEnv,
13 | 'catch': CatchEnv,
14 | 'ant': AntEnv
15 | }
16 |
17 | _stochastic_params = {
18 | 'pick': dict(goal_args=('noisy',(.6,.2),.1)),
19 | 'lob': dict(box_center=(0,0), box_noise=0.4),
20 | 'catch': dict(start_pos=(.1,1.7), start_noise=0.2),
21 | 'ant': dict(angle_range=(0,2*np.pi)),
22 | }
23 |
24 | _deterministic_params = {
25 | 'pick': dict(goal_args=('noisy',(.6,.2),0)),
26 | 'lob': dict(box_center=(0,0), box_noise=0),
27 | 'catch': dict(start_pos=(.1,1.7), start_noise=0),
28 | 'ant': dict(angle_range=(-1e-4,1e-4)),
29 | }
30 |
31 | def create_stochastic(name):
32 | assert name in _stochastic_params
33 | return _envs[name](**_stochastic_params[name])
34 |
35 | def create_deterministic(name):
36 | assert name in _deterministic_params
37 | return _envs[name](**_deterministic_params[name])
38 |
39 | def test_env(env, n_rolls=5, n_steps=50):
40 | for i in range(n_rolls):
41 | env.reset()
42 | for t in range(n_steps):
43 | env.step(env.action_space.sample())
44 | env.render()
45 | env.render(close=True)
--------------------------------------------------------------------------------
/dnc/envs/ant.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from dnc.envs.base import KMeansEnv
4 | from rllab.envs.base import Step
5 |
6 | from rllab.core.serializable import Serializable
7 | from rllab.misc.overrides import overrides
8 | from rllab.misc import logger
9 |
10 | import os.path as osp
11 |
12 |
13 | class AntEnv(KMeansEnv, Serializable):
14 |
15 | FILE = osp.join(osp.abspath(osp.dirname(__file__)), 'assets/ant.xml')
16 |
17 | def __init__(self, angle_range=(0,2*np.pi), frame_skip=2, *args, **kwargs):
18 | self.goal = np.array([0,0])
19 | self.angle_range = angle_range
20 |
21 | super(AntEnv, self).__init__(frame_skip=frame_skip, *args, **kwargs)
22 | Serializable.__init__(self, angle_range, frame_skip, *args, **kwargs)
23 |
24 | def get_current_obs(self):
25 | current_position = self.get_body_com("torso")
26 | return np.concatenate([
27 | self.model.data.qpos.flat,
28 | self.model.data.qvel.flat,
29 | np.clip(self.model.data.cfrc_ext, -1, 1).flat,
30 | self.get_body_xmat("torso").flat,
31 | self.goal,
32 | current_position,
33 | current_position[:2]-self.goal
34 | ]).reshape(-1)
35 |
36 | def step(self, action):
37 | self.forward_dynamics(action)
38 |
39 | # Base Reward
40 | compos = self.get_body_com("torso")
41 | dist = np.linalg.norm((compos[:2] - self.goal)) / 5
42 | forward_reward = 1 - dist
43 |
44 | # Control and Contact Costs
45 | lb, ub = self.action_bounds
46 | scaling = (ub - lb) * 0.5
47 | ctrl_cost = 0.5 * 1e-2 * np.sum(np.square(action / scaling))
48 | contact_cost = 0.5 * 1e-3 * np.sum(
49 | np.square(np.clip(self.model.data.cfrc_ext, -1, 1))
50 | )
51 |
52 | reward = forward_reward - ctrl_cost - contact_cost
53 |
54 | state = self._state
55 | notdone = all([
56 | np.isfinite(state).all(),
57 | not self.touching('torso_geom','floor'),
58 | state[2] >= 0.2,
59 | state[2] <= 1.0,
60 | ])
61 |
62 | done = not notdone
63 | ob = self.get_current_obs()
64 | return Step(ob, float(reward), done, distance=dist, task=self.goal)
65 |
66 | @overrides
67 | def reset(self, init_state=None, reset_args=None, **kwargs):
68 |
69 | qpos = self.init_qpos.copy().reshape(-1)
70 | qvel = self.init_qvel.copy().reshape(-1) + np.random.uniform(low=-0.005,
71 | high=0.005, size=self.model.nv)
72 |
73 | qvel[9:12] = 0
74 |
75 | self.goal = self.propose()
76 | qpos[-7:-5] = self.goal
77 |
78 | self.set_state(qpos.reshape(-1), qvel)
79 |
80 | self.current_com = self.model.data.com_subtree[0]
81 | self.dcom = np.zeros_like(self.current_com)
82 | return self.get_current_obs()
83 |
84 | def viewer_setup(self):
85 | self.viewer.cam.trackbodyid = -1
86 | self.viewer.cam.distance = 20.0
87 | self.viewer.cam.azimuth = +90.0
88 | self.viewer.cam.elevation = -20
89 |
90 | def retrieve_centers(self,full_states):
91 | return full_states[:,15:17]
92 |
93 | def propose_original(self):
94 | angle = self.angle_range[0] + (np.random.rand()*(self.angle_range[1]-self.angle_range[0]))
95 | magnitude = 5
96 |
97 | return np.array([
98 | magnitude * np.cos(angle),
99 | magnitude * np.sin(angle)
100 | ])
101 |
102 | @overrides
103 | def log_diagnostics(self, paths,prefix=''):
104 | min_distances = np.array([
105 | np.min(path["env_infos"]['distance'])
106 | for path in paths
107 | ])
108 |
109 | final_distances = np.array([
110 | path["env_infos"]['distance'][-1]
111 | for path in paths
112 | ])
113 | avgPct = lambda x: round(np.mean(x)*100,2)
114 |
115 | logger.record_tabular(prefix+'AverageMinDistanceToGoal', np.mean(min_distances))
116 | logger.record_tabular(prefix+'MinMinDistanceToGoal', np.min(min_distances))
117 |
118 | logger.record_tabular(prefix+'AverageFinalDistanceToGoal', np.mean(final_distances))
119 | logger.record_tabular(prefix+'MinFinalDistanceToGoal', np.min(final_distances))
120 |
121 | logger.record_tabular(prefix+'PctInGoal', avgPct(progsFinal < .2))
122 |
--------------------------------------------------------------------------------
/dnc/envs/assets/ant.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
--------------------------------------------------------------------------------
/dnc/envs/assets/catch.xml:
--------------------------------------------------------------------------------
1 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
40 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
176 |
177 |
178 |
--------------------------------------------------------------------------------
/dnc/envs/assets/jaco_meshes/jaco_link_1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/dnc/envs/assets/jaco_meshes/jaco_link_1.stl
--------------------------------------------------------------------------------
/dnc/envs/assets/jaco_meshes/jaco_link_2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/dnc/envs/assets/jaco_meshes/jaco_link_2.stl
--------------------------------------------------------------------------------
/dnc/envs/assets/jaco_meshes/jaco_link_3.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/dnc/envs/assets/jaco_meshes/jaco_link_3.stl
--------------------------------------------------------------------------------
/dnc/envs/assets/jaco_meshes/jaco_link_4.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/dnc/envs/assets/jaco_meshes/jaco_link_4.stl
--------------------------------------------------------------------------------
/dnc/envs/assets/jaco_meshes/jaco_link_5.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/dnc/envs/assets/jaco_meshes/jaco_link_5.stl
--------------------------------------------------------------------------------
/dnc/envs/assets/jaco_meshes/jaco_link_base.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/dnc/envs/assets/jaco_meshes/jaco_link_base.stl
--------------------------------------------------------------------------------
/dnc/envs/assets/jaco_meshes/jaco_link_finger_1.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/dnc/envs/assets/jaco_meshes/jaco_link_finger_1.stl
--------------------------------------------------------------------------------
/dnc/envs/assets/jaco_meshes/jaco_link_finger_2.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/dnc/envs/assets/jaco_meshes/jaco_link_finger_2.stl
--------------------------------------------------------------------------------
/dnc/envs/assets/jaco_meshes/jaco_link_finger_3.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/dnc/envs/assets/jaco_meshes/jaco_link_finger_3.stl
--------------------------------------------------------------------------------
/dnc/envs/assets/jaco_meshes/jaco_link_hand.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/dnc/envs/assets/jaco_meshes/jaco_link_hand.stl
--------------------------------------------------------------------------------
/dnc/envs/assets/lob.xml:
--------------------------------------------------------------------------------
1 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
38 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
181 |
182 |
183 |
--------------------------------------------------------------------------------
/dnc/envs/assets/picker.xml:
--------------------------------------------------------------------------------
1 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
38 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
161 |
162 |
163 |
--------------------------------------------------------------------------------
/dnc/envs/base.py:
--------------------------------------------------------------------------------
1 | import rllab.envs.mujoco.mujoco_env as mujoco_env
2 | from rllab.core.serializable import Serializable
3 | from sklearn.cluster import KMeans
4 |
5 | import numpy as np
6 |
7 | class MujocoEnv(mujoco_env.MujocoEnv):
8 | def __init__(self, frame_skip=1, *args, **kwargs):
9 | self.bd_index = None
10 | super().__init__(*args, **kwargs)
11 | self.frame_skip = frame_skip
12 | self.geom_names_to_indices = {name:index for index,name in enumerate(self.model.geom_names)}
13 |
14 |
15 | def set_state(self, qpos, qvel):
16 | assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
17 | self.model.data.qpos = qpos
18 | self.model.data.qvel = qvel
19 | self.model._compute_subtree() # pylint: disable=W0212
20 | self.model.forward()
21 |
22 | def get_body_com(self, body_name):
23 | # Speeds up getting body positions
24 |
25 | if self.bd_index is None:
26 | self.bd_index = {name:index for index,name in enumerate(self.model.body_names)}
27 |
28 | idx = self.bd_index[body_name]
29 | return self.model.data.com_subtree[idx]
30 |
31 | def touching(self, geom1_name, geom2_name):
32 | idx1 = self.geom_names_to_indices[geom1_name]
33 | idx2 = self.geom_names_to_indices[geom2_name]
34 | for c in self.model.data.contact:
35 | if (c.geom1 == idx1 and c.geom2 == idx2) or (c.geom1 == idx2 and c.geom2 == idx1):
36 | return True
37 | return False
38 |
39 | def touching_group(self, geom1_name, geom2_names):
40 | idx1 = self.geom_names_to_indices[geom1_name]
41 | idx2s = set([self.geom_names_to_indices[geom2_name] for geom2_name in geom2_names])
42 |
43 | for c in self.model.data.contact:
44 | if (c.geom1 == idx1 and c.geom2 in idx2s) or (c.geom1 in idx2s and c.geom2 == idx1):
45 | return True
46 | return False
47 |
48 | def viewer_setup(self):
49 | """
50 | This method is called when the viewer is initialized and after every reset
51 | Optionally implement this method, if you need to tinker with camera position
52 | and so forth.
53 | """
54 | pass
55 |
56 | def get_viewer(self):
57 | if self.viewer is None:
58 | viewer = super().get_viewer()
59 | self.viewer_setup()
60 | return viewer
61 | else:
62 | return self.viewer
63 |
64 |
65 | class KMeansEnv(MujocoEnv):
66 | def __init__(self,kmeans_args=None,*args,**kwargs):
67 | if kmeans_args is None:
68 | self.kmeans = False
69 | else:
70 | self.kmeans = True
71 | self.kmeans_centers = kmeans_args['centers']
72 | self.kmeans_index = kmeans_args['index']
73 |
74 | super(KMeansEnv, self).__init__(*args, **kwargs)
75 |
76 | def propose_original(self):
77 | raise NotImplementedError()
78 |
79 | def propose_kmeans(self):
80 | while True:
81 | proposal = self.propose_original()
82 | distances = np.linalg.norm(self.kmeans_centers-proposal,axis=1)
83 | if np.argmin(distances) == self.kmeans_index:
84 | return proposal
85 |
86 | def propose(self):
87 | if self.kmeans:
88 | return self.propose_kmeans()
89 | else:
90 | return self.propose_original()
91 |
92 | def create_partitions(self,n=10000,k=3):
93 | X = np.array([self.reset() for i in range(n)])
94 | kmeans = KMeans(n_clusters=k).fit(X)
95 | return self.retrieve_centers(kmeans.cluster_centers_)
96 |
97 | def retrieve_centers(self,full_states):
98 | raise NotImplementedError()
99 |
100 | def get_param_values(self):
101 | if self.kmeans:
102 | return dict(kmeans=True, centers=self.kmeans_centers, index=self.kmeans_index)
103 | else:
104 | return dict(kmeans=False)
105 |
106 | def set_param_values(self, params):
107 | self.kmeans = params['kmeans']
108 | if self.kmeans:
109 | self.kmeans_centers = params['centers']
110 | self.kmeans_index = params['index']
111 |
112 |
113 | def create_env_partitions(env, k=4):
114 |
115 | assert isinstance(env, KMeansEnv)
116 | cluster_centers = env.create_partitions(k=k)
117 |
118 | envs = [env.clone(env) for i in range(k)]
119 | for i,local_env in enumerate(envs):
120 | local_env.kmeans = True
121 | local_env.kmeans_centers = cluster_centers
122 | local_env.kmeans_index = i
123 |
124 | return envs
125 |
--------------------------------------------------------------------------------
/dnc/envs/catch.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from dnc.envs.base import KMeansEnv
3 |
4 | from rllab.core.serializable import Serializable
5 | from rllab.envs.base import Step
6 | from rllab.misc import logger
7 | from rllab.misc.overrides import overrides
8 |
9 | import os.path as osp
10 |
11 |
12 | class CatchEnv(KMeansEnv, Serializable):
13 |
14 | FILE = osp.join(osp.abspath(osp.dirname(__file__)), 'assets/catch.xml')
15 |
16 | def __init__(
17 | self, start_pos=(.1,1.7), start_noise=0.2, frame_skip=2,
18 | *args, **kwargs):
19 |
20 | self.start_pos = start_pos
21 | self.start_noise = start_noise
22 |
23 | super(CatchEnv, self).__init__(frame_skip=frame_skip, *args, **kwargs)
24 | Serializable.__init__(self, start_pos, start_noise, frame_skip, *args, **kwargs)
25 |
26 | def get_current_obs(self):
27 | return np.concatenate([
28 | self.model.data.qpos.flat[:],
29 | self.model.data.qvel.flat[:],
30 | ]).reshape(-1)
31 |
32 | def step(self, action):
33 | self.forward_dynamics(action)
34 |
35 | ball_position = self.get_body_com("object")
36 |
37 | mitt_position = self.get_body_com("jaco_link_finger_1") + self.get_body_com("jaco_link_finger_2") + self.get_body_com("jaco_link_finger_3")
38 | mitt_position = mitt_position/3.
39 | difference = np.linalg.norm((ball_position - mitt_position))
40 |
41 | reward = 1 if (difference < .15 and ball_position[2] > .05) else 0
42 |
43 | done = False
44 | return Step(self.get_current_obs(), float(reward), done, distance=difference, in_hand=(reward == 1))
45 |
46 | @overrides
47 | def reset(self):
48 |
49 | qpos = self.init_qpos.copy().reshape(-1)
50 | qvel = self.init_qvel.copy().reshape(-1)
51 |
52 | qpos[6] += 1
53 | qpos[10] += 1.5
54 |
55 | qvel[:3] = [-7, 0, 0]
56 |
57 | self.current_start = self.propose() # Proposal
58 | qpos[1:3] = self.current_start
59 |
60 | self.set_state(qpos.reshape(-1), qvel)
61 |
62 | self.current_com = self.model.data.com_subtree[0]
63 | self.dcom = np.zeros_like(self.current_com)
64 | self.last_location = self.get_body_com("object")
65 |
66 | return self.get_current_obs()
67 |
68 | def retrieve_centers(self,full_states):
69 | return full_states[:, 1:3]
70 |
71 | def propose_original(self):
72 | return (np.random.rand(2) - 0.5) * 2 * self.start_noise + np.array(self.start_pos)
73 |
74 | def viewer_setup(self):
75 | self.viewer.cam.trackbodyid = -1
76 | self.viewer.cam.distance = 3.5
77 | self.viewer.cam.azimuth = -180
78 | self.viewer.cam.elevation = -20
79 |
80 | @overrides
81 | def log_diagnostics(self, paths,prefix=''):
82 | progs = np.array([
83 | np.sum(path['env_infos']['in_hand'] > 0)
84 | for path in paths
85 | ]+[1])
86 |
87 |
88 | logger.record_tabular(prefix+'PercentWithReward', 100*np.mean(progs > 0))
89 | logger.record_tabular(prefix+'PercentOK', 100*np.mean(progs > 5))
90 | logger.record_tabular(prefix+'PercentGood', 100*np.mean(progs > 15))
91 | logger.record_tabular(prefix+'AverageTimeInHand', self.frame_skip*.01*np.mean(progs[progs.nonzero()]))
92 |
93 | if __name__ == "__main__":
94 | env = CatchEnv()
95 | for itr in range(5):
96 | env.reset()
97 | for t in range(50):
98 | env.step(env.action_space.sample())
99 | env.render()
--------------------------------------------------------------------------------
/dnc/envs/lob.py:
--------------------------------------------------------------------------------
1 | from dnc.envs.base import KMeansEnv
2 | import numpy as np
3 |
4 | from rllab.core.serializable import Serializable
5 | from rllab.envs.base import Step
6 | from rllab.misc.overrides import overrides
7 | from rllab.misc import logger
8 |
9 | import os.path as osp
10 |
11 |
12 | class LobberEnv(KMeansEnv, Serializable):
13 |
14 | FILE = osp.join(osp.abspath(osp.dirname(__file__)), 'assets/lob.xml')
15 |
16 | def __init__(self, box_center=(0,0), box_noise=0.4, frame_skip=5, *args, **kwargs):
17 |
18 | self.box_center = box_center
19 | self.box_noise = box_noise
20 |
21 | super(LobberEnv, self).__init__(frame_skip=frame_skip, *args, **kwargs)
22 | Serializable.__init__(self, box_center, box_noise, frame_skip, *args, **kwargs)
23 |
24 | def get_current_obs(self):
25 | finger_com = self.get_body_com("jaco_link_finger_1") + self.get_body_com("jaco_link_finger_2") + self.get_body_com("jaco_link_finger_3")
26 | finger_com = finger_com / 3.
27 |
28 | return np.concatenate([
29 | self.model.data.qpos.flat[:],
30 | self.model.data.qvel.flat[:],
31 | finger_com,
32 | self.relativeBoxPosition,
33 | ]).reshape(-1)
34 |
35 | def step(self,action):
36 |
37 | self.model.data.ctrl = action
38 |
39 | # Taking Steps in the Environment
40 |
41 | reward = 0
42 | for _ in range(self.frame_skip):
43 | self.model.step()
44 | step_reward = self.timestep_reward()
45 | reward += step_reward
46 |
47 | # Reached the End of Trajectory
48 |
49 | done = False
50 | onGround = self.touching_group("geom_object", ["ground", "goal_wall1", "goal_wall2", "goal_wall3", "goal_wall4"])
51 | if onGround and self.numClose > 10:
52 | reward += self.final_reward()
53 | done = True
54 |
55 | ob = self.get_current_obs()
56 | new_com = self.model.data.com_subtree[0]
57 | self.dcom = new_com - self.current_com
58 | self.current_com = new_com
59 |
60 | # Recording Metrics
61 |
62 | obj_position = self.get_body_com("object")
63 | goal_position = self.get_body_com("goal")
64 | distance = np.linalg.norm((goal_position - obj_position)[:2])
65 | normalizedDistance = distance / self.init_block_goal_dist
66 |
67 | return Step(ob, float(reward), done, distance=distance, norm_distance=normalizedDistance)
68 |
69 | @overrides
70 | def reset(self):
71 | self.numClose = 0
72 |
73 | qpos = self.init_qpos.copy().reshape(-1)
74 | qvel = self.init_qvel.copy().reshape(-1) + np.random.uniform(low=-0.005,
75 | high=0.005, size=self.model.nv)
76 |
77 | qpos[1] = -1
78 | qpos[9:12] = np.array((0.6, 0.2,0.03))
79 | qvel[9:12] = 0
80 |
81 | self.relativeBoxPosition = self.propose() # Proposal
82 | qpos[-2:] += self.relativeBoxPosition
83 |
84 | self.set_state(qpos.reshape(-1), qvel)
85 |
86 | # Save initial distance between object and goal
87 | obj_position = self.get_body_com("object")
88 | goal_position = self.get_body_com("goal")
89 | self.init_block_goal_dist = np.linalg.norm(obj_position - goal_position)
90 |
91 | self.current_com = self.model.data.com_subtree[0]
92 | self.dcom = np.zeros_like(self.current_com)
93 | return self.get_current_obs()
94 |
95 | def timestep_reward(self):
96 | obj_position = self.get_body_com("object")
97 |
98 | if obj_position[2] < 0.08:
99 | return 0
100 |
101 | finger_com = self.get_body_com("jaco_link_finger_1") + self.get_body_com("jaco_link_finger_2") + self.get_body_com("jaco_link_finger_3")
102 | finger_com = finger_com / 3.
103 |
104 | vec_1 = obj_position - finger_com
105 | dist_1 = np.linalg.norm(vec_1)
106 |
107 | if dist_1 < .1 and obj_position[0] > .2:
108 | self.numClose += 1
109 | return obj_position[2]
110 | else:
111 | return 0
112 |
113 | def final_reward(self):
114 | obj_position = self.get_body_com("object")
115 | goal_position = self.get_body_com("goal")
116 |
117 | vec_2 = obj_position - goal_position
118 | dist_2 = np.linalg.norm(vec_2[:2])
119 | normalized_dist_2 = dist_2 / self.init_block_goal_dist
120 | clipped_dist_2 = min(1.0, normalized_dist_2)
121 |
122 | if dist_2 < .1:
123 | return 40
124 |
125 | reward = 1 - clipped_dist_2
126 |
127 | return 40 * reward
128 |
129 | def retrieve_centers(self,full_states):
130 | return full_states[:,16:18]-self.init_qpos.copy().reshape(-1)[-2:]
131 |
132 | def propose_original(self):
133 | return np.array(self.box_center) + 2*(np.random.random(2)-0.5)*self.box_noise
134 |
135 | def viewer_setup(self):
136 | self.viewer.cam.trackbodyid = -1
137 | self.viewer.cam.distance = 4.0
138 | self.viewer.cam.azimuth = +60.0
139 | self.viewer.cam.elevation = -30
140 |
141 | @overrides
142 | def log_diagnostics(self, paths, prefix=''):
143 |
144 | progs = np.array([
145 | path['env_infos']['norm_distance'][-1] for path in paths
146 | ])
147 |
148 | inGoal = np.array([
149 | path['env_infos']['distance'][-1] < .1 for path in paths
150 | ])
151 |
152 | avgPct = lambda x: round(np.mean(x)*100,3)
153 |
154 | logger.record_tabular(prefix+'PctInGoal', avgPct(inGoal))
155 | logger.record_tabular(prefix+'AverageFinalDistance', np.mean(progs))
156 | logger.record_tabular(prefix+'MinFinalDistance', np.min(progs ))
--------------------------------------------------------------------------------
/dnc/envs/picker.py:
--------------------------------------------------------------------------------
1 | from dnc.envs.base import KMeansEnv
2 | from rllab.core.serializable import Serializable
3 | import numpy as np
4 |
5 | from rllab.envs.base import Step
6 | from rllab.misc.overrides import overrides
7 | from rllab.misc import logger
8 |
9 | import os.path as osp
10 |
11 | class PickerEnv(KMeansEnv, Serializable):
12 | """
13 | Picking a block, where the block position is randomized over a square region
14 |
15 | goal_args is of form ('noisy', center_of_box, half-width of box)
16 |
17 | """
18 | FILE = osp.join(osp.abspath(osp.dirname(__file__)), 'assets/picker.xml')
19 |
20 | def __init__(self, goal_args=('noisy', (.6,.2), .1), frame_skip=5, *args, **kwargs):
21 |
22 | self.goal_args = goal_args
23 |
24 | super(PickerEnv, self).__init__(frame_skip=frame_skip, *args, **kwargs)
25 | Serializable.__init__(self, goal_args, frame_skip, *args, **kwargs)
26 |
27 | def get_current_obs(self):
28 | finger_com = self.get_body_com("jaco_link_finger_1") + self.get_body_com("jaco_link_finger_2") + self.get_body_com("jaco_link_finger_3")
29 | finger_com = finger_com / 3.
30 |
31 | return np.concatenate([
32 | self.model.data.qpos.flat[:],
33 | self.model.data.qvel.flat[:],
34 | finger_com,
35 | ]).reshape(-1)
36 |
37 | def step(self,action):
38 | self.model.data.ctrl = action
39 |
40 | reward = 0
41 | timesInHand = 0
42 |
43 | for _ in range(self.frame_skip):
44 | self.model.step()
45 | step_reward = self.reward()
46 | timesInHand += step_reward > 0
47 | reward += step_reward
48 |
49 | done = reward == 0 and self.numClose > 0 # Stop it if the block is flinged
50 |
51 | ob = self.get_current_obs()
52 |
53 | new_com = self.model.data.com_subtree[0]
54 | self.dcom = new_com - self.current_com
55 | self.current_com = new_com
56 |
57 | return Step(ob, float(reward), done,timeInHand=timesInHand)
58 |
59 | def reward(self):
60 | obj_position = self.get_body_com("object")
61 |
62 | if obj_position[2] < 0.08:
63 | return 0
64 |
65 | finger_com = self.get_body_com("jaco_link_finger_1") + self.get_body_com("jaco_link_finger_2") + self.get_body_com("jaco_link_finger_3")
66 | finger_com = finger_com / 3.
67 |
68 | vec_1 = obj_position - finger_com
69 | dist_1 = np.linalg.norm(vec_1)
70 |
71 | if dist_1 < .1 and obj_position[0] > .2:
72 | self.numClose += 1
73 | return obj_position[2]
74 | else:
75 | return 0
76 |
77 | def sample_position(self,goal_type,center=(0.6,0.2),noise=0):
78 | if goal_type == 'fixed':
79 | return [center[0],center[1],.03]
80 | elif goal_type == 'noisy':
81 | x,y = center
82 | return [x+(np.random.rand()-0.5)*2*noise,y+(np.random.rand()-0.5)*2*noise,.03]
83 | else:
84 | raise NotImplementedError()
85 |
86 | def retrieve_centers(self,full_states):
87 | return full_states[:,9:12]
88 |
89 | def propose_original(self):
90 | return self.sample_position(*self.goal_args)
91 |
92 | @overrides
93 | def reset(self):
94 | qpos = self.init_qpos.copy().reshape(-1)
95 | qvel = self.init_qvel.copy().reshape(-1) + np.random.uniform(low=-0.005,
96 | high=0.005, size=self.model.nv)
97 |
98 | qpos[1] = -1
99 |
100 | self.position = self.propose() # Proposal
101 | qpos[9:12] = self.position
102 | qvel[9:12] = 0
103 |
104 | self.set_state(qpos.reshape(-1), qvel)
105 |
106 | self.numClose = 0
107 |
108 | self.current_com = self.model.data.com_subtree[0]
109 | self.dcom = np.zeros_like(self.current_com)
110 | return self.get_current_obs()
111 |
112 | def viewer_setup(self):
113 | self.viewer.cam.trackbodyid = -1
114 | self.viewer.cam.distance = 4.0
115 | self.viewer.cam.azimuth = +0.0
116 | self.viewer.cam.elevation = -40
117 |
118 | @overrides
119 | def log_diagnostics(self, paths, prefix=''):
120 |
121 | timeOffGround = np.array([
122 | np.sum(path['env_infos']['timeInHand'])*.01
123 | for path in paths])
124 |
125 | timeInAir = timeOffGround[timeOffGround.nonzero()]
126 |
127 | if len(timeInAir) == 0:
128 | timeInAir = [0]
129 |
130 | avgPct = lambda x: round(np.mean(x) * 100, 2)
131 |
132 | logger.record_tabular(prefix+'PctPicked', avgPct(timeOffGround > .3))
133 | logger.record_tabular(prefix+'PctReceivedReward', avgPct(timeOffGround > 0))
134 |
135 | logger.record_tabular(prefix+'AverageTimeInAir',np.mean(timeOffGround))
136 | logger.record_tabular(prefix+'MaxTimeInAir',np.max(timeOffGround ))
--------------------------------------------------------------------------------
/dnc/sampler/policy_sampler.py:
--------------------------------------------------------------------------------
1 | from sandbox.rocky.tf.algos.batch_polopt import BatchPolopt
2 |
3 |
4 | class Sampler(BatchPolopt):
5 | """
6 | Creates a dummy class for sampling:
7 | please refer to sandbox.rocky.tf.algos.batch_polopt:BatchPolopt
8 | for all options
9 | """
10 | def init_opt(self):
11 | pass
12 |
--------------------------------------------------------------------------------
/examples/dnc_pick.py:
--------------------------------------------------------------------------------
1 | # Environment Imports
2 | from sandbox.rocky.tf.envs.base import TfEnv
3 | from rllab.envs.normalized_env import normalize
4 | import dnc.envs as dnc_envs
5 |
6 | # Algo Imports
7 |
8 | import dnc.algos.trpo as dnc_trpo
9 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline
10 | from sandbox.rocky.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy
11 |
12 | # Experiment Imports
13 |
14 | from rllab.misc.instrument import stub, run_experiment_lite
15 |
16 | def run_task(args,*_):
17 |
18 | base_env = dnc_envs.create_stochastic('pick')
19 | base_partitions = dnc_envs.create_env_partitions(base_env, k=4)
20 |
21 | env = TfEnv(normalize(base_env))
22 | partitions = [TfEnv(normalize(part_env)) for part_env in base_partitions]
23 |
24 | policy_class = GaussianMLPPolicy
25 | policy_kwargs = dict(
26 | min_std=1e-2,
27 | hidden_sizes=(150, 100, 50),
28 | )
29 |
30 | baseline_class = LinearFeatureBaseline
31 |
32 | algo = dnc_trpo.TRPO(
33 | env=env,
34 | partitions=partitions,
35 | policy_class=policy_class,
36 | policy_kwargs=policy_kwargs,
37 | baseline_class=baseline_class,
38 | batch_size=20000,
39 | n_itr=500,
40 | force_batch_sampler=True,
41 | max_path_length=50,
42 | discount=1,
43 | step_size=0.02,
44 | )
45 |
46 | algo.train()
47 |
48 | run_experiment_lite(
49 | run_task,
50 | log_dir='data/dnc/pick',
51 | n_parallel=12,
52 | snapshot_mode="last",
53 | seed=1,
54 | variant=dict(),
55 | use_cloudpickle=True,
56 | )
57 |
--------------------------------------------------------------------------------
/examples/trpo_pick.py:
--------------------------------------------------------------------------------
1 | # Environment Imports
2 | from sandbox.rocky.tf.envs.base import TfEnv
3 | from rllab.envs.normalized_env import normalize
4 | import dnc.envs as dnc_envs
5 |
6 | # Algo Imports
7 |
8 | from sandbox.rocky.tf.algos.trpo import TRPO
9 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline
10 | from sandbox.rocky.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy
11 |
12 | # Experiment Imports
13 |
14 | from rllab.misc.instrument import stub, run_experiment_lite
15 |
16 | def run_task(args,*_):
17 |
18 | env = TfEnv(normalize(dnc_envs.create_stochastic('pick'))) # Cannot be solved easily by TRPO
19 |
20 | policy = GaussianMLPPolicy(
21 | name="policy",
22 | env_spec=env.spec,
23 | min_std=1e-2,
24 | hidden_sizes=(150, 100, 50),
25 | )
26 |
27 | baseline = LinearFeatureBaseline(env_spec=env.spec)
28 |
29 | algo = TRPO(
30 | env=env,
31 | policy=policy,
32 | baseline=baseline,
33 | batch_size=50000,
34 | force_batch_sampler=True,
35 | max_path_length=50,
36 | discount=1,
37 | step_size=0.02,
38 | )
39 |
40 | algo.train()
41 |
42 |
43 | run_experiment_lite(
44 | run_task,
45 | log_dir='data/trpo/pick',
46 | n_parallel=12,
47 | snapshot_mode="last",
48 | seed=1,
49 | variant=dict(),
50 | use_cloudpickle=True
51 | )
52 |
--------------------------------------------------------------------------------
/examples/trpo_pick_nonoise.py:
--------------------------------------------------------------------------------
1 | # Environment Imports
2 | from sandbox.rocky.tf.envs.base import TfEnv
3 | from rllab.envs.normalized_env import normalize
4 | import dnc.envs as dnc_envs
5 |
6 | # Algo Imports
7 |
8 | from sandbox.rocky.tf.algos.trpo import TRPO
9 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline
10 | from sandbox.rocky.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy
11 |
12 | # Experiment Imports
13 |
14 | from rllab.misc.instrument import stub, run_experiment_lite
15 |
16 | def run_task(args,*_):
17 |
18 | env = TfEnv(normalize(dnc_envs.create_deterministic('pick'))) # No stochasticity in initial state: should be solved easily by TRPO
19 |
20 | policy = GaussianMLPPolicy(
21 | name="policy",
22 | env_spec=env.spec,
23 | min_std=1e-2,
24 | hidden_sizes=(150, 100, 50),
25 | )
26 |
27 | baseline = LinearFeatureBaseline(env_spec=env.spec)
28 |
29 | algo = TRPO(
30 | env=env,
31 | policy=policy,
32 | baseline=baseline,
33 | batch_size=50000,
34 | force_batch_sampler=True,
35 | max_path_length=50,
36 | discount=.98,
37 | step_size=0.02,
38 | )
39 | algo.train()
40 |
41 |
42 | run_experiment_lite(
43 | run_task,
44 | log_dir='data/trpo_nonoise/pick',
45 | n_parallel=12,
46 | snapshot_mode="last",
47 | seed=1,
48 | variant=dict(),
49 | use_cloudpickle=True,
50 | )
51 |
--------------------------------------------------------------------------------
/scripts/sim_policy.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import joblib
4 | import tensorflow as tf
5 |
6 | from rllab.misc.console import query_yes_no
7 | from rllab.sampler.utils import rollout
8 |
9 | if __name__ == "__main__":
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('file', type=str,
13 | help='path to the snapshot file')
14 | parser.add_argument('--number', type=int, default=-1,
15 | help='Which policy to play (-1 = center): default=-1')
16 |
17 | parser.add_argument('--max_path_length', type=int, default=1000,
18 | help='Max length of rollout')
19 | parser.add_argument('--speedup', type=float, default=1,
20 | help='Speedup')
21 | args = parser.parse_args()
22 |
23 | # If the snapshot file use tensorflow, do:
24 | # import tensorflow as tf
25 | # with tf.Session():
26 | # [rest of the code]
27 | with tf.Session() as sess:
28 | data = joblib.load(args.file)
29 |
30 | if args.number == -1:
31 | policy = data['policy']
32 | env = data['env']
33 | else:
34 | policy = data['policy%d'%args.number]
35 | env = data['env%d'%args.number]
36 |
37 | while True:
38 | path = rollout(env, policy, max_path_length=args.max_path_length,
39 | animated=True, speedup=args.speedup)
40 | if not query_yes_no('Continue simulation?'):
41 | break
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # setup.py
2 | from setuptools import setup
3 |
4 | setup(
5 | name='dnc',
6 | version='1.0.0',
7 | packages=[],
8 | )
9 |
--------------------------------------------------------------------------------
/videos/catching.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/videos/catching.gif
--------------------------------------------------------------------------------
/videos/lobbing.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/videos/lobbing.gif
--------------------------------------------------------------------------------