├── .gitignore ├── README.md ├── dnc ├── algos │ ├── batch_polopt.py │ ├── npo.py │ └── trpo.py ├── envs │ ├── __init__.py │ ├── ant.py │ ├── assets │ │ ├── ant.xml │ │ ├── catch.xml │ │ ├── jaco_meshes │ │ │ ├── jaco_link_1.stl │ │ │ ├── jaco_link_2.stl │ │ │ ├── jaco_link_3.stl │ │ │ ├── jaco_link_4.stl │ │ │ ├── jaco_link_5.stl │ │ │ ├── jaco_link_base.stl │ │ │ ├── jaco_link_finger_1.stl │ │ │ ├── jaco_link_finger_2.stl │ │ │ ├── jaco_link_finger_3.stl │ │ │ └── jaco_link_hand.stl │ │ ├── lob.xml │ │ └── picker.xml │ ├── base.py │ ├── catch.py │ ├── lob.py │ └── picker.py └── sampler │ └── policy_sampler.py ├── examples ├── dnc_pick.py ├── trpo_pick.py └── trpo_pick_nonoise.py ├── scripts └── sim_policy.py ├── setup.py └── videos ├── catching.gif └── lobbing.gif /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | *.pkl 3 | 4 | .vscode/ 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Divide-and-Conquer Reinforcement Learning 2 | 3 | This repository contains code accompaning the paper, [Divide-and-Conquer Reinforcement Learning (Ghosh et al., ICLR 2018)](https://arxiv.org/abs/1711.09874). It includes code for the DnC algorithm, and the Mujoco environments used for the empirical evaluation. Please see [the project website](http://dibyaghosh.com/dnc/) for videos and further details. 4 | 5 | 6 | 7 | ### Dependencies 8 | 9 | This codebase requires a valid installation of `rllab`. Please refer to the [rllab repository](https://github.com/rll/rllab) for installation instructions. 10 | 11 | The environments are built in Mujoco 1.31: follow the instructions [here](https://github.com/openai/mujoco-py/tree/0.5) to install Mujoco 1.31 if not already done. You are required to have a Mujoco license to run any of the environments. 12 | 13 | ### Usage 14 | 15 | Sample scripts for working with DnC and the provided environments can be found in the [examples](examples/) directory. In particular, a sample scripts for running DnC is located [here](examples/dnc_pick.py). 16 | 17 | ```bash 18 | source activate rllab_env 19 | python examples/dnc_pick.py 20 | ``` 21 | 22 | Environments are located in the [dnc/envs/](dnc/envs/) directory, and the DnC implementation can be found at [dnc/algos/](dnc/algos). 23 | 24 | ### Contact 25 | To ask questions or report issues, please open an issue on the [issues tracker](https://github.com/dibyaghosh/dnc/issues). 26 | 27 | ### Citing 28 | 29 | If you use DnC, please cite the following paper: 30 | 31 | - Dibya Ghosh, Avi Singh, Aravind Rajeswaran, Vikash Kumar, Sergey Levine. "[Divide-and-Conquer Reinforcement Learning](https://arxiv.org/abs/1711.09874)". _Proceedings of the International Conference on Learning Representaions (ICLR), 2018._ 32 | -------------------------------------------------------------------------------- /dnc/algos/batch_polopt.py: -------------------------------------------------------------------------------- 1 | from rllab.algos.base import RLAlgorithm 2 | import rllab.misc.logger as logger 3 | 4 | from dnc.sampler.policy_sampler import Sampler 5 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline 6 | 7 | import tensorflow as tf 8 | import numpy as np 9 | import time 10 | 11 | 12 | class BatchPolopt(RLAlgorithm): 13 | """ 14 | Base class for batch sampling-based policy optimization methods. 15 | This includes various policy gradient methods like vpg, npg, ppo, trpo, etc 16 | """ 17 | 18 | def __init__( 19 | self, 20 | # DnC Options 21 | env, 22 | partitions, 23 | policy_class, 24 | policy_kwargs=dict(), 25 | baseline_class=LinearFeatureBaseline, 26 | baseline_kwargs=dict(), 27 | distillation_period=50, 28 | # Base BatchPolopt Options 29 | scope=None, 30 | n_itr=500, 31 | start_itr=0, 32 | batch_size=5000, # Batch size for each partition 33 | max_path_length=500, 34 | discount=0.99, 35 | gae_lambda=1, 36 | plot=False, 37 | pause_for_plot=False, 38 | center_adv=True, 39 | positive_adv=False, 40 | store_paths=False, 41 | whole_paths=True, 42 | fixed_horizon=False, 43 | force_batch_sampler=False, 44 | **kwargs 45 | ): 46 | """ 47 | DnC options 48 | 49 | :param env: Central environment trying to solve 50 | :param partitions: A list of environments to use as partitions for central environment 51 | :param policy_class: The policy class to use for global and local policies (for example GaussianMLPPolicy from sandbox.rocky.tf.policies.gaussian_mlp_policy) 52 | :param policy_kwargs: A dictionary of additional parameters used for policy (beyond name and env_spec) 53 | :param baseline_class: The baseline class used for local policies (for example LinearFeatureBaseline from rllab.baselines.linear_feature_baseline) 54 | :param baseline_kwargs: A dictionary of additional parameters used for baselien (beyond env_spec) 55 | :param distillation_period: How often to distill local policies into global policy, and reset 56 | 57 | Base RLLAB options 58 | 59 | :param scope: Scope for identifying the algorithm. Must be specified if running multiple algorithms simultaneously, each using different environments and policies 60 | :param n_itr: Number of iterations. 61 | :param start_itr: Starting iteration. 62 | :param batch_size: Number of samples per iteration. 63 | :param max_path_length: Maximum length of a single rollout. 64 | :param discount: Discount. 65 | :param gae_lambda: Lambda used for generalized advantage estimation. 66 | :param plot: Plot evaluation run after each iteration. 67 | :param pause_for_plot: Whether to pause before contiuing when plotting. 68 | :param center_adv: Whether to rescale the advantages so that they have mean 0 and standard deviation 1. 69 | :param positive_adv: Whether to shift the advantages so that they are always positive. When used in 70 | conjunction with center_adv the advantages will be standardized before shifting. 71 | :param store_paths: Whether to save all paths data to the snapshot. 72 | :return: 73 | """ 74 | 75 | self.env = env 76 | self.env_partitions = partitions 77 | self.n_parts = len(self.env_partitions) 78 | 79 | self.policy = policy_class(name='central_policy', env_spec=env.spec, **policy_kwargs) 80 | 81 | self.local_policies = [ 82 | policy_class(name='local_policy_%d' % (n), env_spec=env.spec, **policy_kwargs) for n in range(self.n_parts) 83 | ] 84 | 85 | self.baseline = baseline_class(env_spec=env.spec, **baseline_kwargs) 86 | 87 | self.local_baselines = [ 88 | baseline_class(env_spec=env.spec, **baseline_kwargs) for n in range(self.n_parts) 89 | ] 90 | 91 | self.distillation_period = distillation_period 92 | 93 | # Housekeeping 94 | 95 | self.scope = scope 96 | self.n_itr = n_itr 97 | self.start_itr = start_itr 98 | self.batch_size = batch_size 99 | self.max_path_length = max_path_length 100 | self.discount = discount 101 | self.gae_lambda = gae_lambda 102 | self.plot = plot 103 | self.pause_for_plot = pause_for_plot 104 | self.center_adv = center_adv 105 | self.positive_adv = positive_adv 106 | self.store_paths = store_paths 107 | self.whole_paths = whole_paths 108 | self.fixed_horizon = fixed_horizon 109 | 110 | self.local_samplers = [ 111 | Sampler( 112 | env=env, 113 | policy=policy, 114 | baseline=baseline, 115 | scope=scope, 116 | batch_size=batch_size, 117 | max_path_length=max_path_length, 118 | discount=discount, 119 | gae_lambda=gae_lambda, 120 | center_adv=center_adv, 121 | positive_adv=positive_adv, 122 | whole_paths=whole_paths, 123 | fixed_horizon=fixed_horizon, 124 | force_batch_sampler=force_batch_sampler 125 | ) for env, policy, baseline in zip(self.env_partitions, self.local_policies, self.local_baselines) 126 | ] 127 | 128 | self.global_sampler = Sampler( 129 | env=self.env, 130 | policy=self.policy, 131 | baseline=self.baseline, 132 | scope=scope, 133 | batch_size=batch_size, 134 | max_path_length=max_path_length, 135 | discount=discount, 136 | gae_lambda=gae_lambda, 137 | center_adv=center_adv, 138 | positive_adv=positive_adv, 139 | whole_paths=whole_paths, 140 | fixed_horizon=fixed_horizon, 141 | force_batch_sampler=force_batch_sampler 142 | ) 143 | 144 | self.init_opt() 145 | 146 | def start_worker(self): 147 | for sampler in self.local_samplers: 148 | sampler.start_worker() 149 | self.global_sampler.start_worker() 150 | 151 | def shutdown_worker(self): 152 | for sampler in self.local_samplers: 153 | sampler.shutdown_worker() 154 | self.global_sampler.shutdown_worker() 155 | 156 | def train(self, sess=None): 157 | if sess is None: 158 | config = tf.ConfigProto() 159 | config.gpu_options.allow_growth = True 160 | sess = tf.Session(config=config) 161 | sess.__enter__() 162 | sess.run(tf.initialize_all_variables()) 163 | else: 164 | sess.run(tf.initialize_variables(list(tf.get_variable(name) for name in sess.run(tf.report_uninitialized_variables())))) 165 | 166 | self.start_worker() 167 | start_time = time.time() 168 | 169 | for itr in range(self.start_itr, self.n_itr): 170 | itr_start_time = time.time() 171 | 172 | with logger.prefix('itr #%d | ' % itr): 173 | all_paths = [] 174 | logger.log("Obtaining samples...") 175 | for sampler in self.local_samplers: 176 | all_paths.append(sampler.obtain_samples(itr)) 177 | 178 | logger.log("Processing samples...") 179 | all_samples_data = [] 180 | for n, (sampler, paths) in enumerate(zip(self.local_samplers, all_paths)): 181 | with logger.tabular_prefix(str(n)): 182 | all_samples_data.append(sampler.process_samples(itr, paths)) 183 | 184 | logger.log("Logging diagnostics...") 185 | self.log_diagnostics(all_paths,) 186 | 187 | logger.log("Optimizing policy...") 188 | self.optimize_policy(itr, all_samples_data) 189 | 190 | logger.log("Saving snapshot...") 191 | params = self.get_itr_snapshot(itr, all_samples_data) # , **kwargs) 192 | logger.save_itr_params(itr, params) 193 | 194 | logger.log("Saved") 195 | logger.record_tabular('Time', time.time() - start_time) 196 | logger.record_tabular('ItrTime', time.time() - itr_start_time) 197 | logger.dump_tabular(with_prefix=False) 198 | 199 | self.shutdown_worker() 200 | 201 | def log_diagnostics(self, all_paths): 202 | for n, (env, policy, baseline, paths) in enumerate(zip(self.env_partitions, self.local_policies, self.local_baselines, all_paths)): 203 | with logger.tabular_prefix(str(n)): 204 | env.log_diagnostics(paths) 205 | policy.log_diagnostics(paths) 206 | baseline.log_diagnostics(paths) 207 | 208 | def init_opt(self): 209 | """ 210 | Initialize the optimization procedure. If using tensorflow, this may 211 | include declaring all the variables and compiling functions 212 | """ 213 | raise NotImplementedError() 214 | 215 | def get_itr_snapshot(self, itr, samples_data): 216 | """ 217 | Returns all the data that should be saved in the snapshot for this 218 | iteration. 219 | """ 220 | 221 | d = dict() 222 | for n, (policy, env) in enumerate(zip(self.local_policies, self.env_partitions)): 223 | d['policy%d' % n] = policy 224 | d['env%d' % n] = env 225 | 226 | d['policy'] = self.policy 227 | d['env'] = self.env 228 | 229 | return d 230 | 231 | def optimize_policy(self, itr, samples_data): 232 | """ 233 | Runs the optimization procedure 234 | """ 235 | raise NotImplementedError() 236 | -------------------------------------------------------------------------------- /dnc/algos/npo.py: -------------------------------------------------------------------------------- 1 | # Base imports 2 | import numpy as np 3 | import tensorflow as tf 4 | from dnc.algos.batch_polopt import BatchPolopt 5 | 6 | # Optimizers 7 | from sandbox.rocky.tf.optimizers.penalty_lbfgs_optimizer import PenaltyLbfgsOptimizer 8 | from sandbox.rocky.tf.optimizers.first_order_optimizer import FirstOrderOptimizer 9 | 10 | # Utilities 11 | from rllab.misc.ext import sliced_fun 12 | from rllab.misc.overrides import overrides 13 | from sandbox.rocky.tf.misc import tensor_utils 14 | from rllab.misc import ext 15 | 16 | # Logging 17 | import rllab.misc.logger as logger 18 | 19 | 20 | # Convenience Function 21 | def default(variable, defaultValue): 22 | return variable if variable is not None else defaultValue 23 | 24 | 25 | class NPO(BatchPolopt): 26 | """ 27 | Natural Policy Optimization. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | optimizer_class=None, 33 | optimizer_args=None, 34 | step_size=0.01, 35 | penalty=0.0, 36 | **kwargs 37 | ): 38 | 39 | self.optimizer_class = default(optimizer_class, PenaltyLbfgsOptimizer) 40 | self.optimizer_args = default(optimizer_args, dict()) 41 | 42 | self.penalty = penalty 43 | self.constrain_together = penalty > 0 44 | 45 | self.step_size = step_size 46 | 47 | self.metrics = [] 48 | super(NPO, self).__init__(**kwargs) 49 | 50 | @overrides 51 | def init_opt(self): 52 | 53 | ############################### 54 | # 55 | # Variable Definitions 56 | # 57 | ############################### 58 | 59 | all_task_dist_info_vars = [] 60 | all_obs_vars = [] 61 | 62 | for i, policy in enumerate(self.local_policies): 63 | 64 | task_obs_var = self.env_partitions[i].observation_space.new_tensor_variable('obs%d' % i, extra_dims=1) 65 | task_dist_info_vars = [] 66 | 67 | for j, other_policy in enumerate(self.local_policies): 68 | 69 | state_info_vars = dict() # Not handling recurrent policies 70 | dist_info_vars = other_policy.dist_info_sym(task_obs_var, state_info_vars) 71 | task_dist_info_vars.append(dist_info_vars) 72 | 73 | all_obs_vars.append(task_obs_var) 74 | all_task_dist_info_vars.append(task_dist_info_vars) 75 | 76 | obs_var = self.env.observation_space.new_tensor_variable('obs', extra_dims=1) 77 | action_var = self.env.action_space.new_tensor_variable('action', extra_dims=1) 78 | advantage_var = tensor_utils.new_tensor('advantage', ndim=1, dtype=tf.float32) 79 | 80 | old_dist_info_vars = { 81 | k: tf.placeholder(tf.float32, shape=[None] + list(shape), name='old_%s' % k) 82 | for k, shape in self.policy.distribution.dist_info_specs 83 | } 84 | 85 | old_dist_info_vars_list = [old_dist_info_vars[k] for k in self.policy.distribution.dist_info_keys] 86 | 87 | input_list = [obs_var, action_var, advantage_var] + old_dist_info_vars_list + all_obs_vars 88 | 89 | ############################### 90 | # 91 | # Local Policy Optimization 92 | # 93 | ############################### 94 | 95 | self.optimizers = [] 96 | self.metrics = [] 97 | 98 | for n, policy in enumerate(self.local_policies): 99 | 100 | state_info_vars = dict() 101 | dist_info_vars = policy.dist_info_sym(obs_var, state_info_vars) 102 | dist = policy.distribution 103 | 104 | kl = dist.kl_sym(old_dist_info_vars, dist_info_vars) 105 | lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars, dist_info_vars) 106 | surr_loss = - tf.reduce_mean(lr * advantage_var) 107 | 108 | if self.constrain_together: 109 | additional_loss = Metrics.kl_on_others(n, dist, all_task_dist_info_vars) 110 | else: 111 | additional_loss = tf.constant(0.0) 112 | 113 | local_loss = surr_loss + self.penalty * additional_loss 114 | 115 | kl_metric = tensor_utils.compile_function(inputs=input_list, outputs=additional_loss, log_name="KLPenalty%d" % n) 116 | self.metrics.append(kl_metric) 117 | 118 | mean_kl_constraint = tf.reduce_mean(kl) 119 | 120 | optimizer = self.optimizer_class(**self.optimizer_args) 121 | optimizer.update_opt( 122 | loss=local_loss, 123 | target=policy, 124 | leq_constraint=(mean_kl_constraint, self.step_size), 125 | inputs=input_list, 126 | constraint_name="mean_kl_%d" % n, 127 | ) 128 | self.optimizers.append(optimizer) 129 | 130 | ############################### 131 | # 132 | # Global Policy Optimization 133 | # 134 | ############################### 135 | 136 | # Behaviour Cloning Loss 137 | 138 | state_info_vars = dict() 139 | center_dist_info_vars = self.policy.dist_info_sym(obs_var, state_info_vars) 140 | behaviour_cloning_loss = tf.losses.mean_squared_error(action_var, center_dist_info_vars['mean']) 141 | self.center_optimizer = FirstOrderOptimizer(max_epochs=1, verbose=True, batch_size=1000) 142 | self.center_optimizer.update_opt(behaviour_cloning_loss, self.policy, [obs_var, action_var]) 143 | 144 | # TRPO Loss 145 | 146 | kl = dist.kl_sym(old_dist_info_vars, center_dist_info_vars) 147 | lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars, center_dist_info_vars) 148 | center_trpo_loss = - tf.reduce_mean(lr * advantage_var) 149 | mean_kl_constraint = tf.reduce_mean(kl) 150 | 151 | optimizer = self.optimizer_class(**self.optimizer_args) 152 | optimizer.update_opt( 153 | loss=center_trpo_loss, 154 | target=self.policy, 155 | leq_constraint=(mean_kl_constraint, self.step_size), 156 | inputs=[obs_var, action_var, advantage_var] + old_dist_info_vars_list, 157 | constraint_name="mean_kl_center", 158 | ) 159 | 160 | self.center_trpo_optimizer = optimizer 161 | 162 | # Reset Local Policies to Global Policy 163 | 164 | assignment_operations = [] 165 | 166 | for policy in self.local_policies: 167 | for param_local, param_center in zip(policy.get_params_internal(), self.policy.get_params_internal()): 168 | if 'std' not in param_local.name: 169 | assignment_operations.append(tf.assign(param_local, param_center)) 170 | 171 | self.reset_to_center = tf.group(*assignment_operations) 172 | 173 | return dict() 174 | 175 | def optimize_local_policies(self, itr, all_samples_data): 176 | 177 | dist_info_keys = self.policy.distribution.dist_info_keys 178 | for n, optimizer in enumerate(self.optimizers): 179 | 180 | obs_act_adv_values = tuple(ext.extract(all_samples_data[n], "observations", "actions", "advantages")) 181 | dist_info_list = tuple([all_samples_data[n]["agent_infos"][k] for k in dist_info_keys]) 182 | all_task_obs_values = tuple([samples_data["observations"] for samples_data in all_samples_data]) 183 | 184 | all_input_values = obs_act_adv_values + dist_info_list + all_task_obs_values 185 | optimizer.optimize(all_input_values) 186 | 187 | kl_penalty = sliced_fun(self.metrics[n], 1)(all_input_values) 188 | logger.record_tabular('KLPenalty%d' % n, kl_penalty) 189 | 190 | def optimize_global_policy(self, itr, all_samples_data): 191 | 192 | all_observations = np.concatenate([samples_data['observations'] for samples_data in all_samples_data]) 193 | all_actions = np.concatenate([samples_data['agent_infos']['mean'] for samples_data in all_samples_data]) 194 | 195 | num_itrs = 1 if itr % self.distillation_period != 0 else 30 196 | 197 | for _ in range(num_itrs): 198 | self.center_optimizer.optimize([all_observations, all_actions]) 199 | 200 | paths = self.global_sampler.obtain_samples(itr) 201 | samples_data = self.global_sampler.process_samples(itr, paths) 202 | 203 | obs_values = tuple(ext.extract(samples_data, "observations", "actions", "advantages")) 204 | dist_info_list = [samples_data["agent_infos"][k] for k in self.policy.distribution.dist_info_keys] 205 | 206 | all_input_values = obs_values + tuple(dist_info_list) 207 | 208 | self.center_trpo_optimizer.optimize(all_input_values) 209 | self.env.log_diagnostics(paths) 210 | 211 | @overrides 212 | def optimize_policy(self, itr, all_samples_data): 213 | 214 | self.optimize_local_policies(itr, all_samples_data) 215 | self.optimize_global_policy(itr, all_samples_data) 216 | 217 | if itr % self.distillation_period == 0: 218 | sess = tf.get_default_session() 219 | sess.run(self.reset_to_center) 220 | logger.log('Reset Local Policies to Global Policies') 221 | 222 | return dict() 223 | 224 | ############################ 225 | # 226 | # KL Divergence 227 | # 228 | ############################ 229 | 230 | 231 | class Metrics: 232 | @staticmethod 233 | def symmetric_kl(dist, info_vars_1, info_vars_2): 234 | side1 = tf.reduce_mean(dist.kl_sym(info_vars_2, info_vars_1)) 235 | side2 = tf.reduce_mean(dist.kl_sym(info_vars_1, info_vars_2)) 236 | return (side1 + side2) / 2 237 | 238 | @staticmethod 239 | def kl_on_others(n, dist, dist_info_vars): 240 | # \sum_{j=1} E_{\sim S_j}[D_{kl}(\pi_j || \pi_i)] 241 | if len(dist_info_vars) < 2: 242 | return 0 243 | 244 | kl_with_others = 0 245 | for i in range(len(dist_info_vars)): 246 | if i != n: 247 | kl_with_others += Metrics.symmetric_kl(dist, dist_info_vars[i][i], dist_info_vars[i][n]) 248 | 249 | return kl_with_others / (len(dist_info_vars) - 1) 250 | -------------------------------------------------------------------------------- /dnc/algos/trpo.py: -------------------------------------------------------------------------------- 1 | from dnc.algos.npo import NPO 2 | from sandbox.rocky.tf.optimizers.conjugate_gradient_optimizer import ConjugateGradientOptimizer 3 | 4 | 5 | class TRPO(NPO): 6 | """ 7 | Trust Region Policy Optimization 8 | 9 | Please refer to the following classes for parameters: 10 | - dnc.algos.batch_polopt 11 | - dnc.algos.npo 12 | 13 | """ 14 | 15 | def __init__( 16 | self, 17 | **kwargs 18 | ): 19 | super(TRPO, self).__init__( 20 | optimizer_class=ConjugateGradientOptimizer, 21 | optimizer_args=dict(), 22 | **kwargs 23 | ) 24 | -------------------------------------------------------------------------------- /dnc/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from dnc.envs.picker import PickerEnv 2 | from dnc.envs.lob import LobberEnv 3 | from dnc.envs.ant import AntEnv 4 | from dnc.envs.catch import CatchEnv 5 | 6 | from dnc.envs.base import create_env_partitions 7 | 8 | import numpy as np 9 | 10 | _envs = { 11 | 'pick': PickerEnv, 12 | 'lob': LobberEnv, 13 | 'catch': CatchEnv, 14 | 'ant': AntEnv 15 | } 16 | 17 | _stochastic_params = { 18 | 'pick': dict(goal_args=('noisy',(.6,.2),.1)), 19 | 'lob': dict(box_center=(0,0), box_noise=0.4), 20 | 'catch': dict(start_pos=(.1,1.7), start_noise=0.2), 21 | 'ant': dict(angle_range=(0,2*np.pi)), 22 | } 23 | 24 | _deterministic_params = { 25 | 'pick': dict(goal_args=('noisy',(.6,.2),0)), 26 | 'lob': dict(box_center=(0,0), box_noise=0), 27 | 'catch': dict(start_pos=(.1,1.7), start_noise=0), 28 | 'ant': dict(angle_range=(-1e-4,1e-4)), 29 | } 30 | 31 | def create_stochastic(name): 32 | assert name in _stochastic_params 33 | return _envs[name](**_stochastic_params[name]) 34 | 35 | def create_deterministic(name): 36 | assert name in _deterministic_params 37 | return _envs[name](**_deterministic_params[name]) 38 | 39 | def test_env(env, n_rolls=5, n_steps=50): 40 | for i in range(n_rolls): 41 | env.reset() 42 | for t in range(n_steps): 43 | env.step(env.action_space.sample()) 44 | env.render() 45 | env.render(close=True) -------------------------------------------------------------------------------- /dnc/envs/ant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from dnc.envs.base import KMeansEnv 4 | from rllab.envs.base import Step 5 | 6 | from rllab.core.serializable import Serializable 7 | from rllab.misc.overrides import overrides 8 | from rllab.misc import logger 9 | 10 | import os.path as osp 11 | 12 | 13 | class AntEnv(KMeansEnv, Serializable): 14 | 15 | FILE = osp.join(osp.abspath(osp.dirname(__file__)), 'assets/ant.xml') 16 | 17 | def __init__(self, angle_range=(0,2*np.pi), frame_skip=2, *args, **kwargs): 18 | self.goal = np.array([0,0]) 19 | self.angle_range = angle_range 20 | 21 | super(AntEnv, self).__init__(frame_skip=frame_skip, *args, **kwargs) 22 | Serializable.__init__(self, angle_range, frame_skip, *args, **kwargs) 23 | 24 | def get_current_obs(self): 25 | current_position = self.get_body_com("torso") 26 | return np.concatenate([ 27 | self.model.data.qpos.flat, 28 | self.model.data.qvel.flat, 29 | np.clip(self.model.data.cfrc_ext, -1, 1).flat, 30 | self.get_body_xmat("torso").flat, 31 | self.goal, 32 | current_position, 33 | current_position[:2]-self.goal 34 | ]).reshape(-1) 35 | 36 | def step(self, action): 37 | self.forward_dynamics(action) 38 | 39 | # Base Reward 40 | compos = self.get_body_com("torso") 41 | dist = np.linalg.norm((compos[:2] - self.goal)) / 5 42 | forward_reward = 1 - dist 43 | 44 | # Control and Contact Costs 45 | lb, ub = self.action_bounds 46 | scaling = (ub - lb) * 0.5 47 | ctrl_cost = 0.5 * 1e-2 * np.sum(np.square(action / scaling)) 48 | contact_cost = 0.5 * 1e-3 * np.sum( 49 | np.square(np.clip(self.model.data.cfrc_ext, -1, 1)) 50 | ) 51 | 52 | reward = forward_reward - ctrl_cost - contact_cost 53 | 54 | state = self._state 55 | notdone = all([ 56 | np.isfinite(state).all(), 57 | not self.touching('torso_geom','floor'), 58 | state[2] >= 0.2, 59 | state[2] <= 1.0, 60 | ]) 61 | 62 | done = not notdone 63 | ob = self.get_current_obs() 64 | return Step(ob, float(reward), done, distance=dist, task=self.goal) 65 | 66 | @overrides 67 | def reset(self, init_state=None, reset_args=None, **kwargs): 68 | 69 | qpos = self.init_qpos.copy().reshape(-1) 70 | qvel = self.init_qvel.copy().reshape(-1) + np.random.uniform(low=-0.005, 71 | high=0.005, size=self.model.nv) 72 | 73 | qvel[9:12] = 0 74 | 75 | self.goal = self.propose() 76 | qpos[-7:-5] = self.goal 77 | 78 | self.set_state(qpos.reshape(-1), qvel) 79 | 80 | self.current_com = self.model.data.com_subtree[0] 81 | self.dcom = np.zeros_like(self.current_com) 82 | return self.get_current_obs() 83 | 84 | def viewer_setup(self): 85 | self.viewer.cam.trackbodyid = -1 86 | self.viewer.cam.distance = 20.0 87 | self.viewer.cam.azimuth = +90.0 88 | self.viewer.cam.elevation = -20 89 | 90 | def retrieve_centers(self,full_states): 91 | return full_states[:,15:17] 92 | 93 | def propose_original(self): 94 | angle = self.angle_range[0] + (np.random.rand()*(self.angle_range[1]-self.angle_range[0])) 95 | magnitude = 5 96 | 97 | return np.array([ 98 | magnitude * np.cos(angle), 99 | magnitude * np.sin(angle) 100 | ]) 101 | 102 | @overrides 103 | def log_diagnostics(self, paths,prefix=''): 104 | min_distances = np.array([ 105 | np.min(path["env_infos"]['distance']) 106 | for path in paths 107 | ]) 108 | 109 | final_distances = np.array([ 110 | path["env_infos"]['distance'][-1] 111 | for path in paths 112 | ]) 113 | avgPct = lambda x: round(np.mean(x)*100,2) 114 | 115 | logger.record_tabular(prefix+'AverageMinDistanceToGoal', np.mean(min_distances)) 116 | logger.record_tabular(prefix+'MinMinDistanceToGoal', np.min(min_distances)) 117 | 118 | logger.record_tabular(prefix+'AverageFinalDistanceToGoal', np.mean(final_distances)) 119 | logger.record_tabular(prefix+'MinFinalDistanceToGoal', np.min(final_distances)) 120 | 121 | logger.record_tabular(prefix+'PctInGoal', avgPct(progsFinal < .2)) 122 | -------------------------------------------------------------------------------- /dnc/envs/assets/ant.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 91 | -------------------------------------------------------------------------------- /dnc/envs/assets/catch.xml: -------------------------------------------------------------------------------- 1 | 18 | 19 | 20 | 21 | 178 | -------------------------------------------------------------------------------- /dnc/envs/assets/jaco_meshes/jaco_link_1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/dnc/envs/assets/jaco_meshes/jaco_link_1.stl -------------------------------------------------------------------------------- /dnc/envs/assets/jaco_meshes/jaco_link_2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/dnc/envs/assets/jaco_meshes/jaco_link_2.stl -------------------------------------------------------------------------------- /dnc/envs/assets/jaco_meshes/jaco_link_3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/dnc/envs/assets/jaco_meshes/jaco_link_3.stl -------------------------------------------------------------------------------- /dnc/envs/assets/jaco_meshes/jaco_link_4.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/dnc/envs/assets/jaco_meshes/jaco_link_4.stl -------------------------------------------------------------------------------- /dnc/envs/assets/jaco_meshes/jaco_link_5.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/dnc/envs/assets/jaco_meshes/jaco_link_5.stl -------------------------------------------------------------------------------- /dnc/envs/assets/jaco_meshes/jaco_link_base.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/dnc/envs/assets/jaco_meshes/jaco_link_base.stl -------------------------------------------------------------------------------- /dnc/envs/assets/jaco_meshes/jaco_link_finger_1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/dnc/envs/assets/jaco_meshes/jaco_link_finger_1.stl -------------------------------------------------------------------------------- /dnc/envs/assets/jaco_meshes/jaco_link_finger_2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/dnc/envs/assets/jaco_meshes/jaco_link_finger_2.stl -------------------------------------------------------------------------------- /dnc/envs/assets/jaco_meshes/jaco_link_finger_3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/dnc/envs/assets/jaco_meshes/jaco_link_finger_3.stl -------------------------------------------------------------------------------- /dnc/envs/assets/jaco_meshes/jaco_link_hand.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/dnc/envs/assets/jaco_meshes/jaco_link_hand.stl -------------------------------------------------------------------------------- /dnc/envs/assets/lob.xml: -------------------------------------------------------------------------------- 1 | 18 | 19 | 20 | 21 | 183 | -------------------------------------------------------------------------------- /dnc/envs/assets/picker.xml: -------------------------------------------------------------------------------- 1 | 18 | 19 | 20 | 21 | 163 | -------------------------------------------------------------------------------- /dnc/envs/base.py: -------------------------------------------------------------------------------- 1 | import rllab.envs.mujoco.mujoco_env as mujoco_env 2 | from rllab.core.serializable import Serializable 3 | from sklearn.cluster import KMeans 4 | 5 | import numpy as np 6 | 7 | class MujocoEnv(mujoco_env.MujocoEnv): 8 | def __init__(self, frame_skip=1, *args, **kwargs): 9 | self.bd_index = None 10 | super().__init__(*args, **kwargs) 11 | self.frame_skip = frame_skip 12 | self.geom_names_to_indices = {name:index for index,name in enumerate(self.model.geom_names)} 13 | 14 | 15 | def set_state(self, qpos, qvel): 16 | assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,) 17 | self.model.data.qpos = qpos 18 | self.model.data.qvel = qvel 19 | self.model._compute_subtree() # pylint: disable=W0212 20 | self.model.forward() 21 | 22 | def get_body_com(self, body_name): 23 | # Speeds up getting body positions 24 | 25 | if self.bd_index is None: 26 | self.bd_index = {name:index for index,name in enumerate(self.model.body_names)} 27 | 28 | idx = self.bd_index[body_name] 29 | return self.model.data.com_subtree[idx] 30 | 31 | def touching(self, geom1_name, geom2_name): 32 | idx1 = self.geom_names_to_indices[geom1_name] 33 | idx2 = self.geom_names_to_indices[geom2_name] 34 | for c in self.model.data.contact: 35 | if (c.geom1 == idx1 and c.geom2 == idx2) or (c.geom1 == idx2 and c.geom2 == idx1): 36 | return True 37 | return False 38 | 39 | def touching_group(self, geom1_name, geom2_names): 40 | idx1 = self.geom_names_to_indices[geom1_name] 41 | idx2s = set([self.geom_names_to_indices[geom2_name] for geom2_name in geom2_names]) 42 | 43 | for c in self.model.data.contact: 44 | if (c.geom1 == idx1 and c.geom2 in idx2s) or (c.geom1 in idx2s and c.geom2 == idx1): 45 | return True 46 | return False 47 | 48 | def viewer_setup(self): 49 | """ 50 | This method is called when the viewer is initialized and after every reset 51 | Optionally implement this method, if you need to tinker with camera position 52 | and so forth. 53 | """ 54 | pass 55 | 56 | def get_viewer(self): 57 | if self.viewer is None: 58 | viewer = super().get_viewer() 59 | self.viewer_setup() 60 | return viewer 61 | else: 62 | return self.viewer 63 | 64 | 65 | class KMeansEnv(MujocoEnv): 66 | def __init__(self,kmeans_args=None,*args,**kwargs): 67 | if kmeans_args is None: 68 | self.kmeans = False 69 | else: 70 | self.kmeans = True 71 | self.kmeans_centers = kmeans_args['centers'] 72 | self.kmeans_index = kmeans_args['index'] 73 | 74 | super(KMeansEnv, self).__init__(*args, **kwargs) 75 | 76 | def propose_original(self): 77 | raise NotImplementedError() 78 | 79 | def propose_kmeans(self): 80 | while True: 81 | proposal = self.propose_original() 82 | distances = np.linalg.norm(self.kmeans_centers-proposal,axis=1) 83 | if np.argmin(distances) == self.kmeans_index: 84 | return proposal 85 | 86 | def propose(self): 87 | if self.kmeans: 88 | return self.propose_kmeans() 89 | else: 90 | return self.propose_original() 91 | 92 | def create_partitions(self,n=10000,k=3): 93 | X = np.array([self.reset() for i in range(n)]) 94 | kmeans = KMeans(n_clusters=k).fit(X) 95 | return self.retrieve_centers(kmeans.cluster_centers_) 96 | 97 | def retrieve_centers(self,full_states): 98 | raise NotImplementedError() 99 | 100 | def get_param_values(self): 101 | if self.kmeans: 102 | return dict(kmeans=True, centers=self.kmeans_centers, index=self.kmeans_index) 103 | else: 104 | return dict(kmeans=False) 105 | 106 | def set_param_values(self, params): 107 | self.kmeans = params['kmeans'] 108 | if self.kmeans: 109 | self.kmeans_centers = params['centers'] 110 | self.kmeans_index = params['index'] 111 | 112 | 113 | def create_env_partitions(env, k=4): 114 | 115 | assert isinstance(env, KMeansEnv) 116 | cluster_centers = env.create_partitions(k=k) 117 | 118 | envs = [env.clone(env) for i in range(k)] 119 | for i,local_env in enumerate(envs): 120 | local_env.kmeans = True 121 | local_env.kmeans_centers = cluster_centers 122 | local_env.kmeans_index = i 123 | 124 | return envs 125 | -------------------------------------------------------------------------------- /dnc/envs/catch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from dnc.envs.base import KMeansEnv 3 | 4 | from rllab.core.serializable import Serializable 5 | from rllab.envs.base import Step 6 | from rllab.misc import logger 7 | from rllab.misc.overrides import overrides 8 | 9 | import os.path as osp 10 | 11 | 12 | class CatchEnv(KMeansEnv, Serializable): 13 | 14 | FILE = osp.join(osp.abspath(osp.dirname(__file__)), 'assets/catch.xml') 15 | 16 | def __init__( 17 | self, start_pos=(.1,1.7), start_noise=0.2, frame_skip=2, 18 | *args, **kwargs): 19 | 20 | self.start_pos = start_pos 21 | self.start_noise = start_noise 22 | 23 | super(CatchEnv, self).__init__(frame_skip=frame_skip, *args, **kwargs) 24 | Serializable.__init__(self, start_pos, start_noise, frame_skip, *args, **kwargs) 25 | 26 | def get_current_obs(self): 27 | return np.concatenate([ 28 | self.model.data.qpos.flat[:], 29 | self.model.data.qvel.flat[:], 30 | ]).reshape(-1) 31 | 32 | def step(self, action): 33 | self.forward_dynamics(action) 34 | 35 | ball_position = self.get_body_com("object") 36 | 37 | mitt_position = self.get_body_com("jaco_link_finger_1") + self.get_body_com("jaco_link_finger_2") + self.get_body_com("jaco_link_finger_3") 38 | mitt_position = mitt_position/3. 39 | difference = np.linalg.norm((ball_position - mitt_position)) 40 | 41 | reward = 1 if (difference < .15 and ball_position[2] > .05) else 0 42 | 43 | done = False 44 | return Step(self.get_current_obs(), float(reward), done, distance=difference, in_hand=(reward == 1)) 45 | 46 | @overrides 47 | def reset(self): 48 | 49 | qpos = self.init_qpos.copy().reshape(-1) 50 | qvel = self.init_qvel.copy().reshape(-1) 51 | 52 | qpos[6] += 1 53 | qpos[10] += 1.5 54 | 55 | qvel[:3] = [-7, 0, 0] 56 | 57 | self.current_start = self.propose() # Proposal 58 | qpos[1:3] = self.current_start 59 | 60 | self.set_state(qpos.reshape(-1), qvel) 61 | 62 | self.current_com = self.model.data.com_subtree[0] 63 | self.dcom = np.zeros_like(self.current_com) 64 | self.last_location = self.get_body_com("object") 65 | 66 | return self.get_current_obs() 67 | 68 | def retrieve_centers(self,full_states): 69 | return full_states[:, 1:3] 70 | 71 | def propose_original(self): 72 | return (np.random.rand(2) - 0.5) * 2 * self.start_noise + np.array(self.start_pos) 73 | 74 | def viewer_setup(self): 75 | self.viewer.cam.trackbodyid = -1 76 | self.viewer.cam.distance = 3.5 77 | self.viewer.cam.azimuth = -180 78 | self.viewer.cam.elevation = -20 79 | 80 | @overrides 81 | def log_diagnostics(self, paths,prefix=''): 82 | progs = np.array([ 83 | np.sum(path['env_infos']['in_hand'] > 0) 84 | for path in paths 85 | ]+[1]) 86 | 87 | 88 | logger.record_tabular(prefix+'PercentWithReward', 100*np.mean(progs > 0)) 89 | logger.record_tabular(prefix+'PercentOK', 100*np.mean(progs > 5)) 90 | logger.record_tabular(prefix+'PercentGood', 100*np.mean(progs > 15)) 91 | logger.record_tabular(prefix+'AverageTimeInHand', self.frame_skip*.01*np.mean(progs[progs.nonzero()])) 92 | 93 | if __name__ == "__main__": 94 | env = CatchEnv() 95 | for itr in range(5): 96 | env.reset() 97 | for t in range(50): 98 | env.step(env.action_space.sample()) 99 | env.render() -------------------------------------------------------------------------------- /dnc/envs/lob.py: -------------------------------------------------------------------------------- 1 | from dnc.envs.base import KMeansEnv 2 | import numpy as np 3 | 4 | from rllab.core.serializable import Serializable 5 | from rllab.envs.base import Step 6 | from rllab.misc.overrides import overrides 7 | from rllab.misc import logger 8 | 9 | import os.path as osp 10 | 11 | 12 | class LobberEnv(KMeansEnv, Serializable): 13 | 14 | FILE = osp.join(osp.abspath(osp.dirname(__file__)), 'assets/lob.xml') 15 | 16 | def __init__(self, box_center=(0,0), box_noise=0.4, frame_skip=5, *args, **kwargs): 17 | 18 | self.box_center = box_center 19 | self.box_noise = box_noise 20 | 21 | super(LobberEnv, self).__init__(frame_skip=frame_skip, *args, **kwargs) 22 | Serializable.__init__(self, box_center, box_noise, frame_skip, *args, **kwargs) 23 | 24 | def get_current_obs(self): 25 | finger_com = self.get_body_com("jaco_link_finger_1") + self.get_body_com("jaco_link_finger_2") + self.get_body_com("jaco_link_finger_3") 26 | finger_com = finger_com / 3. 27 | 28 | return np.concatenate([ 29 | self.model.data.qpos.flat[:], 30 | self.model.data.qvel.flat[:], 31 | finger_com, 32 | self.relativeBoxPosition, 33 | ]).reshape(-1) 34 | 35 | def step(self,action): 36 | 37 | self.model.data.ctrl = action 38 | 39 | # Taking Steps in the Environment 40 | 41 | reward = 0 42 | for _ in range(self.frame_skip): 43 | self.model.step() 44 | step_reward = self.timestep_reward() 45 | reward += step_reward 46 | 47 | # Reached the End of Trajectory 48 | 49 | done = False 50 | onGround = self.touching_group("geom_object", ["ground", "goal_wall1", "goal_wall2", "goal_wall3", "goal_wall4"]) 51 | if onGround and self.numClose > 10: 52 | reward += self.final_reward() 53 | done = True 54 | 55 | ob = self.get_current_obs() 56 | new_com = self.model.data.com_subtree[0] 57 | self.dcom = new_com - self.current_com 58 | self.current_com = new_com 59 | 60 | # Recording Metrics 61 | 62 | obj_position = self.get_body_com("object") 63 | goal_position = self.get_body_com("goal") 64 | distance = np.linalg.norm((goal_position - obj_position)[:2]) 65 | normalizedDistance = distance / self.init_block_goal_dist 66 | 67 | return Step(ob, float(reward), done, distance=distance, norm_distance=normalizedDistance) 68 | 69 | @overrides 70 | def reset(self): 71 | self.numClose = 0 72 | 73 | qpos = self.init_qpos.copy().reshape(-1) 74 | qvel = self.init_qvel.copy().reshape(-1) + np.random.uniform(low=-0.005, 75 | high=0.005, size=self.model.nv) 76 | 77 | qpos[1] = -1 78 | qpos[9:12] = np.array((0.6, 0.2,0.03)) 79 | qvel[9:12] = 0 80 | 81 | self.relativeBoxPosition = self.propose() # Proposal 82 | qpos[-2:] += self.relativeBoxPosition 83 | 84 | self.set_state(qpos.reshape(-1), qvel) 85 | 86 | # Save initial distance between object and goal 87 | obj_position = self.get_body_com("object") 88 | goal_position = self.get_body_com("goal") 89 | self.init_block_goal_dist = np.linalg.norm(obj_position - goal_position) 90 | 91 | self.current_com = self.model.data.com_subtree[0] 92 | self.dcom = np.zeros_like(self.current_com) 93 | return self.get_current_obs() 94 | 95 | def timestep_reward(self): 96 | obj_position = self.get_body_com("object") 97 | 98 | if obj_position[2] < 0.08: 99 | return 0 100 | 101 | finger_com = self.get_body_com("jaco_link_finger_1") + self.get_body_com("jaco_link_finger_2") + self.get_body_com("jaco_link_finger_3") 102 | finger_com = finger_com / 3. 103 | 104 | vec_1 = obj_position - finger_com 105 | dist_1 = np.linalg.norm(vec_1) 106 | 107 | if dist_1 < .1 and obj_position[0] > .2: 108 | self.numClose += 1 109 | return obj_position[2] 110 | else: 111 | return 0 112 | 113 | def final_reward(self): 114 | obj_position = self.get_body_com("object") 115 | goal_position = self.get_body_com("goal") 116 | 117 | vec_2 = obj_position - goal_position 118 | dist_2 = np.linalg.norm(vec_2[:2]) 119 | normalized_dist_2 = dist_2 / self.init_block_goal_dist 120 | clipped_dist_2 = min(1.0, normalized_dist_2) 121 | 122 | if dist_2 < .1: 123 | return 40 124 | 125 | reward = 1 - clipped_dist_2 126 | 127 | return 40 * reward 128 | 129 | def retrieve_centers(self,full_states): 130 | return full_states[:,16:18]-self.init_qpos.copy().reshape(-1)[-2:] 131 | 132 | def propose_original(self): 133 | return np.array(self.box_center) + 2*(np.random.random(2)-0.5)*self.box_noise 134 | 135 | def viewer_setup(self): 136 | self.viewer.cam.trackbodyid = -1 137 | self.viewer.cam.distance = 4.0 138 | self.viewer.cam.azimuth = +60.0 139 | self.viewer.cam.elevation = -30 140 | 141 | @overrides 142 | def log_diagnostics(self, paths, prefix=''): 143 | 144 | progs = np.array([ 145 | path['env_infos']['norm_distance'][-1] for path in paths 146 | ]) 147 | 148 | inGoal = np.array([ 149 | path['env_infos']['distance'][-1] < .1 for path in paths 150 | ]) 151 | 152 | avgPct = lambda x: round(np.mean(x)*100,3) 153 | 154 | logger.record_tabular(prefix+'PctInGoal', avgPct(inGoal)) 155 | logger.record_tabular(prefix+'AverageFinalDistance', np.mean(progs)) 156 | logger.record_tabular(prefix+'MinFinalDistance', np.min(progs )) -------------------------------------------------------------------------------- /dnc/envs/picker.py: -------------------------------------------------------------------------------- 1 | from dnc.envs.base import KMeansEnv 2 | from rllab.core.serializable import Serializable 3 | import numpy as np 4 | 5 | from rllab.envs.base import Step 6 | from rllab.misc.overrides import overrides 7 | from rllab.misc import logger 8 | 9 | import os.path as osp 10 | 11 | class PickerEnv(KMeansEnv, Serializable): 12 | """ 13 | Picking a block, where the block position is randomized over a square region 14 | 15 | goal_args is of form ('noisy', center_of_box, half-width of box) 16 | 17 | """ 18 | FILE = osp.join(osp.abspath(osp.dirname(__file__)), 'assets/picker.xml') 19 | 20 | def __init__(self, goal_args=('noisy', (.6,.2), .1), frame_skip=5, *args, **kwargs): 21 | 22 | self.goal_args = goal_args 23 | 24 | super(PickerEnv, self).__init__(frame_skip=frame_skip, *args, **kwargs) 25 | Serializable.__init__(self, goal_args, frame_skip, *args, **kwargs) 26 | 27 | def get_current_obs(self): 28 | finger_com = self.get_body_com("jaco_link_finger_1") + self.get_body_com("jaco_link_finger_2") + self.get_body_com("jaco_link_finger_3") 29 | finger_com = finger_com / 3. 30 | 31 | return np.concatenate([ 32 | self.model.data.qpos.flat[:], 33 | self.model.data.qvel.flat[:], 34 | finger_com, 35 | ]).reshape(-1) 36 | 37 | def step(self,action): 38 | self.model.data.ctrl = action 39 | 40 | reward = 0 41 | timesInHand = 0 42 | 43 | for _ in range(self.frame_skip): 44 | self.model.step() 45 | step_reward = self.reward() 46 | timesInHand += step_reward > 0 47 | reward += step_reward 48 | 49 | done = reward == 0 and self.numClose > 0 # Stop it if the block is flinged 50 | 51 | ob = self.get_current_obs() 52 | 53 | new_com = self.model.data.com_subtree[0] 54 | self.dcom = new_com - self.current_com 55 | self.current_com = new_com 56 | 57 | return Step(ob, float(reward), done,timeInHand=timesInHand) 58 | 59 | def reward(self): 60 | obj_position = self.get_body_com("object") 61 | 62 | if obj_position[2] < 0.08: 63 | return 0 64 | 65 | finger_com = self.get_body_com("jaco_link_finger_1") + self.get_body_com("jaco_link_finger_2") + self.get_body_com("jaco_link_finger_3") 66 | finger_com = finger_com / 3. 67 | 68 | vec_1 = obj_position - finger_com 69 | dist_1 = np.linalg.norm(vec_1) 70 | 71 | if dist_1 < .1 and obj_position[0] > .2: 72 | self.numClose += 1 73 | return obj_position[2] 74 | else: 75 | return 0 76 | 77 | def sample_position(self,goal_type,center=(0.6,0.2),noise=0): 78 | if goal_type == 'fixed': 79 | return [center[0],center[1],.03] 80 | elif goal_type == 'noisy': 81 | x,y = center 82 | return [x+(np.random.rand()-0.5)*2*noise,y+(np.random.rand()-0.5)*2*noise,.03] 83 | else: 84 | raise NotImplementedError() 85 | 86 | def retrieve_centers(self,full_states): 87 | return full_states[:,9:12] 88 | 89 | def propose_original(self): 90 | return self.sample_position(*self.goal_args) 91 | 92 | @overrides 93 | def reset(self): 94 | qpos = self.init_qpos.copy().reshape(-1) 95 | qvel = self.init_qvel.copy().reshape(-1) + np.random.uniform(low=-0.005, 96 | high=0.005, size=self.model.nv) 97 | 98 | qpos[1] = -1 99 | 100 | self.position = self.propose() # Proposal 101 | qpos[9:12] = self.position 102 | qvel[9:12] = 0 103 | 104 | self.set_state(qpos.reshape(-1), qvel) 105 | 106 | self.numClose = 0 107 | 108 | self.current_com = self.model.data.com_subtree[0] 109 | self.dcom = np.zeros_like(self.current_com) 110 | return self.get_current_obs() 111 | 112 | def viewer_setup(self): 113 | self.viewer.cam.trackbodyid = -1 114 | self.viewer.cam.distance = 4.0 115 | self.viewer.cam.azimuth = +0.0 116 | self.viewer.cam.elevation = -40 117 | 118 | @overrides 119 | def log_diagnostics(self, paths, prefix=''): 120 | 121 | timeOffGround = np.array([ 122 | np.sum(path['env_infos']['timeInHand'])*.01 123 | for path in paths]) 124 | 125 | timeInAir = timeOffGround[timeOffGround.nonzero()] 126 | 127 | if len(timeInAir) == 0: 128 | timeInAir = [0] 129 | 130 | avgPct = lambda x: round(np.mean(x) * 100, 2) 131 | 132 | logger.record_tabular(prefix+'PctPicked', avgPct(timeOffGround > .3)) 133 | logger.record_tabular(prefix+'PctReceivedReward', avgPct(timeOffGround > 0)) 134 | 135 | logger.record_tabular(prefix+'AverageTimeInAir',np.mean(timeOffGround)) 136 | logger.record_tabular(prefix+'MaxTimeInAir',np.max(timeOffGround )) -------------------------------------------------------------------------------- /dnc/sampler/policy_sampler.py: -------------------------------------------------------------------------------- 1 | from sandbox.rocky.tf.algos.batch_polopt import BatchPolopt 2 | 3 | 4 | class Sampler(BatchPolopt): 5 | """ 6 | Creates a dummy class for sampling: 7 | please refer to sandbox.rocky.tf.algos.batch_polopt:BatchPolopt 8 | for all options 9 | """ 10 | def init_opt(self): 11 | pass 12 | -------------------------------------------------------------------------------- /examples/dnc_pick.py: -------------------------------------------------------------------------------- 1 | # Environment Imports 2 | from sandbox.rocky.tf.envs.base import TfEnv 3 | from rllab.envs.normalized_env import normalize 4 | import dnc.envs as dnc_envs 5 | 6 | # Algo Imports 7 | 8 | import dnc.algos.trpo as dnc_trpo 9 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline 10 | from sandbox.rocky.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy 11 | 12 | # Experiment Imports 13 | 14 | from rllab.misc.instrument import stub, run_experiment_lite 15 | 16 | def run_task(args,*_): 17 | 18 | base_env = dnc_envs.create_stochastic('pick') 19 | base_partitions = dnc_envs.create_env_partitions(base_env, k=4) 20 | 21 | env = TfEnv(normalize(base_env)) 22 | partitions = [TfEnv(normalize(part_env)) for part_env in base_partitions] 23 | 24 | policy_class = GaussianMLPPolicy 25 | policy_kwargs = dict( 26 | min_std=1e-2, 27 | hidden_sizes=(150, 100, 50), 28 | ) 29 | 30 | baseline_class = LinearFeatureBaseline 31 | 32 | algo = dnc_trpo.TRPO( 33 | env=env, 34 | partitions=partitions, 35 | policy_class=policy_class, 36 | policy_kwargs=policy_kwargs, 37 | baseline_class=baseline_class, 38 | batch_size=20000, 39 | n_itr=500, 40 | force_batch_sampler=True, 41 | max_path_length=50, 42 | discount=1, 43 | step_size=0.02, 44 | ) 45 | 46 | algo.train() 47 | 48 | run_experiment_lite( 49 | run_task, 50 | log_dir='data/dnc/pick', 51 | n_parallel=12, 52 | snapshot_mode="last", 53 | seed=1, 54 | variant=dict(), 55 | use_cloudpickle=True, 56 | ) 57 | -------------------------------------------------------------------------------- /examples/trpo_pick.py: -------------------------------------------------------------------------------- 1 | # Environment Imports 2 | from sandbox.rocky.tf.envs.base import TfEnv 3 | from rllab.envs.normalized_env import normalize 4 | import dnc.envs as dnc_envs 5 | 6 | # Algo Imports 7 | 8 | from sandbox.rocky.tf.algos.trpo import TRPO 9 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline 10 | from sandbox.rocky.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy 11 | 12 | # Experiment Imports 13 | 14 | from rllab.misc.instrument import stub, run_experiment_lite 15 | 16 | def run_task(args,*_): 17 | 18 | env = TfEnv(normalize(dnc_envs.create_stochastic('pick'))) # Cannot be solved easily by TRPO 19 | 20 | policy = GaussianMLPPolicy( 21 | name="policy", 22 | env_spec=env.spec, 23 | min_std=1e-2, 24 | hidden_sizes=(150, 100, 50), 25 | ) 26 | 27 | baseline = LinearFeatureBaseline(env_spec=env.spec) 28 | 29 | algo = TRPO( 30 | env=env, 31 | policy=policy, 32 | baseline=baseline, 33 | batch_size=50000, 34 | force_batch_sampler=True, 35 | max_path_length=50, 36 | discount=1, 37 | step_size=0.02, 38 | ) 39 | 40 | algo.train() 41 | 42 | 43 | run_experiment_lite( 44 | run_task, 45 | log_dir='data/trpo/pick', 46 | n_parallel=12, 47 | snapshot_mode="last", 48 | seed=1, 49 | variant=dict(), 50 | use_cloudpickle=True 51 | ) 52 | -------------------------------------------------------------------------------- /examples/trpo_pick_nonoise.py: -------------------------------------------------------------------------------- 1 | # Environment Imports 2 | from sandbox.rocky.tf.envs.base import TfEnv 3 | from rllab.envs.normalized_env import normalize 4 | import dnc.envs as dnc_envs 5 | 6 | # Algo Imports 7 | 8 | from sandbox.rocky.tf.algos.trpo import TRPO 9 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline 10 | from sandbox.rocky.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy 11 | 12 | # Experiment Imports 13 | 14 | from rllab.misc.instrument import stub, run_experiment_lite 15 | 16 | def run_task(args,*_): 17 | 18 | env = TfEnv(normalize(dnc_envs.create_deterministic('pick'))) # No stochasticity in initial state: should be solved easily by TRPO 19 | 20 | policy = GaussianMLPPolicy( 21 | name="policy", 22 | env_spec=env.spec, 23 | min_std=1e-2, 24 | hidden_sizes=(150, 100, 50), 25 | ) 26 | 27 | baseline = LinearFeatureBaseline(env_spec=env.spec) 28 | 29 | algo = TRPO( 30 | env=env, 31 | policy=policy, 32 | baseline=baseline, 33 | batch_size=50000, 34 | force_batch_sampler=True, 35 | max_path_length=50, 36 | discount=.98, 37 | step_size=0.02, 38 | ) 39 | algo.train() 40 | 41 | 42 | run_experiment_lite( 43 | run_task, 44 | log_dir='data/trpo_nonoise/pick', 45 | n_parallel=12, 46 | snapshot_mode="last", 47 | seed=1, 48 | variant=dict(), 49 | use_cloudpickle=True, 50 | ) 51 | -------------------------------------------------------------------------------- /scripts/sim_policy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import joblib 4 | import tensorflow as tf 5 | 6 | from rllab.misc.console import query_yes_no 7 | from rllab.sampler.utils import rollout 8 | 9 | if __name__ == "__main__": 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('file', type=str, 13 | help='path to the snapshot file') 14 | parser.add_argument('--number', type=int, default=-1, 15 | help='Which policy to play (-1 = center): default=-1') 16 | 17 | parser.add_argument('--max_path_length', type=int, default=1000, 18 | help='Max length of rollout') 19 | parser.add_argument('--speedup', type=float, default=1, 20 | help='Speedup') 21 | args = parser.parse_args() 22 | 23 | # If the snapshot file use tensorflow, do: 24 | # import tensorflow as tf 25 | # with tf.Session(): 26 | # [rest of the code] 27 | with tf.Session() as sess: 28 | data = joblib.load(args.file) 29 | 30 | if args.number == -1: 31 | policy = data['policy'] 32 | env = data['env'] 33 | else: 34 | policy = data['policy%d'%args.number] 35 | env = data['env%d'%args.number] 36 | 37 | while True: 38 | path = rollout(env, policy, max_path_length=args.max_path_length, 39 | animated=True, speedup=args.speedup) 40 | if not query_yes_no('Continue simulation?'): 41 | break -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # setup.py 2 | from setuptools import setup 3 | 4 | setup( 5 | name='dnc', 6 | version='1.0.0', 7 | packages=[], 8 | ) 9 | -------------------------------------------------------------------------------- /videos/catching.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/videos/catching.gif -------------------------------------------------------------------------------- /videos/lobbing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dibyaghosh/dnc/e5069d8934bd9d41760941f21b5d20ee05f57afe/videos/lobbing.gif --------------------------------------------------------------------------------