├── .gitignore
├── LICENSE
├── README.md
├── inverse_rl
├── __init__.py
├── algos
│ ├── batch_polopt.py
│ ├── irl_batch_polopt.py
│ ├── irl_npo.py
│ ├── irl_trpo.py
│ ├── npo.py
│ ├── penalty_lbfgs_optimizer.py
│ └── trpo.py
├── envs
│ ├── __init__.py
│ ├── ant_env.py
│ ├── assets
│ │ └── twod_maze.xml
│ ├── dynamic_mjc
│ │ ├── __init__.py
│ │ ├── mjc_models.py
│ │ └── model_builder.py
│ ├── env_utils.py
│ ├── point_maze_env.py
│ ├── pusher_env.py
│ ├── twod_maze.py
│ ├── twod_mjc_env.py
│ ├── utils.py
│ └── visual_pointmass.py
├── models
│ ├── __init__.py
│ ├── airl_state.py
│ ├── architectures.py
│ ├── fusion_manager.py
│ ├── imitation_learning.py
│ └── tf_util.py
└── utils
│ ├── __init__.py
│ ├── general.py
│ ├── hyper_sweep.py
│ ├── hyperparametrized.py
│ ├── log_utils.py
│ └── math_utils.py
├── scripts
├── ant_data_collect.py
├── ant_irl.py
├── ant_transfer_disabled.py
├── pendulum_data_collect.py
├── pendulum_gail.py
└── pendulum_irl.py
└── tabular_maxent_irl
├── README.md
├── maxent_irl.py
├── q_iteration.py
├── simple_env.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data/*
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | env/
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | .hypothesis/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # dotenv
85 | .env
86 |
87 | # virtualenv
88 | .venv
89 | venv/
90 | ENV/
91 |
92 | # Spyder project settings
93 | .spyderproject
94 | .spyproject
95 |
96 | # Rope project settings
97 | .ropeproject
98 |
99 | # mkdocs documentation
100 | /site
101 |
102 | # mypy
103 | .mypy_cache/
104 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Justin Fu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Inverse RL
2 |
3 | Implementations for imitation learning / IRL algorithms in RLLAB
4 |
5 | Contains:
6 | - GAIL (https://arxiv.org/abs/1606.03476/pdf)
7 | - Guided Cost Learning, GAN formulation (https://arxiv.org/pdf/1611.03852.pdf)
8 | - Tabular MaxCausalEnt IRL (http://www.cs.cmu.edu/~bziebart/publications/thesis-bziebart.pdf)
9 |
10 | Setup
11 | ---
12 | This library requires:
13 | - rllab (https://github.com/openai/rllab)
14 | - Tensorflow
15 |
16 | Examples
17 | ---
18 |
19 | Running the Pendulum-v0 gym environment:
20 |
21 | 1) Collect expert data
22 | ```
23 | python scripts/pendulum_data_collect.py
24 | ```
25 |
26 | You should get an "AverageReturn" of around -100 to -150
27 |
28 | 2) Run imitation learning
29 | ```
30 | python scripts/pendulum_gcl.py
31 | ```
32 |
33 | The "OriginalTaskAverageReturn" should reach around -100 to -150
34 |
--------------------------------------------------------------------------------
/inverse_rl/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/justinjfu/inverse_rl/9609933389459a3a54f5c01d652114ada90fa1b3/inverse_rl/__init__.py
--------------------------------------------------------------------------------
/inverse_rl/algos/batch_polopt.py:
--------------------------------------------------------------------------------
1 | import time
2 | from rllab.algos.base import RLAlgorithm
3 | import rllab.misc.logger as logger
4 | from sandbox.rocky.tf.policies.base import Policy
5 | import tensorflow as tf
6 | from sandbox.rocky.tf.samplers.batch_sampler import BatchSampler
7 | from sandbox.rocky.tf.samplers.vectorized_sampler import VectorizedSampler
8 | from rllab.sampler.utils import rollout
9 |
10 | from inverse_rl.utils.hyperparametrized import Hyperparametrized
11 |
12 |
13 | class BatchPolopt(RLAlgorithm, metaclass=Hyperparametrized):
14 | """
15 | Base class for batch sampling-based policy optimization methods.
16 | This includes various policy gradient methods like vpg, npg, ppo, trpo, etc.
17 | """
18 |
19 | def __init__(
20 | self,
21 | env,
22 | policy,
23 | baseline,
24 | scope=None,
25 | n_itr=500,
26 | start_itr=0,
27 | batch_size=5000,
28 | max_path_length=500,
29 | discount=0.99,
30 | gae_lambda=1,
31 | plot=False,
32 | pause_for_plot=False,
33 | center_adv=True,
34 | positive_adv=False,
35 | store_paths=False,
36 | whole_paths=True,
37 | fixed_horizon=False,
38 | sampler_cls=None,
39 | sampler_args=None,
40 | force_batch_sampler=False,
41 | **kwargs
42 | ):
43 | """
44 | :param env: Environment
45 | :param policy: Policy
46 | :type policy: Policy
47 | :param baseline: Baseline
48 | :param scope: Scope for identifying the algorithm. Must be specified if running multiple algorithms
49 | simultaneously, each using different environments and policies
50 | :param n_itr: Number of iterations.
51 | :param start_itr: Starting iteration.
52 | :param batch_size: Number of samples per iteration.
53 | :param max_path_length: Maximum length of a single rollout.
54 | :param discount: Discount.
55 | :param gae_lambda: Lambda used for generalized advantage estimation.
56 | :param plot: Plot evaluation run after each iteration.
57 | :param pause_for_plot: Whether to pause before contiuing when plotting.
58 | :param center_adv: Whether to rescale the advantages so that they have mean 0 and standard deviation 1.
59 | :param positive_adv: Whether to shift the advantages so that they are always positive. When used in
60 | conjunction with center_adv the advantages will be standardized before shifting.
61 | :param store_paths: Whether to save all paths data to the snapshot.
62 | :return:
63 | """
64 | self.env = env
65 | self.policy = policy
66 | self.baseline = baseline
67 | self.scope = scope
68 | self.n_itr = n_itr
69 | self.start_itr = start_itr
70 | self.batch_size = batch_size
71 | self.max_path_length = max_path_length
72 | self.discount = discount
73 | self.gae_lambda = gae_lambda
74 | self.plot = plot
75 | self.pause_for_plot = pause_for_plot
76 | self.center_adv = center_adv
77 | self.positive_adv = positive_adv
78 | self.store_paths = store_paths
79 | self.whole_paths = whole_paths
80 | self.fixed_horizon = fixed_horizon
81 | if sampler_cls is None:
82 | if self.policy.vectorized and not force_batch_sampler:
83 | sampler_cls = VectorizedSampler
84 | else:
85 | sampler_cls = BatchSampler
86 | if sampler_args is None:
87 | sampler_args = dict()
88 | self.sampler = sampler_cls(self, **sampler_args)
89 | self.init_opt()
90 |
91 | def start_worker(self):
92 | self.sampler.start_worker()
93 |
94 | def shutdown_worker(self):
95 | self.sampler.shutdown_worker()
96 |
97 | def obtain_samples(self, itr):
98 | return self.sampler.obtain_samples(itr)
99 |
100 | def process_samples(self, itr, paths):
101 | return self.sampler.process_samples(itr, paths)
102 |
103 | def train(self, sess=None):
104 | created_session = True if (sess is None) else False
105 | if sess is None:
106 | sess = tf.Session()
107 | sess.__enter__()
108 |
109 | sess.run(tf.global_variables_initializer())
110 | self.start_worker()
111 | start_time = time.time()
112 | for itr in range(self.start_itr, self.n_itr):
113 | itr_start_time = time.time()
114 | with logger.prefix('itr #%d | ' % itr):
115 | logger.log("Obtaining samples...")
116 | paths = self.obtain_samples(itr)
117 | logger.log("Processing samples...")
118 | samples_data = self.process_samples(itr, paths)
119 | logger.log("Logging diagnostics...")
120 | self.log_diagnostics(paths)
121 | logger.log("Optimizing policy...")
122 | self.optimize_policy(itr, samples_data)
123 | logger.log("Saving snapshot...")
124 | params = self.get_itr_snapshot(itr, samples_data) # , **kwargs)
125 | if self.store_paths:
126 | params["paths"] = samples_data["paths"]
127 | logger.save_itr_params(itr, params)
128 | logger.log("Saved")
129 | logger.record_tabular('Time', time.time() - start_time)
130 | logger.record_tabular('ItrTime', time.time() - itr_start_time)
131 | logger.dump_tabular(with_prefix=False)
132 | if self.plot:
133 | rollout(self.env, self.policy, animated=True, max_path_length=self.max_path_length)
134 | if self.pause_for_plot:
135 | input("Plotting evaluation run: Press Enter to "
136 | "continue...")
137 | self.shutdown_worker()
138 | if created_session:
139 | sess.close()
140 |
141 | def log_diagnostics(self, paths):
142 | self.env.log_diagnostics(paths)
143 | self.policy.log_diagnostics(paths)
144 | self.baseline.log_diagnostics(paths)
145 |
146 | def init_opt(self):
147 | """
148 | Initialize the optimization procedure. If using tensorflow, this may
149 | include declaring all the variables and compiling functions
150 | """
151 | raise NotImplementedError
152 |
153 | def get_itr_snapshot(self, itr, samples_data):
154 | """
155 | Returns all the data that should be saved in the snapshot for this
156 | iteration.
157 | """
158 | raise NotImplementedError
159 |
160 | def optimize_policy(self, itr, samples_data):
161 | raise NotImplementedError
162 |
--------------------------------------------------------------------------------
/inverse_rl/algos/irl_batch_polopt.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | from rllab.algos.base import RLAlgorithm
4 | import rllab.misc.logger as logger
5 | import rllab.plotter as plotter
6 | from sandbox.rocky.tf.policies.base import Policy
7 | import tensorflow as tf
8 | from sandbox.rocky.tf.samplers.batch_sampler import BatchSampler
9 | from sandbox.rocky.tf.samplers.vectorized_sampler import VectorizedSampler
10 | import numpy as np
11 | from collections import deque
12 |
13 | from inverse_rl.utils.hyperparametrized import Hyperparametrized
14 |
15 |
16 | class IRLBatchPolopt(RLAlgorithm, metaclass=Hyperparametrized):
17 | """
18 | Base class for batch sampling-based policy optimization methods.
19 | This includes various policy gradient methods like vpg, npg, ppo, trpo, etc.
20 | """
21 |
22 | def __init__(
23 | self,
24 | env,
25 | policy,
26 | baseline,
27 | scope=None,
28 | n_itr=500,
29 | start_itr=0,
30 | batch_size=5000,
31 | max_path_length=500,
32 | discount=0.99,
33 | gae_lambda=1,
34 | plot=False,
35 | pause_for_plot=False,
36 | center_adv=True,
37 | positive_adv=False,
38 | store_paths=True,
39 | whole_paths=True,
40 | fixed_horizon=False,
41 | sampler_cls=None,
42 | sampler_args=None,
43 | force_batch_sampler=False,
44 | init_pol_params = None,
45 | irl_model=None,
46 | irl_model_wt=1.0,
47 | discrim_train_itrs=10,
48 | zero_environment_reward=False,
49 | init_irl_params=None,
50 | train_irl=True,
51 | key='',
52 | **kwargs
53 | ):
54 | """
55 | :param env: Environment
56 | :param policy: Policy
57 | :type policy: Policy
58 | :param baseline: Baseline
59 | :param scope: Scope for identifying the algorithm. Must be specified if running multiple algorithms
60 | simultaneously, each using different environments and policies
61 | :param n_itr: Number of iterations.
62 | :param start_itr: Starting iteration.
63 | :param batch_size: Number of samples per iteration.
64 | :param max_path_length: Maximum length of a single rollout.
65 | :param discount: Discount.
66 | :param gae_lambda: Lambda used for generalized advantage estimation.
67 | :param plot: Plot evaluation run after each iteration.
68 | :param pause_for_plot: Whether to pause before contiuing when plotting.
69 | :param center_adv: Whether to rescale the advantages so that they have mean 0 and standard deviation 1.
70 | :param positive_adv: Whether to shift the advantages so that they are always positive. When used in
71 | conjunction with center_adv the advantages will be standardized before shifting.
72 | :param store_paths: Whether to save all paths data to the snapshot.
73 | :return:
74 | """
75 | self.env = env
76 | self.policy = policy
77 | self.baseline = baseline
78 | self.scope = scope
79 | self.n_itr = n_itr
80 | self.start_itr = start_itr
81 | self.batch_size = batch_size
82 | self.max_path_length = max_path_length
83 | self.discount = discount
84 | self.gae_lambda = gae_lambda
85 | self.plot = plot
86 | self.pause_for_plot = pause_for_plot
87 | self.center_adv = center_adv
88 | self.positive_adv = positive_adv
89 | self.store_paths = store_paths
90 | self.whole_paths = whole_paths
91 | self.fixed_horizon = fixed_horizon
92 | self.init_pol_params = init_pol_params
93 | self.init_irl_params = init_irl_params
94 | self.irl_model = irl_model
95 | self.irl_model_wt = irl_model_wt
96 | self.no_reward = zero_environment_reward
97 | self.discrim_train_itrs = discrim_train_itrs
98 | self.train_irl = train_irl
99 | self.__irl_params = None
100 |
101 | if self.irl_model_wt > 0:
102 | assert self.irl_model is not None, "Need to specify a IRL model"
103 |
104 | if sampler_cls is None:
105 | if self.policy.vectorized and not force_batch_sampler:
106 | print('using vec sampler')
107 | sampler_cls = VectorizedSampler
108 | else:
109 | print('using batch sampler')
110 | sampler_cls = BatchSampler
111 | if sampler_args is None:
112 | sampler_args = dict()
113 | self.sampler = sampler_cls(self, **sampler_args)
114 | self.init_opt()
115 |
116 | def start_worker(self):
117 | self.sampler.start_worker()
118 | if self.plot:
119 | plotter.init_plot(self.env, self.policy)
120 |
121 | def shutdown_worker(self):
122 | self.sampler.shutdown_worker()
123 |
124 | def obtain_samples(self, itr):
125 | return self.sampler.obtain_samples(itr)
126 |
127 | def process_samples(self, itr, paths):
128 | #processed = self.sampler.process_samples(itr, paths)
129 | return self.sampler.process_samples(itr, paths)
130 |
131 | def log_avg_returns(self, paths):
132 | undiscounted_returns = [sum(path["rewards"]) for path in paths]
133 | avg_return = np.mean(undiscounted_returns)
134 | return avg_return
135 |
136 | def get_irl_params(self):
137 | return self.__irl_params
138 |
139 | def compute_irl(self, paths, itr=0):
140 | if self.no_reward:
141 | tot_rew = 0
142 | for path in paths:
143 | tot_rew += np.sum(path['rewards'])
144 | path['rewards'] *= 0
145 | logger.record_tabular('OriginalTaskAverageReturn', tot_rew/float(len(paths)))
146 |
147 | if self.irl_model_wt <=0:
148 | return paths
149 |
150 | if self.train_irl:
151 | max_itrs = self.discrim_train_itrs
152 | lr=1e-3
153 | mean_loss = self.irl_model.fit(paths, policy=self.policy, itr=itr, max_itrs=max_itrs, lr=lr,
154 | logger=logger)
155 |
156 | logger.record_tabular('IRLLoss', mean_loss)
157 | self.__irl_params = self.irl_model.get_params()
158 |
159 | probs = self.irl_model.eval(paths, gamma=self.discount, itr=itr)
160 |
161 | logger.record_tabular('IRLRewardMean', np.mean(probs))
162 | logger.record_tabular('IRLRewardMax', np.max(probs))
163 | logger.record_tabular('IRLRewardMin', np.min(probs))
164 |
165 |
166 | if self.irl_model.score_trajectories:
167 | # TODO: should I add to reward here or after advantage computation?
168 | for i, path in enumerate(paths):
169 | path['rewards'][-1] += self.irl_model_wt * probs[i]
170 | else:
171 | for i, path in enumerate(paths):
172 | path['rewards'] += self.irl_model_wt * probs[i]
173 | return paths
174 |
175 | def train(self):
176 | sess = tf.get_default_session()
177 | sess.run(tf.global_variables_initializer())
178 | if self.init_pol_params is not None:
179 | self.policy.set_param_values(self.init_pol_params)
180 | if self.init_irl_params is not None:
181 | self.irl_model.set_params(self.init_irl_params)
182 | self.start_worker()
183 | start_time = time.time()
184 |
185 | returns = []
186 | for itr in range(self.start_itr, self.n_itr):
187 | itr_start_time = time.time()
188 | with logger.prefix('itr #%d | ' % itr):
189 | logger.log("Obtaining samples...")
190 | paths = self.obtain_samples(itr)
191 |
192 | logger.log("Processing samples...")
193 | paths = self.compute_irl(paths, itr=itr)
194 | returns.append(self.log_avg_returns(paths))
195 | samples_data = self.process_samples(itr, paths)
196 |
197 | logger.log("Logging diagnostics...")
198 | self.log_diagnostics(paths)
199 | logger.log("Optimizing policy...")
200 | self.optimize_policy(itr, samples_data)
201 | logger.log("Saving snapshot...")
202 | params = self.get_itr_snapshot(itr, samples_data) # , **kwargs)
203 | if self.store_paths:
204 | params["paths"] = samples_data["paths"]
205 | logger.save_itr_params(itr, params)
206 | logger.log("Saved")
207 | logger.record_tabular('Time', time.time() - start_time)
208 | logger.record_tabular('ItrTime', time.time() - itr_start_time)
209 | logger.dump_tabular(with_prefix=False)
210 | if self.plot:
211 | self.update_plot()
212 | if self.pause_for_plot:
213 | input("Plotting evaluation run: Press Enter to "
214 | "continue...")
215 | self.shutdown_worker()
216 | return
217 |
218 | def log_diagnostics(self, paths):
219 | self.env.log_diagnostics(paths)
220 | self.policy.log_diagnostics(paths)
221 | self.baseline.log_diagnostics(paths)
222 |
223 | def init_opt(self):
224 | """
225 | Initialize the optimization procedure. If using tensorflow, this may
226 | include declaring all the variables and compiling functions
227 | """
228 | raise NotImplementedError
229 |
230 | def get_itr_snapshot(self, itr, samples_data):
231 | """
232 | Returns all the data that should be saved in the snapshot for this
233 | iteration.
234 | """
235 | raise NotImplementedError
236 |
237 | def optimize_policy(self, itr, samples_data):
238 | raise NotImplementedError
239 |
240 | def update_plot(self):
241 | if self.plot:
242 | plotter.update_plot(self.policy, self.max_path_length)
243 |
--------------------------------------------------------------------------------
/inverse_rl/algos/irl_npo.py:
--------------------------------------------------------------------------------
1 | from rllab.misc import ext
2 | from rllab.misc.overrides import overrides
3 | import rllab.misc.logger as logger
4 | from inverse_rl.algos.penalty_lbfgs_optimizer import PenaltyLbfgsOptimizer
5 | from inverse_rl.algos.irl_batch_polopt import IRLBatchPolopt
6 | from sandbox.rocky.tf.misc import tensor_utils
7 | import tensorflow as tf
8 | import numpy as np
9 | from sandbox.rocky.tf.optimizers.conjugate_gradient_optimizer import ConjugateGradientOptimizer
10 |
11 |
12 | class IRLNPO(IRLBatchPolopt):
13 | """
14 | Natural Policy Optimization.
15 | """
16 |
17 | def __init__(
18 | self,
19 | optimizer=None,
20 | optimizer_args=None,
21 | step_size=0.01,
22 | entropy_weight=1.0,
23 | **kwargs):
24 | if optimizer is None:
25 | if optimizer_args is None:
26 | optimizer_args = dict(name='lbfgs')
27 | optimizer = PenaltyLbfgsOptimizer(**optimizer_args)
28 | self.optimizer = optimizer
29 | self.step_size = step_size
30 | self.pol_ent_wt = entropy_weight
31 | super(IRLNPO, self).__init__(**kwargs)
32 |
33 | @overrides
34 | def init_opt(self):
35 | is_recurrent = int(self.policy.recurrent)
36 | obs_var = self.env.observation_space.new_tensor_variable(
37 | 'obs',
38 | extra_dims=1 + is_recurrent,
39 | )
40 | action_var = self.env.action_space.new_tensor_variable(
41 | 'action',
42 | extra_dims=1 + is_recurrent,
43 | )
44 | advantage_var = tensor_utils.new_tensor(
45 | 'advantage',
46 | ndim=1 + is_recurrent,
47 | dtype=tf.float32,
48 | )
49 |
50 | input_list = [
51 | obs_var,
52 | action_var,
53 | advantage_var,
54 | ]
55 |
56 | dist = self.policy.distribution
57 |
58 | old_dist_info_vars = {
59 | k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name='old_%s' % k)
60 | for k, shape in dist.dist_info_specs
61 | }
62 | old_dist_info_vars_list = [old_dist_info_vars[k] for k in dist.dist_info_keys]
63 |
64 | state_info_vars = {
65 | k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name=k)
66 | for k, shape in self.policy.state_info_specs
67 | }
68 | state_info_vars_list = [state_info_vars[k] for k in self.policy.state_info_keys]
69 |
70 | if is_recurrent:
71 | valid_var = tf.placeholder(tf.float32, shape=[None, None], name="valid")
72 | else:
73 | valid_var = None
74 |
75 | dist_info_vars = self.policy.dist_info_sym(obs_var, state_info_vars)
76 | kl = dist.kl_sym(old_dist_info_vars, dist_info_vars)
77 | lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars, dist_info_vars)
78 |
79 | if self.pol_ent_wt > 0:
80 | if 'log_std' in dist_info_vars:
81 | log_std = dist_info_vars['log_std']
82 | ent = tf.reduce_sum(log_std + tf.log(tf.sqrt(2 * np.pi * np.e)), reduction_indices=-1)
83 | elif 'prob' in dist_info_vars:
84 | prob = dist_info_vars['prob']
85 | ent = -tf.reduce_sum(prob*tf.log(prob), reduction_indices=-1)
86 | else:
87 | raise NotImplementedError()
88 | ent = tf.stop_gradient(ent)
89 | adv = advantage_var + self.pol_ent_wt*ent
90 | else:
91 | adv = advantage_var
92 |
93 |
94 | if is_recurrent:
95 | mean_kl = tf.reduce_sum(kl * valid_var) / tf.reduce_sum(valid_var)
96 | surr_loss = - tf.reduce_sum(lr * adv * valid_var) / tf.reduce_sum(valid_var)
97 | else:
98 | mean_kl = tf.reduce_mean(kl)
99 | surr_loss = - tf.reduce_mean(lr * adv)
100 |
101 | input_list += state_info_vars_list + old_dist_info_vars_list
102 | if is_recurrent:
103 | input_list.append(valid_var)
104 |
105 | self.optimizer.update_opt(
106 | loss=surr_loss,
107 | target=self.policy,
108 | leq_constraint=(mean_kl, self.step_size),
109 | inputs=input_list,
110 | constraint_name="mean_kl"
111 | )
112 | return dict()
113 |
114 | @overrides
115 | def optimize_policy(self, itr, samples_data):
116 | all_input_values = tuple(ext.extract(
117 | samples_data,
118 | "observations", "actions", "advantages",
119 | ))
120 |
121 | agent_infos = samples_data["agent_infos"]
122 | state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
123 | dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
124 | all_input_values += tuple(state_info_list) + tuple(dist_info_list)
125 | if self.policy.recurrent:
126 | all_input_values += (samples_data["valids"],)
127 | logger.log("Computing loss before")
128 | loss_before = self.optimizer.loss(all_input_values)
129 | logger.log("Computing KL before")
130 | mean_kl_before = self.optimizer.constraint_val(all_input_values)
131 | logger.log("Optimizing")
132 | self.optimizer.optimize(all_input_values)
133 | logger.log("Computing KL after")
134 | mean_kl = self.optimizer.constraint_val(all_input_values)
135 | logger.log("Computing loss after")
136 | loss_after = self.optimizer.loss(all_input_values)
137 | logger.record_tabular('LossBefore', loss_before)
138 | logger.record_tabular('LossAfter', loss_after)
139 | logger.record_tabular('MeanKLBefore', mean_kl_before)
140 | logger.record_tabular('MeanKL', mean_kl)
141 | logger.record_tabular('dLoss', loss_before - loss_after)
142 | return dict()
143 |
144 | @overrides
145 | def get_itr_snapshot(self, itr, samples_data):
146 | return dict(
147 | itr=itr,
148 | policy=self.policy,
149 | policy_params=self.policy.get_param_values(),
150 | irl_params=self.get_irl_params(),
151 | baseline=self.baseline,
152 | env=self.env,
153 | )
154 |
--------------------------------------------------------------------------------
/inverse_rl/algos/irl_trpo.py:
--------------------------------------------------------------------------------
1 | from inverse_rl.algos.irl_npo import IRLNPO
2 | from sandbox.rocky.tf.optimizers.conjugate_gradient_optimizer import ConjugateGradientOptimizer
3 |
4 |
5 | class IRLTRPO(IRLNPO):
6 | """
7 | Trust Region Policy Optimization
8 | """
9 |
10 | def __init__(
11 | self,
12 | optimizer=None,
13 | optimizer_args=None,
14 | **kwargs):
15 | if optimizer is None:
16 | if optimizer_args is None:
17 | optimizer_args = dict()
18 | optimizer = ConjugateGradientOptimizer(**optimizer_args)
19 | super(IRLTRPO, self).__init__(optimizer=optimizer, **kwargs)
20 |
--------------------------------------------------------------------------------
/inverse_rl/algos/npo.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from rllab.misc import ext
4 | from rllab.misc.overrides import overrides
5 | import rllab.misc.logger as logger
6 | from sandbox.rocky.tf.optimizers.penalty_lbfgs_optimizer import PenaltyLbfgsOptimizer
7 | from sandbox.rocky.tf.misc import tensor_utils
8 | import tensorflow as tf
9 |
10 | from inverse_rl.algos.batch_polopt import BatchPolopt
11 |
12 |
13 | class NPO(BatchPolopt):
14 | """
15 | Natural Policy Optimization.
16 | """
17 |
18 | def __init__(
19 | self,
20 | optimizer=None,
21 | optimizer_args=None,
22 | step_size=0.01,
23 | entropy_weight=0.0,
24 | **kwargs):
25 | if optimizer is None:
26 | if optimizer_args is None:
27 | optimizer_args = dict()
28 | optimizer = PenaltyLbfgsOptimizer(**optimizer_args)
29 | self.optimizer = optimizer
30 | self.step_size = step_size
31 | self.pol_ent_wt = entropy_weight
32 | super(NPO, self).__init__(**kwargs)
33 |
34 | @overrides
35 | def init_opt(self):
36 | is_recurrent = int(self.policy.recurrent)
37 | obs_var = self.env.observation_space.new_tensor_variable(
38 | 'obs',
39 | extra_dims=1 + is_recurrent,
40 | )
41 | action_var = self.env.action_space.new_tensor_variable(
42 | 'action',
43 | extra_dims=1 + is_recurrent,
44 | )
45 | advantage_var = tensor_utils.new_tensor(
46 | 'advantage',
47 | ndim=1 + is_recurrent,
48 | dtype=tf.float32,
49 | )
50 | dist = self.policy.distribution
51 |
52 | old_dist_info_vars = {
53 | k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name='old_%s' % k)
54 | for k, shape in dist.dist_info_specs
55 | }
56 | old_dist_info_vars_list = [old_dist_info_vars[k] for k in dist.dist_info_keys]
57 |
58 | state_info_vars = {
59 | k: tf.placeholder(tf.float32, shape=[None] * (1 + is_recurrent) + list(shape), name=k)
60 | for k, shape in self.policy.state_info_specs
61 | }
62 | state_info_vars_list = [state_info_vars[k] for k in self.policy.state_info_keys]
63 |
64 | if is_recurrent:
65 | valid_var = tf.placeholder(tf.float32, shape=[None, None], name="valid")
66 | else:
67 | valid_var = None
68 |
69 | dist_info_vars = self.policy.dist_info_sym(obs_var, state_info_vars)
70 | kl = dist.kl_sym(old_dist_info_vars, dist_info_vars)
71 | lr = dist.likelihood_ratio_sym(action_var, old_dist_info_vars, dist_info_vars)
72 |
73 | if self.pol_ent_wt > 0:
74 | if 'log_std' in dist_info_vars:
75 | log_std = dist_info_vars['log_std']
76 | ent = tf.reduce_sum(log_std + tf.log(tf.sqrt(2 * np.pi * np.e)), reduction_indices=-1)
77 | elif 'prob' in dist_info_vars:
78 | prob = dist_info_vars['prob']
79 | ent = - tf.reduce_sum(prob*tf.log(prob), reduction_indices=-1)
80 | else:
81 | raise NotImplementedError()
82 | ent = tf.stop_gradient(ent)
83 | adv = advantage_var + self.pol_ent_wt*ent
84 | else:
85 | adv = advantage_var
86 |
87 | if is_recurrent:
88 | mean_kl = tf.reduce_sum(kl * valid_var) / tf.reduce_sum(valid_var)
89 | surr_loss = - tf.reduce_sum(lr * adv * valid_var) / tf.reduce_sum(valid_var)
90 | else:
91 | mean_kl = tf.reduce_mean(kl)
92 | surr_loss = - tf.reduce_mean(lr * adv)
93 |
94 | input_list = [
95 | obs_var,
96 | action_var,
97 | advantage_var,
98 | ] + state_info_vars_list + old_dist_info_vars_list
99 | if is_recurrent:
100 | input_list.append(valid_var)
101 |
102 | self.optimizer.update_opt(
103 | loss=surr_loss,
104 | target=self.policy,
105 | leq_constraint=(mean_kl, self.step_size),
106 | inputs=input_list,
107 | constraint_name="mean_kl"
108 | )
109 | return dict()
110 |
111 | @overrides
112 | def optimize_policy(self, itr, samples_data):
113 | all_input_values = tuple(ext.extract(
114 | samples_data,
115 | "observations", "actions", "advantages"
116 | ))
117 | agent_infos = samples_data["agent_infos"]
118 | state_info_list = [agent_infos[k] for k in self.policy.state_info_keys]
119 | dist_info_list = [agent_infos[k] for k in self.policy.distribution.dist_info_keys]
120 | all_input_values += tuple(state_info_list) + tuple(dist_info_list)
121 | if self.policy.recurrent:
122 | all_input_values += (samples_data["valids"],)
123 | logger.log("Computing loss before")
124 | loss_before = self.optimizer.loss(all_input_values)
125 | logger.log("Computing KL before")
126 | mean_kl_before = self.optimizer.constraint_val(all_input_values)
127 | logger.log("Optimizing")
128 | self.optimizer.optimize(all_input_values)
129 | logger.log("Computing KL after")
130 | mean_kl = self.optimizer.constraint_val(all_input_values)
131 | logger.log("Computing loss after")
132 | loss_after = self.optimizer.loss(all_input_values)
133 | logger.record_tabular('LossBefore', loss_before)
134 | logger.record_tabular('LossAfter', loss_after)
135 | logger.record_tabular('MeanKLBefore', mean_kl_before)
136 | logger.record_tabular('MeanKL', mean_kl)
137 | logger.record_tabular('dLoss', loss_before - loss_after)
138 | return dict()
139 |
140 | @overrides
141 | def get_itr_snapshot(self, itr, samples_data):
142 | return dict(
143 | itr=itr,
144 | policy=self.policy,
145 | baseline=self.baseline,
146 | env=self.env,
147 | )
148 |
--------------------------------------------------------------------------------
/inverse_rl/algos/penalty_lbfgs_optimizer.py:
--------------------------------------------------------------------------------
1 | import scipy
2 | import scipy.optimize
3 | import numpy as np
4 | import tensorflow as tf
5 | from rllab.core.serializable import Serializable
6 | from rllab.misc import ext
7 | from rllab.misc import logger
8 | from sandbox.rocky.tf.misc import tensor_utils
9 |
10 |
11 | class PenaltyLbfgsOptimizer(Serializable):
12 | """
13 | Performs constrained optimization via penalized L-BFGS. The penalty term is adaptively adjusted to make sure that
14 | the constraint is satisfied.
15 | """
16 |
17 | def __init__(
18 | self,
19 | name,
20 | max_opt_itr=20,
21 | initial_penalty=1.0,
22 | min_penalty=1e-2,
23 | max_penalty=1e6,
24 | increase_penalty_factor=2,
25 | decrease_penalty_factor=0.5,
26 | max_penalty_itr=10,
27 | adapt_penalty=True):
28 | Serializable.quick_init(self, locals())
29 | self._name = name
30 | self._max_opt_itr = max_opt_itr
31 | self._penalty = initial_penalty
32 | self._initial_penalty = initial_penalty
33 | self._min_penalty = min_penalty
34 | self._max_penalty = max_penalty
35 | self._increase_penalty_factor = increase_penalty_factor
36 | self._decrease_penalty_factor = decrease_penalty_factor
37 | self._max_penalty_itr = max_penalty_itr
38 | self._adapt_penalty = adapt_penalty
39 |
40 | self._opt_fun = None
41 | self._target = None
42 | self._max_constraint_val = None
43 | self._constraint_name = None
44 |
45 | def update_opt(self, loss, target, leq_constraint, inputs, constraint_name="constraint", *args, **kwargs):
46 | """
47 | :param loss: Symbolic expression for the loss function.
48 | :param target: A parameterized object to optimize over. It should implement methods of the
49 | :class:`rllab.core.paramerized.Parameterized` class.
50 | :param leq_constraint: A constraint provided as a tuple (f, epsilon), of the form f(*inputs) <= epsilon.
51 | :param inputs: A list of symbolic variables as inputs
52 | :return: No return value.
53 | """
54 | constraint_term, constraint_value = leq_constraint
55 | with tf.variable_scope(self._name):
56 | penalty_var = tf.placeholder(tf.float32, tuple(), name="penalty")
57 | penalized_loss = loss + penalty_var * constraint_term
58 |
59 | self._target = target
60 | self._max_constraint_val = constraint_value
61 | self._constraint_name = constraint_name
62 |
63 | def get_opt_output():
64 | params = target.get_params(trainable=True)
65 | grads = tf.gradients(penalized_loss, params)
66 | for idx, (grad, param) in enumerate(zip(grads, params)):
67 | if grad is None:
68 | grads[idx] = tf.zeros_like(param)
69 | flat_grad = tensor_utils.flatten_tensor_variables(grads)
70 | return [
71 | tf.cast(penalized_loss, tf.float64),
72 | tf.cast(flat_grad, tf.float64),
73 | ]
74 |
75 | self._opt_fun = ext.lazydict(
76 | f_loss=lambda: tensor_utils.compile_function(inputs, loss, log_name="f_loss"),
77 | f_constraint=lambda: tensor_utils.compile_function(inputs, constraint_term, log_name="f_constraint"),
78 | f_penalized_loss=lambda: tensor_utils.compile_function(
79 | inputs=inputs + [penalty_var],
80 | outputs=[penalized_loss, loss, constraint_term],
81 | log_name="f_penalized_loss",
82 | ),
83 | f_opt=lambda: tensor_utils.compile_function(
84 | inputs=inputs + [penalty_var],
85 | outputs=get_opt_output(),
86 | )
87 | )
88 |
89 | def loss(self, inputs):
90 | return self._opt_fun["f_loss"](*inputs)
91 |
92 | def constraint_val(self, inputs):
93 | return self._opt_fun["f_constraint"](*inputs)
94 |
95 | def optimize(self, inputs):
96 |
97 | inputs = tuple(inputs)
98 |
99 | try_penalty = np.clip(
100 | self._penalty, self._min_penalty, self._max_penalty)
101 |
102 | penalty_scale_factor = None
103 | f_opt = self._opt_fun["f_opt"]
104 | f_penalized_loss = self._opt_fun["f_penalized_loss"]
105 |
106 | def gen_f_opt(penalty):
107 | def f(flat_params):
108 | self._target.set_param_values(flat_params, trainable=True)
109 | return f_opt(*(inputs + (penalty,)))
110 |
111 | return f
112 |
113 | cur_params = self._target.get_param_values(trainable=True).astype('float64')
114 | opt_params = cur_params
115 |
116 | for penalty_itr in range(self._max_penalty_itr):
117 | logger.log('trying penalty=%.3f...' % try_penalty)
118 |
119 | itr_opt_params, _, _ = scipy.optimize.fmin_l_bfgs_b(
120 | func=gen_f_opt(try_penalty), x0=cur_params,
121 | maxiter=self._max_opt_itr
122 | )
123 |
124 | _, try_loss, try_constraint_val = f_penalized_loss(*(inputs + (try_penalty,)))
125 |
126 | logger.log('penalty %f => loss %f, %s %f' %
127 | (try_penalty, try_loss, self._constraint_name, try_constraint_val))
128 |
129 | # Either constraint satisfied, or we are at the last iteration already and no alternative parameter
130 | # satisfies the constraint
131 | if try_constraint_val < self._max_constraint_val or \
132 | (penalty_itr == self._max_penalty_itr - 1 and opt_params is None):
133 | opt_params = itr_opt_params
134 |
135 | if not self._adapt_penalty:
136 | break
137 |
138 | # Decide scale factor on the first iteration, or if constraint violation yields numerical error
139 | if penalty_scale_factor is None or np.isnan(try_constraint_val):
140 | # Increase penalty if constraint violated, or if constraint term is NAN
141 | if try_constraint_val > self._max_constraint_val or np.isnan(try_constraint_val):
142 | penalty_scale_factor = self._increase_penalty_factor
143 | else:
144 | # Otherwise (i.e. constraint satisfied), shrink penalty
145 | penalty_scale_factor = self._decrease_penalty_factor
146 | opt_params = itr_opt_params
147 | else:
148 | if penalty_scale_factor > 1 and \
149 | try_constraint_val <= self._max_constraint_val:
150 | break
151 | elif penalty_scale_factor < 1 and \
152 | try_constraint_val >= self._max_constraint_val:
153 | break
154 | old_penalty = try_penalty
155 | try_penalty *= penalty_scale_factor
156 | try_penalty = np.clip(try_penalty, self._min_penalty, self._max_penalty)
157 | if try_penalty == old_penalty:
158 | break
159 | self._penalty = try_penalty
160 |
161 | self._target.set_param_values(opt_params, trainable=True)
162 |
--------------------------------------------------------------------------------
/inverse_rl/algos/trpo.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from inverse_rl.algos.npo import NPO
4 | from sandbox.rocky.tf.optimizers.conjugate_gradient_optimizer import ConjugateGradientOptimizer
5 |
6 |
7 | class TRPO(NPO):
8 | """
9 | Trust Region Policy Optimization
10 | """
11 |
12 | def __init__(
13 | self,
14 | optimizer=None,
15 | optimizer_args=None,
16 | **kwargs):
17 | if optimizer is None:
18 | if optimizer_args is None:
19 | optimizer_args = dict()
20 | optimizer = ConjugateGradientOptimizer(**optimizer_args)
21 | super(TRPO, self).__init__(optimizer=optimizer, **kwargs)
22 |
--------------------------------------------------------------------------------
/inverse_rl/envs/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from gym.envs import register
4 |
5 | LOGGER = logging.getLogger(__name__)
6 |
7 | _REGISTERED = False
8 | def register_custom_envs():
9 | global _REGISTERED
10 | if _REGISTERED:
11 | return
12 | _REGISTERED = True
13 |
14 | LOGGER.info("Registering custom gym environments")
15 | register(id='ObjPusher-v0', entry_point='inverse_rl.envs.pusher_env:PusherEnv', kwargs={'sparse_reward': False})
16 | register(id='TwoDMaze-v0', entry_point='inverse_rl.envs.twod_maze:TwoDMaze')
17 | register(id='PointMazeRight-v0', entry_point='inverse_rl.envs.point_maze_env:PointMazeEnv',
18 | kwargs={'sparse_reward': False, 'direction': 1})
19 | register(id='PointMazeLeft-v0', entry_point='inverse_rl.envs.point_maze_env:PointMazeEnv',
20 | kwargs={'sparse_reward': False, 'direction': 0})
21 |
22 | # A modified ant which flips over less and learns faster via TRPO
23 | register(id='CustomAnt-v0', entry_point='inverse_rl.envs.ant_env:CustomAntEnv',
24 | kwargs={'gear': 30, 'disabled': False})
25 | register(id='DisabledAnt-v0', entry_point='inverse_rl.envs.ant_env:CustomAntEnv',
26 | kwargs={'gear': 30, 'disabled': True})
27 |
28 | register(id='VisualPointMaze-v0', entry_point='inverse_rl.envs.visual_pointmass:VisualPointMazeEnv',
29 | kwargs={'sparse_reward': False, 'direction': 1})
30 |
--------------------------------------------------------------------------------
/inverse_rl/envs/ant_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gym import utils
3 | from gym.envs.mujoco import mujoco_env
4 | from inverse_rl.envs.dynamic_mjc.model_builder import MJCModel
5 | from rllab.misc import logger
6 |
7 | def ant_env(gear=150, eyes=True):
8 | mjcmodel = MJCModel('ant_maze')
9 | mjcmodel.root.compiler(inertiafromgeom="true", angle="degree", coordinate="local")
10 | mjcmodel.root.option(timestep="0.01", integrator="RK4")
11 | mjcmodel.root.custom().numeric(data="0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0",name="init_qpos")
12 | asset = mjcmodel.root.asset()
13 | asset.texture(builtin="gradient",height="100",rgb1="1 1 1",rgb2="0 0 0",type="skybox",width="100")
14 | asset.texture(builtin="flat",height="1278",mark="cross",markrgb="1 1 1",name="texgeom",random="0.01",rgb1="0.8 0.6 0.4",rgb2="0.8 0.6 0.4",type="cube",width="127")
15 | asset.texture(builtin="checker",height="100",name="texplane",rgb1="0 0 0",rgb2="0.8 0.8 0.8",type="2d",width="100")
16 | asset.material(name="MatPlane",reflectance="0.5",shininess="1",specular="1",texrepeat="60 60",texture="texplane")
17 | asset.material(name="geom",texture="texgeom",texuniform="true")
18 |
19 | default = mjcmodel.root.default()
20 | default.joint(armature=1, damping=1, limited='true')
21 | default.geom(friction=[1.5,0.5,0.5], density=5.0, margin=0.01, condim=3, conaffinity=0, rgba="0.8 0.6 0.4 1")
22 |
23 | worldbody = mjcmodel.root.worldbody()
24 | worldbody.light(cutoff="100",diffuse=[.8,.8,.8],dir="-0 0 -1.3",directional="true",exponent="1",pos="0 0 1.3",specular=".1 .1 .1")
25 | worldbody.geom(conaffinity=1, condim=3, material="MatPlane",name="floor",pos="0 0 0",rgba="0.8 0.9 0.8 1",size="40 40 40",type="plane")
26 |
27 | ant = worldbody.body(name='torso', pos=[0, 0, 0.75])
28 | ant.geom(name='torso_geom', pos=[0, 0, 0], size="0.25", type="sphere")
29 | ant.joint(armature="0", damping="0", limited="false", margin="0.01", name="root", pos=[0, 0, 0], type="free")
30 |
31 | if eyes:
32 | eye_z = 0.1
33 | eye_y = -.21
34 | eye_x_offset = 0.07
35 | # eyes
36 | ant.geom(fromto=[eye_x_offset,0,eye_z,eye_x_offset,eye_y,eye_z], name='eye1', size='0.03', type='capsule', rgba=[1,1,1,1])
37 | ant.geom(fromto=[eye_x_offset,0,eye_z,eye_x_offset,eye_y-0.02,eye_z], name='eye1_', size='0.02', type='capsule', rgba=[0,0,0,1])
38 | ant.geom(fromto=[-eye_x_offset,0,eye_z,-eye_x_offset,eye_y,eye_z], name='eye2', size='0.03', type='capsule', rgba=[1,1,1,1])
39 | ant.geom(fromto=[-eye_x_offset,0,eye_z,-eye_x_offset,eye_y-0.02,eye_z], name='eye2_', size='0.02', type='capsule', rgba=[0,0,0,1])
40 | # eyebrows
41 | ant.geom(fromto=[eye_x_offset-0.03,eye_y, eye_z+0.07, eye_x_offset+0.03, eye_y, eye_z+0.1], name='brow1', size='0.02', type='capsule', rgba=[0,0,0,1])
42 | ant.geom(fromto=[-eye_x_offset+0.03,eye_y, eye_z+0.07, -eye_x_offset-0.03, eye_y, eye_z+0.1], name='brow2', size='0.02', type='capsule', rgba=[0,0,0,1])
43 |
44 | front_left_leg = ant.body(name="front_left_leg", pos=[0, 0, 0])
45 | front_left_leg.geom(fromto=[0.0, 0.0, 0.0, 0.2, 0.2, 0.0], name="aux_1_geom", size="0.08", type="capsule")
46 | aux_1 = front_left_leg.body(name="aux_1", pos=[0.2, 0.2, 0])
47 | aux_1.joint(axis=[0, 0, 1], name="hip_1", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge")
48 | aux_1.geom(fromto=[0.0, 0.0, 0.0, 0.2, 0.2, 0.0], name="left_leg_geom", size="0.08", type="capsule")
49 | ankle_1 = aux_1.body(pos=[0.2, 0.2, 0])
50 | ankle_1.joint(axis=[-1, 1, 0], name="ankle_1", pos=[0.0, 0.0, 0.0], range=[30, 70], type="hinge")
51 | ankle_1.geom(fromto=[0.0, 0.0, 0.0, 0.4, 0.4, 0.0], name="left_ankle_geom", size="0.08", type="capsule")
52 |
53 | front_right_leg = ant.body(name="front_right_leg", pos=[0, 0, 0])
54 | front_right_leg.geom(fromto=[0.0, 0.0, 0.0, -0.2, 0.2, 0.0], name="aux_2_geom", size="0.08", type="capsule")
55 | aux_2 = front_right_leg.body(name="aux_2", pos=[-0.2, 0.2, 0])
56 | aux_2.joint(axis=[0, 0, 1], name="hip_2", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge")
57 | aux_2.geom(fromto=[0.0, 0.0, 0.0, -0.2, 0.2, 0.0], name="right_leg_geom", size="0.08", type="capsule")
58 | ankle_2 = aux_2.body(pos=[-0.2, 0.2, 0])
59 | ankle_2.joint(axis=[1, 1, 0], name="ankle_2", pos=[0.0, 0.0, 0.0], range=[-70, -30], type="hinge")
60 | ankle_2.geom(fromto=[0.0, 0.0, 0.0, -0.4, 0.4, 0.0], name="right_ankle_geom", size="0.08", type="capsule")
61 |
62 | back_left_leg = ant.body(name="back_left_leg", pos=[0, 0, 0])
63 | back_left_leg.geom(fromto=[0.0, 0.0, 0.0, -0.2, -0.2, 0.0], name="aux_3_geom", size="0.08", type="capsule")
64 | aux_3 = back_left_leg.body(name="aux_3", pos=[-0.2, -0.2, 0])
65 | aux_3.joint(axis=[0, 0, 1], name="hip_3", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge")
66 | aux_3.geom(fromto=[0.0, 0.0, 0.0, -0.2, -0.2, 0.0], name="backleft_leg_geom", size="0.08", type="capsule")
67 | ankle_3 = aux_3.body(pos=[-0.2, -0.2, 0])
68 | ankle_3.joint(axis=[-1, 1, 0], name="ankle_3", pos=[0.0, 0.0, 0.0], range=[-70, -30], type="hinge")
69 | ankle_3.geom(fromto=[0.0, 0.0, 0.0, -0.4, -0.4, 0.0], name="backleft_ankle_geom", size="0.08", type="capsule")
70 |
71 | back_right_leg = ant.body(name="back_right_leg", pos=[0, 0, 0])
72 | back_right_leg.geom(fromto=[0.0, 0.0, 0.0, 0.2, -0.2, 0.0], name="aux_4_geom", size="0.08", type="capsule")
73 | aux_4 = back_right_leg.body(name="aux_4", pos=[0.2, -0.2, 0])
74 | aux_4.joint(axis=[0, 0, 1], name="hip_4", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge")
75 | aux_4.geom(fromto=[0.0, 0.0, 0.0, 0.2, -0.2, 0.0], name="backright_leg_geom", size="0.08", type="capsule")
76 | ankle_4 = aux_4.body(pos=[0.2, -0.2, 0])
77 | ankle_4.joint(axis=[1, 1, 0], name="ankle_4", pos=[0.0, 0.0, 0.0], range=[30, 70], type="hinge")
78 | ankle_4.geom(fromto=[0.0, 0.0, 0.0, 0.4, -0.4, 0.0], name="backright_ankle_geom", size="0.08", type="capsule")
79 |
80 | actuator = mjcmodel.root.actuator()
81 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_4", gear=gear)
82 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_4", gear=gear)
83 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_1", gear=gear)
84 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_1", gear=gear)
85 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_2", gear=gear)
86 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_2", gear=gear)
87 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_3", gear=gear)
88 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_3", gear=gear)
89 | return mjcmodel
90 |
91 |
92 | def angry_ant_crippled(gear=150):
93 | mjcmodel = MJCModel('ant_maze')
94 | mjcmodel.root.compiler(inertiafromgeom="true", angle="degree", coordinate="local")
95 | mjcmodel.root.option(timestep="0.01", integrator="RK4")
96 | mjcmodel.root.custom().numeric(data="0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0",name="init_qpos")
97 | asset = mjcmodel.root.asset()
98 | asset.texture(builtin="gradient",height="100",rgb1="1 1 1",rgb2="0 0 0",type="skybox",width="100")
99 | asset.texture(builtin="flat",height="1278",mark="cross",markrgb="1 1 1",name="texgeom",random="0.01",rgb1="0.8 0.6 0.4",rgb2="0.8 0.6 0.4",type="cube",width="127")
100 | asset.texture(builtin="checker",height="100",name="texplane",rgb1="0 0 0",rgb2="0.8 0.8 0.8",type="2d",width="100")
101 | asset.material(name="MatPlane",reflectance="0.5",shininess="1",specular="1",texrepeat="60 60",texture="texplane")
102 | asset.material(name="geom",texture="texgeom",texuniform="true")
103 |
104 |
105 |
106 | default = mjcmodel.root.default()
107 | default.joint(armature=1, damping=1, limited='true')
108 | default.geom(friction=[1.5,0.5,0.5], density=5.0, margin=0.01, condim=3, conaffinity=0, rgba="0.8 0.6 0.4 1")
109 |
110 | worldbody = mjcmodel.root.worldbody()
111 |
112 | worldbody.geom(conaffinity=1, condim=3, material="MatPlane",name="floor",pos="0 0 0",rgba="0.8 0.9 0.8 1",size="40 40 40",type="plane")
113 | worldbody.light(cutoff="100",diffuse=[.8,.8,.8],dir="-0 0 -1.3",directional="true",exponent="1",pos="0 0 1.3",specular=".1 .1 .1")
114 |
115 |
116 | ant = worldbody.body(name='torso', pos=[0, 0, 0.75])
117 | ant.geom(name='torso_geom', pos=[0, 0, 0], size="0.25", type="sphere")
118 | ant.joint(armature="0", damping="0", limited="false", margin="0.01", name="root", pos=[0, 0, 0], type="free")
119 |
120 | eye_z = 0.1
121 | eye_y = -.21
122 | eye_x_offset = 0.07
123 | # eyes
124 | ant.geom(fromto=[eye_x_offset,0,eye_z,eye_x_offset,eye_y,eye_z], name='eye1', size='0.03', type='capsule', rgba=[1,1,1,1])
125 | ant.geom(fromto=[eye_x_offset,0,eye_z,eye_x_offset,eye_y-0.02,eye_z], name='eye1_', size='0.02', type='capsule', rgba=[0,0,0,1])
126 | ant.geom(fromto=[-eye_x_offset,0,eye_z,-eye_x_offset,eye_y,eye_z], name='eye2', size='0.03', type='capsule', rgba=[1,1,1,1])
127 | ant.geom(fromto=[-eye_x_offset,0,eye_z,-eye_x_offset,eye_y-0.02,eye_z], name='eye2_', size='0.02', type='capsule', rgba=[0,0,0,1])
128 | # eyebrows
129 | ant.geom(fromto=[eye_x_offset-0.03,eye_y, eye_z+0.07, eye_x_offset+0.03, eye_y, eye_z+0.1], name='brow1', size='0.02', type='capsule', rgba=[0,0,0,1])
130 | ant.geom(fromto=[-eye_x_offset+0.03,eye_y, eye_z+0.07, -eye_x_offset-0.03, eye_y, eye_z+0.1], name='brow2', size='0.02', type='capsule', rgba=[0,0,0,1])
131 |
132 |
133 |
134 |
135 | front_left_leg = ant.body(name="front_left_leg", pos=[0, 0, 0])
136 | front_left_leg.geom(fromto=[0.0, 0.0, 0.0, 0.2, 0.2, 0.0], name="aux_1_geom", size="0.08", type="capsule")
137 | aux_1 = front_left_leg.body(name="aux_1", pos=[0.2, 0.2, 0])
138 | aux_1.joint(axis=[0, 0, 1], name="hip_1", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge")
139 | aux_1.geom(fromto=[0.0, 0.0, 0.0, 0.2, 0.2, 0.0], name="left_leg_geom", size="0.08", type="capsule")
140 | ankle_1 = aux_1.body(pos=[0.2, 0.2, 0])
141 | ankle_1.joint(axis=[-1, 1, 0], name="ankle_1", pos=[0.0, 0.0, 0.0], range=[30, 70], type="hinge")
142 | ankle_1.geom(fromto=[0.0, 0.0, 0.0, 0.4, 0.4, 0.0], name="left_ankle_geom", size="0.08", type="capsule")
143 |
144 | front_right_leg = ant.body(name="front_right_leg", pos=[0, 0, 0])
145 | front_right_leg.geom(fromto=[0.0, 0.0, 0.0, -0.2, 0.2, 0.0], name="aux_2_geom", size="0.08", type="capsule")
146 | aux_2 = front_right_leg.body(name="aux_2", pos=[-0.2, 0.2, 0])
147 | aux_2.joint(axis=[0, 0, 1], name="hip_2", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge")
148 | aux_2.geom(fromto=[0.0, 0.0, 0.0, -0.2, 0.2, 0.0], name="right_leg_geom", size="0.08", type="capsule")
149 | ankle_2 = aux_2.body(pos=[-0.2, 0.2, 0])
150 | ankle_2.joint(axis=[1, 1, 0], name="ankle_2", pos=[0.0, 0.0, 0.0], range=[-70, -30], type="hinge")
151 | ankle_2.geom(fromto=[0.0, 0.0, 0.0, -0.4, 0.4, 0.0], name="right_ankle_geom", size="0.08", type="capsule")
152 |
153 | # Back left leg is crippled
154 | thigh_length = 0.1 #0.2
155 | ankle_length = 0.2 #0.4
156 | dark_red = [0.8,0.3,0.3,1.0]
157 |
158 | back_left_leg = ant.body(name="back_left_leg", pos=[0, 0, 0])
159 | back_left_leg.geom(fromto=[0.0, 0.0, 0.0, -0.2, -0.2, 0.0], name="aux_3_geom", size="0.08", type="capsule",
160 | rgba=dark_red)
161 | aux_3 = back_left_leg.body(name="aux_3", pos=[-0.2, -0.2, 0])
162 | aux_3.joint(axis=[0, 0, 1], name="hip_3", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge")
163 | aux_3.geom(fromto=[0.0, 0.0, 0.0, -thigh_length, -thigh_length, 0.0], name="backleft_leg_geom", size="0.08", type="capsule",
164 | rgba=dark_red)
165 | ankle_3 = aux_3.body(pos=[-thigh_length, -thigh_length, 0])
166 | ankle_3.joint(axis=[-1, 1, 0], name="ankle_3", pos=[0.0, 0.0, 0.0], range=[-70, -30], type="hinge")
167 | ankle_3.geom(fromto=[0.0, 0.0, 0.0, -ankle_length, -ankle_length, 0.0], name="backleft_ankle_geom", size="0.08", type="capsule",
168 | rgba=dark_red)
169 |
170 | back_right_leg = ant.body(name="back_right_leg", pos=[0, 0, 0])
171 | back_right_leg.geom(fromto=[0.0, 0.0, 0.0, 0.2, -0.2, 0.0], name="aux_4_geom", size="0.08", type="capsule",
172 | rgba=dark_red)
173 | aux_4 = back_right_leg.body(name="aux_4", pos=[0.2, -0.2, 0])
174 | aux_4.joint(axis=[0, 0, 1], name="hip_4", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge")
175 | aux_4.geom(fromto=[0.0, 0.0, 0.0, thigh_length, -thigh_length, 0.0], name="backright_leg_geom", size="0.08", type="capsule",
176 | rgba=dark_red)
177 | ankle_4 = aux_4.body(pos=[thigh_length, -thigh_length, 0])
178 | ankle_4.joint(axis=[1, 1, 0], name="ankle_4", pos=[0.0, 0.0, 0.0], range=[30, 70], type="hinge")
179 | ankle_4.geom(fromto=[0.0, 0.0, 0.0, ankle_length, -ankle_length, 0.0], name="backright_ankle_geom", size="0.08", type="capsule",
180 | rgba=dark_red)
181 |
182 | actuator = mjcmodel.root.actuator()
183 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_1", gear=gear)
184 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_1", gear=gear)
185 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_2", gear=gear)
186 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_2", gear=gear)
187 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_3", gear=1) # cripple the joints
188 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_3", gear=1) # cripple the joints
189 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_4", gear=1)
190 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_4", gear=1)
191 | return mjcmodel
192 |
193 |
194 | class CustomAntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
195 | """
196 | A modified ant env with lower joint gear ratios so it flips less often and learns faster.
197 | """
198 | def __init__(self, max_timesteps=1000, disabled=False, gear=150):
199 | #mujoco_env.MujocoEnv.__init__(self, 'ant.xml', 5)
200 | utils.EzPickle.__init__(self)
201 | self.timesteps = 0
202 | self.max_timesteps=max_timesteps
203 |
204 | if disabled:
205 | model = angry_ant_crippled(gear=gear)
206 | else:
207 | model = ant_env(gear=gear)
208 | with model.asfile() as f:
209 | mujoco_env.MujocoEnv.__init__(self, f.name, 5)
210 |
211 | def _step(self, a):
212 | vel = self.model.data.qvel.flat[0]
213 | forward_reward = vel
214 | self.do_simulation(a, self.frame_skip)
215 |
216 | ctrl_cost = .01 * np.square(a).sum()
217 | contact_cost = 0.5 * 1e-3 * np.sum(
218 | np.square(np.clip(self.model.data.cfrc_ext, -1, 1)))
219 | state = self.state_vector()
220 | flipped = not (state[2] >= 0.2)
221 | flipped_rew = -1 if flipped else 0
222 | reward = forward_reward - ctrl_cost - contact_cost +flipped_rew
223 |
224 | self.timesteps += 1
225 | done = self.timesteps >= self.max_timesteps
226 |
227 | ob = self._get_obs()
228 | return ob, reward, done, dict(
229 | reward_forward=forward_reward,
230 | reward_ctrl=-ctrl_cost,
231 | reward_contact=-contact_cost,
232 | reward_flipped=flipped_rew)
233 |
234 | def _get_obs(self):
235 | return np.concatenate([
236 | self.model.data.qpos.flat[2:],
237 | self.model.data.qvel.flat,
238 | np.clip(self.model.data.cfrc_ext, -1, 1).flat,
239 | ])
240 |
241 | def reset_model(self):
242 | self.timesteps = 0
243 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1)
244 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
245 | self.set_state(qpos, qvel)
246 | return self._get_obs()
247 |
248 | def viewer_setup(self):
249 | self.viewer.cam.distance = self.model.stat.extent * 0.5
250 |
251 | def log_diagnostics(self, paths):
252 | forward_rew = np.array([np.mean(traj['env_infos']['reward_forward']) for traj in paths])
253 | reward_ctrl = np.array([np.mean(traj['env_infos']['reward_ctrl']) for traj in paths])
254 | reward_cont = np.array([np.mean(traj['env_infos']['reward_contact']) for traj in paths])
255 | reward_flip = np.array([np.mean(traj['env_infos']['reward_flipped']) for traj in paths])
256 |
257 | logger.record_tabular('AvgRewardFwd', np.mean(forward_rew))
258 | logger.record_tabular('AvgRewardCtrl', np.mean(reward_ctrl))
259 | logger.record_tabular('AvgRewardContact', np.mean(reward_cont))
260 | logger.record_tabular('AvgRewardFlipped', np.mean(reward_flip))
261 |
262 |
263 | if __name__ == "__main__":
264 | env = CustomAntEnv(disabled=True, gear=30)
265 |
266 | for _ in range(1000):
267 | env.render()
268 | env.step(env.action_space.sample())
269 |
--------------------------------------------------------------------------------
/inverse_rl/envs/assets/twod_maze.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/inverse_rl/envs/dynamic_mjc/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/justinjfu/inverse_rl/9609933389459a3a54f5c01d652114ada90fa1b3/inverse_rl/envs/dynamic_mjc/__init__.py
--------------------------------------------------------------------------------
/inverse_rl/envs/dynamic_mjc/mjc_models.py:
--------------------------------------------------------------------------------
1 | from inverse_rl.envs.dynamic_mjc.model_builder import MJCModel
2 | import numpy as np
3 |
4 |
5 | def block_push(object_pos=(0,0,0), goal_pos=(0,0,0)):
6 | mjcmodel = MJCModel('block_push')
7 | mjcmodel.root.compiler(inertiafromgeom="true",angle="radian",coordinate="local")
8 | mjcmodel.root.option(timestep="0.01",gravity="0 0 0",iterations="20",integrator="Euler")
9 | default = mjcmodel.root.default()
10 | default.joint(armature='0.04', damping=1, limited='true')
11 | default.geom(friction=".8 .1 .1",density="300",margin="0.002",condim="1",contype="1",conaffinity="1")
12 |
13 | worldbody = mjcmodel.root.worldbody()
14 |
15 | palm = worldbody.body(name='palm', pos=[0,0,0])
16 | palm.geom(name='palm_geom', type='capsule', fromto=[0,-0.1,0,0,0.1,0], size=.12)
17 | proximal1 = palm.body(name='proximal_1', pos=[0,0,0])
18 | proximal1.joint(name='proximal_j_1', type='hinge', pos=[0,0,0], axis=[0,1,0], range=[-2.5,2.3])
19 | proximal1.geom(type='capsule', fromto=[0,0,0,0.4,0,0], size=0.06, contype=1, conaffinity=1)
20 | distal1 = proximal1.body(name='distal_1', pos=[0.4,0,0])
21 | distal1.joint(name = "distal_j_1", type = "hinge", pos = "0 0 0", axis = "0 1 0", range = "-2.3213 2.3", damping = "1.0")
22 | distal1.geom(type="capsule", fromto="0 0 0 0.4 0 0", size="0.06", contype="1", conaffinity="1")
23 | distal2 = distal1.body(name='distal_2', pos=[0.4,0,0])
24 | distal2.joint(name="distal_j_2",type="hinge",pos="0 0 0",axis="0 1 0",range="-2.3213 2.3",damping="1.0")
25 | distal2.geom(type="capsule",fromto="0 0 0 0.4 0 0",size="0.06",contype="1",conaffinity="1")
26 | distal4 = distal2.body(name='distal_4', pos=[0.4,0,0])
27 | distal4.site(name="tip arml",pos="0.1 0 -0.2",size="0.01")
28 | distal4.site(name="tip armr",pos="0.1 0 0.2",size="0.01")
29 | distal4.joint(name="distal_j_3",type="hinge",pos="0 0 0",axis="1 0 0",range="-3.3213 3.3",damping="0.5")
30 | distal4.geom(type="capsule",fromto="0 0 -0.2 0 0 0.2",size="0.04",contype="1",conaffinity="1")
31 | distal4.geom(type="capsule",fromto="0 0 -0.2 0.2 0 -0.2",size="0.04",contype="1",conaffinity="1")
32 | distal4.geom(type="capsule",fromto="0 0 0.2 0.2 0 0.2",size="0.04",contype="1",conaffinity="1")
33 |
34 | object = worldbody.body(name='object', pos=object_pos)
35 | object.geom(rgba="1. 1. 1. 1",type="box",size="0.05 0.05 0.05",density='0.00001',contype="1",conaffinity="1")
36 | object.joint(name="obj_slidez",type="slide",pos="0.025 0.025 0.025",axis="0 0 1",range="-10.3213 10.3",damping="0.5")
37 | object.joint(name="obj_slidex",type="slide",pos="0.025 0.025 0.025",axis="1 0 0",range="-10.3213 10.3",damping="0.5")
38 | distal10 = object.body(name='distal_10', pos=[0,0,0])
39 | distal10.site(name='obj_pos', pos=[0.025,0.025,0.025], size=0.01)
40 |
41 | goal = worldbody.body(name='goal', pos=goal_pos)
42 | goal.geom(rgba="1. 0. 0. 1",type="box",size="0.1 0.1 0.1",density='0.00001',contype="0",conaffinity="0")
43 | distal11 = goal.body(name='distal_11', pos=[0,0,0])
44 | distal11.site(name='goal_pos', pos=[0.05,0.05,0.05], size=0.01)
45 |
46 |
47 | actuator = mjcmodel.root.actuator()
48 | actuator.motor(joint="proximal_j_1",ctrlrange="-2 2",ctrllimited="true")
49 | actuator.motor(joint="distal_j_1",ctrlrange="-2 2",ctrllimited="true")
50 | actuator.motor(joint="distal_j_2",ctrlrange="-2 2",ctrllimited="true")
51 | actuator.motor(joint="distal_j_3",ctrlrange="-2 2",ctrllimited="true")
52 |
53 | return mjcmodel
54 |
55 |
56 | EAST = 0
57 | WEST = 1
58 | NORTH = 2
59 | SOUTH = 3
60 |
61 | def twod_corridor(direction=EAST, length=1.2):
62 | mjcmodel = MJCModel('twod_corridor')
63 | mjcmodel.root.compiler(inertiafromgeom="true", angle="radian", coordinate="local")
64 | mjcmodel.root.option(timestep="0.01", gravity="0 0 0", iterations="20", integrator="Euler")
65 | default = mjcmodel.root.default()
66 | default.joint(damping=1, limited='false')
67 | default.geom(friction=".5 .1 .1", density="1000", margin="0.002", condim="1", contype="2", conaffinity="1")
68 |
69 | worldbody = mjcmodel.root.worldbody()
70 |
71 | particle = worldbody.body(name='particle', pos=[0,0,0])
72 | particle.geom(name='particle_geom', type='sphere', size='0.03', rgba='0.0 0.0 1.0 1', contype=1)
73 | particle.site(name='particle_site', pos=[0,0,0], size=0.01)
74 | particle.joint(name='ball_x', type='slide', pos=[0,0,0], axis=[1,0,0])
75 | particle.joint(name='ball_y', type='slide', pos=[0,0,0], axis=[0,1,0])
76 |
77 | pos = np.array([0.0,0,0])
78 | if direction == EAST or direction == WEST:
79 | pos[0] = length-0.1
80 | else:
81 | pos[1] = length-0.1
82 | if direction == WEST or direction == SOUTH:
83 | pos = -pos
84 |
85 | target = worldbody.body(name='target', pos=pos)
86 | target.geom(name='target_geom', conaffinity=2, type='sphere', size=0.02, rgba=[0,0.9,0.1,1])
87 |
88 | # arena
89 | if direction == EAST:
90 | L = -0.1
91 | R = length
92 | U = 0.1
93 | D = -0.1
94 | elif direction == WEST:
95 | L = -length
96 | R = 0.1
97 | U = 0.1
98 | D = -0.1
99 | elif direction == SOUTH:
100 | L = -0.1
101 | R = 0.1
102 | U = 0.1
103 | D = -length
104 | elif direction == NORTH:
105 | L = -0.1
106 | R = 0.1
107 | U = length
108 | D = -0.1
109 |
110 | worldbody.geom(conaffinity=1, fromto=[L, D, .01, R, D, .01], name="sideS", rgba="0.9 0.4 0.6 1", size=.02, type="capsule")
111 | worldbody.geom(conaffinity=1, fromto=[R, D, .01, R, U, .01], name="sideE", rgba="0.9 0.4 0.6 1", size=".02", type="capsule")
112 | worldbody.geom(conaffinity=1, fromto=[L, U, .01, R, U, .01], name="sideN", rgba="0.9 0.4 0.6 1", size=".02", type="capsule")
113 | worldbody.geom(conaffinity=1, fromto=[L, D, .01, L, U, .01], name="sideW", rgba="0.9 0.4 0.6 1", size=".02", type="capsule")
114 |
115 | actuator = mjcmodel.root.actuator()
116 | actuator.motor(joint="ball_x", ctrlrange=[-1.0, 1.0], ctrllimited=True)
117 | actuator.motor(joint="ball_y", ctrlrange=[-1.0, 1.0], ctrllimited=True)
118 |
119 | return mjcmodel
120 |
121 |
122 | LEFT = 0
123 | RIGHT = 1
124 |
125 | def point_mass_maze(direction=RIGHT, length=1.2, borders=True):
126 | mjcmodel = MJCModel('twod_maze')
127 | mjcmodel.root.compiler(inertiafromgeom="true", angle="radian", coordinate="local")
128 | mjcmodel.root.option(timestep="0.01", gravity="0 0 0", iterations="20", integrator="Euler")
129 | default = mjcmodel.root.default()
130 | default.joint(damping=1, limited='false')
131 | default.geom(friction=".5 .1 .1", density="1000", margin="0.002", condim="1", contype="2", conaffinity="1")
132 |
133 | worldbody = mjcmodel.root.worldbody()
134 |
135 | particle = worldbody.body(name='particle', pos=[length/2,0,0])
136 | particle.geom(name='particle_geom', type='sphere', size='0.03', rgba='0.0 0.0 1.0 1', contype=1)
137 | particle.site(name='particle_site', pos=[0,0,0], size=0.01)
138 | particle.joint(name='ball_x', type='slide', pos=[0,0,0], axis=[1,0,0])
139 | particle.joint(name='ball_y', type='slide', pos=[0,0,0], axis=[0,1,0])
140 |
141 | target = worldbody.body(name='target', pos=[length/2,length-0.1,0])
142 | target.geom(name='target_geom', conaffinity=2, type='sphere', size=0.02, rgba=[0,0.9,0.1,1])
143 |
144 | L = -0.1
145 | R = length
146 | U = length
147 | D = -0.1
148 |
149 | if borders:
150 | worldbody.geom(conaffinity=1, fromto=[L, D, .01, R, D, .01], name="sideS", rgba="0.9 0.4 0.6 1", size=".02", type="capsule")
151 | worldbody.geom(conaffinity=1, fromto=[R, D, .01, R, U, .01], name="sideE", rgba="0.9 0.4 0.6 1", size=".02", type="capsule")
152 | worldbody.geom(conaffinity=1, fromto=[L, U, .01, R, U, .01], name="sideN", rgba="0.9 0.4 0.6 1", size=".02", type="capsule")
153 | worldbody.geom(conaffinity=1, fromto=[L, D, .01, L, U, .01], name="sideW", rgba="0.9 0.4 0.6 1", size=".02", type="capsule")
154 |
155 | # arena
156 | if direction == LEFT:
157 | BL = -0.1
158 | BR = length * 2/3
159 | BH = length/2
160 | else:
161 | BL = length * 1/3
162 | BR = length
163 | BH = length/2
164 |
165 | worldbody.geom(conaffinity=1, fromto=[BL, BH, .01, BR, BH, .01], name="barrier", rgba="0.9 0.4 0.6 1", size=".02", type="capsule")
166 |
167 | actuator = mjcmodel.root.actuator()
168 | actuator.motor(joint="ball_x", ctrlrange=[-1.0, 1.0], ctrllimited=True)
169 | actuator.motor(joint="ball_y", ctrlrange=[-1.0, 1.0], ctrllimited=True)
170 |
171 | return mjcmodel
172 |
173 |
174 | def ant_maze(direction=RIGHT, length=6.0):
175 | mjcmodel = MJCModel('ant_maze')
176 | mjcmodel.root.compiler(inertiafromgeom="true", angle="degree", coordinate="local")
177 | mjcmodel.root.option(timestep="0.01", gravity="0 0 -9.8", iterations="20", integrator="Euler")
178 |
179 | assets = mjcmodel.root.asset()
180 | assets.texture(builtin="gradient", height="100", rgb1="1 1 1", rgb2="0 0 0", type="skybox", width="100")
181 | assets.texture(builtin="flat", height="1278", mark="cross", markrgb="1 1 1", name="texgeom", random="0.01", rgb1="0.8 0.6 0.4", rgb2="0.8 0.6 0.4", type="cube", width="127")
182 | assets.texture(builtin="checker", height="100", name="texplane", rgb1="0 0 0", rgb2="0.8 0.8 0.8", type="2d", width="100")
183 | assets.material(name="MatPlane", reflectance="0.5", shininess="1", specular="1", texrepeat="60 60", texture="texplane")
184 | assets.material(name="geom", texture="texgeom", texuniform="true")
185 |
186 | default = mjcmodel.root.default()
187 | default.joint(armature="1", damping=1, limited='true')
188 | default.geom(friction="1 0.5 0.5", density="5.0", margin="0.01", condim="3", conaffinity="0")
189 |
190 | worldbody = mjcmodel.root.worldbody()
191 |
192 | ant = worldbody.body(name='ant', pos=[length/2, 1.0, 0.05])
193 | ant.geom(name='torso_geom', pos=[0, 0, 0], size="0.25", type="sphere")
194 | ant.joint(armature="0", damping="0", limited="false", margin="0.01", name="root", pos=[0, 0, 0], type="free")
195 |
196 | front_left_leg = ant.body(name="front_left_leg", pos=[0, 0, 0])
197 | front_left_leg.geom(fromto=[0.0, 0.0, 0.0, 0.2, 0.2, 0.0], name="aux_1_geom", size="0.08", type="capsule")
198 | aux_1 = front_left_leg.body(name="aux_1", pos=[0.2, 0.2, 0])
199 | aux_1.joint(axis=[0, 0, 1], name="hip_1", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge")
200 | aux_1.geom(fromto=[0.0, 0.0, 0.0, 0.2, 0.2, 0.0], name="left_leg_geom", size="0.08", type="capsule")
201 | ankle_1 = aux_1.body(pos=[0.2, 0.2, 0])
202 | ankle_1.joint(axis=[-1, 1, 0], name="ankle_1", pos=[0.0, 0.0, 0.0], range=[30, 70], type="hinge")
203 | ankle_1.geom(fromto=[0.0, 0.0, 0.0, 0.4, 0.4, 0.0], name="left_ankle_geom", size="0.08", type="capsule")
204 |
205 | front_right_leg = ant.body(name="front_right_leg", pos=[0, 0, 0])
206 | front_right_leg.geom(fromto=[0.0, 0.0, 0.0, -0.2, 0.2, 0.0], name="aux_2_geom", size="0.08", type="capsule")
207 | aux_2 = front_right_leg.body(name="aux_2", pos=[-0.2, 0.2, 0])
208 | aux_2.joint(axis=[0, 0, 1], name="hip_2", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge")
209 | aux_2.geom(fromto=[0.0, 0.0, 0.0, -0.2, 0.2, 0.0], name="right_leg_geom", size="0.08", type="capsule")
210 | ankle_2 = aux_2.body(pos=[-0.2, 0.2, 0])
211 | ankle_2.joint(axis=[1, 1, 0], name="ankle_2", pos=[0.0, 0.0, 0.0], range=[-70, -30], type="hinge")
212 | ankle_2.geom(fromto=[0.0, 0.0, 0.0, -0.4, 0.4, 0.0], name="right_ankle_geom", size="0.08", type="capsule")
213 |
214 | back_left_leg = ant.body(name="back_left_leg", pos=[0, 0, 0])
215 | back_left_leg.geom(fromto=[0.0, 0.0, 0.0, -0.2, -0.2, 0.0], name="aux_3_geom", size="0.08", type="capsule")
216 | aux_3 = back_left_leg.body(name="aux_3", pos=[-0.2, -0.2, 0])
217 | aux_3.joint(axis=[0, 0, 1], name="hip_3", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge")
218 | aux_3.geom(fromto=[0.0, 0.0, 0.0, -0.2, -0.2, 0.0], name="backleft_leg_geom", size="0.08", type="capsule")
219 | ankle_3 = aux_3.body(pos=[-0.2, -0.2, 0])
220 | ankle_3.joint(axis=[-1, 1, 0], name="ankle_3", pos=[0.0, 0.0, 0.0], range=[-70, -30], type="hinge")
221 | ankle_3.geom(fromto=[0.0, 0.0, 0.0, -0.4, -0.4, 0.0], name="backleft_ankle_geom", size="0.08", type="capsule")
222 |
223 | back_right_leg = ant.body(name="back_right_leg", pos=[0, 0, 0])
224 | back_right_leg.geom(fromto=[0.0, 0.0, 0.0, 0.2, -0.2, 0.0], name="aux_4_geom", size="0.08", type="capsule")
225 | aux_4 = back_right_leg.body(name="aux_4", pos=[0.2, -0.2, 0])
226 | aux_4.joint(axis=[0, 0, 1], name="hip_4", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge")
227 | aux_4.geom(fromto=[0.0, 0.0, 0.0, 0.2, -0.2, 0.0], name="backright_leg_geom", size="0.08", type="capsule")
228 | ankle_4 = aux_4.body(pos=[0.2, -0.2, 0])
229 | ankle_4.joint(axis=[1, 1, 0], name="ankle_4", pos=[0.0, 0.0, 0.0], range=[30, 70], type="hinge")
230 | ankle_4.geom(fromto=[0.0, 0.0, 0.0, 0.4, -0.4, 0.0], name="backright_ankle_geom", size="0.08", type="capsule")
231 |
232 | target = worldbody.body(name='target', pos=[length/2,length-0.2,-0.5])
233 | target.geom(name='target_geom', conaffinity=2, type='sphere', size=0.2, rgba=[0,0.9,0.1,1])
234 |
235 | l = length/2
236 | h = 0.75
237 | w = 0.05
238 |
239 | worldbody.geom(conaffinity=1, name="sideS", rgba="0.9 0.4 0.6 1", size=[l, w, h], pos=[length/2, 0, 0], type="box")
240 | worldbody.geom(conaffinity=1, name="sideE", rgba="0.9 0.4 0.6 1", size=[w, l, h], pos=[length, length/2, 0], type="box")
241 | worldbody.geom(conaffinity=1, name="sideN", rgba="0.9 0.4 0.6 1", size=[l, w, h], pos=[length/2, length, 0], type="box")
242 | worldbody.geom(conaffinity=1, name="sideW", rgba="0.9 0.4 0.6 1", size=[w, l, h], pos=[0, length/2, 0], type="box")
243 |
244 | # arena
245 | if direction == LEFT:
246 | bx, by, bz = (length/3, length/2, 0)
247 | else:
248 | bx, by, bz = (length*2/3, length/2, 0)
249 |
250 | worldbody.geom(conaffinity=1, name="barrier", rgba="0.9 0.4 0.6 1", size=[l * 2/3, w, h], pos=[bx, by, bz], type="box")
251 | worldbody.geom(conaffinity="1", condim="3", material="MatPlane", name="floor", pos=[length/2, length/2, -h + w],
252 | rgba="0.8 0.9 0.8 1", size="40 40 40", type="plane")
253 |
254 | actuator = mjcmodel.root.actuator()
255 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_4", gear="50")
256 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_4", gear="50")
257 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_1", gear="50")
258 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_1", gear="50")
259 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_2", gear="50")
260 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_2", gear="50")
261 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_3", gear="50")
262 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_3", gear="50")
263 |
264 | return mjcmodel
265 |
266 |
267 | def ant_maze_corridor(direction=RIGHT, height=6.0, width=10.0):
268 | mjcmodel = MJCModel('ant_maze_corridor')
269 | mjcmodel.root.compiler(inertiafromgeom="true", angle="degree", coordinate="local")
270 | mjcmodel.root.option(timestep="0.01", gravity="0 0 -9.8", iterations="20", integrator="Euler")
271 |
272 | assets = mjcmodel.root.asset()
273 | assets.texture(builtin="gradient", height="100", rgb1="1 1 1", rgb2="0 0 0", type="skybox", width="100")
274 | assets.texture(builtin="flat", height="1278", mark="cross", markrgb="1 1 1", name="texgeom", random="0.01", rgb1="0.8 0.6 0.4", rgb2="0.8 0.6 0.4", type="cube", width="127")
275 | assets.texture(builtin="checker", height="100", name="texplane", rgb1="0 0 0", rgb2="0.8 0.8 0.8", type="2d", width="100")
276 | assets.material(name="MatPlane", reflectance="0.5", shininess="1", specular="1", texrepeat="60 60", texture="texplane")
277 | assets.material(name="geom", texture="texgeom", texuniform="true")
278 |
279 | default = mjcmodel.root.default()
280 | default.joint(armature="1", damping=1, limited='true')
281 | default.geom(friction="1 0.5 0.5", density="5.0", margin="0.01", condim="3", conaffinity="0")
282 |
283 | worldbody = mjcmodel.root.worldbody()
284 |
285 | ant = worldbody.body(name='ant', pos=[height/2, 1.0, 0.05])
286 | ant.geom(name='torso_geom', pos=[0, 0, 0], size="0.25", type="sphere")
287 | ant.joint(armature="0", damping="0", limited="false", margin="0.01", name="root", pos=[0, 0, 0], type="free")
288 |
289 | front_left_leg = ant.body(name="front_left_leg", pos=[0, 0, 0])
290 | front_left_leg.geom(fromto=[0.0, 0.0, 0.0, 0.2, 0.2, 0.0], name="aux_1_geom", size="0.08", type="capsule")
291 | aux_1 = front_left_leg.body(name="aux_1", pos=[0.2, 0.2, 0])
292 | aux_1.joint(axis=[0, 0, 1], name="hip_1", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge")
293 | aux_1.geom(fromto=[0.0, 0.0, 0.0, 0.2, 0.2, 0.0], name="left_leg_geom", size="0.08", type="capsule")
294 | ankle_1 = aux_1.body(pos=[0.2, 0.2, 0])
295 | ankle_1.joint(axis=[-1, 1, 0], name="ankle_1", pos=[0.0, 0.0, 0.0], range=[30, 70], type="hinge")
296 | ankle_1.geom(fromto=[0.0, 0.0, 0.0, 0.4, 0.4, 0.0], name="left_ankle_geom", size="0.08", type="capsule")
297 |
298 | front_right_leg = ant.body(name="front_right_leg", pos=[0, 0, 0])
299 | front_right_leg.geom(fromto=[0.0, 0.0, 0.0, -0.2, 0.2, 0.0], name="aux_2_geom", size="0.08", type="capsule")
300 | aux_2 = front_right_leg.body(name="aux_2", pos=[-0.2, 0.2, 0])
301 | aux_2.joint(axis=[0, 0, 1], name="hip_2", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge")
302 | aux_2.geom(fromto=[0.0, 0.0, 0.0, -0.2, 0.2, 0.0], name="right_leg_geom", size="0.08", type="capsule")
303 | ankle_2 = aux_2.body(pos=[-0.2, 0.2, 0])
304 | ankle_2.joint(axis=[1, 1, 0], name="ankle_2", pos=[0.0, 0.0, 0.0], range=[-70, -30], type="hinge")
305 | ankle_2.geom(fromto=[0.0, 0.0, 0.0, -0.4, 0.4, 0.0], name="right_ankle_geom", size="0.08", type="capsule")
306 |
307 | back_left_leg = ant.body(name="back_left_leg", pos=[0, 0, 0])
308 | back_left_leg.geom(fromto=[0.0, 0.0, 0.0, -0.2, -0.2, 0.0], name="aux_3_geom", size="0.08", type="capsule")
309 | aux_3 = back_left_leg.body(name="aux_3", pos=[-0.2, -0.2, 0])
310 | aux_3.joint(axis=[0, 0, 1], name="hip_3", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge")
311 | aux_3.geom(fromto=[0.0, 0.0, 0.0, -0.2, -0.2, 0.0], name="backleft_leg_geom", size="0.08", type="capsule")
312 | ankle_3 = aux_3.body(pos=[-0.2, -0.2, 0])
313 | ankle_3.joint(axis=[-1, 1, 0], name="ankle_3", pos=[0.0, 0.0, 0.0], range=[-70, -30], type="hinge")
314 | ankle_3.geom(fromto=[0.0, 0.0, 0.0, -0.4, -0.4, 0.0], name="backleft_ankle_geom", size="0.08", type="capsule")
315 |
316 | back_right_leg = ant.body(name="back_right_leg", pos=[0, 0, 0])
317 | back_right_leg.geom(fromto=[0.0, 0.0, 0.0, 0.2, -0.2, 0.0], name="aux_4_geom", size="0.08", type="capsule")
318 | aux_4 = back_right_leg.body(name="aux_4", pos=[0.2, -0.2, 0])
319 | aux_4.joint(axis=[0, 0, 1], name="hip_4", pos=[0.0, 0.0, 0.0], range=[-30, 30], type="hinge")
320 | aux_4.geom(fromto=[0.0, 0.0, 0.0, 0.2, -0.2, 0.0], name="backright_leg_geom", size="0.08", type="capsule")
321 | ankle_4 = aux_4.body(pos=[0.2, -0.2, 0])
322 | ankle_4.joint(axis=[1, 1, 0], name="ankle_4", pos=[0.0, 0.0, 0.0], range=[30, 70], type="hinge")
323 | ankle_4.geom(fromto=[0.0, 0.0, 0.0, 0.4, -0.4, 0.0], name="backright_ankle_geom", size="0.08", type="capsule")
324 |
325 | target = worldbody.body(name='target', pos=[height/2, width-1.0,-0.5])
326 | target.geom(name='target_geom', conaffinity=2, type='sphere', size=0.2, rgba=[0,0.9,0.1,1])
327 |
328 | l = height/2
329 | h = 0.75
330 | w = 0.05
331 |
332 | worldbody.geom(conaffinity=1, name="sideS", rgba="0.9 0.4 0.6 1", size=[l, w, h], pos=[height/2, 0, 0], type="box")
333 | worldbody.geom(conaffinity=1, name="sideE", rgba="0.9 0.4 0.6 1", size=[w, width/2, h], pos=[height, width/2, 0], type="box")
334 | worldbody.geom(conaffinity=1, name="sideN", rgba="0.9 0.4 0.6 1", size=[l, w, h], pos=[height/2, width, 0], type="box")
335 | worldbody.geom(conaffinity=1, name="sideW", rgba="0.9 0.4 0.6 1", size=[w, width/2, h], pos=[0, width/2, 0], type="box")
336 |
337 | # arena
338 | wall_ratio = .55#2.0/3
339 | if direction == LEFT:
340 | bx, by, bz = (height*(wall_ratio/2), width/2, 0)
341 | #bx, by, bz = (height/4, width/2, 0)
342 | # bx, by, bz = length * 5/3, length * 5/6 + w, 0
343 | # bx1, by1, bz1 = bx - length/12, by-l/6, bz
344 | else:
345 | bx, by, bz = (height*(1-wall_ratio/2), width/2, 0)
346 | #bx, by, bz = (height*(3/4), width/2, 0)
347 | # bx, by, bz = length / 3, length * 5/6 + w, 0
348 | # bx1, by1, bz1 = bx + length/12, by-l/6, bz
349 |
350 | worldbody.geom(conaffinity=1, name="barrier", rgba="0.9 0.4 0.6 1", size=[l*(wall_ratio), w, h], pos=[bx, by, bz], type="box")
351 | # worldbody.geom(conaffinity=1, name="barrier1", rgba="0.9 0.4 0.6 1", size=[w, l/2 - 2*w, h], pos=[length/2, length/2, bz], type="box")
352 | # worldbody.geom(conaffinity=1, name="barrier2", rgba="0.9 0.4 0.6 1", size=[l/6, w, h], pos=[bx1, by1, bz1], type="box")
353 | # worldbody.geom(conaffinity=1, name="barrier3", rgba="0.9 0.4 0.6 1", size=[w, l/6, h], pos=[bx, by, bz], type="box")
354 | # worldbody.geom(conaffinity=1, condim=3, name="floor", rgba="0.4 0.4 0.4 1", size=[l, l, w], pos=[length/2, length/2, -h],
355 | # type="box")
356 | worldbody.geom(conaffinity="1", condim="3", material="MatPlane", name="floor", pos=[height/2, height/2, -h + w],
357 | rgba="0.8 0.9 0.8 1", size="40 40 40", type="plane")
358 |
359 | actuator = mjcmodel.root.actuator()
360 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_4", gear="30")
361 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_4", gear="30")
362 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_1", gear="30")
363 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_1", gear="30")
364 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_2", gear="30")
365 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_2", gear="30")
366 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="hip_3", gear="30")
367 | actuator.motor(ctrllimited="true", ctrlrange="-1.0 1.0", joint="ankle_3", gear="30")
368 |
369 | return mjcmodel
370 |
371 |
372 | def pusher(goal_pos=np.array([0.45, -0.05, -0.323])):
373 | mjcmodel = MJCModel('pusher')
374 | mjcmodel.root.compiler(inertiafromgeom="true", angle="radian", coordinate="local")
375 | mjcmodel.root.option(timestep="0.01", gravity="0 0 0", iterations="20", integrator="Euler")
376 | default = mjcmodel.root.default()
377 | default.joint(armature=0.04, damping=1, limited=False)
378 | default.geom(friction=[.8, .1, .1], density=300, margin=0.002, condim=1, contype=0, conaffinity=0)
379 |
380 | worldbody = mjcmodel.root.worldbody()
381 | worldbody.light(diffuse=[.5,.5,.5], pos=[0,0,3], dir=[0,0,-1])
382 | worldbody.geom(name='table', type='plane', pos=[0,.5,-.325],size=[1,1,0.1], contype=1, conaffinity=1)
383 |
384 | r_shoulder_pan_link = worldbody.body(name='r_shoulder_pan_link', pos=[0,-.6,0])
385 | r_shoulder_pan_link.geom(name='e1', type='sphere', rgba=[.6,.6,.6,1], pos=[-0.06,0.05,0.2], size=0.05)
386 | r_shoulder_pan_link.geom(name='e2', type='sphere', rgba=[.6,.6,.6,1], pos=[0.06,0.05,0.2], size=0.05)
387 | r_shoulder_pan_link.geom(name='e1p', type='sphere', rgba=[.1,.1,.1,1], pos=[-0.06,0.09,0.2], size=0.03)
388 | r_shoulder_pan_link.geom(name='e2p', type='sphere', rgba=[.1,.1,.1,1], pos=[0.06,0.09,0.2], size=0.03)
389 | r_shoulder_pan_link.geom(name='sp', type='capsule', fromto=[0,0,-0.4,0,0,0.2], size=0.1)
390 |
391 | r_shoulder_pan_link.joint(name='r_shoulder_pan_joint', type='hinge', pos=[0,0,0], axis=[0,0,1],
392 | range=[-2.2854, 1.714602], damping=1.0)
393 |
394 | r_shoulder_lift_link = r_shoulder_pan_link.body(name='r_shoulder_lift_link', pos=[0.1,0,0])
395 | r_shoulder_lift_link.geom(name='s1', type='capsule', fromto="0 -0.1 0 0 0.1 0", size="0.1")
396 | r_shoulder_lift_link.joint(name="r_shoulder_lift_joint", type="hinge", pos="0 0 0", axis="0 1 0",
397 | range="-0.5236 1.3963", damping="1.0")
398 |
399 | r_upper_arm_roll_link = r_shoulder_lift_link.body(name='r_upper_arm_roll_link', pos=[0,0,0])
400 | r_upper_arm_roll_link.geom(name="uar", type="capsule", fromto="-0.1 0 0 0.1 0 0", size="0.02")
401 | r_upper_arm_roll_link.joint(name="r_upper_arm_roll_joint", type="hinge", pos="0 0 0", axis="1 0 0",
402 | range="-1.5 1.7", damping="0.1")
403 |
404 | r_upper_arm_link = r_upper_arm_roll_link.body(name='r_upper_arm_link', pos=[0,0,0])
405 | r_upper_arm_link.geom(name="ua", type="capsule", fromto="0 0 0 0.4 0 0", size="0.06")
406 |
407 | r_elbow_flex_link = r_upper_arm_link.body(name='r_elbow_flex_link', pos=[0.4,0,0])
408 | r_elbow_flex_link.geom(name="ef", type="capsule", fromto="0 -0.02 0 0.0 0.02 0", size="0.06")
409 | r_elbow_flex_link.joint(name="r_elbow_flex_joint", type="hinge", pos="0 0 0", axis="0 1 0", range="-2.3213 0",
410 | damping="0.1")
411 |
412 | r_forearm_roll_link = r_elbow_flex_link.body(name='r_forearm_roll_link', pos=[0,0,0])
413 | r_forearm_roll_link.geom(name="fr", type="capsule", fromto="-0.1 0 0 0.1 0 0", size="0.02")
414 | r_forearm_roll_link.joint(name="r_forearm_roll_joint", type="hinge", limited="true", pos="0 0 0",
415 | axis="1 0 0", damping=".1", range="-1.5 1.5")
416 |
417 | r_forearm_link = r_forearm_roll_link.body(name='r_forearm_link', pos=[0,0,0])
418 | r_forearm_link.geom(name="fa", type="capsule", fromto="0 0 0 0.291 0 0", size="0.05")
419 |
420 | r_wrist_flex_link = r_forearm_link.body(name='r_wrist_flex_link', pos=[0.321,0,0])
421 | r_wrist_flex_link.geom(name="wf", type="capsule", fromto="0 -0.02 0 0 0.02 0", size="0.01" )
422 | r_wrist_flex_link.joint(name="r_wrist_flex_joint", type="hinge", pos="0 0 0", axis="0 1 0",
423 | range="-1.094 0", damping=".1")
424 |
425 | r_wrist_roll_link = r_wrist_flex_link.body(name='r_wrist_roll_link', pos=[0,0,0])
426 | r_wrist_roll_link.joint(name="r_wrist_roll_joint", type="hinge", pos="0 0 0", limited="true", axis="1 0 0",
427 | damping="0.1", range="-1.5 1.5")
428 | r_wrist_roll_link.geom(type="capsule",fromto="0 -0.1 0. 0.0 +0.1 0",size="0.02",contype="1",conaffinity="1")
429 | r_wrist_roll_link.geom(type="capsule",fromto="0 -0.1 0. 0.1 -0.1 0",size="0.02",contype="1",conaffinity="1")
430 | r_wrist_roll_link.geom(type="capsule",fromto="0 +0.1 0. 0.1 +0.1 0",size="0.02",contype="1",conaffinity="1")
431 |
432 | tips_arm = r_wrist_roll_link.body(name='tips_arm', pos=[0,0,0])
433 | tips_arm.geom(name="tip_arml",type="sphere",pos="0.1 -0.1 0.",size="0.01")
434 | tips_arm.geom(name="tip_armr",type="sphere",pos="0.1 0.1 0.",size="0.01")
435 |
436 | #object_ = worldbody.body(name="object", pos=[0.45, -0.05, -0.275])
437 | object_ = worldbody.body(name="object", pos=[0.0, 0.0, -0.275])
438 | #object_.geom(rgba="1 1 1 0",type="sphere",size="0.05 0.05 0.05",density="0.00001",conaffinity="0")
439 | object_.geom(rgba="1 1 1 1",type="cylinder",size="0.05 0.05 0.05",density="0.00001",conaffinity="0", contype=1)
440 | object_.joint(name="obj_slidey",type="slide",pos="0 0 0",axis="0 1 0",range="-10.3213 10.3",damping="0.5")
441 | object_.joint(name="obj_slidex",type="slide",pos="0 0 0",axis="1 0 0",range="-10.3213 10.3",damping="0.5")
442 |
443 | goal = worldbody.body(name='goal', pos=goal_pos)
444 | goal.geom(rgba="1 0 0 1",type="cylinder",size="0.08 0.001 0.1",density='0.00001',contype="0",conaffinity="0")
445 | goal.joint(name="goal_slidey",type="slide",pos="0 0 0",axis="0 1 0",range="-10.3213 10.3",damping="0.5")
446 | goal.joint(name="goal_slidex",type="slide",pos="0 0 0",axis="1 0 0",range="-10.3213 10.3",damping="0.5")
447 |
448 | actuator = mjcmodel.root.actuator()
449 | actuator.motor(joint="r_shoulder_pan_joint",ctrlrange=[-2.0,2.0], ctrllimited=True)
450 | actuator.motor(joint="r_shoulder_lift_joint",ctrlrange=[-2.0,2.0], ctrllimited=True)
451 | actuator.motor(joint="r_upper_arm_roll_joint",ctrlrange=[-2.0,2.0], ctrllimited=True)
452 | actuator.motor(joint="r_elbow_flex_joint",ctrlrange=[-2.0,2.0], ctrllimited=True)
453 | actuator.motor(joint="r_forearm_roll_joint",ctrlrange=[-2.0,2.0], ctrllimited=True)
454 | actuator.motor(joint="r_wrist_flex_joint",ctrlrange=[-2.0,2.0], ctrllimited=True)
455 | actuator.motor(joint="r_wrist_roll_joint",ctrlrange=[-2.0,2.0], ctrllimited=True)
456 |
457 | return mjcmodel
458 |
459 |
460 | def swimmer():
461 | mjcmodel = MJCModel('swimmer')
462 | mjcmodel.root.compiler(inertiafromgeom="true", angle="degree", coordinate="local")
463 | mjcmodel.root.option(timestep=0.01, viscosity=0.1, density=4000, integrator="RK4", collision="predefined")
464 | default = mjcmodel.root.default()
465 | default.joint(armature=0.1)
466 | default.geom(rgba=[0.8, .6, .1, 1], condim=1, contype=1, conaffinity=1, material='geom')
467 |
468 | asset = mjcmodel.root.asset()
469 | asset.texture(builtin='gradient', height=100, rgb1=[1,1,1], rgb2=[0,0,0], type='skybox', width=100)
470 | asset.texture(builtin='flat', height=1278, mark='cross', markrgb=[1,1,1], name='texgeom',
471 | random=0.01, rgb1=[0.8,0.6,0.4], rgb2=[0.8,0.6,0.4], type='cube', width=127)
472 | asset.texture(builtin='checker', height=100, name='texplane',rgb1=[0,0,0], rgb2=[0.8,0.8,0.8], type='2d', width=100)
473 | asset.material(name='MatPlane', reflectance=0.5, shininess=1, specular=1, texrepeat=[30,30], texture='texplane')
474 | asset.material(name='geom', texture='texgeom', texuniform=True)
475 |
476 | worldbody = mjcmodel.root.worldbody()
477 | worldbody.light(cutoff=100, diffuse=[1,1,1], dir=[0,0,-1.3], directional=True, exponent=1, pos=[0,0,1.3], specular=[.1,.1,.1])
478 | worldbody.geom(conaffinity=1, condim=3, material='MatPlane', name='floor', pos=[0,0,-0.1], rgba=[0.8,0.9,0.9,1], size=[40,40,0.1], type='plane')
479 | torso = worldbody.body(name='torso', pos=[0,0,0])
480 | torso.geom(density=1000, fromto=[1.5,0,0,0.5,0,0], size=0.1, type='capsule')
481 | torso.joint(axis=[1,0,0], name='slider1', pos=[0,0,0], type='slide')
482 | torso.joint(axis=[0,1,0], name='slider2', pos=[0,0,0], type='slide')
483 | torso.joint(axis=[0,0,1], name='rot', pos=[0,0,0], type='hinge')
484 | mid = torso.body(name='mid', pos=[0.5,0,0])
485 | mid.geom(density=1000, fromto=[0,0,0,-1,0,0], size=0.1, type='capsule')
486 | mid.joint(axis=[0,0,1], limited=True, name='rot2', pos=[0,0,0], range=[-100,100], type='hinge')
487 | back = mid.body(name='back', pos=[-1,0,0])
488 | back.geom(density=1000, fromto=[0,0,0,-1,0,0], size=0.1, type='capsule')
489 | back.joint(axis=[0,0,1], limited=True, name='rot3', pos=[0,0,0], range=[-100,100], type='hinge')
490 |
491 | actuator = mjcmodel.root.actuator()
492 | actuator.motor(ctrllimited=True, ctrlrange=[-1,1], gear=150, joint='rot2')
493 | actuator.motor(ctrllimited=True, ctrlrange=[-1,1], gear=150, joint='rot3')
494 | return mjcmodel
495 |
496 | def swimmer_rllab():
497 | mjcmodel = MJCModel('swimmer')
498 | mjcmodel.root.compiler(inertiafromgeom="true", angle="degree", coordinate="local")
499 | mjcmodel.root.option(timestep=0.01, viscosity=0.1, density=4000, integrator="Euler", iterations=1000, collision="predefined")
500 |
501 | custom = mjcmodel.root.custom()
502 | custom.numeric(name='frame_skip', data=50)
503 |
504 | default = mjcmodel.root.default()
505 | #default.joint(armature=0.1)
506 | default.geom(rgba=[0.8, .6, .1, 1], condim=1, contype=1, conaffinity=1, material='geom')
507 |
508 | asset = mjcmodel.root.asset()
509 | asset.texture(builtin='gradient', height=100, rgb1=[1,1,1], rgb2=[0,0,0], type='skybox', width=100)
510 | asset.texture(builtin='flat', height=1278, mark='cross', markrgb=[1,1,1], name='texgeom',
511 | random=0.01, rgb1=[0.8,0.6,0.4], rgb2=[0.8,0.6,0.4], type='cube', width=127)
512 | asset.texture(builtin='checker', height=100, name='texplane',rgb1=[0,0,0], rgb2=[0.8,0.8,0.8], type='2d', width=100)
513 | asset.material(name='MatPlane', reflectance=0.5, shininess=1, specular=1, texrepeat=[30,30], texture='texplane')
514 | asset.material(name='geom', texture='texgeom', texuniform=True)
515 |
516 | worldbody = mjcmodel.root.worldbody()
517 | worldbody.light(cutoff=100, diffuse=[1,1,1], dir=[0,0,-1.3], directional=True, exponent=1, pos=[0,0,1.3], specular=[.1,.1,.1])
518 | worldbody.geom(conaffinity=1, condim=3, material='MatPlane', name='floor', pos=[0,0,-0.1], rgba=[0.8,0.9,0.9,1], size=[40,40,0.1], type='plane')
519 | torso = worldbody.body(name='torso', pos=[0,0,0])
520 | torso.geom(density=1000, fromto=[1.5,0,0,0.5,0,0], size=0.1, type='capsule')
521 | torso.joint(axis=[1,0,0], name='slider1', pos=[0,0,0], type='slide')
522 | torso.joint(axis=[0,1,0], name='slider2', pos=[0,0,0], type='slide')
523 | torso.joint(axis=[0,0,1], name='rot', pos=[0,0,0], type='hinge')
524 | mid = torso.body(name='mid', pos=[0.5,0,0])
525 | mid.geom(density=1000, fromto=[0,0,0,-1,0,0], size=0.1, type='capsule')
526 | mid.joint(axis=[0,0,1], limited=True, name='rot2', pos=[0,0,0], range=[-100,100], type='hinge')
527 | back = mid.body(name='back', pos=[-1,0,0])
528 | back.geom(density=1000, fromto=[0,0,0,-1,0,0], size=0.1, type='capsule')
529 | back.joint(axis=[0,0,1], limited=True, name='rot3', pos=[0,0,0], range=[-100,100], type='hinge')
530 |
531 | actuator = mjcmodel.root.actuator()
532 | actuator.motor(ctrllimited=True, ctrlrange=[-50,50], joint='rot2')
533 | actuator.motor(ctrllimited=True, ctrlrange=[-50,50], joint='rot3')
534 | return mjcmodel
535 |
--------------------------------------------------------------------------------
/inverse_rl/envs/dynamic_mjc/model_builder.py:
--------------------------------------------------------------------------------
1 | """
2 | model_builder.py
3 | A small library for programatically building MuJoCo XML files
4 | """
5 | from contextlib import contextmanager
6 | import tempfile
7 | import numpy as np
8 |
9 |
10 | def default_model(name):
11 | """
12 | Get a model with basic settings such as gravity and RK4 integration enabled
13 | """
14 | model = MJCModel(name)
15 | root = model.root
16 |
17 | # Setup
18 | root.compiler(angle="radian", inertiafromgeom="true")
19 | default = root.default()
20 | default.joint(armature=1, damping=1, limited="true")
21 | default.geom(contype=0, friction='1 0.1 0.1', rgba='0.7 0.7 0 1')
22 | root.option(gravity="0 0 -9.81", integrator="RK4", timestep=0.01)
23 | return model
24 |
25 | def pointmass_model(name):
26 | """
27 | Get a model with basic settings such as gravity and Euler integration enabled
28 | """
29 | model = MJCModel(name)
30 | root = model.root
31 |
32 | # Setup
33 | root.compiler(angle="radian", inertiafromgeom="true", coordinate="local")
34 | default = root.default()
35 | default.joint(limited="false", damping=1)
36 | default.geom(contype=2, conaffinity="1", condim="1", friction=".5 .1 .1", density="1000", margin="0.002")
37 | root.option(timestep=0.01, gravity="0 0 0", iterations="20", integrator="Euler")
38 | return model
39 |
40 |
41 | class MJCModel(object):
42 | def __init__(self, name):
43 | self.name = name
44 | self.root = MJCTreeNode("mujoco").add_attr('model', name)
45 |
46 | @contextmanager
47 | def asfile(self):
48 | """
49 | Usage:
50 | model = MJCModel('reacher')
51 | with model.asfile() as f:
52 | print f.read() # prints a dump of the model
53 | """
54 | with tempfile.NamedTemporaryFile(mode='w+', suffix='.xml', delete=True) as f:
55 | self.root.write(f)
56 | f.seek(0)
57 | yield f
58 |
59 | def open(self):
60 | self.file = tempfile.NamedTemporaryFile(mode='w+', suffix='.xml', delete=True)
61 | self.root.write(self.file)
62 | self.file.seek(0)
63 | return self.file
64 |
65 | def close(self):
66 | self.file.close()
67 |
68 | def find_attr(self, attr, value):
69 | return self.root.find_attr(attr, value)
70 |
71 | def __getstate__(self):
72 | return {}
73 |
74 | def __setstate__(self, state):
75 | pass
76 |
77 |
78 | class MJCTreeNode(object):
79 | def __init__(self, name):
80 | self.name = name
81 | self.attrs = {}
82 | self.children = []
83 |
84 | def add_attr(self, key, value):
85 | if isinstance(value, str):
86 | pass
87 | elif isinstance(value, list) or isinstance(value, np.ndarray):
88 | value = ' '.join([str(val).lower() for val in value])
89 | else:
90 | value = str(value).lower()
91 |
92 | self.attrs[key] = value
93 | return self
94 |
95 | def __getattr__(self, name):
96 | def wrapper(**kwargs):
97 | newnode = MJCTreeNode(name)
98 | for (k, v) in kwargs.items():
99 | newnode.add_attr(k, v)
100 | self.children.append(newnode)
101 | return newnode
102 | return wrapper
103 |
104 | def dfs(self):
105 | yield self
106 | if self.children:
107 | for child in self.children:
108 | for node in child.dfs():
109 | yield node
110 |
111 | def find_attr(self, attr, value):
112 | """ Run DFS to find a matching attr """
113 | if attr in self.attrs and self.attrs[attr] == value:
114 | return self
115 | for child in self.children:
116 | res = child.find_attr(attr, value)
117 | if res is not None:
118 | return res
119 | return None
120 |
121 |
122 | def write(self, ostream, tabs=0):
123 | contents = ' '.join(['%s="%s"'%(k,v) for (k,v) in self.attrs.items()])
124 | if self.children:
125 | ostream.write('\t'*tabs)
126 | ostream.write('<%s %s>\n' % (self.name, contents))
127 | for child in self.children:
128 | child.write(ostream, tabs=tabs+1)
129 | ostream.write('\t'*tabs)
130 | ostream.write('%s>\n' % self.name)
131 | else:
132 | ostream.write('\t'*tabs)
133 | ostream.write('<%s %s/>\n' % (self.name, contents))
134 |
135 | def __str__(self):
136 | s = "<"+self.name
137 | s += ' '.join(['%s="%s"'%(k,v) for (k,v) in self.attrs.items()])
138 | return s+">"
--------------------------------------------------------------------------------
/inverse_rl/envs/env_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import namedtuple
3 |
4 | import numpy as np
5 | from gym import Wrapper
6 | import gym
7 |
8 | from rllab.core.serializable import Serializable
9 | from rllab.envs.base import Env, Step
10 | from rllab.envs.gym_env import GymEnv, FixedIntervalVideoSchedule, NoVideoSchedule, CappedCubicVideoSchedule, \
11 | convert_gym_space
12 | from rllab.envs.proxy_env import ProxyEnv
13 | from rllab.misc import logger
14 | from rllab.misc.overrides import overrides
15 |
16 | ENV_ASSET_DIR = os.path.join(os.path.dirname(__file__), 'assets')
17 |
18 |
19 | def get_asset_xml(xml_name):
20 | return os.path.join(ENV_ASSET_DIR, xml_name)
21 |
22 |
23 | class RllabGymEnv(Env, Serializable):
24 | def __init__(self, env_name, wrappers=(), wrapper_args=(),
25 | record_video=True, video_schedule=None, log_dir=None, record_log=True,
26 | post_create_env_seed=None,
27 | force_reset=False):
28 | if log_dir is None:
29 | if logger.get_snapshot_dir() is None:
30 | logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
31 | else:
32 | log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
33 | Serializable.quick_init(self, locals())
34 |
35 | env = gym.envs.make(env_name)
36 | if post_create_env_seed is not None:
37 | env.set_env_seed(post_create_env_seed)
38 | for i, wrapper in enumerate(wrappers):
39 | if wrapper_args and len(wrapper_args) == len(wrappers):
40 | env = wrapper(env, **wrapper_args[i])
41 | else:
42 | env = wrapper(env)
43 | self.env = env
44 | self.env_id = env.spec.id
45 |
46 | assert not (not record_log and record_video)
47 |
48 | if log_dir is None or record_log is False:
49 | self.monitoring = False
50 | else:
51 | if not record_video:
52 | video_schedule = NoVideoSchedule()
53 | else:
54 | if video_schedule is None:
55 | video_schedule = CappedCubicVideoSchedule()
56 | self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True)
57 | self.monitoring = True
58 |
59 | self._observation_space = convert_gym_space(env.observation_space)
60 | logger.log("observation space: {}".format(self._observation_space))
61 | self._action_space = convert_gym_space(env.action_space)
62 | logger.log("action space: {}".format(self._action_space))
63 | self._horizon = env.spec.tags.get('wrapper_config.TimeLimit.max_episode_steps')
64 | self._log_dir = log_dir
65 | self._force_reset = force_reset
66 |
67 | @property
68 | def observation_space(self):
69 | return self._observation_space
70 |
71 | @property
72 | def action_space(self):
73 | return self._action_space
74 |
75 | @property
76 | def horizon(self):
77 | return self._horizon
78 |
79 | def reset(self):
80 | if self._force_reset and self.monitoring:
81 | from gym.wrappers.monitoring import Monitor
82 | assert isinstance(self.env, Monitor)
83 | recorder = self.env.stats_recorder
84 | if recorder is not None:
85 | recorder.done = True
86 | return self.env.reset()
87 |
88 | def step(self, action):
89 | next_obs, reward, done, info = self.env.step(action)
90 | return Step(next_obs, reward, done, **info)
91 |
92 | def render(self):
93 | self.env.render()
94 |
95 | def terminate(self):
96 | if self.monitoring:
97 | self.env._close()
98 | if self._log_dir is not None:
99 | print("""
100 | ***************************
101 |
102 | Training finished! You can upload results to OpenAI Gym by running the following command:
103 |
104 | python scripts/submit_gym.py %s
105 |
106 | ***************************
107 | """ % self._log_dir)
108 |
109 | class CustomGymEnv(RllabGymEnv):
110 | def __init__(self, env_name, gym_wrappers=(),
111 | register_fn=None, wrapper_args = (), record_log=False, record_video=False,
112 | post_create_env_seed=None):
113 | Serializable.quick_init(self, locals())
114 | if register_fn is None:
115 | import inverse_rl.envs
116 | register_fn = inverse_rl.envs.register_custom_envs
117 | register_fn() # Force register
118 | self.env_name = env_name
119 | super(CustomGymEnv, self).__init__(env_name, wrappers=gym_wrappers,
120 | wrapper_args=wrapper_args,
121 | record_log=record_log, record_video=record_video,
122 | post_create_env_seed=post_create_env_seed,
123 | video_schedule=FixedIntervalVideoSchedule(50))
124 |
125 | def _get_obs(self):
126 | return self.env._get_obs()
127 |
128 | @overrides
129 | def log_diagnostics(self, paths):
130 | get_inner_env(self.env).log_diagnostics(paths)
131 |
132 | @overrides
133 | def plot_trajs(self, paths, **kwargs):
134 | if hasattr(self.env, 'plot_trajs'):
135 | self.env.plot_trajs(paths, **kwargs)
136 | else:
137 | raise ValueError('Env %s has no traj plotting'%self.env)
138 |
139 | @overrides
140 | def plot_costs(self, *args, **kwargs):
141 | if hasattr(self.env, 'plot_costs'):
142 | self.env.plot_costs(*args, **kwargs)
143 | else:
144 | raise ValueError('Env %s has no traj plotting'% self.env)
145 |
146 | @property
147 | def wrapped_observation_space(self):
148 | if hasattr(self.env, 'wrapped_observation_space'):
149 | return self.env.wrapped_observation_space
150 | else:
151 | raise AttributeError()
152 |
153 | def get_param_values(self):
154 | return None
155 |
156 | def set_param_values(self, params):
157 | pass
158 |
159 | def get_viewer(self):
160 | return self.env._get_viewer()
161 |
162 |
163 | def get_inner_env(env):
164 | if isinstance(env, ProxyEnv):
165 | return get_inner_env(env.wrapped_env)
166 | elif isinstance(env, GymEnv):
167 | return get_inner_env(env.env)
168 | elif isinstance(env, Wrapper):
169 | return get_inner_env(env.env)
170 | return env
171 |
172 |
173 | def test_env(env, T=100):
174 | aspace = env.action_space
175 | env.reset()
176 | for t in range(T):
177 | o, r, done, infos = env.step(aspace.sample())
178 | print('---T=%d---' % t)
179 | print('rew:', r)
180 | print('obs:', o)
181 | env.render()
182 | if done:
183 | break
184 |
185 |
--------------------------------------------------------------------------------
/inverse_rl/envs/point_maze_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gym import utils
3 | from gym.envs.mujoco import mujoco_env
4 |
5 | from rllab.misc import logger
6 |
7 | from inverse_rl.envs.dynamic_mjc.mjc_models import point_mass_maze
8 |
9 |
10 | class PointMazeEnv(mujoco_env.MujocoEnv, utils.EzPickle):
11 | def __init__(self, direction=1, maze_length=0.6, sparse_reward=False, no_reward=False, episode_length=100):
12 | utils.EzPickle.__init__(self)
13 | self.sparse_reward = sparse_reward
14 | self.no_reward = no_reward
15 | self.max_episode_length = episode_length
16 | self.direction = direction
17 | self.length = maze_length
18 |
19 | self.episode_length = 0
20 |
21 | model = point_mass_maze(direction=self.direction, length=self.length)
22 | with model.asfile() as f:
23 | mujoco_env.MujocoEnv.__init__(self, f.name, 5)
24 |
25 | def _step(self, a):
26 | vec_dist = self.get_body_com("particle") - self.get_body_com("target")
27 |
28 | reward_dist = - np.linalg.norm(vec_dist) # particle to target
29 | reward_ctrl = - np.square(a).sum()
30 | if self.no_reward:
31 | reward = 0
32 | elif self.sparse_reward:
33 | if reward_dist <= 0.1:
34 | reward = 1
35 | else:
36 | reward = 0
37 | else:
38 | reward = reward_dist + 0.001 * reward_ctrl
39 |
40 | self.do_simulation(a, self.frame_skip)
41 | ob = self._get_obs()
42 | self.episode_length += 1
43 | done = self.episode_length >= self.max_episode_length
44 | return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
45 |
46 | def viewer_setup(self):
47 | self.viewer.cam.trackbodyid = -1
48 | self.viewer.cam.distance = 4.0
49 |
50 | def reset_model(self):
51 | qpos = self.init_qpos
52 | self.episode_length = 0
53 | qvel = self.init_qvel + self.np_random.uniform(size=self.model.nv, low=-0.01, high=0.01)
54 | self.set_state(qpos, qvel)
55 | self.episode_length = 0
56 | return self._get_obs()
57 |
58 | def _get_obs(self):
59 | return np.concatenate([
60 | self.get_body_com("particle"),
61 | #self.get_body_com("target"),
62 | ])
63 |
64 | def plot_trajs(self, *args, **kwargs):
65 | pass
66 |
67 | def log_diagnostics(self, paths):
68 | rew_dist = np.array([traj['env_infos']['reward_dist'] for traj in paths])
69 | rew_ctrl = np.array([traj['env_infos']['reward_ctrl'] for traj in paths])
70 |
71 | logger.record_tabular('AvgObjectToGoalDist', -np.mean(rew_dist.mean()))
72 | logger.record_tabular('AvgControlCost', -np.mean(rew_ctrl.mean()))
73 | logger.record_tabular('AvgMinToGoalDist', np.mean(np.min(-rew_dist, axis=1)))
74 |
75 |
--------------------------------------------------------------------------------
/inverse_rl/envs/pusher_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gym import utils
3 | from gym.envs.mujoco import mujoco_env
4 |
5 | #import mujoco_py
6 | #from mujoco_py.mjlib import mjlib
7 | from rllab.misc import logger
8 |
9 | from inverse_rl.envs.dynamic_mjc.mjc_models import pusher
10 |
11 | class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
12 | def __init__(self, sparse_reward=False, no_reward=False, episode_length=200):
13 | utils.EzPickle.__init__(self)
14 | self.sparse_reward = sparse_reward
15 | self.no_reward = no_reward
16 | self.max_episode_length = episode_length
17 | self.goal_pos = np.asarray([0.0, 0.0])
18 |
19 | self.episode_length = 0
20 |
21 | model = pusher(goal_pos=[self.goal_pos[0], self.goal_pos[1], -.323])
22 | with model.asfile() as f:
23 | mujoco_env.MujocoEnv.__init__(self, f.name, 5)
24 |
25 |
26 | def _step(self, a):
27 | vec_1 = self.get_body_com("object") - self.get_body_com("tips_arm")
28 | vec_2 = self.get_body_com("object") - self.get_body_com("goal")
29 | #print('pre_step:', self.get_body_com('object'), self.get_body_com("goal"))
30 |
31 | reward_near = - np.linalg.norm(vec_1) # arm to object
32 | reward_dist = - np.linalg.norm(vec_2[0:2]) # object to goal
33 | reward_ctrl = - np.square(a).sum()
34 | if self.no_reward:
35 | reward = 0
36 | elif self.sparse_reward:
37 | reward = reward_dist + 0.1 * reward_ctrl
38 | else:
39 | reward = reward_dist + 0.1 * reward_ctrl + 0.5 * reward_near
40 | #print(reward_near)
41 | #if (-reward_near ) <= 0.1:
42 | # reward = reward_dist
43 | #else:
44 | # reward = reward_dist + 0.5*reward_near
45 |
46 | #reward += 0.2 * reward_ctrl
47 |
48 | self.do_simulation(a, self.frame_skip)
49 | ob = self._get_obs()
50 | self.episode_length += 1
51 | done = self.episode_length >= self.max_episode_length
52 | #print('post_step:', self.get_body_com('object'), self.get_body_com("goal"))
53 | return ob, reward, done, dict(reward_dist=reward_dist, reward_near=reward_near,
54 | reward_ctrl=reward_ctrl)
55 |
56 | def viewer_setup(self):
57 | self.viewer.cam.trackbodyid = -1
58 | self.viewer.cam.distance = 4.0
59 |
60 | def reset_model(self):
61 | qpos = self.init_qpos
62 | self.episode_length = 0
63 |
64 | while True:
65 | self.cylinder_pos = np.concatenate([
66 | self.np_random.uniform(low=-0.5, high=0, size=1),
67 | self.np_random.uniform(low=-0.5, high=0.5, size=1)])
68 | #self.cylinder_pos = self.np_random.uniform(low=-1.5, high=1.5, size=2)
69 | cyl_dist = np.linalg.norm(self.cylinder_pos - self.goal_pos)
70 | if cyl_dist > 0.2 and cyl_dist < 0.4:
71 | break
72 |
73 | qpos[-4:-2] = self.cylinder_pos
74 | #qpos[-2:] = self.goal_pos
75 | qpos[-2] = self.goal_pos[0]
76 | qpos[-1] = self.goal_pos[1]
77 | qvel = self.init_qvel + self.np_random.uniform(low=-0.005,
78 | high=0.005, size=self.model.nv)
79 | qvel[-4:] = 0
80 | #print('qpos_pre:', self.model.data.qpos.flat[-4:])
81 | #print('qpos_pre_bodycom:', self.get_body_com('object'), self.get_body_com("goal"))
82 | self.set_state(qpos, qvel)
83 | #print('qpos_post:', self.model.data.qpos.flat[-4:])
84 | #print('qpos_post_bodycom:', self.get_body_com('object'), self.get_body_com("goal"))
85 | return self._get_obs()
86 |
87 | def _get_obs(self):
88 | return np.concatenate([
89 | self.model.data.qpos.flat[:7],
90 | self.model.data.qvel.flat[:7],
91 | self.get_body_com("tips_arm"),
92 | self.get_body_com("object"),
93 | self.get_body_com("goal"),
94 | ])
95 |
96 | def plot_trajs(self, *args, **kwargs):
97 | pass
98 |
99 | def log_diagnostics(self, paths):
100 | rew_near = np.array([traj['env_infos']['reward_near'] for traj in paths])
101 | rew_dist = np.array([traj['env_infos']['reward_dist'] for traj in paths])
102 | rew_ctrl = np.array([traj['env_infos']['reward_ctrl'] for traj in paths])
103 |
104 | logger.record_tabular('AvgArmToObjectDist', -np.mean(rew_near))
105 | logger.record_tabular('AvgObjectToGoalDist', -np.mean(rew_dist))
106 | logger.record_tabular('AvgControlCost', -np.mean(rew_ctrl))
107 |
108 |
--------------------------------------------------------------------------------
/inverse_rl/envs/twod_maze.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gym import utils
3 |
4 | from inverse_rl.envs.env_utils import get_asset_xml
5 | from inverse_rl.envs.twod_mjc_env import TwoDEnv
6 |
7 | from rllab.misc import logger as logger
8 |
9 | INIT_POS = np.array([0.15,0.15])
10 | TARGET = np.array([0.15, -0.15])
11 | DIST_THRESH = 0.12
12 |
13 | class TwoDMaze(TwoDEnv, utils.EzPickle):
14 | def __init__(self, verbose=False):
15 | self.verbose = verbose
16 | self.max_episode_length = 200
17 | self.episode_length = 0
18 | utils.EzPickle.__init__(self)
19 | TwoDEnv.__init__(self, get_asset_xml('twod_maze.xml'), 2, xbounds=[-0.3,0.3], ybounds=[-0.3,0.3])
20 |
21 | def _step(self, a):
22 | self.do_simulation(a, self.frame_skip)
23 | ob = self._get_obs()
24 | pos = ob[0:2]
25 | dist = np.sum(np.abs(pos-TARGET)) #np.linalg.norm(pos - TARGET)
26 | reward = - (dist)
27 |
28 | reward_ctrl = - np.square(a).sum()
29 | reward += 1e-3 * reward_ctrl
30 |
31 | if self.verbose:
32 | print(pos, reward)
33 | self.episode_length += 1
34 | done = self.episode_length >= self.max_episode_length
35 | return ob, reward, done, {'distance': dist}
36 |
37 | def reset_model(self):
38 | self.episode_length = 0
39 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.01, high=0.01)
40 | qvel = self.init_qvel + self.np_random.uniform(size=self.model.nv, low=-0.01, high=0.01)
41 | self.set_state(qpos, qvel)
42 | return self._get_obs()
43 |
44 | def _get_obs(self):
45 | #return np.concatenate([self.model.data.qpos, self.model.data.qvel]).ravel()
46 | return np.concatenate([self.model.data.qpos]).ravel() - INIT_POS
47 |
48 | def viewer_setup(self):
49 | v = self.viewer
50 | #v.cam.trackbodyid=0
51 | #v.cam.distance = v.model.stat.extent
52 |
53 | def log_diagnostics(self, paths):
54 | rew_dist = np.array([traj['env_infos']['distance'] for traj in paths])
55 |
56 | logger.record_tabular('AvgObjectToGoalDist', np.mean(rew_dist))
57 | logger.record_tabular('MinAvgObjectToGoalDist', np.mean(np.min(rew_dist, axis=1)))
58 |
59 |
60 |
61 | if __name__ == "__main__":
62 | from inverse_rl.utils.getch import getKey
63 | env = TwoDMaze(verbose=True)
64 |
65 | while True:
66 | key = getKey()
67 | a = np.array([0.0,0.0])
68 | if key == 'w':
69 | a += np.array([0.0, 1.0])
70 | elif key == 'a':
71 | a += np.array([-1.0, 0.0])
72 | elif key == 's':
73 | a += np.array([0.0, -1.0])
74 | elif key == 'd':
75 | a += np.array([1.0, 0.0])
76 | elif key == 'q':
77 | break
78 | a *= 0.2
79 | env.step(a)
80 | env.render()
--------------------------------------------------------------------------------
/inverse_rl/envs/twod_mjc_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from rllab.core.serializable import Serializable
3 |
4 | from gym.envs.mujoco import mujoco_env
5 | from gym.spaces import Box
6 |
7 | from rllab.envs.base import Env
8 | from rllab.misc.overrides import overrides
9 |
10 |
11 | class MapConfig(Serializable):
12 | def __init__(self, xs, ys, xres, yres):
13 | Serializable.quick_init(self, locals())
14 | self.xs = xs
15 | self.ys = ys
16 | self.xres = xres
17 | self.yres=yres
18 |
19 | def map_config(xs=(-0.3,0.3), ys=(-0.3,0.3), xres=50, yres=50):
20 | return MapConfig(xs,ys,xres,yres)
21 |
22 | def make_heat_map(eval_func, map_config):
23 | gps = get_dense_gridpoints(map_config)
24 | vals = np.zeros(map_config.xres*map_config.yres)
25 | for i, pnt in enumerate(gps):
26 | vals[i] = eval_func(pnt)
27 | return predictions_to_heatmap(vals, map_config)
28 |
29 |
30 | def get_dense_gridpoints(map_config):
31 | xl = np.linspace(map_config.xs[0], map_config.xs[1], num=map_config.xres)
32 | yl = np.linspace(map_config.ys[0], map_config.ys[1], num=map_config.yres)
33 | gridpoints = np.zeros((map_config.xres*map_config.yres, 2))
34 | for i in range(map_config.xres):
35 | for j in range(map_config.yres):
36 | gridpoints[i+map_config.xres*j] = np.array((xl[i], yl[j]))
37 | return gridpoints
38 |
39 |
40 | def predictions_to_heatmap(predictions, map_config):
41 | map = np.zeros((map_config.xres, map_config.yres))
42 | for i in range(map_config.xres):
43 | for j in range(map_config.yres):
44 | map[i,j] = predictions[i+map_config.xres*j]
45 | map = map/np.max(map)
46 | return map.T
47 |
48 | def make_density_map(paths, map_config):
49 | xs = np.linspace(map_config.xs[0], map_config.xs[1], num=map_config.xres+1)
50 | ys = np.linspace(map_config.ys[0], map_config.ys[1], num=map_config.yres+1)
51 | y = paths[:,0]
52 | x = paths[:,1]
53 | H, xedges, yedges = np.histogram2d(y, x, bins=(xs, ys))
54 | H = H.astype(np.float)
55 | H = H/np.max(H)
56 | return H.T
57 |
58 | def plot_maps(combined_list=None, *heatmaps):
59 | import matplotlib.pyplot as plt
60 | combined = np.c_[heatmaps]
61 | if combined_list is not None:
62 | combined_list.append(combined)
63 | combined = np.concatenate(combined_list)
64 | else:
65 | combined_list = []
66 | plt.figure()
67 | plt.imshow(combined, cmap='afmhot', interpolation='none')
68 | plt.show()
69 | return combined_list
70 |
71 |
72 | class TwoDEnv(mujoco_env.MujocoEnv):
73 | def __init__(self, model_path, frame_skip, xbounds, ybounds):
74 | super(TwoDEnv, self).__init__(model_path=model_path, frame_skip=frame_skip)
75 | assert isinstance(self.observation_space, Box)
76 | assert self.observation_space.shape == (2,)
77 | self.__map_config = map_config(xs=(xbounds[0], xbounds[1]),
78 | ys=(ybounds[0], ybounds[1]))
79 |
80 | @property
81 | def map_config(self):
82 | return self.__map_config
83 |
84 | @property
85 | def grid_flat_dim(self):
86 | return self.map_config.xres*self.map_config.yres
87 |
88 | def make_density_map(self, paths):
89 | return make_density_map(paths, self.map_config)
90 |
91 | def make_heatmap(self, eval_func):
92 | return make_heat_map(eval_func, self.map_config)
93 |
94 | def get_dense_gridpoints(self):
95 | return get_dense_gridpoints(self.map_config)
96 |
97 | def predictions_to_heatmap(self, predictions):
98 | return predictions_to_heatmap(predictions, self.map_config)
99 |
100 | def get_viewer(self):
101 | return self._get_viewer()
102 |
103 | @overrides
104 | def log_diagnostics(self, paths):
105 | pass
106 |
--------------------------------------------------------------------------------
/inverse_rl/envs/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def flat_to_one_hot(val, ndim):
4 | """
5 |
6 | >>> flat_to_one_hot(2, ndim=4)
7 | array([ 0., 0., 1., 0.])
8 | >>> flat_to_one_hot(4, ndim=5)
9 | array([ 0., 0., 0., 0., 1.])
10 | >>> flat_to_one_hot(np.array([2, 4, 3]), ndim=5)
11 | array([[ 0., 0., 1., 0., 0.],
12 | [ 0., 0., 0., 0., 1.],
13 | [ 0., 0., 0., 1., 0.]])
14 | """
15 | shape =np.array(val).shape
16 | v = np.zeros(shape + (ndim,))
17 | if len(shape) == 1:
18 | v[np.arange(shape[0]), val] = 1.0
19 | else:
20 | v[val] = 1.0
21 | return v
22 |
23 | def one_hot_to_flat(val):
24 | """
25 | >>> one_hot_to_flat(np.array([0,0,0,0,1]))
26 | 4
27 | >>> one_hot_to_flat(np.array([0,0,1,0]))
28 | 2
29 | >>> one_hot_to_flat(np.array([[0,0,1,0], [1,0,0,0], [0,1,0,0]]))
30 | array([2, 0, 1])
31 | """
32 | idxs = np.array(np.where(val == 1.0))[-1]
33 | if len(val.shape) == 1:
34 | return int(idxs)
35 | return idxs
--------------------------------------------------------------------------------
/inverse_rl/envs/visual_pointmass.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from gym import utils
4 | from gym.envs.mujoco import mujoco_env
5 | from gym.spaces import Box
6 |
7 | from inverse_rl.envs.dynamic_mjc.mjc_models import point_mass_maze
8 | from inverse_rl.envs.env_utils import get_asset_xml
9 | from inverse_rl.envs.twod_mjc_env import TwoDEnv
10 |
11 | from rllab.misc import logger as logger
12 |
13 | INIT_POS = np.array([0.15,0.15])
14 | TARGET = np.array([0.15, -0.15])
15 | DIST_THRESH = 0.12
16 |
17 | class VisualTwoDMaze(mujoco_env.MujocoEnv, utils.EzPickle):
18 | def __init__(self, verbose=False, width=64, height=64):
19 | self.verbose = verbose
20 | self.max_episode_length = 200
21 | self.episode_length = 0
22 | self.width = width
23 | self.height = height
24 | utils.EzPickle.__init__(self)
25 | super(VisualTwoDMaze, self).__init__(get_asset_xml('twod_maze.xml'), frame_skip=2)
26 |
27 | # calculate image dimensions
28 | #self._get_viewer().render()
29 | #data, width, height = self._get_viewer().get_image()
30 | self.observation_space = Box(0, 1, shape=(width, height, 3))
31 |
32 | def _step(self, a):
33 | self.do_simulation(a, self.frame_skip)
34 | state = self._get_state()
35 | pos = state[0:2]
36 | dist = np.sum(np.abs(pos-TARGET)) #np.linalg.norm(pos - TARGET)
37 | reward = - (dist)
38 |
39 | reward_ctrl = - np.square(a).sum()
40 | reward += 1e-3 * reward_ctrl
41 |
42 | if self.verbose:
43 | print(pos, reward)
44 | self.episode_length += 1
45 | done = self.episode_length >= self.max_episode_length
46 |
47 | ob = self._get_obs()
48 | return ob, reward, done, {'distance': dist}
49 |
50 | def reset_model(self):
51 | self.episode_length = 0
52 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-0.01, high=0.01)
53 | qvel = self.init_qvel + self.np_random.uniform(size=self.model.nv, low=-0.01, high=0.01)
54 | self.set_state(qpos, qvel)
55 | return self._get_obs()
56 |
57 | def _get_state(self):
58 | #return np.concatenate([self.model.data.qpos, self.model.data.qvel]).ravel()
59 | return np.concatenate([self.model.data.qpos]).ravel() - INIT_POS
60 |
61 | def _get_obs(self):
62 | self._get_viewer().render()
63 | data, width, height = self._get_viewer().get_image()
64 | image = np.fromstring(data, dtype='uint8').reshape(height, width, 3)[::-1, :, :]
65 | # reshape
66 | if self.grayscale:
67 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
68 | image = cv2.resize(image, (self.width, self.height), interpolation=cv2.INTER_AREA)
69 |
70 | # rescale image to float
71 | image = image.astype(np.float32)/255.0
72 | return image
73 |
74 | def viewer_setup(self):
75 | #v = self.viewer
76 | self.viewer.cam.trackbodyid = -1
77 | self.viewer.cam.distance = 1.0
78 | #v.cam.trackbodyid=0
79 | #v.cam.distance = v.model.stat.extent
80 |
81 | def log_diagnostics(self, paths):
82 | rew_dist = np.array([traj['env_infos']['distance'] for traj in paths])
83 | logger.record_tabular('AvgObjectToGoalDist', np.mean(rew_dist))
84 | logger.record_tabular('MinAvgObjectToGoalDist', np.mean(np.min(rew_dist, axis=1)))
85 |
86 |
87 | class VisualPointMazeEnv(mujoco_env.MujocoEnv, utils.EzPickle):
88 | def __init__(self, direction=1, maze_length=0.6,
89 | sparse_reward=False, no_reward=False, episode_length=100, grayscale=True,
90 | width=64, height=64):
91 | utils.EzPickle.__init__(self)
92 | self.sparse_reward = sparse_reward
93 | self.no_reward = no_reward
94 | self.max_episode_length = episode_length
95 | self.direction = direction
96 | self.length = maze_length
97 |
98 | self.width = width
99 | self.height = height
100 | self.grayscale=grayscale
101 |
102 | self.episode_length = 0
103 |
104 | model = point_mass_maze(direction=self.direction, length=self.length, borders=False)
105 | with model.asfile() as f:
106 | mujoco_env.MujocoEnv.__init__(self, f.name, 5)
107 |
108 | if self.grayscale:
109 | self.observation_space = Box(0, 1, shape=(width, height))
110 | else:
111 | self.observation_space = Box(0, 1, shape=(width, height, 3))
112 |
113 | def _step(self, a):
114 | vec_dist = self.get_body_com("particle") - self.get_body_com("target")
115 |
116 | reward_dist = - np.linalg.norm(vec_dist) # particle to target
117 | reward_ctrl = - np.square(a).sum()
118 | if self.no_reward:
119 | reward = 0
120 | elif self.sparse_reward:
121 | if reward_dist <= 0.1:
122 | reward = 1
123 | else:
124 | reward = 0
125 | else:
126 | reward = reward_dist + 0.001 * reward_ctrl
127 |
128 | self.do_simulation(a, self.frame_skip)
129 | ob = self._get_obs()
130 | self.episode_length += 1
131 | done = self.episode_length >= self.max_episode_length
132 | return ob, reward, done, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl)
133 |
134 | def viewer_setup(self):
135 | self.viewer.cam.trackbodyid = -1
136 | self.viewer.cam.distance = 1.0
137 |
138 | def reset_model(self):
139 | qpos = self.init_qpos
140 | self.episode_length = 0
141 | qvel = self.init_qvel + self.np_random.uniform(size=self.model.nv, low=-0.01, high=0.01)
142 | self.set_state(qpos, qvel)
143 | self.episode_length = 0
144 | return self._get_obs()
145 |
146 | def _get_state(self):
147 | return np.concatenate([
148 | self.get_body_com("particle"),
149 | #self.get_body_com("target"),
150 | ])
151 |
152 | def _get_obs(self):
153 | self._get_viewer().render()
154 | data, width, height = self._get_viewer().get_image()
155 | image = np.fromstring(data, dtype='uint8').reshape(height, width, 3)[::-1, :, :]
156 | # rescale image to float
157 | if self.grayscale:
158 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
159 | image = cv2.resize(image, (self.width, self.height), interpolation=cv2.INTER_AREA)
160 | image = image.astype(np.float32)/255.0
161 | return image
162 |
163 | def log_diagnostics(self, paths):
164 | rew_dist = np.array([traj['env_infos']['reward_dist'] for traj in paths])
165 | rew_ctrl = np.array([traj['env_infos']['reward_ctrl'] for traj in paths])
166 |
167 | logger.record_tabular('AvgObjectToGoalDist', -np.mean(rew_dist.mean()))
168 | logger.record_tabular('AvgControlCost', -np.mean(rew_ctrl.mean()))
169 | logger.record_tabular('AvgMinToGoalDist', np.mean(np.min(-rew_dist, axis=1)))
170 |
171 |
172 | if __name__ == "__main__":
173 | from inverse_rl.utils.getch import getKey
174 | #env = VisualTwoDMaze(verbose=True)
175 | env = VisualPointMazeEnv()
176 |
177 | while True:
178 | key = getKey()
179 | a = np.array([0.0,0.0])
180 | if key == 'w':
181 | a += np.array([0.0, 1.0])
182 | elif key == 'a':
183 | a += np.array([-1.0, 0.0])
184 | elif key == 's':
185 | a += np.array([0.0, -1.0])
186 | elif key == 'd':
187 | a += np.array([1.0, 0.0])
188 | elif key == 'q':
189 | break
190 | a *= 0.2
191 | o,_,_,_ = env.step(a)
192 | print(o.shape, o)
193 | env.render()
194 |
--------------------------------------------------------------------------------
/inverse_rl/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/justinjfu/inverse_rl/9609933389459a3a54f5c01d652114ada90fa1b3/inverse_rl/models/__init__.py
--------------------------------------------------------------------------------
/inverse_rl/models/airl_state.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from sandbox.rocky.tf.spaces.box import Box
4 |
5 | from inverse_rl.models.fusion_manager import RamFusionDistr
6 | from inverse_rl.models.imitation_learning import SingleTimestepIRL
7 | from inverse_rl.models.architectures import relu_net
8 | from inverse_rl.utils import TrainingIterator
9 |
10 |
11 |
12 | class AIRL(SingleTimestepIRL):
13 | """
14 |
15 |
16 | Args:
17 | fusion (bool): Use trajectories from old iterations to train.
18 | state_only (bool): Fix the learned reward to only depend on state.
19 | score_discrim (bool): Use log D - log 1-D as reward (if true you should not need to use an entropy bonus)
20 | max_itrs (int): Number of training iterations to run per fit step.
21 | """
22 | def __init__(self, env,
23 | expert_trajs=None,
24 | reward_arch=relu_net,
25 | reward_arch_args=None,
26 | value_fn_arch=relu_net,
27 | score_discrim=False,
28 | discount=1.0,
29 | state_only=False,
30 | max_itrs=100,
31 | fusion=False,
32 | name='airl'):
33 | super(AIRL, self).__init__()
34 | env_spec = env.spec
35 | if reward_arch_args is None:
36 | reward_arch_args = {}
37 |
38 | if fusion:
39 | self.fusion = RamFusionDistr(100, subsample_ratio=0.5)
40 | else:
41 | self.fusion = None
42 | self.dO = env_spec.observation_space.flat_dim
43 | self.dU = env_spec.action_space.flat_dim
44 | assert isinstance(env.action_space, Box)
45 | self.score_discrim = score_discrim
46 | self.gamma = discount
47 | assert value_fn_arch is not None
48 | self.set_demos(expert_trajs)
49 | self.state_only=state_only
50 | self.max_itrs=max_itrs
51 |
52 | # build energy model
53 | with tf.variable_scope(name) as _vs:
54 | # Should be batch_size x T x dO/dU
55 | self.obs_t = tf.placeholder(tf.float32, [None, self.dO], name='obs')
56 | self.nobs_t = tf.placeholder(tf.float32, [None, self.dO], name='nobs')
57 | self.act_t = tf.placeholder(tf.float32, [None, self.dU], name='act')
58 | self.nact_t = tf.placeholder(tf.float32, [None, self.dU], name='nact')
59 | self.labels = tf.placeholder(tf.float32, [None, 1], name='labels')
60 | self.lprobs = tf.placeholder(tf.float32, [None, 1], name='log_probs')
61 | self.lr = tf.placeholder(tf.float32, (), name='lr')
62 |
63 | with tf.variable_scope('discrim') as dvs:
64 | rew_input = self.obs_t
65 | if not self.state_only:
66 | rew_input = tf.concat([self.obs_t, self.act_t], axis=1)
67 | with tf.variable_scope('reward'):
68 | self.reward = reward_arch(rew_input, dout=1, **reward_arch_args)
69 | #energy_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name)
70 |
71 | # value function shaping
72 | with tf.variable_scope('vfn'):
73 | fitted_value_fn_n = value_fn_arch(self.nobs_t, dout=1)
74 | with tf.variable_scope('vfn', reuse=True):
75 | self.value_fn = fitted_value_fn = value_fn_arch(self.obs_t, dout=1)
76 |
77 | # Define log p_tau(a|s) = r + gamma * V(s') - V(s)
78 | self.qfn = self.reward + self.gamma*fitted_value_fn_n
79 | log_p_tau = self.reward + self.gamma*fitted_value_fn_n - fitted_value_fn
80 |
81 | log_q_tau = self.lprobs
82 |
83 | log_pq = tf.reduce_logsumexp([log_p_tau, log_q_tau], axis=0)
84 | self.discrim_output = tf.exp(log_p_tau-log_pq)
85 | cent_loss = -tf.reduce_mean(self.labels*(log_p_tau-log_pq) + (1-self.labels)*(log_q_tau-log_pq))
86 |
87 | self.loss = cent_loss
88 | tot_loss = self.loss
89 | self.step = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(tot_loss)
90 | self._make_param_ops(_vs)
91 |
92 | def fit(self, paths, policy=None, batch_size=32, logger=None, lr=1e-3,**kwargs):
93 |
94 | if self.fusion is not None:
95 | old_paths = self.fusion.sample_paths(n=len(paths))
96 | self.fusion.add_paths(paths)
97 | paths = paths+old_paths
98 |
99 | # eval samples under current policy
100 | self._compute_path_probs(paths, insert=True)
101 |
102 | # eval expert log probs under current policy
103 | self.eval_expert_probs(self.expert_trajs, policy, insert=True)
104 |
105 | self._insert_next_state(paths)
106 | self._insert_next_state(self.expert_trajs)
107 | obs, obs_next, acts, acts_next, path_probs = \
108 | self.extract_paths(paths,
109 | keys=('observations', 'observations_next', 'actions', 'actions_next', 'a_logprobs'))
110 | expert_obs, expert_obs_next, expert_acts, expert_acts_next, expert_probs = \
111 | self.extract_paths(self.expert_trajs,
112 | keys=('observations', 'observations_next', 'actions', 'actions_next', 'a_logprobs'))
113 |
114 |
115 | # Train discriminator
116 | for it in TrainingIterator(self.max_itrs, heartbeat=5):
117 | nobs_batch, obs_batch, nact_batch, act_batch, lprobs_batch = \
118 | self.sample_batch(obs_next, obs, acts_next, acts, path_probs, batch_size=batch_size)
119 |
120 | nexpert_obs_batch, expert_obs_batch, nexpert_act_batch, expert_act_batch, expert_lprobs_batch = \
121 | self.sample_batch(expert_obs_next, expert_obs, expert_acts_next, expert_acts, expert_probs, batch_size=batch_size)
122 |
123 | # Build feed dict
124 | labels = np.zeros((batch_size*2, 1))
125 | labels[batch_size:] = 1.0
126 | obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0)
127 | nobs_batch = np.concatenate([nobs_batch, nexpert_obs_batch], axis=0)
128 | act_batch = np.concatenate([act_batch, expert_act_batch], axis=0)
129 | nact_batch = np.concatenate([nact_batch, nexpert_act_batch], axis=0)
130 | lprobs_batch = np.expand_dims(np.concatenate([lprobs_batch, expert_lprobs_batch], axis=0), axis=1).astype(np.float32)
131 | feed_dict = {
132 | self.act_t: act_batch,
133 | self.obs_t: obs_batch,
134 | self.nobs_t: nobs_batch,
135 | self.nact_t: nact_batch,
136 | self.labels: labels,
137 | self.lprobs: lprobs_batch,
138 | self.lr: lr
139 | }
140 |
141 | loss, _ = tf.get_default_session().run([self.loss, self.step], feed_dict=feed_dict)
142 | it.record('loss', loss)
143 | if it.heartbeat:
144 | print(it.itr_message())
145 | mean_loss = it.pop_mean('loss')
146 | print('\tLoss:%f' % mean_loss)
147 |
148 | if logger:
149 | logger.record_tabular('GCLDiscrimLoss', mean_loss)
150 | #obs_next = np.r_[obs_next, np.expand_dims(obs_next[-1], axis=0)]
151 | energy, logZ, dtau = tf.get_default_session().run([self.reward, self.value_fn, self.discrim_output],
152 | feed_dict={self.act_t: acts, self.obs_t: obs, self.nobs_t: obs_next,
153 | self.nact_t: acts_next,
154 | self.lprobs: np.expand_dims(path_probs, axis=1)})
155 | energy = -energy
156 | logger.record_tabular('GCLLogZ', np.mean(logZ))
157 | logger.record_tabular('GCLAverageEnergy', np.mean(energy))
158 | logger.record_tabular('GCLAverageLogPtau', np.mean(-energy-logZ))
159 | logger.record_tabular('GCLAverageLogQtau', np.mean(path_probs))
160 | logger.record_tabular('GCLMedianLogQtau', np.median(path_probs))
161 | logger.record_tabular('GCLAverageDtau', np.mean(dtau))
162 |
163 |
164 | #expert_obs_next = np.r_[expert_obs_next, np.expand_dims(expert_obs_next[-1], axis=0)]
165 | energy, logZ, dtau = tf.get_default_session().run([self.reward, self.value_fn, self.discrim_output],
166 | feed_dict={self.act_t: expert_acts, self.obs_t: expert_obs, self.nobs_t: expert_obs_next,
167 | self.nact_t: expert_acts_next,
168 | self.lprobs: np.expand_dims(expert_probs, axis=1)})
169 | energy = -energy
170 | logger.record_tabular('GCLAverageExpertEnergy', np.mean(energy))
171 | logger.record_tabular('GCLAverageExpertLogPtau', np.mean(-energy-logZ))
172 | logger.record_tabular('GCLAverageExpertLogQtau', np.mean(expert_probs))
173 | logger.record_tabular('GCLMedianExpertLogQtau', np.median(expert_probs))
174 | logger.record_tabular('GCLAverageExpertDtau', np.mean(dtau))
175 | return mean_loss
176 |
177 | def eval(self, paths, **kwargs):
178 | """
179 | Return bonus
180 | """
181 | if self.score_discrim:
182 | self._compute_path_probs(paths, insert=True)
183 | obs, obs_next, acts, path_probs = self.extract_paths(paths, keys=('observations', 'observations_next', 'actions', 'a_logprobs'))
184 | path_probs = np.expand_dims(path_probs, axis=1)
185 | scores = tf.get_default_session().run(self.discrim_output,
186 | feed_dict={self.act_t: acts, self.obs_t: obs,
187 | self.nobs_t: obs_next,
188 | self.lprobs: path_probs})
189 | score = np.log(scores) - np.log(1-scores)
190 | score = score[:,0]
191 | else:
192 | obs, acts = self.extract_paths(paths)
193 | reward = tf.get_default_session().run(self.reward,
194 | feed_dict={self.act_t: acts, self.obs_t: obs})
195 | score = reward[:,0]
196 | return self.unpack(score, paths)
197 |
198 | def eval_single(self, obs):
199 | reward = tf.get_default_session().run(self.reward,
200 | feed_dict={self.obs_t: obs})
201 | score = reward[:, 0]
202 | return score
203 |
204 | def debug_eval(self, paths, **kwargs):
205 | obs, acts = self.extract_paths(paths)
206 | reward, v, qfn = tf.get_default_session().run([self.reward, self.value_fn,
207 | self.qfn],
208 | feed_dict={self.act_t: acts, self.obs_t: obs})
209 | return {
210 | 'reward': reward,
211 | 'value': v,
212 | 'qfn': qfn,
213 | }
214 |
215 |
--------------------------------------------------------------------------------
/inverse_rl/models/architectures.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from inverse_rl.models.tf_util import relu_layer, linear
3 |
4 |
5 | def make_relu_net(layers=2, dout=1, d_hidden=32):
6 | def relu_net(x, last_layer_bias=True):
7 | out = x
8 | for i in range(layers):
9 | out = relu_layer(out, dout=d_hidden, name='l%d'%i)
10 | out = linear(out, dout=dout, name='lfinal', bias=last_layer_bias)
11 | return out
12 | return relu_net
13 |
14 |
15 | def relu_net(x, layers=2, dout=1, d_hidden=32):
16 | out = x
17 | for i in range(layers):
18 | out = relu_layer(out, dout=d_hidden, name='l%d'%i)
19 | out = linear(out, dout=dout, name='lfinal')
20 | return out
21 |
22 |
23 | def linear_net(x, dout=1):
24 | out = x
25 | out = linear(out, dout=dout, name='lfinal')
26 | return out
27 |
28 |
29 | def feedforward_energy(obs_act, ff_arch=relu_net):
30 | # for trajectories, using feedforward nets rather than RNNs
31 | dimOU = int(obs_act.get_shape()[2])
32 | orig_shape = tf.shape(obs_act)
33 |
34 | obs_act = tf.reshape(obs_act, [-1, dimOU])
35 | outputs = ff_arch(obs_act)
36 | dOut = int(outputs.get_shape()[-1])
37 |
38 | new_shape = tf.stack([orig_shape[0],orig_shape[1], dOut])
39 | outputs = tf.reshape(outputs, new_shape)
40 | return outputs
41 |
42 |
43 | def rnn_trajectory_energy(obs_act):
44 | """
45 | Operates on trajectories
46 | """
47 | # for trajectories
48 | dimOU = int(obs_act.get_shape()[2])
49 |
50 | cell = tf.contrib.rnn.GRUCell(num_units=dimOU)
51 | cell_out = tf.contrib.rnn.OutputProjectionWrapper(cell, 1)
52 | outputs, hidden = tf.nn.dynamic_rnn(cell_out, obs_act, time_major=False, dtype=tf.float32)
53 | return outputs
54 |
55 |
--------------------------------------------------------------------------------
/inverse_rl/models/fusion_manager.py:
--------------------------------------------------------------------------------
1 | import os
2 | import joblib
3 | import re
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 | from rllab.misc.logger import get_snapshot_dir
8 |
9 |
10 | class FusionDistrManager(object):
11 | def add_paths(self, paths):
12 | raise NotImplementedError()
13 |
14 | def sample_paths(self, n):
15 | raise NotImplementedError()
16 |
17 |
18 | class PathsReader(object):
19 | ITR_REG = re.compile(r"itr_(?P[0-9]+)\.pkl")
20 |
21 | def __init__(self, path_dir):
22 | self.path_dir = path_dir
23 |
24 | def get_path_files(self):
25 | itr_files = []
26 | for i, filename in enumerate(os.listdir(self.path_dir)):
27 | m = PathsReader.ITR_REG.match(filename)
28 | if m:
29 | itr_count = m.group('itr_count')
30 | itr_files.append((itr_count, filename))
31 |
32 | itr_files = sorted(itr_files, key=lambda x: int(x[0]), reverse=True)
33 | for itr_file_and_count in itr_files:
34 | fname = os.path.join(self.path_dir, itr_file_and_count[1])
35 | yield fname
36 |
37 | def __len__(self):
38 | return len(list(self.get_path_files()))
39 |
40 |
41 | class DiskFusionDistr(FusionDistrManager):
42 | def __init__(self, path_dir=None):
43 | if path_dir is None:
44 | path_dir = get_snapshot_dir()
45 | self.path_dir = path_dir
46 | self.paths_reader = PathsReader(path_dir)
47 |
48 | def add_paths(self, paths):
49 | raise NotImplementedError()
50 |
51 | def sample_paths(self, n):
52 | # load from disk!
53 | fnames = list(self.paths_reader.get_path_files())
54 | N = len(fnames)
55 | sample_files = np.random.randint(0, N, size=(n))
56 | #sample_hist = np.histogram(sample_files, range=(0, N))
57 | #print(sample_hist)
58 | unique, counts = np.unique(sample_files, return_counts=True)
59 | unique_dict = dict(zip(unique, counts))
60 |
61 | all_paths = []
62 | for fidx in unique_dict:
63 | fname = fnames[fidx]
64 | n_samp = unique_dict[fidx]
65 | print(fname, n_samp)
66 |
67 | config = tf.ConfigProto()
68 | config.gpu_options.allow_growth = True
69 | with tf.Graph().as_default():
70 | with tf.Session(config=config).as_default():
71 | snapshot_dict = joblib.load(fname)
72 | paths = snapshot_dict['paths']
73 | pidxs = np.random.randint(0, len(paths), size=(n_samp))
74 | all_paths.extend([paths[pidx] for pidx in pidxs])
75 | return all_paths
76 |
77 |
78 | class RamFusionDistr(FusionDistrManager):
79 | def __init__(self, buf_size, subsample_ratio=0.5):
80 | self.buf_size = buf_size
81 | self.buffer = []
82 | self.subsample_ratio = subsample_ratio
83 |
84 | def add_paths(self, paths, subsample=True):
85 | if subsample:
86 | paths = paths[:int(len(paths)*self.subsample_ratio)]
87 | self.buffer.extend(paths)
88 | overflow = len(self.buffer)-self.buf_size
89 | while overflow > 0:
90 | #self.buffer = self.buffer[overflow:]
91 | N = len(self.buffer)
92 | probs = np.arange(N)+1
93 | probs = probs/float(np.sum(probs))
94 | pidx = np.random.choice(np.arange(N), p=probs)
95 | self.buffer.pop(pidx)
96 | overflow -= 1
97 |
98 | def sample_paths(self, n):
99 | if len(self.buffer) == 0:
100 | return []
101 | else:
102 | pidxs = np.random.randint(0, len(self.buffer), size=(n))
103 | return [self.buffer[pidx] for pidx in pidxs]
104 |
105 |
106 | if __name__ == "__main__":
107 | #fm = DiskFusionDistr(path_dir='data_nobs/gridworld_random/gru1')
108 | #paths = fm.sample_paths(10)
109 | fm = RamFusionDistr(10)
110 | fm.add_paths([1,2,3,4,5,6,7,8,9,10,11,12,13])
111 | print(fm.buffer)
112 | print(fm.sample_paths(5))
113 |
--------------------------------------------------------------------------------
/inverse_rl/models/imitation_learning.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | from inverse_rl.models.architectures import feedforward_energy, relu_net
5 | from inverse_rl.models.tf_util import discounted_reduce_sum
6 | from inverse_rl.utils.general import TrainingIterator
7 | from inverse_rl.utils.hyperparametrized import Hyperparametrized
8 | from inverse_rl.utils.math_utils import gauss_log_pdf, categorical_log_pdf
9 | from sandbox.rocky.tf.misc import tensor_utils
10 |
11 | LOG_REG = 1e-8
12 | DIST_GAUSSIAN = 'gaussian'
13 | DIST_CATEGORICAL = 'categorical'
14 |
15 | class ImitationLearning(object, metaclass=Hyperparametrized):
16 | def __init__(self):
17 | pass
18 |
19 | def set_demos(self, paths):
20 | if paths is not None:
21 | self.expert_trajs = paths
22 | self.expert_trajs_extracted = self.extract_paths(paths)
23 |
24 | @staticmethod
25 | def _compute_path_probs(paths, pol_dist_type=None, insert=True,
26 | insert_key='a_logprobs'):
27 | """
28 | Returns a N x T matrix of action probabilities
29 | """
30 | if insert_key in paths[0]:
31 | return np.array([path[insert_key] for path in paths])
32 |
33 | if pol_dist_type is None:
34 | # try to infer distribution type
35 | path0 = paths[0]
36 | if 'log_std' in path0['agent_infos']:
37 | pol_dist_type = DIST_GAUSSIAN
38 | elif 'prob' in path0['agent_infos']:
39 | pol_dist_type = DIST_CATEGORICAL
40 | else:
41 | raise NotImplementedError()
42 |
43 | # compute path probs
44 | Npath = len(paths)
45 | actions = [path['actions'] for path in paths]
46 | if pol_dist_type == DIST_GAUSSIAN:
47 | params = [(path['agent_infos']['mean'], path['agent_infos']['log_std']) for path in paths]
48 | path_probs = [gauss_log_pdf(params[i], actions[i]) for i in range(Npath)]
49 | elif pol_dist_type == DIST_CATEGORICAL:
50 | params = [(path['agent_infos']['prob'],) for path in paths]
51 | path_probs = [categorical_log_pdf(params[i], actions[i]) for i in range(Npath)]
52 | else:
53 | raise NotImplementedError("Unknown distribution type")
54 |
55 | if insert:
56 | for i, path in enumerate(paths):
57 | path[insert_key] = path_probs[i]
58 |
59 | return np.array(path_probs)
60 |
61 | @staticmethod
62 | def _insert_next_state(paths, pad_val=0.0):
63 | for path in paths:
64 | if 'observations_next' in path:
65 | continue
66 | nobs = path['observations'][1:]
67 | nact = path['actions'][1:]
68 | nobs = np.r_[nobs, pad_val*np.expand_dims(np.ones_like(nobs[0]), axis=0)]
69 | nact = np.r_[nact, pad_val*np.expand_dims(np.ones_like(nact[0]), axis=0)]
70 | path['observations_next'] = nobs
71 | path['actions_next'] = nact
72 | return paths
73 |
74 | @staticmethod
75 | def extract_paths(paths, keys=('observations', 'actions'), stack=True):
76 | if stack:
77 | return [np.stack([t[key] for t in paths]).astype(np.float32) for key in keys]
78 | else:
79 | return [np.concatenate([t[key] for t in paths]).astype(np.float32) for key in keys]
80 |
81 | @staticmethod
82 | def sample_batch(*args, batch_size=32):
83 | N = args[0].shape[0]
84 | batch_idxs = np.random.randint(0, N, batch_size) # trajectories are negatives
85 | return [data[batch_idxs] for data in args]
86 |
87 | def fit(self, paths, **kwargs):
88 | raise NotImplementedError()
89 |
90 | def eval(self, paths, **kwargs):
91 | raise NotImplementedError()
92 |
93 | def _make_param_ops(self, vs):
94 | self._params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name)
95 | assert len(self._params)>0
96 | self._assign_plc = [tf.placeholder(tf.float32, shape=param.get_shape(), name='assign_%s'%param.name.replace('/','_').replace(':','_')) for param in self._params]
97 | self._assign_ops = [tf.assign(self._params[i], self._assign_plc[i]) for i in range(len(self._params))]
98 |
99 | def get_params(self):
100 | params = tf.get_default_session().run(self._params)
101 | assert len(params) == len(self._params)
102 | return params
103 |
104 | def set_params(self, params):
105 | tf.get_default_session().run(self._assign_ops, feed_dict={
106 | self._assign_plc[i]: params[i] for i in range(len(self._params))
107 | })
108 |
109 |
110 | class TrajectoryIRL(ImitationLearning):
111 | """
112 | Base class for models that score entire trajectories at once
113 | """
114 | @property
115 | def score_trajectories(self):
116 | return True
117 |
118 | def eval_expert_probs(self, expert_paths, policy, insert=False):
119 | """
120 | Evaluate expert policy probability under current policy
121 | """
122 | if policy.recurrent:
123 | policy.reset([True]*len(expert_paths))
124 | expert_obs = self.extract_paths(expert_paths, keys=('observations',))[0]
125 | agent_infos = []
126 | for t in range(expert_obs.shape[1]):
127 | a, infos = policy.get_actions(expert_obs[:, t])
128 | agent_infos.append(infos)
129 | agent_infos_stack = tensor_utils.stack_tensor_dict_list(agent_infos)
130 | for key in agent_infos_stack:
131 | agent_infos_stack[key] = np.transpose(agent_infos_stack[key], axes=[1,0,2])
132 | agent_infos_transpose = tensor_utils.split_tensor_dict_list(agent_infos_stack)
133 | for i, path in enumerate(expert_paths):
134 | path['agent_infos'] = agent_infos_transpose[i]
135 | else:
136 | for path in expert_paths:
137 | actions, agent_infos = policy.get_actions(path['observations'])
138 | path['agent_infos'] = agent_infos
139 | return self._compute_path_probs(expert_paths, insert=insert)
140 |
141 |
142 |
143 | class SingleTimestepIRL(ImitationLearning):
144 | """
145 | Base class for models that score single timesteps at once
146 | """
147 | @staticmethod
148 | def extract_paths(paths, keys=('observations', 'actions'), stack=False):
149 | return ImitationLearning.extract_paths(paths, keys=keys, stack=stack)
150 |
151 | @staticmethod
152 | def unpack(data, paths):
153 | lengths = [path['observations'].shape[0] for path in paths]
154 | unpacked = []
155 | idx = 0
156 | for l in lengths:
157 | unpacked.append(data[idx:idx+l])
158 | idx += l
159 | return unpacked
160 |
161 | @property
162 | def score_trajectories(self):
163 | return False
164 |
165 | def eval_expert_probs(self, expert_paths, policy, insert=False):
166 | """
167 | Evaluate expert policy probability under current policy
168 | """
169 | for traj in expert_paths:
170 | if 'agent_infos' in traj:
171 | del traj['agent_infos']
172 | if 'a_logprobs' in traj:
173 | del traj['a_logprobs']
174 |
175 | if isinstance(policy, np.ndarray):
176 | return self._compute_path_probs(expert_paths, insert=insert)
177 | elif hasattr(policy, 'recurrent') and policy.recurrent:
178 | policy.reset([True]*len(expert_paths))
179 | expert_obs = self.extract_paths(expert_paths, keys=('observations',), stack=True)[0]
180 | agent_infos = []
181 | for t in range(expert_obs.shape[1]):
182 | a, infos = policy.get_actions(expert_obs[:, t])
183 | agent_infos.append(infos)
184 | agent_infos_stack = tensor_utils.stack_tensor_dict_list(agent_infos)
185 | for key in agent_infos_stack:
186 | agent_infos_stack[key] = np.transpose(agent_infos_stack[key], axes=[1,0,2])
187 | agent_infos_transpose = tensor_utils.split_tensor_dict_list(agent_infos_stack)
188 | for i, path in enumerate(expert_paths):
189 | path['agent_infos'] = agent_infos_transpose[i]
190 | else:
191 | for path in expert_paths:
192 | actions, agent_infos = policy.get_actions(path['observations'])
193 | path['agent_infos'] = agent_infos
194 | return self._compute_path_probs(expert_paths, insert=insert)
195 |
196 |
197 | class GAIL(SingleTimestepIRL):
198 | """
199 | Generative adverserial imitation learning
200 | See https://arxiv.org/pdf/1606.03476.pdf
201 |
202 | This version consumes single timesteps.
203 | """
204 | def __init__(self, env_spec, expert_trajs=None,
205 | discrim_arch=relu_net,
206 | discrim_arch_args={},
207 | name='gail'):
208 | super(GAIL, self).__init__()
209 | self.dO = env_spec.observation_space.flat_dim
210 | self.dU = env_spec.action_space.flat_dim
211 | self.set_demos(expert_trajs)
212 |
213 | # build energy model
214 | with tf.variable_scope(name) as vs:
215 | # Should be batch_size x T x dO/dU
216 | self.obs_t = tf.placeholder(tf.float32, [None, self.dO], name='obs')
217 | self.act_t = tf.placeholder(tf.float32, [None, self.dU], name='act')
218 | self.labels = tf.placeholder(tf.float32, [None, 1], name='labels')
219 | self.lr = tf.placeholder(tf.float32, (), name='lr')
220 |
221 | obs_act = tf.concat([self.obs_t, self.act_t], axis=1)
222 | logits = discrim_arch(obs_act, **discrim_arch_args)
223 | self.predictions = tf.nn.sigmoid(logits)
224 | self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=self.labels))
225 | self.step = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(self.loss)
226 | self._make_param_ops(vs)
227 |
228 |
229 | def fit(self, trajs, batch_size=32, max_itrs=100, **kwargs):
230 | obs, acts = self.extract_paths(trajs)
231 | expert_obs, expert_acts = self.expert_trajs_extracted
232 |
233 | # Train discriminator
234 | for it in TrainingIterator(max_itrs, heartbeat=5):
235 | obs_batch, act_batch = self.sample_batch(obs, acts, batch_size=batch_size)
236 | expert_obs_batch, expert_act_batch = self.sample_batch(expert_obs, expert_acts, batch_size=batch_size)
237 | labels = np.zeros((batch_size*2, 1))
238 | labels[batch_size:] = 1.0
239 | obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0)
240 | act_batch = np.concatenate([act_batch, expert_act_batch], axis=0)
241 |
242 | loss, _ = tf.get_default_session().run([self.loss, self.step], feed_dict={
243 | self.act_t: act_batch,
244 | self.obs_t: obs_batch,
245 | self.labels: labels,
246 | self.lr: 1e-3
247 | })
248 |
249 | it.record('loss', loss)
250 | if it.heartbeat:
251 | print(it.itr_message())
252 | mean_loss = it.pop_mean('loss')
253 | print('\tLoss:%f' % mean_loss)
254 | return mean_loss
255 |
256 | def eval(self, paths, **kwargs):
257 | """
258 | Return bonus
259 | """
260 | obs, acts = self.extract_paths(paths)
261 | scores = tf.get_default_session().run(self.predictions,
262 | feed_dict={self.act_t: acts, self.obs_t: obs})
263 |
264 | # reward = log D(s, a)
265 | scores = np.log(scores[:,0]+LOG_REG)
266 | return self.unpack(scores, paths)
267 |
268 |
269 | class AIRLStateAction(SingleTimestepIRL):
270 | """
271 | This version consumes single timesteps.
272 | """
273 | def __init__(self, env_spec, expert_trajs=None,
274 | discrim_arch=relu_net,
275 | discrim_arch_args={},
276 | l2_reg=0,
277 | discount=1.0,
278 | name='gcl'):
279 | super(AIRLStateAction, self).__init__()
280 | self.dO = env_spec.observation_space.flat_dim
281 | self.dU = env_spec.action_space.flat_dim
282 | self.set_demos(expert_trajs)
283 |
284 | # build energy model
285 | with tf.variable_scope(name) as _vs:
286 | # Should be batch_size x T x dO/dU
287 | self.obs_t = tf.placeholder(tf.float32, [None, self.dO], name='obs')
288 | self.act_t = tf.placeholder(tf.float32, [None, self.dU], name='act')
289 | self.labels = tf.placeholder(tf.float32, [None, 1], name='labels')
290 | self.lprobs = tf.placeholder(tf.float32, [None, 1], name='log_probs')
291 | self.lr = tf.placeholder(tf.float32, (), name='lr')
292 |
293 | obs_act = tf.concat([self.obs_t, self.act_t], axis=1)
294 | with tf.variable_scope('discrim') as dvs:
295 | with tf.variable_scope('energy'):
296 | self.energy = discrim_arch(obs_act, **discrim_arch_args)
297 | # we do not learn a separate log Z(s) because it is impossible to separate from the energy
298 | # In a discrete domain we can explicitly normalize to calculate log Z(s)
299 | log_p_tau = -self.energy
300 | discrim_vars = tf.get_collection('reg_vars', scope=dvs.name)
301 |
302 | log_q_tau = self.lprobs
303 |
304 | if l2_reg > 0:
305 | reg_loss = l2_reg*tf.reduce_sum([tf.reduce_sum(tf.square(var)) for var in discrim_vars])
306 | else:
307 | reg_loss = 0
308 |
309 | log_pq = tf.reduce_logsumexp([log_p_tau, log_q_tau], axis=0)
310 | self.d_tau = tf.exp(log_p_tau-log_pq)
311 | cent_loss = -tf.reduce_mean(self.labels*(log_p_tau-log_pq) + (1-self.labels)*(log_q_tau-log_pq))
312 |
313 | self.loss = cent_loss + reg_loss
314 | self.step = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(self.loss)
315 | self._make_param_ops(_vs)
316 |
317 |
318 | def fit(self, paths, policy=None, batch_size=32, max_itrs=100, logger=None, lr=1e-3,**kwargs):
319 | #self._compute_path_probs(paths, insert=True)
320 | self.eval_expert_probs(paths, policy, insert=True)
321 | self.eval_expert_probs(self.expert_trajs, policy, insert=True)
322 | obs, acts, path_probs = self.extract_paths(paths, keys=('observations', 'actions', 'a_logprobs'))
323 | expert_obs, expert_acts, expert_probs = self.extract_paths(self.expert_trajs, keys=('observations', 'actions', 'a_logprobs'))
324 |
325 | # Train discriminator
326 | for it in TrainingIterator(max_itrs, heartbeat=5):
327 | obs_batch, act_batch, lprobs_batch = \
328 | self.sample_batch(obs, acts, path_probs, batch_size=batch_size)
329 |
330 | expert_obs_batch, expert_act_batch, expert_lprobs_batch = \
331 | self.sample_batch(expert_obs, expert_acts, expert_probs, batch_size=batch_size)
332 |
333 | labels = np.zeros((batch_size*2, 1))
334 | labels[batch_size:] = 1.0
335 | obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0)
336 | act_batch = np.concatenate([act_batch, expert_act_batch], axis=0)
337 | lprobs_batch = np.expand_dims(np.concatenate([lprobs_batch, expert_lprobs_batch], axis=0), axis=1).astype(np.float32)
338 |
339 | loss, _ = tf.get_default_session().run([self.loss, self.step], feed_dict={
340 | self.act_t: act_batch,
341 | self.obs_t: obs_batch,
342 | self.labels: labels,
343 | self.lprobs: lprobs_batch,
344 | self.lr: lr
345 | })
346 |
347 | it.record('loss', loss)
348 | if it.heartbeat:
349 | print(it.itr_message())
350 | mean_loss = it.pop_mean('loss')
351 | print('\tLoss:%f' % mean_loss)
352 | if logger:
353 | energy = tf.get_default_session().run(self.energy,
354 | feed_dict={self.act_t: acts, self.obs_t: obs})
355 | logger.record_tabular('IRLAverageEnergy', np.mean(energy))
356 | logger.record_tabular('IRLAverageLogQtau', np.mean(path_probs))
357 | logger.record_tabular('IRLMedianLogQtau', np.median(path_probs))
358 |
359 | energy = tf.get_default_session().run(self.energy,
360 | feed_dict={self.act_t: expert_acts, self.obs_t: expert_obs})
361 | logger.record_tabular('IRLAverageExpertEnergy', np.mean(energy))
362 | #logger.record_tabular('GCLAverageExpertLogPtau', np.mean(-energy-logZ))
363 | logger.record_tabular('IRLAverageExpertLogQtau', np.mean(expert_probs))
364 | logger.record_tabular('IRLMedianExpertLogQtau', np.median(expert_probs))
365 | return mean_loss
366 |
367 |
368 | def eval(self, paths, **kwargs):
369 | """
370 | Return bonus
371 | """
372 | obs, acts = self.extract_paths(paths)
373 |
374 | energy = tf.get_default_session().run(self.energy,
375 | feed_dict={self.act_t: acts, self.obs_t: obs})
376 | energy = -energy[:,0]
377 | return self.unpack(energy, paths)
378 |
379 |
380 | class GAN_GCL(TrajectoryIRL):
381 | """
382 | Guided cost learning, GAN formulation with learned partition function
383 | See https://arxiv.org/pdf/1611.03852.pdf
384 | """
385 | def __init__(self, env_spec, expert_trajs=None,
386 | discrim_arch=feedforward_energy,
387 | discrim_arch_args={},
388 | l2_reg = 0,
389 | discount = 1.0,
390 | init_itrs = None,
391 | score_dtau=False,
392 | state_only=False,
393 | name='trajprior'):
394 | super(GAN_GCL, self).__init__()
395 | self.dO = env_spec.observation_space.flat_dim
396 | self.dU = env_spec.action_space.flat_dim
397 | self.score_dtau = score_dtau
398 | self.set_demos(expert_trajs)
399 |
400 | # build energy model
401 | with tf.variable_scope(name) as vs:
402 | # Should be batch_size x T x dO/dU
403 | self.obs_t = tf.placeholder(tf.float32, [None, None, self.dO], name='obs')
404 | self.act_t = tf.placeholder(tf.float32, [None, None, self.dU], name='act')
405 | self.traj_logprobs = tf.placeholder(tf.float32, [None, None], name='traj_probs')
406 | self.labels = tf.placeholder(tf.float32, [None, 1], name='labels')
407 | self.lr = tf.placeholder(tf.float32, (), name='lr')
408 |
409 | if state_only:
410 | obs_act = self.obs_t
411 | else:
412 | obs_act = tf.concat([self.obs_t, self.act_t], axis=2)
413 |
414 | with tf.variable_scope('discrim') as vs2:
415 | self.energy = discrim_arch(obs_act, **discrim_arch_args)
416 | discrim_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs2.name)
417 |
418 | self.energy_timestep = self.energy
419 | # Don't train separate log Z because we can't fully separate it from the energy function
420 | if discount >= 1.0:
421 | log_p_tau = tf.reduce_sum(-self.energy, axis=1)
422 | else:
423 | log_p_tau = discounted_reduce_sum(-self.energy, discount=discount, axis=1)
424 | log_q_tau = tf.reduce_sum(self.traj_logprobs, axis=1, keep_dims=True)
425 |
426 | # numerical stability trick
427 | log_pq = tf.reduce_logsumexp([log_p_tau, log_q_tau], axis=0)
428 | self.d_tau = tf.exp(log_p_tau-log_pq)
429 | cent_loss = -tf.reduce_mean(self.labels*(log_p_tau-log_pq) + (1-self.labels)*(log_q_tau-log_pq))
430 |
431 | if l2_reg > 0:
432 | reg_loss = l2_reg*tf.reduce_sum([tf.reduce_sum(tf.square(var)) for var in discrim_vars])
433 | else:
434 | reg_loss = 0
435 |
436 | #self.predictions = tf.nn.sigmoid(logits)
437 | self.loss = cent_loss + reg_loss
438 | self.step = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(self.loss)
439 | self._make_param_ops(vs)
440 |
441 | @property
442 | def score_trajectories(self):
443 | return False
444 |
445 |
446 | def fit(self, paths, policy=None, batch_size=32, max_itrs=100, logger=None, lr=1e-3,**kwargs):
447 | self._compute_path_probs(paths, insert=True)
448 | self.eval_expert_probs(self.expert_trajs, policy, insert=True)
449 | obs, acts, path_probs = self.extract_paths(paths, keys=('observations', 'actions', 'a_logprobs'))
450 | expert_obs, expert_acts, expert_probs = self.extract_paths(self.expert_trajs, keys=('observations', 'actions', 'a_logprobs'))
451 |
452 | # Train discriminator
453 | for it in TrainingIterator(max_itrs, heartbeat=5):
454 | obs_batch, act_batch, lprobs_batch = \
455 | self.sample_batch(obs, acts, path_probs, batch_size=batch_size)
456 |
457 | expert_obs_batch, expert_act_batch, expert_lprobs_batch = \
458 | self.sample_batch(expert_obs, expert_acts, expert_probs, batch_size=batch_size)
459 | T = expert_obs_batch.shape[1]
460 |
461 | labels = np.zeros((batch_size*2, 1))
462 | labels[batch_size:] = 1.0
463 | obs_batch = np.concatenate([obs_batch, expert_obs_batch], axis=0)
464 | act_batch = np.concatenate([act_batch, expert_act_batch], axis=0)
465 | lprobs_batch = np.concatenate([lprobs_batch, expert_lprobs_batch], axis=0)
466 |
467 | loss, _ = tf.get_default_session().run([self.loss, self.step], feed_dict={
468 | self.act_t: act_batch,
469 | self.obs_t: obs_batch,
470 | self.labels: labels,
471 | self.traj_logprobs: lprobs_batch,
472 | self.lr: lr,
473 | })
474 |
475 | it.record('loss', loss)
476 | if it.heartbeat:
477 | print(it.itr_message())
478 | mean_loss = it.pop_mean('loss')
479 | print('\tLoss:%f' % mean_loss)
480 |
481 | if logger:
482 | energy, dtau = tf.get_default_session().run([self.energy_timestep, self.d_tau],
483 | feed_dict={self.act_t: acts, self.obs_t: obs,
484 | self.traj_logprobs: path_probs})
485 | #logger.record_tabular('GCLLogZ', logZ)
486 | logger.record_tabular('IRLAverageEnergy', np.mean(energy))
487 | #logger.record_tabular('GCLAverageLogPtau', np.mean(-energy))
488 | logger.record_tabular('IRLAverageLogQtau', np.mean(path_probs))
489 | logger.record_tabular('IRLMedianLogQtau', np.median(path_probs))
490 | logger.record_tabular('IRLAverageDtau', np.mean(dtau))
491 |
492 | energy, dtau = tf.get_default_session().run([self.energy_timestep, self.d_tau],
493 | feed_dict={self.act_t: expert_acts, self.obs_t: expert_obs,
494 | self.traj_logprobs: expert_probs})
495 | logger.record_tabular('IRLAverageExpertEnergy', np.mean(energy))
496 | #logger.record_tabular('GCLAverageExpertLogPtau', np.mean(-energy))
497 | logger.record_tabular('IRLAverageExpertLogQtau', np.mean(expert_probs))
498 | logger.record_tabular('IRLMedianExpertLogQtau', np.median(expert_probs))
499 | logger.record_tabular('IRLAverageExpertDtau', np.mean(dtau))
500 | return mean_loss
501 |
502 | def eval(self, paths, **kwargs):
503 | """
504 | Return bonus
505 | """
506 | obs, acts = self.extract_paths(paths)
507 |
508 | scores = tf.get_default_session().run(self.energy,
509 | feed_dict={self.act_t: acts, self.obs_t: obs})
510 | scores = -scores[:,:,0]
511 | return scores
512 |
513 |
--------------------------------------------------------------------------------
/inverse_rl/models/tf_util.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | REG_VARS = 'reg_vars'
5 |
6 | def linear(X, dout, name, bias=True):
7 | with tf.variable_scope(name):
8 | dX = int(X.get_shape()[-1])
9 | W = tf.get_variable('W', shape=(dX, dout))
10 | tf.add_to_collection(REG_VARS, W)
11 | if bias:
12 | b = tf.get_variable('b', initializer=tf.constant(np.zeros(dout).astype(np.float32)))
13 | else:
14 | b = 0
15 | return tf.matmul(X, W)+b
16 |
17 | def discounted_reduce_sum(X, discount, axis=-1):
18 | if discount != 1.0:
19 | disc = tf.cumprod(discount*tf.ones_like(X), axis=axis)
20 | else:
21 | disc = 1.0
22 | return tf.reduce_sum(X*disc, axis=axis)
23 |
24 | def assert_shape(tens, shape):
25 | assert tens.get_shape().is_compatible_with(shape)
26 |
27 | def relu_layer(X, dout, name):
28 | return tf.nn.relu(linear(X, dout, name))
29 |
30 | def softplus_layer(X, dout, name):
31 | return tf.nn.softplus(linear(X, dout, name))
32 |
33 | def tanh_layer(X, dout, name):
34 | return tf.nn.tanh(linear(X, dout, name))
35 |
36 | def get_session_config():
37 | session_config = tf.ConfigProto()
38 | session_config.gpu_options.allow_growth = True
39 | #session_config.gpu_options.per_process_gpu_memory_fraction = 0.2
40 | return session_config
41 |
42 |
43 | def load_prior_params(pkl_fname):
44 | import joblib
45 | with tf.Session(config=get_session_config()):
46 | params = joblib.load(pkl_fname)
47 | tf.reset_default_graph()
48 | #joblib.dump(params, file_name, compress=3)
49 | params = params['irl_params']
50 | #print(params)
51 | assert params is not None
52 | return params
53 |
--------------------------------------------------------------------------------
/inverse_rl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from inverse_rl.utils.general import *
2 |
--------------------------------------------------------------------------------
/inverse_rl/utils/general.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time
3 |
4 | def flatten_list(lol):
5 | return [ a for b in lol for a in b ]
6 |
7 | class TrainingIterator(object):
8 | def __init__(self, itrs, heartbeat=float('inf')):
9 | self.itrs = itrs
10 | self.heartbeat_time = heartbeat
11 | self.__vals = {}
12 |
13 | def random_idx(self, N, size):
14 | return np.random.randint(0, N, size=size)
15 |
16 | @property
17 | def itr(self):
18 | return self.__itr
19 |
20 | @property
21 | def heartbeat(self):
22 | return self.__heartbeat
23 |
24 | @property
25 | def elapsed(self):
26 | assert self.heartbeat, 'elapsed is only valid when heartbeat=True'
27 | return self.__elapsed
28 |
29 | def itr_message(self):
30 | return '==> Itr %d/%d (elapsed:%.2f)' % (self.itr+1, self.itrs, self.elapsed)
31 |
32 | def record(self, key, value):
33 | if key in self.__vals:
34 | self.__vals[key].append(value)
35 | else:
36 | self.__vals[key] = [value]
37 |
38 | def pop(self, key):
39 | vals = self.__vals.get(key, [])
40 | del self.__vals[key]
41 | return vals
42 |
43 | def pop_mean(self, key):
44 | return np.mean(self.pop(key))
45 |
46 | def __iter__(self):
47 | prev_time = time.time()
48 | self.__heartbeat = False
49 | for i in range(self.itrs):
50 | self.__itr = i
51 | cur_time = time.time()
52 | if (cur_time-prev_time) > self.heartbeat_time or i==(self.itrs-1):
53 | self.__heartbeat = True
54 | self.__elapsed = cur_time-prev_time
55 | prev_time = cur_time
56 | yield self
57 | self.__heartbeat = False
--------------------------------------------------------------------------------
/inverse_rl/utils/hyper_sweep.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage
3 |
4 | args = {
5 | 'param1': [1e-3, 1e-2, 1e-2],
6 | 'param2': [1,5,10,20],
7 | }
8 |
9 | run_sweep_parallel(func, args)
10 |
11 | or
12 |
13 | run_sweep_serial(func, args)
14 |
15 | """
16 | import itertools
17 | import multiprocessing
18 | import random
19 | from datetime import datetime
20 |
21 | class Sweeper(object):
22 | def __init__(self, hyper_config, repeat):
23 | self.hyper_config = hyper_config
24 | self.repeat = repeat
25 |
26 | def __iter__(self):
27 | count = 0
28 | for _ in range(self.repeat):
29 | for config in itertools.product(*[val for val in self.hyper_config.values()]):
30 | kwargs = {key:config[i] for i, key in enumerate(self.hyper_config.keys())}
31 | timestamp = datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
32 | kwargs['exp_name'] = "%s_%d" % (timestamp, count)
33 | count += 1
34 | yield kwargs
35 |
36 |
37 |
38 | def run_sweep_serial(run_method, params, repeat=1):
39 | sweeper = Sweeper(params, repeat)
40 | for config in sweeper:
41 | import tensorflow as tf; tf.reset_default_graph()
42 | run_method(**config)
43 |
44 | def kwargs_wrapper(args_method):
45 | import tensorflow as tf; tf.reset_default_graph()
46 | args, method = args_method
47 | return method(**args)
48 |
49 |
50 | def run_sweep_parallel(run_method, params, repeat=1, num_cpu=multiprocessing.cpu_count()):
51 | sweeper = Sweeper(params, repeat)
52 | pool = multiprocessing.Pool(num_cpu)
53 | exp_args = []
54 | for config in sweeper:
55 | exp_args.append((config, run_method))
56 | random.shuffle(exp_args)
57 | pool.map(kwargs_wrapper, exp_args)
58 |
59 |
60 | def example_run_method(exp_name, param1, param2='a', param3=3, param4=4):
61 | import time
62 | time.sleep(1.0)
63 | print(exp_name, param1, param2, param3, param4)
64 |
65 |
66 | if __name__ == "__main__":
67 | sweep_op = {
68 | 'param1': [1e-3, 1e-2, 1e-1],
69 | 'param2': [1,5,10,20],
70 | 'param3': [True, False]
71 | }
72 | run_sweep_parallel(example_run_method, sweep_op, repeat=2)
--------------------------------------------------------------------------------
/inverse_rl/utils/hyperparametrized.py:
--------------------------------------------------------------------------------
1 | CLSNAME = '__clsname__'
2 | _HYPER_ = '__hyper__'
3 | _HYPERNAME_ = '__hyper_clsname__'
4 |
5 |
6 | def extract_hyperparams(obj):
7 | if any([isinstance(obj, type_) for type_ in (int, float, str)]):
8 | return obj
9 | elif isinstance(type(obj), Hyperparametrized):
10 | hypers = getattr(obj, _HYPER_)
11 | hypers[CLSNAME] = getattr(obj, _HYPERNAME_)
12 | for attr in hypers:
13 | hypers[attr] = extract_hyperparams(hypers[attr])
14 | return hypers
15 | return type(obj).__name__
16 |
17 | class Hyperparametrized(type):
18 | def __new__(self, clsname, bases, clsdict):
19 | old_init = clsdict.get('__init__', bases[0].__init__)
20 | def init_wrapper(inst, *args, **kwargs):
21 | hyper = getattr(inst, _HYPER_, {})
22 | hyper.update(kwargs)
23 | setattr(inst, _HYPER_, hyper)
24 |
25 | if getattr(inst, _HYPERNAME_, None) is None:
26 | setattr(inst, _HYPERNAME_, clsname)
27 | return old_init(inst, *args, **kwargs)
28 | clsdict['__init__'] = init_wrapper
29 |
30 | cls = super(Hyperparametrized, self).__new__(self, clsname, bases, clsdict)
31 | return cls
32 |
33 |
34 | class HyperparamWrapper(object, metaclass=Hyperparametrized):
35 | def __init__(self, **hyper_kwargs):
36 | pass
37 |
38 | if __name__ == "__main__":
39 | class Algo1(object, metaclass=Hyperparametrized):
40 | def __init__(self, hyper1=1.0, hyper2=2.0, model1=None):
41 | pass
42 |
43 |
44 | class Algo2(Algo1):
45 | def __init__(self, hyper3=5.0, **kwargs):
46 | super(Algo2, self).__init__(**kwargs)
47 |
48 |
49 | class Model1(object, metaclass=Hyperparametrized):
50 | def __init__(self, hyper1=None):
51 | pass
52 |
53 |
54 | def get_params_json(**kwargs):
55 | hyper_dict = extract_hyperparams(HyperparamWrapper(**kwargs))
56 | del hyper_dict[CLSNAME]
57 | return hyper_dict
58 |
59 | m1 = Model1(hyper1='Test')
60 | a1 = Algo2(hyper1=1.0, hyper2=5.0, hyper3=10.0, model1=m1)
61 |
62 | print( isinstance(type(a1), Hyperparametrized))
63 | print(get_params_json(a1=a1))
64 |
--------------------------------------------------------------------------------
/inverse_rl/utils/log_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import joblib
4 | import json
5 | import contextlib
6 |
7 | import rllab.misc.logger as rllablogger
8 | import tensorflow as tf
9 | import numpy as np
10 |
11 | from inverse_rl.utils.hyperparametrized import extract_hyperparams
12 |
13 | @contextlib.contextmanager
14 | def rllab_logdir(algo=None, dirname=None):
15 | if dirname:
16 | rllablogger.set_snapshot_dir(dirname)
17 | dirname = rllablogger.get_snapshot_dir()
18 | rllablogger.add_tabular_output(os.path.join(dirname, 'progress.csv'))
19 | if algo:
20 | with open(os.path.join(dirname, 'params.json'), 'w') as f:
21 | params = extract_hyperparams(algo)
22 | json.dump(params, f)
23 | yield dirname
24 | rllablogger.remove_tabular_output(os.path.join(dirname, 'progress.csv'))
25 |
26 |
27 | def get_expert_fnames(log_dir, n=5):
28 | print('Looking for paths')
29 | import re
30 | itr_reg = re.compile(r"itr_(?P[0-9]+)\.pkl")
31 |
32 | itr_files = []
33 | for i, filename in enumerate(os.listdir(log_dir)):
34 | m = itr_reg.match(filename)
35 | if m:
36 | itr_count = m.group('itr_count')
37 | itr_files.append((itr_count, filename))
38 |
39 | itr_files = sorted(itr_files, key=lambda x: int(x[0]), reverse=True)[:n]
40 | for itr_file_and_count in itr_files:
41 | fname = os.path.join(log_dir, itr_file_and_count[1])
42 | print('Loading %s' % fname)
43 | yield fname
44 |
45 |
46 | def load_experts(fname, max_files=float('inf'), min_return=None):
47 | config = tf.ConfigProto()
48 | config.gpu_options.allow_growth = True
49 | if hasattr(fname, '__iter__'):
50 | paths = []
51 | for fname_ in fname:
52 | tf.reset_default_graph()
53 | with tf.Session(config=config):
54 | snapshot_dict = joblib.load(fname_)
55 | paths.extend(snapshot_dict['paths'])
56 | else:
57 | with tf.Session(config=config):
58 | snapshot_dict = joblib.load(fname)
59 | paths = snapshot_dict['paths']
60 | tf.reset_default_graph()
61 |
62 | trajs = []
63 | for path in paths:
64 | obses = path['observations']
65 | actions = path['actions']
66 | returns = path['returns']
67 | total_return = np.sum(returns)
68 | if (min_return is None) or (total_return >= min_return):
69 | traj = {'observations': obses, 'actions': actions}
70 | trajs.append(traj)
71 | random.shuffle(trajs)
72 | print('Loaded %d trajectories' % len(trajs))
73 | return trajs
74 |
75 |
76 | def load_latest_experts(logdir, n=5, min_return=None):
77 | return load_experts(get_expert_fnames(logdir, n=n), min_return=min_return)
78 |
79 |
80 | def load_latest_experts_multiple_runs(logdir, n=5):
81 | paths = []
82 | for i, dirname in enumerate(os.listdir(logdir)):
83 | dirname = os.path.join(logdir, dirname)
84 | if os.path.isdir(dirname):
85 | print('Loading experts from %s' % dirname)
86 | paths.extend(load_latest_experts(dirname, n=n))
87 | return paths
88 |
--------------------------------------------------------------------------------
/inverse_rl/utils/math_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy as sp
3 | import scipy.stats
4 |
5 | def rle(inarray):
6 | """ run length encoding. Partial credit to R rle function.
7 | Multi datatype arrays catered for including non Numpy
8 | returns: tuple (runlengths, startpositions, values) """
9 | ia = np.array(inarray) # force numpy
10 | n = len(ia)
11 | if n == 0:
12 | return (None, None, None)
13 | else:
14 | y = np.array(ia[1:] != ia[:-1]) # pairwise unequal (string safe)
15 | i = np.append(np.where(y), n - 1) # must include last element posi
16 | z = np.diff(np.append(-1, i)) # run lengths
17 | p = np.cumsum(np.append(0, z))[:-1] # positions
18 | return(z, p, ia[i])
19 |
20 | def split_list_by_lengths(values, lengths):
21 | """
22 |
23 | >>> split_list_by_lengths([0,0,0,1,1,1,2,2,2], [2,2,5])
24 | [[0, 0], [0, 1], [1, 1, 2, 2, 2]]
25 | """
26 | assert np.sum(lengths) == len(values)
27 | idxs = np.cumsum(lengths)
28 | idxs = np.insert(idxs, 0, 0)
29 | return [ values[idxs[i]:idxs[i+1] ] for i in range(len(idxs)-1)]
30 |
31 | def clip_sing(X, clip_val=1):
32 | U, E, V = np.linalg.svd(X, full_matrices=False)
33 | E = np.clip(E, -clip_val, clip_val)
34 | return U.dot(np.diag(E)).dot(V)
35 |
36 | def gauss_log_pdf(params, x):
37 | mean, log_diag_std = params
38 | N, d = mean.shape
39 | cov = np.square(np.exp(log_diag_std))
40 | diff = x-mean
41 | exp_term = -0.5 * np.sum(np.square(diff)/cov, axis=1)
42 | norm_term = -0.5*d*np.log(2*np.pi)
43 | var_term = -0.5 * np.sum(np.log(cov), axis=1)
44 | log_probs = norm_term + var_term + exp_term
45 | return log_probs #sp.stats.multivariate_normal.logpdf(x, mean=mean, cov=cov)
46 |
47 | def categorical_log_pdf(params, x, one_hot=True):
48 | if not one_hot:
49 | raise NotImplementedError()
50 | probs = params[0]
51 | return np.log(np.max(probs * x, axis=1))
52 |
53 |
--------------------------------------------------------------------------------
/scripts/ant_data_collect.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from inverse_rl.algos.trpo import TRPO
4 | from inverse_rl.models.tf_util import get_session_config
5 | from sandbox.rocky.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy
6 | from sandbox.rocky.tf.envs.base import TfEnv
7 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline
8 |
9 | from inverse_rl.envs.env_utils import CustomGymEnv
10 | from inverse_rl.utils.log_utils import rllab_logdir
11 | from inverse_rl.utils.hyper_sweep import run_sweep_parallel, run_sweep_serial
12 |
13 |
14 | def main(exp_name, ent_wt=1.0):
15 | tf.reset_default_graph()
16 | env = TfEnv(CustomGymEnv('CustomAnt-v0', record_video=False, record_log=False))
17 | policy = GaussianMLPPolicy(name='policy', env_spec=env.spec, hidden_sizes=(32, 32))
18 | with tf.Session(config=get_session_config()) as sess:
19 | algo = TRPO(
20 | env=env,
21 | sess=sess,
22 | policy=policy,
23 | n_itr=1500,
24 | batch_size=20000,
25 | max_path_length=500,
26 | discount=0.99,
27 | store_paths=True,
28 | entropy_weight=ent_wt,
29 | baseline=LinearFeatureBaseline(env_spec=env.spec),
30 | exp_name=exp_name,
31 | )
32 | with rllab_logdir(algo=algo, dirname='data/ant_data_collect/%s'%exp_name):
33 | algo.train()
34 |
35 | if __name__ == "__main__":
36 | params_dict = {
37 | 'ent_wt': [0.1]
38 | }
39 | run_sweep_parallel(main, params_dict, repeat=4)
40 |
--------------------------------------------------------------------------------
/scripts/ant_irl.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from sandbox.rocky.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy
4 | from sandbox.rocky.tf.envs.base import TfEnv
5 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline
6 | from rllab.envs.gym_env import GymEnv
7 |
8 | from inverse_rl.envs.env_utils import CustomGymEnv
9 | from inverse_rl.algos.irl_trpo import IRLTRPO
10 | from inverse_rl.models.airl_state import *
11 | from inverse_rl.utils.log_utils import rllab_logdir, load_latest_experts, load_latest_experts_multiple_runs
12 | from inverse_rl.utils.hyper_sweep import run_sweep_parallel, run_sweep_serial
13 |
14 | def main(exp_name=None, fusion=False):
15 | env = TfEnv(CustomGymEnv('CustomAnt-v0', record_video=False, record_log=False))
16 |
17 | # load ~2 iterations worth of data from each forward RL experiment as demos
18 | experts = load_latest_experts_multiple_runs('data/ant_data_collect', n=2)
19 |
20 | irl_model = AIRL(env=env, expert_trajs=experts, state_only=True, fusion=fusion, max_itrs=10)
21 |
22 | policy = GaussianMLPPolicy(name='policy', env_spec=env.spec, hidden_sizes=(32, 32))
23 | algo = IRLTRPO(
24 | env=env,
25 | policy=policy,
26 | irl_model=irl_model,
27 | n_itr=1000,
28 | batch_size=10000,
29 | max_path_length=500,
30 | discount=0.99,
31 | store_paths=True,
32 | irl_model_wt=1.0,
33 | entropy_weight=0.1,
34 | zero_environment_reward=True,
35 | baseline=LinearFeatureBaseline(env_spec=env.spec),
36 | )
37 | with rllab_logdir(algo=algo, dirname='data/ant_state_irl/%s' % exp_name):
38 | with tf.Session():
39 | algo.train()
40 |
41 | if __name__ == "__main__":
42 | params_dict = {
43 | 'fusion': [True]
44 | }
45 | run_sweep_parallel(main, params_dict, repeat=3)
46 |
47 |
--------------------------------------------------------------------------------
/scripts/ant_transfer_disabled.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 |
4 | from sandbox.rocky.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy
5 | from sandbox.rocky.tf.envs.base import TfEnv
6 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline
7 | from rllab.envs.gym_env import GymEnv
8 |
9 |
10 | from inverse_rl.algos.irl_trpo import IRLTRPO
11 | from inverse_rl.envs.env_utils import CustomGymEnv
12 | from inverse_rl.models.airl_state import *
13 | from inverse_rl.models.tf_util import load_prior_params
14 | from inverse_rl.utils.log_utils import rllab_logdir, load_latest_experts
15 | from inverse_rl.utils.hyper_sweep import run_sweep_parallel, run_sweep_serial
16 |
17 |
18 | DATA_DIR = 'data/ant_state_irl'
19 | def main(exp_name, params_folder=None):
20 | env = TfEnv(CustomGymEnv('DisabledAnt-v0', record_video=False, record_log=False))
21 |
22 | irl_itr = 100 # earlier IRL iterations overfit less; 100 seems to work well.
23 | params_file = os.path.join(DATA_DIR, '%s/itr_%d.pkl' % (params_folder, irl_itr))
24 | prior_params = load_prior_params(params_file)
25 |
26 | irl_model = AIRL(env=env, expert_trajs=None, state_only=True)
27 | policy = GaussianMLPPolicy(name='policy', env_spec=env.spec, hidden_sizes=(32, 32))
28 | algo = IRLTRPO(
29 | init_irl_params=prior_params,
30 | env=env,
31 | policy=policy,
32 | irl_model=irl_model,
33 | n_itr=1000,
34 | batch_size=10000,
35 | max_path_length=500,
36 | discount=0.99,
37 | store_paths=False,
38 | train_irl=False,
39 | irl_model_wt=1.0,
40 | entropy_weight=0.1,
41 | zero_environment_reward=True,
42 | baseline=LinearFeatureBaseline(env_spec=env.spec),
43 | log_params_folder=params_folder,
44 | log_experiment_name=exp_name,
45 | )
46 | with rllab_logdir(algo=algo, dirname='data/ant_transfer/%s'%exp_name):
47 | with tf.Session():
48 | algo.train()
49 |
50 | if __name__ == "__main__":
51 | import os
52 | params_folders = os.listdir(DATA_DIR)
53 | params_dict = {
54 | 'params_folder': params_folders,
55 | }
56 | run_sweep_parallel(main, params_dict, repeat=3)
57 |
58 |
--------------------------------------------------------------------------------
/scripts/pendulum_data_collect.py:
--------------------------------------------------------------------------------
1 | from sandbox.rocky.tf.algos.trpo import TRPO
2 | from sandbox.rocky.tf.envs.base import TfEnv
3 | from sandbox.rocky.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy
4 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline
5 | from rllab.envs.gym_env import GymEnv
6 |
7 | from inverse_rl.utils.log_utils import rllab_logdir
8 |
9 | def main():
10 | env = TfEnv(GymEnv('Pendulum-v0', record_video=False, record_log=False))
11 | policy = GaussianMLPPolicy(name='policy', env_spec=env.spec, hidden_sizes=(32, 32))
12 | algo = TRPO(
13 | env=env,
14 | policy=policy,
15 | n_itr=200,
16 | batch_size=1000,
17 | max_path_length=100,
18 | discount=0.99,
19 | store_paths=True,
20 | baseline=LinearFeatureBaseline(env_spec=env.spec)
21 | )
22 |
23 | with rllab_logdir(algo=algo, dirname='data/pendulum'):
24 | algo.train()
25 |
26 | if __name__ == "__main__":
27 | main()
28 |
--------------------------------------------------------------------------------
/scripts/pendulum_gail.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from sandbox.rocky.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy
4 | from sandbox.rocky.tf.envs.base import TfEnv
5 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline
6 | from rllab.envs.gym_env import GymEnv
7 |
8 |
9 | from inverse_rl.algos.irl_trpo import IRLTRPO
10 | from inverse_rl.models.imitation_learning import GAIL
11 | from inverse_rl.utils.log_utils import rllab_logdir, load_latest_experts
12 |
13 | def main():
14 | env = TfEnv(GymEnv('Pendulum-v0', record_video=False, record_log=False))
15 |
16 | experts = load_latest_experts('data/pendulum', n=5)
17 |
18 | irl_model = GAIL(env_spec=env.spec, expert_trajs=experts)
19 | policy = GaussianMLPPolicy(name='policy', env_spec=env.spec, hidden_sizes=(32, 32))
20 | algo = IRLTRPO(
21 | env=env,
22 | policy=policy,
23 | irl_model=irl_model,
24 | n_itr=200,
25 | batch_size=1000,
26 | max_path_length=100,
27 | discount=0.99,
28 | store_paths=True,
29 | discrim_train_itrs=50,
30 | irl_model_wt=1.0,
31 | entropy_weight=0.0, # GAIL should not use entropy unless for exploration
32 | zero_environment_reward=True,
33 | baseline=LinearFeatureBaseline(env_spec=env.spec)
34 | )
35 |
36 | with rllab_logdir(algo=algo, dirname='data/pendulum_gail'):
37 | with tf.Session():
38 | algo.train()
39 |
40 | if __name__ == "__main__":
41 | main()
42 |
--------------------------------------------------------------------------------
/scripts/pendulum_irl.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from sandbox.rocky.tf.policies.gaussian_mlp_policy import GaussianMLPPolicy
4 | from sandbox.rocky.tf.envs.base import TfEnv
5 | from rllab.baselines.linear_feature_baseline import LinearFeatureBaseline
6 | from rllab.envs.gym_env import GymEnv
7 |
8 |
9 | from inverse_rl.algos.irl_trpo import IRLTRPO
10 | from inverse_rl.models.imitation_learning import AIRLStateAction
11 | from inverse_rl.utils.log_utils import rllab_logdir, load_latest_experts
12 |
13 | def main():
14 | env = TfEnv(GymEnv('Pendulum-v0', record_video=False, record_log=False))
15 |
16 | experts = load_latest_experts('data/pendulum', n=5)
17 |
18 | irl_model = AIRLStateAction(env_spec=env.spec, expert_trajs=experts)
19 | policy = GaussianMLPPolicy(name='policy', env_spec=env.spec, hidden_sizes=(32, 32))
20 | algo = IRLTRPO(
21 | env=env,
22 | policy=policy,
23 | irl_model=irl_model,
24 | n_itr=200,
25 | batch_size=1000,
26 | max_path_length=100,
27 | discount=0.99,
28 | store_paths=True,
29 | discrim_train_itrs=50,
30 | irl_model_wt=1.0,
31 | entropy_weight=0.1, # this should be 1.0 but 0.1 seems to work better
32 | zero_environment_reward=True,
33 | baseline=LinearFeatureBaseline(env_spec=env.spec)
34 | )
35 |
36 | with rllab_logdir(algo=algo, dirname='data/pendulum_gcl'):
37 | with tf.Session():
38 | algo.train()
39 |
40 | if __name__ == "__main__":
41 | main()
42 |
--------------------------------------------------------------------------------
/tabular_maxent_irl/README.md:
--------------------------------------------------------------------------------
1 | # Tabular MaxEnt IRL
2 |
3 | This folder contains a self-enclosed version of tabular maximum entropy (MaxEnt) IRL.
4 |
5 | To run on a randomly generated MDP:
6 | ```
7 | python maxent_irl.py
8 | ```
9 |
10 |
11 |
--------------------------------------------------------------------------------
/tabular_maxent_irl/maxent_irl.py:
--------------------------------------------------------------------------------
1 | """
2 | This implements Maximum Entropy IRL using dynamic programming. This
3 |
4 | Simply call tabular_maxent_irl(env, expert_visitations)
5 | The expert visitations can be generated via the compute_visitations function on an expert q_function (exact),
6 | or using compute_visitations_demo on demos (approximate)
7 |
8 | """
9 | import numpy as np
10 | from utils import one_hot_to_flat, flat_to_one_hot
11 | from q_iteration import q_iteration, logsumexp, get_policy
12 | from utils import TrainingIterator
13 | from utils import gd_momentum_optimizer, adam_optimizer
14 |
15 |
16 | def compute_visitation(env, q_fn, ent_wt=1.0, T=50, discount=1.0):
17 | pol_probs = get_policy(q_fn, ent_wt=ent_wt)
18 |
19 | dim_obs = env.observation_space.flat_dim
20 | dim_act = env.action_space.flat_dim
21 | state_visitation = np.expand_dims(env.initial_state_distribution, axis=1)
22 | t_matrix = env.transition_matrix # S x A x S
23 | sa_visit_t = np.zeros((dim_obs, dim_act, T))
24 |
25 | for i in range(T):
26 | sa_visit = state_visitation * pol_probs
27 | sa_visit_t[:,:,i] = sa_visit #(discount**i) * sa_visit
28 | # sum-out (SA)S
29 | new_state_visitation = np.einsum('ij,ijk->k', sa_visit, t_matrix)
30 | state_visitation = np.expand_dims(new_state_visitation, axis=1)
31 | return np.sum(sa_visit_t, axis=2) / float(T)
32 |
33 |
34 | def compute_vistation_demos(env, demos):
35 | dim_obs = env.observation_space.flat_dim
36 | dim_act = env.action_space.flat_dim
37 | counts = np.zeros((dim_obs, dim_act))
38 |
39 | for demo in demos:
40 | obs = demo['observations']
41 | act = demo['actions']
42 | state_ids = one_hot_to_flat(obs)
43 | T = len(state_ids)
44 | for t in range(T):
45 | counts[state_ids[t], act[t]] += 1
46 | return counts / float(np.sum(counts))
47 |
48 |
49 | def sample_states(env, q_fn, visitation_probs, n_sample, ent_wt):
50 | dS, dA = visitation_probs.shape
51 | samples = np.random.choice(np.arange(dS*dA), size=n_sample, p=visitation_probs.reshape(dS*dA))
52 | policy = get_policy(q_fn, ent_wt=ent_wt)
53 | observations = samples // dA
54 | actions = samples % dA
55 | a_logprobs = np.log(policy[observations, actions])
56 |
57 | observations_next = []
58 | for i in range(n_sample):
59 | t_distr = env.tabular_trans_distr(observations[i], actions[i])
60 | next_state = flat_to_one_hot(np.random.choice(np.arange(len(t_distr)), p=t_distr), ndim=dS)
61 | observations_next.append(next_state)
62 | observations_next = np.array(observations_next)
63 |
64 | return {'observations': flat_to_one_hot(observations, ndim=dS),
65 | 'actions': flat_to_one_hot(actions, ndim=dA),
66 | 'a_logprobs': a_logprobs,
67 | 'observations_next': observations_next}
68 |
69 |
70 | def tabular_maxent_irl(env, demo_visitations, num_itrs=50, ent_wt=1.0, lr=1e-3, state_only=False,
71 | discount=0.99, T=5):
72 | dim_obs = env.observation_space.flat_dim
73 | dim_act = env.action_space.flat_dim
74 |
75 | # Initialize policy and reward function
76 | reward_fn = np.zeros((dim_obs, dim_act))
77 | q_rew = np.zeros((dim_obs, dim_act))
78 |
79 | update = adam_optimizer(lr)
80 |
81 | for it in TrainingIterator(num_itrs, heartbeat=1.0):
82 | q_itrs = 20 if it.itr>5 else 100
83 | ### compute policy in closed form
84 | q_rew = q_iteration(env, reward_matrix=reward_fn, ent_wt=ent_wt, warmstart_q=q_rew, K=q_itrs, gamma=discount)
85 |
86 | ### update reward
87 | # need to count how often the policy will visit a particular (s, a) pair
88 | pol_visitations = compute_visitation(env, q_rew, ent_wt=ent_wt, T=T, discount=discount)
89 |
90 | grad = -(demo_visitations - pol_visitations)
91 | it.record('VisitationInfNormError', np.max(np.abs(grad)))
92 | if state_only:
93 | grad = np.sum(grad, axis=1, keepdims=True)
94 | reward_fn = update(reward_fn, grad)
95 |
96 | if it.heartbeat:
97 | print(it.itr_message())
98 | print('\tVisitationError:',it.pop_mean('VisitationInfNormError'))
99 | return reward_fn, q_rew
100 |
101 |
102 | if __name__ == "__main__":
103 | # test IRL
104 | from q_iteration import q_iteration
105 | from simple_env import random_env
106 | np.set_printoptions(suppress=True)
107 |
108 | # Environment parameters
109 | env = random_env(16, 4, seed=1, terminate=False, t_sparsity=0.8)
110 | dS = env.spec.observation_space.flat_dim
111 | dU = env.spec.action_space.flat_dim
112 | dO = 8
113 | ent_wt = 1.0
114 | discount = 0.9
115 | obs_matrix = np.random.randn(dS, dO)
116 |
117 | # Compute optimal policy for double checking
118 | true_q = q_iteration(env, K=150, ent_wt=ent_wt, gamma=discount)
119 | true_sa_visits = compute_visitation(env, true_q, ent_wt=ent_wt, T=5, discount=discount)
120 | expert_pol = get_policy(true_q, ent_wt=ent_wt)
121 |
122 | # Run MaxEnt IRL State-only
123 | learned_rew, learned_q = tabular_maxent_irl(env, true_sa_visits, lr=0.01, num_itrs=1000,
124 | ent_wt=ent_wt, state_only=True,
125 | discount=discount)
126 | learned_pol = get_policy(learned_q, ent_wt=ent_wt)
127 |
128 |
129 | # Normalize reward (if state_only=True, reward is accurate up to a constant)
130 | adjusted_rew = learned_rew - np.mean(learned_rew) + np.mean(env.rew_matrix)
131 |
132 | diff_rew = np.abs(env.rew_matrix - adjusted_rew)
133 | diff_pol = np.abs(expert_pol - learned_pol)
134 | print('----- Results State Only -----')
135 | print('InfNormRewError', np.max(diff_rew))
136 | print('InfNormPolicyError', np.max(diff_pol))
137 | print('AvdDiffRew', np.mean(diff_rew))
138 | print('AvgDiffPol', np.mean(diff_pol))
139 | print('True Reward', env.rew_matrix)
140 | print('Learned Reward', adjusted_rew)
141 |
142 |
143 | # Run MaxEnt IRL State-Action
144 | learned_rew, learned_q = tabular_maxent_irl(env, true_sa_visits, lr=0.01, num_itrs=1000,
145 | ent_wt=ent_wt, state_only=False,
146 | discount=discount)
147 | learned_pol = get_policy(learned_q, ent_wt=ent_wt)
148 |
149 | # Normalize reward (if state_only=True, reward is accurate up to a constant)
150 | adjusted_rew = learned_rew - np.mean(learned_rew) + np.mean(env.rew_matrix)
151 |
152 | diff_rew = np.abs(env.rew_matrix - adjusted_rew)
153 | diff_pol = np.abs(expert_pol - learned_pol)
154 | print('----- Results State-Action -----')
155 | print('InfNormRewError', np.max(diff_rew))
156 | print('InfNormPolicyError', np.max(diff_pol))
157 | print('AvdDiffRew', np.mean(diff_rew))
158 | print('AvgDiffPol', np.mean(diff_pol))
159 | print('True Reward', env.rew_matrix)
160 | print('Learned Reward', adjusted_rew)
161 |
--------------------------------------------------------------------------------
/tabular_maxent_irl/q_iteration.py:
--------------------------------------------------------------------------------
1 | """
2 | Use q-iteration to solve for an optimal policy
3 |
4 | Usage: q_iteration(env, gamma=discount factor, ent_wt= entropy bonus)
5 | """
6 | import numpy as np
7 | from scipy.misc import logsumexp as sp_lse
8 |
9 | def softmax(q, alpha=1.0):
10 | q = (1.0/alpha)*q
11 | q = q-np.max(q)
12 | probs = np.exp(q)
13 | probs = probs/np.sum(probs)
14 | return probs
15 |
16 | def logsumexp(q, alpha=1.0, axis=1):
17 | return alpha*sp_lse((1.0/alpha)*q, axis=axis)
18 |
19 | def get_policy(q_fn, ent_wt=1.0):
20 | """
21 | Return a policy by normalizing a Q-function
22 | """
23 | v_rew = logsumexp(q_fn, alpha=ent_wt)
24 | adv_rew = q_fn - np.expand_dims(v_rew, axis=1)
25 | pol_probs = np.exp((1.0/ent_wt)*adv_rew)
26 | assert np.all(np.isclose(np.sum(pol_probs, axis=1), 1.0)), str(pol_probs)
27 | return pol_probs
28 |
29 | def q_iteration(env, reward_matrix=None, K=50, gamma=0.99, ent_wt=0.1, warmstart_q=None, policy=None):
30 | """
31 | Perform tabular soft Q-iteration
32 |
33 | If policy is given, this computes Q_pi rather than Q_star
34 | """
35 | dim_obs = env.observation_space.flat_dim
36 | dim_act = env.action_space.flat_dim
37 | if reward_matrix is None:
38 | reward_matrix = env.rew_matrix
39 | if warmstart_q is None:
40 | q_fn = np.zeros((dim_obs, dim_act))
41 | else:
42 | q_fn = warmstart_q
43 |
44 | t_matrix = env.transition_matrix
45 | for k in range(K):
46 | if policy is None:
47 | v_fn = logsumexp(q_fn, alpha=ent_wt)
48 | else:
49 | v_fn = np.sum((q_fn - np.log(policy))*policy, axis=1)
50 | new_q = reward_matrix + gamma*t_matrix.dot(v_fn)
51 | q_fn = new_q
52 | return q_fn
53 |
54 |
--------------------------------------------------------------------------------
/tabular_maxent_irl/simple_env.py:
--------------------------------------------------------------------------------
1 | import itertools
2 |
3 | from matplotlib import cm
4 |
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | from matplotlib.patches import Rectangle
8 | from rllab.envs.base import Env
9 | from rllab.misc import logger
10 | from rllab.spaces import Box
11 | from rllab.spaces import Discrete
12 |
13 |
14 | from utils import flat_to_one_hot, np_seed
15 |
16 | class DiscreteEnv(Env):
17 | def __init__(self, transition_matrix, reward, init_state, terminate_on_reward=False):
18 | super(DiscreteEnv, self).__init__()
19 | dX, dA, dXX = transition_matrix.shape
20 | self.nstates = dX
21 | self.nactions = dA
22 | self.transitions = transition_matrix
23 | self.init_state = init_state
24 | self.reward = reward
25 | self.terminate_on_reward = terminate_on_reward
26 |
27 | self.__observation_space = Box(0, 1, shape=(self.nstates,))
28 | #max_A = 0
29 | #for trans in self.transitions:
30 | # max_A = max(max_A, len(self.transitions[trans]))
31 | self.__action_space = Discrete(dA)
32 |
33 | def reset(self):
34 | self.cur_state = self.init_state
35 | obs = flat_to_one_hot(self.cur_state, ndim=self.nstates)
36 | return obs
37 |
38 | def step(self, a):
39 | transition_probs = self.transitions[self.cur_state, a]
40 | next_state = np.random.choice(np.arange(self.nstates), p=transition_probs)
41 | r = self.reward[self.cur_state, a, next_state]
42 | self.cur_state = next_state
43 | obs = flat_to_one_hot(self.cur_state, ndim=self.nstates)
44 |
45 | done = False
46 | if self.terminate_on_reward and r>0:
47 | done = True
48 | return obs, r, done, {}
49 |
50 | def tabular_trans_distr(self, s, a):
51 | return self.transitions[s, a]
52 |
53 | def reward_fn(self, s, a):
54 | return self.reward[s, a]
55 |
56 | def log_diagnostics(self, paths):
57 | #Ntraj = len(paths)
58 | #acts = np.array([traj['actions'] for traj in paths])
59 | obs = np.array([np.sum(traj['observations'], axis=0) for traj in paths])
60 |
61 | state_count = np.sum(obs, axis=0)
62 | #state_count = np.mean(state_count, axis=0)
63 | state_freq = state_count/float(np.sum(state_count))
64 | for state in range(self.nstates):
65 | logger.record_tabular('AvgStateFreq%d'%state, state_freq[state])
66 |
67 | @property
68 | def transition_matrix(self):
69 | return self.transitions
70 |
71 | @property
72 | def rew_matrix(self):
73 | return self.reward
74 |
75 | @property
76 | def initial_state_distribution(self):
77 | return flat_to_one_hot(self.init_state, ndim=self.nstates)
78 |
79 | @property
80 | def action_space(self):
81 | return self.__action_space
82 |
83 | @property
84 | def observation_space(self):
85 | return self.__observation_space
86 |
87 |
88 | def random_env(Nstates, Nact, seed=None, terminate=False, t_sparsity=0.75):
89 | assert Nstates >= 2
90 | if seed is None:
91 | seed = 0
92 | reward_state=0
93 | start_state=1
94 | with np_seed(seed):
95 | transition_matrix = np.random.rand(Nstates, Nact, Nstates)
96 | transition_matrix = np.exp(transition_matrix)
97 | for s in range(Nstates):
98 | for a in range(Nact):
99 | zero_idxs = np.random.randint(0, Nstates, size=int(Nstates*t_sparsity))
100 | transition_matrix[s, a, zero_idxs] = 0.0
101 |
102 | transition_matrix = transition_matrix/np.sum(transition_matrix, axis=2, keepdims=True)
103 | reward = np.zeros((Nstates, Nact))
104 | reward[reward_state, :] = 1.0
105 | #reward = np.random.randn(Nstates,1 ) + reward
106 |
107 | stable_action = seed % Nact #np.random.randint(0, Nact)
108 | transition_matrix[reward_state, stable_action] = np.zeros(Nstates)
109 | transition_matrix[reward_state, stable_action, reward_state] = 1
110 | return DiscreteEnv(transition_matrix, reward=reward, init_state=start_state, terminate_on_reward=terminate)
111 |
112 |
113 | if __name__ == '__main__':
114 | env = random_env(5, 2, seed=0)
115 | print(env.transitions)
116 | print(env.transitions[0,0])
117 | print(env.transitions[0,1])
118 | env.reset()
119 | for _ in range(100):
120 | print(env.step(env.action_space.sample()))
121 |
122 |
--------------------------------------------------------------------------------
/tabular_maxent_irl/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time
3 | import scipy as sp
4 | import scipy.stats
5 | import contextlib
6 |
7 |
8 | def flat_to_one_hot(val, ndim):
9 | """
10 |
11 | >>> flat_to_one_hot(2, ndim=4)
12 | array([ 0., 0., 1., 0.])
13 | >>> flat_to_one_hot(4, ndim=5)
14 | array([ 0., 0., 0., 0., 1.])
15 | >>> flat_to_one_hot(np.array([2, 4, 3]), ndim=5)
16 | array([[ 0., 0., 1., 0., 0.],
17 | [ 0., 0., 0., 0., 1.],
18 | [ 0., 0., 0., 1., 0.]])
19 | """
20 | shape =np.array(val).shape
21 | v = np.zeros(shape + (ndim,))
22 | if len(shape) == 1:
23 | v[np.arange(shape[0]), val] = 1.0
24 | else:
25 | v[val] = 1.0
26 | return v
27 |
28 | def one_hot_to_flat(val):
29 | """
30 | >>> one_hot_to_flat(np.array([0,0,0,0,1]))
31 | 4
32 | >>> one_hot_to_flat(np.array([0,0,1,0]))
33 | 2
34 | >>> one_hot_to_flat(np.array([[0,0,1,0], [1,0,0,0], [0,1,0,0]]))
35 | array([2, 0, 1])
36 | """
37 | idxs = np.array(np.where(val == 1.0))[-1]
38 | if len(val.shape) == 1:
39 | return int(idxs)
40 | return idxs
41 |
42 |
43 | def flatten_list(lol):
44 | return [ a for b in lol for a in b ]
45 |
46 | class TrainingIterator(object):
47 | def __init__(self, itrs, heartbeat=float('inf')):
48 | self.itrs = itrs
49 | self.heartbeat_time = heartbeat
50 | self.__vals = {}
51 |
52 | def random_idx(self, N, size):
53 | return np.random.randint(0, N, size=size)
54 |
55 | @property
56 | def itr(self):
57 | return self.__itr
58 |
59 | @property
60 | def heartbeat(self):
61 | return self.__heartbeat
62 |
63 | @property
64 | def elapsed(self):
65 | assert self.heartbeat, 'elapsed is only valid when heartbeat=True'
66 | return self.__elapsed
67 |
68 | def itr_message(self):
69 | return '==> Itr %d/%d (elapsed:%.2f)' % (self.itr+1, self.itrs, self.elapsed)
70 |
71 | def record(self, key, value):
72 | if key in self.__vals:
73 | self.__vals[key].append(value)
74 | else:
75 | self.__vals[key] = [value]
76 |
77 | def pop(self, key):
78 | vals = self.__vals.get(key, [])
79 | del self.__vals[key]
80 | return vals
81 |
82 | def pop_mean(self, key):
83 | return np.mean(self.pop(key))
84 |
85 | def __iter__(self):
86 | prev_time = time.time()
87 | self.__heartbeat = False
88 | for i in range(self.itrs):
89 | self.__itr = i
90 | cur_time = time.time()
91 | if (cur_time-prev_time) > self.heartbeat_time or i==(self.itrs-1):
92 | self.__heartbeat = True
93 | self.__elapsed = cur_time-prev_time
94 | prev_time = cur_time
95 | yield self
96 | self.__heartbeat = False
97 |
98 |
99 | def gd_optimizer(lr, lr_sched=None):
100 | if lr_sched is None:
101 | lr_sched = {}
102 |
103 | itr = 0
104 | def update(x, grad):
105 | nonlocal itr, lr
106 | if itr in lr_sched:
107 | lr *= lr_sched[itr]
108 | new_x = x - lr * grad
109 | itr += 1
110 | return new_x
111 | return update
112 |
113 |
114 | def gd_momentum_optimizer(lr, momentum=0.9, lr_sched=None):
115 | if lr_sched is None:
116 | lr_sched = {}
117 |
118 | itr = 0
119 | prev_grad = None
120 | def update(x, grad):
121 | nonlocal itr, lr, prev_grad
122 | if itr in lr_sched:
123 | lr *= lr_sched[itr]
124 |
125 | if prev_grad is None:
126 | grad = grad
127 | else:
128 | grad = grad + momentum * prev_grad
129 | new_x = x - lr * grad
130 | prev_grad = grad
131 | itr += 1
132 | return new_x
133 | return update
134 |
135 |
136 | def adam_optimizer(lr, beta1=0.9, beta2=0.999, eps=1e-8):
137 | itr = 0
138 | pm = None
139 | pv = None
140 | def update(x, grad):
141 | nonlocal itr, lr, pm, pv
142 | if pm is None:
143 | pm = np.zeros_like(grad)
144 | pv = np.zeros_like(grad)
145 |
146 | pm = beta1 * pm + (1-beta1)*grad
147 | pv = beta2 * pv + (1-beta2)*(grad*grad)
148 | mhat = pm/(1-beta1**(itr+1))
149 | vhat = pv/(1-beta2**(itr+1))
150 | update_vec = mhat / (np.sqrt(vhat)+eps)
151 | new_x = x - lr * update_vec
152 | itr += 1
153 | return new_x
154 | return update
155 |
156 |
157 | @contextlib.contextmanager
158 | def np_seed(seed):
159 | """ A context for np random seeds """
160 | if seed is None:
161 | yield
162 | else:
163 | old_state = np.random.get_state()
164 | np.random.seed(seed)
165 | yield
166 | np.random.set_state(old_state)
167 |
168 |
--------------------------------------------------------------------------------