├── .gitignore ├── LICENSE ├── README.md ├── inverse_rl ├── __init__.py ├── algos │ ├── batch_polopt.py │ ├── irl_batch_polopt.py │ ├── irl_npo.py │ ├── irl_trpo.py │ ├── npo.py │ ├── penalty_lbfgs_optimizer.py │ └── trpo.py ├── envs │ ├── __init__.py │ ├── ant_env.py │ ├── assets │ │ └── twod_maze.xml │ ├── dynamic_mjc │ │ ├── __init__.py │ │ ├── mjc_models.py │ │ └── model_builder.py │ ├── env_utils.py │ ├── point_maze_env.py │ ├── pusher_env.py │ ├── twod_maze.py │ ├── twod_mjc_env.py │ ├── utils.py │ └── visual_pointmass.py ├── models │ ├── __init__.py │ ├── airl_state.py │ ├── architectures.py │ ├── fusion_manager.py │ ├── imitation_learning.py │ └── tf_util.py └── utils │ ├── __init__.py │ ├── general.py │ ├── hyper_sweep.py │ ├── hyperparametrized.py │ ├── log_utils.py │ └── math_utils.py ├── scripts ├── ant_data_collect.py ├── ant_irl.py ├── ant_transfer_disabled.py ├── pendulum_data_collect.py ├── pendulum_gail.py └── pendulum_irl.py └── tabular_maxent_irl ├── README.md ├── maxent_irl.py ├── q_iteration.py ├── simple_env.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # dotenv 85 | .env 86 | 87 | # virtualenv 88 | .venv 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Justin Fu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Inverse RL 2 | 3 | Implementations for imitation learning / IRL algorithms in RLLAB 4 | 5 | Contains: 6 | - GAIL (https://arxiv.org/abs/1606.03476/pdf) 7 | - Guided Cost Learning, GAN formulation (https://arxiv.org/pdf/1611.03852.pdf) 8 | - Tabular MaxCausalEnt IRL (http://www.cs.cmu.edu/~bziebart/publications/thesis-bziebart.pdf) 9 | 10 | Setup 11 | --- 12 | This library requires: 13 | - rllab (https://github.com/openai/rllab) 14 | - Tensorflow 15 | 16 | Examples 17 | --- 18 | 19 | Running the Pendulum-v0 gym environment: 20 | 21 | 1) Collect expert data 22 | ``` 23 | python scripts/pendulum_data_collect.py 24 | ``` 25 | 26 | You should get an "AverageReturn" of around -100 to -150 27 | 28 | 2) Run imitation learning 29 | ``` 30 | python scripts/pendulum_gcl.py 31 | ``` 32 | 33 | The "OriginalTaskAverageReturn" should reach around -100 to -150 34 | -------------------------------------------------------------------------------- /inverse_rl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinjfu/inverse_rl/9609933389459a3a54f5c01d652114ada90fa1b3/inverse_rl/__init__.py -------------------------------------------------------------------------------- /inverse_rl/algos/batch_polopt.py: -------------------------------------------------------------------------------- 1 | import time 2 | from rllab.algos.base import RLAlgorithm 3 | import rllab.misc.logger as logger 4 | from sandbox.rocky.tf.policies.base import Policy 5 | import tensorflow as tf 6 | from sandbox.rocky.tf.samplers.batch_sampler import BatchSampler 7 | from sandbox.rocky.tf.samplers.vectorized_sampler import VectorizedSampler 8 | from rllab.sampler.utils import rollout 9 | 10 | from inverse_rl.utils.hyperparametrized import Hyperparametrized 11 | 12 | 13 | class BatchPolopt(RLAlgorithm, metaclass=Hyperparametrized): 14 | """ 15 | Base class for batch sampling-based policy optimization methods. 16 | This includes various policy gradient methods like vpg, npg, ppo, trpo, etc. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | env, 22 | policy, 23 | baseline, 24 | scope=None, 25 | n_itr=500, 26 | start_itr=0, 27 | batch_size=5000, 28 | max_path_length=500, 29 | discount=0.99, 30 | gae_lambda=1, 31 | plot=False, 32 | pause_for_plot=False, 33 | center_adv=True, 34 | positive_adv=False, 35 | store_paths=False, 36 | whole_paths=True, 37 | fixed_horizon=False, 38 | sampler_cls=None, 39 | sampler_args=None, 40 | force_batch_sampler=False, 41 | **kwargs 42 | ): 43 | """ 44 | :param env: Environment 45 | :param policy: Policy 46 | :type policy: Policy 47 | :param baseline: Baseline 48 | :param scope: Scope for identifying the algorithm. Must be specified if running multiple algorithms 49 | simultaneously, each using different environments and policies 50 | :param n_itr: Number of iterations. 51 | :param start_itr: Starting iteration. 52 | :param batch_size: Number of samples per iteration. 53 | :param max_path_length: Maximum length of a single rollout. 54 | :param discount: Discount. 55 | :param gae_lambda: Lambda used for generalized advantage estimation. 56 | :param plot: Plot evaluation run after each iteration. 57 | :param pause_for_plot: Whether to pause before contiuing when plotting. 58 | :param center_adv: Whether to rescale the advantages so that they have mean 0 and standard deviation 1. 59 | :param positive_adv: Whether to shift the advantages so that they are always positive. When used in 60 | conjunction with center_adv the advantages will be standardized before shifting. 61 | :param store_paths: Whether to save all paths data to the snapshot. 62 | :return: 63 | """ 64 | self.env = env 65 | self.policy = policy 66 | self.baseline = baseline 67 | self.scope = scope 68 | self.n_itr = n_itr 69 | self.start_itr = start_itr 70 | self.batch_size = batch_size 71 | self.max_path_length = max_path_length 72 | self.discount = discount 73 | self.gae_lambda = gae_lambda 74 | self.plot = plot 75 | self.pause_for_plot = pause_for_plot 76 | self.center_adv = center_adv 77 | self.positive_adv = positive_adv 78 | self.store_paths = store_paths 79 | self.whole_paths = whole_paths 80 | self.fixed_horizon = fixed_horizon 81 | if sampler_cls is None: 82 | if self.policy.vectorized and not force_batch_sampler: 83 | sampler_cls = VectorizedSampler 84 | else: 85 | sampler_cls = BatchSampler 86 | if sampler_args is None: 87 | sampler_args = dict() 88 | self.sampler = sampler_cls(self, **sampler_args) 89 | self.init_opt() 90 | 91 | def start_worker(self): 92 | self.sampler.start_worker() 93 | 94 | def shutdown_worker(self): 95 | self.sampler.shutdown_worker() 96 | 97 | def obtain_samples(self, itr): 98 | return self.sampler.obtain_samples(itr) 99 | 100 | def process_samples(self, itr, paths): 101 | return self.sampler.process_samples(itr, paths) 102 | 103 | def train(self, sess=None): 104 | created_session = True if (sess is None) else False 105 | if sess is None: 106 | sess = tf.Session() 107 | sess.__enter__() 108 | 109 | sess.run(tf.global_variables_initializer()) 110 | self.start_worker() 111 | start_time = time.time() 112 | for itr in range(self.start_itr, self.n_itr): 113 | itr_start_time = time.time() 114 | with logger.prefix('itr #%d | ' % itr): 115 | logger.log("Obtaining samples...") 116 | paths = self.obtain_samples(itr) 117 | logger.log("Processing samples...") 118 | samples_data = self.process_samples(itr, paths) 119 | logger.log("Logging diagnostics...") 120 | self.log_diagnostics(paths) 121 | logger.log("Optimizing policy...") 122 | self.optimize_policy(itr, samples_data) 123 | logger.log("Saving snapshot...") 124 | params = self.get_itr_snapshot(itr, samples_data) # , **kwargs) 125 | if self.store_paths: 126 | params["paths"] = samples_data["paths"] 127 | logger.save_itr_params(itr, params) 128 | logger.log("Saved") 129 | logger.record_tabular('Time', time.time() - start_time) 130 | logger.record_tabular('ItrTime', time.time() - itr_start_time) 131 | logger.dump_tabular(with_prefix=False) 132 | if self.plot: 133 | rollout(self.env, self.policy, animated=True, max_path_length=self.max_path_length) 134 | if self.pause_for_plot: 135 | input("Plotting evaluation run: Press Enter to " 136 | "continue...") 137 | self.shutdown_worker() 138 | if created_session: 139 | sess.close() 140 | 141 | def log_diagnostics(self, paths): 142 | self.env.log_diagnostics(paths) 143 | self.policy.log_diagnostics(paths) 144 | self.baseline.log_diagnostics(paths) 145 | 146 | def init_opt(self): 147 | """ 148 | Initialize the optimization procedure. If using tensorflow, this may 149 | include declaring all the variables and compiling functions 150 | """ 151 | raise NotImplementedError 152 | 153 | def get_itr_snapshot(self, itr, samples_data): 154 | """ 155 | Returns all the data that should be saved in the snapshot for this 156 | iteration. 157 | """ 158 | raise NotImplementedError 159 | 160 | def optimize_policy(self, itr, samples_data): 161 | raise NotImplementedError 162 | -------------------------------------------------------------------------------- /inverse_rl/algos/irl_batch_polopt.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from rllab.algos.base import RLAlgorithm 4 | import rllab.misc.logger as logger 5 | import rllab.plotter as plotter 6 | from sandbox.rocky.tf.policies.base import Policy 7 | import tensorflow as tf 8 | from sandbox.rocky.tf.samplers.batch_sampler import BatchSampler 9 | from sandbox.rocky.tf.samplers.vectorized_sampler import VectorizedSampler 10 | import numpy as np 11 | from collections import deque 12 | 13 | from inverse_rl.utils.hyperparametrized import Hyperparametrized 14 | 15 | 16 | class IRLBatchPolopt(RLAlgorithm, metaclass=Hyperparametrized): 17 | """ 18 | Base class for batch sampling-based policy optimization methods. 19 | This includes various policy gradient methods like vpg, npg, ppo, trpo, etc. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | env, 25 | policy, 26 | baseline, 27 | scope=None, 28 | n_itr=500, 29 | start_itr=0, 30 | batch_size=5000, 31 | max_path_length=500, 32 | discount=0.99, 33 | gae_lambda=1, 34 | plot=False, 35 | pause_for_plot=False, 36 | center_adv=True, 37 | positive_adv=False, 38 | store_paths=True, 39 | whole_paths=True, 40 | fixed_horizon=False, 41 | sampler_cls=None, 42 | sampler_args=None, 43 | force_batch_sampler=False, 44 | init_pol_params = None, 45 | irl_model=None, 46 | irl_model_wt=1.0, 47 | discrim_train_itrs=10, 48 | zero_environment_reward=False, 49 | init_irl_params=None, 50 | train_irl=True, 51 | key='', 52 | **kwargs 53 | ): 54 | """ 55 | :param env: Environment 56 | :param policy: Policy 57 | :type policy: Policy 58 | :param baseline: Baseline 59 | :param scope: Scope for identifying the algorithm. Must be specified if running multiple algorithms 60 | simultaneously, each using different environments and policies 61 | :param n_itr: Number of iterations. 62 | :param start_itr: Starting iteration. 63 | :param batch_size: Number of samples per iteration. 64 | :param max_path_length: Maximum length of a single rollout. 65 | :param discount: Discount. 66 | :param gae_lambda: Lambda used for generalized advantage estimation. 67 | :param plot: Plot evaluation run after each iteration. 68 | :param pause_for_plot: Whether to pause before contiuing when plotting. 69 | :param center_adv: Whether to rescale the advantages so that they have mean 0 and standard deviation 1. 70 | :param positive_adv: Whether to shift the advantages so that they are always positive. When used in 71 | conjunction with center_adv the advantages will be standardized before shifting. 72 | :param store_paths: Whether to save all paths data to the snapshot. 73 | :return: 74 | """ 75 | self.env = env 76 | self.policy = policy 77 | self.baseline = baseline 78 | self.scope = scope 79 | self.n_itr = n_itr 80 | self.start_itr = start_itr 81 | self.batch_size = batch_size 82 | self.max_path_length = max_path_length 83 | self.discount = discount 84 | self.gae_lambda = gae_lambda 85 | self.plot = plot 86 | self.pause_for_plot = pause_for_plot 87 | self.center_adv = center_adv 88 | self.positive_adv = positive_adv 89 | self.store_paths = store_paths 90 | self.whole_paths = whole_paths 91 | self.fixed_horizon = fixed_horizon 92 | self.init_pol_params = init_pol_params 93 | self.init_irl_params = init_irl_params 94 | self.irl_model = irl_model 95 | self.irl_model_wt = irl_model_wt 96 | self.no_reward = zero_environment_reward 97 | self.discrim_train_itrs = discrim_train_itrs 98 | self.train_irl = train_irl 99 | self.__irl_params = None 100 | 101 | if self.irl_model_wt > 0: 102 | assert self.irl_model is not None, "Need to specify a IRL model" 103 | 104 | if sampler_cls is None: 105 | if self.policy.vectorized and not force_batch_sampler: 106 | print('using vec sampler') 107 | sampler_cls = VectorizedSampler 108 | else: 109 | print('using batch sampler') 110 | sampler_cls = BatchSampler 111 | if sampler_args is None: 112 | sampler_args = dict() 113 | self.sampler = sampler_cls(self, **sampler_args) 114 | self.init_opt() 115 | 116 | def start_worker(self): 117 | self.sampler.start_worker() 118 | if self.plot: 119 | plotter.init_plot(self.env, self.policy) 120 | 121 | def shutdown_worker(self): 122 | self.sampler.shutdown_worker() 123 | 124 | def obtain_samples(self, itr): 125 | return self.sampler.obtain_samples(itr) 126 | 127 | def process_samples(self, itr, paths): 128 | #processed = self.sampler.process_samples(itr, paths) 129 | return self.sampler.process_samples(itr, paths) 130 | 131 | def log_avg_returns(self, paths): 132 | undiscounted_returns = [sum(path["rewards"]) for path in paths] 133 | avg_return = np.mean(undiscounted_returns) 134 | return avg_return 135 | 136 | def get_irl_params(self): 137 | return self.__irl_params 138 | 139 | def compute_irl(self, paths, itr=0): 140 | if self.no_reward: 141 | tot_rew = 0 142 | for path in paths: 143 | tot_rew += np.sum(path['rewards']) 144 | path['rewards'] *= 0 145 | logger.record_tabular('OriginalTaskAverageReturn', tot_rew/float(len(paths))) 146 | 147 | if self.irl_model_wt <=0: 148 | return paths 149 | 150 | if self.train_irl: 151 | max_itrs = self.discrim_train_itrs 152 | lr=1e-3 153 | mean_loss = self.irl_model.fit(paths, policy=self.policy, itr=itr, max_itrs=max_itrs, lr=lr, 154 | logger=logger) 155 | 156 | logger.record_tabular('IRLLoss', mean_loss) 157 | self.__irl_params = self.irl_model.get_params() 158 | 159 | probs = self.irl_model.eval(paths, gamma=self.discount, itr=itr) 160 | 161 | logger.record_tabular('IRLRewardMean', np.mean(probs)) 162 | logger.record_tabular('IRLRewardMax', np.max(probs)) 163 | logger.record_tabular('IRLRewardMin', np.min(probs)) 164 | 165 | 166 | if self.irl_model.score_trajectories: 167 | # TODO: should I add to reward here or after advantage computation? 168 | for i, path in enumerate(paths): 169 | path['rewards'][-1] += self.irl_model_wt * probs[i] 170 | else: 171 | for i, path in enumerate(paths): 172 | path['rewards'] += self.irl_model_wt * probs[i] 173 | return paths 174 | 175 | def train(self): 176 | sess = tf.get_default_session() 177 | sess.run(tf.global_variables_initializer()) 178 | if self.init_pol_params is not None: 179 | self.policy.set_param_values(self.init_pol_params) 180 | if self.init_irl_params is not None: 181 | self.irl_model.set_params(self.init_irl_params) 182 | self.start_worker() 183 | start_time = time.time() 184 | 185 | returns = [] 186 | for itr in range(self.start_itr, self.n_itr): 187 | itr_start_time = time.time() 188 | with logger.prefix('itr #%d | ' % itr): 189 | logger.log("Obtaining samples...") 190 | paths = self.obtain_samples(itr) 191 | 192 | logger.log("Processing samples...") 193 | paths = self.compute_irl(paths, itr=itr) 194 | returns.append(self.log_avg_returns(paths)) 195 | samples_data = self.process_samples(itr, paths) 196 | 197 | logger.log("Logging diagnostics...") 198 | self.log_diagnostics(paths) 199 | logger.log("Optimizing policy...") 200 | self.optimize_policy(itr, samples_data) 201 | logger.log("Saving snapshot...") 202 | params = self.get_itr_snapshot(itr, samples_data) # , **kwargs) 203 | if self.store_paths: 204 | params["paths"] = samples_data["paths"] 205 | logger.save_itr_params(itr, params) 206 | logger.log("Saved") 207 | logger.record_tabular('Time', time.time() - start_time) 208 | logger.record_tabular('ItrTime', time.time() - itr_start_time) 209 | logger.dump_tabular(with_prefix=False) 210 | if self.plot: 211 | self.update_plot() 212 | if self.pause_for_plot: 213 | input("Plotting evaluation run: Press Enter to " 214 | "continue...") 215 | self.shutdown_worker() 216 | return 217 | 218 | def log_diagnostics(self, paths): 219 | self.env.log_diagnostics(paths) 220 | self.policy.log_diagnostics(paths) 221 | self.baseline.log_diagnostics(paths) 222 | 223 | def init_opt(self): 224 | """ 225 | Initialize the optimization procedure. If using tensorflow, this may 226 | include declaring all the variables and compiling functions 227 | """ 228 | raise NotImplementedError 229 | 230 | def get_itr_snapshot(self, itr, samples_data): 231 | """ 232 | Returns all the data that should be saved in the snapshot for this 233 | iteration. 234 | """ 235 | raise NotImplementedError 236 | 237 | def optimize_policy(self, itr, samples_data): 238 | raise NotImplementedError 239 | 240 | def update_plot(self): 241 | if self.plot: 242 | plotter.update_plot(self.policy, self.max_path_length) 243 | -------------------------------------------------------------------------------- /inverse_rl/algos/irl_npo.py: -------------------------------------------------------------------------------- 1 | from rllab.misc import ext 2 | from rllab.misc.overrides import overrides 3 | import rllab.misc.logger as logger 4 | from inverse_rl.algos.penalty_lbfgs_optimizer import PenaltyLbfgsOptimizer 5 | from inverse_rl.algos.irl_batch_polopt import IRLBatchPolopt 6 | from sandbox.rocky.tf.misc import tensor_utils 7 | import tensorflow as tf 8 | import numpy as np 9 | from sandbox.rocky.tf.optimizers.conjugate_gradient_optimizer import ConjugateGradientOptimizer 10 | 11 | 12 | class IRLNPO(IRLBatchPolopt): 13 | """ 14 | Natural Policy Optimization. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | optimizer=None, 20 | optimizer_args=None, 21 | step_size=0.01, 22 | entropy_weight=1.0, 23 | **kwargs): 24 | if optimizer is None: 25 | if optimizer_args is None: 26 | optimizer_args = dict(name='lbfgs') 27 | optimizer = PenaltyLbfgsOptimizer(**optimizer_args) 28 | self.optimizer = optimizer 29 | self.step_size = step_size 30 | self.pol_ent_wt = entropy_weight 31 | super(IRLNPO, self).__init__(**kwargs) 32 | 33 | @overrides 34 | def init_opt(self): 35 | is_recurrent = int(self.policy.recurrent) 36 | obs_var = self.env.observation_space.new_tensor_variable( 37 | 'obs', 38 | extra_dims=1 + is_recurrent, 39 | ) 40 | action_var = self.env.action_space.new_tensor_variable( 41 | 'action', 42 | extra_dims=1 + is_recurrent, 43 | ) 44 | advantage_var = tensor_utils.new_tensor( 45 | 'advantage', 46 | ndim=1 + is_recurrent, 47 | dtype=tf.float32, 48 | ) 49 | 50 | input_list = [ 51 | obs_var, 52 | action_var, 53 | advantage_var, 54 | ] 55 | 56 | dist = self.policy.distribution 57 | 58 | old_dist_info_vars = { 59 | k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name='old_%s' % k) 60 | for k, shape in dist.dist_info_specs 61 | } 62 | old_dist_info_vars_list = [old_dist_info_vars[k] for k in dist.dist_info_keys] 63 | 64 | state_info_vars = { 65 | k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name=k) 66 | for k, shape in self.policy.state_info_specs 67 | } 68 | state_info_vars_list = [state_info_vars[k] for k in self.policy.state_info_keys] 69 | 70 | if is_recurrent: 71 | valid_var = tf.placeholder(tf.float32, shape=[None, None], name="valid") 72 | else: 73 | valid_var = None 74 | 75 | dist_info_vars = self.policy.dist_info_sym(obs_var, state_info_vars) 76 | kl = dist.kl_sym(old_dist_info_vars, dist_info_vars) 77 | lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars, dist_info_vars) 78 | 79 | if self.pol_ent_wt > 0: 80 | if 'log_std' in dist_info_vars: 81 | log_std = dist_info_vars['log_std'] 82 | ent = tf.reduce_sum(log_std + tf.log(tf.sqrt(2 * np.pi * np.e)), reduction_indices=-1) 83 | elif 'prob' in dist_info_vars: 84 | prob = dist_info_vars['prob'] 85 | ent = -tf.reduce_sum(prob*tf.log(prob), reduction_indices=-1) 86 | else: 87 | raise NotImplementedError() 88 | ent = tf.stop_gradient(ent) 89 | adv = advantage_var + self.pol_ent_wt*ent 90 | else: 91 | adv = advantage_var 92 | 93 | 94 | if is_recurrent: 95 | mean_kl = tf.reduce_sum(kl * valid_var) / tf.reduce_sum(valid_var) 96 | surr_loss = - tf.reduce_sum(lr * adv * valid_var) / tf.reduce_sum(valid_var) 97 | else: 98 | mean_kl = tf.reduce_mean(kl) 99 | surr_loss = - tf.reduce_mean(lr * adv) 100 | 101 | input_list += state_info_vars_list + old_dist_info_vars_list 102 | if is_recurrent: 103 | input_list.append(valid_var) 104 | 105 | self.optimizer.update_opt( 106 | loss=surr_loss, 107 | target=self.policy, 108 | leq_constraint=(mean_kl, self.step_size), 109 | inputs=input_list, 110 | constraint_name="mean_kl" 111 | ) 112 | return dict() 113 | 114 | @overrides 115 | def optimize_policy(self, itr, samples_data): 116 | all_input_values = tuple(ext.extract( 117 | samples_data, 118 | "observations", "actions", "advantages", 119 | )) 120 | 121 | agent_infos = samples_data["agent_infos"] 122 | state_info_list = [agent_infos[k] for k in self.policy.state_info_keys] 123 | dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys] 124 | all_input_values += tuple(state_info_list) + tuple(dist_info_list) 125 | if self.policy.recurrent: 126 | all_input_values += (samples_data["valids"],) 127 | logger.log("Computing loss before") 128 | loss_before = self.optimizer.loss(all_input_values) 129 | logger.log("Computing KL before") 130 | mean_kl_before = self.optimizer.constraint_val(all_input_values) 131 | logger.log("Optimizing") 132 | self.optimizer.optimize(all_input_values) 133 | logger.log("Computing KL after") 134 | mean_kl = self.optimizer.constraint_val(all_input_values) 135 | logger.log("Computing loss after") 136 | loss_after = self.optimizer.loss(all_input_values) 137 | logger.record_tabular('LossBefore', loss_before) 138 | logger.record_tabular('LossAfter', loss_after) 139 | logger.record_tabular('MeanKLBefore', mean_kl_before) 140 | logger.record_tabular('MeanKL', mean_kl) 141 | logger.record_tabular('dLoss', loss_before - loss_after) 142 | return dict() 143 | 144 | @overrides 145 | def get_itr_snapshot(self, itr, samples_data): 146 | return dict( 147 | itr=itr, 148 | policy=self.policy, 149 | policy_params=self.policy.get_param_values(), 150 | irl_params=self.get_irl_params(), 151 | baseline=self.baseline, 152 | env=self.env, 153 | ) 154 | -------------------------------------------------------------------------------- /inverse_rl/algos/irl_trpo.py: -------------------------------------------------------------------------------- 1 | from inverse_rl.algos.irl_npo import IRLNPO 2 | from sandbox.rocky.tf.optimizers.conjugate_gradient_optimizer import ConjugateGradientOptimizer 3 | 4 | 5 | class IRLTRPO(IRLNPO): 6 | """ 7 | Trust Region Policy Optimization 8 | """ 9 | 10 | def __init__( 11 | self, 12 | optimizer=None, 13 | optimizer_args=None, 14 | **kwargs): 15 | if optimizer is None: 16 | if optimizer_args is None: 17 | optimizer_args = dict() 18 | optimizer = ConjugateGradientOptimizer(**optimizer_args) 19 | super(IRLTRPO, self).__init__(optimizer=optimizer, **kwargs) 20 | -------------------------------------------------------------------------------- /inverse_rl/algos/npo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rllab.misc import ext 4 | from rllab.misc.overrides import overrides 5 | import rllab.misc.logger as logger 6 | from sandbox.rocky.tf.optimizers.penalty_lbfgs_optimizer import PenaltyLbfgsOptimizer 7 | from sandbox.rocky.tf.misc import tensor_utils 8 | import tensorflow as tf 9 | 10 | from inverse_rl.algos.batch_polopt import BatchPolopt 11 | 12 | 13 | class NPO(BatchPolopt): 14 | """ 15 | Natural Policy Optimization. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | optimizer=None, 21 | optimizer_args=None, 22 | step_size=0.01, 23 | entropy_weight=0.0, 24 | **kwargs): 25 | if optimizer is None: 26 | if optimizer_args is None: 27 | optimizer_args = dict() 28 | optimizer = PenaltyLbfgsOptimizer(**optimizer_args) 29 | self.optimizer = optimizer 30 | self.step_size = step_size 31 | self.pol_ent_wt = entropy_weight 32 | super(NPO, self).__init__(**kwargs) 33 | 34 | @overrides 35 | def init_opt(self): 36 | is_recurrent = int(self.policy.recurrent) 37 | obs_var = self.env.observation_space.new_tensor_variable( 38 | 'obs', 39 | extra_dims=1 + is_recurrent, 40 | ) 41 | action_var = self.env.action_space.new_tensor_variable( 42 | 'action', 43 | extra_dims=1 + is_recurrent, 44 | ) 45 | advantage_var = tensor_utils.new_tensor( 46 | 'advantage', 47 | ndim=1 + is_recurrent, 48 | dtype=tf.float32, 49 | ) 50 | dist = self.policy.distribution 51 | 52 | old_dist_info_vars = { 53 | k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name='old_%s' % k) 54 | for k, shape in dist.dist_info_specs 55 | } 56 | old_dist_info_vars_list = [old_dist_info_vars[k] for k in dist.dist_info_keys] 57 | 58 | state_info_vars = { 59 | k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name=k) 60 | for k, shape in self.policy.state_info_specs 61 | } 62 | state_info_vars_list = [state_info_vars[k] for k in self.policy.state_info_keys] 63 | 64 | if is_recurrent: 65 | valid_var = tf.placeholder(tf.float32, shape=[None, None], name="valid") 66 | else: 67 | valid_var = None 68 | 69 | dist_info_vars = self.policy.dist_info_sym(obs_var, state_info_vars) 70 | kl = dist.kl_sym(old_dist_info_vars, dist_info_vars) 71 | lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars, dist_info_vars) 72 | 73 | if self.pol_ent_wt > 0: 74 | if 'log_std' in dist_info_vars: 75 | log_std = dist_info_vars['log_std'] 76 | ent = tf.reduce_sum(log_std + tf.log(tf.sqrt(2 * np.pi * np.e)), reduction_indices=-1) 77 | elif 'prob' in dist_info_vars: 78 | prob = dist_info_vars['prob'] 79 | ent = - tf.reduce_sum(prob*tf.log(prob), reduction_indices=-1) 80 | else: 81 | raise NotImplementedError() 82 | ent = tf.stop_gradient(ent) 83 | adv = advantage_var + self.pol_ent_wt*ent 84 | else: 85 | adv = advantage_var 86 | 87 | if is_recurrent: 88 | mean_kl = tf.reduce_sum(kl * valid_var) / tf.reduce_sum(valid_var) 89 | surr_loss = - tf.reduce_sum(lr * adv * valid_var) / tf.reduce_sum(valid_var) 90 | else: 91 | mean_kl = tf.reduce_mean(kl) 92 | surr_loss = - tf.reduce_mean(lr * adv) 93 | 94 | input_list = [ 95 | obs_var, 96 | action_var, 97 | advantage_var, 98 | ] + state_info_vars_list + old_dist_info_vars_list 99 | if is_recurrent: 100 | input_list.append(valid_var) 101 | 102 | self.optimizer.update_opt( 103 | loss=surr_loss, 104 | target=self.policy, 105 | leq_constraint=(mean_kl, self.step_size), 106 | inputs=input_list, 107 | constraint_name="mean_kl" 108 | ) 109 | return dict() 110 | 111 | @overrides 112 | def optimize_policy(self, itr, samples_data): 113 | all_input_values = tuple(ext.extract( 114 | samples_data, 115 | "observations", "actions", "advantages" 116 | )) 117 | agent_infos = samples_data["agent_infos"] 118 | state_info_list = [agent_infos[k] for k in self.policy.state_info_keys] 119 | dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys] 120 | all_input_values += tuple(state_info_list) + tuple(dist_info_list) 121 | if self.policy.recurrent: 122 | all_input_values += (samples_data["valids"],) 123 | logger.log("Computing loss before") 124 | loss_before = self.optimizer.loss(all_input_values) 125 | logger.log("Computing KL before") 126 | mean_kl_before = self.optimizer.constraint_val(all_input_values) 127 | logger.log("Optimizing") 128 | self.optimizer.optimize(all_input_values) 129 | logger.log("Computing KL after") 130 | mean_kl = self.optimizer.constraint_val(all_input_values) 131 | logger.log("Computing loss after") 132 | loss_after = self.optimizer.loss(all_input_values) 133 | logger.record_tabular('LossBefore', loss_before) 134 | logger.record_tabular('LossAfter', loss_after) 135 | logger.record_tabular('MeanKLBefore', mean_kl_before) 136 | logger.record_tabular('MeanKL', mean_kl) 137 | logger.record_tabular('dLoss', loss_before - loss_after) 138 | return dict() 139 | 140 | @overrides 141 | def get_itr_snapshot(self, itr, samples_data): 142 | return dict( 143 | itr=itr, 144 | policy=self.policy, 145 | baseline=self.baseline, 146 | env=self.env, 147 | ) 148 | -------------------------------------------------------------------------------- /inverse_rl/algos/penalty_lbfgs_optimizer.py: -------------------------------------------------------------------------------- 1 | import scipy 2 | import scipy.optimize 3 | import numpy as np 4 | import tensorflow as tf 5 | from rllab.core.serializable import Serializable 6 | from rllab.misc import ext 7 | from rllab.misc import logger 8 | from sandbox.rocky.tf.misc import tensor_utils 9 | 10 | 11 | class PenaltyLbfgsOptimizer(Serializable): 12 | """ 13 | Performs constrained optimization via penalized L-BFGS. The penalty term is adaptively adjusted to make sure that 14 | the constraint is satisfied. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | name, 20 | max_opt_itr=20, 21 | initial_penalty=1.0, 22 | min_penalty=1e-2, 23 | max_penalty=1e6, 24 | increase_penalty_factor=2, 25 | decrease_penalty_factor=0.5, 26 | max_penalty_itr=10, 27 | adapt_penalty=True): 28 | Serializable.quick_init(self, locals()) 29 | self._name = name 30 | self._max_opt_itr = max_opt_itr 31 | self._penalty = initial_penalty 32 | self._initial_penalty = initial_penalty 33 | self._min_penalty = min_penalty 34 | self._max_penalty = max_penalty 35 | self._increase_penalty_factor = increase_penalty_factor 36 | self._decrease_penalty_factor = decrease_penalty_factor 37 | self._max_penalty_itr = max_penalty_itr 38 | self._adapt_penalty = adapt_penalty 39 | 40 | self._opt_fun = None 41 | self._target = None 42 | self._max_constraint_val = None 43 | self._constraint_name = None 44 | 45 | def update_opt(self, loss, target, leq_constraint, inputs, constraint_name="constraint", *args, **kwargs): 46 | """ 47 | :param loss: Symbolic expression for the loss function. 48 | :param target: A parameterized object to optimize over. It should implement methods of the 49 | :class:`rllab.core.paramerized.Parameterized` class. 50 | :param leq_constraint: A constraint provided as a tuple (f, epsilon), of the form f(*inputs) <= epsilon. 51 | :param inputs: A list of symbolic variables as inputs 52 | :return: No return value. 53 | """ 54 | constraint_term, constraint_value = leq_constraint 55 | with tf.variable_scope(self._name): 56 | penalty_var = tf.placeholder(tf.float32, tuple(), name="penalty") 57 | penalized_loss = loss + penalty_var * constraint_term 58 | 59 | self._target = target 60 | self._max_constraint_val = constraint_value 61 | self._constraint_name = constraint_name 62 | 63 | def get_opt_output(): 64 | params = target.get_params(trainable=True) 65 | grads = tf.gradients(penalized_loss, params) 66 | for idx, (grad, param) in enumerate(zip(grads, params)): 67 | if grad is None: 68 | grads[idx] = tf.zeros_like(param) 69 | flat_grad = tensor_utils.flatten_tensor_variables(grads) 70 | return [ 71 | tf.cast(penalized_loss, tf.float64), 72 | tf.cast(flat_grad, tf.float64), 73 | ] 74 | 75 | self._opt_fun = ext.lazydict( 76 | f_loss=lambda: tensor_utils.compile_function(inputs, loss, log_name="f_loss"), 77 | f_constraint=lambda: tensor_utils.compile_function(inputs, constraint_term, log_name="f_constraint"), 78 | f_penalized_loss=lambda: tensor_utils.compile_function( 79 | inputs=inputs + [penalty_var], 80 | outputs=[penalized_loss, loss, constraint_term], 81 | log_name="f_penalized_loss", 82 | ), 83 | f_opt=lambda: tensor_utils.compile_function( 84 | inputs=inputs + [penalty_var], 85 | outputs=get_opt_output(), 86 | ) 87 | ) 88 | 89 | def loss(self, inputs): 90 | return self._opt_fun["f_loss"](*inputs) 91 | 92 | def constraint_val(self, inputs): 93 | return self._opt_fun["f_constraint"](*inputs) 94 | 95 | def optimize(self, inputs): 96 | 97 | inputs = tuple(inputs) 98 | 99 | try_penalty = np.clip( 100 | self._penalty, self._min_penalty, self._max_penalty) 101 | 102 | penalty_scale_factor = None 103 | f_opt = self._opt_fun["f_opt"] 104 | f_penalized_loss = self._opt_fun["f_penalized_loss"] 105 | 106 | def gen_f_opt(penalty): 107 | def f(flat_params): 108 | self._target.set_param_values(flat_params, trainable=True) 109 | return f_opt(*(inputs + (penalty,))) 110 | 111 | return f 112 | 113 | cur_params = self._target.get_param_values(trainable=True).astype('float64') 114 | opt_params = cur_params 115 | 116 | for penalty_itr in range(self._max_penalty_itr): 117 | logger.log('trying penalty=%.3f...' % try_penalty) 118 | 119 | itr_opt_params, _, _ = scipy.optimize.fmin_l_bfgs_b( 120 | func=gen_f_opt(try_penalty), x0=cur_params, 121 | maxiter=self._max_opt_itr 122 | ) 123 | 124 | _, try_loss, try_constraint_val = f_penalized_loss(*(inputs + (try_penalty,))) 125 | 126 | logger.log('penalty %f => loss %f, %s %f' % 127 | (try_penalty, try_loss, self._constraint_name, try_constraint_val)) 128 | 129 | # Either constraint satisfied, or we are at the last iteration already and no alternative parameter 130 | # satisfies the constraint 131 | if try_constraint_val < self._max_constraint_val or \ 132 | (penalty_itr == self._max_penalty_itr - 1 and opt_params is None): 133 | opt_params = itr_opt_params 134 | 135 | if not self._adapt_penalty: 136 | break 137 | 138 | # Decide scale factor on the first iteration, or if constraint violation yields numerical error 139 | if penalty_scale_factor is None or np.isnan(try_constraint_val): 140 | # Increase penalty if constraint violated, or if constraint term is NAN 141 | if try_constraint_val > self._max_constraint_val or np.isnan(try_constraint_val): 142 | penalty_scale_factor = self._increase_penalty_factor 143 | else: 144 | # Otherwise (i.e. constraint satisfied), shrink penalty 145 | penalty_scale_factor = self._decrease_penalty_factor 146 | opt_params = itr_opt_params 147 | else: 148 | if penalty_scale_factor > 1 and \ 149 | try_constraint_val <= self._max_constraint_val: 150 | break 151 | elif penalty_scale_factor < 1 and \ 152 | try_constraint_val >= self._max_constraint_val: 153 | break 154 | old_penalty = try_penalty 155 | try_penalty *= penalty_scale_factor 156 | try_penalty = np.clip(try_penalty, self._min_penalty, self._max_penalty) 157 | if try_penalty == old_penalty: 158 | break 159 | self._penalty = try_penalty 160 | 161 | self._target.set_param_values(opt_params, trainable=True) 162 | -------------------------------------------------------------------------------- /inverse_rl/algos/trpo.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from inverse_rl.algos.npo import NPO 4 | from sandbox.rocky.tf.optimizers.conjugate_gradient_optimizer import ConjugateGradientOptimizer 5 | 6 | 7 | class TRPO(NPO): 8 | """ 9 | Trust Region Policy Optimization 10 | """ 11 | 12 | def __init__( 13 | self, 14 | optimizer=None, 15 | optimizer_args=None, 16 | **kwargs): 17 | if optimizer is None: 18 | if optimizer_args is None: 19 | optimizer_args = dict() 20 | optimizer = ConjugateGradientOptimizer(**optimizer_args) 21 | super(TRPO, self).__init__(optimizer=optimizer, **kwargs) 22 | -------------------------------------------------------------------------------- /inverse_rl/envs/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from gym.envs import register 4 | 5 | LOGGER = logging.getLogger(__name__) 6 | 7 | _REGISTERED = False 8 | def register_custom_envs(): 9 | global _REGISTERED 10 | if _REGISTERED: 11 | return 12 | _REGISTERED = True 13 | 14 | LOGGER.info("Registering custom gym environments") 15 | register(id='ObjPusher-v0', entry_point='inverse_rl.envs.pusher_env:PusherEnv', kwargs={'sparse_reward': False}) 16 | register(id='TwoDMaze-v0', entry_point='inverse_rl.envs.twod_maze:TwoDMaze') 17 | register(id='PointMazeRight-v0', entry_point='inverse_rl.envs.point_maze_env:PointMazeEnv', 18 | kwargs={'sparse_reward': False, 'direction': 1}) 19 | register(id='PointMazeLeft-v0', entry_point='inverse_rl.envs.point_maze_env:PointMazeEnv', 20 | kwargs={'sparse_reward': False, 'direction': 0}) 21 | 22 | # A modified ant which flips over less and learns faster via TRPO 23 | register(id='CustomAnt-v0', entry_point='inverse_rl.envs.ant_env:CustomAntEnv', 24 | kwargs={'gear': 30, 'disabled': False}) 25 | register(id='DisabledAnt-v0', entry_point='inverse_rl.envs.ant_env:CustomAntEnv', 26 | kwargs={'gear': 30, 'disabled': True}) 27 | 28 | register(id='VisualPointMaze-v0', entry_point='inverse_rl.envs.visual_pointmass:VisualPointMazeEnv', 29 | kwargs={'sparse_reward': False, 'direction': 1}) 30 | -------------------------------------------------------------------------------- /inverse_rl/envs/ant_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import utils 3 | from gym.envs.mujoco import mujoco_env 4 | from inverse_rl.envs.dynamic_mjc.model_builder import MJCModel 5 | from rllab.misc import logger 6 | 7 | def ant_env(gear=150, eyes=True): 8 | mjcmodel = MJCModel('ant_maze') 9 | mjcmodel.root.compiler(inertiafromgeom="true", angle="degree", coordinate="local") 10 | mjcmodel.root.option(timestep="0.01", integrator="RK4") 11 | mjcmodel.root.custom().numeric(data="0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0",name="init_qpos") 12 | asset = mjcmodel.root.asset() 13 | asset.texture(builtin="gradient",height="100",rgb1="1 1 1",rgb2="0 0 0",type="skybox",width="100") 14 | asset.texture(builtin="flat",height="1278",mark="cross",markrgb="1 1 1",name="texgeom",random="0.01",rgb1="0.8 0.6 0.4",rgb2="0.8 0.6 0.4",type="cube",width="127") 15 | asset.texture(builtin="checker",height="100",name="texplane",rgb1="0 0 0",rgb2="0.8 0.8 0.8",type="2d",width="100") 16 | asset.material(name="MatPlane",reflectance="0.5",shininess="1",specular="1",texrepeat="60 60",texture="texplane") 17 | asset.material(name="geom",texture="texgeom",texuniform="true") 18 | 19 | default = mjcmodel.root.default() 20 | default.joint(armature=1, damping=1, limited='true') 21 | default.geom(friction=[1.5,0.5,0.5], density=5.0, margin=0.01, condim=3, conaffinity=0, rgba="0.8 0.6 0.4 1") 22 | 23 | worldbody = mjcmodel.root.worldbody() 24 | worldbody.light(cutoff="100",diffuse=[.8,.8,.8],dir="-0 0 -1.3",directional="true",exponent="1",pos="0 0 1.3",specular=".1 .1 .1") 25 | worldbody.geom(conaffinity=1, condim=3, material="MatPlane",name="floor",pos="0 0 0",rgba="0.8 0.9 0.8 1",size="40 40 40",type="plane") 26 | 27 | ant = worldbody.body(name='torso', pos=[0, 0, 0.75]) 28 | ant.geom(name='torso_geom', pos=[0, 0, 0], size="0.25", type="sphere") 29 | ant.joint(armature="0", damping="0", limited="false", margin="0.01", name="root", pos=[0, 0, 0], type="free") 30 | 31 | if eyes: 32 | eye_z = 0.1 33 | eye_y = -.21 34 | eye_x_offset = 0.07 35 | # eyes 36 | ant.geom(fromto=[eye_x_offset,0,eye_z,eye_x_offset,eye_y,eye_z], name='eye1', size='0.03', type='capsule', rgba=[1,1,1,1]) 37 | ant.geom(fromto=[eye_x_offset,0,eye_z,eye_x_offset,eye_y-0.02,eye_z], name='eye1_', size='0.02', type='capsule', rgba=[0,0,0,1]) 38 | ant.geom(fromto=[-eye_x_offset,0,eye_z,-eye_x_offset,eye_y,eye_z], name='eye2', size='0.03', type='capsule', rgba=[1,1,1,1]) 39 | ant.geom(fromto=[-eye_x_offset,0,eye_z,-eye_x_offset,eye_y-0.02,eye_z], name='eye2_', size='0.02', type='capsule', rgba=[0,0,0,1]) 40 | # eyebrows 41 | ant.geom(fromto=[eye_x_offset-0.03,eye_y, eye_z+0.07, eye_x_offset+0.03, eye_y, eye_z+0.1], name='brow1', size='0.02', type='capsule', rgba=[0,0,0,1]) 42 | ant.geom(fromto=[-eye_x_offset+0.03,eye_y, eye_z+0.07, -eye_x_offset-0.03, eye_y, eye_z+0.1], name='brow2', size='0.02', type='capsule', rgba=[0,0,0,1]) 43 | 44 | front_left_leg = ant.body(name="front_left_leg", pos=[0, 0, 0]) 45 | front_left_leg.geom(fromto=[0.0, 0.0, 0.0, 0.2, 0.2, 0.0], name="aux_1_geom", size="0.08", type="capsule") 46 | aux_1 = front_left_leg.body(name="aux_1", pos=[0.2, 0.2, 0]) 47 | aux_1.joint(axis=[0, 0, 1], name="hip_1", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge") 48 | aux_1.geom(fromto=[0.0, 0.0, 0.0, 0.2, 0.2, 0.0], name="left_leg_geom", size="0.08", type="capsule") 49 | ankle_1 = aux_1.body(pos=[0.2, 0.2, 0]) 50 | ankle_1.joint(axis=[-1, 1, 0], name="ankle_1", pos=[0.0, 0.0, 0.0], range=[30, 70], type="hinge") 51 | ankle_1.geom(fromto=[0.0, 0.0, 0.0, 0.4, 0.4, 0.0], name="left_ankle_geom", size="0.08", type="capsule") 52 | 53 | front_right_leg = ant.body(name="front_right_leg", pos=[0, 0, 0]) 54 | front_right_leg.geom(fromto=[0.0, 0.0, 0.0, -0.2, 0.2, 0.0], name="aux_2_geom", size="0.08", type="capsule") 55 | aux_2 = front_right_leg.body(name="aux_2", pos=[-0.2, 0.2, 0]) 56 | aux_2.joint(axis=[0, 0, 1], name="hip_2", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge") 57 | aux_2.geom(fromto=[0.0, 0.0, 0.0, -0.2, 0.2, 0.0], name="right_leg_geom", size="0.08", type="capsule") 58 | ankle_2 = aux_2.body(pos=[-0.2, 0.2, 0]) 59 | ankle_2.joint(axis=[1, 1, 0], name="ankle_2", pos=[0.0, 0.0, 0.0], range=[-70, -30], type="hinge") 60 | ankle_2.geom(fromto=[0.0, 0.0, 0.0, -0.4, 0.4, 0.0], name="right_ankle_geom", size="0.08", type="capsule") 61 | 62 | back_left_leg = ant.body(name="back_left_leg", pos=[0, 0, 0]) 63 | back_left_leg.geom(fromto=[0.0, 0.0, 0.0, -0.2, -0.2, 0.0], name="aux_3_geom", size="0.08", type="capsule") 64 | aux_3 = back_left_leg.body(name="aux_3", pos=[-0.2, -0.2, 0]) 65 | aux_3.joint(axis=[0, 0, 1], name="hip_3", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge") 66 | aux_3.geom(fromto=[0.0, 0.0, 0.0, -0.2, -0.2, 0.0], name="backleft_leg_geom", size="0.08", type="capsule") 67 | ankle_3 = aux_3.body(pos=[-0.2, -0.2, 0]) 68 | ankle_3.joint(axis=[-1, 1, 0], name="ankle_3", pos=[0.0, 0.0, 0.0], range=[-70, -30], type="hinge") 69 | ankle_3.geom(fromto=[0.0, 0.0, 0.0, -0.4, -0.4, 0.0], name="backleft_ankle_geom", size="0.08", type="capsule") 70 | 71 | back_right_leg = ant.body(name="back_right_leg", pos=[0, 0, 0]) 72 | back_right_leg.geom(fromto=[0.0, 0.0, 0.0, 0.2, -0.2, 0.0], name="aux_4_geom", size="0.08", type="capsule") 73 | aux_4 = back_right_leg.body(name="aux_4", pos=[0.2, -0.2, 0]) 74 | aux_4.joint(axis=[0, 0, 1], name="hip_4", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge") 75 | aux_4.geom(fromto=[0.0, 0.0, 0.0, 0.2, -0.2, 0.0], name="backright_leg_geom", size="0.08", type="capsule") 76 | ankle_4 = aux_4.body(pos=[0.2, -0.2, 0]) 77 | ankle_4.joint(axis=[1, 1, 0], name="ankle_4", pos=[0.0, 0.0, 0.0], range=[30, 70], type="hinge") 78 | ankle_4.geom(fromto=[0.0, 0.0, 0.0, 0.4, -0.4, 0.0], name="backright_ankle_geom", size="0.08", type="capsule") 79 | 80 | actuator = mjcmodel.root.actuator() 81 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_4", gear=gear) 82 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_4", gear=gear) 83 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_1", gear=gear) 84 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_1", gear=gear) 85 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_2", gear=gear) 86 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_2", gear=gear) 87 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_3", gear=gear) 88 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_3", gear=gear) 89 | return mjcmodel 90 | 91 | 92 | def angry_ant_crippled(gear=150): 93 | mjcmodel = MJCModel('ant_maze') 94 | mjcmodel.root.compiler(inertiafromgeom="true", angle="degree", coordinate="local") 95 | mjcmodel.root.option(timestep="0.01", integrator="RK4") 96 | mjcmodel.root.custom().numeric(data="0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0",name="init_qpos") 97 | asset = mjcmodel.root.asset() 98 | asset.texture(builtin="gradient",height="100",rgb1="1 1 1",rgb2="0 0 0",type="skybox",width="100") 99 | asset.texture(builtin="flat",height="1278",mark="cross",markrgb="1 1 1",name="texgeom",random="0.01",rgb1="0.8 0.6 0.4",rgb2="0.8 0.6 0.4",type="cube",width="127") 100 | asset.texture(builtin="checker",height="100",name="texplane",rgb1="0 0 0",rgb2="0.8 0.8 0.8",type="2d",width="100") 101 | asset.material(name="MatPlane",reflectance="0.5",shininess="1",specular="1",texrepeat="60 60",texture="texplane") 102 | asset.material(name="geom",texture="texgeom",texuniform="true") 103 | 104 | 105 | 106 | default = mjcmodel.root.default() 107 | default.joint(armature=1, damping=1, limited='true') 108 | default.geom(friction=[1.5,0.5,0.5], density=5.0, margin=0.01, condim=3, conaffinity=0, rgba="0.8 0.6 0.4 1") 109 | 110 | worldbody = mjcmodel.root.worldbody() 111 | 112 | worldbody.geom(conaffinity=1, condim=3, material="MatPlane",name="floor",pos="0 0 0",rgba="0.8 0.9 0.8 1",size="40 40 40",type="plane") 113 | worldbody.light(cutoff="100",diffuse=[.8,.8,.8],dir="-0 0 -1.3",directional="true",exponent="1",pos="0 0 1.3",specular=".1 .1 .1") 114 | 115 | 116 | ant = worldbody.body(name='torso', pos=[0, 0, 0.75]) 117 | ant.geom(name='torso_geom', pos=[0, 0, 0], size="0.25", type="sphere") 118 | ant.joint(armature="0", damping="0", limited="false", margin="0.01", name="root", pos=[0, 0, 0], type="free") 119 | 120 | eye_z = 0.1 121 | eye_y = -.21 122 | eye_x_offset = 0.07 123 | # eyes 124 | ant.geom(fromto=[eye_x_offset,0,eye_z,eye_x_offset,eye_y,eye_z], name='eye1', size='0.03', type='capsule', rgba=[1,1,1,1]) 125 | ant.geom(fromto=[eye_x_offset,0,eye_z,eye_x_offset,eye_y-0.02,eye_z], name='eye1_', size='0.02', type='capsule', rgba=[0,0,0,1]) 126 | ant.geom(fromto=[-eye_x_offset,0,eye_z,-eye_x_offset,eye_y,eye_z], name='eye2', size='0.03', type='capsule', rgba=[1,1,1,1]) 127 | ant.geom(fromto=[-eye_x_offset,0,eye_z,-eye_x_offset,eye_y-0.02,eye_z], name='eye2_', size='0.02', type='capsule', rgba=[0,0,0,1]) 128 | # eyebrows 129 | ant.geom(fromto=[eye_x_offset-0.03,eye_y, eye_z+0.07, eye_x_offset+0.03, eye_y, eye_z+0.1], name='brow1', size='0.02', type='capsule', rgba=[0,0,0,1]) 130 | ant.geom(fromto=[-eye_x_offset+0.03,eye_y, eye_z+0.07, -eye_x_offset-0.03, eye_y, eye_z+0.1], name='brow2', size='0.02', type='capsule', rgba=[0,0,0,1]) 131 | 132 | 133 | 134 | 135 | front_left_leg = ant.body(name="front_left_leg", pos=[0, 0, 0]) 136 | front_left_leg.geom(fromto=[0.0, 0.0, 0.0, 0.2, 0.2, 0.0], name="aux_1_geom", size="0.08", type="capsule") 137 | aux_1 = front_left_leg.body(name="aux_1", pos=[0.2, 0.2, 0]) 138 | aux_1.joint(axis=[0, 0, 1], name="hip_1", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge") 139 | aux_1.geom(fromto=[0.0, 0.0, 0.0, 0.2, 0.2, 0.0], name="left_leg_geom", size="0.08", type="capsule") 140 | ankle_1 = aux_1.body(pos=[0.2, 0.2, 0]) 141 | ankle_1.joint(axis=[-1, 1, 0], name="ankle_1", pos=[0.0, 0.0, 0.0], range=[30, 70], type="hinge") 142 | ankle_1.geom(fromto=[0.0, 0.0, 0.0, 0.4, 0.4, 0.0], name="left_ankle_geom", size="0.08", type="capsule") 143 | 144 | front_right_leg = ant.body(name="front_right_leg", pos=[0, 0, 0]) 145 | front_right_leg.geom(fromto=[0.0, 0.0, 0.0, -0.2, 0.2, 0.0], name="aux_2_geom", size="0.08", type="capsule") 146 | aux_2 = front_right_leg.body(name="aux_2", pos=[-0.2, 0.2, 0]) 147 | aux_2.joint(axis=[0, 0, 1], name="hip_2", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge") 148 | aux_2.geom(fromto=[0.0, 0.0, 0.0, -0.2, 0.2, 0.0], name="right_leg_geom", size="0.08", type="capsule") 149 | ankle_2 = aux_2.body(pos=[-0.2, 0.2, 0]) 150 | ankle_2.joint(axis=[1, 1, 0], name="ankle_2", pos=[0.0, 0.0, 0.0], range=[-70, -30], type="hinge") 151 | ankle_2.geom(fromto=[0.0, 0.0, 0.0, -0.4, 0.4, 0.0], name="right_ankle_geom", size="0.08", type="capsule") 152 | 153 | # Back left leg is crippled 154 | thigh_length = 0.1 #0.2 155 | ankle_length = 0.2 #0.4 156 | dark_red = [0.8,0.3,0.3,1.0] 157 | 158 | back_left_leg = ant.body(name="back_left_leg", pos=[0, 0, 0]) 159 | back_left_leg.geom(fromto=[0.0, 0.0, 0.0, -0.2, -0.2, 0.0], name="aux_3_geom", size="0.08", type="capsule", 160 | rgba=dark_red) 161 | aux_3 = back_left_leg.body(name="aux_3", pos=[-0.2, -0.2, 0]) 162 | aux_3.joint(axis=[0, 0, 1], name="hip_3", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge") 163 | aux_3.geom(fromto=[0.0, 0.0, 0.0, -thigh_length, -thigh_length, 0.0], name="backleft_leg_geom", size="0.08", type="capsule", 164 | rgba=dark_red) 165 | ankle_3 = aux_3.body(pos=[-thigh_length, -thigh_length, 0]) 166 | ankle_3.joint(axis=[-1, 1, 0], name="ankle_3", pos=[0.0, 0.0, 0.0], range=[-70, -30], type="hinge") 167 | ankle_3.geom(fromto=[0.0, 0.0, 0.0, -ankle_length, -ankle_length, 0.0], name="backleft_ankle_geom", size="0.08", type="capsule", 168 | rgba=dark_red) 169 | 170 | back_right_leg = ant.body(name="back_right_leg", pos=[0, 0, 0]) 171 | back_right_leg.geom(fromto=[0.0, 0.0, 0.0, 0.2, -0.2, 0.0], name="aux_4_geom", size="0.08", type="capsule", 172 | rgba=dark_red) 173 | aux_4 = back_right_leg.body(name="aux_4", pos=[0.2, -0.2, 0]) 174 | aux_4.joint(axis=[0, 0, 1], name="hip_4", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge") 175 | aux_4.geom(fromto=[0.0, 0.0, 0.0, thigh_length, -thigh_length, 0.0], name="backright_leg_geom", size="0.08", type="capsule", 176 | rgba=dark_red) 177 | ankle_4 = aux_4.body(pos=[thigh_length, -thigh_length, 0]) 178 | ankle_4.joint(axis=[1, 1, 0], name="ankle_4", pos=[0.0, 0.0, 0.0], range=[30, 70], type="hinge") 179 | ankle_4.geom(fromto=[0.0, 0.0, 0.0, ankle_length, -ankle_length, 0.0], name="backright_ankle_geom", size="0.08", type="capsule", 180 | rgba=dark_red) 181 | 182 | actuator = mjcmodel.root.actuator() 183 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_1", gear=gear) 184 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_1", gear=gear) 185 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_2", gear=gear) 186 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_2", gear=gear) 187 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_3", gear=1) # cripple the joints 188 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_3", gear=1) # cripple the joints 189 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_4", gear=1) 190 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_4", gear=1) 191 | return mjcmodel 192 | 193 | 194 | class CustomAntEnv(mujoco_env.MujocoEnv, utils.EzPickle): 195 | """ 196 | A modified ant env with lower joint gear ratios so it flips less often and learns faster. 197 | """ 198 | def __init__(self, max_timesteps=1000, disabled=False, gear=150): 199 | #mujoco_env.MujocoEnv.__init__(self, 'ant.xml', 5) 200 | utils.EzPickle.__init__(self) 201 | self.timesteps = 0 202 | self.max_timesteps=max_timesteps 203 | 204 | if disabled: 205 | model = angry_ant_crippled(gear=gear) 206 | else: 207 | model = ant_env(gear=gear) 208 | with model.asfile() as f: 209 | mujoco_env.MujocoEnv.__init__(self, f.name, 5) 210 | 211 | def _step(self, a): 212 | vel = self.model.data.qvel.flat[0] 213 | forward_reward = vel 214 | self.do_simulation(a, self.frame_skip) 215 | 216 | ctrl_cost = .01 * np.square(a).sum() 217 | contact_cost = 0.5 * 1e-3 * np.sum( 218 | np.square(np.clip(self.model.data.cfrc_ext, -1, 1))) 219 | state = self.state_vector() 220 | flipped = not (state[2] >= 0.2) 221 | flipped_rew = -1 if flipped else 0 222 | reward = forward_reward - ctrl_cost - contact_cost +flipped_rew 223 | 224 | self.timesteps += 1 225 | done = self.timesteps >= self.max_timesteps 226 | 227 | ob = self._get_obs() 228 | return ob, reward, done, dict( 229 | reward_forward=forward_reward, 230 | reward_ctrl=-ctrl_cost, 231 | reward_contact=-contact_cost, 232 | reward_flipped=flipped_rew) 233 | 234 | def _get_obs(self): 235 | return np.concatenate([ 236 | self.model.data.qpos.flat[2:], 237 | self.model.data.qvel.flat, 238 | np.clip(self.model.data.cfrc_ext, -1, 1).flat, 239 | ]) 240 | 241 | def reset_model(self): 242 | self.timesteps = 0 243 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1) 244 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1 245 | self.set_state(qpos, qvel) 246 | return self._get_obs() 247 | 248 | def viewer_setup(self): 249 | self.viewer.cam.distance = self.model.stat.extent * 0.5 250 | 251 | def log_diagnostics(self, paths): 252 | forward_rew = np.array([np.mean(traj['env_infos']['reward_forward']) for traj in paths]) 253 | reward_ctrl = np.array([np.mean(traj['env_infos']['reward_ctrl']) for traj in paths]) 254 | reward_cont = np.array([np.mean(traj['env_infos']['reward_contact']) for traj in paths]) 255 | reward_flip = np.array([np.mean(traj['env_infos']['reward_flipped']) for traj in paths]) 256 | 257 | logger.record_tabular('AvgRewardFwd', np.mean(forward_rew)) 258 | logger.record_tabular('AvgRewardCtrl', np.mean(reward_ctrl)) 259 | logger.record_tabular('AvgRewardContact', np.mean(reward_cont)) 260 | logger.record_tabular('AvgRewardFlipped', np.mean(reward_flip)) 261 | 262 | 263 | if __name__ == "__main__": 264 | env = CustomAntEnv(disabled=True, gear=30) 265 | 266 | for _ in range(1000): 267 | env.render() 268 | env.step(env.action_space.sample()) 269 | -------------------------------------------------------------------------------- /inverse_rl/envs/assets/twod_maze.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /inverse_rl/envs/dynamic_mjc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinjfu/inverse_rl/9609933389459a3a54f5c01d652114ada90fa1b3/inverse_rl/envs/dynamic_mjc/__init__.py -------------------------------------------------------------------------------- /inverse_rl/envs/dynamic_mjc/mjc_models.py: -------------------------------------------------------------------------------- 1 | from inverse_rl.envs.dynamic_mjc.model_builder import MJCModel 2 | import numpy as np 3 | 4 | 5 | def block_push(object_pos=(0,0,0), goal_pos=(0,0,0)): 6 | mjcmodel = MJCModel('block_push') 7 | mjcmodel.root.compiler(inertiafromgeom="true",angle="radian",coordinate="local") 8 | mjcmodel.root.option(timestep="0.01",gravity="0 0 0",iterations="20",integrator="Euler") 9 | default = mjcmodel.root.default() 10 | default.joint(armature='0.04', damping=1, limited='true') 11 | default.geom(friction=".8 .1 .1",density="300",margin="0.002",condim="1",contype="1",conaffinity="1") 12 | 13 | worldbody = mjcmodel.root.worldbody() 14 | 15 | palm = worldbody.body(name='palm', pos=[0,0,0]) 16 | palm.geom(name='palm_geom', type='capsule', fromto=[0,-0.1,0,0,0.1,0], size=.12) 17 | proximal1 = palm.body(name='proximal_1', pos=[0,0,0]) 18 | proximal1.joint(name='proximal_j_1', type='hinge', pos=[0,0,0], axis=[0,1,0], range=[-2.5,2.3]) 19 | proximal1.geom(type='capsule', fromto=[0,0,0,0.4,0,0], size=0.06, contype=1, conaffinity=1) 20 | distal1 = proximal1.body(name='distal_1', pos=[0.4,0,0]) 21 | distal1.joint(name = "distal_j_1", type = "hinge", pos = "0 0 0", axis = "0 1 0", range = "-2.3213 2.3", damping = "1.0") 22 | distal1.geom(type="capsule", fromto="0 0 0 0.4 0 0", size="0.06", contype="1", conaffinity="1") 23 | distal2 = distal1.body(name='distal_2', pos=[0.4,0,0]) 24 | distal2.joint(name="distal_j_2",type="hinge",pos="0 0 0",axis="0 1 0",range="-2.3213 2.3",damping="1.0") 25 | distal2.geom(type="capsule",fromto="0 0 0 0.4 0 0",size="0.06",contype="1",conaffinity="1") 26 | distal4 = distal2.body(name='distal_4', pos=[0.4,0,0]) 27 | distal4.site(name="tip arml",pos="0.1 0 -0.2",size="0.01") 28 | distal4.site(name="tip armr",pos="0.1 0 0.2",size="0.01") 29 | distal4.joint(name="distal_j_3",type="hinge",pos="0 0 0",axis="1 0 0",range="-3.3213 3.3",damping="0.5") 30 | distal4.geom(type="capsule",fromto="0 0 -0.2 0 0 0.2",size="0.04",contype="1",conaffinity="1") 31 | distal4.geom(type="capsule",fromto="0 0 -0.2 0.2 0 -0.2",size="0.04",contype="1",conaffinity="1") 32 | distal4.geom(type="capsule",fromto="0 0 0.2 0.2 0 0.2",size="0.04",contype="1",conaffinity="1") 33 | 34 | object = worldbody.body(name='object', pos=object_pos) 35 | object.geom(rgba="1. 1. 1. 1",type="box",size="0.05 0.05 0.05",density='0.00001',contype="1",conaffinity="1") 36 | object.joint(name="obj_slidez",type="slide",pos="0.025 0.025 0.025",axis="0 0 1",range="-10.3213 10.3",damping="0.5") 37 | object.joint(name="obj_slidex",type="slide",pos="0.025 0.025 0.025",axis="1 0 0",range="-10.3213 10.3",damping="0.5") 38 | distal10 = object.body(name='distal_10', pos=[0,0,0]) 39 | distal10.site(name='obj_pos', pos=[0.025,0.025,0.025], size=0.01) 40 | 41 | goal = worldbody.body(name='goal', pos=goal_pos) 42 | goal.geom(rgba="1. 0. 0. 1",type="box",size="0.1 0.1 0.1",density='0.00001',contype="0",conaffinity="0") 43 | distal11 = goal.body(name='distal_11', pos=[0,0,0]) 44 | distal11.site(name='goal_pos', pos=[0.05,0.05,0.05], size=0.01) 45 | 46 | 47 | actuator = mjcmodel.root.actuator() 48 | actuator.motor(joint="proximal_j_1",ctrlrange="-2 2",ctrllimited="true") 49 | actuator.motor(joint="distal_j_1",ctrlrange="-2 2",ctrllimited="true") 50 | actuator.motor(joint="distal_j_2",ctrlrange="-2 2",ctrllimited="true") 51 | actuator.motor(joint="distal_j_3",ctrlrange="-2 2",ctrllimited="true") 52 | 53 | return mjcmodel 54 | 55 | 56 | EAST = 0 57 | WEST = 1 58 | NORTH = 2 59 | SOUTH = 3 60 | 61 | def twod_corridor(direction=EAST, length=1.2): 62 | mjcmodel = MJCModel('twod_corridor') 63 | mjcmodel.root.compiler(inertiafromgeom="true", angle="radian", coordinate="local") 64 | mjcmodel.root.option(timestep="0.01", gravity="0 0 0", iterations="20", integrator="Euler") 65 | default = mjcmodel.root.default() 66 | default.joint(damping=1, limited='false') 67 | default.geom(friction=".5 .1 .1", density="1000", margin="0.002", condim="1", contype="2", conaffinity="1") 68 | 69 | worldbody = mjcmodel.root.worldbody() 70 | 71 | particle = worldbody.body(name='particle', pos=[0,0,0]) 72 | particle.geom(name='particle_geom', type='sphere', size='0.03', rgba='0.0 0.0 1.0 1', contype=1) 73 | particle.site(name='particle_site', pos=[0,0,0], size=0.01) 74 | particle.joint(name='ball_x', type='slide', pos=[0,0,0], axis=[1,0,0]) 75 | particle.joint(name='ball_y', type='slide', pos=[0,0,0], axis=[0,1,0]) 76 | 77 | pos = np.array([0.0,0,0]) 78 | if direction == EAST or direction == WEST: 79 | pos[0] = length-0.1 80 | else: 81 | pos[1] = length-0.1 82 | if direction == WEST or direction == SOUTH: 83 | pos = -pos 84 | 85 | target = worldbody.body(name='target', pos=pos) 86 | target.geom(name='target_geom', conaffinity=2, type='sphere', size=0.02, rgba=[0,0.9,0.1,1]) 87 | 88 | # arena 89 | if direction == EAST: 90 | L = -0.1 91 | R = length 92 | U = 0.1 93 | D = -0.1 94 | elif direction == WEST: 95 | L = -length 96 | R = 0.1 97 | U = 0.1 98 | D = -0.1 99 | elif direction == SOUTH: 100 | L = -0.1 101 | R = 0.1 102 | U = 0.1 103 | D = -length 104 | elif direction == NORTH: 105 | L = -0.1 106 | R = 0.1 107 | U = length 108 | D = -0.1 109 | 110 | worldbody.geom(conaffinity=1, fromto=[L, D, .01, R, D, .01], name="sideS", rgba="0.9 0.4 0.6 1", size=.02, type="capsule") 111 | worldbody.geom(conaffinity=1, fromto=[R, D, .01, R, U, .01], name="sideE", rgba="0.9 0.4 0.6 1", size=".02", type="capsule") 112 | worldbody.geom(conaffinity=1, fromto=[L, U, .01, R, U, .01], name="sideN", rgba="0.9 0.4 0.6 1", size=".02", type="capsule") 113 | worldbody.geom(conaffinity=1, fromto=[L, D, .01, L, U, .01], name="sideW", rgba="0.9 0.4 0.6 1", size=".02", type="capsule") 114 | 115 | actuator = mjcmodel.root.actuator() 116 | actuator.motor(joint="ball_x", ctrlrange=[-1.0, 1.0], ctrllimited=True) 117 | actuator.motor(joint="ball_y", ctrlrange=[-1.0, 1.0], ctrllimited=True) 118 | 119 | return mjcmodel 120 | 121 | 122 | LEFT = 0 123 | RIGHT = 1 124 | 125 | def point_mass_maze(direction=RIGHT, length=1.2, borders=True): 126 | mjcmodel = MJCModel('twod_maze') 127 | mjcmodel.root.compiler(inertiafromgeom="true", angle="radian", coordinate="local") 128 | mjcmodel.root.option(timestep="0.01", gravity="0 0 0", iterations="20", integrator="Euler") 129 | default = mjcmodel.root.default() 130 | default.joint(damping=1, limited='false') 131 | default.geom(friction=".5 .1 .1", density="1000", margin="0.002", condim="1", contype="2", conaffinity="1") 132 | 133 | worldbody = mjcmodel.root.worldbody() 134 | 135 | particle = worldbody.body(name='particle', pos=[length/2,0,0]) 136 | particle.geom(name='particle_geom', type='sphere', size='0.03', rgba='0.0 0.0 1.0 1', contype=1) 137 | particle.site(name='particle_site', pos=[0,0,0], size=0.01) 138 | particle.joint(name='ball_x', type='slide', pos=[0,0,0], axis=[1,0,0]) 139 | particle.joint(name='ball_y', type='slide', pos=[0,0,0], axis=[0,1,0]) 140 | 141 | target = worldbody.body(name='target', pos=[length/2,length-0.1,0]) 142 | target.geom(name='target_geom', conaffinity=2, type='sphere', size=0.02, rgba=[0,0.9,0.1,1]) 143 | 144 | L = -0.1 145 | R = length 146 | U = length 147 | D = -0.1 148 | 149 | if borders: 150 | worldbody.geom(conaffinity=1, fromto=[L, D, .01, R, D, .01], name="sideS", rgba="0.9 0.4 0.6 1", size=".02", type="capsule") 151 | worldbody.geom(conaffinity=1, fromto=[R, D, .01, R, U, .01], name="sideE", rgba="0.9 0.4 0.6 1", size=".02", type="capsule") 152 | worldbody.geom(conaffinity=1, fromto=[L, U, .01, R, U, .01], name="sideN", rgba="0.9 0.4 0.6 1", size=".02", type="capsule") 153 | worldbody.geom(conaffinity=1, fromto=[L, D, .01, L, U, .01], name="sideW", rgba="0.9 0.4 0.6 1", size=".02", type="capsule") 154 | 155 | # arena 156 | if direction == LEFT: 157 | BL = -0.1 158 | BR = length * 2/3 159 | BH = length/2 160 | else: 161 | BL = length * 1/3 162 | BR = length 163 | BH = length/2 164 | 165 | worldbody.geom(conaffinity=1, fromto=[BL, BH, .01, BR, BH, .01], name="barrier", rgba="0.9 0.4 0.6 1", size=".02", type="capsule") 166 | 167 | actuator = mjcmodel.root.actuator() 168 | actuator.motor(joint="ball_x", ctrlrange=[-1.0, 1.0], ctrllimited=True) 169 | actuator.motor(joint="ball_y", ctrlrange=[-1.0, 1.0], ctrllimited=True) 170 | 171 | return mjcmodel 172 | 173 | 174 | def ant_maze(direction=RIGHT, length=6.0): 175 | mjcmodel = MJCModel('ant_maze') 176 | mjcmodel.root.compiler(inertiafromgeom="true", angle="degree", coordinate="local") 177 | mjcmodel.root.option(timestep="0.01", gravity="0 0 -9.8", iterations="20", integrator="Euler") 178 | 179 | assets = mjcmodel.root.asset() 180 | assets.texture(builtin="gradient", height="100", rgb1="1 1 1", rgb2="0 0 0", type="skybox", width="100") 181 | assets.texture(builtin="flat", height="1278", mark="cross", markrgb="1 1 1", name="texgeom", random="0.01", rgb1="0.8 0.6 0.4", rgb2="0.8 0.6 0.4", type="cube", width="127") 182 | assets.texture(builtin="checker", height="100", name="texplane", rgb1="0 0 0", rgb2="0.8 0.8 0.8", type="2d", width="100") 183 | assets.material(name="MatPlane", reflectance="0.5", shininess="1", specular="1", texrepeat="60 60", texture="texplane") 184 | assets.material(name="geom", texture="texgeom", texuniform="true") 185 | 186 | default = mjcmodel.root.default() 187 | default.joint(armature="1", damping=1, limited='true') 188 | default.geom(friction="1 0.5 0.5", density="5.0", margin="0.01", condim="3", conaffinity="0") 189 | 190 | worldbody = mjcmodel.root.worldbody() 191 | 192 | ant = worldbody.body(name='ant', pos=[length/2, 1.0, 0.05]) 193 | ant.geom(name='torso_geom', pos=[0, 0, 0], size="0.25", type="sphere") 194 | ant.joint(armature="0", damping="0", limited="false", margin="0.01", name="root", pos=[0, 0, 0], type="free") 195 | 196 | front_left_leg = ant.body(name="front_left_leg", pos=[0, 0, 0]) 197 | front_left_leg.geom(fromto=[0.0, 0.0, 0.0, 0.2, 0.2, 0.0], name="aux_1_geom", size="0.08", type="capsule") 198 | aux_1 = front_left_leg.body(name="aux_1", pos=[0.2, 0.2, 0]) 199 | aux_1.joint(axis=[0, 0, 1], name="hip_1", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge") 200 | aux_1.geom(fromto=[0.0, 0.0, 0.0, 0.2, 0.2, 0.0], name="left_leg_geom", size="0.08", type="capsule") 201 | ankle_1 = aux_1.body(pos=[0.2, 0.2, 0]) 202 | ankle_1.joint(axis=[-1, 1, 0], name="ankle_1", pos=[0.0, 0.0, 0.0], range=[30, 70], type="hinge") 203 | ankle_1.geom(fromto=[0.0, 0.0, 0.0, 0.4, 0.4, 0.0], name="left_ankle_geom", size="0.08", type="capsule") 204 | 205 | front_right_leg = ant.body(name="front_right_leg", pos=[0, 0, 0]) 206 | front_right_leg.geom(fromto=[0.0, 0.0, 0.0, -0.2, 0.2, 0.0], name="aux_2_geom", size="0.08", type="capsule") 207 | aux_2 = front_right_leg.body(name="aux_2", pos=[-0.2, 0.2, 0]) 208 | aux_2.joint(axis=[0, 0, 1], name="hip_2", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge") 209 | aux_2.geom(fromto=[0.0, 0.0, 0.0, -0.2, 0.2, 0.0], name="right_leg_geom", size="0.08", type="capsule") 210 | ankle_2 = aux_2.body(pos=[-0.2, 0.2, 0]) 211 | ankle_2.joint(axis=[1, 1, 0], name="ankle_2", pos=[0.0, 0.0, 0.0], range=[-70, -30], type="hinge") 212 | ankle_2.geom(fromto=[0.0, 0.0, 0.0, -0.4, 0.4, 0.0], name="right_ankle_geom", size="0.08", type="capsule") 213 | 214 | back_left_leg = ant.body(name="back_left_leg", pos=[0, 0, 0]) 215 | back_left_leg.geom(fromto=[0.0, 0.0, 0.0, -0.2, -0.2, 0.0], name="aux_3_geom", size="0.08", type="capsule") 216 | aux_3 = back_left_leg.body(name="aux_3", pos=[-0.2, -0.2, 0]) 217 | aux_3.joint(axis=[0, 0, 1], name="hip_3", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge") 218 | aux_3.geom(fromto=[0.0, 0.0, 0.0, -0.2, -0.2, 0.0], name="backleft_leg_geom", size="0.08", type="capsule") 219 | ankle_3 = aux_3.body(pos=[-0.2, -0.2, 0]) 220 | ankle_3.joint(axis=[-1, 1, 0], name="ankle_3", pos=[0.0, 0.0, 0.0], range=[-70, -30], type="hinge") 221 | ankle_3.geom(fromto=[0.0, 0.0, 0.0, -0.4, -0.4, 0.0], name="backleft_ankle_geom", size="0.08", type="capsule") 222 | 223 | back_right_leg = ant.body(name="back_right_leg", pos=[0, 0, 0]) 224 | back_right_leg.geom(fromto=[0.0, 0.0, 0.0, 0.2, -0.2, 0.0], name="aux_4_geom", size="0.08", type="capsule") 225 | aux_4 = back_right_leg.body(name="aux_4", pos=[0.2, -0.2, 0]) 226 | aux_4.joint(axis=[0, 0, 1], name="hip_4", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge") 227 | aux_4.geom(fromto=[0.0, 0.0, 0.0, 0.2, -0.2, 0.0], name="backright_leg_geom", size="0.08", type="capsule") 228 | ankle_4 = aux_4.body(pos=[0.2, -0.2, 0]) 229 | ankle_4.joint(axis=[1, 1, 0], name="ankle_4", pos=[0.0, 0.0, 0.0], range=[30, 70], type="hinge") 230 | ankle_4.geom(fromto=[0.0, 0.0, 0.0, 0.4, -0.4, 0.0], name="backright_ankle_geom", size="0.08", type="capsule") 231 | 232 | target = worldbody.body(name='target', pos=[length/2,length-0.2,-0.5]) 233 | target.geom(name='target_geom', conaffinity=2, type='sphere', size=0.2, rgba=[0,0.9,0.1,1]) 234 | 235 | l = length/2 236 | h = 0.75 237 | w = 0.05 238 | 239 | worldbody.geom(conaffinity=1, name="sideS", rgba="0.9 0.4 0.6 1", size=[l, w, h], pos=[length/2, 0, 0], type="box") 240 | worldbody.geom(conaffinity=1, name="sideE", rgba="0.9 0.4 0.6 1", size=[w, l, h], pos=[length, length/2, 0], type="box") 241 | worldbody.geom(conaffinity=1, name="sideN", rgba="0.9 0.4 0.6 1", size=[l, w, h], pos=[length/2, length, 0], type="box") 242 | worldbody.geom(conaffinity=1, name="sideW", rgba="0.9 0.4 0.6 1", size=[w, l, h], pos=[0, length/2, 0], type="box") 243 | 244 | # arena 245 | if direction == LEFT: 246 | bx, by, bz = (length/3, length/2, 0) 247 | else: 248 | bx, by, bz = (length*2/3, length/2, 0) 249 | 250 | worldbody.geom(conaffinity=1, name="barrier", rgba="0.9 0.4 0.6 1", size=[l * 2/3, w, h], pos=[bx, by, bz], type="box") 251 | worldbody.geom(conaffinity="1", condim="3", material="MatPlane", name="floor", pos=[length/2, length/2, -h + w], 252 | rgba="0.8 0.9 0.8 1", size="40 40 40", type="plane") 253 | 254 | actuator = mjcmodel.root.actuator() 255 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_4", gear="50") 256 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_4", gear="50") 257 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_1", gear="50") 258 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_1", gear="50") 259 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_2", gear="50") 260 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_2", gear="50") 261 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_3", gear="50") 262 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_3", gear="50") 263 | 264 | return mjcmodel 265 | 266 | 267 | def ant_maze_corridor(direction=RIGHT, height=6.0, width=10.0): 268 | mjcmodel = MJCModel('ant_maze_corridor') 269 | mjcmodel.root.compiler(inertiafromgeom="true", angle="degree", coordinate="local") 270 | mjcmodel.root.option(timestep="0.01", gravity="0 0 -9.8", iterations="20", integrator="Euler") 271 | 272 | assets = mjcmodel.root.asset() 273 | assets.texture(builtin="gradient", height="100", rgb1="1 1 1", rgb2="0 0 0", type="skybox", width="100") 274 | assets.texture(builtin="flat", height="1278", mark="cross", markrgb="1 1 1", name="texgeom", random="0.01", rgb1="0.8 0.6 0.4", rgb2="0.8 0.6 0.4", type="cube", width="127") 275 | assets.texture(builtin="checker", height="100", name="texplane", rgb1="0 0 0", rgb2="0.8 0.8 0.8", type="2d", width="100") 276 | assets.material(name="MatPlane", reflectance="0.5", shininess="1", specular="1", texrepeat="60 60", texture="texplane") 277 | assets.material(name="geom", texture="texgeom", texuniform="true") 278 | 279 | default = mjcmodel.root.default() 280 | default.joint(armature="1", damping=1, limited='true') 281 | default.geom(friction="1 0.5 0.5", density="5.0", margin="0.01", condim="3", conaffinity="0") 282 | 283 | worldbody = mjcmodel.root.worldbody() 284 | 285 | ant = worldbody.body(name='ant', pos=[height/2, 1.0, 0.05]) 286 | ant.geom(name='torso_geom', pos=[0, 0, 0], size="0.25", type="sphere") 287 | ant.joint(armature="0", damping="0", limited="false", margin="0.01", name="root", pos=[0, 0, 0], type="free") 288 | 289 | front_left_leg = ant.body(name="front_left_leg", pos=[0, 0, 0]) 290 | front_left_leg.geom(fromto=[0.0, 0.0, 0.0, 0.2, 0.2, 0.0], name="aux_1_geom", size="0.08", type="capsule") 291 | aux_1 = front_left_leg.body(name="aux_1", pos=[0.2, 0.2, 0]) 292 | aux_1.joint(axis=[0, 0, 1], name="hip_1", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge") 293 | aux_1.geom(fromto=[0.0, 0.0, 0.0, 0.2, 0.2, 0.0], name="left_leg_geom", size="0.08", type="capsule") 294 | ankle_1 = aux_1.body(pos=[0.2, 0.2, 0]) 295 | ankle_1.joint(axis=[-1, 1, 0], name="ankle_1", pos=[0.0, 0.0, 0.0], range=[30, 70], type="hinge") 296 | ankle_1.geom(fromto=[0.0, 0.0, 0.0, 0.4, 0.4, 0.0], name="left_ankle_geom", size="0.08", type="capsule") 297 | 298 | front_right_leg = ant.body(name="front_right_leg", pos=[0, 0, 0]) 299 | front_right_leg.geom(fromto=[0.0, 0.0, 0.0, -0.2, 0.2, 0.0], name="aux_2_geom", size="0.08", type="capsule") 300 | aux_2 = front_right_leg.body(name="aux_2", pos=[-0.2, 0.2, 0]) 301 | aux_2.joint(axis=[0, 0, 1], name="hip_2", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge") 302 | aux_2.geom(fromto=[0.0, 0.0, 0.0, -0.2, 0.2, 0.0], name="right_leg_geom", size="0.08", type="capsule") 303 | ankle_2 = aux_2.body(pos=[-0.2, 0.2, 0]) 304 | ankle_2.joint(axis=[1, 1, 0], name="ankle_2", pos=[0.0, 0.0, 0.0], range=[-70, -30], type="hinge") 305 | ankle_2.geom(fromto=[0.0, 0.0, 0.0, -0.4, 0.4, 0.0], name="right_ankle_geom", size="0.08", type="capsule") 306 | 307 | back_left_leg = ant.body(name="back_left_leg", pos=[0, 0, 0]) 308 | back_left_leg.geom(fromto=[0.0, 0.0, 0.0, -0.2, -0.2, 0.0], name="aux_3_geom", size="0.08", type="capsule") 309 | aux_3 = back_left_leg.body(name="aux_3", pos=[-0.2, -0.2, 0]) 310 | aux_3.joint(axis=[0, 0, 1], name="hip_3", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge") 311 | aux_3.geom(fromto=[0.0, 0.0, 0.0, -0.2, -0.2, 0.0], name="backleft_leg_geom", size="0.08", type="capsule") 312 | ankle_3 = aux_3.body(pos=[-0.2, -0.2, 0]) 313 | ankle_3.joint(axis=[-1, 1, 0], name="ankle_3", pos=[0.0, 0.0, 0.0], range=[-70, -30], type="hinge") 314 | ankle_3.geom(fromto=[0.0, 0.0, 0.0, -0.4, -0.4, 0.0], name="backleft_ankle_geom", size="0.08", type="capsule") 315 | 316 | back_right_leg = ant.body(name="back_right_leg", pos=[0, 0, 0]) 317 | back_right_leg.geom(fromto=[0.0, 0.0, 0.0, 0.2, -0.2, 0.0], name="aux_4_geom", size="0.08", type="capsule") 318 | aux_4 = back_right_leg.body(name="aux_4", pos=[0.2, -0.2, 0]) 319 | aux_4.joint(axis=[0, 0, 1], name="hip_4", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge") 320 | aux_4.geom(fromto=[0.0, 0.0, 0.0, 0.2, -0.2, 0.0], name="backright_leg_geom", size="0.08", type="capsule") 321 | ankle_4 = aux_4.body(pos=[0.2, -0.2, 0]) 322 | ankle_4.joint(axis=[1, 1, 0], name="ankle_4", pos=[0.0, 0.0, 0.0], range=[30, 70], type="hinge") 323 | ankle_4.geom(fromto=[0.0, 0.0, 0.0, 0.4, -0.4, 0.0], name="backright_ankle_geom", size="0.08", type="capsule") 324 | 325 | target = worldbody.body(name='target', pos=[height/2, width-1.0,-0.5]) 326 | target.geom(name='target_geom', conaffinity=2, type='sphere', size=0.2, rgba=[0,0.9,0.1,1]) 327 | 328 | l = height/2 329 | h = 0.75 330 | w = 0.05 331 | 332 | worldbody.geom(conaffinity=1, name="sideS", rgba="0.9 0.4 0.6 1", size=[l, w, h], pos=[height/2, 0, 0], type="box") 333 | worldbody.geom(conaffinity=1, name="sideE", rgba="0.9 0.4 0.6 1", size=[w, width/2, h], pos=[height, width/2, 0], type="box") 334 | worldbody.geom(conaffinity=1, name="sideN", rgba="0.9 0.4 0.6 1", size=[l, w, h], pos=[height/2, width, 0], type="box") 335 | worldbody.geom(conaffinity=1, name="sideW", rgba="0.9 0.4 0.6 1", size=[w, width/2, h], pos=[0, width/2, 0], type="box") 336 | 337 | # arena 338 | wall_ratio = .55#2.0/3 339 | if direction == LEFT: 340 | bx, by, bz = (height*(wall_ratio/2), width/2, 0) 341 | #bx, by, bz = (height/4, width/2, 0) 342 | # bx, by, bz = length * 5/3, length * 5/6 + w, 0 343 | # bx1, by1, bz1 = bx - length/12, by-l/6, bz 344 | else: 345 | bx, by, bz = (height*(1-wall_ratio/2), width/2, 0) 346 | #bx, by, bz = (height*(3/4), width/2, 0) 347 | # bx, by, bz = length / 3, length * 5/6 + w, 0 348 | # bx1, by1, bz1 = bx + length/12, by-l/6, bz 349 | 350 | worldbody.geom(conaffinity=1, name="barrier", rgba="0.9 0.4 0.6 1", size=[l*(wall_ratio), w, h], pos=[bx, by, bz], type="box") 351 | # worldbody.geom(conaffinity=1, name="barrier1", rgba="0.9 0.4 0.6 1", size=[w, l/2 - 2*w, h], pos=[length/2, length/2, bz], type="box") 352 | # worldbody.geom(conaffinity=1, name="barrier2", rgba="0.9 0.4 0.6 1", size=[l/6, w, h], pos=[bx1, by1, bz1], type="box") 353 | # worldbody.geom(conaffinity=1, name="barrier3", rgba="0.9 0.4 0.6 1", size=[w, l/6, h], pos=[bx, by, bz], type="box") 354 | # worldbody.geom(conaffinity=1, condim=3, name="floor", rgba="0.4 0.4 0.4 1", size=[l, l, w], pos=[length/2, length/2, -h], 355 | # type="box") 356 | worldbody.geom(conaffinity="1", condim="3", material="MatPlane", name="floor", pos=[height/2, height/2, -h + w], 357 | rgba="0.8 0.9 0.8 1", size="40 40 40", type="plane") 358 | 359 | actuator = mjcmodel.root.actuator() 360 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_4", gear="30") 361 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_4", gear="30") 362 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_1", gear="30") 363 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_1", gear="30") 364 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_2", gear="30") 365 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_2", gear="30") 366 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_3", gear="30") 367 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_3", gear="30") 368 | 369 | return mjcmodel 370 | 371 | 372 | def pusher(goal_pos=np.array([0.45, -0.05, -0.323])): 373 | mjcmodel = MJCModel('pusher') 374 | mjcmodel.root.compiler(inertiafromgeom="true", angle="radian", coordinate="local") 375 | mjcmodel.root.option(timestep="0.01", gravity="0 0 0", iterations="20", integrator="Euler") 376 | default = mjcmodel.root.default() 377 | default.joint(armature=0.04, damping=1, limited=False) 378 | default.geom(friction=[.8, .1, .1], density=300, margin=0.002, condim=1, contype=0, conaffinity=0) 379 | 380 | worldbody = mjcmodel.root.worldbody() 381 | worldbody.light(diffuse=[.5,.5,.5], pos=[0,0,3], dir=[0,0,-1]) 382 | worldbody.geom(name='table', type='plane', pos=[0,.5,-.325],size=[1,1,0.1], contype=1, conaffinity=1) 383 | 384 | r_shoulder_pan_link = worldbody.body(name='r_shoulder_pan_link', pos=[0,-.6,0]) 385 | r_shoulder_pan_link.geom(name='e1', type='sphere', rgba=[.6,.6,.6,1], pos=[-0.06,0.05,0.2], size=0.05) 386 | r_shoulder_pan_link.geom(name='e2', type='sphere', rgba=[.6,.6,.6,1], pos=[0.06,0.05,0.2], size=0.05) 387 | r_shoulder_pan_link.geom(name='e1p', type='sphere', rgba=[.1,.1,.1,1], pos=[-0.06,0.09,0.2], size=0.03) 388 | r_shoulder_pan_link.geom(name='e2p', type='sphere', rgba=[.1,.1,.1,1], pos=[0.06,0.09,0.2], size=0.03) 389 | r_shoulder_pan_link.geom(name='sp', type='capsule', fromto=[0,0,-0.4,0,0,0.2], size=0.1) 390 | 391 | r_shoulder_pan_link.joint(name='r_shoulder_pan_joint', type='hinge', pos=[0,0,0], axis=[0,0,1], 392 | range=[-2.2854, 1.714602], damping=1.0) 393 | 394 | r_shoulder_lift_link = r_shoulder_pan_link.body(name='r_shoulder_lift_link', pos=[0.1,0,0]) 395 | r_shoulder_lift_link.geom(name='s1', type='capsule', fromto="0 -0.1 0 0 0.1 0", size="0.1") 396 | r_shoulder_lift_link.joint(name="r_shoulder_lift_joint", type="hinge", pos="0 0 0", axis="0 1 0", 397 | range="-0.5236 1.3963", damping="1.0") 398 | 399 | r_upper_arm_roll_link = r_shoulder_lift_link.body(name='r_upper_arm_roll_link', pos=[0,0,0]) 400 | r_upper_arm_roll_link.geom(name="uar", type="capsule", fromto="-0.1 0 0 0.1 0 0", size="0.02") 401 | r_upper_arm_roll_link.joint(name="r_upper_arm_roll_joint", type="hinge", pos="0 0 0", axis="1 0 0", 402 | range="-1.5 1.7", damping="0.1") 403 | 404 | r_upper_arm_link = r_upper_arm_roll_link.body(name='r_upper_arm_link', pos=[0,0,0]) 405 | r_upper_arm_link.geom(name="ua", type="capsule", fromto="0 0 0 0.4 0 0", size="0.06") 406 | 407 | r_elbow_flex_link = r_upper_arm_link.body(name='r_elbow_flex_link', pos=[0.4,0,0]) 408 | r_elbow_flex_link.geom(name="ef", type="capsule", fromto="0 -0.02 0 0.0 0.02 0", size="0.06") 409 | r_elbow_flex_link.joint(name="r_elbow_flex_joint", type="hinge", pos="0 0 0", axis="0 1 0", range="-2.3213 0", 410 | damping="0.1") 411 | 412 | r_forearm_roll_link = r_elbow_flex_link.body(name='r_forearm_roll_link', pos=[0,0,0]) 413 | r_forearm_roll_link.geom(name="fr", type="capsule", fromto="-0.1 0 0 0.1 0 0", size="0.02") 414 | r_forearm_roll_link.joint(name="r_forearm_roll_joint", type="hinge", limited="true", pos="0 0 0", 415 | axis="1 0 0", damping=".1", range="-1.5 1.5") 416 | 417 | r_forearm_link = r_forearm_roll_link.body(name='r_forearm_link', pos=[0,0,0]) 418 | r_forearm_link.geom(name="fa", type="capsule", fromto="0 0 0 0.291 0 0", size="0.05") 419 | 420 | r_wrist_flex_link = r_forearm_link.body(name='r_wrist_flex_link', pos=[0.321,0,0]) 421 | r_wrist_flex_link.geom(name="wf", type="capsule", fromto="0 -0.02 0 0 0.02 0", size="0.01" ) 422 | r_wrist_flex_link.joint(name="r_wrist_flex_joint", type="hinge", pos="0 0 0", axis="0 1 0", 423 | range="-1.094 0", damping=".1") 424 | 425 | r_wrist_roll_link = r_wrist_flex_link.body(name='r_wrist_roll_link', pos=[0,0,0]) 426 | r_wrist_roll_link.joint(name="r_wrist_roll_joint", type="hinge", pos="0 0 0", limited="true", axis="1 0 0", 427 | damping="0.1", range="-1.5 1.5") 428 | r_wrist_roll_link.geom(type="capsule",fromto="0 -0.1 0. 0.0 +0.1 0",size="0.02",contype="1",conaffinity="1") 429 | r_wrist_roll_link.geom(type="capsule",fromto="0 -0.1 0. 0.1 -0.1 0",size="0.02",contype="1",conaffinity="1") 430 | r_wrist_roll_link.geom(type="capsule",fromto="0 +0.1 0. 0.1 +0.1 0",size="0.02",contype="1",conaffinity="1") 431 | 432 | tips_arm = r_wrist_roll_link.body(name='tips_arm', pos=[0,0,0]) 433 | tips_arm.geom(name="tip_arml",type="sphere",pos="0.1 -0.1 0.",size="0.01") 434 | tips_arm.geom(name="tip_armr",type="sphere",pos="0.1 0.1 0.",size="0.01") 435 | 436 | #object_ = worldbody.body(name="object", pos=[0.45, -0.05, -0.275]) 437 | object_ = worldbody.body(name="object", pos=[0.0, 0.0, -0.275]) 438 | #object_.geom(rgba="1 1 1 0",type="sphere",size="0.05 0.05 0.05",density="0.00001",conaffinity="0") 439 | object_.geom(rgba="1 1 1 1",type="cylinder",size="0.05 0.05 0.05",density="0.00001",conaffinity="0", contype=1) 440 | object_.joint(name="obj_slidey",type="slide",pos="0 0 0",axis="0 1 0",range="-10.3213 10.3",damping="0.5") 441 | object_.joint(name="obj_slidex",type="slide",pos="0 0 0",axis="1 0 0",range="-10.3213 10.3",damping="0.5") 442 | 443 | goal = worldbody.body(name='goal', pos=goal_pos) 444 | goal.geom(rgba="1 0 0 1",type="cylinder",size="0.08 0.001 0.1",density='0.00001',contype="0",conaffinity="0") 445 | goal.joint(name="goal_slidey",type="slide",pos="0 0 0",axis="0 1 0",range="-10.3213 10.3",damping="0.5") 446 | goal.joint(name="goal_slidex",type="slide",pos="0 0 0",axis="1 0 0",range="-10.3213 10.3",damping="0.5") 447 | 448 | actuator = mjcmodel.root.actuator() 449 | actuator.motor(joint="r_shoulder_pan_joint",ctrlrange=[-2.0,2.0], ctrllimited=True) 450 | actuator.motor(joint="r_shoulder_lift_joint",ctrlrange=[-2.0,2.0], ctrllimited=True) 451 | actuator.motor(joint="r_upper_arm_roll_joint",ctrlrange=[-2.0,2.0], ctrllimited=True) 452 | actuator.motor(joint="r_elbow_flex_joint",ctrlrange=[-2.0,2.0], ctrllimited=True) 453 | actuator.motor(joint="r_forearm_roll_joint",ctrlrange=[-2.0,2.0], ctrllimited=True) 454 | actuator.motor(joint="r_wrist_flex_joint",ctrlrange=[-2.0,2.0], ctrllimited=True) 455 | actuator.motor(joint="r_wrist_roll_joint",ctrlrange=[-2.0,2.0], ctrllimited=True) 456 | 457 | return mjcmodel 458 | 459 | 460 | def swimmer(): 461 | mjcmodel = MJCModel('swimmer') 462 | mjcmodel.root.compiler(inertiafromgeom="true", angle="degree", coordinate="local") 463 | mjcmodel.root.option(timestep=0.01, viscosity=0.1, density=4000, integrator="RK4", collision="predefined") 464 | default = mjcmodel.root.default() 465 | default.joint(armature=0.1) 466 | default.geom(rgba=[0.8, .6, .1, 1], condim=1, contype=1, conaffinity=1, material='geom') 467 | 468 | asset = mjcmodel.root.asset() 469 | asset.texture(builtin='gradient', height=100, rgb1=[1,1,1], rgb2=[0,0,0], type='skybox', width=100) 470 | asset.texture(builtin='flat', height=1278, mark='cross', markrgb=[1,1,1], name='texgeom', 471 | random=0.01, rgb1=[0.8,0.6,0.4], rgb2=[0.8,0.6,0.4], type='cube', width=127) 472 | asset.texture(builtin='checker', height=100, name='texplane',rgb1=[0,0,0], rgb2=[0.8,0.8,0.8], type='2d', width=100) 473 | asset.material(name='MatPlane', reflectance=0.5, shininess=1, specular=1, texrepeat=[30,30], texture='texplane') 474 | asset.material(name='geom', texture='texgeom', texuniform=True) 475 | 476 | worldbody = mjcmodel.root.worldbody() 477 | worldbody.light(cutoff=100, diffuse=[1,1,1], dir=[0,0,-1.3], directional=True, exponent=1, pos=[0,0,1.3], specular=[.1,.1,.1]) 478 | worldbody.geom(conaffinity=1, condim=3, material='MatPlane', name='floor', pos=[0,0,-0.1], rgba=[0.8,0.9,0.9,1], size=[40,40,0.1], type='plane') 479 | torso = worldbody.body(name='torso', pos=[0,0,0]) 480 | torso.geom(density=1000, fromto=[1.5,0,0,0.5,0,0], size=0.1, type='capsule') 481 | torso.joint(axis=[1,0,0], name='slider1', pos=[0,0,0], type='slide') 482 | torso.joint(axis=[0,1,0], name='slider2', pos=[0,0,0], type='slide') 483 | torso.joint(axis=[0,0,1], name='rot', pos=[0,0,0], type='hinge') 484 | mid = torso.body(name='mid', pos=[0.5,0,0]) 485 | mid.geom(density=1000, fromto=[0,0,0,-1,0,0], size=0.1, type='capsule') 486 | mid.joint(axis=[0,0,1], limited=True, name='rot2', pos=[0,0,0], range=[-100,100], type='hinge') 487 | back = mid.body(name='back', pos=[-1,0,0]) 488 | back.geom(density=1000, fromto=[0,0,0,-1,0,0], size=0.1, type='capsule') 489 | back.joint(axis=[0,0,1], limited=True, name='rot3', pos=[0,0,0], range=[-100,100], type='hinge') 490 | 491 | actuator = mjcmodel.root.actuator() 492 | actuator.motor(ctrllimited=True, ctrlrange=[-1,1], gear=150, joint='rot2') 493 | actuator.motor(ctrllimited=True, ctrlrange=[-1,1], gear=150, joint='rot3') 494 | return mjcmodel 495 | 496 | def swimmer_rllab(): 497 | mjcmodel = MJCModel('swimmer') 498 | mjcmodel.root.compiler(inertiafromgeom="true", angle="degree", coordinate="local") 499 | mjcmodel.root.option(timestep=0.01, viscosity=0.1, density=4000, integrator="Euler", iterations=1000, collision="predefined") 500 | 501 | custom = mjcmodel.root.custom() 502 | custom.numeric(name='frame_skip', data=50) 503 | 504 | default = mjcmodel.root.default() 505 | #default.joint(armature=0.1) 506 | default.geom(rgba=[0.8, .6, .1, 1], condim=1, contype=1, conaffinity=1, material='geom') 507 | 508 | asset = mjcmodel.root.asset() 509 | asset.texture(builtin='gradient', height=100, rgb1=[1,1,1], rgb2=[0,0,0], type='skybox', width=100) 510 | asset.texture(builtin='flat', height=1278, mark='cross', markrgb=[1,1,1], name='texgeom', 511 | random=0.01, rgb1=[0.8,0.6,0.4], rgb2=[0.8,0.6,0.4], type='cube', width=127) 512 | asset.texture(builtin='checker', height=100, name='texplane',rgb1=[0,0,0], rgb2=[0.8,0.8,0.8], type='2d', width=100) 513 | asset.material(name='MatPlane', reflectance=0.5, shininess=1, specular=1, texrepeat=[30,30], texture='texplane') 514 | asset.material(name='geom', texture='texgeom', texuniform=True) 515 | 516 | worldbody = mjcmodel.root.worldbody() 517 | worldbody.light(cutoff=100, diffuse=[1,1,1], dir=[0,0,-1.3], directional=True, exponent=1, pos=[0,0,1.3], specular=[.1,.1,.1]) 518 | worldbody.geom(conaffinity=1, condim=3, material='MatPlane', name='floor', pos=[0,0,-0.1], rgba=[0.8,0.9,0.9,1], size=[40,40,0.1], type='plane') 519 | torso = worldbody.body(name='torso', pos=[0,0,0]) 520 | torso.geom(density=1000, fromto=[1.5,0,0,0.5,0,0], size=0.1, type='capsule') 521 | torso.joint(axis=[1,0,0], name='slider1', pos=[0,0,0], type='slide') 522 | torso.joint(axis=[0,1,0], name='slider2', pos=[0,0,0], type='slide') 523 | torso.joint(axis=[0,0,1], name='rot', pos=[0,0,0], type='hinge') 524 | mid = torso.body(name='mid', pos=[0.5,0,0]) 525 | mid.geom(density=1000, fromto=[0,0,0,-1,0,0], size=0.1, type='capsule') 526 | mid.joint(axis=[0,0,1], limited=True, name='rot2', pos=[0,0,0], range=[-100,100], type='hinge') 527 | back = mid.body(name='back', pos=[-1,0,0]) 528 | back.geom(density=1000, fromto=[0,0,0,-1,0,0], size=0.1, type='capsule') 529 | back.joint(axis=[0,0,1], limited=True, name='rot3', pos=[0,0,0], range=[-100,100], type='hinge') 530 | 531 | actuator = mjcmodel.root.actuator() 532 | actuator.motor(ctrllimited=True, ctrlrange=[-50,50], joint='rot2') 533 | actuator.motor(ctrllimited=True, ctrlrange=[-50,50], joint='rot3') 534 | return mjcmodel 535 | -------------------------------------------------------------------------------- /inverse_rl/envs/dynamic_mjc/model_builder.py: -------------------------------------------------------------------------------- 1 | """ 2 | model_builder.py 3 | A small library for programatically building MuJoCo XML files 4 | """ 5 | from contextlib import contextmanager 6 | import tempfile 7 | import numpy as np 8 | 9 | 10 | def default_model(name): 11 | """ 12 | Get a model with basic settings such as gravity and RK4 integration enabled 13 | """ 14 | model = MJCModel(name) 15 | root = model.root 16 | 17 | # Setup 18 | root.compiler(angle="radian", inertiafromgeom="true") 19 | default = root.default() 20 | default.joint(armature=1, damping=1, limited="true") 21 | default.geom(contype=0, friction='1 0.1 0.1', rgba='0.7 0.7 0 1') 22 | root.option(gravity="0 0 -9.81", integrator="RK4", timestep=0.01) 23 | return model 24 | 25 | def pointmass_model(name): 26 | """ 27 | Get a model with basic settings such as gravity and Euler integration enabled 28 | """ 29 | model = MJCModel(name) 30 | root = model.root 31 | 32 | # Setup 33 | root.compiler(angle="radian", inertiafromgeom="true", coordinate="local") 34 | default = root.default() 35 | default.joint(limited="false", damping=1) 36 | default.geom(contype=2, conaffinity="1", condim="1", friction=".5 .1 .1", density="1000", margin="0.002") 37 | root.option(timestep=0.01, gravity="0 0 0", iterations="20", integrator="Euler") 38 | return model 39 | 40 | 41 | class MJCModel(object): 42 | def __init__(self, name): 43 | self.name = name 44 | self.root = MJCTreeNode("mujoco").add_attr('model', name) 45 | 46 | @contextmanager 47 | def asfile(self): 48 | """ 49 | Usage: 50 | model = MJCModel('reacher') 51 | with model.asfile() as f: 52 | print f.read() # prints a dump of the model 53 | """ 54 | with tempfile.NamedTemporaryFile(mode='w+', suffix='.xml', delete=True) as f: 55 | self.root.write(f) 56 | f.seek(0) 57 | yield f 58 | 59 | def open(self): 60 | self.file = tempfile.NamedTemporaryFile(mode='w+', suffix='.xml', delete=True) 61 | self.root.write(self.file) 62 | self.file.seek(0) 63 | return self.file 64 | 65 | def close(self): 66 | self.file.close() 67 | 68 | def find_attr(self, attr, value): 69 | return self.root.find_attr(attr, value) 70 | 71 | def __getstate__(self): 72 | return {} 73 | 74 | def __setstate__(self, state): 75 | pass 76 | 77 | 78 | class MJCTreeNode(object): 79 | def __init__(self, name): 80 | self.name = name 81 | self.attrs = {} 82 | self.children = [] 83 | 84 | def add_attr(self, key, value): 85 | if isinstance(value, str): 86 | pass 87 | elif isinstance(value, list) or isinstance(value, np.ndarray): 88 | value = ' '.join([str(val).lower() for val in value]) 89 | else: 90 | value = str(value).lower() 91 | 92 | self.attrs[key] = value 93 | return self 94 | 95 | def __getattr__(self, name): 96 | def wrapper(**kwargs): 97 | newnode = MJCTreeNode(name) 98 | for (k, v) in kwargs.items(): 99 | newnode.add_attr(k, v) 100 | self.children.append(newnode) 101 | return newnode 102 | return wrapper 103 | 104 | def dfs(self): 105 | yield self 106 | if self.children: 107 | for child in self.children: 108 | for node in child.dfs(): 109 | yield node 110 | 111 | def find_attr(self, attr, value): 112 | """ Run DFS to find a matching attr """ 113 | if attr in self.attrs and self.attrs[attr] == value: 114 | return self 115 | for child in self.children: 116 | res = child.find_attr(attr, value) 117 | if res is not None: 118 | return res 119 | return None 120 | 121 | 122 | def write(self, ostream, tabs=0): 123 | contents = ' '.join(['%s="%s"'%(k,v) for (k,v) in self.attrs.items()]) 124 | if self.children: 125 | ostream.write('\t'*tabs) 126 | ostream.write('<%s %s>\n' % (self.name, contents)) 127 | for child in self.children: 128 | child.write(ostream, tabs=tabs+1) 129 | ostream.write('\t'*tabs) 130 | ostream.write('\n' % self.name) 131 | else: 132 | ostream.write('\t'*tabs) 133 | ostream.write('<%s %s/>\n' % (self.name, contents)) 134 | 135 | def __str__(self): 136 | s = "<"+self.name 137 | s += ' '.join(['%s="%s"'%(k,v) for (k,v) in self.attrs.items()]) 138 | return s+">" -------------------------------------------------------------------------------- /inverse_rl/envs/env_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import namedtuple 3 | 4 | import numpy as np 5 | from gym import Wrapper 6 | import gym 7 | 8 | from rllab.core.serializable import Serializable 9 | from rllab.envs.base import Env, Step 10 | from rllab.envs.gym_env import GymEnv, FixedIntervalVideoSchedule, NoVideoSchedule, CappedCubicVideoSchedule, \ 11 | convert_gym_space 12 | from rllab.envs.proxy_env import ProxyEnv 13 | from rllab.misc import logger 14 | from rllab.misc.overrides import overrides 15 | 16 | ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'assets') 17 | 18 | 19 | def get_asset_xml(xml_name): 20 | return os.path.join(ENV_ASSET_DIR, xml_name) 21 | 22 | 23 | class RllabGymEnv(Env, Serializable): 24 | def __init__(self, env_name, wrappers=(), wrapper_args=(), 25 | record_video=True, video_schedule=None, log_dir=None, record_log=True, 26 | post_create_env_seed=None, 27 | force_reset=False): 28 | if log_dir is None: 29 | if logger.get_snapshot_dir() is None: 30 | logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.") 31 | else: 32 | log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log") 33 | Serializable.quick_init(self, locals()) 34 | 35 | env = gym.envs.make(env_name) 36 | if post_create_env_seed is not None: 37 | env.set_env_seed(post_create_env_seed) 38 | for i, wrapper in enumerate(wrappers): 39 | if wrapper_args and len(wrapper_args) == len(wrappers): 40 | env = wrapper(env, **wrapper_args[i]) 41 | else: 42 | env = wrapper(env) 43 | self.env = env 44 | self.env_id = env.spec.id 45 | 46 | assert not (not record_log and record_video) 47 | 48 | if log_dir is None or record_log is False: 49 | self.monitoring = False 50 | else: 51 | if not record_video: 52 | video_schedule = NoVideoSchedule() 53 | else: 54 | if video_schedule is None: 55 | video_schedule = CappedCubicVideoSchedule() 56 | self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True) 57 | self.monitoring = True 58 | 59 | self._observation_space = convert_gym_space(env.observation_space) 60 | logger.log("observation space: {}".format(self._observation_space)) 61 | self._action_space = convert_gym_space(env.action_space) 62 | logger.log("action space: {}".format(self._action_space)) 63 | self._horizon = env.spec.tags.get('wrapper_config.TimeLimit.max_episode_steps') 64 | self._log_dir = log_dir 65 | self._force_reset = force_reset 66 | 67 | @property 68 | def observation_space(self): 69 | return self._observation_space 70 | 71 | @property 72 | def action_space(self): 73 | return self._action_space 74 | 75 | @property 76 | def horizon(self): 77 | return self._horizon 78 | 79 | def reset(self): 80 | if self._force_reset and self.monitoring: 81 | from gym.wrappers.monitoring import Monitor 82 | assert isinstance(self.env, Monitor) 83 | recorder = self.env.stats_recorder 84 | if recorder is not None: 85 | recorder.done = True 86 | return self.env.reset() 87 | 88 | def step(self, action): 89 | next_obs, reward, done, info = self.env.step(action) 90 | return Step(next_obs, reward, done, **info) 91 | 92 | def render(self): 93 | self.env.render() 94 | 95 | def terminate(self): 96 | if self.monitoring: 97 | self.env._close() 98 | if self._log_dir is not None: 99 | print(""" 100 | *************************** 101 | 102 | Training finished! You can upload results to OpenAI Gym by running the following command: 103 | 104 | python scripts/submit_gym.py %s 105 | 106 | *************************** 107 | """ % self._log_dir) 108 | 109 | class CustomGymEnv(RllabGymEnv): 110 | def __init__(self, env_name, gym_wrappers=(), 111 | register_fn=None, wrapper_args = (), record_log=False, record_video=False, 112 | post_create_env_seed=None): 113 | Serializable.quick_init(self, locals()) 114 | if register_fn is None: 115 | import inverse_rl.envs 116 | register_fn = inverse_rl.envs.register_custom_envs 117 | register_fn() # Force register 118 | self.env_name = env_name 119 | super(CustomGymEnv, self).__init__(env_name, wrappers=gym_wrappers, 120 | wrapper_args=wrapper_args, 121 | record_log=record_log, record_video=record_video, 122 | post_create_env_seed=post_create_env_seed, 123 | video_schedule=FixedIntervalVideoSchedule(50)) 124 | 125 | def _get_obs(self): 126 | return self.env._get_obs() 127 | 128 | @overrides 129 | def log_diagnostics(self, paths): 130 | get_inner_env(self.env).log_diagnostics(paths) 131 | 132 | @overrides 133 | def plot_trajs(self, paths, **kwargs): 134 | if hasattr(self.env, 'plot_trajs'): 135 | self.env.plot_trajs(paths, **kwargs) 136 | else: 137 | raise ValueError('Env %s has no traj plotting'%self.env) 138 | 139 | @overrides 140 | def plot_costs(self, *args, **kwargs): 141 | if hasattr(self.env, 'plot_costs'): 142 | self.env.plot_costs(*args, **kwargs) 143 | else: 144 | raise ValueError('Env %s has no traj plotting'% self.env) 145 | 146 | @property 147 | def wrapped_observation_space(self): 148 | if hasattr(self.env, 'wrapped_observation_space'): 149 | return self.env.wrapped_observation_space 150 | else: 151 | raise AttributeError() 152 | 153 | def get_param_values(self): 154 | return None 155 | 156 | def set_param_values(self, params): 157 | pass 158 | 159 | def get_viewer(self): 160 | return self.env._get_viewer() 161 | 162 | 163 | def get_inner_env(env): 164 | if isinstance(env, ProxyEnv): 165 | return get_inner_env(env.wrapped_env) 166 | elif isinstance(env, GymEnv): 167 | return get_inner_env(env.env) 168 | elif isinstance(env, Wrapper): 169 | return get_inner_env(env.env) 170 | return env 171 | 172 | 173 | def test_env(env, T=100): 174 | aspace = env.action_space 175 | env.reset() 176 | for t in range(T): 177 | o, r, done, infos = env.step(aspace.sample()) 178 | print('---T=%d---' % t) 179 | print('rew:', r) 180 | print('obs:', o) 181 | env.render() 182 | if done: 183 | break 184 | 185 | -------------------------------------------------------------------------------- /inverse_rl/envs/point_maze_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import utils 3 | from gym.envs.mujoco import mujoco_env 4 | 5 | from rllab.misc import logger 6 | 7 | from inverse_rl.envs.dynamic_mjc.mjc_models import point_mass_maze 8 | 9 | 10 | class PointMazeEnv(mujoco_env.MujocoEnv, utils.EzPickle): 11 | def __init__(self, direction=1, maze_length=0.6, sparse_reward=False, no_reward=False, episode_length=100): 12 | utils.EzPickle.__init__(self) 13 | self.sparse_reward = sparse_reward 14 | self.no_reward = no_reward 15 | self.max_episode_length = episode_length 16 | self.direction = direction 17 | self.length = maze_length 18 | 19 | self.episode_length = 0 20 | 21 | model = point_mass_maze(direction=self.direction, length=self.length) 22 | with model.asfile() as f: 23 | mujoco_env.MujocoEnv.__init__(self, f.name, 5) 24 | 25 | def _step(self, a): 26 | vec_dist = self.get_body_com("particle") - self.get_body_com("target") 27 | 28 | reward_dist = - np.linalg.norm(vec_dist) # particle to target 29 | reward_ctrl = - np.square(a).sum() 30 | if self.no_reward: 31 | reward = 0 32 | elif self.sparse_reward: 33 | if reward_dist <= 0.1: 34 | reward = 1 35 | else: 36 | reward = 0 37 | else: 38 | reward = reward_dist + 0.001 * reward_ctrl 39 | 40 | self.do_simulation(a, self.frame_skip) 41 | ob = self._get_obs() 42 | self.episode_length += 1 43 | done = self.episode_length >= self.max_episode_length 44 | return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl) 45 | 46 | def viewer_setup(self): 47 | self.viewer.cam.trackbodyid = -1 48 | self.viewer.cam.distance = 4.0 49 | 50 | def reset_model(self): 51 | qpos = self.init_qpos 52 | self.episode_length = 0 53 | qvel = self.init_qvel + self.np_random.uniform(size=self.model.nv, low=-0.01, high=0.01) 54 | self.set_state(qpos, qvel) 55 | self.episode_length = 0 56 | return self._get_obs() 57 | 58 | def _get_obs(self): 59 | return np.concatenate([ 60 | self.get_body_com("particle"), 61 | #self.get_body_com("target"), 62 | ]) 63 | 64 | def plot_trajs(self, *args, **kwargs): 65 | pass 66 | 67 | def log_diagnostics(self, paths): 68 | rew_dist = np.array([traj['env_infos']['reward_dist'] for traj in paths]) 69 | rew_ctrl = np.array([traj['env_infos']['reward_ctrl'] for traj in paths]) 70 | 71 | logger.record_tabular('AvgObjectToGoalDist', -np.mean(rew_dist.mean())) 72 | logger.record_tabular('AvgControlCost', -np.mean(rew_ctrl.mean())) 73 | logger.record_tabular('AvgMinToGoalDist', np.mean(np.min(-rew_dist, axis=1))) 74 | 75 | -------------------------------------------------------------------------------- /inverse_rl/envs/pusher_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import utils 3 | from gym.envs.mujoco import mujoco_env 4 | 5 | #import mujoco_py 6 | #from mujoco_py.mjlib import mjlib 7 | from rllab.misc import logger 8 | 9 | from inverse_rl.envs.dynamic_mjc.mjc_models import pusher 10 | 11 | class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle): 12 | def __init__(self, sparse_reward=False, no_reward=False, episode_length=200): 13 | utils.EzPickle.__init__(self) 14 | self.sparse_reward = sparse_reward 15 | self.no_reward = no_reward 16 | self.max_episode_length = episode_length 17 | self.goal_pos = np.asarray([0.0, 0.0]) 18 | 19 | self.episode_length = 0 20 | 21 | model = pusher(goal_pos=[self.goal_pos[0], self.goal_pos[1], -.323]) 22 | with model.asfile() as f: 23 | mujoco_env.MujocoEnv.__init__(self, f.name, 5) 24 | 25 | 26 | def _step(self, a): 27 | vec_1 = self.get_body_com("object") - self.get_body_com("tips_arm") 28 | vec_2 = self.get_body_com("object") - self.get_body_com("goal") 29 | #print('pre_step:', self.get_body_com('object'), self.get_body_com("goal")) 30 | 31 | reward_near = - np.linalg.norm(vec_1) # arm to object 32 | reward_dist = - np.linalg.norm(vec_2[0:2]) # object to goal 33 | reward_ctrl = - np.square(a).sum() 34 | if self.no_reward: 35 | reward = 0 36 | elif self.sparse_reward: 37 | reward = reward_dist + 0.1 * reward_ctrl 38 | else: 39 | reward = reward_dist + 0.1 * reward_ctrl + 0.5 * reward_near 40 | #print(reward_near) 41 | #if (-reward_near ) <= 0.1: 42 | # reward = reward_dist 43 | #else: 44 | # reward = reward_dist + 0.5*reward_near 45 | 46 | #reward += 0.2 * reward_ctrl 47 | 48 | self.do_simulation(a, self.frame_skip) 49 | ob = self._get_obs() 50 | self.episode_length += 1 51 | done = self.episode_length >= self.max_episode_length 52 | #print('post_step:', self.get_body_com('object'), self.get_body_com("goal")) 53 | return ob, reward, done, dict(reward_dist=reward_dist, reward_near=reward_near, 54 | reward_ctrl=reward_ctrl) 55 | 56 | def viewer_setup(self): 57 | self.viewer.cam.trackbodyid = -1 58 | self.viewer.cam.distance = 4.0 59 | 60 | def reset_model(self): 61 | qpos = self.init_qpos 62 | self.episode_length = 0 63 | 64 | while True: 65 | self.cylinder_pos = np.concatenate([ 66 | self.np_random.uniform(low=-0.5, high=0, size=1), 67 | self.np_random.uniform(low=-0.5, high=0.5, size=1)]) 68 | #self.cylinder_pos = self.np_random.uniform(low=-1.5, high=1.5, size=2) 69 | cyl_dist = np.linalg.norm(self.cylinder_pos - self.goal_pos) 70 | if cyl_dist > 0.2 and cyl_dist < 0.4: 71 | break 72 | 73 | qpos[-4:-2] = self.cylinder_pos 74 | #qpos[-2:] = self.goal_pos 75 | qpos[-2] = self.goal_pos[0] 76 | qpos[-1] = self.goal_pos[1] 77 | qvel = self.init_qvel + self.np_random.uniform(low=-0.005, 78 | high=0.005, size=self.model.nv) 79 | qvel[-4:] = 0 80 | #print('qpos_pre:', self.model.data.qpos.flat[-4:]) 81 | #print('qpos_pre_bodycom:', self.get_body_com('object'), self.get_body_com("goal")) 82 | self.set_state(qpos, qvel) 83 | #print('qpos_post:', self.model.data.qpos.flat[-4:]) 84 | #print('qpos_post_bodycom:', self.get_body_com('object'), self.get_body_com("goal")) 85 | return self._get_obs() 86 | 87 | def _get_obs(self): 88 | return np.concatenate([ 89 | self.model.data.qpos.flat[:7], 90 | self.model.data.qvel.flat[:7], 91 | self.get_body_com("tips_arm"), 92 | self.get_body_com("object"), 93 | self.get_body_com("goal"), 94 | ]) 95 | 96 | def plot_trajs(self, *args, **kwargs): 97 | pass 98 | 99 | def log_diagnostics(self, paths): 100 | rew_near = np.array([traj['env_infos']['reward_near'] for traj in paths]) 101 | rew_dist = np.array([traj['env_infos']['reward_dist'] for traj in paths]) 102 | rew_ctrl = np.array([traj['env_infos']['reward_ctrl'] for traj in paths]) 103 | 104 | logger.record_tabular('AvgArmToObjectDist', -np.mean(rew_near)) 105 | logger.record_tabular('AvgObjectToGoalDist', -np.mean(rew_dist)) 106 | logger.record_tabular('AvgControlCost', -np.mean(rew_ctrl)) 107 | 108 | -------------------------------------------------------------------------------- /inverse_rl/envs/twod_maze.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import utils 3 | 4 | from inverse_rl.envs.env_utils import get_asset_xml 5 | from inverse_rl.envs.twod_mjc_env import TwoDEnv 6 | 7 | from rllab.misc import logger as logger 8 | 9 | INIT_POS = np.array([0.15,0.15]) 10 | TARGET = np.array([0.15, -0.15]) 11 | DIST_THRESH = 0.12 12 | 13 | class TwoDMaze(TwoDEnv, utils.EzPickle): 14 | def __init__(self, verbose=False): 15 | self.verbose = verbose 16 | self.max_episode_length = 200 17 | self.episode_length = 0 18 | utils.EzPickle.__init__(self) 19 | TwoDEnv.__init__(self, get_asset_xml('twod_maze.xml'), 2, xbounds=[-0.3,0.3], ybounds=[-0.3,0.3]) 20 | 21 | def _step(self, a): 22 | self.do_simulation(a, self.frame_skip) 23 | ob = self._get_obs() 24 | pos = ob[0:2] 25 | dist = np.sum(np.abs(pos-TARGET)) #np.linalg.norm(pos - TARGET) 26 | reward = - (dist) 27 | 28 | reward_ctrl = - np.square(a).sum() 29 | reward += 1e-3 * reward_ctrl 30 | 31 | if self.verbose: 32 | print(pos, reward) 33 | self.episode_length += 1 34 | done = self.episode_length >= self.max_episode_length 35 | return ob, reward, done, {'distance': dist} 36 | 37 | def reset_model(self): 38 | self.episode_length = 0 39 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.01, high=0.01) 40 | qvel = self.init_qvel + self.np_random.uniform(size=self.model.nv, low=-0.01, high=0.01) 41 | self.set_state(qpos, qvel) 42 | return self._get_obs() 43 | 44 | def _get_obs(self): 45 | #return np.concatenate([self.model.data.qpos, self.model.data.qvel]).ravel() 46 | return np.concatenate([self.model.data.qpos]).ravel() - INIT_POS 47 | 48 | def viewer_setup(self): 49 | v = self.viewer 50 | #v.cam.trackbodyid=0 51 | #v.cam.distance = v.model.stat.extent 52 | 53 | def log_diagnostics(self, paths): 54 | rew_dist = np.array([traj['env_infos']['distance'] for traj in paths]) 55 | 56 | logger.record_tabular('AvgObjectToGoalDist', np.mean(rew_dist)) 57 | logger.record_tabular('MinAvgObjectToGoalDist', np.mean(np.min(rew_dist, axis=1))) 58 | 59 | 60 | 61 | if __name__ == "__main__": 62 | from inverse_rl.utils.getch import getKey 63 | env = TwoDMaze(verbose=True) 64 | 65 | while True: 66 | key = getKey() 67 | a = np.array([0.0,0.0]) 68 | if key == 'w': 69 | a += np.array([0.0, 1.0]) 70 | elif key == 'a': 71 | a += np.array([-1.0, 0.0]) 72 | elif key == 's': 73 | a += np.array([0.0, -1.0]) 74 | elif key == 'd': 75 | a += np.array([1.0, 0.0]) 76 | elif key == 'q': 77 | break 78 | a *= 0.2 79 | env.step(a) 80 | env.render() -------------------------------------------------------------------------------- /inverse_rl/envs/twod_mjc_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rllab.core.serializable import Serializable 3 | 4 | from gym.envs.mujoco import mujoco_env 5 | from gym.spaces import Box 6 | 7 | from rllab.envs.base import Env 8 | from rllab.misc.overrides import overrides 9 | 10 | 11 | class MapConfig(Serializable): 12 | def __init__(self, xs, ys, xres, yres): 13 | Serializable.quick_init(self, locals()) 14 | self.xs = xs 15 | self.ys = ys 16 | self.xres = xres 17 | self.yres=yres 18 | 19 | def map_config(xs=(-0.3,0.3), ys=(-0.3,0.3), xres=50, yres=50): 20 | return MapConfig(xs,ys,xres,yres) 21 | 22 | def make_heat_map(eval_func, map_config): 23 | gps = get_dense_gridpoints(map_config) 24 | vals = np.zeros(map_config.xres*map_config.yres) 25 | for i, pnt in enumerate(gps): 26 | vals[i] = eval_func(pnt) 27 | return predictions_to_heatmap(vals, map_config) 28 | 29 | 30 | def get_dense_gridpoints(map_config): 31 | xl = np.linspace(map_config.xs[0], map_config.xs[1], num=map_config.xres) 32 | yl = np.linspace(map_config.ys[0], map_config.ys[1], num=map_config.yres) 33 | gridpoints = np.zeros((map_config.xres*map_config.yres, 2)) 34 | for i in range(map_config.xres): 35 | for j in range(map_config.yres): 36 | gridpoints[i+map_config.xres*j] = np.array((xl[i], yl[j])) 37 | return gridpoints 38 | 39 | 40 | def predictions_to_heatmap(predictions, map_config): 41 | map = np.zeros((map_config.xres, map_config.yres)) 42 | for i in range(map_config.xres): 43 | for j in range(map_config.yres): 44 | map[i,j] = predictions[i+map_config.xres*j] 45 | map = map/np.max(map) 46 | return map.T 47 | 48 | def make_density_map(paths, map_config): 49 | xs = np.linspace(map_config.xs[0], map_config.xs[1], num=map_config.xres+1) 50 | ys = np.linspace(map_config.ys[0], map_config.ys[1], num=map_config.yres+1) 51 | y = paths[:,0] 52 | x = paths[:,1] 53 | H, xedges, yedges = np.histogram2d(y, x, bins=(xs, ys)) 54 | H = H.astype(np.float) 55 | H = H/np.max(H) 56 | return H.T 57 | 58 | def plot_maps(combined_list=None, *heatmaps): 59 | import matplotlib.pyplot as plt 60 | combined = np.c_[heatmaps] 61 | if combined_list is not None: 62 | combined_list.append(combined) 63 | combined = np.concatenate(combined_list) 64 | else: 65 | combined_list = [] 66 | plt.figure() 67 | plt.imshow(combined, cmap='afmhot', interpolation='none') 68 | plt.show() 69 | return combined_list 70 | 71 | 72 | class TwoDEnv(mujoco_env.MujocoEnv): 73 | def __init__(self, model_path, frame_skip, xbounds, ybounds): 74 | super(TwoDEnv, self).__init__(model_path=model_path, frame_skip=frame_skip) 75 | assert isinstance(self.observation_space, Box) 76 | assert self.observation_space.shape == (2,) 77 | self.__map_config = map_config(xs=(xbounds[0], xbounds[1]), 78 | ys=(ybounds[0], ybounds[1])) 79 | 80 | @property 81 | def map_config(self): 82 | return self.__map_config 83 | 84 | @property 85 | def grid_flat_dim(self): 86 | return self.map_config.xres*self.map_config.yres 87 | 88 | def make_density_map(self, paths): 89 | return make_density_map(paths, self.map_config) 90 | 91 | def make_heatmap(self, eval_func): 92 | return make_heat_map(eval_func, self.map_config) 93 | 94 | def get_dense_gridpoints(self): 95 | return get_dense_gridpoints(self.map_config) 96 | 97 | def predictions_to_heatmap(self, predictions): 98 | return predictions_to_heatmap(predictions, self.map_config) 99 | 100 | def get_viewer(self): 101 | return self._get_viewer() 102 | 103 | @overrides 104 | def log_diagnostics(self, paths): 105 | pass 106 | -------------------------------------------------------------------------------- /inverse_rl/envs/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def flat_to_one_hot(val, ndim): 4 | """ 5 | 6 | >>> flat_to_one_hot(2, ndim=4) 7 | array([ 0., 0., 1., 0.]) 8 | >>> flat_to_one_hot(4, ndim=5) 9 | array([ 0., 0., 0., 0., 1.]) 10 | >>> flat_to_one_hot(np.array([2, 4, 3]), ndim=5) 11 | array([[ 0., 0., 1., 0., 0.], 12 | [ 0., 0., 0., 0., 1.], 13 | [ 0., 0., 0., 1., 0.]]) 14 | """ 15 | shape =np.array(val).shape 16 | v = np.zeros(shape + (ndim,)) 17 | if len(shape) == 1: 18 | v[np.arange(shape[0]), val] = 1.0 19 | else: 20 | v[val] = 1.0 21 | return v 22 | 23 | def one_hot_to_flat(val): 24 | """ 25 | >>> one_hot_to_flat(np.array([0,0,0,0,1])) 26 | 4 27 | >>> one_hot_to_flat(np.array([0,0,1,0])) 28 | 2 29 | >>> one_hot_to_flat(np.array([[0,0,1,0], [1,0,0,0], [0,1,0,0]])) 30 | array([2, 0, 1]) 31 | """ 32 | idxs = np.array(np.where(val == 1.0))[-1] 33 | if len(val.shape) == 1: 34 | return int(idxs) 35 | return idxs -------------------------------------------------------------------------------- /inverse_rl/envs/visual_pointmass.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from gym import utils 4 | from gym.envs.mujoco import mujoco_env 5 | from gym.spaces import Box 6 | 7 | from inverse_rl.envs.dynamic_mjc.mjc_models import point_mass_maze 8 | from inverse_rl.envs.env_utils import get_asset_xml 9 | from inverse_rl.envs.twod_mjc_env import TwoDEnv 10 | 11 | from rllab.misc import logger as logger 12 | 13 | INIT_POS = np.array([0.15,0.15]) 14 | TARGET = np.array([0.15, -0.15]) 15 | DIST_THRESH = 0.12 16 | 17 | class VisualTwoDMaze(mujoco_env.MujocoEnv, utils.EzPickle): 18 | def __init__(self, verbose=False, width=64, height=64): 19 | self.verbose = verbose 20 | self.max_episode_length = 200 21 | self.episode_length = 0 22 | self.width = width 23 | self.height = height 24 | utils.EzPickle.__init__(self) 25 | super(VisualTwoDMaze, self).__init__(get_asset_xml('twod_maze.xml'), frame_skip=2) 26 | 27 | # calculate image dimensions 28 | #self._get_viewer().render() 29 | #data, width, height = self._get_viewer().get_image() 30 | self.observation_space = Box(0, 1, shape=(width, height, 3)) 31 | 32 | def _step(self, a): 33 | self.do_simulation(a, self.frame_skip) 34 | state = self._get_state() 35 | pos = state[0:2] 36 | dist = np.sum(np.abs(pos-TARGET)) #np.linalg.norm(pos - TARGET) 37 | reward = - (dist) 38 | 39 | reward_ctrl = - np.square(a).sum() 40 | reward += 1e-3 * reward_ctrl 41 | 42 | if self.verbose: 43 | print(pos, reward) 44 | self.episode_length += 1 45 | done = self.episode_length >= self.max_episode_length 46 | 47 | ob = self._get_obs() 48 | return ob, reward, done, {'distance': dist} 49 | 50 | def reset_model(self): 51 | self.episode_length = 0 52 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.01, high=0.01) 53 | qvel = self.init_qvel + self.np_random.uniform(size=self.model.nv, low=-0.01, high=0.01) 54 | self.set_state(qpos, qvel) 55 | return self._get_obs() 56 | 57 | def _get_state(self): 58 | #return np.concatenate([self.model.data.qpos, self.model.data.qvel]).ravel() 59 | return np.concatenate([self.model.data.qpos]).ravel() - INIT_POS 60 | 61 | def _get_obs(self): 62 | self._get_viewer().render() 63 | data, width, height = self._get_viewer().get_image() 64 | image = np.fromstring(data, dtype='uint8').reshape(height, width, 3)[::-1, :, :] 65 | # reshape 66 | if self.grayscale: 67 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 68 | image = cv2.resize(image, (self.width, self.height), interpolation=cv2.INTER_AREA) 69 | 70 | # rescale image to float 71 | image = image.astype(np.float32)/255.0 72 | return image 73 | 74 | def viewer_setup(self): 75 | #v = self.viewer 76 | self.viewer.cam.trackbodyid = -1 77 | self.viewer.cam.distance = 1.0 78 | #v.cam.trackbodyid=0 79 | #v.cam.distance = v.model.stat.extent 80 | 81 | def log_diagnostics(self, paths): 82 | rew_dist = np.array([traj['env_infos']['distance'] for traj in paths]) 83 | logger.record_tabular('AvgObjectToGoalDist', np.mean(rew_dist)) 84 | logger.record_tabular('MinAvgObjectToGoalDist', np.mean(np.min(rew_dist, axis=1))) 85 | 86 | 87 | class VisualPointMazeEnv(mujoco_env.MujocoEnv, utils.EzPickle): 88 | def __init__(self, direction=1, maze_length=0.6, 89 | sparse_reward=False, no_reward=False, episode_length=100, grayscale=True, 90 | width=64, height=64): 91 | utils.EzPickle.__init__(self) 92 | self.sparse_reward = sparse_reward 93 | self.no_reward = no_reward 94 | self.max_episode_length = episode_length 95 | self.direction = direction 96 | self.length = maze_length 97 | 98 | self.width = width 99 | self.height = height 100 | self.grayscale=grayscale 101 | 102 | self.episode_length = 0 103 | 104 | model = point_mass_maze(direction=self.direction, length=self.length, borders=False) 105 | with model.asfile() as f: 106 | mujoco_env.MujocoEnv.__init__(self, f.name, 5) 107 | 108 | if self.grayscale: 109 | self.observation_space = Box(0, 1, shape=(width, height)) 110 | else: 111 | self.observation_space = Box(0, 1, shape=(width, height, 3)) 112 | 113 | def _step(self, a): 114 | vec_dist = self.get_body_com("particle") - self.get_body_com("target") 115 | 116 | reward_dist = - np.linalg.norm(vec_dist) # particle to target 117 | reward_ctrl = - np.square(a).sum() 118 | if self.no_reward: 119 | reward = 0 120 | elif self.sparse_reward: 121 | if reward_dist <= 0.1: 122 | reward = 1 123 | else: 124 | reward = 0 125 | else: 126 | reward = reward_dist + 0.001 * reward_ctrl 127 | 128 | self.do_simulation(a, self.frame_skip) 129 | ob = self._get_obs() 130 | self.episode_length += 1 131 | done = self.episode_length >= self.max_episode_length 132 | return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl) 133 | 134 | def viewer_setup(self): 135 | self.viewer.cam.trackbodyid = -1 136 | self.viewer.cam.distance = 1.0 137 | 138 | def reset_model(self): 139 | qpos = self.init_qpos 140 | self.episode_length = 0 141 | qvel = self.init_qvel + self.np_random.uniform(size=self.model.nv, low=-0.01, high=0.01) 142 | self.set_state(qpos, qvel) 143 | self.episode_length = 0 144 | return self._get_obs() 145 | 146 | def _get_state(self): 147 | return np.concatenate([ 148 | self.get_body_com("particle"), 149 | #self.get_body_com("target"), 150 | ]) 151 | 152 | def _get_obs(self): 153 | self._get_viewer().render() 154 | data, width, height = self._get_viewer().get_image() 155 | image = np.fromstring(data, dtype='uint8').reshape(height, width, 3)[::-1, :, :] 156 | # rescale image to float 157 | if self.grayscale: 158 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 159 | image = cv2.resize(image, (self.width, self.height), interpolation=cv2.INTER_AREA) 160 | image = image.astype(np.float32)/255.0 161 | return image 162 | 163 | def log_diagnostics(self, paths): 164 | rew_dist = np.array([traj['env_infos']['reward_dist'] for traj in paths]) 165 | rew_ctrl = np.array([traj['env_infos']['reward_ctrl'] for traj in paths]) 166 | 167 | logger.record_tabular('AvgObjectToGoalDist', -np.mean(rew_dist.mean())) 168 | logger.record_tabular('AvgControlCost', -np.mean(rew_ctrl.mean())) 169 | logger.record_tabular('AvgMinToGoalDist', np.mean(np.min(-rew_dist, axis=1))) 170 | 171 | 172 | if __name__ == "__main__": 173 | from inverse_rl.utils.getch import getKey 174 | #env = VisualTwoDMaze(verbose=True) 175 | env = VisualPointMazeEnv() 176 | 177 | while True: 178 | key = getKey() 179 | a = np.array([0.0,0.0]) 180 | if key == 'w': 181 | a += np.array([0.0, 1.0]) 182 | elif key == 'a': 183 | a += np.array([-1.0, 0.0]) 184 | elif key == 's': 185 | a += np.array([0.0, -1.0]) 186 | elif key == 'd': 187 | a += np.array([1.0, 0.0]) 188 | elif key == 'q': 189 | break 190 | a *= 0.2 191 | o,_,_,_ = env.step(a) 192 | print(o.shape, o) 193 | env.render() 194 | -------------------------------------------------------------------------------- /inverse_rl/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinjfu/inverse_rl/9609933389459a3a54f5c01d652114ada90fa1b3/inverse_rl/models/__init__.py -------------------------------------------------------------------------------- /inverse_rl/models/airl_state.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from sandbox.rocky.tf.spaces.box import Box 4 | 5 | from inverse_rl.models.fusion_manager import RamFusionDistr 6 | from inverse_rl.models.imitation_learning import SingleTimestepIRL 7 | from inverse_rl.models.architectures import relu_net 8 | from inverse_rl.utils import TrainingIterator 9 | 10 | 11 | 12 | class AIRL(SingleTimestepIRL): 13 | """ 14 | 15 | 16 | Args: 17 | fusion (bool): Use trajectories from old iterations to train. 18 | state_only (bool): Fix the learned reward to only depend on state. 19 | score_discrim (bool): Use log D - log 1-D as reward (if true you should not need to use an entropy bonus) 20 | max_itrs (int): Number of training iterations to run per fit step. 21 | """ 22 | def __init__(self, env, 23 | expert_trajs=None, 24 | reward_arch=relu_net, 25 | reward_arch_args=None, 26 | value_fn_arch=relu_net, 27 | score_discrim=False, 28 | discount=1.0, 29 | state_only=False, 30 | max_itrs=100, 31 | fusion=False, 32 | name='airl'): 33 | super(AIRL, self).__init__() 34 | env_spec = env.spec 35 | if reward_arch_args is None: 36 | reward_arch_args = {} 37 | 38 | if fusion: 39 | self.fusion = RamFusionDistr(100, subsample_ratio=0.5) 40 | else: 41 | self.fusion = None 42 | self.dO = env_spec.observation_space.flat_dim 43 | self.dU = env_spec.action_space.flat_dim 44 | assert isinstance(env.action_space, Box) 45 | self.score_discrim = score_discrim 46 | self.gamma = discount 47 | assert value_fn_arch is not None 48 | self.set_demos(expert_trajs) 49 | self.state_only=state_only 50 | self.max_itrs=max_itrs 51 | 52 | # build energy model 53 | with tf.variable_scope(name) as _vs: 54 | # Should be batch_size x T x dO/dU 55 | self.obs_t = tf.placeholder(tf.float32, [None, self.dO], name='obs') 56 | self.nobs_t = tf.placeholder(tf.float32, [None, self.dO], name='nobs') 57 | self.act_t = tf.placeholder(tf.float32, [None, self.dU], name='act') 58 | self.nact_t = tf.placeholder(tf.float32, [None, self.dU], name='nact') 59 | self.labels = tf.placeholder(tf.float32, [None, 1], name='labels') 60 | self.lprobs = tf.placeholder(tf.float32, [None, 1], name='log_probs') 61 | self.lr = tf.placeholder(tf.float32, (), name='lr') 62 | 63 | with tf.variable_scope('discrim') as dvs: 64 | rew_input = self.obs_t 65 | if not self.state_only: 66 | rew_input = tf.concat([self.obs_t, self.act_t], axis=1) 67 | with tf.variable_scope('reward'): 68 | self.reward = reward_arch(rew_input, dout=1, **reward_arch_args) 69 | #energy_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name) 70 | 71 | # value function shaping 72 | with tf.variable_scope('vfn'): 73 | fitted_value_fn_n = value_fn_arch(self.nobs_t, dout=1) 74 | with tf.variable_scope('vfn', reuse=True): 75 | self.value_fn = fitted_value_fn = value_fn_arch(self.obs_t, dout=1) 76 | 77 | # Define log p_tau(a|s) = r + gamma * V(s') - V(s) 78 | self.qfn = self.reward + self.gamma*fitted_value_fn_n 79 | log_p_tau = self.reward + self.gamma*fitted_value_fn_n - fitted_value_fn 80 | 81 | log_q_tau = self.lprobs 82 | 83 | log_pq = tf.reduce_logsumexp([log_p_tau, log_q_tau], axis=0) 84 | self.discrim_output = tf.exp(log_p_tau-log_pq) 85 | cent_loss = -tf.reduce_mean(self.labels*(log_p_tau-log_pq) + (1-self.labels)*(log_q_tau-log_pq)) 86 | 87 | self.loss = cent_loss 88 | tot_loss = self.loss 89 | self.step = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(tot_loss) 90 | self._make_param_ops(_vs) 91 | 92 | def fit(self, paths, policy=None, batch_size=32, logger=None, lr=1e-3,**kwargs): 93 | 94 | if self.fusion is not None: 95 | old_paths = self.fusion.sample_paths(n=len(paths)) 96 | self.fusion.add_paths(paths) 97 | paths = paths+old_paths 98 | 99 | # eval samples under current policy 100 | self._compute_path_probs(paths, insert=True) 101 | 102 | # eval expert log probs under current policy 103 | self.eval_expert_probs(self.expert_trajs, policy, insert=True) 104 | 105 | self._insert_next_state(paths) 106 | self._insert_next_state(self.expert_trajs) 107 | obs, obs_next, acts, acts_next, path_probs = \ 108 | self.extract_paths(paths, 109 | keys=('observations', 'observations_next', 'actions', 'actions_next', 'a_logprobs')) 110 | expert_obs, expert_obs_next, expert_acts, expert_acts_next, expert_probs = \ 111 | self.extract_paths(self.expert_trajs, 112 | keys=('observations', 'observations_next', 'actions', 'actions_next', 'a_logprobs')) 113 | 114 | 115 | # Train discriminator 116 | for it in TrainingIterator(self.max_itrs, heartbeat=5): 117 | nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \ 118 | self.sample_batch(obs_next, obs, acts_next, acts, path_probs, batch_size=batch_size) 119 | 120 | nexpert_obs_batch, expert_obs_batch, nexpert_act_batch, expert_act_batch, expert_lprobs_batch = \ 121 | self.sample_batch(expert_obs_next, expert_obs, expert_acts_next, expert_acts, expert_probs, batch_size=batch_size) 122 | 123 | # Build feed dict 124 | labels = np.zeros((batch_size*2, 1)) 125 | labels[batch_size:] = 1.0 126 | obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0) 127 | nobs_batch = np.concatenate([nobs_batch, nexpert_obs_batch], axis=0) 128 | act_batch = np.concatenate([act_batch, expert_act_batch], axis=0) 129 | nact_batch = np.concatenate([nact_batch, nexpert_act_batch], axis=0) 130 | lprobs_batch = np.expand_dims(np.concatenate([lprobs_batch, expert_lprobs_batch], axis=0), axis=1).astype(np.float32) 131 | feed_dict = { 132 | self.act_t: act_batch, 133 | self.obs_t: obs_batch, 134 | self.nobs_t: nobs_batch, 135 | self.nact_t: nact_batch, 136 | self.labels: labels, 137 | self.lprobs: lprobs_batch, 138 | self.lr: lr 139 | } 140 | 141 | loss, _ = tf.get_default_session().run([self.loss, self.step], feed_dict=feed_dict) 142 | it.record('loss', loss) 143 | if it.heartbeat: 144 | print(it.itr_message()) 145 | mean_loss = it.pop_mean('loss') 146 | print('\tLoss:%f' % mean_loss) 147 | 148 | if logger: 149 | logger.record_tabular('GCLDiscrimLoss', mean_loss) 150 | #obs_next = np.r_[obs_next, np.expand_dims(obs_next[-1], axis=0)] 151 | energy, logZ, dtau = tf.get_default_session().run([self.reward, self.value_fn, self.discrim_output], 152 | feed_dict={self.act_t: acts, self.obs_t: obs, self.nobs_t: obs_next, 153 | self.nact_t: acts_next, 154 | self.lprobs: np.expand_dims(path_probs, axis=1)}) 155 | energy = -energy 156 | logger.record_tabular('GCLLogZ', np.mean(logZ)) 157 | logger.record_tabular('GCLAverageEnergy', np.mean(energy)) 158 | logger.record_tabular('GCLAverageLogPtau', np.mean(-energy-logZ)) 159 | logger.record_tabular('GCLAverageLogQtau', np.mean(path_probs)) 160 | logger.record_tabular('GCLMedianLogQtau', np.median(path_probs)) 161 | logger.record_tabular('GCLAverageDtau', np.mean(dtau)) 162 | 163 | 164 | #expert_obs_next = np.r_[expert_obs_next, np.expand_dims(expert_obs_next[-1], axis=0)] 165 | energy, logZ, dtau = tf.get_default_session().run([self.reward, self.value_fn, self.discrim_output], 166 | feed_dict={self.act_t: expert_acts, self.obs_t: expert_obs, self.nobs_t: expert_obs_next, 167 | self.nact_t: expert_acts_next, 168 | self.lprobs: np.expand_dims(expert_probs, axis=1)}) 169 | energy = -energy 170 | logger.record_tabular('GCLAverageExpertEnergy', np.mean(energy)) 171 | logger.record_tabular('GCLAverageExpertLogPtau', np.mean(-energy-logZ)) 172 | logger.record_tabular('GCLAverageExpertLogQtau', np.mean(expert_probs)) 173 | logger.record_tabular('GCLMedianExpertLogQtau', np.median(expert_probs)) 174 | logger.record_tabular('GCLAverageExpertDtau', np.mean(dtau)) 175 | return mean_loss 176 | 177 | def eval(self, paths, **kwargs): 178 | """ 179 | Return bonus 180 | """ 181 | if self.score_discrim: 182 | self._compute_path_probs(paths, insert=True) 183 | obs, obs_next, acts, path_probs = self.extract_paths(paths, keys=('observations', 'observations_next', 'actions', 'a_logprobs')) 184 | path_probs = np.expand_dims(path_probs, axis=1) 185 | scores = tf.get_default_session().run(self.discrim_output, 186 | feed_dict={self.act_t: acts, self.obs_t: obs, 187 | self.nobs_t: obs_next, 188 | self.lprobs: path_probs}) 189 | score = np.log(scores) - np.log(1-scores) 190 | score = score[:,0] 191 | else: 192 | obs, acts = self.extract_paths(paths) 193 | reward = tf.get_default_session().run(self.reward, 194 | feed_dict={self.act_t: acts, self.obs_t: obs}) 195 | score = reward[:,0] 196 | return self.unpack(score, paths) 197 | 198 | def eval_single(self, obs): 199 | reward = tf.get_default_session().run(self.reward, 200 | feed_dict={self.obs_t: obs}) 201 | score = reward[:, 0] 202 | return score 203 | 204 | def debug_eval(self, paths, **kwargs): 205 | obs, acts = self.extract_paths(paths) 206 | reward, v, qfn = tf.get_default_session().run([self.reward, self.value_fn, 207 | self.qfn], 208 | feed_dict={self.act_t: acts, self.obs_t: obs}) 209 | return { 210 | 'reward': reward, 211 | 'value': v, 212 | 'qfn': qfn, 213 | } 214 | 215 | -------------------------------------------------------------------------------- /inverse_rl/models/architectures.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from inverse_rl.models.tf_util import relu_layer, linear 3 | 4 | 5 | def make_relu_net(layers=2, dout=1, d_hidden=32): 6 | def relu_net(x, last_layer_bias=True): 7 | out = x 8 | for i in range(layers): 9 | out = relu_layer(out, dout=d_hidden, name='l%d'%i) 10 | out = linear(out, dout=dout, name='lfinal', bias=last_layer_bias) 11 | return out 12 | return relu_net 13 | 14 | 15 | def relu_net(x, layers=2, dout=1, d_hidden=32): 16 | out = x 17 | for i in range(layers): 18 | out = relu_layer(out, dout=d_hidden, name='l%d'%i) 19 | out = linear(out, dout=dout, name='lfinal') 20 | return out 21 | 22 | 23 | def linear_net(x, dout=1): 24 | out = x 25 | out = linear(out, dout=dout, name='lfinal') 26 | return out 27 | 28 | 29 | def feedforward_energy(obs_act, ff_arch=relu_net): 30 | # for trajectories, using feedforward nets rather than RNNs 31 | dimOU = int(obs_act.get_shape()[2]) 32 | orig_shape = tf.shape(obs_act) 33 | 34 | obs_act = tf.reshape(obs_act, [-1, dimOU]) 35 | outputs = ff_arch(obs_act) 36 | dOut = int(outputs.get_shape()[-1]) 37 | 38 | new_shape = tf.stack([orig_shape[0],orig_shape[1], dOut]) 39 | outputs = tf.reshape(outputs, new_shape) 40 | return outputs 41 | 42 | 43 | def rnn_trajectory_energy(obs_act): 44 | """ 45 | Operates on trajectories 46 | """ 47 | # for trajectories 48 | dimOU = int(obs_act.get_shape()[2]) 49 | 50 | cell = tf.contrib.rnn.GRUCell(num_units=dimOU) 51 | cell_out = tf.contrib.rnn.OutputProjectionWrapper(cell, 1) 52 | outputs, hidden = tf.nn.dynamic_rnn(cell_out, obs_act, time_major=False, dtype=tf.float32) 53 | return outputs 54 | 55 | -------------------------------------------------------------------------------- /inverse_rl/models/fusion_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import joblib 3 | import re 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from rllab.misc.logger import get_snapshot_dir 8 | 9 | 10 | class FusionDistrManager(object): 11 | def add_paths(self, paths): 12 | raise NotImplementedError() 13 | 14 | def sample_paths(self, n): 15 | raise NotImplementedError() 16 | 17 | 18 | class PathsReader(object): 19 | ITR_REG = re.compile(r"itr_(?P[0-9]+)\.pkl") 20 | 21 | def __init__(self, path_dir): 22 | self.path_dir = path_dir 23 | 24 | def get_path_files(self): 25 | itr_files = [] 26 | for i, filename in enumerate(os.listdir(self.path_dir)): 27 | m = PathsReader.ITR_REG.match(filename) 28 | if m: 29 | itr_count = m.group('itr_count') 30 | itr_files.append((itr_count, filename)) 31 | 32 | itr_files = sorted(itr_files, key=lambda x: int(x[0]), reverse=True) 33 | for itr_file_and_count in itr_files: 34 | fname = os.path.join(self.path_dir, itr_file_and_count[1]) 35 | yield fname 36 | 37 | def __len__(self): 38 | return len(list(self.get_path_files())) 39 | 40 | 41 | class DiskFusionDistr(FusionDistrManager): 42 | def __init__(self, path_dir=None): 43 | if path_dir is None: 44 | path_dir = get_snapshot_dir() 45 | self.path_dir = path_dir 46 | self.paths_reader = PathsReader(path_dir) 47 | 48 | def add_paths(self, paths): 49 | raise NotImplementedError() 50 | 51 | def sample_paths(self, n): 52 | # load from disk! 53 | fnames = list(self.paths_reader.get_path_files()) 54 | N = len(fnames) 55 | sample_files = np.random.randint(0, N, size=(n)) 56 | #sample_hist = np.histogram(sample_files, range=(0, N)) 57 | #print(sample_hist) 58 | unique, counts = np.unique(sample_files, return_counts=True) 59 | unique_dict = dict(zip(unique, counts)) 60 | 61 | all_paths = [] 62 | for fidx in unique_dict: 63 | fname = fnames[fidx] 64 | n_samp = unique_dict[fidx] 65 | print(fname, n_samp) 66 | 67 | config = tf.ConfigProto() 68 | config.gpu_options.allow_growth = True 69 | with tf.Graph().as_default(): 70 | with tf.Session(config=config).as_default(): 71 | snapshot_dict = joblib.load(fname) 72 | paths = snapshot_dict['paths'] 73 | pidxs = np.random.randint(0, len(paths), size=(n_samp)) 74 | all_paths.extend([paths[pidx] for pidx in pidxs]) 75 | return all_paths 76 | 77 | 78 | class RamFusionDistr(FusionDistrManager): 79 | def __init__(self, buf_size, subsample_ratio=0.5): 80 | self.buf_size = buf_size 81 | self.buffer = [] 82 | self.subsample_ratio = subsample_ratio 83 | 84 | def add_paths(self, paths, subsample=True): 85 | if subsample: 86 | paths = paths[:int(len(paths)*self.subsample_ratio)] 87 | self.buffer.extend(paths) 88 | overflow = len(self.buffer)-self.buf_size 89 | while overflow > 0: 90 | #self.buffer = self.buffer[overflow:] 91 | N = len(self.buffer) 92 | probs = np.arange(N)+1 93 | probs = probs/float(np.sum(probs)) 94 | pidx = np.random.choice(np.arange(N), p=probs) 95 | self.buffer.pop(pidx) 96 | overflow -= 1 97 | 98 | def sample_paths(self, n): 99 | if len(self.buffer) == 0: 100 | return [] 101 | else: 102 | pidxs = np.random.randint(0, len(self.buffer), size=(n)) 103 | return [self.buffer[pidx] for pidx in pidxs] 104 | 105 | 106 | if __name__ == "__main__": 107 | #fm = DiskFusionDistr(path_dir='data_nobs/gridworld_random/gru1') 108 | #paths = fm.sample_paths(10) 109 | fm = RamFusionDistr(10) 110 | fm.add_paths([1,2,3,4,5,6,7,8,9,10,11,12,13]) 111 | print(fm.buffer) 112 | print(fm.sample_paths(5)) 113 | -------------------------------------------------------------------------------- /inverse_rl/models/imitation_learning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from inverse_rl.models.architectures import feedforward_energy, relu_net 5 | from inverse_rl.models.tf_util import discounted_reduce_sum 6 | from inverse_rl.utils.general import TrainingIterator 7 | from inverse_rl.utils.hyperparametrized import Hyperparametrized 8 | from inverse_rl.utils.math_utils import gauss_log_pdf, categorical_log_pdf 9 | from sandbox.rocky.tf.misc import tensor_utils 10 | 11 | LOG_REG = 1e-8 12 | DIST_GAUSSIAN = 'gaussian' 13 | DIST_CATEGORICAL = 'categorical' 14 | 15 | class ImitationLearning(object, metaclass=Hyperparametrized): 16 | def __init__(self): 17 | pass 18 | 19 | def set_demos(self, paths): 20 | if paths is not None: 21 | self.expert_trajs = paths 22 | self.expert_trajs_extracted = self.extract_paths(paths) 23 | 24 | @staticmethod 25 | def _compute_path_probs(paths, pol_dist_type=None, insert=True, 26 | insert_key='a_logprobs'): 27 | """ 28 | Returns a N x T matrix of action probabilities 29 | """ 30 | if insert_key in paths[0]: 31 | return np.array([path[insert_key] for path in paths]) 32 | 33 | if pol_dist_type is None: 34 | # try to infer distribution type 35 | path0 = paths[0] 36 | if 'log_std' in path0['agent_infos']: 37 | pol_dist_type = DIST_GAUSSIAN 38 | elif 'prob' in path0['agent_infos']: 39 | pol_dist_type = DIST_CATEGORICAL 40 | else: 41 | raise NotImplementedError() 42 | 43 | # compute path probs 44 | Npath = len(paths) 45 | actions = [path['actions'] for path in paths] 46 | if pol_dist_type == DIST_GAUSSIAN: 47 | params = [(path['agent_infos']['mean'], path['agent_infos']['log_std']) for path in paths] 48 | path_probs = [gauss_log_pdf(params[i], actions[i]) for i in range(Npath)] 49 | elif pol_dist_type == DIST_CATEGORICAL: 50 | params = [(path['agent_infos']['prob'],) for path in paths] 51 | path_probs = [categorical_log_pdf(params[i], actions[i]) for i in range(Npath)] 52 | else: 53 | raise NotImplementedError("Unknown distribution type") 54 | 55 | if insert: 56 | for i, path in enumerate(paths): 57 | path[insert_key] = path_probs[i] 58 | 59 | return np.array(path_probs) 60 | 61 | @staticmethod 62 | def _insert_next_state(paths, pad_val=0.0): 63 | for path in paths: 64 | if 'observations_next' in path: 65 | continue 66 | nobs = path['observations'][1:] 67 | nact = path['actions'][1:] 68 | nobs = np.r_[nobs, pad_val*np.expand_dims(np.ones_like(nobs[0]), axis=0)] 69 | nact = np.r_[nact, pad_val*np.expand_dims(np.ones_like(nact[0]), axis=0)] 70 | path['observations_next'] = nobs 71 | path['actions_next'] = nact 72 | return paths 73 | 74 | @staticmethod 75 | def extract_paths(paths, keys=('observations', 'actions'), stack=True): 76 | if stack: 77 | return [np.stack([t[key] for t in paths]).astype(np.float32) for key in keys] 78 | else: 79 | return [np.concatenate([t[key] for t in paths]).astype(np.float32) for key in keys] 80 | 81 | @staticmethod 82 | def sample_batch(*args, batch_size=32): 83 | N = args[0].shape[0] 84 | batch_idxs = np.random.randint(0, N, batch_size) # trajectories are negatives 85 | return [data[batch_idxs] for data in args] 86 | 87 | def fit(self, paths, **kwargs): 88 | raise NotImplementedError() 89 | 90 | def eval(self, paths, **kwargs): 91 | raise NotImplementedError() 92 | 93 | def _make_param_ops(self, vs): 94 | self._params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name) 95 | assert len(self._params)>0 96 | self._assign_plc = [tf.placeholder(tf.float32, shape=param.get_shape(), name='assign_%s'%param.name.replace('/','_').replace(':','_')) for param in self._params] 97 | self._assign_ops = [tf.assign(self._params[i], self._assign_plc[i]) for i in range(len(self._params))] 98 | 99 | def get_params(self): 100 | params = tf.get_default_session().run(self._params) 101 | assert len(params) == len(self._params) 102 | return params 103 | 104 | def set_params(self, params): 105 | tf.get_default_session().run(self._assign_ops, feed_dict={ 106 | self._assign_plc[i]: params[i] for i in range(len(self._params)) 107 | }) 108 | 109 | 110 | class TrajectoryIRL(ImitationLearning): 111 | """ 112 | Base class for models that score entire trajectories at once 113 | """ 114 | @property 115 | def score_trajectories(self): 116 | return True 117 | 118 | def eval_expert_probs(self, expert_paths, policy, insert=False): 119 | """ 120 | Evaluate expert policy probability under current policy 121 | """ 122 | if policy.recurrent: 123 | policy.reset([True]*len(expert_paths)) 124 | expert_obs = self.extract_paths(expert_paths, keys=('observations',))[0] 125 | agent_infos = [] 126 | for t in range(expert_obs.shape[1]): 127 | a, infos = policy.get_actions(expert_obs[:, t]) 128 | agent_infos.append(infos) 129 | agent_infos_stack = tensor_utils.stack_tensor_dict_list(agent_infos) 130 | for key in agent_infos_stack: 131 | agent_infos_stack[key] = np.transpose(agent_infos_stack[key], axes=[1,0,2]) 132 | agent_infos_transpose = tensor_utils.split_tensor_dict_list(agent_infos_stack) 133 | for i, path in enumerate(expert_paths): 134 | path['agent_infos'] = agent_infos_transpose[i] 135 | else: 136 | for path in expert_paths: 137 | actions, agent_infos = policy.get_actions(path['observations']) 138 | path['agent_infos'] = agent_infos 139 | return self._compute_path_probs(expert_paths, insert=insert) 140 | 141 | 142 | 143 | class SingleTimestepIRL(ImitationLearning): 144 | """ 145 | Base class for models that score single timesteps at once 146 | """ 147 | @staticmethod 148 | def extract_paths(paths, keys=('observations', 'actions'), stack=False): 149 | return ImitationLearning.extract_paths(paths, keys=keys, stack=stack) 150 | 151 | @staticmethod 152 | def unpack(data, paths): 153 | lengths = [path['observations'].shape[0] for path in paths] 154 | unpacked = [] 155 | idx = 0 156 | for l in lengths: 157 | unpacked.append(data[idx:idx+l]) 158 | idx += l 159 | return unpacked 160 | 161 | @property 162 | def score_trajectories(self): 163 | return False 164 | 165 | def eval_expert_probs(self, expert_paths, policy, insert=False): 166 | """ 167 | Evaluate expert policy probability under current policy 168 | """ 169 | for traj in expert_paths: 170 | if 'agent_infos' in traj: 171 | del traj['agent_infos'] 172 | if 'a_logprobs' in traj: 173 | del traj['a_logprobs'] 174 | 175 | if isinstance(policy, np.ndarray): 176 | return self._compute_path_probs(expert_paths, insert=insert) 177 | elif hasattr(policy, 'recurrent') and policy.recurrent: 178 | policy.reset([True]*len(expert_paths)) 179 | expert_obs = self.extract_paths(expert_paths, keys=('observations',), stack=True)[0] 180 | agent_infos = [] 181 | for t in range(expert_obs.shape[1]): 182 | a, infos = policy.get_actions(expert_obs[:, t]) 183 | agent_infos.append(infos) 184 | agent_infos_stack = tensor_utils.stack_tensor_dict_list(agent_infos) 185 | for key in agent_infos_stack: 186 | agent_infos_stack[key] = np.transpose(agent_infos_stack[key], axes=[1,0,2]) 187 | agent_infos_transpose = tensor_utils.split_tensor_dict_list(agent_infos_stack) 188 | for i, path in enumerate(expert_paths): 189 | path['agent_infos'] = agent_infos_transpose[i] 190 | else: 191 | for path in expert_paths: 192 | actions, agent_infos = policy.get_actions(path['observations']) 193 | path['agent_infos'] = agent_infos 194 | return self._compute_path_probs(expert_paths, insert=insert) 195 | 196 | 197 | class GAIL(SingleTimestepIRL): 198 | """ 199 | Generative adverserial imitation learning 200 | See https://arxiv.org/pdf/1606.03476.pdf 201 | 202 | This version consumes single timesteps. 203 | """ 204 | def __init__(self, env_spec, expert_trajs=None, 205 | discrim_arch=relu_net, 206 | discrim_arch_args={}, 207 | name='gail'): 208 | super(GAIL, self).__init__() 209 | self.dO = env_spec.observation_space.flat_dim 210 | self.dU = env_spec.action_space.flat_dim 211 | self.set_demos(expert_trajs) 212 | 213 | # build energy model 214 | with tf.variable_scope(name) as vs: 215 | # Should be batch_size x T x dO/dU 216 | self.obs_t = tf.placeholder(tf.float32, [None, self.dO], name='obs') 217 | self.act_t = tf.placeholder(tf.float32, [None, self.dU], name='act') 218 | self.labels = tf.placeholder(tf.float32, [None, 1], name='labels') 219 | self.lr = tf.placeholder(tf.float32, (), name='lr') 220 | 221 | obs_act = tf.concat([self.obs_t, self.act_t], axis=1) 222 | logits = discrim_arch(obs_act, **discrim_arch_args) 223 | self.predictions = tf.nn.sigmoid(logits) 224 | self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=self.labels)) 225 | self.step = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(self.loss) 226 | self._make_param_ops(vs) 227 | 228 | 229 | def fit(self, trajs, batch_size=32, max_itrs=100, **kwargs): 230 | obs, acts = self.extract_paths(trajs) 231 | expert_obs, expert_acts = self.expert_trajs_extracted 232 | 233 | # Train discriminator 234 | for it in TrainingIterator(max_itrs, heartbeat=5): 235 | obs_batch, act_batch = self.sample_batch(obs, acts, batch_size=batch_size) 236 | expert_obs_batch, expert_act_batch = self.sample_batch(expert_obs, expert_acts, batch_size=batch_size) 237 | labels = np.zeros((batch_size*2, 1)) 238 | labels[batch_size:] = 1.0 239 | obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0) 240 | act_batch = np.concatenate([act_batch, expert_act_batch], axis=0) 241 | 242 | loss, _ = tf.get_default_session().run([self.loss, self.step], feed_dict={ 243 | self.act_t: act_batch, 244 | self.obs_t: obs_batch, 245 | self.labels: labels, 246 | self.lr: 1e-3 247 | }) 248 | 249 | it.record('loss', loss) 250 | if it.heartbeat: 251 | print(it.itr_message()) 252 | mean_loss = it.pop_mean('loss') 253 | print('\tLoss:%f' % mean_loss) 254 | return mean_loss 255 | 256 | def eval(self, paths, **kwargs): 257 | """ 258 | Return bonus 259 | """ 260 | obs, acts = self.extract_paths(paths) 261 | scores = tf.get_default_session().run(self.predictions, 262 | feed_dict={self.act_t: acts, self.obs_t: obs}) 263 | 264 | # reward = log D(s, a) 265 | scores = np.log(scores[:,0]+LOG_REG) 266 | return self.unpack(scores, paths) 267 | 268 | 269 | class AIRLStateAction(SingleTimestepIRL): 270 | """ 271 | This version consumes single timesteps. 272 | """ 273 | def __init__(self, env_spec, expert_trajs=None, 274 | discrim_arch=relu_net, 275 | discrim_arch_args={}, 276 | l2_reg=0, 277 | discount=1.0, 278 | name='gcl'): 279 | super(AIRLStateAction, self).__init__() 280 | self.dO = env_spec.observation_space.flat_dim 281 | self.dU = env_spec.action_space.flat_dim 282 | self.set_demos(expert_trajs) 283 | 284 | # build energy model 285 | with tf.variable_scope(name) as _vs: 286 | # Should be batch_size x T x dO/dU 287 | self.obs_t = tf.placeholder(tf.float32, [None, self.dO], name='obs') 288 | self.act_t = tf.placeholder(tf.float32, [None, self.dU], name='act') 289 | self.labels = tf.placeholder(tf.float32, [None, 1], name='labels') 290 | self.lprobs = tf.placeholder(tf.float32, [None, 1], name='log_probs') 291 | self.lr = tf.placeholder(tf.float32, (), name='lr') 292 | 293 | obs_act = tf.concat([self.obs_t, self.act_t], axis=1) 294 | with tf.variable_scope('discrim') as dvs: 295 | with tf.variable_scope('energy'): 296 | self.energy = discrim_arch(obs_act, **discrim_arch_args) 297 | # we do not learn a separate log Z(s) because it is impossible to separate from the energy 298 | # In a discrete domain we can explicitly normalize to calculate log Z(s) 299 | log_p_tau = -self.energy 300 | discrim_vars = tf.get_collection('reg_vars', scope=dvs.name) 301 | 302 | log_q_tau = self.lprobs 303 | 304 | if l2_reg > 0: 305 | reg_loss = l2_reg*tf.reduce_sum([tf.reduce_sum(tf.square(var)) for var in discrim_vars]) 306 | else: 307 | reg_loss = 0 308 | 309 | log_pq = tf.reduce_logsumexp([log_p_tau, log_q_tau], axis=0) 310 | self.d_tau = tf.exp(log_p_tau-log_pq) 311 | cent_loss = -tf.reduce_mean(self.labels*(log_p_tau-log_pq) + (1-self.labels)*(log_q_tau-log_pq)) 312 | 313 | self.loss = cent_loss + reg_loss 314 | self.step = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(self.loss) 315 | self._make_param_ops(_vs) 316 | 317 | 318 | def fit(self, paths, policy=None, batch_size=32, max_itrs=100, logger=None, lr=1e-3,**kwargs): 319 | #self._compute_path_probs(paths, insert=True) 320 | self.eval_expert_probs(paths, policy, insert=True) 321 | self.eval_expert_probs(self.expert_trajs, policy, insert=True) 322 | obs, acts, path_probs = self.extract_paths(paths, keys=('observations', 'actions', 'a_logprobs')) 323 | expert_obs, expert_acts, expert_probs = self.extract_paths(self.expert_trajs, keys=('observations', 'actions', 'a_logprobs')) 324 | 325 | # Train discriminator 326 | for it in TrainingIterator(max_itrs, heartbeat=5): 327 | obs_batch, act_batch, lprobs_batch = \ 328 | self.sample_batch(obs, acts, path_probs, batch_size=batch_size) 329 | 330 | expert_obs_batch, expert_act_batch, expert_lprobs_batch = \ 331 | self.sample_batch(expert_obs, expert_acts, expert_probs, batch_size=batch_size) 332 | 333 | labels = np.zeros((batch_size*2, 1)) 334 | labels[batch_size:] = 1.0 335 | obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0) 336 | act_batch = np.concatenate([act_batch, expert_act_batch], axis=0) 337 | lprobs_batch = np.expand_dims(np.concatenate([lprobs_batch, expert_lprobs_batch], axis=0), axis=1).astype(np.float32) 338 | 339 | loss, _ = tf.get_default_session().run([self.loss, self.step], feed_dict={ 340 | self.act_t: act_batch, 341 | self.obs_t: obs_batch, 342 | self.labels: labels, 343 | self.lprobs: lprobs_batch, 344 | self.lr: lr 345 | }) 346 | 347 | it.record('loss', loss) 348 | if it.heartbeat: 349 | print(it.itr_message()) 350 | mean_loss = it.pop_mean('loss') 351 | print('\tLoss:%f' % mean_loss) 352 | if logger: 353 | energy = tf.get_default_session().run(self.energy, 354 | feed_dict={self.act_t: acts, self.obs_t: obs}) 355 | logger.record_tabular('IRLAverageEnergy', np.mean(energy)) 356 | logger.record_tabular('IRLAverageLogQtau', np.mean(path_probs)) 357 | logger.record_tabular('IRLMedianLogQtau', np.median(path_probs)) 358 | 359 | energy = tf.get_default_session().run(self.energy, 360 | feed_dict={self.act_t: expert_acts, self.obs_t: expert_obs}) 361 | logger.record_tabular('IRLAverageExpertEnergy', np.mean(energy)) 362 | #logger.record_tabular('GCLAverageExpertLogPtau', np.mean(-energy-logZ)) 363 | logger.record_tabular('IRLAverageExpertLogQtau', np.mean(expert_probs)) 364 | logger.record_tabular('IRLMedianExpertLogQtau', np.median(expert_probs)) 365 | return mean_loss 366 | 367 | 368 | def eval(self, paths, **kwargs): 369 | """ 370 | Return bonus 371 | """ 372 | obs, acts = self.extract_paths(paths) 373 | 374 | energy = tf.get_default_session().run(self.energy, 375 | feed_dict={self.act_t: acts, self.obs_t: obs}) 376 | energy = -energy[:,0] 377 | return self.unpack(energy, paths) 378 | 379 | 380 | class GAN_GCL(TrajectoryIRL): 381 | """ 382 | Guided cost learning, GAN formulation with learned partition function 383 | See https://arxiv.org/pdf/1611.03852.pdf 384 | """ 385 | def __init__(self, env_spec, expert_trajs=None, 386 | discrim_arch=feedforward_energy, 387 | discrim_arch_args={}, 388 | l2_reg = 0, 389 | discount = 1.0, 390 | init_itrs = None, 391 | score_dtau=False, 392 | state_only=False, 393 | name='trajprior'): 394 | super(GAN_GCL, self).__init__() 395 | self.dO = env_spec.observation_space.flat_dim 396 | self.dU = env_spec.action_space.flat_dim 397 | self.score_dtau = score_dtau 398 | self.set_demos(expert_trajs) 399 | 400 | # build energy model 401 | with tf.variable_scope(name) as vs: 402 | # Should be batch_size x T x dO/dU 403 | self.obs_t = tf.placeholder(tf.float32, [None, None, self.dO], name='obs') 404 | self.act_t = tf.placeholder(tf.float32, [None, None, self.dU], name='act') 405 | self.traj_logprobs = tf.placeholder(tf.float32, [None, None], name='traj_probs') 406 | self.labels = tf.placeholder(tf.float32, [None, 1], name='labels') 407 | self.lr = tf.placeholder(tf.float32, (), name='lr') 408 | 409 | if state_only: 410 | obs_act = self.obs_t 411 | else: 412 | obs_act = tf.concat([self.obs_t, self.act_t], axis=2) 413 | 414 | with tf.variable_scope('discrim') as vs2: 415 | self.energy = discrim_arch(obs_act, **discrim_arch_args) 416 | discrim_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs2.name) 417 | 418 | self.energy_timestep = self.energy 419 | # Don't train separate log Z because we can't fully separate it from the energy function 420 | if discount >= 1.0: 421 | log_p_tau = tf.reduce_sum(-self.energy, axis=1) 422 | else: 423 | log_p_tau = discounted_reduce_sum(-self.energy, discount=discount, axis=1) 424 | log_q_tau = tf.reduce_sum(self.traj_logprobs, axis=1, keep_dims=True) 425 | 426 | # numerical stability trick 427 | log_pq = tf.reduce_logsumexp([log_p_tau, log_q_tau], axis=0) 428 | self.d_tau = tf.exp(log_p_tau-log_pq) 429 | cent_loss = -tf.reduce_mean(self.labels*(log_p_tau-log_pq) + (1-self.labels)*(log_q_tau-log_pq)) 430 | 431 | if l2_reg > 0: 432 | reg_loss = l2_reg*tf.reduce_sum([tf.reduce_sum(tf.square(var)) for var in discrim_vars]) 433 | else: 434 | reg_loss = 0 435 | 436 | #self.predictions = tf.nn.sigmoid(logits) 437 | self.loss = cent_loss + reg_loss 438 | self.step = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(self.loss) 439 | self._make_param_ops(vs) 440 | 441 | @property 442 | def score_trajectories(self): 443 | return False 444 | 445 | 446 | def fit(self, paths, policy=None, batch_size=32, max_itrs=100, logger=None, lr=1e-3,**kwargs): 447 | self._compute_path_probs(paths, insert=True) 448 | self.eval_expert_probs(self.expert_trajs, policy, insert=True) 449 | obs, acts, path_probs = self.extract_paths(paths, keys=('observations', 'actions', 'a_logprobs')) 450 | expert_obs, expert_acts, expert_probs = self.extract_paths(self.expert_trajs, keys=('observations', 'actions', 'a_logprobs')) 451 | 452 | # Train discriminator 453 | for it in TrainingIterator(max_itrs, heartbeat=5): 454 | obs_batch, act_batch, lprobs_batch = \ 455 | self.sample_batch(obs, acts, path_probs, batch_size=batch_size) 456 | 457 | expert_obs_batch, expert_act_batch, expert_lprobs_batch = \ 458 | self.sample_batch(expert_obs, expert_acts, expert_probs, batch_size=batch_size) 459 | T = expert_obs_batch.shape[1] 460 | 461 | labels = np.zeros((batch_size*2, 1)) 462 | labels[batch_size:] = 1.0 463 | obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0) 464 | act_batch = np.concatenate([act_batch, expert_act_batch], axis=0) 465 | lprobs_batch = np.concatenate([lprobs_batch, expert_lprobs_batch], axis=0) 466 | 467 | loss, _ = tf.get_default_session().run([self.loss, self.step], feed_dict={ 468 | self.act_t: act_batch, 469 | self.obs_t: obs_batch, 470 | self.labels: labels, 471 | self.traj_logprobs: lprobs_batch, 472 | self.lr: lr, 473 | }) 474 | 475 | it.record('loss', loss) 476 | if it.heartbeat: 477 | print(it.itr_message()) 478 | mean_loss = it.pop_mean('loss') 479 | print('\tLoss:%f' % mean_loss) 480 | 481 | if logger: 482 | energy, dtau = tf.get_default_session().run([self.energy_timestep, self.d_tau], 483 | feed_dict={self.act_t: acts, self.obs_t: obs, 484 | self.traj_logprobs: path_probs}) 485 | #logger.record_tabular('GCLLogZ', logZ) 486 | logger.record_tabular('IRLAverageEnergy', np.mean(energy)) 487 | #logger.record_tabular('GCLAverageLogPtau', np.mean(-energy)) 488 | logger.record_tabular('IRLAverageLogQtau', np.mean(path_probs)) 489 | logger.record_tabular('IRLMedianLogQtau', np.median(path_probs)) 490 | logger.record_tabular('IRLAverageDtau', np.mean(dtau)) 491 | 492 | energy, dtau = tf.get_default_session().run([self.energy_timestep, self.d_tau], 493 | feed_dict={self.act_t: expert_acts, self.obs_t: expert_obs, 494 | self.traj_logprobs: expert_probs}) 495 | logger.record_tabular('IRLAverageExpertEnergy', np.mean(energy)) 496 | #logger.record_tabular('GCLAverageExpertLogPtau', np.mean(-energy)) 497 | logger.record_tabular('IRLAverageExpertLogQtau', np.mean(expert_probs)) 498 | logger.record_tabular('IRLMedianExpertLogQtau', np.median(expert_probs)) 499 | logger.record_tabular('IRLAverageExpertDtau', np.mean(dtau)) 500 | return mean_loss 501 | 502 | def eval(self, paths, **kwargs): 503 | """ 504 | Return bonus 505 | """ 506 | obs, acts = self.extract_paths(paths) 507 | 508 | scores = tf.get_default_session().run(self.energy, 509 | feed_dict={self.act_t: acts, self.obs_t: obs}) 510 | scores = -scores[:,:,0] 511 | return scores 512 | 513 | -------------------------------------------------------------------------------- /inverse_rl/models/tf_util.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | REG_VARS = 'reg_vars' 5 | 6 | def linear(X, dout, name, bias=True): 7 | with tf.variable_scope(name): 8 | dX = int(X.get_shape()[-1]) 9 | W = tf.get_variable('W', shape=(dX, dout)) 10 | tf.add_to_collection(REG_VARS, W) 11 | if bias: 12 | b = tf.get_variable('b', initializer=tf.constant(np.zeros(dout).astype(np.float32))) 13 | else: 14 | b = 0 15 | return tf.matmul(X, W)+b 16 | 17 | def discounted_reduce_sum(X, discount, axis=-1): 18 | if discount != 1.0: 19 | disc = tf.cumprod(discount*tf.ones_like(X), axis=axis) 20 | else: 21 | disc = 1.0 22 | return tf.reduce_sum(X*disc, axis=axis) 23 | 24 | def assert_shape(tens, shape): 25 | assert tens.get_shape().is_compatible_with(shape) 26 | 27 | def relu_layer(X, dout, name): 28 | return tf.nn.relu(linear(X, dout, name)) 29 | 30 | def softplus_layer(X, dout, name): 31 | return tf.nn.softplus(linear(X, dout, name)) 32 | 33 | def tanh_layer(X, dout, name): 34 | return tf.nn.tanh(linear(X, dout, name)) 35 | 36 | def get_session_config(): 37 | session_config = tf.ConfigProto() 38 | session_config.gpu_options.allow_growth = True 39 | #session_config.gpu_options.per_process_gpu_memory_fraction = 0.2 40 | return session_config 41 | 42 | 43 | def load_prior_params(pkl_fname): 44 | import joblib 45 | with tf.Session(config=get_session_config()): 46 | params = joblib.load(pkl_fname) 47 | tf.reset_default_graph() 48 | #joblib.dump(params, file_name, compress=3) 49 | params = params['irl_params'] 50 | #print(params) 51 | assert params is not None 52 | return params 53 | -------------------------------------------------------------------------------- /inverse_rl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from inverse_rl.utils.general import * 2 | -------------------------------------------------------------------------------- /inverse_rl/utils/general.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | def flatten_list(lol): 5 | return [ a for b in lol for a in b ] 6 | 7 | class TrainingIterator(object): 8 | def __init__(self, itrs, heartbeat=float('inf')): 9 | self.itrs = itrs 10 | self.heartbeat_time = heartbeat 11 | self.__vals = {} 12 | 13 | def random_idx(self, N, size): 14 | return np.random.randint(0, N, size=size) 15 | 16 | @property 17 | def itr(self): 18 | return self.__itr 19 | 20 | @property 21 | def heartbeat(self): 22 | return self.__heartbeat 23 | 24 | @property 25 | def elapsed(self): 26 | assert self.heartbeat, 'elapsed is only valid when heartbeat=True' 27 | return self.__elapsed 28 | 29 | def itr_message(self): 30 | return '==> Itr %d/%d (elapsed:%.2f)' % (self.itr+1, self.itrs, self.elapsed) 31 | 32 | def record(self, key, value): 33 | if key in self.__vals: 34 | self.__vals[key].append(value) 35 | else: 36 | self.__vals[key] = [value] 37 | 38 | def pop(self, key): 39 | vals = self.__vals.get(key, []) 40 | del self.__vals[key] 41 | return vals 42 | 43 | def pop_mean(self, key): 44 | return np.mean(self.pop(key)) 45 | 46 | def __iter__(self): 47 | prev_time = time.time() 48 | self.__heartbeat = False 49 | for i in range(self.itrs): 50 | self.__itr = i 51 | cur_time = time.time() 52 | if (cur_time-prev_time) > self.heartbeat_time or i==(self.itrs-1): 53 | self.__heartbeat = True 54 | self.__elapsed = cur_time-prev_time 55 | prev_time = cur_time 56 | yield self 57 | self.__heartbeat = False -------------------------------------------------------------------------------- /inverse_rl/utils/hyper_sweep.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage 3 | 4 | args = { 5 | 'param1': [1e-3, 1e-2, 1e-2], 6 | 'param2': [1,5,10,20], 7 | } 8 | 9 | run_sweep_parallel(func, args) 10 | 11 | or 12 | 13 | run_sweep_serial(func, args) 14 | 15 | """ 16 | import itertools 17 | import multiprocessing 18 | import random 19 | from datetime import datetime 20 | 21 | class Sweeper(object): 22 | def __init__(self, hyper_config, repeat): 23 | self.hyper_config = hyper_config 24 | self.repeat = repeat 25 | 26 | def __iter__(self): 27 | count = 0 28 | for _ in range(self.repeat): 29 | for config in itertools.product(*[val for val in self.hyper_config.values()]): 30 | kwargs = {key:config[i] for i, key in enumerate(self.hyper_config.keys())} 31 | timestamp = datetime.now().strftime('%Y_%m_%d_%H_%M_%S') 32 | kwargs['exp_name'] = "%s_%d" % (timestamp, count) 33 | count += 1 34 | yield kwargs 35 | 36 | 37 | 38 | def run_sweep_serial(run_method, params, repeat=1): 39 | sweeper = Sweeper(params, repeat) 40 | for config in sweeper: 41 | import tensorflow as tf; tf.reset_default_graph() 42 | run_method(**config) 43 | 44 | def kwargs_wrapper(args_method): 45 | import tensorflow as tf; tf.reset_default_graph() 46 | args, method = args_method 47 | return method(**args) 48 | 49 | 50 | def run_sweep_parallel(run_method, params, repeat=1, num_cpu=multiprocessing.cpu_count()): 51 | sweeper = Sweeper(params, repeat) 52 | pool = multiprocessing.Pool(num_cpu) 53 | exp_args = [] 54 | for config in sweeper: 55 | exp_args.append((config, run_method)) 56 | random.shuffle(exp_args) 57 | pool.map(kwargs_wrapper, exp_args) 58 | 59 | 60 | def example_run_method(exp_name, param1, param2='a', param3=3, param4=4): 61 | import time 62 | time.sleep(1.0) 63 | print(exp_name, param1, param2, param3, param4) 64 | 65 | 66 | if __name__ == "__main__": 67 | sweep_op = { 68 | 'param1': [1e-3, 1e-2, 1e-1], 69 | 'param2': [1,5,10,20], 70 | 'param3': [True, False] 71 | } 72 | run_sweep_parallel(example_run_method, sweep_op, repeat=2) -------------------------------------------------------------------------------- /inverse_rl/utils/hyperparametrized.py: -------------------------------------------------------------------------------- 1 | CLSNAME = '__clsname__' 2 | _HYPER_ = '__hyper__' 3 | _HYPERNAME_ = '__hyper_clsname__' 4 | 5 | 6 | def extract_hyperparams(obj): 7 | if any([isinstance(obj, type_) for type_ in (int, float, str)]): 8 | return obj 9 | elif isinstance(type(obj), Hyperparametrized): 10 | hypers = getattr(obj, _HYPER_) 11 | hypers[CLSNAME] = getattr(obj, _HYPERNAME_) 12 | for attr in hypers: 13 | hypers[attr] = extract_hyperparams(hypers[attr]) 14 | return hypers 15 | return type(obj).__name__ 16 | 17 | class Hyperparametrized(type): 18 | def __new__(self, clsname, bases, clsdict): 19 | old_init = clsdict.get('__init__', bases[0].__init__) 20 | def init_wrapper(inst, *args, **kwargs): 21 | hyper = getattr(inst, _HYPER_, {}) 22 | hyper.update(kwargs) 23 | setattr(inst, _HYPER_, hyper) 24 | 25 | if getattr(inst, _HYPERNAME_, None) is None: 26 | setattr(inst, _HYPERNAME_, clsname) 27 | return old_init(inst, *args, **kwargs) 28 | clsdict['__init__'] = init_wrapper 29 | 30 | cls = super(Hyperparametrized, self).__new__(self, clsname, bases, clsdict) 31 | return cls 32 | 33 | 34 | class HyperparamWrapper(object, metaclass=Hyperparametrized): 35 | def __init__(self, **hyper_kwargs): 36 | pass 37 | 38 | if __name__ == "__main__": 39 | class Algo1(object, metaclass=Hyperparametrized): 40 | def __init__(self, hyper1=1.0, hyper2=2.0, model1=None): 41 | pass 42 | 43 | 44 | class Algo2(Algo1): 45 | def __init__(self, hyper3=5.0, **kwargs): 46 | super(Algo2, self).__init__(**kwargs) 47 | 48 | 49 | class Model1(object, metaclass=Hyperparametrized): 50 | def __init__(self, hyper1=None): 51 | pass 52 | 53 | 54 | def get_params_json(**kwargs): 55 | hyper_dict = extract_hyperparams(HyperparamWrapper(**kwargs)) 56 | del hyper_dict[CLSNAME] 57 | return hyper_dict 58 | 59 | m1 = Model1(hyper1='Test') 60 | a1 = Algo2(hyper1=1.0, hyper2=5.0, hyper3=10.0, model1=m1) 61 | 62 | print( isinstance(type(a1), Hyperparametrized)) 63 | print(get_params_json(a1=a1)) 64 | -------------------------------------------------------------------------------- /inverse_rl/utils/log_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import joblib 4 | import json 5 | import contextlib 6 | 7 | import rllab.misc.logger as rllablogger 8 | import tensorflow as tf 9 | import numpy as np 10 | 11 | from inverse_rl.utils.hyperparametrized import extract_hyperparams 12 | 13 | @contextlib.contextmanager 14 | def rllab_logdir(algo=None, dirname=None): 15 | if dirname: 16 | rllablogger.set_snapshot_dir(dirname) 17 | dirname = rllablogger.get_snapshot_dir() 18 | rllablogger.add_tabular_output(os.path.join(dirname, 'progress.csv')) 19 | if algo: 20 | with open(os.path.join(dirname, 'params.json'), 'w') as f: 21 | params = extract_hyperparams(algo) 22 | json.dump(params, f) 23 | yield dirname 24 | rllablogger.remove_tabular_output(os.path.join(dirname, 'progress.csv')) 25 | 26 | 27 | def get_expert_fnames(log_dir, n=5): 28 | print('Looking for paths') 29 | import re 30 | itr_reg = re.compile(r"itr_(?P[0-9]+)\.pkl") 31 | 32 | itr_files = [] 33 | for i, filename in enumerate(os.listdir(log_dir)): 34 | m = itr_reg.match(filename) 35 | if m: 36 | itr_count = m.group('itr_count') 37 | itr_files.append((itr_count, filename)) 38 | 39 | itr_files = sorted(itr_files, key=lambda x: int(x[0]), reverse=True)[:n] 40 | for itr_file_and_count in itr_files: 41 | fname = os.path.join(log_dir, itr_file_and_count[1]) 42 | print('Loading %s' % fname) 43 | yield fname 44 | 45 | 46 | def load_experts(fname, max_files=float('inf'), min_return=None): 47 | config = tf.ConfigProto() 48 | config.gpu_options.allow_growth = True 49 | if hasattr(fname, '__iter__'): 50 | paths = [] 51 | for fname_ in fname: 52 | tf.reset_default_graph() 53 | with tf.Session(config=config): 54 | snapshot_dict = joblib.load(fname_) 55 | paths.extend(snapshot_dict['paths']) 56 | else: 57 | with tf.Session(config=config): 58 | snapshot_dict = joblib.load(fname) 59 | paths = snapshot_dict['paths'] 60 | tf.reset_default_graph() 61 | 62 | trajs = [] 63 | for path in paths: 64 | obses = path['observations'] 65 | actions = path['actions'] 66 | returns = path['returns'] 67 | total_return = np.sum(returns) 68 | if (min_return is None) or (total_return >= min_return): 69 | traj = {'observations': obses, 'actions': actions} 70 | trajs.append(traj) 71 | random.shuffle(trajs) 72 | print('Loaded %d trajectories' % len(trajs)) 73 | return trajs 74 | 75 | 76 | def load_latest_experts(logdir, n=5, min_return=None): 77 | return load_experts(get_expert_fnames(logdir, n=n), min_return=min_return) 78 | 79 | 80 | def load_latest_experts_multiple_runs(logdir, n=5): 81 | paths = [] 82 | for i, dirname in enumerate(os.listdir(logdir)): 83 | dirname = os.path.join(logdir, dirname) 84 | if os.path.isdir(dirname): 85 | print('Loading experts from %s' % dirname) 86 | paths.extend(load_latest_experts(dirname, n=n)) 87 | return paths 88 | -------------------------------------------------------------------------------- /inverse_rl/utils/math_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy as sp 3 | import scipy.stats 4 | 5 | def rle(inarray): 6 | """ run length encoding. Partial credit to R rle function. 7 | Multi datatype arrays catered for including non Numpy 8 | returns: tuple (runlengths, startpositions, values) """ 9 | ia = np.array(inarray) # force numpy 10 | n = len(ia) 11 | if n == 0: 12 | return (None, None, None) 13 | else: 14 | y = np.array(ia[1:] != ia[:-1]) # pairwise unequal (string safe) 15 | i = np.append(np.where(y), n - 1) # must include last element posi 16 | z = np.diff(np.append(-1, i)) # run lengths 17 | p = np.cumsum(np.append(0, z))[:-1] # positions 18 | return(z, p, ia[i]) 19 | 20 | def split_list_by_lengths(values, lengths): 21 | """ 22 | 23 | >>> split_list_by_lengths([0,0,0,1,1,1,2,2,2], [2,2,5]) 24 | [[0, 0], [0, 1], [1, 1, 2, 2, 2]] 25 | """ 26 | assert np.sum(lengths) == len(values) 27 | idxs = np.cumsum(lengths) 28 | idxs = np.insert(idxs, 0, 0) 29 | return [ values[idxs[i]:idxs[i+1] ] for i in range(len(idxs)-1)] 30 | 31 | def clip_sing(X, clip_val=1): 32 | U, E, V = np.linalg.svd(X, full_matrices=False) 33 | E = np.clip(E, -clip_val, clip_val) 34 | return U.dot(np.diag(E)).dot(V) 35 | 36 | def gauss_log_pdf(params, x): 37 | mean, log_diag_std = params 38 | N, d = mean.shape 39 | cov = np.square(np.exp(log_diag_std)) 40 | diff = x-mean 41 | exp_term = -0.5 * np.sum(np.square(diff)/cov, axis=1) 42 | norm_term = -0.5*d*np.log(2*np.pi) 43 | var_term = -0.5 * np.sum(np.log(cov), axis=1) 44 | log_probs = norm_term + var_term + exp_term 45 | return log_probs #sp.stats.multivariate_normal.logpdf(x, mean=mean, cov=cov) 46 | 47 | def categorical_log_pdf(params, x, one_hot=True): 48 | if not one_hot: 49 | raise NotImplementedError() 50 | probs = params[0] 51 | return np.log(np.max(probs * x, axis=1)) 52 | 53 | -------------------------------------------------------------------------------- /scripts/ant_data_collect.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from inverse_rl.algos.trpo import TRPO 4 | from inverse_rl.models.tf_util import get_session_config 5 | from sandbox.rocky.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy 6 | from sandbox.rocky.tf.envs.base import TfEnv 7 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline 8 | 9 | from inverse_rl.envs.env_utils import CustomGymEnv 10 | from inverse_rl.utils.log_utils import rllab_logdir 11 | from inverse_rl.utils.hyper_sweep import run_sweep_parallel, run_sweep_serial 12 | 13 | 14 | def main(exp_name, ent_wt=1.0): 15 | tf.reset_default_graph() 16 | env = TfEnv(CustomGymEnv('CustomAnt-v0', record_video=False, record_log=False)) 17 | policy = GaussianMLPPolicy(name='policy', env_spec=env.spec, hidden_sizes=(32, 32)) 18 | with tf.Session(config=get_session_config()) as sess: 19 | algo = TRPO( 20 | env=env, 21 | sess=sess, 22 | policy=policy, 23 | n_itr=1500, 24 | batch_size=20000, 25 | max_path_length=500, 26 | discount=0.99, 27 | store_paths=True, 28 | entropy_weight=ent_wt, 29 | baseline=LinearFeatureBaseline(env_spec=env.spec), 30 | exp_name=exp_name, 31 | ) 32 | with rllab_logdir(algo=algo, dirname='data/ant_data_collect/%s'%exp_name): 33 | algo.train() 34 | 35 | if __name__ == "__main__": 36 | params_dict = { 37 | 'ent_wt': [0.1] 38 | } 39 | run_sweep_parallel(main, params_dict, repeat=4) 40 | -------------------------------------------------------------------------------- /scripts/ant_irl.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from sandbox.rocky.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy 4 | from sandbox.rocky.tf.envs.base import TfEnv 5 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline 6 | from rllab.envs.gym_env import GymEnv 7 | 8 | from inverse_rl.envs.env_utils import CustomGymEnv 9 | from inverse_rl.algos.irl_trpo import IRLTRPO 10 | from inverse_rl.models.airl_state import * 11 | from inverse_rl.utils.log_utils import rllab_logdir, load_latest_experts, load_latest_experts_multiple_runs 12 | from inverse_rl.utils.hyper_sweep import run_sweep_parallel, run_sweep_serial 13 | 14 | def main(exp_name=None, fusion=False): 15 | env = TfEnv(CustomGymEnv('CustomAnt-v0', record_video=False, record_log=False)) 16 | 17 | # load ~2 iterations worth of data from each forward RL experiment as demos 18 | experts = load_latest_experts_multiple_runs('data/ant_data_collect', n=2) 19 | 20 | irl_model = AIRL(env=env, expert_trajs=experts, state_only=True, fusion=fusion, max_itrs=10) 21 | 22 | policy = GaussianMLPPolicy(name='policy', env_spec=env.spec, hidden_sizes=(32, 32)) 23 | algo = IRLTRPO( 24 | env=env, 25 | policy=policy, 26 | irl_model=irl_model, 27 | n_itr=1000, 28 | batch_size=10000, 29 | max_path_length=500, 30 | discount=0.99, 31 | store_paths=True, 32 | irl_model_wt=1.0, 33 | entropy_weight=0.1, 34 | zero_environment_reward=True, 35 | baseline=LinearFeatureBaseline(env_spec=env.spec), 36 | ) 37 | with rllab_logdir(algo=algo, dirname='data/ant_state_irl/%s' % exp_name): 38 | with tf.Session(): 39 | algo.train() 40 | 41 | if __name__ == "__main__": 42 | params_dict = { 43 | 'fusion': [True] 44 | } 45 | run_sweep_parallel(main, params_dict, repeat=3) 46 | 47 | -------------------------------------------------------------------------------- /scripts/ant_transfer_disabled.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | 4 | from sandbox.rocky.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy 5 | from sandbox.rocky.tf.envs.base import TfEnv 6 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline 7 | from rllab.envs.gym_env import GymEnv 8 | 9 | 10 | from inverse_rl.algos.irl_trpo import IRLTRPO 11 | from inverse_rl.envs.env_utils import CustomGymEnv 12 | from inverse_rl.models.airl_state import * 13 | from inverse_rl.models.tf_util import load_prior_params 14 | from inverse_rl.utils.log_utils import rllab_logdir, load_latest_experts 15 | from inverse_rl.utils.hyper_sweep import run_sweep_parallel, run_sweep_serial 16 | 17 | 18 | DATA_DIR = 'data/ant_state_irl' 19 | def main(exp_name, params_folder=None): 20 | env = TfEnv(CustomGymEnv('DisabledAnt-v0', record_video=False, record_log=False)) 21 | 22 | irl_itr = 100 # earlier IRL iterations overfit less; 100 seems to work well. 23 | params_file = os.path.join(DATA_DIR, '%s/itr_%d.pkl' % (params_folder, irl_itr)) 24 | prior_params = load_prior_params(params_file) 25 | 26 | irl_model = AIRL(env=env, expert_trajs=None, state_only=True) 27 | policy = GaussianMLPPolicy(name='policy', env_spec=env.spec, hidden_sizes=(32, 32)) 28 | algo = IRLTRPO( 29 | init_irl_params=prior_params, 30 | env=env, 31 | policy=policy, 32 | irl_model=irl_model, 33 | n_itr=1000, 34 | batch_size=10000, 35 | max_path_length=500, 36 | discount=0.99, 37 | store_paths=False, 38 | train_irl=False, 39 | irl_model_wt=1.0, 40 | entropy_weight=0.1, 41 | zero_environment_reward=True, 42 | baseline=LinearFeatureBaseline(env_spec=env.spec), 43 | log_params_folder=params_folder, 44 | log_experiment_name=exp_name, 45 | ) 46 | with rllab_logdir(algo=algo, dirname='data/ant_transfer/%s'%exp_name): 47 | with tf.Session(): 48 | algo.train() 49 | 50 | if __name__ == "__main__": 51 | import os 52 | params_folders = os.listdir(DATA_DIR) 53 | params_dict = { 54 | 'params_folder': params_folders, 55 | } 56 | run_sweep_parallel(main, params_dict, repeat=3) 57 | 58 | -------------------------------------------------------------------------------- /scripts/pendulum_data_collect.py: -------------------------------------------------------------------------------- 1 | from sandbox.rocky.tf.algos.trpo import TRPO 2 | from sandbox.rocky.tf.envs.base import TfEnv 3 | from sandbox.rocky.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy 4 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline 5 | from rllab.envs.gym_env import GymEnv 6 | 7 | from inverse_rl.utils.log_utils import rllab_logdir 8 | 9 | def main(): 10 | env = TfEnv(GymEnv('Pendulum-v0', record_video=False, record_log=False)) 11 | policy = GaussianMLPPolicy(name='policy', env_spec=env.spec, hidden_sizes=(32, 32)) 12 | algo = TRPO( 13 | env=env, 14 | policy=policy, 15 | n_itr=200, 16 | batch_size=1000, 17 | max_path_length=100, 18 | discount=0.99, 19 | store_paths=True, 20 | baseline=LinearFeatureBaseline(env_spec=env.spec) 21 | ) 22 | 23 | with rllab_logdir(algo=algo, dirname='data/pendulum'): 24 | algo.train() 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /scripts/pendulum_gail.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from sandbox.rocky.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy 4 | from sandbox.rocky.tf.envs.base import TfEnv 5 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline 6 | from rllab.envs.gym_env import GymEnv 7 | 8 | 9 | from inverse_rl.algos.irl_trpo import IRLTRPO 10 | from inverse_rl.models.imitation_learning import GAIL 11 | from inverse_rl.utils.log_utils import rllab_logdir, load_latest_experts 12 | 13 | def main(): 14 | env = TfEnv(GymEnv('Pendulum-v0', record_video=False, record_log=False)) 15 | 16 | experts = load_latest_experts('data/pendulum', n=5) 17 | 18 | irl_model = GAIL(env_spec=env.spec, expert_trajs=experts) 19 | policy = GaussianMLPPolicy(name='policy', env_spec=env.spec, hidden_sizes=(32, 32)) 20 | algo = IRLTRPO( 21 | env=env, 22 | policy=policy, 23 | irl_model=irl_model, 24 | n_itr=200, 25 | batch_size=1000, 26 | max_path_length=100, 27 | discount=0.99, 28 | store_paths=True, 29 | discrim_train_itrs=50, 30 | irl_model_wt=1.0, 31 | entropy_weight=0.0, # GAIL should not use entropy unless for exploration 32 | zero_environment_reward=True, 33 | baseline=LinearFeatureBaseline(env_spec=env.spec) 34 | ) 35 | 36 | with rllab_logdir(algo=algo, dirname='data/pendulum_gail'): 37 | with tf.Session(): 38 | algo.train() 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /scripts/pendulum_irl.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from sandbox.rocky.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy 4 | from sandbox.rocky.tf.envs.base import TfEnv 5 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline 6 | from rllab.envs.gym_env import GymEnv 7 | 8 | 9 | from inverse_rl.algos.irl_trpo import IRLTRPO 10 | from inverse_rl.models.imitation_learning import AIRLStateAction 11 | from inverse_rl.utils.log_utils import rllab_logdir, load_latest_experts 12 | 13 | def main(): 14 | env = TfEnv(GymEnv('Pendulum-v0', record_video=False, record_log=False)) 15 | 16 | experts = load_latest_experts('data/pendulum', n=5) 17 | 18 | irl_model = AIRLStateAction(env_spec=env.spec, expert_trajs=experts) 19 | policy = GaussianMLPPolicy(name='policy', env_spec=env.spec, hidden_sizes=(32, 32)) 20 | algo = IRLTRPO( 21 | env=env, 22 | policy=policy, 23 | irl_model=irl_model, 24 | n_itr=200, 25 | batch_size=1000, 26 | max_path_length=100, 27 | discount=0.99, 28 | store_paths=True, 29 | discrim_train_itrs=50, 30 | irl_model_wt=1.0, 31 | entropy_weight=0.1, # this should be 1.0 but 0.1 seems to work better 32 | zero_environment_reward=True, 33 | baseline=LinearFeatureBaseline(env_spec=env.spec) 34 | ) 35 | 36 | with rllab_logdir(algo=algo, dirname='data/pendulum_gcl'): 37 | with tf.Session(): 38 | algo.train() 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /tabular_maxent_irl/README.md: -------------------------------------------------------------------------------- 1 | # Tabular MaxEnt IRL 2 | 3 | This folder contains a self-enclosed version of tabular maximum entropy (MaxEnt) IRL. 4 | 5 | To run on a randomly generated MDP: 6 | ``` 7 | python maxent_irl.py 8 | ``` 9 | 10 | 11 | -------------------------------------------------------------------------------- /tabular_maxent_irl/maxent_irl.py: -------------------------------------------------------------------------------- 1 | """ 2 | This implements Maximum Entropy IRL using dynamic programming. This 3 | 4 | Simply call tabular_maxent_irl(env, expert_visitations) 5 | The expert visitations can be generated via the compute_visitations function on an expert q_function (exact), 6 | or using compute_visitations_demo on demos (approximate) 7 | 8 | """ 9 | import numpy as np 10 | from utils import one_hot_to_flat, flat_to_one_hot 11 | from q_iteration import q_iteration, logsumexp, get_policy 12 | from utils import TrainingIterator 13 | from utils import gd_momentum_optimizer, adam_optimizer 14 | 15 | 16 | def compute_visitation(env, q_fn, ent_wt=1.0, T=50, discount=1.0): 17 | pol_probs = get_policy(q_fn, ent_wt=ent_wt) 18 | 19 | dim_obs = env.observation_space.flat_dim 20 | dim_act = env.action_space.flat_dim 21 | state_visitation = np.expand_dims(env.initial_state_distribution, axis=1) 22 | t_matrix = env.transition_matrix # S x A x S 23 | sa_visit_t = np.zeros((dim_obs, dim_act, T)) 24 | 25 | for i in range(T): 26 | sa_visit = state_visitation * pol_probs 27 | sa_visit_t[:,:,i] = sa_visit #(discount**i) * sa_visit 28 | # sum-out (SA)S 29 | new_state_visitation = np.einsum('ij,ijk->k', sa_visit, t_matrix) 30 | state_visitation = np.expand_dims(new_state_visitation, axis=1) 31 | return np.sum(sa_visit_t, axis=2) / float(T) 32 | 33 | 34 | def compute_vistation_demos(env, demos): 35 | dim_obs = env.observation_space.flat_dim 36 | dim_act = env.action_space.flat_dim 37 | counts = np.zeros((dim_obs, dim_act)) 38 | 39 | for demo in demos: 40 | obs = demo['observations'] 41 | act = demo['actions'] 42 | state_ids = one_hot_to_flat(obs) 43 | T = len(state_ids) 44 | for t in range(T): 45 | counts[state_ids[t], act[t]] += 1 46 | return counts / float(np.sum(counts)) 47 | 48 | 49 | def sample_states(env, q_fn, visitation_probs, n_sample, ent_wt): 50 | dS, dA = visitation_probs.shape 51 | samples = np.random.choice(np.arange(dS*dA), size=n_sample, p=visitation_probs.reshape(dS*dA)) 52 | policy = get_policy(q_fn, ent_wt=ent_wt) 53 | observations = samples // dA 54 | actions = samples % dA 55 | a_logprobs = np.log(policy[observations, actions]) 56 | 57 | observations_next = [] 58 | for i in range(n_sample): 59 | t_distr = env.tabular_trans_distr(observations[i], actions[i]) 60 | next_state = flat_to_one_hot(np.random.choice(np.arange(len(t_distr)), p=t_distr), ndim=dS) 61 | observations_next.append(next_state) 62 | observations_next = np.array(observations_next) 63 | 64 | return {'observations': flat_to_one_hot(observations, ndim=dS), 65 | 'actions': flat_to_one_hot(actions, ndim=dA), 66 | 'a_logprobs': a_logprobs, 67 | 'observations_next': observations_next} 68 | 69 | 70 | def tabular_maxent_irl(env, demo_visitations, num_itrs=50, ent_wt=1.0, lr=1e-3, state_only=False, 71 | discount=0.99, T=5): 72 | dim_obs = env.observation_space.flat_dim 73 | dim_act = env.action_space.flat_dim 74 | 75 | # Initialize policy and reward function 76 | reward_fn = np.zeros((dim_obs, dim_act)) 77 | q_rew = np.zeros((dim_obs, dim_act)) 78 | 79 | update = adam_optimizer(lr) 80 | 81 | for it in TrainingIterator(num_itrs, heartbeat=1.0): 82 | q_itrs = 20 if it.itr>5 else 100 83 | ### compute policy in closed form 84 | q_rew = q_iteration(env, reward_matrix=reward_fn, ent_wt=ent_wt, warmstart_q=q_rew, K=q_itrs, gamma=discount) 85 | 86 | ### update reward 87 | # need to count how often the policy will visit a particular (s, a) pair 88 | pol_visitations = compute_visitation(env, q_rew, ent_wt=ent_wt, T=T, discount=discount) 89 | 90 | grad = -(demo_visitations - pol_visitations) 91 | it.record('VisitationInfNormError', np.max(np.abs(grad))) 92 | if state_only: 93 | grad = np.sum(grad, axis=1, keepdims=True) 94 | reward_fn = update(reward_fn, grad) 95 | 96 | if it.heartbeat: 97 | print(it.itr_message()) 98 | print('\tVisitationError:',it.pop_mean('VisitationInfNormError')) 99 | return reward_fn, q_rew 100 | 101 | 102 | if __name__ == "__main__": 103 | # test IRL 104 | from q_iteration import q_iteration 105 | from simple_env import random_env 106 | np.set_printoptions(suppress=True) 107 | 108 | # Environment parameters 109 | env = random_env(16, 4, seed=1, terminate=False, t_sparsity=0.8) 110 | dS = env.spec.observation_space.flat_dim 111 | dU = env.spec.action_space.flat_dim 112 | dO = 8 113 | ent_wt = 1.0 114 | discount = 0.9 115 | obs_matrix = np.random.randn(dS, dO) 116 | 117 | # Compute optimal policy for double checking 118 | true_q = q_iteration(env, K=150, ent_wt=ent_wt, gamma=discount) 119 | true_sa_visits = compute_visitation(env, true_q, ent_wt=ent_wt, T=5, discount=discount) 120 | expert_pol = get_policy(true_q, ent_wt=ent_wt) 121 | 122 | # Run MaxEnt IRL State-only 123 | learned_rew, learned_q = tabular_maxent_irl(env, true_sa_visits, lr=0.01, num_itrs=1000, 124 | ent_wt=ent_wt, state_only=True, 125 | discount=discount) 126 | learned_pol = get_policy(learned_q, ent_wt=ent_wt) 127 | 128 | 129 | # Normalize reward (if state_only=True, reward is accurate up to a constant) 130 | adjusted_rew = learned_rew - np.mean(learned_rew) + np.mean(env.rew_matrix) 131 | 132 | diff_rew = np.abs(env.rew_matrix - adjusted_rew) 133 | diff_pol = np.abs(expert_pol - learned_pol) 134 | print('----- Results State Only -----') 135 | print('InfNormRewError', np.max(diff_rew)) 136 | print('InfNormPolicyError', np.max(diff_pol)) 137 | print('AvdDiffRew', np.mean(diff_rew)) 138 | print('AvgDiffPol', np.mean(diff_pol)) 139 | print('True Reward', env.rew_matrix) 140 | print('Learned Reward', adjusted_rew) 141 | 142 | 143 | # Run MaxEnt IRL State-Action 144 | learned_rew, learned_q = tabular_maxent_irl(env, true_sa_visits, lr=0.01, num_itrs=1000, 145 | ent_wt=ent_wt, state_only=False, 146 | discount=discount) 147 | learned_pol = get_policy(learned_q, ent_wt=ent_wt) 148 | 149 | # Normalize reward (if state_only=True, reward is accurate up to a constant) 150 | adjusted_rew = learned_rew - np.mean(learned_rew) + np.mean(env.rew_matrix) 151 | 152 | diff_rew = np.abs(env.rew_matrix - adjusted_rew) 153 | diff_pol = np.abs(expert_pol - learned_pol) 154 | print('----- Results State-Action -----') 155 | print('InfNormRewError', np.max(diff_rew)) 156 | print('InfNormPolicyError', np.max(diff_pol)) 157 | print('AvdDiffRew', np.mean(diff_rew)) 158 | print('AvgDiffPol', np.mean(diff_pol)) 159 | print('True Reward', env.rew_matrix) 160 | print('Learned Reward', adjusted_rew) 161 | -------------------------------------------------------------------------------- /tabular_maxent_irl/q_iteration.py: -------------------------------------------------------------------------------- 1 | """ 2 | Use q-iteration to solve for an optimal policy 3 | 4 | Usage: q_iteration(env, gamma=discount factor, ent_wt= entropy bonus) 5 | """ 6 | import numpy as np 7 | from scipy.misc import logsumexp as sp_lse 8 | 9 | def softmax(q, alpha=1.0): 10 | q = (1.0/alpha)*q 11 | q = q-np.max(q) 12 | probs = np.exp(q) 13 | probs = probs/np.sum(probs) 14 | return probs 15 | 16 | def logsumexp(q, alpha=1.0, axis=1): 17 | return alpha*sp_lse((1.0/alpha)*q, axis=axis) 18 | 19 | def get_policy(q_fn, ent_wt=1.0): 20 | """ 21 | Return a policy by normalizing a Q-function 22 | """ 23 | v_rew = logsumexp(q_fn, alpha=ent_wt) 24 | adv_rew = q_fn - np.expand_dims(v_rew, axis=1) 25 | pol_probs = np.exp((1.0/ent_wt)*adv_rew) 26 | assert np.all(np.isclose(np.sum(pol_probs, axis=1), 1.0)), str(pol_probs) 27 | return pol_probs 28 | 29 | def q_iteration(env, reward_matrix=None, K=50, gamma=0.99, ent_wt=0.1, warmstart_q=None, policy=None): 30 | """ 31 | Perform tabular soft Q-iteration 32 | 33 | If policy is given, this computes Q_pi rather than Q_star 34 | """ 35 | dim_obs = env.observation_space.flat_dim 36 | dim_act = env.action_space.flat_dim 37 | if reward_matrix is None: 38 | reward_matrix = env.rew_matrix 39 | if warmstart_q is None: 40 | q_fn = np.zeros((dim_obs, dim_act)) 41 | else: 42 | q_fn = warmstart_q 43 | 44 | t_matrix = env.transition_matrix 45 | for k in range(K): 46 | if policy is None: 47 | v_fn = logsumexp(q_fn, alpha=ent_wt) 48 | else: 49 | v_fn = np.sum((q_fn - np.log(policy))*policy, axis=1) 50 | new_q = reward_matrix + gamma*t_matrix.dot(v_fn) 51 | q_fn = new_q 52 | return q_fn 53 | 54 | -------------------------------------------------------------------------------- /tabular_maxent_irl/simple_env.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | from matplotlib import cm 4 | 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from matplotlib.patches import Rectangle 8 | from rllab.envs.base import Env 9 | from rllab.misc import logger 10 | from rllab.spaces import Box 11 | from rllab.spaces import Discrete 12 | 13 | 14 | from utils import flat_to_one_hot, np_seed 15 | 16 | class DiscreteEnv(Env): 17 | def __init__(self, transition_matrix, reward, init_state, terminate_on_reward=False): 18 | super(DiscreteEnv, self).__init__() 19 | dX, dA, dXX = transition_matrix.shape 20 | self.nstates = dX 21 | self.nactions = dA 22 | self.transitions = transition_matrix 23 | self.init_state = init_state 24 | self.reward = reward 25 | self.terminate_on_reward = terminate_on_reward 26 | 27 | self.__observation_space = Box(0, 1, shape=(self.nstates,)) 28 | #max_A = 0 29 | #for trans in self.transitions: 30 | # max_A = max(max_A, len(self.transitions[trans])) 31 | self.__action_space = Discrete(dA) 32 | 33 | def reset(self): 34 | self.cur_state = self.init_state 35 | obs = flat_to_one_hot(self.cur_state, ndim=self.nstates) 36 | return obs 37 | 38 | def step(self, a): 39 | transition_probs = self.transitions[self.cur_state, a] 40 | next_state = np.random.choice(np.arange(self.nstates), p=transition_probs) 41 | r = self.reward[self.cur_state, a, next_state] 42 | self.cur_state = next_state 43 | obs = flat_to_one_hot(self.cur_state, ndim=self.nstates) 44 | 45 | done = False 46 | if self.terminate_on_reward and r>0: 47 | done = True 48 | return obs, r, done, {} 49 | 50 | def tabular_trans_distr(self, s, a): 51 | return self.transitions[s, a] 52 | 53 | def reward_fn(self, s, a): 54 | return self.reward[s, a] 55 | 56 | def log_diagnostics(self, paths): 57 | #Ntraj = len(paths) 58 | #acts = np.array([traj['actions'] for traj in paths]) 59 | obs = np.array([np.sum(traj['observations'], axis=0) for traj in paths]) 60 | 61 | state_count = np.sum(obs, axis=0) 62 | #state_count = np.mean(state_count, axis=0) 63 | state_freq = state_count/float(np.sum(state_count)) 64 | for state in range(self.nstates): 65 | logger.record_tabular('AvgStateFreq%d'%state, state_freq[state]) 66 | 67 | @property 68 | def transition_matrix(self): 69 | return self.transitions 70 | 71 | @property 72 | def rew_matrix(self): 73 | return self.reward 74 | 75 | @property 76 | def initial_state_distribution(self): 77 | return flat_to_one_hot(self.init_state, ndim=self.nstates) 78 | 79 | @property 80 | def action_space(self): 81 | return self.__action_space 82 | 83 | @property 84 | def observation_space(self): 85 | return self.__observation_space 86 | 87 | 88 | def random_env(Nstates, Nact, seed=None, terminate=False, t_sparsity=0.75): 89 | assert Nstates >= 2 90 | if seed is None: 91 | seed = 0 92 | reward_state=0 93 | start_state=1 94 | with np_seed(seed): 95 | transition_matrix = np.random.rand(Nstates, Nact, Nstates) 96 | transition_matrix = np.exp(transition_matrix) 97 | for s in range(Nstates): 98 | for a in range(Nact): 99 | zero_idxs = np.random.randint(0, Nstates, size=int(Nstates*t_sparsity)) 100 | transition_matrix[s, a, zero_idxs] = 0.0 101 | 102 | transition_matrix = transition_matrix/np.sum(transition_matrix, axis=2, keepdims=True) 103 | reward = np.zeros((Nstates, Nact)) 104 | reward[reward_state, :] = 1.0 105 | #reward = np.random.randn(Nstates,1 ) + reward 106 | 107 | stable_action = seed % Nact #np.random.randint(0, Nact) 108 | transition_matrix[reward_state, stable_action] = np.zeros(Nstates) 109 | transition_matrix[reward_state, stable_action, reward_state] = 1 110 | return DiscreteEnv(transition_matrix, reward=reward, init_state=start_state, terminate_on_reward=terminate) 111 | 112 | 113 | if __name__ == '__main__': 114 | env = random_env(5, 2, seed=0) 115 | print(env.transitions) 116 | print(env.transitions[0,0]) 117 | print(env.transitions[0,1]) 118 | env.reset() 119 | for _ in range(100): 120 | print(env.step(env.action_space.sample())) 121 | 122 | -------------------------------------------------------------------------------- /tabular_maxent_irl/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import scipy as sp 4 | import scipy.stats 5 | import contextlib 6 | 7 | 8 | def flat_to_one_hot(val, ndim): 9 | """ 10 | 11 | >>> flat_to_one_hot(2, ndim=4) 12 | array([ 0., 0., 1., 0.]) 13 | >>> flat_to_one_hot(4, ndim=5) 14 | array([ 0., 0., 0., 0., 1.]) 15 | >>> flat_to_one_hot(np.array([2, 4, 3]), ndim=5) 16 | array([[ 0., 0., 1., 0., 0.], 17 | [ 0., 0., 0., 0., 1.], 18 | [ 0., 0., 0., 1., 0.]]) 19 | """ 20 | shape =np.array(val).shape 21 | v = np.zeros(shape + (ndim,)) 22 | if len(shape) == 1: 23 | v[np.arange(shape[0]), val] = 1.0 24 | else: 25 | v[val] = 1.0 26 | return v 27 | 28 | def one_hot_to_flat(val): 29 | """ 30 | >>> one_hot_to_flat(np.array([0,0,0,0,1])) 31 | 4 32 | >>> one_hot_to_flat(np.array([0,0,1,0])) 33 | 2 34 | >>> one_hot_to_flat(np.array([[0,0,1,0], [1,0,0,0], [0,1,0,0]])) 35 | array([2, 0, 1]) 36 | """ 37 | idxs = np.array(np.where(val == 1.0))[-1] 38 | if len(val.shape) == 1: 39 | return int(idxs) 40 | return idxs 41 | 42 | 43 | def flatten_list(lol): 44 | return [ a for b in lol for a in b ] 45 | 46 | class TrainingIterator(object): 47 | def __init__(self, itrs, heartbeat=float('inf')): 48 | self.itrs = itrs 49 | self.heartbeat_time = heartbeat 50 | self.__vals = {} 51 | 52 | def random_idx(self, N, size): 53 | return np.random.randint(0, N, size=size) 54 | 55 | @property 56 | def itr(self): 57 | return self.__itr 58 | 59 | @property 60 | def heartbeat(self): 61 | return self.__heartbeat 62 | 63 | @property 64 | def elapsed(self): 65 | assert self.heartbeat, 'elapsed is only valid when heartbeat=True' 66 | return self.__elapsed 67 | 68 | def itr_message(self): 69 | return '==> Itr %d/%d (elapsed:%.2f)' % (self.itr+1, self.itrs, self.elapsed) 70 | 71 | def record(self, key, value): 72 | if key in self.__vals: 73 | self.__vals[key].append(value) 74 | else: 75 | self.__vals[key] = [value] 76 | 77 | def pop(self, key): 78 | vals = self.__vals.get(key, []) 79 | del self.__vals[key] 80 | return vals 81 | 82 | def pop_mean(self, key): 83 | return np.mean(self.pop(key)) 84 | 85 | def __iter__(self): 86 | prev_time = time.time() 87 | self.__heartbeat = False 88 | for i in range(self.itrs): 89 | self.__itr = i 90 | cur_time = time.time() 91 | if (cur_time-prev_time) > self.heartbeat_time or i==(self.itrs-1): 92 | self.__heartbeat = True 93 | self.__elapsed = cur_time-prev_time 94 | prev_time = cur_time 95 | yield self 96 | self.__heartbeat = False 97 | 98 | 99 | def gd_optimizer(lr, lr_sched=None): 100 | if lr_sched is None: 101 | lr_sched = {} 102 | 103 | itr = 0 104 | def update(x, grad): 105 | nonlocal itr, lr 106 | if itr in lr_sched: 107 | lr *= lr_sched[itr] 108 | new_x = x - lr * grad 109 | itr += 1 110 | return new_x 111 | return update 112 | 113 | 114 | def gd_momentum_optimizer(lr, momentum=0.9, lr_sched=None): 115 | if lr_sched is None: 116 | lr_sched = {} 117 | 118 | itr = 0 119 | prev_grad = None 120 | def update(x, grad): 121 | nonlocal itr, lr, prev_grad 122 | if itr in lr_sched: 123 | lr *= lr_sched[itr] 124 | 125 | if prev_grad is None: 126 | grad = grad 127 | else: 128 | grad = grad + momentum * prev_grad 129 | new_x = x - lr * grad 130 | prev_grad = grad 131 | itr += 1 132 | return new_x 133 | return update 134 | 135 | 136 | def adam_optimizer(lr, beta1=0.9, beta2=0.999, eps=1e-8): 137 | itr = 0 138 | pm = None 139 | pv = None 140 | def update(x, grad): 141 | nonlocal itr, lr, pm, pv 142 | if pm is None: 143 | pm = np.zeros_like(grad) 144 | pv = np.zeros_like(grad) 145 | 146 | pm = beta1 * pm + (1-beta1)*grad 147 | pv = beta2 * pv + (1-beta2)*(grad*grad) 148 | mhat = pm/(1-beta1**(itr+1)) 149 | vhat = pv/(1-beta2**(itr+1)) 150 | update_vec = mhat / (np.sqrt(vhat)+eps) 151 | new_x = x - lr * update_vec 152 | itr += 1 153 | return new_x 154 | return update 155 | 156 | 157 | @contextlib.contextmanager 158 | def np_seed(seed): 159 | """ A context for np random seeds """ 160 | if seed is None: 161 | yield 162 | else: 163 | old_state = np.random.get_state() 164 | np.random.seed(seed) 165 | yield 166 | np.random.set_state(old_state) 167 | 168 | --------------------------------------------------------------------------------