├── .gitignore ├── README.md ├── baselines ├── LICENSE ├── __init__.py └── common │ ├── __init__.py │ ├── math_util.py │ ├── misc_util.py │ ├── mpi_adam.py │ ├── mpi_aux_update.py │ ├── mpi_moments.py │ ├── mpi_running_mean_std.py │ ├── mpi_update.py │ └── tf_util.py ├── environment.yml ├── envs ├── __init__.py ├── goal_env_ext │ ├── __init__.py │ ├── assets │ │ ├── LICENSE.md │ │ ├── dm_control │ │ │ └── manipulator.xml │ │ ├── fetch │ │ │ ├── pick_and_place.xml │ │ │ ├── push.xml │ │ │ ├── reach.xml │ │ │ ├── robot.xml │ │ │ ├── shared.xml │ │ │ └── slide.xml │ │ ├── hand │ │ │ ├── manipulate_block.xml │ │ │ ├── manipulate_egg.xml │ │ │ ├── manipulate_pen.xml │ │ │ ├── reach.xml │ │ │ ├── robot.xml │ │ │ ├── shared.xml │ │ │ └── shared_asset.xml │ │ ├── reacher │ │ │ └── reacher.xml │ │ ├── stls │ │ │ ├── .get │ │ │ ├── fetch │ │ │ │ ├── base_link_collision.stl │ │ │ │ ├── bellows_link_collision.stl │ │ │ │ ├── elbow_flex_link_collision.stl │ │ │ │ ├── estop_link.stl │ │ │ │ ├── forearm_roll_link_collision.stl │ │ │ │ ├── gripper_link.stl │ │ │ │ ├── head_pan_link_collision.stl │ │ │ │ ├── head_tilt_link_collision.stl │ │ │ │ ├── l_wheel_link_collision.stl │ │ │ │ ├── laser_link.stl │ │ │ │ ├── r_wheel_link_collision.stl │ │ │ │ ├── shoulder_lift_link_collision.stl │ │ │ │ ├── shoulder_pan_link_collision.stl │ │ │ │ ├── torso_fixed_link.stl │ │ │ │ ├── torso_lift_link_collision.stl │ │ │ │ ├── upperarm_roll_link_collision.stl │ │ │ │ ├── wrist_flex_link_collision.stl │ │ │ │ └── wrist_roll_link_collision.stl │ │ │ └── hand │ │ │ │ ├── F1.stl │ │ │ │ ├── F2.stl │ │ │ │ ├── F3.stl │ │ │ │ ├── TH1_z.stl │ │ │ │ ├── TH2_z.stl │ │ │ │ ├── TH3_z.stl │ │ │ │ ├── forearm_electric.stl │ │ │ │ ├── forearm_electric_cvx.stl │ │ │ │ ├── knuckle.stl │ │ │ │ ├── lfmetacarpal.stl │ │ │ │ ├── palm.stl │ │ │ │ └── wrist.stl │ │ └── textures │ │ │ ├── block.png │ │ │ └── block_hidden.png │ ├── dm_control │ │ ├── __init__.py │ │ ├── common │ │ │ ├── __init__.py │ │ │ ├── materials.xml │ │ │ ├── skybox.xml │ │ │ └── visual.xml │ │ ├── dm_control_env.py │ │ ├── finger.py │ │ ├── finger.xml │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── parse_amc.py │ │ │ ├── parse_amc_test.py │ │ │ ├── randomizers.py │ │ │ └── randomizers_test.py │ │ └── wrappers │ │ │ ├── __init__.py │ │ │ ├── action_noise.py │ │ │ ├── action_noise_test.py │ │ │ ├── pixels.py │ │ │ └── pixels_test.py │ ├── fetch │ │ ├── __init__.py │ │ ├── fetch_env.py │ │ ├── pick_and_place.py │ │ ├── push.py │ │ ├── reach.py │ │ └── slide.py │ ├── goal_env_ext.py │ └── hand │ │ ├── __init__.py │ │ ├── hand_env.py │ │ └── reach.py ├── rotations.py └── utils.py ├── olaux ├── cnn_actor_critic.py ├── config.py ├── ddpg.py ├── her.py ├── logger.py ├── main.py ├── normalizer.py ├── quadratic.py ├── replay_buffer.py ├── rollout.py ├── test_tf.py └── utils.py └── results ├── VisualFetchReach.png ├── VisualFinger.png ├── VisualHandReach.png ├── mujoco.gif └── quadratic.gif /.gitignore: -------------------------------------------------------------------------------- 1 | *.sh 2 | data 3 | enable_cuda.sh 4 | *.img 5 | *.swp 6 | *.pyc 7 | *.pkl 8 | *.py~ 9 | .pytest_cache 10 | .DS_Store 11 | .idea 12 | 13 | .ipynb_checkpoints 14 | ghostdriver.log 15 | 16 | htmlcov 17 | 18 | *.egg-info 19 | .cache 20 | 21 | MUJOCO_LOG.TXT 22 | 23 | 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adaptive Auxiliary Task Weighting for Reinforcement Learning 2 | 3 | Xingyu Lin*, Harjatin Baweja*, George Kantor, David Held 4 | 5 | (* indicates equal contributions) 6 | 7 | **NeurIPS 2019** [[paper]](https://papers.nips.cc/paper/8724-adaptive-auxiliary-task-weighting-for-reinforcement-learning) 8 | 9 | ## Demo 10 | #### Quadratic Optimization 11 | ![](results/quadratic.gif) 12 | #### MuJoCo 13 | Below shows the policy learned with OL-AUX on the VisualHandReach, along with some of the auxiliary tasks. From left to right: Observation, goal, auto-encoder reconstruction, predicted optical flow, egomotion transformation image. 14 | 15 | ![](results/mujoco.gif) 16 | 17 | ## Installation 18 | 1. Install [MuJoCo](http://www.mujoco.org/) and [Conda](https://docs.conda.io/en/latest/minicondahtml) if you have not already. 19 | 2. Install prerequisite 20 | 21 | ``` 22 | sudo apt-get update && sudo apt-get install libopenmpi-dev python3-dev zlib1g-dev 23 | ``` 24 | 25 | 3. Create python environments 26 | 27 | ``` 28 | conda env create --file environment.yml 29 | ``` 30 | 31 | 4. Check OpenMPI version by running `mpirun --version`. The code is tested under version `2.0.2`. Make sure MPI is correctly installed by running `mpirun -np 2 python baselines/common/mpio_running_mean_std.py` 32 | 33 | ## Instructions 34 | First add the working directory to `PYTHONPATH`. 35 | 36 | ### Quadratic example 37 | ``` 38 | python olaux/quadratic.py 39 | ``` 40 | 41 | The resulted video will be saved to `results/quadratic.mp4`. 42 | 43 | ### MuJoCo experiments 44 | To produce the original results, make sure you have at least 8 CPU cores and ~60G memory. Run 45 | 46 | ``` 47 | . activate olaux 48 | export MUJOCO_PY_FORCE_CPU=True # Force CPU rendering 49 | export MUJOCO_GL=osmesa 50 | mkdir data && python olaux/main.py 51 | ``` 52 | Refer to the documentation in `olaux/main.py` for arguments. By default it runs OL-AUX on FetchReach and the data will be saved to 53 | ## Cite Our Paper 54 | If you find this codebase useful in your research, please consider citing: 55 | 56 | ``` 57 | @incollection{NIPS2019_8724, 58 | title = {Adaptive Auxiliary Task Weighting for Reinforcement Learning}, 59 | author = {Lin, Xingyu and Baweja, Harjatin and Kantor, George and Held, David}, 60 | booktitle = {Advances in Neural Information Processing Systems 32}, 61 | editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett}, 62 | pages = {4772--4783}, 63 | year = {2019}, 64 | publisher = {Curran Associates, Inc.}, 65 | url = {http://papers.nips.cc/paper/8724-adaptive-auxiliary-task-weighting-for-reinforcement-learning.pdf} 66 | } 67 | ``` 68 | 69 | ## Caveat 70 | Unfortunately, due to a series of accidents with our computing cluster and local computing machines, the original code used to create the plots in the paper was lost. Furthermore, due to complicated immigration situations, one of the authors who contribute siginificantly to the original code cannot legally work on anything related to this paper at the moment. 71 | 72 | This repository is a re-implementation of OL-AUX by one of the authors that mostly reproduces the original results. Please refer to the plots in `./results/` for the reproduced results. However, due to the issues mentioned above, many hyper-parameters and implementation details in this repository may differ from the original code in the paper. We will potentially release another version of the code once the other author is able to legally work again. 73 | 74 | We apologize for these issues. Please contact Xingyu Lin (xlin3@andrew.cmu.edu) and David Held (dheld@andrew.cmu.edu) if you have further questions. 75 | 76 | 77 | ## References 78 | The code here is based on the DDPG and HER implementation in [OpenAI Baselines](https://github.com/openai/baselines) 79 | -------------------------------------------------------------------------------- /baselines/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2017 OpenAI (http://openai.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /baselines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/baselines/__init__.py -------------------------------------------------------------------------------- /baselines/common/__init__.py: -------------------------------------------------------------------------------- 1 | from baselines.common.math_util import * 2 | from baselines.common.misc_util import * 3 | -------------------------------------------------------------------------------- /baselines/common/math_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.signal 3 | 4 | 5 | def discount(x, gamma): 6 | """ 7 | computes discounted sums along 0th dimension of x. 8 | 9 | inputs 10 | ------ 11 | x: ndarray 12 | gamma: float 13 | 14 | outputs 15 | ------- 16 | y: ndarray with same shape as x, satisfying 17 | 18 | y[t] = x[t] + gamma*x[t+1] + gamma^2*x[t+2] + ... + gamma^k x[t+k], 19 | where k = len(x) - t - 1 20 | 21 | """ 22 | assert x.ndim >= 1 23 | return scipy.signal.lfilter([1],[1,-gamma],x[::-1], axis=0)[::-1] 24 | 25 | def explained_variance(ypred,y): 26 | """ 27 | Computes fraction of variance that ypred explains about y. 28 | Returns 1 - Var[y-ypred] / Var[y] 29 | 30 | interpretation: 31 | ev=0 => might as well have predicted zero 32 | ev=1 => perfect prediction 33 | ev<0 => worse than just predicting zero 34 | 35 | """ 36 | assert y.ndim == 1 and ypred.ndim == 1 37 | vary = np.var(y) 38 | return np.nan if vary==0 else 1 - np.var(y-ypred)/vary 39 | 40 | def explained_variance_2d(ypred, y): 41 | assert y.ndim == 2 and ypred.ndim == 2 42 | vary = np.var(y, axis=0) 43 | out = 1 - np.var(y-ypred)/vary 44 | out[vary < 1e-10] = 0 45 | return out 46 | 47 | def ncc(ypred, y): 48 | return np.corrcoef(ypred, y)[1,0] 49 | 50 | def flatten_arrays(arrs): 51 | return np.concatenate([arr.flat for arr in arrs]) 52 | 53 | def unflatten_vector(vec, shapes): 54 | i=0 55 | arrs = [] 56 | for shape in shapes: 57 | size = np.prod(shape) 58 | arr = vec[i:i+size].reshape(shape) 59 | arrs.append(arr) 60 | i += size 61 | return arrs 62 | 63 | def discount_with_boundaries(X, New, gamma): 64 | """ 65 | X: 2d array of floats, time x features 66 | New: 2d array of bools, indicating when a new episode has started 67 | """ 68 | Y = np.zeros_like(X) 69 | T = X.shape[0] 70 | Y[T-1] = X[T-1] 71 | for t in range(T-2, -1, -1): 72 | Y[t] = X[t] + gamma * Y[t+1] * (1 - New[t+1]) 73 | return Y 74 | 75 | def test_discount_with_boundaries(): 76 | gamma=0.9 77 | x = np.array([1.0, 2.0, 3.0, 4.0], 'float32') 78 | starts = [1.0, 0.0, 0.0, 1.0] 79 | y = discount_with_boundaries(x, starts, gamma) 80 | assert np.allclose(y, [ 81 | 1 + gamma * 2 + gamma**2 * 3, 82 | 2 + gamma * 3, 83 | 3, 84 | 4 85 | ]) -------------------------------------------------------------------------------- /baselines/common/misc_util.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import os 4 | import pickle 5 | import random 6 | import tempfile 7 | import zipfile 8 | 9 | 10 | def zipsame(*seqs): 11 | L = len(seqs[0]) 12 | assert all(len(seq) == L for seq in seqs[1:]) 13 | return zip(*seqs) 14 | 15 | 16 | def unpack(seq, sizes): 17 | """ 18 | Unpack 'seq' into a sequence of lists, with lengths specified by 'sizes'. 19 | None = just one bare element, not a list 20 | 21 | Example: 22 | unpack([1,2,3,4,5,6], [3,None,2]) -> ([1,2,3], 4, [5,6]) 23 | """ 24 | seq = list(seq) 25 | it = iter(seq) 26 | assert sum(1 if s is None else s for s in sizes) == len(seq), "Trying to unpack %s into %s" % (seq, sizes) 27 | for size in sizes: 28 | if size is None: 29 | yield it.__next__() 30 | else: 31 | li = [] 32 | for _ in range(size): 33 | li.append(it.__next__()) 34 | yield li 35 | 36 | 37 | class EzPickle(object): 38 | """Objects that are pickled and unpickled via their constructor 39 | arguments. 40 | 41 | Example usage: 42 | 43 | class Dog(Animal, EzPickle): 44 | def __init__(self, furcolor, tailkind="bushy"): 45 | Animal.__init__() 46 | EzPickle.__init__(furcolor, tailkind) 47 | ... 48 | 49 | When this object is unpickled, a new Dog will be constructed by passing the provided 50 | furcolor and tailkind into the constructor. However, philosophers are still not sure 51 | whether it is still the same dog. 52 | 53 | This is generally needed only for environments which wrap C/C++ code, such as MuJoCo 54 | and Atari. 55 | """ 56 | 57 | def __init__(self, *args, **kwargs): 58 | self._ezpickle_args = args 59 | self._ezpickle_kwargs = kwargs 60 | 61 | def __getstate__(self): 62 | return {"_ezpickle_args": self._ezpickle_args, "_ezpickle_kwargs": self._ezpickle_kwargs} 63 | 64 | def __setstate__(self, d): 65 | out = type(self)(*d["_ezpickle_args"], **d["_ezpickle_kwargs"]) 66 | self.__dict__.update(out.__dict__) 67 | 68 | 69 | def set_global_seeds(i): 70 | try: 71 | import tensorflow as tf 72 | except ImportError: 73 | pass 74 | else: 75 | tf.set_random_seed(i) 76 | np.random.seed(i) 77 | random.seed(i) 78 | 79 | 80 | def pretty_eta(seconds_left): 81 | """Print the number of seconds in human readable format. 82 | 83 | Examples: 84 | 2 days 85 | 2 hours and 37 minutes 86 | less than a minute 87 | 88 | Paramters 89 | --------- 90 | seconds_left: int 91 | Number of seconds to be converted to the ETA 92 | Returns 93 | ------- 94 | eta: str 95 | String representing the pretty ETA. 96 | """ 97 | minutes_left = seconds_left // 60 98 | seconds_left %= 60 99 | hours_left = minutes_left // 60 100 | minutes_left %= 60 101 | days_left = hours_left // 24 102 | hours_left %= 24 103 | 104 | def helper(cnt, name): 105 | return "{} {}{}".format(str(cnt), name, ('s' if cnt > 1 else '')) 106 | 107 | if days_left > 0: 108 | msg = helper(days_left, 'day') 109 | if hours_left > 0: 110 | msg += ' and ' + helper(hours_left, 'hour') 111 | return msg 112 | if hours_left > 0: 113 | msg = helper(hours_left, 'hour') 114 | if minutes_left > 0: 115 | msg += ' and ' + helper(minutes_left, 'minute') 116 | return msg 117 | if minutes_left > 0: 118 | return helper(minutes_left, 'minute') 119 | return 'less than a minute' 120 | 121 | 122 | class RunningAvg(object): 123 | def __init__(self, gamma, init_value=None): 124 | """Keep a running estimate of a quantity. This is a bit like mean 125 | but more sensitive to recent changes. 126 | 127 | Parameters 128 | ---------- 129 | gamma: float 130 | Must be between 0 and 1, where 0 is the most sensitive to recent 131 | changes. 132 | init_value: float or None 133 | Initial value of the estimate. If None, it will be set on the first update. 134 | """ 135 | self._value = init_value 136 | self._gamma = gamma 137 | 138 | def update(self, new_val): 139 | """Update the estimate. 140 | 141 | Parameters 142 | ---------- 143 | new_val: float 144 | new observated value of estimated quantity. 145 | """ 146 | if self._value is None: 147 | self._value = new_val 148 | else: 149 | self._value = self._gamma * self._value + (1.0 - self._gamma) * new_val 150 | 151 | def __float__(self): 152 | """Get the current estimate""" 153 | return self._value 154 | 155 | def boolean_flag(parser, name, default=False, help=None): 156 | """Add a boolean flag to argparse parser. 157 | 158 | Parameters 159 | ---------- 160 | parser: argparse.Parser 161 | parser to add the flag to 162 | name: str 163 | -- will enable the flag, while --no- will disable it 164 | default: bool or None 165 | default value of the flag 166 | help: str 167 | help string for the flag 168 | """ 169 | dest = name.replace('-', '_') 170 | parser.add_argument("--" + name, action="store_true", default=default, dest=dest, help=help) 171 | parser.add_argument("--no-" + name, action="store_false", dest=dest) 172 | 173 | 174 | def get_wrapper_by_name(env, classname): 175 | """Given an a gym environment possibly wrapped multiple times, returns a wrapper 176 | of class named classname or raises ValueError if no such wrapper was applied 177 | 178 | Parameters 179 | ---------- 180 | env: gym.Env of gym.Wrapper 181 | gym environment 182 | classname: str 183 | name of the wrapper 184 | 185 | Returns 186 | ------- 187 | wrapper: gym.Wrapper 188 | wrapper named classname 189 | """ 190 | currentenv = env 191 | while True: 192 | if classname == currentenv.class_name(): 193 | return currentenv 194 | elif isinstance(currentenv, gym.Wrapper): 195 | currentenv = currentenv.env 196 | else: 197 | raise ValueError("Couldn't find wrapper named %s" % classname) 198 | 199 | 200 | def relatively_safe_pickle_dump(obj, path, compression=False): 201 | """This is just like regular pickle dump, except from the fact that failure cases are 202 | different: 203 | 204 | - It's never possible that we end up with a pickle in corrupted state. 205 | - If a there was a different file at the path, that file will remain unchanged in the 206 | even of failure (provided that filesystem rename is atomic). 207 | - it is sometimes possible that we end up with useless temp file which needs to be 208 | deleted manually (it will be removed automatically on the next function call) 209 | 210 | The indended use case is periodic checkpoints of experiment state, such that we never 211 | corrupt previous checkpoints if the current one fails. 212 | 213 | Parameters 214 | ---------- 215 | obj: object 216 | object to pickle 217 | path: str 218 | path to the output file 219 | compression: bool 220 | if true pickle will be compressed 221 | """ 222 | temp_storage = path + ".relatively_safe" 223 | if compression: 224 | # Using gzip here would be simpler, but the size is limited to 2GB 225 | with tempfile.NamedTemporaryFile() as uncompressed_file: 226 | pickle.dump(obj, uncompressed_file) 227 | uncompressed_file.file.flush() 228 | with zipfile.ZipFile(temp_storage, "w", compression=zipfile.ZIP_DEFLATED) as myzip: 229 | myzip.write(uncompressed_file.name, "data") 230 | else: 231 | with open(temp_storage, "wb") as f: 232 | pickle.dump(obj, f) 233 | os.rename(temp_storage, path) 234 | 235 | 236 | def pickle_load(path, compression=False): 237 | """Unpickle a possible compressed pickle. 238 | 239 | Parameters 240 | ---------- 241 | path: str 242 | path to the output file 243 | compression: bool 244 | if true assumes that pickle was compressed when created and attempts decompression. 245 | 246 | Returns 247 | ------- 248 | obj: object 249 | the unpickled object 250 | """ 251 | 252 | if compression: 253 | with zipfile.ZipFile(path, "r", compression=zipfile.ZIP_DEFLATED) as myzip: 254 | with myzip.open("data") as f: 255 | return pickle.load(f) 256 | else: 257 | with open(path, "rb") as f: 258 | return pickle.load(f) 259 | -------------------------------------------------------------------------------- /baselines/common/mpi_adam.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | import baselines.common.tf_util as U 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | 7 | class MpiAdam(object): 8 | def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None): 9 | self.var_list = var_list 10 | self.beta1 = beta1 11 | self.beta2 = beta2 12 | self.epsilon = epsilon 13 | self.scale_grad_by_procs = scale_grad_by_procs 14 | size = sum(U.numel(v) for v in var_list) 15 | self.m = np.zeros(size, 'float32') 16 | self.v = np.zeros(size, 'float32') 17 | self.t = 0 18 | self.setfromflat = U.SetFromFlat(var_list) 19 | self.getflat = U.GetFlat(var_list) 20 | self.comm = MPI.COMM_WORLD if comm is None else comm 21 | 22 | def update(self, localg, stepsize): 23 | if self.t % 100 == 0: 24 | self.check_synced() 25 | localg = localg.astype('float32') 26 | globalg = np.zeros_like(localg) 27 | self.comm.Allreduce(localg, globalg, op=MPI.SUM) 28 | if self.scale_grad_by_procs: 29 | globalg /= self.comm.Get_size() 30 | 31 | self.t += 1 32 | a = stepsize * np.sqrt(1 - self.beta2 ** self.t) / (1 - self.beta1 ** self.t) 33 | self.m = self.beta1 * self.m + (1 - self.beta1) * globalg 34 | self.v = self.beta2 * self.v + (1 - self.beta2) * (globalg * globalg) 35 | step = (- a) * self.m / (np.sqrt(self.v) + self.epsilon) 36 | self.setfromflat(self.getflat() + step) 37 | 38 | def sync(self): 39 | theta = self.getflat() 40 | self.comm.Bcast(theta, root=0) 41 | self.setfromflat(theta) 42 | 43 | def check_synced(self): 44 | if self.comm.Get_rank() == 0: # this is root 45 | theta = self.getflat() 46 | self.comm.Bcast(theta, root=0) 47 | else: 48 | thetalocal = self.getflat() 49 | thetaroot = np.empty_like(thetalocal) 50 | self.comm.Bcast(thetaroot, root=0) 51 | assert (thetaroot == thetalocal).all(), (thetaroot, thetalocal) 52 | 53 | 54 | @U.in_session 55 | def test_MpiAdam(): 56 | np.random.seed(0) 57 | tf.set_random_seed(0) 58 | 59 | a = tf.Variable(np.random.randn(3).astype('float32')) 60 | b = tf.Variable(np.random.randn(2, 5).astype('float32')) 61 | loss = tf.reduce_sum(tf.square(a)) + tf.reduce_sum(tf.sin(b)) 62 | 63 | stepsize = 1e-2 64 | update_op = tf.train.AdamOptimizer(stepsize).minimize(loss) 65 | do_update = U.function([], loss, updates=[update_op]) 66 | 67 | tf.get_default_session().run(tf.global_variables_initializer()) 68 | for i in range(10): 69 | print(i, do_update()) 70 | 71 | tf.set_random_seed(0) 72 | tf.get_default_session().run(tf.global_variables_initializer()) 73 | 74 | var_list = [a, b] 75 | lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)], updates=[update_op]) 76 | adam = MpiAdam(var_list) 77 | 78 | for i in range(10): 79 | l, g = lossandgrad() 80 | adam.update(g, stepsize) 81 | print(i, l) 82 | -------------------------------------------------------------------------------- /baselines/common/mpi_aux_update.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | import baselines.common.tf_util as U 3 | import numpy as np 4 | 5 | 6 | class MpiAuxUpdate(object): 7 | def __init__(self, var_list, *, scale_grad_by_procs=True, comm=None): 8 | self.var_list = var_list 9 | self.scale_grad_by_procs = scale_grad_by_procs 10 | self.t = 0 11 | self.setfromflat = U.SetFromFlat(var_list) 12 | self.getflat = U.GetFlat(var_list) 13 | self.comm = MPI.COMM_WORLD if comm is None else comm 14 | 15 | def flatten(self, main_grad, aux_grads): 16 | flattened = np.concatenate([main_grad.reshape(1, -1), aux_grads], axis=0).flatten() 17 | N = len(aux_grads) + 1 18 | return flattened, N 19 | 20 | def unflatten(self, global_grad, N): 21 | assert len(global_grad) % N ==0 22 | grad = np.array(global_grad).reshape([N, -1]) 23 | main_grad = grad[0] 24 | aux_grads = grad[1:] 25 | return main_grad, aux_grads 26 | 27 | def get_syncd_grad(self, main_grad, aux_grads): 28 | if self.t % 100 == 0: 29 | self.check_synced() 30 | local_grad, N = self.flatten(main_grad, aux_grads) 31 | local_grad = local_grad.astype('float32') 32 | global_grad = np.zeros_like(local_grad) 33 | self.comm.Allreduce(local_grad, global_grad, op=MPI.SUM) 34 | global_grad /= self.comm.Get_size() 35 | self.t += 1 36 | return self.unflatten(global_grad, N) 37 | 38 | def update(self, globalg): 39 | self.setfromflat(np.clip(self.getflat() + globalg, 0., None)) 40 | 41 | def set(self, local_param): 42 | if self.t % 100 == 0: 43 | self.check_synced() 44 | local_param = local_param.astype('float32') 45 | global_param = np.zeros_like(local_param) 46 | self.comm.Allreduce(local_param, global_param, op=MPI.SUM) 47 | if self.scale_grad_by_procs: 48 | global_param /= self.comm.Get_size() 49 | 50 | self.t += 1 51 | self.setfromflat(np.clip(global_param, 0., None)) 52 | 53 | def sync(self): 54 | weight = self.getflat() 55 | self.comm.Bcast(weight, root=0) 56 | self.setfromflat(weight) 57 | 58 | def check_synced(self): 59 | if self.comm.Get_rank() == 0: # this is root 60 | weight = self.getflat() 61 | self.comm.Bcast(weight, root=0) 62 | else: 63 | weightlocal = self.getflat() 64 | weightroot = np.empty_like(weightlocal) 65 | self.comm.Bcast(weightroot, root=0) 66 | assert (weightroot == weightlocal).all(), (weightroot, weightlocal) 67 | -------------------------------------------------------------------------------- /baselines/common/mpi_moments.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | import numpy as np 3 | from baselines.common import zipsame 4 | 5 | def mpi_mean(x, axis=0, comm=None, keepdims=False): 6 | x = np.asarray(x) 7 | assert x.ndim > 0 8 | if comm is None: comm = MPI.COMM_WORLD 9 | xsum = x.sum(axis=axis, keepdims=keepdims) 10 | n = xsum.size 11 | localsum = np.zeros(n+1, x.dtype) 12 | localsum[:n] = xsum.ravel() 13 | localsum[n] = x.shape[axis] 14 | globalsum = np.zeros_like(localsum) 15 | comm.Allreduce(localsum, globalsum, op=MPI.SUM) 16 | return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n] 17 | 18 | def mpi_moments(x, axis=0, comm=None, keepdims=False): 19 | x = np.asarray(x) 20 | assert x.ndim > 0 21 | mean, count = mpi_mean(x, axis=axis, comm=comm, keepdims=True) 22 | sqdiffs = np.square(x - mean) 23 | meansqdiff, count1 = mpi_mean(sqdiffs, axis=axis, comm=comm, keepdims=True) 24 | assert count1 == count 25 | std = np.sqrt(meansqdiff) 26 | if not keepdims: 27 | newshape = mean.shape[:axis] + mean.shape[axis+1:] 28 | mean = mean.reshape(newshape) 29 | std = std.reshape(newshape) 30 | return mean, std, count 31 | 32 | 33 | def test_runningmeanstd(): 34 | import subprocess 35 | subprocess.check_call(['mpirun', '-np', '3', 36 | 'python','-c', 37 | 'from baselines.common.mpi_moments import _helper_runningmeanstd; _helper_runningmeanstd()']) 38 | 39 | def _helper_runningmeanstd(): 40 | comm = MPI.COMM_WORLD 41 | np.random.seed(0) 42 | for (triple,axis) in [ 43 | ((np.random.randn(3), np.random.randn(4), np.random.randn(5)),0), 44 | ((np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),0), 45 | ((np.random.randn(2,3), np.random.randn(2,4), np.random.randn(2,4)),1), 46 | ]: 47 | 48 | 49 | x = np.concatenate(triple, axis=axis) 50 | ms1 = [x.mean(axis=axis), x.std(axis=axis), x.shape[axis]] 51 | 52 | 53 | ms2 = mpi_moments(triple[comm.Get_rank()],axis=axis) 54 | 55 | for (a1,a2) in zipsame(ms1, ms2): 56 | print(a1, a2) 57 | assert np.allclose(a1, a2) 58 | print("ok!") 59 | 60 | -------------------------------------------------------------------------------- /baselines/common/mpi_running_mean_std.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | import tensorflow as tf, baselines.common.tf_util as U, numpy as np 3 | 4 | class RunningMeanStd(object): 5 | # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 6 | def __init__(self, epsilon=1e-2, shape=()): 7 | 8 | self._sum = tf.get_variable( 9 | dtype=tf.float64, 10 | shape=shape, 11 | initializer=tf.constant_initializer(0.0), 12 | name="runningsum", trainable=False) 13 | self._sumsq = tf.get_variable( 14 | dtype=tf.float64, 15 | shape=shape, 16 | initializer=tf.constant_initializer(epsilon), 17 | name="runningsumsq", trainable=False) 18 | self._count = tf.get_variable( 19 | dtype=tf.float64, 20 | shape=(), 21 | initializer=tf.constant_initializer(epsilon), 22 | name="count", trainable=False) 23 | self.shape = shape 24 | 25 | self.mean = tf.to_float(self._sum / self._count) 26 | self.std = tf.sqrt( tf.maximum( tf.to_float(self._sumsq / self._count) - tf.square(self.mean) , 1e-2 )) 27 | 28 | newsum = tf.placeholder(shape=self.shape, dtype=tf.float64, name='sum') 29 | newsumsq = tf.placeholder(shape=self.shape, dtype=tf.float64, name='var') 30 | newcount = tf.placeholder(shape=[], dtype=tf.float64, name='count') 31 | self.incfiltparams = U.function([newsum, newsumsq, newcount], [], 32 | updates=[tf.assign_add(self._sum, newsum), 33 | tf.assign_add(self._sumsq, newsumsq), 34 | tf.assign_add(self._count, newcount)]) 35 | 36 | 37 | def update(self, x): 38 | x = x.astype('float64') 39 | n = int(np.prod(self.shape)) 40 | totalvec = np.zeros(n*2+1, 'float64') 41 | addvec = np.concatenate([x.sum(axis=0).ravel(), np.square(x).sum(axis=0).ravel(), np.array([len(x)],dtype='float64')]) 42 | MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM) 43 | self.incfiltparams(totalvec[0:n].reshape(self.shape), totalvec[n:2*n].reshape(self.shape), totalvec[2*n]) 44 | 45 | @U.in_session 46 | def test_runningmeanstd(): 47 | for (x1, x2, x3) in [ 48 | (np.random.randn(3), np.random.randn(4), np.random.randn(5)), 49 | (np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)), 50 | ]: 51 | 52 | rms = RunningMeanStd(epsilon=0.0, shape=x1.shape[1:]) 53 | U.initialize() 54 | 55 | x = np.concatenate([x1, x2, x3], axis=0) 56 | ms1 = [x.mean(axis=0), x.std(axis=0)] 57 | rms.update(x1) 58 | rms.update(x2) 59 | rms.update(x3) 60 | ms2 = [rms.mean.eval(), rms.std.eval()] 61 | #print(ms1,ms2) 62 | assert np.allclose(ms1, ms2) 63 | 64 | @U.in_session 65 | def test_dist(): 66 | np.random.seed(0) 67 | p1,p2,p3=(np.random.randn(3,1), np.random.randn(4,1), np.random.randn(5,1)) 68 | q1,q2,q3=(np.random.randn(6,1), np.random.randn(7,1), np.random.randn(8,1)) 69 | 70 | # p1,p2,p3=(np.random.randn(3), np.random.randn(4), np.random.randn(5)) 71 | # q1,q2,q3=(np.random.randn(6), np.random.randn(7), np.random.randn(8)) 72 | 73 | comm = MPI.COMM_WORLD 74 | assert comm.Get_size()==2 75 | if comm.Get_rank()==0: 76 | x1,x2,x3 = p1,p2,p3 77 | elif comm.Get_rank()==1: 78 | x1,x2,x3 = q1,q2,q3 79 | else: 80 | assert False 81 | 82 | rms = RunningMeanStd(epsilon=0.0, shape=(1,)) 83 | U.initialize() 84 | 85 | rms.update(x1) 86 | rms.update(x2) 87 | rms.update(x3) 88 | 89 | bigvec = np.concatenate([p1,p2,p3,q1,q2,q3]) 90 | 91 | def checkallclose(x,y): 92 | print(x,y) 93 | return np.allclose(x,y) 94 | 95 | assert checkallclose( 96 | bigvec.mean(axis=0), 97 | rms.mean.eval(), 98 | ) 99 | assert checkallclose( 100 | bigvec.std(axis=0), 101 | rms.std.eval(), 102 | ) 103 | 104 | 105 | if __name__ == "__main__": 106 | # Run with mpirun -np 2 python 107 | test_dist() 108 | #test_runningmeanstd() 109 | -------------------------------------------------------------------------------- /baselines/common/mpi_update.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | import baselines.common.tf_util as U 3 | import numpy as np 4 | 5 | 6 | class MpiUpdate(object): 7 | def __init__(self, var_list, *, scale_grad_by_procs=True, comm=None): 8 | self.var_list = var_list 9 | self.scale_grad_by_procs = scale_grad_by_procs 10 | self.t = 0 11 | self.setfromflat = U.SetFromFlat(var_list) 12 | self.getflat = U.GetFlat(var_list) 13 | self.comm = MPI.COMM_WORLD if comm is None else comm 14 | 15 | def update(self, localg): 16 | if self.t % 100 == 0: 17 | self.check_synced() 18 | localg = localg.astype('float32') 19 | globalg = np.zeros_like(localg) 20 | self.comm.Allreduce(localg, globalg, op=MPI.SUM) 21 | if self.scale_grad_by_procs: 22 | globalg /= self.comm.Get_size() 23 | 24 | self.t += 1 25 | self.setfromflat(np.clip(self.getflat() + globalg, 0., None)) 26 | 27 | def set(self, local_param): 28 | if self.t % 100 == 0: 29 | self.check_synced() 30 | local_param = local_param.astype('float32') 31 | global_param = np.zeros_like(local_param) 32 | self.comm.Allreduce(local_param, global_param, op=MPI.SUM) 33 | if self.scale_grad_by_procs: 34 | global_param /= self.comm.Get_size() 35 | 36 | self.t += 1 37 | self.setfromflat(np.clip(global_param, 0., None)) 38 | 39 | def sync(self): 40 | weight = self.getflat() 41 | self.comm.Bcast(weight, root=0) 42 | self.setfromflat(weight) 43 | 44 | def check_synced(self): 45 | if self.comm.Get_rank() == 0: # this is root 46 | weight = self.getflat() 47 | self.comm.Bcast(weight, root=0) 48 | else: 49 | weightlocal = self.getflat() 50 | weightroot = np.empty_like(weightlocal) 51 | self.comm.Bcast(weightroot, root=0) 52 | assert (weightroot == weightlocal).all(), (weightroot, weightlocal) 53 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: olaux 2 | channels: 3 | - https://conda.anaconda.org/kne 4 | - https://conda.anaconda.org/tlatorre 5 | - https://conda.anaconda.org/cjs14 6 | - https://conda.anaconda.org/menpo 7 | - jjhelmus 8 | dependencies: 9 | - python==3.5.4 10 | - numpy<1.17 11 | - scipy 12 | - flask 13 | - path.py 14 | - matplotlib 15 | - scikit-learn 16 | - plotly 17 | - joblib 18 | - dm_control 19 | - cudnn==7.6.5 20 | - pip: 21 | - python-dateutil 22 | - opencv-python 23 | - mpi4py==3.0.0 24 | - Pillow 25 | - boto3 26 | - PyOpenGL 27 | - mujoco-py<1.50.2,>=1.50.1 28 | - cached_property 29 | - Cython 30 | - git+https://github.com/plotly/plotly.py.git@2594076e29584ede2d09f2aa40a8a195b3f3fc66#egg=plotly 31 | - gym==0.10.9 32 | - git+https://github.com/neocxi/prettytensor.git 33 | - chainer==1.18.0 34 | - tensorflow==1.14.0 35 | - tensorflow-gpu==1.14.0 36 | - dominate 37 | - scikit-image 38 | -------------------------------------------------------------------------------- /envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/__init__.py -------------------------------------------------------------------------------- /envs/goal_env_ext/__init__.py: -------------------------------------------------------------------------------- 1 | # Created by Xingyu Lin, 04/09/2018 2 | -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/LICENSE.md: -------------------------------------------------------------------------------- 1 | This work contains code used under the following license: 2 | 3 | # ShadowHand 4 | The model of the [ShadowHand](https://www.shadowrobot.com/products/dexterous-hand/) is based on [models 5 | provided by Shadow](https://github.com/shadow-robot/sr_common/tree/kinetic-devel/sr_description/hand/model). 6 | It was adapted and refined by Vikash Kumar and OpenAI. 7 | 8 | (C) Vikash Kumar, CSE, UW. Licensed under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0. Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. 9 | 10 | # Fetch Robotics 11 | The model of the [Fetch](http://fetchrobotics.com/platforms-research-development/) is based on [models provided by Fetch](https://github.com/fetchrobotics/fetch_ros/tree/indigo-devel/fetch_description). 12 | It was adapted and refined by OpenAI. 13 | -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/dm_control/manipulator.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | > 15 | 16 | 212 | -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/fetch/pick_and_place.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/fetch/push.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/fetch/reach.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/fetch/robot.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/fetch/shared.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/fetch/slide.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/hand/manipulate_block.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/hand/manipulate_egg.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/hand/manipulate_pen.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/hand/reach.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/hand/shared_asset.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/reacher/reacher.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 42 | 43 | -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/.get: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/.get -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/fetch/base_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/fetch/base_link_collision.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/fetch/bellows_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/fetch/bellows_link_collision.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/fetch/elbow_flex_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/fetch/elbow_flex_link_collision.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/fetch/estop_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/fetch/estop_link.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/fetch/forearm_roll_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/fetch/forearm_roll_link_collision.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/fetch/gripper_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/fetch/gripper_link.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/fetch/head_pan_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/fetch/head_pan_link_collision.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/fetch/head_tilt_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/fetch/head_tilt_link_collision.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/fetch/l_wheel_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/fetch/l_wheel_link_collision.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/fetch/laser_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/fetch/laser_link.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/fetch/r_wheel_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/fetch/r_wheel_link_collision.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/fetch/shoulder_lift_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/fetch/shoulder_lift_link_collision.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/fetch/shoulder_pan_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/fetch/shoulder_pan_link_collision.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/fetch/torso_fixed_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/fetch/torso_fixed_link.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/fetch/torso_lift_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/fetch/torso_lift_link_collision.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/fetch/upperarm_roll_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/fetch/upperarm_roll_link_collision.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/fetch/wrist_flex_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/fetch/wrist_flex_link_collision.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/fetch/wrist_roll_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/fetch/wrist_roll_link_collision.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/hand/F1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/hand/F1.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/hand/F2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/hand/F2.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/hand/F3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/hand/F3.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/hand/TH1_z.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/hand/TH1_z.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/hand/TH2_z.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/hand/TH2_z.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/hand/TH3_z.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/hand/TH3_z.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/hand/forearm_electric.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/hand/forearm_electric.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/hand/forearm_electric_cvx.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/hand/forearm_electric_cvx.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/hand/knuckle.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/hand/knuckle.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/hand/lfmetacarpal.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/hand/lfmetacarpal.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/hand/palm.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/hand/palm.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/stls/hand/wrist.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/stls/hand/wrist.stl -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/textures/block.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/textures/block.png -------------------------------------------------------------------------------- /envs/goal_env_ext/assets/textures/block_hidden.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/envs/goal_env_ext/assets/textures/block_hidden.png -------------------------------------------------------------------------------- /envs/goal_env_ext/dm_control/__init__.py: -------------------------------------------------------------------------------- 1 | from envs.goal_env_ext.dm_control.finger import Finger 2 | -------------------------------------------------------------------------------- /envs/goal_env_ext/dm_control/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Functions to manage the common assets for domains.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | from dm_control.utils import io as resources 24 | 25 | _SUITE_DIR = os.path.dirname(os.path.dirname(__file__)) 26 | _FILENAMES = [ 27 | "./common/materials.xml", 28 | "./common/skybox.xml", 29 | "./common/visual.xml", 30 | ] 31 | 32 | ASSETS = {filename: resources.GetResource(os.path.join(_SUITE_DIR, filename)) 33 | for filename in _FILENAMES} 34 | 35 | 36 | def read_model(model_filename): 37 | """Reads a model XML file and returns its contents as a string.""" 38 | return resources.GetResource(os.path.join(_SUITE_DIR, model_filename)) 39 | -------------------------------------------------------------------------------- /envs/goal_env_ext/dm_control/common/materials.xml: -------------------------------------------------------------------------------- 1 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /envs/goal_env_ext/dm_control/common/skybox.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /envs/goal_env_ext/dm_control/common/visual.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /envs/goal_env_ext/dm_control/dm_control_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import spaces 3 | from gym.utils import seeding 4 | import os 5 | import numpy as np 6 | from numpy.random import random 7 | import time 8 | from os import path 9 | import cv2 as cv 10 | from dm_control import mujoco 11 | from dm_control.suite import base 12 | from dm_control import mujoco, viewer 13 | from dm_control.rl import control 14 | from dm_control.suite import base 15 | from dm_control.rl.control import PhysicsError 16 | from dm_control.mujoco import Physics 17 | from dm_control.utils import io as resources 18 | from dm_control.suite.utils import randomizers 19 | from dm_control.utils import xml_tools 20 | from envs.goal_env_ext.goal_env_ext import GoalEnvExt 21 | from termcolor import colored 22 | from lxml import etree 23 | from dm_control.utils import io as resources 24 | from envs.goal_env_ext.dm_control import common 25 | 26 | 27 | class DmControlEnv(GoalEnvExt): 28 | def __init__(self, model_path, n_substeps, initial_qpos, 29 | distance_threshold, 30 | distance_threshold_obs, 31 | horizon, 32 | image_size, 33 | n_actions, 34 | reward_type='sparse', 35 | use_image_goal=False, 36 | with_goal=True, 37 | use_visual_observation=True, 38 | default_camera_name='fixed', 39 | fully_observable=True, 40 | **kwargs): 41 | 42 | """ 43 | :param model_path: 44 | :param distance_threshold: 45 | :param frame_skip: 46 | :param horizon: 47 | :param goal_range: 48 | :param image_size: 49 | """ 50 | if model_path.startswith("/"): 51 | fullpath = model_path 52 | else: 53 | fullpath = os.path.join(os.path.dirname(__file__), model_path) 54 | if not path.exists(fullpath): 55 | raise IOError("File %s does not exist" % fullpath) 56 | 57 | self.model_path = fullpath 58 | model_string, assets = self.get_model_and_assets() 59 | self.physics = Physics.from_xml_string(model_string, 60 | assets=assets) # export MUJOCO_GL=osmesa to resolve mujoco context error 61 | # self.physics = Physics.from_xml_string(*self.get_model_and_assets1(fullpath)) 62 | self._init_configure() 63 | self.time_step = 0 64 | self.random = np.random.RandomState(None) 65 | self._fully_observable = fully_observable 66 | super(DmControlEnv, self).__init__(model_path=self.model_path, n_substeps=n_substeps, n_actions=n_actions, 67 | initial_qpos=initial_qpos, use_image_goal=use_image_goal, 68 | use_visual_observation=use_visual_observation, reward_type=reward_type, 69 | distance_threshold=distance_threshold, 70 | distance_threshold_obs=distance_threshold_obs, horizon=horizon, 71 | image_size=image_size, with_goal=with_goal, dm_control_env=True, 72 | default_camera_name=default_camera_name, **kwargs) 73 | 74 | def _set_action(self, action): 75 | try: 76 | self.physics.set_control(action.continuous_actions) 77 | except AttributeError: 78 | self.physics.set_control(action) 79 | 80 | try: 81 | for _ in range(self.n_substeps): 82 | self.physics.step() 83 | except PhysicsError as ex: 84 | print(colored(ex, 'red')) 85 | 86 | # def step(self, action): 87 | # action = action.flatten() 88 | # # action = np.clip(action, self.action_space.low, self.action_space.high) 89 | # 90 | # try: 91 | # self._set_action(action) 92 | # except NotImplementedError: 93 | # try: 94 | # self.physics.set_control(action.continuous_actions) 95 | # except AttributeError: 96 | # self.physics.set_control(action) 97 | # 98 | # try: 99 | # for _ in range(self.n_substeps): 100 | # self.physics.step() 101 | # except PhysicsError as ex: 102 | # print(colored(ex, 'red')) 103 | # 104 | # obs = self._get_obs() 105 | # 106 | # if self.use_auxiliary_info: 107 | # NotImplementedError 108 | # else: 109 | # aug_info = {} 110 | # 111 | # state_info = self.get_current_info() 112 | # info = {**aug_info, **state_info} 113 | # reward = self.compute_reward(obs['achieved_goal'], obs['desired_goal'], info) 114 | # self.time_step += 1 115 | # done = False 116 | # 117 | # if self.time_step >= self.horizon: 118 | # done = True 119 | # return obs, reward, done, info 120 | 121 | def _is_success(self, achieved_goal, desired_goal): 122 | 123 | achieved_goal = achieved_goal.reshape([-1, self.goal_state_dim]) 124 | desired_goal = desired_goal.reshape([-1, self.goal_state_dim]) 125 | d = np.linalg.norm(achieved_goal - desired_goal, axis=-1) 126 | return (d < self.distance_threshold).astype(np.float32) 127 | 128 | def get_model_and_assets(self): 129 | """Returns a tuple containing the model XML string and a dict of assets. 130 | 131 | Args: 132 | n_joints: An integer specifying the number of joints in the swimmer. 133 | 134 | Returns: 135 | A tuple `(model_xml_string, assets)`, where `assets` is a dict consisting of 136 | `{filename: contents_string}` pairs. 137 | """ 138 | 139 | return self._make_model(), common.ASSETS 140 | 141 | def read_model(self, model_filename): 142 | """Reads a model XML file and returns its contents as a string.""" 143 | return resources.GetResource(model_filename) 144 | 145 | def _init_configure(self): 146 | pass 147 | 148 | def render(self, mode='rgb_array', image_size=None, camera_name=None, depth=True): 149 | if image_size is None: 150 | image_size = self.image_size 151 | 152 | if camera_name is None: 153 | camera_name = self.default_camera_name 154 | 155 | if not depth: 156 | obs = self.physics.render(height=image_size, width=image_size, camera_id=camera_name, depth=depth) 157 | else: 158 | img = self.physics.render(height=image_size, width=image_size, camera_id=camera_name, depth=False) 159 | depth_img = self.physics.render(height=image_size, width=image_size, camera_id=camera_name, depth=True) 160 | obs = np.dstack((img, depth_img)) 161 | 162 | return obs 163 | -------------------------------------------------------------------------------- /envs/goal_env_ext/dm_control/finger.py: -------------------------------------------------------------------------------- 1 | from gym import utils 2 | import numpy as np 3 | from dm_control.rl.control import PhysicsError 4 | from dm_control.suite.utils import randomizers 5 | from termcolor import colored 6 | from dm_control.utils import containers 7 | from envs.goal_env_ext.dm_control.dm_control_env import DmControlEnv 8 | from envs.goal_env_ext.dm_control import common 9 | 10 | _DEFAULT_TIME_LIMIT = 20 # (seconds) 11 | _CONTROL_TIMESTEP = .02 # (seconds) 12 | # For TURN tasks, the 'tip' geom needs to enter a spherical target of sizes: 13 | _EASY_TARGET_SIZE = 0.00 14 | _HARD_TARGET_SIZE = 0.03 15 | # Initial spin velocity for the Stop task. 16 | _INITIAL_SPIN_VELOCITY = 100 17 | # Spinning slower than this value (radian/second) is considered stopped. 18 | _STOP_VELOCITY = 1e-6 19 | # Spinning faster than this value (radian/second) is considered spinning. 20 | _SPIN_VELOCITY = 15.0 21 | 22 | SUITE = containers.TaggedTasks() 23 | 24 | from collections import deque 25 | 26 | 27 | class Finger(DmControlEnv, utils.EzPickle): 28 | def __init__(self, model_path='finger.xml', n_substeps=2, n_actions=2, goal_range=[-2.0, 2.0], stack_obs=False, **kwargs): 29 | # horizon = 1000, goal_range = [-2.0, 2.0], image_size = 460, init_position = 'goal_range', use_auxiliary_loss = False, use_visual_observation = True, 30 | # noisy_reward_fp = False, noisy_reward_fn = False, use_true_reward = True,distance_threshold_obs = 0.0, **kwargs): 31 | 32 | """ 33 | :param model_path: 34 | :param distance_threshold: 35 | :param frame_skip: 36 | :param horizon: 37 | :param goal_range: 38 | :param action_type: Should be in ['force', 'velocity', 'position'] 39 | :param image_size: 40 | """ 41 | self.n_substeps = n_substeps 42 | self._target_radius = _EASY_TARGET_SIZE 43 | self.stack_obs = stack_obs 44 | self.stack_buffer = deque(maxlen=3) 45 | DmControlEnv.__init__( 46 | self, model_path, n_substeps=n_substeps, n_actions=n_actions, initial_qpos=None, default_camera_name='cam0', 47 | **kwargs) 48 | 49 | utils.EzPickle.__init__(self) 50 | 51 | self.goal_range = goal_range 52 | 53 | def get_model_and_assets(self): 54 | """Returns a tuple containing the model XML string and a dict of assets. 55 | 56 | Args: 57 | n_joints: An integer specifying the number of joints in the swimmer. 58 | 59 | Returns: 60 | A tuple `(model_xml_string, assets)`, where `assets` is a dict consisting of 61 | `{filename: contents_string}` pairs. 62 | """ 63 | 64 | return common.read_model('finger.xml'), common.ASSETS 65 | 66 | def touch(self): 67 | """Returns logarithmically scaled signals from the two touch sensors.""" 68 | return np.log1p(self.physics.named.data.sensordata[['touchtop', 'touchbottom']]) 69 | 70 | def hinge_velocity(self): 71 | """Returns the velocity of the hinge joint.""" 72 | return self.physics.named.data.sensordata['hinge_velocity'] 73 | 74 | def tip_position(self): 75 | """Returns the (x,z) position of the tip relative to the hinge.""" 76 | return (self.physics.named.data.sensordata['tip'][[0, 2]] - 77 | self.physics.named.data.sensordata['spinner'][[0, 2]]) 78 | 79 | def bounded_position(self): 80 | """Returns the positions, with the hinge angle replaced by tip position.""" 81 | return np.hstack((self.physics.named.data.sensordata[['proximal', 'distal']], 82 | self.tip_position())) 83 | 84 | def velocity(self): 85 | """Returns the velocities (extracted from sensordata).""" 86 | return self.physics.named.data.sensordata[['proximal_velocity', 'distal_velocity', 'hinge_velocity']] 87 | 88 | def target_position(self): 89 | """Returns the (x,z) position of the target relative to the hinge.""" 90 | return (self.physics.named.data.sensordata['target'][[0, 2]] - 91 | self.physics.named.data.sensordata['spinner'][[0, 2]]) 92 | 93 | def to_target(self): 94 | """Returns the vector from the tip to the target.""" 95 | return self.target_position() - self.tip_position() 96 | 97 | def dist_to_target(self): 98 | """Returns the signed distance to the target surface, negative is inside.""" 99 | return (np.linalg.norm(self.to_target()) - 100 | self.physics.named.model.site_size['target', 0]) 101 | 102 | def _reset_sim(self): 103 | self.time_step = 0 104 | with self.physics.reset_context(): 105 | target_angle = self.random.uniform(-np.pi, np.pi) 106 | self._set_to_goal(target_angle) 107 | 108 | hinge_x, hinge_z = self.physics.named.data.xanchor['hinge', ['x', 'z']] 109 | radius = self.physics.named.model.geom_size['cap1'].sum() 110 | target_x = hinge_x + radius * np.sin(-target_angle) 111 | target_z = hinge_z + radius * np.cos(-target_angle) 112 | self.physics.named.model.site_pos['target', ['x', 'z']] = target_x, target_z 113 | self.physics.named.model.site_size['target', 0] = self._target_radius 114 | # self.physics.named.model.site_size['tip', 0] = 0.02 115 | self.physics.named.model.site_size['tip', 0] = 0.05 116 | self.goal_image = self.render() 117 | with self.physics.reset_context(): 118 | self._set_random_joint_angles(self.physics, self.random) 119 | self.goal_state = self.target_position().copy() 120 | self.stack_buffer.clear() 121 | return True 122 | 123 | def _set_to_goal(self, target_angle): 124 | with self.physics.reset_context(): 125 | self.physics.named.data.qpos['hinge'] = target_angle 126 | 127 | def _set_random_joint_angles(self, physics, random, max_attempts=1000): 128 | """Sets the joints to a random collision-free state.""" 129 | for _ in range(max_attempts): 130 | randomizers.randomize_limited_and_rotational_joints(physics, random) 131 | # Check for collisions. 132 | physics.after_reset() 133 | if physics.data.ncon == 0: 134 | break 135 | else: 136 | raise RuntimeError('Could not find a collision-free state ' 137 | 'after {} attempts'.format(max_attempts)) 138 | 139 | # def _get_obs_stack(self, image_size=None, camera_name=None, depth=True): 140 | # try: 141 | # for _ in range(int(self.n_substeps / 3) + 1): 142 | # self.physics.step() 143 | # except PhysicsError as ex: 144 | # print(colored(ex, 'red')) 145 | # 146 | # ob1 = self.render().copy() 147 | # 148 | # try: 149 | # for _ in range(int(self.n_substeps / 3) + 1): 150 | # self.physics.step() 151 | # except PhysicsError as ex: 152 | # print(colored(ex, 'red')) 153 | # 154 | # ob2 = self.render().copy() 155 | # 156 | # try: 157 | # for _ in range(int(self.n_substeps / 3) + 1): 158 | # self.physics.step() 159 | # except PhysicsError as ex: 160 | # print(colored(ex, 'red')) 161 | # 162 | # ob3 = self.render().copy() 163 | # obs = np.dstack((ob1, ob2, ob3)) 164 | # return { 165 | # 'observation': obs.copy(), 166 | # 'achieved_goal': self.tip_position().copy(), 167 | # 'desired_goal': self.target_position().copy() 168 | # } 169 | 170 | def _get_obs(self): 171 | # traceback.print_stack() 172 | if self.use_visual_observation: 173 | if self.stack_obs: 174 | steps = [x[1] for x in self.stack_buffer] 175 | if len(self.stack_buffer) > 0 and self.time_step == self.stack_buffer[-1][1]: 176 | obs = np.concatenate([ob[0] for ob in self.stack_buffer], axis=-1) 177 | else: 178 | try: 179 | for _ in range(int(self.n_substeps) + 1): 180 | self.physics.step() 181 | except PhysicsError as ex: 182 | print(colored(ex, 'red')) 183 | new_obs = self.render() 184 | self.stack_buffer.append((new_obs.copy(), self.time_step)) 185 | # Fill the buffer 186 | while len(self.stack_buffer) < self.stack_buffer.maxlen: 187 | self.stack_buffer.append((new_obs.copy(), self.time_step)) 188 | obs = np.concatenate([ob[0] for ob in self.stack_buffer], axis=-1) 189 | else: 190 | try: 191 | for _ in range(int(self.n_substeps) + 1): 192 | self.physics.step() 193 | except PhysicsError as ex: 194 | print(colored(ex, 'red')) 195 | 196 | obs = self.render().copy() 197 | else: 198 | assert False 199 | obs = np.concatenate((self.bounded_position().flatten().copy(), self.velocity().flatten().copy(), 200 | self.touch().flatten().copy(), 201 | self.target_position().flatten().copy(), self.dist_to_target().flatten().copy())) 202 | 203 | if self.use_visual_observation and self.stack_buffer: 204 | achieved_goal = np.tile(obs[:, :, -4:], [1, 1, 3]) 205 | desired_goal = np.tile(self.goal_image, [1, 1, 3]) 206 | else: 207 | achieved_goal = obs.copy() 208 | desired_goal = self.goal_image.copy() 209 | return { 210 | 'observation': obs.copy(), 211 | 'achieved_goal': achieved_goal.copy(), 212 | 'desired_goal': desired_goal.copy() 213 | } 214 | 215 | def get_current_info(self): 216 | 217 | info = { 218 | 'is_success': self._is_success(self.tip_position().copy(), self.target_position().copy()), 219 | 'ag_state': self.tip_position().copy(), 220 | 'g_state': self.target_position().copy() 221 | } 222 | 223 | return info 224 | 225 | def _is_success(self, achieved_goal, desired_goal): 226 | achieved_goal = achieved_goal.reshape([-1, self.goal_state_dim]) 227 | desired_goal = desired_goal.reshape([-1, self.goal_state_dim]) 228 | d = np.linalg.norm(achieved_goal - desired_goal, axis=-1) - self.physics.named.model.site_size['target', 0] 229 | return (d <= self.distance_threshold).astype(np.float32) 230 | 231 | def _get_info_state(self, achieved_goal, desired_goal): 232 | # Given g, ag in state space and return the distance and success 233 | achieved_goal = achieved_goal.reshape([-1, self.goal_state_dim]) 234 | desired_goal = desired_goal.reshape([-1, self.goal_state_dim]) 235 | d = np.linalg.norm(achieved_goal - desired_goal, axis=-1) - self.physics.named.model.site_size['target', 0] 236 | return d, (d < self.distance_threshold).astype(np.float32) 237 | -------------------------------------------------------------------------------- /envs/goal_env_ext/dm_control/finger.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /envs/goal_env_ext/dm_control/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Utility functions used in the control suite.""" 17 | -------------------------------------------------------------------------------- /envs/goal_env_ext/dm_control/utils/parse_amc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Parse and convert amc motion capture data.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | 24 | from dm_control.mujoco.wrapper import mjbindings 25 | import numpy as np 26 | from scipy import interpolate 27 | from six.moves import range 28 | 29 | mjlib = mjbindings.mjlib 30 | 31 | MOCAP_DT = 1.0/120.0 32 | CONVERSION_LENGTH = 0.056444 33 | 34 | _CMU_MOCAP_JOINT_ORDER = ( 35 | 'root0', 'root1', 'root2', 'root3', 'root4', 'root5', 'lowerbackrx', 36 | 'lowerbackry', 'lowerbackrz', 'upperbackrx', 'upperbackry', 'upperbackrz', 37 | 'thoraxrx', 'thoraxry', 'thoraxrz', 'lowerneckrx', 'lowerneckry', 38 | 'lowerneckrz', 'upperneckrx', 'upperneckry', 'upperneckrz', 'headrx', 39 | 'headry', 'headrz', 'rclaviclery', 'rclaviclerz', 'rhumerusrx', 40 | 'rhumerusry', 'rhumerusrz', 'rradiusrx', 'rwristry', 'rhandrx', 'rhandrz', 41 | 'rfingersrx', 'rthumbrx', 'rthumbrz', 'lclaviclery', 'lclaviclerz', 42 | 'lhumerusrx', 'lhumerusry', 'lhumerusrz', 'lradiusrx', 'lwristry', 43 | 'lhandrx', 'lhandrz', 'lfingersrx', 'lthumbrx', 'lthumbrz', 'rfemurrx', 44 | 'rfemurry', 'rfemurrz', 'rtibiarx', 'rfootrx', 'rfootrz', 'rtoesrx', 45 | 'lfemurrx', 'lfemurry', 'lfemurrz', 'ltibiarx', 'lfootrx', 'lfootrz', 46 | 'ltoesrx' 47 | ) 48 | 49 | Converted = collections.namedtuple('Converted', 50 | ['qpos', 'qvel', 'time']) 51 | 52 | 53 | def convert(file_name, physics, timestep): 54 | """Converts the parsed .amc values into qpos and qvel values and resamples. 55 | 56 | Args: 57 | file_name: The .amc file to be parsed and converted. 58 | physics: The corresponding physics instance. 59 | timestep: Desired output interval between resampled frames. 60 | 61 | Returns: 62 | A namedtuple with fields: 63 | `qpos`, a numpy array containing converted positional variables. 64 | `qvel`, a numpy array containing converted velocity variables. 65 | `time`, a numpy array containing the corresponding times. 66 | """ 67 | frame_values = parse(file_name) 68 | joint2index = {} 69 | for name in physics.named.data.qpos.axes.row.names: 70 | joint2index[name] = physics.named.data.qpos.axes.row.convert_key_item(name) 71 | index2joint = {} 72 | for joint, index in joint2index.items(): 73 | if isinstance(index, slice): 74 | indices = range(index.start, index.stop) 75 | else: 76 | indices = [index] 77 | for ii in indices: 78 | index2joint[ii] = joint 79 | 80 | # Convert frame_values to qpos 81 | amcvals2qpos_transformer = Amcvals2qpos(index2joint, _CMU_MOCAP_JOINT_ORDER) 82 | qpos_values = [] 83 | for frame_value in frame_values: 84 | qpos_values.append(amcvals2qpos_transformer(frame_value)) 85 | qpos_values = np.stack(qpos_values) # Time by nq 86 | 87 | # Interpolate/resample. 88 | # Note: interpolate quaternions rather than euler angles (slerp). 89 | # see https://en.wikipedia.org/wiki/Slerp 90 | qpos_values_resampled = [] 91 | time_vals = np.arange(0, len(frame_values)*MOCAP_DT - 1e-8, MOCAP_DT) 92 | time_vals_new = np.arange(0, len(frame_values)*MOCAP_DT, timestep) 93 | while time_vals_new[-1] > time_vals[-1]: 94 | time_vals_new = time_vals_new[:-1] 95 | 96 | for i in range(qpos_values.shape[1]): 97 | f = interpolate.splrep(time_vals, qpos_values[:, i]) 98 | qpos_values_resampled.append(interpolate.splev(time_vals_new, f)) 99 | 100 | qpos_values_resampled = np.stack(qpos_values_resampled) # nq by ntime 101 | 102 | qvel_list = [] 103 | for t in range(qpos_values_resampled.shape[1]-1): 104 | p_tp1 = qpos_values_resampled[:, t + 1] 105 | p_t = qpos_values_resampled[:, t] 106 | qvel = [(p_tp1[:3]-p_t[:3])/ timestep, 107 | mj_quat2vel(mj_quatdiff(p_t[3:7], p_tp1[3:7]), timestep), 108 | (p_tp1[7:]-p_t[7:])/ timestep] 109 | qvel_list.append(np.concatenate(qvel)) 110 | 111 | qvel_values_resampled = np.vstack(qvel_list).T 112 | 113 | return Converted(qpos_values_resampled, qvel_values_resampled, time_vals_new) 114 | 115 | 116 | def parse(file_name): 117 | """Parses the amc file format.""" 118 | values = [] 119 | fid = open(file_name, 'r') 120 | line = fid.readline().strip() 121 | frame_ind = 1 122 | first_frame = True 123 | while True: 124 | # Parse first frame. 125 | if first_frame and line[0] == str(frame_ind): 126 | first_frame = False 127 | frame_ind += 1 128 | frame_vals = [] 129 | while True: 130 | line = fid.readline().strip() 131 | if not line or line == str(frame_ind): 132 | values.append(np.array(frame_vals, dtype=np.float)) 133 | break 134 | tokens = line.split() 135 | frame_vals.extend(tokens[1:]) 136 | # Parse other frames. 137 | elif line == str(frame_ind): 138 | frame_ind += 1 139 | frame_vals = [] 140 | while True: 141 | line = fid.readline().strip() 142 | if not line or line == str(frame_ind): 143 | values.append(np.array(frame_vals, dtype=np.float)) 144 | break 145 | tokens = line.split() 146 | frame_vals.extend(tokens[1:]) 147 | else: 148 | line = fid.readline().strip() 149 | if not line: 150 | break 151 | return values 152 | 153 | 154 | class Amcvals2qpos(object): 155 | """Callable that converts .amc values for a frame and to MuJoCo qpos format. 156 | """ 157 | 158 | def __init__(self, index2joint, joint_order): 159 | """Initializes a new Amcvals2qpos instance. 160 | 161 | Args: 162 | index2joint: List of joint angles in .amc file. 163 | joint_order: List of joint names in MuJoco MJCF. 164 | """ 165 | # Root is x,y,z, then quat. 166 | # need to get indices of qpos that order for amc default order 167 | self.qpos_root_xyz_ind = [0, 1, 2] 168 | self.root_xyz_ransform = np.array( 169 | [[1, 0, 0], [0, 0, -1], [0, 1, 0]]) * CONVERSION_LENGTH 170 | self.qpos_root_quat_ind = [3, 4, 5, 6] 171 | amc2qpos_transform = np.zeros((len(index2joint), len(joint_order))) 172 | for i in range(len(index2joint)): 173 | for j in range(len(joint_order)): 174 | if index2joint[i] == joint_order[j]: 175 | if 'rx' in index2joint[i]: 176 | amc2qpos_transform[i][j] = 1 177 | elif 'ry' in index2joint[i]: 178 | amc2qpos_transform[i][j] = 1 179 | elif 'rz' in index2joint[i]: 180 | amc2qpos_transform[i][j] = 1 181 | self.amc2qpos_transform = amc2qpos_transform 182 | 183 | def __call__(self, amc_val): 184 | """Converts a `.amc` frame to MuJoCo qpos format.""" 185 | amc_val_rad = np.deg2rad(amc_val) 186 | qpos = np.dot(self.amc2qpos_transform, amc_val_rad) 187 | 188 | # Root. 189 | qpos[:3] = np.dot(self.root_xyz_ransform, amc_val[:3]) 190 | qpos_quat = euler2quat(amc_val[3], amc_val[4], amc_val[5]) 191 | qpos_quat = mj_quatprod(euler2quat(90, 0, 0), qpos_quat) 192 | 193 | for i, ind in enumerate(self.qpos_root_quat_ind): 194 | qpos[ind] = qpos_quat[i] 195 | 196 | return qpos 197 | 198 | 199 | def euler2quat(ax, ay, az): 200 | """Converts euler angles to a quaternion. 201 | 202 | Note: rotation order is zyx 203 | 204 | Args: 205 | ax: Roll angle (deg) 206 | ay: Pitch angle (deg). 207 | az: Yaw angle (deg). 208 | 209 | Returns: 210 | A numpy array representing the rotation as a quaternion. 211 | """ 212 | r1 = az 213 | r2 = ay 214 | r3 = ax 215 | 216 | c1 = np.cos(np.deg2rad(r1 / 2)) 217 | s1 = np.sin(np.deg2rad(r1 / 2)) 218 | c2 = np.cos(np.deg2rad(r2 / 2)) 219 | s2 = np.sin(np.deg2rad(r2 / 2)) 220 | c3 = np.cos(np.deg2rad(r3 / 2)) 221 | s3 = np.sin(np.deg2rad(r3 / 2)) 222 | 223 | q0 = c1 * c2 * c3 + s1 * s2 * s3 224 | q1 = c1 * c2 * s3 - s1 * s2 * c3 225 | q2 = c1 * s2 * c3 + s1 * c2 * s3 226 | q3 = s1 * c2 * c3 - c1 * s2 * s3 227 | 228 | return np.array([q0, q1, q2, q3]) 229 | 230 | 231 | def mj_quatprod(q, r): 232 | quaternion = np.zeros(4) 233 | mjlib.mju_mulQuat(quaternion, np.ascontiguousarray(q), 234 | np.ascontiguousarray(r)) 235 | return quaternion 236 | 237 | 238 | def mj_quat2vel(q, dt): 239 | vel = np.zeros(3) 240 | mjlib.mju_quat2Vel(vel, np.ascontiguousarray(q), dt) 241 | return vel 242 | 243 | 244 | def mj_quatneg(q): 245 | quaternion = np.zeros(4) 246 | mjlib.mju_negQuat(quaternion, np.ascontiguousarray(q)) 247 | return quaternion 248 | 249 | 250 | def mj_quatdiff(source, target): 251 | return mj_quatprod(mj_quatneg(source), np.ascontiguousarray(target)) 252 | -------------------------------------------------------------------------------- /envs/goal_env_ext/dm_control/utils/parse_amc_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Tests for parse_amc utility.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | # Internal dependencies. 25 | 26 | from absl.testing import absltest 27 | from dm_control.suite import humanoid_CMU 28 | from dm_control.suite.utils import parse_amc 29 | 30 | from dm_control.utils import io as resources 31 | 32 | _TEST_AMC_PATH = resources.GetResourceFilename( 33 | os.path.join(os.path.dirname(__file__), '../demos/zeros.amc')) 34 | 35 | 36 | class ParseAMCTest(absltest.TestCase): 37 | 38 | def test_sizes_of_parsed_data(self): 39 | 40 | # Instantiate the humanoid environment. 41 | env = humanoid_CMU.stand() 42 | 43 | # Parse and convert specified clip. 44 | converted = parse_amc.convert( 45 | _TEST_AMC_PATH, env.physics, env.control_timestep()) 46 | 47 | self.assertEqual(converted.qpos.shape[0], 63) 48 | self.assertEqual(converted.qvel.shape[0], 62) 49 | self.assertEqual(converted.time.shape[0], converted.qpos.shape[1]) 50 | self.assertEqual(converted.qpos.shape[1], 51 | converted.qvel.shape[1] + 1) 52 | 53 | # Parse and convert specified clip -- WITH SMALLER TIMESTEP 54 | converted2 = parse_amc.convert( 55 | _TEST_AMC_PATH, env.physics, 0.5 * env.control_timestep()) 56 | 57 | self.assertEqual(converted2.qpos.shape[0], 63) 58 | self.assertEqual(converted2.qvel.shape[0], 62) 59 | self.assertEqual(converted2.time.shape[0], converted2.qpos.shape[1]) 60 | self.assertEqual(converted.qpos.shape[1], 61 | converted.qvel.shape[1] + 1) 62 | 63 | # Compare sizes of parsed objects for different timesteps 64 | self.assertEqual(converted.qpos.shape[1] * 2, converted2.qpos.shape[1]) 65 | 66 | 67 | if __name__ == '__main__': 68 | absltest.main() 69 | -------------------------------------------------------------------------------- /envs/goal_env_ext/dm_control/utils/randomizers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Randomization functions.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from dm_control.mujoco.wrapper import mjbindings 23 | import numpy as np 24 | from six.moves import range 25 | 26 | 27 | def random_limited_quaternion(random, limit): 28 | """Generates a random quaternion limited to the specified rotations.""" 29 | axis = random.randn(3) 30 | axis /= np.linalg.norm(axis) 31 | angle = random.rand() * limit 32 | 33 | quaternion = np.zeros(4) 34 | mjbindings.mjlib.mju_axisAngle2Quat(quaternion, axis, angle) 35 | 36 | return quaternion 37 | 38 | 39 | def randomize_limited_and_rotational_joints(physics, random=None): 40 | """Randomizes the positions of joints defined in the physics body. 41 | 42 | The following randomization rules apply: 43 | - Bounded joints (hinges or sliders) are sampled uniformly in the bounds. 44 | - Unbounded hinges are samples uniformly in [-pi, pi] 45 | - Quaternions for unlimited free joints and ball joints are sampled 46 | uniformly on the unit 3-sphere. 47 | - Quaternions for limited ball joints are sampled uniformly on a sector 48 | of the unit 3-sphere. 49 | - The linear degrees of freedom of free joints are not randomized. 50 | 51 | Args: 52 | physics: Instance of 'Physics' class that holds a loaded model. 53 | random: Optional instance of 'np.random.RandomState'. Defaults to the global 54 | NumPy random state. 55 | """ 56 | random = random or np.random 57 | 58 | hinge = mjbindings.enums.mjtJoint.mjJNT_HINGE 59 | slide = mjbindings.enums.mjtJoint.mjJNT_SLIDE 60 | ball = mjbindings.enums.mjtJoint.mjJNT_BALL 61 | free = mjbindings.enums.mjtJoint.mjJNT_FREE 62 | 63 | qpos = physics.named.data.qpos 64 | 65 | for joint_id in range(physics.model.njnt): 66 | joint_name = physics.model.id2name(joint_id, 'joint') 67 | joint_type = physics.model.jnt_type[joint_id] 68 | is_limited = physics.model.jnt_limited[joint_id] 69 | range_min, range_max = physics.model.jnt_range[joint_id] 70 | 71 | if is_limited: 72 | if joint_type == hinge or joint_type == slide: 73 | qpos[joint_name] = random.uniform(range_min, range_max) 74 | 75 | elif joint_type == ball: 76 | qpos[joint_name] = random_limited_quaternion(random, range_max) 77 | 78 | else: 79 | if joint_type == hinge: 80 | qpos[joint_name] = random.uniform(-np.pi, np.pi) 81 | 82 | elif joint_type == ball: 83 | quat = random.randn(4) 84 | quat /= np.linalg.norm(quat) 85 | qpos[joint_name] = quat 86 | 87 | elif joint_type == free: 88 | quat = random.rand(4) 89 | quat /= np.linalg.norm(quat) 90 | qpos[joint_name][3:] = quat 91 | 92 | -------------------------------------------------------------------------------- /envs/goal_env_ext/dm_control/utils/randomizers_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Tests for randomizers.py.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # Internal dependencies. 23 | from absl.testing import absltest 24 | from absl.testing import parameterized 25 | from dm_control import mujoco 26 | from dm_control.mujoco.wrapper import mjbindings 27 | from dm_control.suite.utils import randomizers 28 | import numpy as np 29 | from six.moves import range 30 | 31 | mjlib = mjbindings.mjlib 32 | 33 | 34 | class RandomizeUnlimitedJointsTest(parameterized.TestCase): 35 | 36 | def setUp(self): 37 | self.rand = np.random.RandomState(100) 38 | 39 | def test_single_joint_of_each_type(self): 40 | physics = mujoco.Physics.from_xml_string(""" 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | """) 66 | 67 | randomizers.randomize_limited_and_rotational_joints(physics, self.rand) 68 | self.assertNotEqual(0., physics.named.data.qpos['hinge']) 69 | self.assertNotEqual(0., physics.named.data.qpos['limited_hinge']) 70 | self.assertNotEqual(0., physics.named.data.qpos['limited_slide']) 71 | 72 | self.assertNotEqual(0., np.sum(physics.named.data.qpos['ball'])) 73 | self.assertNotEqual(0., np.sum(physics.named.data.qpos['limited_ball'])) 74 | 75 | self.assertNotEqual(0., np.sum(physics.named.data.qpos['free'][3:])) 76 | 77 | # Unlimited slide and the positional part of the free joint remains 78 | # uninitialized. 79 | self.assertEqual(0., physics.named.data.qpos['slide']) 80 | self.assertEqual(0., np.sum(physics.named.data.qpos['free'][:3])) 81 | 82 | def test_multiple_joints_of_same_type(self): 83 | physics = mujoco.Physics.from_xml_string(""" 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | """) 93 | 94 | randomizers.randomize_limited_and_rotational_joints(physics, self.rand) 95 | self.assertNotEqual(0., physics.named.data.qpos['hinge_1']) 96 | self.assertNotEqual(0., physics.named.data.qpos['hinge_2']) 97 | self.assertNotEqual(0., physics.named.data.qpos['hinge_3']) 98 | 99 | self.assertNotEqual(physics.named.data.qpos['hinge_1'], 100 | physics.named.data.qpos['hinge_2']) 101 | 102 | self.assertNotEqual(physics.named.data.qpos['hinge_2'], 103 | physics.named.data.qpos['hinge_3']) 104 | 105 | self.assertNotEqual(physics.named.data.qpos['hinge_1'], 106 | physics.named.data.qpos['hinge_3']) 107 | 108 | def test_unlimited_hinge_randomization_range(self): 109 | physics = mujoco.Physics.from_xml_string(""" 110 | 111 | 112 | 113 | 114 | 115 | 116 | """) 117 | 118 | for _ in range(10): 119 | randomizers.randomize_limited_and_rotational_joints(physics, self.rand) 120 | self.assertBetween(physics.named.data.qpos['hinge'], -np.pi, np.pi) 121 | 122 | def test_limited_1d_joint_limits_are_respected(self): 123 | physics = mujoco.Physics.from_xml_string(""" 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | """) 135 | 136 | for _ in range(10): 137 | randomizers.randomize_limited_and_rotational_joints(physics, self.rand) 138 | self.assertBetween(physics.named.data.qpos['hinge'], 139 | np.deg2rad(0), np.deg2rad(10)) 140 | self.assertBetween(physics.named.data.qpos['slide'], 30, 50) 141 | 142 | def test_limited_ball_joint_are_respected(self): 143 | physics = mujoco.Physics.from_xml_string(""" 144 | 145 | 146 | 147 | 148 | 149 | 150 | """) 151 | 152 | body_axis = np.array([1., 0., 0.]) 153 | joint_axis = np.zeros(3) 154 | for _ in range(10): 155 | randomizers.randomize_limited_and_rotational_joints(physics, self.rand) 156 | 157 | quat = physics.named.data.qpos['ball'] 158 | mjlib.mju_rotVecQuat(joint_axis, body_axis, quat) 159 | angle_cos = np.dot(body_axis, joint_axis) 160 | self.assertGreater(angle_cos, 0.5) # cos(60) = 0.5 161 | 162 | 163 | if __name__ == '__main__': 164 | absltest.main() 165 | -------------------------------------------------------------------------------- /envs/goal_env_ext/dm_control/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Environment wrappers used to extend or modify environment behaviour.""" 17 | -------------------------------------------------------------------------------- /envs/goal_env_ext/dm_control/wrappers/action_noise.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Wrapper control suite environments that adds Gaussian noise to actions.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | from dm_control.rl import environment 24 | 25 | 26 | _BOUNDS_MUST_BE_FINITE = ( 27 | 'All bounds in `env.action_spec()` must be finite, got: {action_spec}') 28 | 29 | 30 | class Wrapper(environment.Base): 31 | """Wraps a control environment and adds Gaussian noise to actions.""" 32 | 33 | def __init__(self, env, scale=0.01): 34 | """Initializes a new action noise Wrapper. 35 | 36 | Args: 37 | env: The control suite environment to wrap. 38 | scale: The standard deviation of the noise, expressed as a fraction 39 | of the max-min range for each action dimension. 40 | 41 | Raises: 42 | ValueError: If any of the action dimensions of the wrapped environment are 43 | unbounded. 44 | """ 45 | action_spec = env.action_spec() 46 | if not (np.all(np.isfinite(action_spec.minimum)) and 47 | np.all(np.isfinite(action_spec.maximum))): 48 | raise ValueError(_BOUNDS_MUST_BE_FINITE.format(action_spec=action_spec)) 49 | self._minimum = action_spec.minimum 50 | self._maximum = action_spec.maximum 51 | self._noise_std = scale * (action_spec.maximum - action_spec.minimum) 52 | self._env = env 53 | 54 | def step(self, action): 55 | noisy_action = action + self._env.task.random.normal(scale=self._noise_std) 56 | # Clip the noisy actions in place so that they fall within the bounds 57 | # specified by the `action_spec`. Note that MuJoCo implicitly clips out-of- 58 | # bounds control inputs, but we also clip here in case the actions do not 59 | # correspond directly to MuJoCo actuators, or if there are other wrapper 60 | # layers that expect the actions to be within bounds. 61 | np.clip(noisy_action, self._minimum, self._maximum, out=noisy_action) 62 | return self._env.step(noisy_action) 63 | 64 | def reset(self): 65 | return self._env.reset() 66 | 67 | def observation_spec(self): 68 | return self._env.observation_spec() 69 | 70 | def action_spec(self): 71 | return self._env.action_spec() 72 | 73 | def __getattr__(self, name): 74 | return getattr(self._env, name) 75 | -------------------------------------------------------------------------------- /envs/goal_env_ext/dm_control/wrappers/action_noise_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Tests for the action noise wrapper.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # Internal dependencies. 23 | from absl.testing import absltest 24 | from absl.testing import parameterized 25 | from dm_control.rl import control 26 | from dm_control.suite.wrappers import action_noise 27 | import mock 28 | import numpy as np 29 | from dm_control.rl import specs 30 | 31 | 32 | class ActionNoiseTest(parameterized.TestCase): 33 | 34 | def make_action_spec(self, lower=(-1.,), upper=(1.,)): 35 | lower, upper = np.broadcast_arrays(lower, upper) 36 | return specs.BoundedArraySpec( 37 | shape=lower.shape, dtype=float, minimum=lower, maximum=upper) 38 | 39 | def make_mock_env(self, action_spec=None): 40 | action_spec = action_spec or self.make_action_spec() 41 | env = mock.Mock(spec=control.Environment) 42 | env.action_spec.return_value = action_spec 43 | return env 44 | 45 | def assertStepCalledOnceWithCorrectAction(self, env, expected_action): 46 | # NB: `assert_called_once_with()` doesn't support numpy arrays. 47 | env.step.assert_called_once() 48 | actual_action = env.step.call_args_list[0][0][0] 49 | np.testing.assert_array_equal(expected_action, actual_action) 50 | 51 | @parameterized.parameters([ 52 | dict(lower=np.r_[-1., 0.], upper=np.r_[1., 2.], scale=0.05), 53 | dict(lower=np.r_[-1., 0.], upper=np.r_[1., 2.], scale=0.), 54 | dict(lower=np.r_[-1., 0.], upper=np.r_[-1., 0.], scale=0.05), 55 | ]) 56 | def test_step(self, lower, upper, scale): 57 | seed = 0 58 | std = scale * (upper - lower) 59 | expected_noise = np.random.RandomState(seed).normal(scale=std) 60 | action = np.random.RandomState(seed).uniform(lower, upper) 61 | expected_noisy_action = np.clip(action + expected_noise, lower, upper) 62 | task = mock.Mock(spec=control.Task) 63 | task.random = np.random.RandomState(seed) 64 | action_spec = self.make_action_spec(lower=lower, upper=upper) 65 | env = self.make_mock_env(action_spec=action_spec) 66 | env.task = task 67 | wrapped_env = action_noise.Wrapper(env, scale=scale) 68 | time_step = wrapped_env.step(action) 69 | self.assertStepCalledOnceWithCorrectAction(env, expected_noisy_action) 70 | self.assertIs(time_step, env.step(expected_noisy_action)) 71 | 72 | @parameterized.named_parameters([ 73 | dict(testcase_name='within_bounds', action=np.r_[-1.], noise=np.r_[0.1]), 74 | dict(testcase_name='below_lower', action=np.r_[-1.], noise=np.r_[-0.1]), 75 | dict(testcase_name='above_upper', action=np.r_[1.], noise=np.r_[0.1]), 76 | ]) 77 | def test_action_clipping(self, action, noise): 78 | lower = -1. 79 | upper = 1. 80 | expected_noisy_action = np.clip(action + noise, lower, upper) 81 | task = mock.Mock(spec=control.Task) 82 | task.random = mock.Mock(spec=np.random.RandomState) 83 | task.random.normal.return_value = noise 84 | action_spec = self.make_action_spec(lower=lower, upper=upper) 85 | env = self.make_mock_env(action_spec=action_spec) 86 | env.task = task 87 | wrapped_env = action_noise.Wrapper(env) 88 | time_step = wrapped_env.step(action) 89 | self.assertStepCalledOnceWithCorrectAction(env, expected_noisy_action) 90 | self.assertIs(time_step, env.step(expected_noisy_action)) 91 | 92 | @parameterized.parameters([ 93 | dict(lower=np.r_[-1., 0.], upper=np.r_[1., np.inf]), 94 | dict(lower=np.r_[np.nan, 0.], upper=np.r_[1., 2.]), 95 | ]) 96 | def test_error_if_action_bounds_non_finite(self, lower, upper): 97 | action_spec = self.make_action_spec(lower=lower, upper=upper) 98 | env = self.make_mock_env(action_spec=action_spec) 99 | with self.assertRaisesWithLiteralMatch( 100 | ValueError, 101 | action_noise._BOUNDS_MUST_BE_FINITE.format(action_spec=action_spec)): 102 | _ = action_noise.Wrapper(env) 103 | 104 | def test_reset(self): 105 | env = self.make_mock_env() 106 | wrapped_env = action_noise.Wrapper(env) 107 | time_step = wrapped_env.reset() 108 | env.reset.assert_called_once_with() 109 | self.assertIs(time_step, env.reset()) 110 | 111 | def test_observation_spec(self): 112 | env = self.make_mock_env() 113 | wrapped_env = action_noise.Wrapper(env) 114 | observation_spec = wrapped_env.observation_spec() 115 | env.observation_spec.assert_called_once_with() 116 | self.assertIs(observation_spec, env.observation_spec()) 117 | 118 | def test_action_spec(self): 119 | env = self.make_mock_env() 120 | wrapped_env = action_noise.Wrapper(env) 121 | # `env.action_spec()` is called in `Wrapper.__init__()` 122 | env.action_spec.reset_mock() 123 | action_spec = wrapped_env.action_spec() 124 | env.action_spec.assert_called_once_with() 125 | self.assertIs(action_spec, env.action_spec()) 126 | 127 | @parameterized.parameters(['task', 'physics', 'control_timestep']) 128 | def test_getattr(self, attribute_name): 129 | env = self.make_mock_env() 130 | wrapped_env = action_noise.Wrapper(env) 131 | attr = getattr(wrapped_env, attribute_name) 132 | self.assertIs(attr, getattr(env, attribute_name)) 133 | 134 | 135 | if __name__ == '__main__': 136 | absltest.main() 137 | -------------------------------------------------------------------------------- /envs/goal_env_ext/dm_control/wrappers/pixels.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Wrapper that adds pixel observations to a control environment.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | 24 | from dm_control.rl import environment 25 | from dm_control.rl import specs 26 | 27 | STATE_KEY = 'state' 28 | 29 | 30 | class Wrapper(environment.Base): 31 | """Wraps a control environment and adds a rendered pixel observation.""" 32 | 33 | def __init__(self, env, pixels_only=True, render_kwargs=None, 34 | observation_key='pixels'): 35 | """Initializes a new pixel Wrapper. 36 | 37 | Args: 38 | env: The environment to wrap. 39 | pixels_only: If True (default), the original set of 'state' observations 40 | returned by the wrapped environment will be discarded, and the 41 | `OrderedDict` of observations will only contain pixels. If False, the 42 | `OrderedDict` will contain the original observations as well as the 43 | pixel observations. 44 | render_kwargs: Optional `dict` containing keyword arguments passed to the 45 | `mujoco.Physics.render` method. 46 | observation_key: Optional custom string specifying the pixel observation's 47 | key in the `OrderedDict` of observations. Defaults to 'pixels'. 48 | 49 | Raises: 50 | ValueError: If `env`'s observation spec is not compatible with the 51 | wrapper. Supported formats are a single array, or a dict of arrays. 52 | ValueError: If `env`'s observation already contains the specified 53 | `observation_key`. 54 | """ 55 | if render_kwargs is None: 56 | render_kwargs = {} 57 | 58 | wrapped_observation_spec = env.observation_spec() 59 | 60 | if isinstance(wrapped_observation_spec, specs.ArraySpec): 61 | self._observation_is_dict = False 62 | invalid_keys = set([STATE_KEY]) 63 | elif isinstance(wrapped_observation_spec, collections.MutableMapping): 64 | self._observation_is_dict = True 65 | invalid_keys = set(wrapped_observation_spec.keys()) 66 | else: 67 | raise ValueError('Unsupported observation spec structure.') 68 | 69 | if not pixels_only and observation_key in invalid_keys: 70 | raise ValueError('Duplicate or reserved observation key {!r}.' 71 | .format(observation_key)) 72 | 73 | if pixels_only: 74 | self._observation_spec = collections.OrderedDict() 75 | elif self._observation_is_dict: 76 | self._observation_spec = wrapped_observation_spec.copy() 77 | else: 78 | self._observation_spec = collections.OrderedDict() 79 | self._observation_spec[STATE_KEY] = wrapped_observation_spec 80 | 81 | # Extend observation spec. 82 | pixels = env.physics.render(**render_kwargs) 83 | pixels_spec = specs.ArraySpec( 84 | shape=pixels.shape, dtype=pixels.dtype, name=observation_key) 85 | self._observation_spec[observation_key] = pixels_spec 86 | 87 | self._env = env 88 | self._pixels_only = pixels_only 89 | self._render_kwargs = render_kwargs 90 | self._observation_key = observation_key 91 | 92 | def reset(self): 93 | time_step = self._env.reset() 94 | return self._add_pixel_observation(time_step) 95 | 96 | def step(self, action): 97 | time_step = self._env.step(action) 98 | return self._add_pixel_observation(time_step) 99 | 100 | def observation_spec(self): 101 | return self._observation_spec 102 | 103 | def action_spec(self): 104 | return self._env.action_spec() 105 | 106 | def _add_pixel_observation(self, time_step): 107 | if self._pixels_only: 108 | observation = collections.OrderedDict() 109 | elif self._observation_is_dict: 110 | observation = type(time_step.observation)(time_step.observation) 111 | else: 112 | observation = collections.OrderedDict() 113 | observation[STATE_KEY] = time_step.observation 114 | 115 | pixels = self._env.physics.render(**self._render_kwargs) 116 | observation[self._observation_key] = pixels 117 | return time_step._replace(observation=observation) 118 | 119 | def __getattr__(self, name): 120 | return getattr(self._env, name) 121 | -------------------------------------------------------------------------------- /envs/goal_env_ext/dm_control/wrappers/pixels_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Tests for the pixel wrapper.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | 24 | # Internal dependencies. 25 | from absl.testing import absltest 26 | from absl.testing import parameterized 27 | from dm_control.suite import cartpole 28 | from dm_control.suite.wrappers import pixels 29 | 30 | import numpy as np 31 | 32 | from dm_control.rl import environment 33 | from dm_control.rl import specs 34 | 35 | 36 | class FakePhysics(object): 37 | 38 | def render(self, *args, **kwargs): 39 | del args 40 | del kwargs 41 | return np.zeros((4, 5, 3), dtype=np.uint8) 42 | 43 | 44 | class FakeArrayObservationEnvironment(environment.Base): 45 | 46 | def __init__(self): 47 | self.physics = FakePhysics() 48 | 49 | def reset(self): 50 | return environment.restart(np.zeros((2,))) 51 | 52 | def step(self, action): 53 | del action 54 | return environment.transition(0.0, np.zeros((2,))) 55 | 56 | def action_spec(self): 57 | pass 58 | 59 | def observation_spec(self): 60 | return specs.ArraySpec(shape=(2,), dtype=np.float) 61 | 62 | 63 | class PixelsTest(parameterized.TestCase): 64 | 65 | @parameterized.parameters(True, False) 66 | def test_dict_observation(self, pixels_only): 67 | pixel_key = 'rgb' 68 | 69 | env = cartpole.swingup() 70 | 71 | # Make sure we are testing the right environment for the test. 72 | observation_spec = env.observation_spec() 73 | self.assertIsInstance(observation_spec, collections.OrderedDict) 74 | 75 | width = 320 76 | height = 240 77 | 78 | # The wrapper should only add one observation. 79 | wrapped = pixels.Wrapper(env, 80 | observation_key=pixel_key, 81 | pixels_only=pixels_only, 82 | render_kwargs={'width': width, 'height': height}) 83 | 84 | wrapped_observation_spec = wrapped.observation_spec() 85 | self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict) 86 | 87 | if pixels_only: 88 | self.assertEqual(1, len(wrapped_observation_spec)) 89 | self.assertEqual([pixel_key], list(wrapped_observation_spec.keys())) 90 | else: 91 | self.assertEqual(len(observation_spec) + 1, len(wrapped_observation_spec)) 92 | expected_keys = list(observation_spec.keys()) + [pixel_key] 93 | self.assertEqual(expected_keys, list(wrapped_observation_spec.keys())) 94 | 95 | # Check that the added spec item is consistent with the added observation. 96 | time_step = wrapped.reset() 97 | rgb_observation = time_step.observation[pixel_key] 98 | wrapped_observation_spec[pixel_key].validate(rgb_observation) 99 | 100 | self.assertEqual(rgb_observation.shape, (height, width, 3)) 101 | self.assertEqual(rgb_observation.dtype, np.uint8) 102 | 103 | @parameterized.parameters(True, False) 104 | def test_single_array_observation(self, pixels_only): 105 | pixel_key = 'depth' 106 | 107 | env = FakeArrayObservationEnvironment() 108 | observation_spec = env.observation_spec() 109 | self.assertIsInstance(observation_spec, specs.ArraySpec) 110 | 111 | wrapped = pixels.Wrapper(env, observation_key=pixel_key, 112 | pixels_only=pixels_only) 113 | wrapped_observation_spec = wrapped.observation_spec() 114 | self.assertIsInstance(wrapped_observation_spec, collections.OrderedDict) 115 | 116 | if pixels_only: 117 | self.assertEqual(1, len(wrapped_observation_spec)) 118 | self.assertEqual([pixel_key], list(wrapped_observation_spec.keys())) 119 | else: 120 | self.assertEqual(2, len(wrapped_observation_spec)) 121 | self.assertEqual([pixels.STATE_KEY, pixel_key], 122 | list(wrapped_observation_spec.keys())) 123 | 124 | time_step = wrapped.reset() 125 | 126 | depth_observation = time_step.observation[pixel_key] 127 | wrapped_observation_spec[pixel_key].validate(depth_observation) 128 | 129 | self.assertEqual(depth_observation.shape, (4, 5, 3)) 130 | self.assertEqual(depth_observation.dtype, np.uint8) 131 | 132 | if __name__ == '__main__': 133 | absltest.main() 134 | -------------------------------------------------------------------------------- /envs/goal_env_ext/fetch/__init__.py: -------------------------------------------------------------------------------- 1 | # Created by Xingyu Lin, 04/09/2018 2 | -------------------------------------------------------------------------------- /envs/goal_env_ext/fetch/pick_and_place.py: -------------------------------------------------------------------------------- 1 | from gym import utils 2 | from envs.goal_env_ext.fetch import fetch_env 3 | 4 | 5 | class FetchPickAndPlaceEnv(fetch_env.FetchEnv, utils.EzPickle): 6 | def __init__(self, reward_type='sparse'): 7 | initial_qpos = { 8 | 'robot0:slide0': 0.405, 9 | 'robot0:slide1': 0.48, 10 | 'robot0:slide2': 0.0, 11 | 'table0:slide0': 1.05, 12 | 'table0:slide1': 0.4, 13 | 'table0:slide2': 0.0, 14 | 'object0:joint': [1.25, 0.53, 0.4, 1., 0., 0., 0.], 15 | } 16 | fetch_env.FetchEnv.__init__( 17 | self, 'fetch/pick_and_place.xml', has_object=True, block_gripper=False, n_substeps=20, 18 | gripper_extra_height=0.2, target_in_the_air=True, target_offset=0.0, 19 | obj_range=0.15, target_range=0.15, distance_threshold=0.05, 20 | initial_qpos=initial_qpos, reward_type=reward_type) 21 | utils.EzPickle.__init__(self) 22 | -------------------------------------------------------------------------------- /envs/goal_env_ext/fetch/push.py: -------------------------------------------------------------------------------- 1 | from gym import utils 2 | from envs.goal_env_ext.fetch import fetch_env 3 | 4 | 5 | class FetchPushEnv(fetch_env.FetchEnv, utils.EzPickle): 6 | def __init__(self, reward_type='sparse', distance_threshold=0.05, OLSAGP=0.0, n_substeps=20, **kwargs): 7 | initial_qpos = { 8 | 'robot0:slide0': 0.405, 9 | 'robot0:slide1': 0.48, 10 | 'robot0:slide2': 0.0, 11 | 'table0:slide0': 1.05, 12 | 'table0:slide1': 0.4, 13 | 'table0:slide2': 0.0, 14 | 'object0:joint': [1.25, 0.53, 0.4, 1., 0., 0., 0.], 15 | } 16 | fetch_env.FetchEnv.__init__( 17 | self, 'fetch/push.xml', has_object=True, block_gripper=True, n_substeps=n_substeps, 18 | gripper_extra_height=0.0, target_in_the_air=False, target_offset=0.0, 19 | obj_range=0.15, target_range=0.15, distance_threshold=distance_threshold, 20 | initial_qpos=initial_qpos, reward_type=reward_type, object_location_same_as_gripper_probability=OLSAGP, 21 | **kwargs) 22 | utils.EzPickle.__init__(self) 23 | -------------------------------------------------------------------------------- /envs/goal_env_ext/fetch/reach.py: -------------------------------------------------------------------------------- 1 | from gym import utils 2 | from envs.goal_env_ext.fetch import fetch_env 3 | 4 | 5 | class FetchReachEnv(fetch_env.FetchEnv, utils.EzPickle): 6 | def __init__(self, reward_type='sparse', distance_threshold=0.05, n_substeps=20, **kwargs): 7 | initial_qpos = { 8 | 'robot0:slide0': 0.4049, 9 | 'robot0:slide1': 0.48, 10 | 'robot0:slide2': 0.0, 11 | 'table0:slide0': 1.05, 12 | 'table0:slide1': 0.4, 13 | 'table0:slide2': 0.0, 14 | } 15 | fetch_env.FetchEnv.__init__( 16 | self, 'fetch/reach.xml', has_object=False, block_gripper=True, n_substeps=n_substeps, 17 | gripper_extra_height=0.2, target_in_the_air=True, target_offset=0.0, 18 | obj_range=0.15, target_range=0.15, distance_threshold=distance_threshold, 19 | initial_qpos=initial_qpos, reward_type=reward_type, **kwargs) 20 | utils.EzPickle.__init__(self) 21 | -------------------------------------------------------------------------------- /envs/goal_env_ext/fetch/slide.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gym import utils 4 | from envs.goal_env_ext.fetch import fetch_env 5 | 6 | 7 | class FetchSlideEnv(fetch_env.FetchEnv, utils.EzPickle): 8 | def __init__(self, reward_type='sparse', distance_threshold=0.05, n_substeps=20, **kwargs): 9 | initial_qpos = { 10 | 'robot0:slide0': 0.05, 11 | 'robot0:slide1': 0.48, 12 | 'robot0:slide2': 0.0, 13 | 'table0:slide0': 0.7, 14 | 'table0:slide1': 0.3, 15 | 'table0:slide2': 0.0, 16 | 'object0:joint': [1.7, 1.1, 0.4, 1., 0., 0., 0.], 17 | } 18 | fetch_env.FetchEnv.__init__( 19 | self, 'fetch/slide.xml', has_object=True, block_gripper=True, n_substeps=n_substeps, 20 | gripper_extra_height=-0.02, target_in_the_air=False, target_offset=np.array([0.4, 0.0, 0.0]), 21 | obj_range=0.1, target_range=0.3, distance_threshold=distance_threshold, 22 | initial_qpos=initial_qpos, reward_type=reward_type, object_location_same_as_gripper_probability = 0.0, **kwargs) 23 | utils.EzPickle.__init__(self) 24 | -------------------------------------------------------------------------------- /envs/goal_env_ext/hand/__init__.py: -------------------------------------------------------------------------------- 1 | # Created by Xingyu Lin, 04/09/2018 2 | -------------------------------------------------------------------------------- /envs/goal_env_ext/hand/hand_env.py: -------------------------------------------------------------------------------- 1 | # Created by Xingyu Lin, 04/09/2018 2 | import os 3 | import copy 4 | import numpy as np 5 | 6 | import gym 7 | from gym import error, spaces 8 | from gym.utils import seeding 9 | from envs.goal_env_ext.goal_env_ext import GoalEnvExt 10 | 11 | 12 | class HandEnv(GoalEnvExt): 13 | def __init__(self, model_path, n_substeps, initial_qpos, relative_control, 14 | distance_threshold, 15 | distance_threshold_obs, 16 | horizon, 17 | image_size, 18 | reward_type='sparse', 19 | use_true_reward=False, 20 | use_visual_observation=True, 21 | use_image_goal=True, 22 | with_goal=True, 23 | **kwargs): 24 | self.relative_control = relative_control 25 | 26 | super(HandEnv, self).__init__(model_path=model_path, n_substeps=n_substeps, n_actions=20, 27 | initial_qpos=initial_qpos, use_image_goal=use_image_goal, 28 | use_visual_observation=use_visual_observation, 29 | use_true_reward=use_true_reward, reward_type=reward_type, 30 | distance_threshold=distance_threshold, 31 | distance_threshold_obs=distance_threshold_obs, horizon=horizon, 32 | image_size=image_size, with_goal=with_goal, **kwargs) 33 | 34 | # RobotEnv methods 35 | # ---------------------------- 36 | 37 | def _set_action(self, action): 38 | assert action.shape == (20,) 39 | 40 | ctrlrange = self.sim.model.actuator_ctrlrange 41 | actuation_range = (ctrlrange[:, 1] - ctrlrange[:, 0]) / 2. 42 | if self.relative_control: 43 | actuation_center = np.zeros_like(action) 44 | for i in range(self.sim.data.ctrl.shape[0]): 45 | actuation_center[i] = self.sim.data.get_joint_qpos( 46 | self.sim.model.actuator_names[i].replace(':A_', ':')) 47 | for joint_name in ['FF', 'MF', 'RF', 'LF']: 48 | act_idx = self.sim.model.actuator_name2id( 49 | 'robot0:A_{}J1'.format(joint_name)) 50 | actuation_center[act_idx] += self.sim.data.get_joint_qpos( 51 | 'robot0:{}J0'.format(joint_name)) 52 | else: 53 | actuation_center = (ctrlrange[:, 1] + ctrlrange[:, 0]) / 2. 54 | self.sim.data.ctrl[:] = actuation_center + action * actuation_range 55 | self.sim.data.ctrl[:] = np.clip(self.sim.data.ctrl, ctrlrange[:, 0], ctrlrange[:, 1]) 56 | 57 | def _viewer_setup(self): 58 | body_id = self.sim.model.body_name2id('robot0:palm') 59 | lookat = self.sim.data.body_xpos[body_id] 60 | for idx, value in enumerate(lookat): 61 | self.viewer.cam.lookat[idx] = value 62 | self.viewer.cam.distance = 0.5 63 | self.viewer.cam.azimuth = 55. 64 | self.viewer.cam.elevation = -25. 65 | -------------------------------------------------------------------------------- /envs/goal_env_ext/hand/reach.py: -------------------------------------------------------------------------------- 1 | # Created by Xingyu Lin, 04/09/2018 2 | import numpy as np 3 | 4 | from gym import utils 5 | from envs.utils import robot_get_obs 6 | from envs.goal_env_ext.hand import hand_env 7 | 8 | FINGERTIP_SITE_NAMES = [ 9 | 'robot0:S_fftip', 10 | 'robot0:S_mftip', 11 | 'robot0:S_rftip', 12 | 'robot0:S_lftip', 13 | 'robot0:S_thtip', 14 | ] 15 | 16 | DEFAULT_INITIAL_QPOS = { 17 | 'robot0:WRJ1': -0.16514339750464327, 18 | 'robot0:WRJ0': -0.31973286565062153, 19 | 'robot0:FFJ3': 0.14340512546557435, 20 | 'robot0:FFJ2': 0.32028208333591573, 21 | 'robot0:FFJ1': 0.7126053607727917, 22 | 'robot0:FFJ0': 0.6705281001412586, 23 | 'robot0:MFJ3': 0.000246444303701037, 24 | 'robot0:MFJ2': 0.3152655251085491, 25 | 'robot0:MFJ1': 0.7659800313729842, 26 | 'robot0:MFJ0': 0.7323156897425923, 27 | 'robot0:RFJ3': 0.00038520700007378114, 28 | 'robot0:RFJ2': 0.36743546201985233, 29 | 'robot0:RFJ1': 0.7119514095008576, 30 | 'robot0:RFJ0': 0.6699446327514138, 31 | 'robot0:LFJ4': 0.0525442258033891, 32 | 'robot0:LFJ3': -0.13615534724474673, 33 | 'robot0:LFJ2': 0.39872030433433003, 34 | 'robot0:LFJ1': 0.7415570009679252, 35 | 'robot0:LFJ0': 0.704096378652974, 36 | 'robot0:THJ4': 0.003673823825070126, 37 | 'robot0:THJ3': 0.5506291436028695, 38 | 'robot0:THJ2': -0.014515151997119306, 39 | 'robot0:THJ1': -0.0015229223564485414, 40 | 'robot0:THJ0': -0.7894883021600622, 41 | } 42 | 43 | 44 | def goal_distance(goal_a, goal_b): 45 | assert goal_a.shape == goal_b.shape 46 | return np.linalg.norm(goal_a - goal_b, axis=-1) 47 | 48 | 49 | class HandReachEnv(hand_env.HandEnv, utils.EzPickle): 50 | def __init__(self, n_substeps=20, relative_control=False, 51 | initial_qpos=DEFAULT_INITIAL_QPOS, **kwargs): 52 | hand_env.HandEnv.__init__( 53 | self, 'hand/reach.xml', n_substeps=n_substeps, initial_qpos=initial_qpos, 54 | relative_control=relative_control, **kwargs) 55 | utils.EzPickle.__init__(self) 56 | 57 | def _get_achieved_goal(self): 58 | goal = [self.sim.data.get_site_xpos(name) for name in FINGERTIP_SITE_NAMES] 59 | return np.array(goal).flatten() 60 | 61 | # GoalEnvExt methods 62 | # ---------------------------- 63 | 64 | def _reset_sim(self): 65 | """Resets a simulation and indicates whether or not it was successful. 66 | If a reset was unsuccessful (e.g. if a randomized state caused an error in the 67 | simulation), this method should indicate such a failure by returning False. 68 | In such a case, this method will be called again to attempt a the reset again. 69 | """ 70 | self.sim.set_state(self.initial_state) 71 | self.sim.forward() 72 | for i in range(20): 73 | action = self.action_space.sample() 74 | action += self.np_random.normal(scale=5, size=action.shape) 75 | self._set_action(action) 76 | self.sim.step() 77 | self._step_callback() 78 | # for name, _ in self.init_qpos.items(): 79 | # value = np.random.random()*3.14 - 1.57 80 | # self.sim.data.set_joint_qpos(name, value) 81 | self.goal_state = self._get_achieved_goal() 82 | self.goal_observation = self.render(mode='rgb_array', depth=True) 83 | 84 | # Revert to original state 85 | self.sim.set_state(self.initial_state) 86 | self.sim.forward() 87 | return True 88 | 89 | def _env_setup(self, initial_qpos): 90 | for name, value in initial_qpos.items(): 91 | self.sim.data.set_joint_qpos(name, value) 92 | self.sim.forward() 93 | 94 | self.initial_goal = self._get_achieved_goal().copy() 95 | self.palm_xpos = self.sim.data.body_xpos[self.sim.model.body_name2id('robot0:palm')].copy() 96 | 97 | def _get_obs(self): 98 | info = self.get_current_info() 99 | if self.use_visual_observation: 100 | obs = self.render(mode='rgb_array', depth=True) 101 | else: 102 | obs = info['obs_state'] 103 | 104 | if self.use_image_goal: 105 | assert self.use_visual_observation 106 | ag = obs.copy() 107 | g = self.goal_observation 108 | else: 109 | ag = info['ag_state'] 110 | g = info['g_state'] 111 | return { 112 | 'observation': obs.copy(), 113 | 'achieved_goal': ag.copy(), 114 | 'desired_goal': g.copy(), 115 | } 116 | 117 | def _sample_goal(self): 118 | thumb_name = 'robot0:S_thtip' 119 | finger_names = [name for name in FINGERTIP_SITE_NAMES if name != thumb_name] 120 | finger_name = self.np_random.choice(finger_names) 121 | 122 | thumb_idx = FINGERTIP_SITE_NAMES.index(thumb_name) 123 | finger_idx = FINGERTIP_SITE_NAMES.index(finger_name) 124 | assert thumb_idx != finger_idx 125 | 126 | # Pick a meeting point above the hand. 127 | meeting_pos = self.palm_xpos + np.array([0.0, -0.09, 0.05]) 128 | meeting_pos += self.np_random.normal(scale=0.005, size=meeting_pos.shape) 129 | 130 | # Slightly move meeting goal towards the respective finger to avoid that they 131 | # overlap. 132 | goal = self.initial_goal.copy().reshape(-1, 3) 133 | for idx in [thumb_idx, finger_idx]: 134 | offset_direction = (meeting_pos - goal[idx]) 135 | offset_direction /= np.linalg.norm(offset_direction) 136 | goal[idx] = meeting_pos - 0.005 * offset_direction 137 | 138 | if self.np_random.uniform() < 0.1: 139 | # With some probability, ask all fingers to move back to the origin. 140 | # This avoids that the thumb constantly stays near the goal position already. 141 | goal = self.initial_goal.copy() 142 | return goal.flatten() 143 | 144 | def _render_callback(self): 145 | # Visualize targets. 146 | sites_offset = (self.sim.data.site_xpos - self.sim.model.site_pos).copy() 147 | goal = self.goal_state.reshape(5, 3) 148 | for finger_idx in range(5): 149 | site_name = 'target{}'.format(finger_idx) 150 | site_id = self.sim.model.site_name2id(site_name) 151 | self.sim.model.site_pos[site_id] = goal[finger_idx] - sites_offset[site_id] 152 | 153 | # Visualize finger positions. 154 | achieved_goal = self._get_achieved_goal().reshape(5, 3) 155 | for finger_idx in range(5): 156 | site_name = 'finger{}'.format(finger_idx) 157 | site_id = self.sim.model.site_name2id(site_name) 158 | self.sim.model.site_pos[site_id] = achieved_goal[finger_idx] - sites_offset[site_id] 159 | self.sim.forward() 160 | 161 | def get_current_info(self): 162 | robot_qpos, robot_qvel = robot_get_obs(self.sim) 163 | achieved_goal = self._get_achieved_goal().ravel() 164 | obs = np.concatenate([robot_qpos, robot_qvel, achieved_goal]) 165 | achieved_goal.copy() 166 | return { 167 | 'obs_state': obs.copy(), 168 | 'ag_state': achieved_goal.copy(), 169 | 'g_state': self.goal_state.copy(), 170 | } 171 | -------------------------------------------------------------------------------- /envs/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gym import error 4 | 5 | try: 6 | import mujoco_py 7 | except ImportError as e: 8 | raise error.DependencyNotInstalled( 9 | "{}. (HINT: you need to install mujoco_py, and also perform the setup instructions here: https://github.com/openai/mujoco-py/.)".format( 10 | e)) 11 | 12 | 13 | def robot_get_obs(sim): 14 | """Returns all joint positions and velocities associated with 15 | a robot. 16 | """ 17 | if sim.data.qpos is not None and sim.model.joint_names: 18 | names = [n for n in sim.model.joint_names if n.startswith('robot')] 19 | return ( 20 | np.array([sim.data.get_joint_qpos(name) for name in names]), 21 | np.array([sim.data.get_joint_qvel(name) for name in names]), 22 | ) 23 | return np.zeros(0), np.zeros(0) 24 | 25 | 26 | def ctrl_set_action(sim, action): 27 | """For torque actuators it copies the action into mujoco ctrl field. 28 | For position actuators it sets the target relative to the current qpos. 29 | """ 30 | if sim.model.nmocap > 0: 31 | _, action = np.split(action, (sim.model.nmocap * 7,)) 32 | if sim.data.ctrl is not None: 33 | for i in range(action.shape[0]): 34 | if sim.model.actuator_biastype[i] == 0: 35 | sim.data.ctrl[i] = action[i] 36 | else: 37 | idx = sim.model.jnt_qposadr[sim.model.actuator_trnid[i, 0]] 38 | sim.data.ctrl[i] = sim.data.qpos[idx] + action[i] 39 | 40 | 41 | def mocap_set_action(sim, action): 42 | """The action controls the robot using mocaps. Specifically, bodies 43 | on the robot (for example the gripper wrist) is controlled with 44 | mocap bodies. In this case the action is the desired difference 45 | in position and orientation (quaternion), in world coordinates, 46 | of the of the target body. The mocap is positioned relative to 47 | the target body according to the delta, and the MuJoCo equality 48 | constraint optimizer tries to center the welded body on the mocap. 49 | """ 50 | if sim.model.nmocap > 0: 51 | action, _ = np.split(action, (sim.model.nmocap * 7,)) 52 | action = action.reshape(sim.model.nmocap, 7) 53 | 54 | pos_delta = action[:, :3] 55 | quat_delta = action[:, 3:] 56 | 57 | reset_mocap2body_xpos(sim) 58 | sim.data.mocap_pos[:] = sim.data.mocap_pos + pos_delta 59 | sim.data.mocap_quat[:] = sim.data.mocap_quat + quat_delta 60 | 61 | 62 | def reset_mocap_welds(sim): 63 | """Resets the mocap welds that we use for actuation. 64 | """ 65 | if sim.model.nmocap > 0 and sim.model.eq_data is not None: 66 | for i in range(sim.model.eq_data.shape[0]): 67 | if sim.model.eq_type[i] == mujoco_py.const.EQ_WELD: 68 | sim.model.eq_data[i, :] = np.array( 69 | [0., 0., 0., 1., 0., 0., 0.]) 70 | sim.forward() 71 | 72 | 73 | def reset_mocap2body_xpos(sim): 74 | """Resets the position and orientation of the mocap bodies to the same 75 | values as the bodies they're welded to. 76 | """ 77 | 78 | if (sim.model.eq_type is None or 79 | sim.model.eq_obj1id is None or 80 | sim.model.eq_obj2id is None): 81 | return 82 | for eq_type, obj1_id, obj2_id in zip(sim.model.eq_type, 83 | sim.model.eq_obj1id, 84 | sim.model.eq_obj2id): 85 | if eq_type != mujoco_py.const.EQ_WELD: 86 | continue 87 | 88 | mocap_id = sim.model.body_mocapid[obj1_id] 89 | if mocap_id != -1: 90 | # obj1 is the mocap, obj2 is the welded body 91 | body_idx = obj2_id 92 | else: 93 | # obj2 is the mocap, obj1 is the welded body 94 | mocap_id = sim.model.body_mocapid[obj2_id] 95 | body_idx = obj1_id 96 | 97 | assert (mocap_id != -1) 98 | sim.data.mocap_pos[mocap_id][:] = sim.data.body_xpos[body_idx] 99 | sim.data.mocap_quat[mocap_id][:] = sim.data.body_xquat[body_idx] 100 | 101 | 102 | def separate_img(rgbd_img): 103 | """ 104 | Given RGBD image, return rgb and depth image 105 | :return: 106 | """ 107 | return rgbd_img[:, :, :3], rgbd_img[:, :, -1] 108 | -------------------------------------------------------------------------------- /olaux/config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from olaux.ddpg import DDPG 4 | from olaux.her import make_sample_her_transitions 5 | 6 | from envs.goal_env_ext.fetch.reach import FetchReachEnv 7 | from envs.goal_env_ext.hand.reach import HandReachEnv 8 | from envs.goal_env_ext.dm_control.finger import Finger 9 | 10 | CUSTOM_ENVS = { 11 | 'VisualFetchReach': FetchReachEnv, 12 | 'VisualHandReach': HandReachEnv, 13 | 'VisualFinger': Finger 14 | } 15 | 16 | env_arg_dict = { 17 | 'VisualFetchReach': {'n_cycles': 5, 'n_epochs': 40, 'horizon': 50, 'image_size': 100, 'distance_threshold': 5e-2, 18 | 'distance_threshold_obs': 0., 19 | 'use_image_goal': True, 'use_visual_observation': True, 'with_goal': False}, 20 | 'VisualHandReach': {'n_cycles': 10, 'n_epochs': 150, 'horizon': 50, 'image_size': 100, 'distance_threshold': 3e-2, 21 | 'distance_threshold_obs': 0., 'use_image_goal': True, 'use_visual_observation': True, 22 | 'with_goal': False}, 23 | 'VisualFinger': {'n_cycles': 10, 'n_epochs': 100, 'horizon': 50, 'distance_threshold': 0.07, 'distance_threshold_obs': 0., 'n_substeps': 10, 24 | 'image_size': 100, 'use_auxiliary_info': True, 'use_image_goal': True, 'use_visual_observation': True, 25 | 'stack_obs': False} 26 | } 27 | 28 | CACHED_ENVS = {} 29 | 30 | 31 | def cached_make_env(make_env): 32 | """ 33 | Only creates a new environment from the provided function if one has not yet already been 34 | created. This is useful here because we need to infer certain properties of the env, e.g. 35 | its observation and action spaces, without any intend of actually using it. 36 | """ 37 | if make_env not in CACHED_ENVS: 38 | env = make_env() 39 | CACHED_ENVS[make_env] = env 40 | return CACHED_ENVS[make_env] 41 | 42 | 43 | def configure_her(vv): 44 | env = cached_make_env(vv['make_env']) 45 | env.reset() 46 | 47 | def reward_fun(ag_2, g, info): # vectorized 48 | return env.compute_reward(achieved_goal=ag_2, desired_goal=g, info=info) 49 | 50 | her_vv = {'reward_fun': reward_fun} 51 | for name in vv: 52 | if name.startswith('her_'): 53 | her_vv[name[len('her_'):]] = vv[name] 54 | sample_her_transitions = make_sample_her_transitions(**her_vv) 55 | 56 | return sample_her_transitions 57 | 58 | 59 | def configure_ddpg(vv, dims, shapes, reuse=False, clip_return=True, rank=0): 60 | sample_her_transitions = configure_her(vv) 61 | ddpg_vv = vv.copy() 62 | ddpg_vv.update({'input_dims': dims.copy(), # agent takes an input observations 63 | 'image_input_shapes': shapes.copy(), 64 | 'clip_pos_returns': False, # clip positive returns 65 | 'clip_return': (1. / (1. - vv['gamma'])) if clip_return else np.inf, # max abs of return 66 | 'sample_transitions': sample_her_transitions, 67 | 'rank': rank}) 68 | 69 | policy = DDPG(reuse=reuse, **ddpg_vv) 70 | return policy 71 | 72 | 73 | def configure_shapes(make_env): 74 | env = cached_make_env(make_env) 75 | env.reset() 76 | obs, _, _, info = env.step(env.action_space.sample()) 77 | 78 | shapes = { 79 | 'o': obs['observation'].shape, 80 | 'u': env.action_space.shape, 81 | 'g': obs['desired_goal'].shape, 82 | } 83 | for key, value in info.items(): 84 | value = np.array(value) 85 | if value.ndim == 0: 86 | value = value.reshape(1) 87 | shapes['info_{}'.format(key)] = value.shape 88 | return shapes 89 | -------------------------------------------------------------------------------- /olaux/her.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def make_sample_her_transitions(replay_strategy, 5 | replay_k, 6 | reward_fun, 7 | reward_choices=(-1, 1), 8 | use_aux_tasks=False): 9 | """ Creates a sample function that can be used for HER experience replay. 10 | Args: 11 | replay_strategy (in ['future', 'none']): the HER replay strategy; if set to 'none', 12 | regular DDPG experience replay is used 13 | replay_k (int): the ratio between HER replays and regular replays (e.g. k = 4 -> 4 times 14 | as many HER replays as regular replays are used) 15 | reward_fun (function): function to re-compute the reward with substituted goals 16 | reward_choices: Rewards when goal is achieved / not achieved 17 | use_aux_tasks: Whether auxiliary information will also be sampled in the transitions 18 | """ 19 | if replay_strategy == 'future': 20 | future_p = 1 - (1. / (1 + replay_k)) 21 | 22 | def _sample_her_transitions(episode_batch, batch_size_in_transitions): 23 | """episode_batch is {key: array(buffer_size x T x dim_key)} 24 | """ 25 | 26 | episode_batch['info_ag_2_state'] = episode_batch['info_ag_state'][:, 1:, :] 27 | T = episode_batch['u'].shape[1] 28 | if use_aux_tasks: 29 | episode_batch['info_transformation'] = episode_batch['info_transformation'][:, 1:, :] 30 | episode_batch['info_transformed_frame'] = episode_batch['info_transformed_frame'][:, 1:, :] 31 | episode_batch['info_op_flow'] = episode_batch['info_op_flow'][:, 1:, :] 32 | episode_batch['info_bw_frame'] = episode_batch['info_bw_frame'][:, 1:, :] 33 | 34 | rollout_batch_size = episode_batch['u'].shape[0] 35 | batch_size = batch_size_in_transitions 36 | 37 | # Select which episodes and time steps to use. 38 | episode_idxs = np.random.randint(0, rollout_batch_size, batch_size) 39 | t_samples = np.random.randint(T, size=batch_size) 40 | transitions = {key: episode_batch[key][episode_idxs, t_samples].copy() 41 | for key in episode_batch.keys()} 42 | # Baseline HER with a replay probability 43 | # Select future time indexes proportional with probability future_p. These 44 | # will be used for HER replay by substituting in future goals. 45 | her_indexes = np.where(np.random.uniform(size=batch_size) < future_p) 46 | future_offset = np.random.uniform(size=batch_size) * (T - t_samples) 47 | 48 | future_offset = future_offset.astype(int) 49 | future_t = (t_samples + 1 + future_offset)[her_indexes] 50 | 51 | # Replace goal with achieved goal but only for the previously-selected 52 | # HER transitions (as defined by her_indexes). For the other transitions, 53 | # keep the original goal. 54 | 55 | future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t] 56 | future_ag_val = episode_batch['info_ag_state'][episode_idxs[her_indexes], future_t] 57 | transitions['info_g_state'][her_indexes] = future_ag_val 58 | transitions['g'][her_indexes] = future_ag 59 | 60 | # Reconstruct info dictionary for reward computation. 61 | info = {} 62 | for key, value in transitions.items(): 63 | if key.startswith('info_'): 64 | info[key.replace('info_', '')] = value 65 | 66 | info['ag_state'] = info['ag_2_state'] 67 | 68 | # Re-compute reward since we may have substituted the goal. 69 | reward_params = {k: transitions[k] for k in ['ag_2', 'g']} 70 | reward_params['info'] = info 71 | 72 | rewards = reward_fun(**reward_params) 73 | goal_reached = (np.round(rewards) == 0) # Make sure the environments are zero when goals are reached 74 | rewards = goal_reached * reward_choices[1] + (1. - goal_reached) * reward_choices[0] 75 | transitions['r'] = rewards 76 | 77 | transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:]) 78 | for k in transitions.keys()} 79 | assert (transitions['u'].shape[0] == batch_size_in_transitions) 80 | 81 | return transitions 82 | 83 | return _sample_her_transitions 84 | -------------------------------------------------------------------------------- /olaux/normalizer.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | import numpy as np 4 | from mpi4py import MPI 5 | import tensorflow as tf 6 | 7 | from olaux.utils import reshape_for_broadcasting 8 | 9 | 10 | class Normalizer: 11 | def __init__(self, size, eps=1e-2, default_clip_range=np.inf, sess=None): 12 | """A normalizer that ensures that observations are approximately distributed according to 13 | a standard Normal distribution (i.e. have mean zero and variance one). 14 | 15 | Args: 16 | size (int): the size of the observation to be normalized 17 | eps (float): a small constant that avoids underflows 18 | default_clip_range (float): normalized observations are clipped to be in 19 | [-default_clip_range, default_clip_range] 20 | sess (object): the TensorFlow session to be used 21 | """ 22 | self.size = size 23 | self.eps = eps 24 | self.default_clip_range = default_clip_range 25 | self.sess = sess if sess is not None else tf.get_default_session() 26 | 27 | self.local_sum = np.zeros(self.size, np.float32) 28 | self.local_sumsq = np.zeros(self.size, np.float32) 29 | self.local_count = np.zeros(1, np.float32) 30 | 31 | self.sum_tf = tf.get_variable( 32 | initializer=tf.zeros_initializer(), shape=self.local_sum.shape, name='sum', 33 | trainable=False, dtype=tf.float32) 34 | self.sumsq_tf = tf.get_variable( 35 | initializer=tf.zeros_initializer(), shape=self.local_sumsq.shape, name='sumsq', 36 | trainable=False, dtype=tf.float32) 37 | self.count_tf = tf.get_variable( 38 | initializer=tf.ones_initializer(), shape=self.local_count.shape, name='count', 39 | trainable=False, dtype=tf.float32) 40 | self.mean = tf.get_variable( 41 | initializer=tf.zeros_initializer(), shape=(self.size,), name='mean', 42 | trainable=False, dtype=tf.float32) 43 | self.std = tf.get_variable( 44 | initializer=tf.ones_initializer(), shape=(self.size,), name='std', 45 | trainable=False, dtype=tf.float32) 46 | self.count_pl = tf.placeholder(name='count_pl', shape=(1,), dtype=tf.float32) 47 | self.sum_pl = tf.placeholder(name='sum_pl', shape=(self.size,), dtype=tf.float32) 48 | self.sumsq_pl = tf.placeholder(name='sumsq_pl', shape=(self.size,), dtype=tf.float32) 49 | 50 | self.update_op = tf.group( 51 | self.count_tf.assign_add(self.count_pl), 52 | self.sum_tf.assign_add(self.sum_pl), 53 | self.sumsq_tf.assign_add(self.sumsq_pl) 54 | ) 55 | self.recompute_op = tf.group( 56 | tf.assign(self.mean, self.sum_tf / self.count_tf), 57 | tf.assign(self.std, tf.sqrt(tf.maximum( 58 | tf.square(self.eps), 59 | self.sumsq_tf / self.count_tf - tf.square(self.sum_tf / self.count_tf) 60 | ))), 61 | ) 62 | self.lock = threading.Lock() 63 | 64 | def update(self, v): 65 | v = v.reshape(-1, self.size) 66 | 67 | with self.lock: 68 | self.local_sum += v.sum(axis=0) 69 | self.local_sumsq += (np.square(v)).sum(axis=0) 70 | self.local_count[0] += v.shape[0] 71 | 72 | def normalize(self, v, clip_range=None): 73 | if clip_range is None: 74 | clip_range = self.default_clip_range 75 | mean = reshape_for_broadcasting(self.mean, v) 76 | std = reshape_for_broadcasting(self.std, v) 77 | return tf.clip_by_value((v - mean) / std, -clip_range, clip_range) 78 | 79 | def denormalize(self, v): 80 | mean = reshape_for_broadcasting(self.mean, v) 81 | std = reshape_for_broadcasting(self.std, v) 82 | return mean + v * std 83 | 84 | def _mpi_average(self, x): 85 | buf = np.zeros_like(x) 86 | MPI.COMM_WORLD.Allreduce(x, buf, op=MPI.SUM) 87 | buf /= MPI.COMM_WORLD.Get_size() 88 | return buf 89 | 90 | def synchronize(self, local_sum, local_sumsq, local_count, root=None): 91 | local_sum[...] = self._mpi_average(local_sum) 92 | local_sumsq[...] = self._mpi_average(local_sumsq) 93 | local_count[...] = self._mpi_average(local_count) 94 | return local_sum, local_sumsq, local_count 95 | 96 | def recompute_stats(self): 97 | with self.lock: 98 | # Copy over results. 99 | local_count = self.local_count.copy() 100 | local_sum = self.local_sum.copy() 101 | local_sumsq = self.local_sumsq.copy() 102 | 103 | # Reset. 104 | self.local_count[...] = 0 105 | self.local_sum[...] = 0 106 | self.local_sumsq[...] = 0 107 | 108 | # We perform the synchronization outside of the lock to keep the critical section as short 109 | # as possible. 110 | synced_sum, synced_sumsq, synced_count = self.synchronize( 111 | local_sum=local_sum, local_sumsq=local_sumsq, local_count=local_count) 112 | 113 | self.sess.run(self.update_op, feed_dict={ 114 | self.count_pl: synced_count, 115 | self.sum_pl: synced_sum, 116 | self.sumsq_pl: synced_sumsq, 117 | }) 118 | self.sess.run(self.recompute_op) 119 | 120 | 121 | class IdentityNormalizer: 122 | def __init__(self, size, std=1.): 123 | self.size = size 124 | self.mean = tf.zeros(self.size, tf.float32) 125 | self.std = std * tf.ones(self.size, tf.float32) 126 | 127 | def update(self, x): 128 | pass 129 | 130 | def normalize(self, x, clip_range=None): 131 | return x / self.std 132 | 133 | def denormalize(self, x): 134 | return self.std * x 135 | 136 | def synchronize(self): 137 | pass 138 | 139 | def recompute_stats(self): 140 | pass 141 | -------------------------------------------------------------------------------- /olaux/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | import numpy as np 4 | 5 | 6 | class ReplayBuffer: 7 | def __init__(self, buffer_shapes, size_in_transitions, T, sample_transitions): 8 | """Creates a replay buffer. 9 | 10 | Args: 11 | buffer_shapes (dict of ints): the shape for all buffers that are used in the replay 12 | buffer 13 | size_in_transitions (int): the size of the buffer, measured in transitions 14 | T (int): the time horizon for episodes 15 | sample_transitions (function): a function that samples from the replay buffer 16 | """ 17 | self.buffer_shapes = buffer_shapes 18 | self.size = size_in_transitions // T 19 | self.T = T 20 | self.sample_transitions = sample_transitions 21 | 22 | # self.buffers is {key: array(size_in_episodes x T or T+1 x dim_key)} 23 | self.buffers = {key: np.empty([self.size, *shape], dtype=np.float32) 24 | for key, shape in buffer_shapes.items()} 25 | 26 | # memory management 27 | self.current_size = 0 28 | self.n_transitions_stored = 0 29 | 30 | @property 31 | def full(self): 32 | return self.current_size == self.size 33 | 34 | def sample(self, batch_size): 35 | """Returns a dict {key: array(batch_size x shapes[key])} 36 | """ 37 | buffers = {} 38 | 39 | assert self.current_size > 0 40 | for key in self.buffers.keys(): 41 | buffers[key] = self.buffers[key][:self.current_size] 42 | 43 | buffers['o_2'] = buffers['o'][:, 1:, :] 44 | buffers['ag_2'] = buffers['ag'][:, 1:, :] 45 | 46 | transitions = self.sample_transitions(buffers, batch_size) 47 | 48 | for key in (['r', 'o_2', 'ag_2'] + list(self.buffers.keys())): 49 | assert key in transitions, "key %s missing from transitions" % key 50 | 51 | return transitions 52 | 53 | def store_episode(self, episode_batch): 54 | """episode_batch: array(batch_size x (T or T+1) x dim_key) 55 | """ 56 | batch_sizes = [len(episode_batch[key]) for key in episode_batch.keys()] 57 | T = len(episode_batch['o'][0]) # This is actually T+1 58 | assert np.all(np.array(batch_sizes) == batch_sizes[0]) 59 | batch_size = batch_sizes[0] 60 | 61 | idxs = self._get_storage_idx(batch_size) 62 | 63 | # load inputs into buffers 64 | for key in self.buffers.keys(): 65 | self.buffers[key][idxs] = episode_batch[key] 66 | self.n_transitions_stored += batch_size * self.T 67 | 68 | def get_current_episode_size(self): 69 | return self.current_size 70 | 71 | def get_current_size(self): 72 | return self.current_size * self.T 73 | 74 | def get_transitions_stored(self): 75 | return self.n_transitions_stored 76 | 77 | def clear_buffer(self): 78 | self.current_size = 0 79 | 80 | def _get_storage_idx(self, inc=None): 81 | inc = inc or 1 # size increment 82 | assert inc <= self.size, "Batch committed to replay is too large!" 83 | # go consecutively until you hit the end, and then go randomly. 84 | if self.current_size + inc <= self.size: 85 | idx = np.arange(self.current_size, self.current_size + inc) 86 | elif self.current_size < self.size: 87 | overflow = inc - (self.size - self.current_size) 88 | idx_a = np.arange(self.current_size, self.size) 89 | idx_b = np.random.randint(0, self.current_size, overflow) 90 | idx = np.concatenate([idx_a, idx_b]) 91 | else: 92 | idx = np.random.randint(0, self.size, inc) 93 | 94 | # update replay size 95 | self.current_size = min(self.size, self.current_size + inc) 96 | 97 | if inc == 1: 98 | idx = idx[0] 99 | return idx 100 | -------------------------------------------------------------------------------- /olaux/rollout.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import numpy as np 4 | import pickle 5 | from mujoco_py import MujocoException 6 | 7 | from olaux.utils import convert_episode_to_batch_major, store_args 8 | import cv2 as cv 9 | 10 | 11 | def shapes_to_dims(input_shapes): 12 | return {key: np.prod(val) for key, val in input_shapes.items()} 13 | 14 | 15 | class RolloutWorker: 16 | @store_args 17 | def __init__(self, make_env, policy, shapes, logger, T, rollout_batch_size=1, 18 | exploit=False, use_target_net=False, compute_Q=False, noise_eps=0, 19 | random_eps=0, history_len=100, render=False, full_info=True, **kwargs): 20 | """Rollout worker generates experience by interacting with one or many environments. 21 | 22 | Args: 23 | make_env (function): a factory function that creates a new instance of the environment 24 | when called 25 | policy (object): the policy that is used to act 26 | dims (dict of ints): the dimensions for observations (o), goals (g), and actions (u) 27 | logger (object): the logger that is used by the rollout worker 28 | rollout_batch_size (int): the number of parallel rollouts that should be used 29 | exploit (boolean): whether or not to exploit, i.e. to act optimally according to the 30 | current policy without any exploration 31 | use_target_net (boolean): whether or not to use the target net for rollouts 32 | compute_Q (boolean): whether or not to compute the Q values alongside the actions 33 | noise_eps (float): scale of the additive Gaussian noise 34 | random_eps (float): probability of selecting a completely random action 35 | history_len (int): length of history for statistics smoothing 36 | render (boolean): whether or not to render the rollouts 37 | """ 38 | dims = shapes_to_dims(shapes) 39 | self.dims = dims 40 | self.envs = [make_env() for _ in range(rollout_batch_size)] 41 | assert self.T > 0 42 | 43 | self.info_keys = [key.replace('info_', '') for key in dims.keys() if key.startswith('info_')] 44 | 45 | self.logging_keys = ['success_state', 'success_obs', 'goal_dist_final_state', 'goal_dist_final_obs'] 46 | if self.compute_Q: 47 | self.logging_keys.append('mean_Q') 48 | # Smoothed by maxlen 49 | self.logging_history = {key: deque(maxlen=history_len) for key in self.logging_keys} 50 | # Logging history: [T, batch_id] 51 | self.n_episodes = 0 52 | self.g = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # goals 53 | self.initial_o = np.empty((self.rollout_batch_size, self.dims['o']), np.float32) # observations 54 | self.initial_ag = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # achieved goals 55 | self.full_info = full_info 56 | if full_info: 57 | self.initial_info = [None] * rollout_batch_size 58 | self.reset_all_rollouts() 59 | self.clear_history() 60 | 61 | def reset_rollout(self, i): 62 | """Resets the `i`-th rollout environment, re-samples a new goal, and updates the `initial_o` 63 | and `g` arrays accordingly. 64 | """ 65 | if self.full_info: 66 | if hasattr(self.envs[i], 'get_initial_info'): 67 | self.initial_info[i] = self.envs[i].get_initial_info() 68 | else: 69 | # Only used for visualization of deprecated policy 70 | self.initial_info[i] = self.envs[i].get_current_info() 71 | 72 | obs = self.envs[i].reset() 73 | self.initial_o[i] = obs['observation'].flatten() 74 | self.initial_ag[i] = obs['achieved_goal'].flatten() 75 | self.g[i] = obs['desired_goal'].flatten() 76 | 77 | def reset_all_rollouts(self): 78 | """Resets all `rollout_batch_size` rollout workers. 79 | """ 80 | for i in range(self.rollout_batch_size): 81 | self.reset_rollout(i) 82 | 83 | def generate_rollouts(self): 84 | """Performs `rollout_batch_size` rollouts in parallel for time horizon `T` with the current 85 | policy acting on it accordingly. 86 | """ 87 | self.reset_all_rollouts() 88 | 89 | # compute observations 90 | o = np.empty((self.rollout_batch_size, self.dims['o']), np.float32) # observations 91 | ag = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # achieved goals 92 | o[:] = self.initial_o 93 | ag[:] = self.initial_ag 94 | 95 | # generate episodes 96 | obs, achieved_goals, acts, goals = [], [], [], [] 97 | info_values = [np.empty((self.T, self.rollout_batch_size, self.dims['info_' + key]), np.float32) for key in 98 | self.info_keys] 99 | Qs = [] 100 | for t in range(self.T): 101 | policy_output = self.policy.get_actions( 102 | o, ag, self.g, 103 | compute_Q=self.compute_Q, 104 | noise_eps=self.noise_eps if not self.exploit else 0., 105 | random_eps=self.random_eps if not self.exploit else 0., 106 | use_target_net=self.use_target_net) 107 | if self.compute_Q: 108 | u, Q = policy_output 109 | Qs.append(Q) 110 | else: 111 | u = policy_output 112 | if u.ndim == 1: 113 | # The non-batched case should still have a reasonable shape. 114 | u = u.reshape(1, -1) 115 | o_new = np.empty((self.rollout_batch_size, self.dims['o'])) 116 | ag_new = np.empty((self.rollout_batch_size, self.dims['g'])) 117 | # compute new states and observations 118 | for i in range(self.rollout_batch_size): 119 | try: 120 | # We fully ignore the reward here because it will have to be re-computed 121 | # for HER. 122 | curr_o_new, _, _, info = self.envs[i].step(u[i]) 123 | 124 | o_new[i] = curr_o_new['observation'].flatten() 125 | ag_new[i] = curr_o_new['achieved_goal'].flatten() 126 | for idx, key in enumerate(self.info_keys): 127 | info_values[idx][t, i] = info[key] 128 | if self.render == 1 and i==0: 129 | self.envs[i].render() 130 | elif self.render == 2 and i==0: 131 | display_o = cv.resize(curr_o_new['observation'], (500, 500)) 132 | display_g = cv.resize(curr_o_new['desired_goal'], (500, 500)) 133 | display_img = np.concatenate((display_o, display_g), axis=1) 134 | display_img = display_img[:, :, (2, 1, 0)] / 256. 135 | cv.imshow('display', display_img) 136 | cv.waitKey(10) 137 | except MujocoException as e: 138 | return self.generate_rollouts() 139 | if np.isnan(o_new).any(): 140 | self.logger.warning('NaN caught during rollout generation. Trying again...') 141 | self.reset_all_rollouts() 142 | return self.generate_rollouts() 143 | 144 | obs.append(o.copy()) 145 | achieved_goals.append(ag.copy()) 146 | acts.append(u.copy()) 147 | goals.append(self.g.copy()) 148 | o[...] = o_new 149 | ag[...] = ag_new 150 | obs.append(o.copy()) 151 | achieved_goals.append(ag.copy()) 152 | self.initial_o[:] = o 153 | 154 | episode = dict(o=obs, 155 | u=acts, 156 | g=goals, 157 | ag=achieved_goals) 158 | if self.full_info: 159 | for idx, key in enumerate(self.info_keys): 160 | init_info_values = [np.empty((1, self.rollout_batch_size, self.dims['info_' + key]), np.float32) for key 161 | in self.info_keys] 162 | for t_idx, t_key in enumerate(self.info_keys): 163 | for i in range(self.rollout_batch_size): 164 | init_info_values[t_idx][0, i] = self.initial_info[i][t_key] 165 | info_values[idx] = np.concatenate([init_info_values[idx], info_values[idx]], axis=0) 166 | 167 | for key, value in zip(self.info_keys, info_values): 168 | episode['info_{}'.format(key)] = value 169 | 170 | # stats 171 | d, d_bool = self.envs[0]._get_info_state(episode['info_ag_state'][-1][:], episode['info_g_state'][-1][:]) 172 | self.logging_history['goal_dist_final_state'].append(np.mean(d)) 173 | self.logging_history['success_state'].append(np.mean(d_bool)) 174 | 175 | d, d_bool = self.envs[0]._get_info_obs(episode['ag'][-1][:], episode['g'][-1][:]) 176 | self.logging_history['goal_dist_final_obs'].append(np.mean(d)) 177 | self.logging_history['success_obs'].append(np.mean(d_bool)) 178 | 179 | if self.compute_Q: 180 | self.logging_history['mean_Q'].append(np.mean(Qs)) 181 | self.n_episodes += self.rollout_batch_size 182 | 183 | return convert_episode_to_batch_major(episode) 184 | 185 | def clear_history(self): 186 | """Clears all histories that are used for statistics 187 | """ 188 | for _, log_queue in self.logging_history.items(): 189 | log_queue.clear() 190 | 191 | def current_success_rate(self): 192 | return np.mean(self.logging_history['success_state']) 193 | 194 | def current_mean_Q(self): 195 | return np.mean(self.logging_history['mean_Q']) 196 | 197 | def save_policy(self, path): 198 | """Pickles the current policy for later inspection. 199 | """ 200 | # print(dir(self.policy)) 201 | with open(path, 'wb') as f: 202 | pickle.dump(self.policy, f) 203 | 204 | def logs(self, prefix='worker'): 205 | """Generates a dictionary that contains all collected statistics. 206 | """ 207 | logs = [] 208 | 209 | for key, log_queue in sorted(self.logging_history.items()): 210 | logs += [(key, np.mean(self.logging_history[key]))] 211 | logs += [('episode', self.n_episodes)] 212 | 213 | if prefix is not '' and not prefix.endswith('/'): 214 | return [(prefix + '/' + key, val) for key, val in logs] 215 | else: 216 | return logs 217 | 218 | def seed(self, seed): 219 | """Seeds each environment with a distinct seed derived from the passed in global seed. 220 | """ 221 | for idx, env in enumerate(self.envs): 222 | env.seed(seed + 1000 * idx) 223 | -------------------------------------------------------------------------------- /olaux/test_tf.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def conv_layer(o): 5 | layer = tf.keras.layers.Conv2D(name='conv1', 6 | kernel_size=[2, 2], 7 | filters=64, strides=2, 8 | padding='same', 9 | activation=tf.nn.relu) 10 | return layer(o) 11 | 12 | 13 | if __name__ == '__main__': 14 | o = tf.random_normal(shape=(5, 10, 10, 4)) 15 | x = conv_layer(o) 16 | sess = tf.Session() 17 | sess.run(tf.global_variables_initializer()) 18 | # print(sess.run(o)) 19 | print(sess.run([x])) 20 | 21 | 22 | # from __future__ import division, print_function, absolute_import 23 | # 24 | # # Import MNIST data 25 | # from tensorflow.examples.tutorials.mnist import input_data 26 | # mnist = input_data.read_data_sets("/tmp/data/", one_hot=False) 27 | # 28 | # import tensorflow as tf 29 | # import matplotlib.pyplot as plt 30 | # import numpy as np 31 | # 32 | # # Training Parameters 33 | # learning_rate = 0.001 34 | # num_steps = 2000 35 | # batch_size = 128 36 | # 37 | # # Network Parameters 38 | # num_input = 784 # MNIST data input (img shape: 28*28) 39 | # num_classes = 10 # MNIST total classes (0-9 digits) 40 | # dropout = 0.25 # Dropout, probability to drop a unit 41 | # 42 | # 43 | # # Create the neural network 44 | # def conv_net(x_dict, n_classes, dropout, reuse, is_training): 45 | # # Define a scope for reusing the variables 46 | # with tf.variable_scope('ConvNet', reuse=reuse): 47 | # # TF Estimator input is a dict, in case of multiple inputs 48 | # x = x_dict['images'] 49 | # 50 | # # MNIST data input is a 1-D vector of 784 features (28*28 pixels) 51 | # # Reshape to match picture format [Height x Width x Channel] 52 | # # Tensor input become 4-D: [Batch Size, Height, Width, Channel] 53 | # x = tf.reshape(x, shape=[-1, 28, 28, 1]) 54 | # 55 | # # Convolution Layer with 32 filters and a kernel size of 5 56 | # conv1 = tf.layers.conv2d(x, 32, 5, activation=tf.nn.relu) 57 | # # Max Pooling (down-sampling) with strides of 2 and kernel size of 2 58 | # conv1 = tf.layers.max_pooling2d(conv1, 2, 2) 59 | # 60 | # # Convolution Layer with 64 filters and a kernel size of 3 61 | # conv2 = tf.layers.conv2d(conv1, 64, 3, activation=tf.nn.relu) 62 | # # Max Pooling (down-sampling) with strides of 2 and kernel size of 2 63 | # conv2 = tf.layers.max_pooling2d(conv2, 2, 2) 64 | # 65 | # # Flatten the data to a 1-D vector for the fully connected layer 66 | # fc1 = tf.contrib.layers.flatten(conv2) 67 | # 68 | # # Fully connected layer (in tf contrib folder for now) 69 | # fc1 = tf.layers.dense(fc1, 1024) 70 | # # Apply Dropout (if is_training is False, dropout is not applied) 71 | # fc1 = tf.layers.dropout(fc1, rate=dropout, training=is_training) 72 | # 73 | # # Output layer, class prediction 74 | # out = tf.layers.dense(fc1, n_classes) 75 | # 76 | # return out 77 | # 78 | # 79 | # def model_fn(features, labels, mode): 80 | # # Build the neural network 81 | # # Because Dropout have different behavior at training and prediction time, we 82 | # # need to create 2 distinct computation graphs that still share the same weights. 83 | # logits_train = conv_net(features, num_classes, dropout, reuse=False, is_training=True) 84 | # logits_test = conv_net(features, num_classes, dropout, reuse=True, is_training=False) 85 | # 86 | # # Predictions 87 | # pred_classes = tf.argmax(logits_test, axis=1) 88 | # pred_probas = tf.nn.softmax(logits_test) 89 | # 90 | # # If prediction mode, early return 91 | # if mode == tf.estimator.ModeKeys.PREDICT: 92 | # return tf.estimator.EstimatorSpec(mode, predictions=pred_classes) 93 | # 94 | # # Define loss and optimizer 95 | # loss_op = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( 96 | # logits=logits_train, labels=tf.cast(labels, dtype=tf.int32))) 97 | # optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) 98 | # train_op = optimizer.minimize(loss_op, global_step=tf.train.get_global_step()) 99 | # 100 | # # Evaluate the accuracy of the model 101 | # acc_op = tf.metrics.accuracy(labels=labels, predictions=pred_classes) 102 | # 103 | # # TF Estimators requires to return a EstimatorSpec, that specify 104 | # # the different ops for training, evaluating, ... 105 | # estim_specs = tf.estimator.EstimatorSpec( 106 | # mode=mode, 107 | # predictions=pred_classes, 108 | # loss=loss_op, 109 | # train_op=train_op, 110 | # eval_metric_ops={'accuracy': acc_op}) 111 | # 112 | # return estim_specs 113 | # 114 | # 115 | # # Build the Estimator 116 | # model = tf.estimator.Estimator(model_fn) 117 | # 118 | # 119 | # # Define the input function for training 120 | # input_fn = tf.estimator.inputs.numpy_input_fn( 121 | # x={'images': mnist.train.images}, y=mnist.train.labels, 122 | # batch_size=batch_size, num_epochs=None, shuffle=True) 123 | # # Train the Model 124 | # model.train(input_fn, steps=num_steps) -------------------------------------------------------------------------------- /olaux/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import subprocess 4 | import sys 5 | import importlib 6 | import inspect 7 | import functools 8 | 9 | import tensorflow as tf 10 | import numpy as np 11 | 12 | from baselines.common import tf_util as U 13 | from baselines.common.mpi_moments import mpi_moments 14 | 15 | 16 | def shapes_to_dims(input_shapes): 17 | return {key: np.prod(val) for key, val in input_shapes.items()} 18 | 19 | 20 | def dims_to_shapes(input_dims): 21 | return {key: tuple([val]) if val > 0 else tuple() for key, val in input_dims.items()} 22 | 23 | 24 | def mpi_average(value): 25 | if value == []: 26 | value = [0.] 27 | if not isinstance(value, list): 28 | value = [value] 29 | return mpi_moments(np.array(value))[0] 30 | 31 | 32 | def log_params(params, logger): 33 | for key in sorted(params.keys()): 34 | logger.info('{}: {}'.format(key, params[key])) 35 | 36 | 37 | def set_global_seeds(i): 38 | try: 39 | import tensorflow as tf 40 | except ImportError: 41 | pass 42 | else: 43 | tf.set_random_seed(i) 44 | np.random.seed(i) 45 | random.seed(i) 46 | 47 | 48 | def store_args(method): 49 | """Stores provided method args as instance attributes. 50 | """ 51 | argspec = inspect.getfullargspec(method) 52 | defaults = {} 53 | if argspec.defaults is not None: 54 | defaults = dict( 55 | zip(argspec.args[-len(argspec.defaults):], argspec.defaults)) 56 | if argspec.kwonlydefaults is not None: 57 | defaults.update(argspec.kwonlydefaults) 58 | arg_names = argspec.args[1:] 59 | 60 | @functools.wraps(method) 61 | def wrapper(*positional_args, **keyword_args): 62 | self = positional_args[0] 63 | # Get default arg values 64 | args = defaults.copy() 65 | # Add provided arg values 66 | for name, value in zip(arg_names, positional_args[1:]): 67 | args[name] = value 68 | args.update(keyword_args) 69 | self.__dict__.update(args) 70 | return method(*positional_args, **keyword_args) 71 | 72 | return wrapper 73 | 74 | 75 | def import_function(spec): 76 | """Import a function identified by a string like "pkg.module:fn_name". 77 | """ 78 | mod_name, fn_name = spec.split(':') 79 | module = importlib.import_module(mod_name) 80 | fn = getattr(module, fn_name) 81 | return fn 82 | 83 | 84 | def flatten_grads(var_list, grads, clip_grad_range=None): 85 | """Flattens a variables and their gradients. 86 | """ 87 | if clip_grad_range is not None: 88 | return tf.concat([tf.reshape(tf.clip_by_value(grad, *clip_grad_range), [U.numel(v)]) 89 | for (v, grad) in zip(var_list, grads)], 0) 90 | else: 91 | return tf.concat([tf.reshape(grad, [U.numel(v)]) 92 | for (v, grad) in zip(var_list, grads)], 0) 93 | 94 | 95 | def nn(input, layers_sizes, reuse=None, flatten=False, name=""): 96 | """Creates a simple neural network 97 | """ 98 | for i, size in enumerate(layers_sizes): 99 | activation = tf.nn.relu if i < len(layers_sizes) - 1 else None 100 | input = tf.layers.dense(inputs=input, 101 | units=size, 102 | kernel_initializer=tf.contrib.layers.xavier_initializer(), 103 | reuse=reuse, 104 | name=name + '_' + str(i)) 105 | if activation: 106 | input = activation(input) 107 | if flatten: 108 | assert layers_sizes[-1] == 1 109 | input = tf.reshape(input, [-1]) 110 | return input 111 | 112 | 113 | def install_mpi_excepthook(): 114 | import sys 115 | from mpi4py import MPI 116 | old_hook = sys.excepthook 117 | 118 | def new_hook(a, b, c): 119 | old_hook(a, b, c) 120 | sys.stdout.flush() 121 | sys.stderr.flush() 122 | MPI.COMM_WORLD.Abort() 123 | 124 | sys.excepthook = new_hook 125 | 126 | 127 | def mpi_fork(n): 128 | """Re-launches the current script with workers 129 | Returns "parent" for original parent, "child" for MPI children 130 | """ 131 | if n <= 1: 132 | return "child" 133 | if os.getenv("IN_MPI") is None: 134 | env = os.environ.copy() 135 | env.update( 136 | MKL_NUM_THREADS="1", 137 | OMP_NUM_THREADS="1", 138 | IN_MPI="1" 139 | ) 140 | core_binding = env.get('BASELINE_MPI_CORE_BINDING') 141 | cmd_binding = [] if core_binding is None or core_binding == 'False' else ["-bind-to", "core"] 142 | 143 | # "-bind-to core" is crucial for good performance 144 | 145 | cmd_profile = [] 146 | args = [ 147 | "mpirun", 148 | "-np", 149 | str(n), 150 | *cmd_binding, 151 | # "--report-bindings", 152 | *cmd_profile, 153 | # "-oversubscribe", 154 | sys.executable 155 | ] 156 | args += sys.argv 157 | subprocess.check_call(args, env=env) 158 | return "parent" 159 | else: 160 | install_mpi_excepthook() 161 | return "child" 162 | 163 | 164 | def convert_episode_to_batch_major(episode): 165 | """Converts an episode to have the batch dimension in the major (first) 166 | dimension. 167 | """ 168 | episode_batch = {} 169 | for key in episode.keys(): 170 | val = np.array(episode[key]).copy() 171 | # make inputs batch-major instead of time-major 172 | episode_batch[key] = val.swapaxes(0, 1) 173 | 174 | return episode_batch 175 | 176 | 177 | def transitions_in_episode_batch(episode_batch): 178 | """Number of transitions in a given episode batch. 179 | """ 180 | shape = episode_batch['u'].shape 181 | return shape[0] * shape[1] 182 | 183 | 184 | def reshape_for_broadcasting(source, target): 185 | """Reshapes a tensor (source) to have the correct shape and dtype of the target 186 | before broadcasting it with MPI. 187 | """ 188 | dim = len(target.get_shape()) 189 | shape = ([1] * (dim - 1)) + [-1] 190 | return tf.reshape(tf.cast(source, target.dtype), shape) 191 | -------------------------------------------------------------------------------- /results/VisualFetchReach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/results/VisualFetchReach.png -------------------------------------------------------------------------------- /results/VisualFinger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/results/VisualFinger.png -------------------------------------------------------------------------------- /results/VisualHandReach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/results/VisualHandReach.png -------------------------------------------------------------------------------- /results/mujoco.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/results/mujoco.gif -------------------------------------------------------------------------------- /results/quadratic.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Xingyu-Lin/auxiliary-tasks-rl/0146d2e9f10eb64097ccd3f317797c90ce7f9a21/results/quadratic.gif --------------------------------------------------------------------------------