├── README.md ├── buffer.py ├── lompo.py ├── lompo.yaml ├── models.py ├── tools.py └── wrappers.py /README.md: -------------------------------------------------------------------------------- 1 | # LOMPO 2 | Official Codebase for [Offline Reinforcement Learning from Images with Latent Space Models](https://arxiv.org/abs/2012.11547) 3 | -------------------------------------------------------------------------------- /buffer.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | 4 | 5 | class Buffer(): 6 | pass 7 | 8 | class LatentReplayBuffer(object): 9 | def __init__(self, 10 | real_size: int, 11 | latent_size: int, 12 | obs_dim: int, 13 | action_dim: int, 14 | immutable: bool = False, 15 | load_from: str = None, 16 | silent: bool = False, 17 | seed:int = 0): 18 | 19 | self.immutable = immutable 20 | 21 | self.buffers = dict() 22 | self.sizes = {'real': real_size, 'latent': latent_size} 23 | for key in ['real', 'latent']: 24 | self.buffers[key] = Buffer() 25 | self.buffers[key]._obs = np.full((self.sizes[key], obs_dim), float('nan'), dtype=np.float32) 26 | self.buffers[key]._actions = np.full((self.sizes[key], action_dim), float('nan'), dtype=np.float32) 27 | self.buffers[key]._rewards = np.full((self.sizes[key], 1), float('nan'), dtype=np.float32) 28 | self.buffers[key]._next_obs = np.full((self.sizes[key], obs_dim), float('nan'), dtype=np.float32) 29 | self.buffers[key]._terminals = np.full((self.sizes[key], 1), float('nan'), dtype=np.float32) 30 | 31 | self._real_stored_steps = 0 32 | self._real_write_location = 0 33 | 34 | self._latent_stored_steps = 0 35 | self._latent_write_location = 0 36 | 37 | self._stored_steps = 0 38 | self._random = np.random.RandomState(seed) 39 | 40 | @property 41 | def obs_dim(self): 42 | return self._obs.shape[-1] 43 | 44 | @property 45 | def action_dim(self): 46 | return self._actions.shape[-1] 47 | 48 | def __len__(self): 49 | return self._stored_steps 50 | 51 | def save(self, location: str): 52 | f = h5py.File(location, 'w') 53 | f.create_dataset('obs', data=self.buffers['real']._obs[:self._real_stored_steps], compression='lzf') 54 | f.create_dataset('actions', data=self.buffers['real']._actions[:self._real_stored_steps], compression='lzf') 55 | f.create_dataset('rewards', data=self.buffers['real']._rewards[:self._real_stored_steps], compression='lzf') 56 | f.create_dataset('next_obs', data=self.buffers['real']._next_obs[:self._real_stored_steps], compression='lzf') 57 | f.create_dataset('terminals', data=self.buffers['real']._terminals[:self._real_stored_steps], compression='lzf') 58 | f.close() 59 | 60 | def load(self, location: str): 61 | with h5py.File(location, "r") as f: 62 | obs = np.array(f['obs']) 63 | self._real_stored_steps = obs.shape[0] 64 | self._real_write_location = obs.shape[0] % self.sizes['real'] 65 | 66 | self.buffers['real']._obs[:self._real_stored_steps] = np.array(f['obs']) 67 | self.buffers['real']._actions[:self._real_stored_steps] = np.array(f['actions']) 68 | self.buffers['real']._rewards[:self._real_stored_steps] = np.array(f['rewards']) 69 | self.buffers['real']._next_obs[:self._real_stored_steps] = np.array(f['next_obs']) 70 | self.buffers['real']._terminals[:self._real_stored_steps] = np.array(f['terminals']) 71 | 72 | 73 | def add_samples(self, obs_feats, actions, next_obs_feats, rewards, terminals, sample_type = 'latent'): 74 | if sample_type == 'real': 75 | for obsi, actsi, nobsi, rewi, termi in zip(obs_feats, actions, next_obs_feats, rewards, terminals): 76 | self.buffers['real']._obs[self._real_write_location] = obsi 77 | self.buffers['real']._actions[self._real_write_location] = actsi 78 | self.buffers['real']._next_obs[self._real_write_location] = nobsi 79 | self.buffers['real']._rewards[self._real_write_location] = rewi 80 | self.buffers['real']._terminals[self._real_write_location] = termi 81 | 82 | self._real_write_location = (self._real_write_location + 1) % self.sizes['real'] 83 | self._real_stored_steps = min(self._real_stored_steps + 1, self.sizes['real']) 84 | 85 | else: 86 | for obsi, actsi, nobsi, rewi, termi in zip(obs_feats, actions, next_obs_feats, rewards, terminals): 87 | self.buffers['latent']._obs[self._latent_write_location] = obsi 88 | self.buffers['latent']._actions[self._latent_write_location] = actsi 89 | self.buffers['latent']._next_obs[self._latent_write_location] = nobsi 90 | self.buffers['latent']._rewards[self._latent_write_location] = rewi 91 | self.buffers['latent']._terminals[self._latent_write_location] = termi 92 | 93 | self._latent_write_location = (self._latent_write_location + 1) % self.sizes['latent'] 94 | self._latent_stored_steps = min(self._latent_stored_steps + 1, self.sizes['latent']) 95 | 96 | self._stored_steps = self._real_stored_steps + self._latent_stored_steps 97 | 98 | 99 | def sample(self, batch_size, return_dict: bool = False): 100 | real_idxs = self._random.choice(self._real_stored_steps, batch_size) 101 | latent_idxs = self._random.choice(self._latent_stored_steps, batch_size) 102 | 103 | obs = np.concatenate([self.buffers['real']._obs[real_idxs], 104 | self.buffers['latent']._obs[latent_idxs]], axis = 0) 105 | actions = np.concatenate([self.buffers['real']._actions[real_idxs], 106 | self.buffers['latent']._actions[latent_idxs]], axis = 0) 107 | next_obs = np.concatenate([self.buffers['real']._next_obs[real_idxs], 108 | self.buffers['latent']._next_obs[latent_idxs]], axis = 0) 109 | rewards = np.concatenate([self.buffers['real']._rewards[real_idxs], 110 | self.buffers['latent']._rewards[latent_idxs]], axis = 0) 111 | terminals = np.concatenate([self.buffers['real']._terminals[real_idxs], 112 | self.buffers['latent']._terminals[latent_idxs]], axis = 0) 113 | 114 | data = { 115 | 'obs': obs, 116 | 'actions': actions, 117 | 'next_obs': next_obs, 118 | 'rewards': rewards, 119 | 'terminals': terminals 120 | } 121 | 122 | 123 | 124 | return data 125 | -------------------------------------------------------------------------------- /lompo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | from copy import deepcopy 4 | import functools 5 | import json 6 | import os 7 | import pathlib 8 | import sys 9 | import time 10 | 11 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 12 | os.environ['MUJOCO_GL'] = 'osmesa' 13 | 14 | import numpy as np 15 | import random 16 | import tensorflow as tf 17 | from tensorflow.keras.mixed_precision import experimental as prec 18 | 19 | tf.get_logger().setLevel('ERROR') 20 | 21 | from tensorflow_probability import distributions as tfd 22 | 23 | 24 | import buffer 25 | import models 26 | import tools 27 | import wrappers 28 | 29 | def define_config(): 30 | config = tools.AttrDict() 31 | # General. 32 | config.logdir = pathlib.Path('.logdir') 33 | #config.datadir = pathlib.Path('./data/') 34 | config.datadir = pathlib.Path('.datadir/walker') 35 | config.seed = 0 36 | config.log_every = 1000 37 | config.save_every = 5000 38 | config.log_scalars = True 39 | config.log_images = True 40 | config.gpu_growth = True 41 | 42 | # Environment. 43 | config.task = 'dmc_walker_walk' 44 | config.envs = 1 45 | config.parallel = 'none' 46 | config.action_repeat = 2 47 | config.time_limit = 1000 48 | config.im_size = 64 49 | config.eval_noise = 0.0 50 | config.clip_rewards = 'none' 51 | config.precision = 32 52 | 53 | # Model. 54 | config.deter_size = 256 55 | config.stoch_size = 64 56 | config.num_models = 7 57 | config.num_units = 256 58 | config.proprio = False 59 | config.penalty_type = 'log_prob' 60 | config.dense_act = 'elu' 61 | config.cnn_act = 'relu' 62 | config.cnn_depth = 32 63 | config.pcont = False 64 | config.kl_scale = 1.0 65 | config.pcont_scale = 10.0 66 | config.weight_decay = 0.0 67 | config.weight_decay_pattern = r'.*' 68 | 69 | # Training. 70 | config.load_model = False 71 | config.load_agent = False 72 | config.load_buffer = False 73 | config.train_steps = 100000 74 | config.model_train_steps = 25000 75 | config.model_batch_size = 64 76 | config.model_batch_length = 50 77 | config.agent_batch_size = 256 78 | config.cql_samples = 16 79 | config.start_training = 50000 80 | config.agent_train_steps = 100000 81 | config.agent_itters_per_step = 200 82 | config.buffer_size = 2e6 83 | config.model_lr = 6e-4 84 | config.q_lr = 3e-4 85 | config.actor_lr = 3e-4 86 | config.grad_clip = 100.0 87 | config.tau=5e-3 88 | config.target_update_interval=1 89 | config.dataset_balance = False 90 | 91 | # Behavior. 92 | config.lmbd = 5.0 93 | config.alpha = 0.0 94 | config.sample = True 95 | config.discount = 0.99 96 | config.disclam = 0.95 97 | config.horizon = 5 98 | config.done_treshold = 0.5 99 | config.action_dist = 'tanh_normal' 100 | config.action_init_std = 5.0 101 | config.expl = 'additive_gaussian' 102 | config.expl_amount = 0.2 103 | config.expl_decay = 0.0 104 | config.expl_min = 0.0 105 | return config 106 | 107 | class Lompo(tools.Module): 108 | 109 | def __init__(self, config, datadir, actspace, writer): 110 | self._c = config 111 | self._actspace = actspace 112 | self._actdim = actspace.n if hasattr(actspace, 'n') else actspace.shape[0] 113 | episodes, steps = tools.count_episodes(config.datadir) 114 | self.latent_buffer = buffer.LatentReplayBuffer(steps, 115 | steps, 116 | self._c.deter_size + self._c.stoch_size, 117 | self._actdim) 118 | self.lmbd = config.lmbd 119 | self.alpha = config.alpha 120 | 121 | self._writer = writer 122 | tf.summary.experimental.set_step(0) 123 | self._metrics = dict() 124 | 125 | self._agent_step = 0 126 | self._model_step = 0 127 | 128 | self._random = np.random.RandomState(config.seed) 129 | self._float = prec.global_policy().compute_dtype 130 | self._dataset = iter(load_dataset(config.datadir, self._c)) 131 | self._episode_itterator = episode_itterator(config.datadir, self._c) 132 | 133 | self._build_model() 134 | for _ in range(10): 135 | self._model_train_step(next(self._dataset), prefix='eval') 136 | 137 | def __call__(self, obs, reset, state=None, training=True): 138 | if state is not None and reset.any(): 139 | mask = tf.cast(1 - reset, self._float)[:, None] 140 | state = tf.nest.map_structure(lambda x: x * mask, state) 141 | action, state = self.policy(obs, state, training) 142 | return action, state 143 | 144 | def load(self, filename): 145 | try: 146 | self.load_model(filename) 147 | except: 148 | pass 149 | 150 | try: 151 | self.load_agent(filename) 152 | except: 153 | pass 154 | 155 | def save(self, filename): 156 | self.save_model(filename) 157 | self.save_agentl(filename) 158 | 159 | def load_model(self, filename): 160 | self._encode.load(filename / 'encode.pkl') 161 | self._dynamics.load(filename / 'dynamic.pkl') 162 | self._decode.load(filename / 'decode.pkl') 163 | self._reward.load(filename / 'reward.pkl') 164 | if self._c.pcont: 165 | self._pcont.load(filename / 'pcont.pkl') 166 | if self._c.proprio: 167 | self._proprio.load(filename / 'proprio.pkl') 168 | 169 | def save_model(self, filename): 170 | filename.mkdir(parents=True, exist_ok=True) 171 | self._encode.save(filename / 'encode.pkl') 172 | self._dynamics.save(filename / 'dynamic.pkl') 173 | self._decode.save(filename / 'decode.pkl') 174 | self._reward.save(filename / 'reward.pkl') 175 | if self._c.pcont: 176 | self._pcont.save(filename / 'pcont.pkl') 177 | if self._c.proprio: 178 | self._proprio.save(filename / 'proprio.pkl') 179 | 180 | def load_agent(self, filename): 181 | self._qf1.load(filename / 'qf1.pkl') 182 | self._qf2.load(filename / 'qf2.pkl') 183 | self._target_qf1.load(filename / 'target_qf1.pkl') 184 | self._target_qf2.load(filename / 'target_qf2.pkl') 185 | self._actor.load(filename / 'actor.pkl') 186 | 187 | def save_agent(self, filename): 188 | filename.mkdir(parents=True, exist_ok=True) 189 | self._qf1.save(filename / 'qf1.pkl') 190 | self._qf2.save(filename / 'qf2.pkl') 191 | self._target_qf1.save(filename / 'target_qf1.pkl') 192 | self._target_qf2.save(filename / 'target_qf2.pkl') 193 | self._actor.save(filename / 'actor.pkl') 194 | 195 | def policy(self, obs, state, training): 196 | if state is None: 197 | latent = self._dynamics.initial(len(obs['image'])) 198 | action = tf.zeros((len(obs['image']), self._actdim), self._float) 199 | else: 200 | latent, action = state 201 | 202 | embed = self._encode(preprocess_raw(obs, self._c)) 203 | latent, _ = self._dynamics.obs_step(latent, action, embed) 204 | feat = self._dynamics.get_feat(latent) 205 | action = self._exploration(self._actor(feat), training) 206 | state = (latent, action) 207 | return action, state 208 | 209 | def _build_model(self): 210 | acts = dict(elu=tf.nn.elu, relu=tf.nn.relu, 211 | swish=tf.nn.swish, leaky_relu=tf.nn.leaky_relu) 212 | cnn_act = acts[self._c.cnn_act] 213 | act = acts[self._c.dense_act] 214 | 215 | #Create encoder based on environment observations 216 | if self._c.proprio: 217 | if self._c.im_size==64: 218 | self._encode = models.ConvEncoderProprio(self._c.cnn_depth, cnn_act) 219 | else: 220 | self._encode = models.ConvEncoderProprioLarge(self._c.cnn_depth, cnn_act) 221 | else: 222 | if self._c.im_size==64: 223 | self._encode = models.ConvEncoder(self._c.cnn_depth, cnn_act) 224 | else: 225 | self._encode = models.ConvEncoderLarge(self._c.cnn_depth, cnn_act) 226 | #RSSM model with ensables 227 | self._dynamics = models.RSSME(self._c.stoch_size, self._c.deter_size, 228 | self._c.deter_size, num_models = self._c.num_models) 229 | #Create decoder based on image size 230 | if self._c.im_size==64: 231 | self._decode = models.ConvDecoder(self._c.cnn_depth, cnn_act, 232 | shape = (self._c.im_size, self._c.im_size, 3)) 233 | else: 234 | self._decode = models.ConvDecoderLarge(self._c.cnn_depth, cnn_act, 235 | shape = (self._c.im_size, self._c.im_size, 3)) 236 | if self._c.proprio: 237 | self._proprio = models.DenseDecoder((self._propriodim,), 3, self._c.num_units, act=act) 238 | if self._c.pcont: 239 | self._pcont = models.DenseDecoder((), 3, self._c.num_units, 'binary', act=act) 240 | self._reward = models.DenseDecoder((), 2, self._c.num_units, act=act) 241 | 242 | model_modules = [self._encode, self._dynamics, self._decode, self._reward] 243 | if self._c.proprio: 244 | model_modules.append(self._proprio) 245 | if self._c.pcont: 246 | model_modules.append(self._pcont) 247 | 248 | #Build actor-critic networks 249 | self._qf1 = models.DenseNetwork(1, 3, self._c.num_units, act=act) 250 | self._qf2 = models.DenseNetwork(1, 3, self._c.num_units, act=act) 251 | self._target_qf1 = deepcopy(self._qf2) 252 | self._target_qf2 = deepcopy(self._qf1) 253 | self._qf_criterion = tf.keras.losses.Huber() 254 | self._actor = models.ActorNetwork(self._actdim, 4, self._c.num_units, act=act) 255 | 256 | #Initialize optimizers 257 | Optimizer = functools.partial(tools.Adam, 258 | wd=self._c.weight_decay, 259 | clip=self._c.grad_clip, 260 | wdpattern=self._c.weight_decay_pattern) 261 | 262 | self._model_opt = Optimizer('model', model_modules, self._c.model_lr) 263 | self._qf_opt = Optimizer('qf', [self._qf1, self._qf2], self._c.q_lr) 264 | self._actor_opt = Optimizer('actor', [self._actor], self._c.actor_lr) 265 | 266 | def _exploration(self, action, training): 267 | if training: 268 | amount = self._c.expl_amount 269 | if self._c.expl_decay: 270 | amount *= 0.5 ** (tf.cast(self._agent_step, tf.float32) / self._c.expl_decay) 271 | if self._c.expl_min: 272 | amount = tf.maximum(self._c.expl_min, amount) 273 | self._metrics['expl_amount'] = amount 274 | elif self._c.eval_noise: 275 | amount = self._c.eval_noise 276 | else: 277 | return action 278 | if self._c.expl == 'additive_gaussian': 279 | return tf.clip_by_value(tfd.Normal(action, amount).sample(), -1, 1) 280 | if self._c.expl == 'completely_random': 281 | return tf.random.uniform(action.shape, -1, 1) 282 | raise NotImplementedError(self._c.expl) 283 | 284 | def fit_model(self, itters): 285 | for itter in range(itters): 286 | data = next(self._dataset) 287 | self._model_train_step(data) 288 | if itter % self._c.save_every == 0: 289 | self.save_model(self._c.logdir / 'model_step_{}'.format(itter)) 290 | self.save_model(self._c.logdir / 'final_model') 291 | 292 | def _model_train_step(self, data, prefix='train'): 293 | with tf.GradientTape() as model_tape: 294 | embed = self._encode(data) 295 | post, prior = self._dynamics.observe(embed, data['action']) 296 | feat = self._dynamics.get_feat(post) 297 | image_pred = self._decode(feat) 298 | reward_pred = self._reward(feat) 299 | likes = tools.AttrDict() 300 | likes.image = tf.reduce_mean(tf.boolean_mask(image_pred.log_prob(data['image']), 301 | data['mask'])) 302 | likes.reward = tf.reduce_mean(tf.boolean_mask(reward_pred.log_prob(data['reward']), 303 | data['mask'])) 304 | if self._c.pcont: 305 | pcont_pred = self._pcont(feat) 306 | pcont_target = data['terminal'] 307 | likes.pcont = tf.reduce_mean(tf.boolean_mask(pcont_pred.log_prob(pcont_target), 308 | data['mask'])) 309 | likes.pcont *= self._c.pcont_scale 310 | 311 | for key in prior.keys(): 312 | prior[key] = tf.boolean_mask(prior[key], data['mask']) 313 | post[key] = tf.boolean_mask(post[key], data['mask']) 314 | 315 | prior_dist = self._dynamics.get_dist(prior) 316 | post_dist = self._dynamics.get_dist(post) 317 | div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist)) 318 | model_loss = self._c.kl_scale * div - sum(likes.values()) 319 | 320 | if prefix == 'train': 321 | model_norm = self._model_opt(model_tape, model_loss) 322 | self._model_step += 1 323 | 324 | if self._model_step % self._c.log_every == 0: 325 | self._image_summaries(data, embed, image_pred, self._model_step, prefix) 326 | model_summaries = dict() 327 | model_summaries['model_train/KL Divergence'] = tf.reduce_mean(div) 328 | model_summaries['model_train/image_recon'] = tf.reduce_mean(likes.image) 329 | model_summaries['model_train/reward_recon'] = tf.reduce_mean(likes.reward) 330 | model_summaries['model_train/model_loss'] = tf.reduce_mean(model_loss) 331 | if prefix == 'train': 332 | model_summaries['model_train/model_norm'] = tf.reduce_mean(model_norm) 333 | if self._c.pcont: 334 | model_summaries['model_train/terminal_recon'] = tf.reduce_mean(likes.pcont) 335 | self._write_summaries(model_summaries, self._model_step) 336 | 337 | def train_agent(self, itters): 338 | for itter in range(itters): 339 | data = preprocess_latent(self.latent_buffer.sample(self._c.agent_batch_size)) 340 | self._agent_train_step(data) 341 | if self._agent_step % self._c.target_update_interval == 0: 342 | self._update_target_critics() 343 | if itter % self._c.save_every == 0: 344 | self.save_agent(self._c.logdir) 345 | self.save_agent(self._c.logdir / 'final_agent') 346 | 347 | def _agent_train_step(self, data): 348 | obs = data['obs'] 349 | actions = data['actions'] 350 | next_obs = data['next_obs'] 351 | rewards = data['rewards'] 352 | terminals = data['terminals'] 353 | 354 | with tf.GradientTape() as q_tape: 355 | q1_pred = self._qf1(tf.concat([obs, actions], axis = -1)) 356 | q2_pred = self._qf2(tf.concat([obs, actions], axis = -1)) 357 | #new_next_actions = self._exploration(self._actor(next_obs), True) 358 | new_actions = self._actor(obs) 359 | new_next_actions = self._actor(next_obs) 360 | 361 | target_q_values = tf.reduce_min([self._target_qf1(tf.concat([next_obs, new_next_actions], axis = -1)), 362 | self._target_qf2(tf.concat([next_obs, new_next_actions], axis = -1))], axis = 0) 363 | q_target = rewards + self._c.discount * (1.0 - terminals) * target_q_values 364 | 365 | expanded_actions = tf.expand_dims(actions, 0) 366 | tilled_actions = tf.tile(expanded_actions, [self._c.cql_samples, 1, 1]) 367 | tilled_actions = tf.random.uniform(tilled_actions.shape, minval = -1, maxval = 1) 368 | tilled_actions = tf.concat([tilled_actions, tf.expand_dims(new_actions, 0)], axis = 0) 369 | 370 | expanded_obs = tf.expand_dims(obs, 0) 371 | tilled_obs = tf.tile(expanded_obs, [self._c.cql_samples + 1, 1, 1]) 372 | 373 | q1_values = self._qf1(tf.concat([tilled_obs, tilled_actions], axis = -1)) 374 | q2_values = self._qf2(tf.concat([tilled_obs, tilled_actions], axis = -1)) 375 | q1_penalty = tf.math.reduce_logsumexp(q1_values, axis = 0) 376 | q2_penalty = tf.math.reduce_logsumexp(q2_values, axis = 0) 377 | 378 | qf1_loss = self.alpha * (tf.reduce_mean(q1_penalty) - tf.reduce_mean(q1_pred[:self._c.agent_batch_size])) + \ 379 | tf.reduce_mean((q1_pred - tf.stop_gradient(q_target))**2) 380 | qf2_loss = self.alpha * (tf.reduce_mean(q2_penalty) - tf.reduce_mean(q2_pred[:self._c.agent_batch_size])) + \ 381 | tf.reduce_mean((q2_pred - tf.stop_gradient(q_target))**2) 382 | 383 | q_loss = qf1_loss + qf2_loss 384 | 385 | with tf.GradientTape() as actor_tape: 386 | new_obs_actions = self._actor(obs) 387 | q_new_actions = tf.reduce_min([self._qf1(tf.concat([obs, new_obs_actions], axis = -1)), 388 | self._qf2(tf.concat([obs, new_obs_actions], axis = -1))], axis = 0) 389 | actor_loss = -tf.reduce_mean(q_new_actions) 390 | 391 | q_norm = self._qf_opt(q_tape, q_loss) 392 | actor_norm = self._actor_opt(actor_tape, actor_loss) 393 | self._agent_step += 1 394 | 395 | if self._agent_step % self._c.log_every == 0: 396 | agent_summaries = dict() 397 | agent_summaries['agent/Q1_value'] = tf.reduce_mean(q1_pred) 398 | agent_summaries['agent/Q2_value'] = tf.reduce_mean(q2_pred) 399 | agent_summaries['agent/Q_target'] = tf.reduce_mean(q_target) 400 | agent_summaries['agent/Q_loss'] = q_loss 401 | agent_summaries['agent/actor_loss'] = actor_loss 402 | agent_summaries['agent/Q_grad_norm'] = q_norm 403 | agent_summaries['agent/actor_grad_norm'] = actor_norm 404 | self._write_summaries(agent_summaries, self._agent_step) 405 | 406 | def _update_target_critics(self): 407 | tau = tf.constant(self._c.tau) 408 | for source_weight, target_weight in zip(self._qf1.trainable_variables, 409 | self._target_qf1.trainable_variables): 410 | target_weight.assign(tau * source_weight + (1.0 - tau) * target_weight) 411 | for source_weight, target_weight in zip(self._qf2.trainable_variables, 412 | self._target_qf2.trainable_variables): 413 | target_weight.assign(tau * source_weight + (1.0 - tau) * target_weight) 414 | 415 | def _generate_real_data(self, batches = 1): 416 | for _ in range(batches): 417 | data = next(self._dataset) 418 | if not self._c.pcont: 419 | data['terminal'] = tf.zeros_like(data['reward']) 420 | embed = self._encode(data) 421 | post, prior = self._dynamics.observe(embed, data['action']) 422 | feat = self._dynamics.get_feat(post) 423 | 424 | for i in range(len(feat)): 425 | obs = tf.boolean_mask(feat[i], data['mask'][i]).numpy()[:-1] 426 | action = tf.boolean_mask(data['action'][i], data['mask'][i]).numpy()[1:] 427 | next_obs = tf.boolean_mask(feat[i], data['mask'][i]).numpy()[1:] 428 | reward = tf.boolean_mask(data['reward'][i], data['mask'][i]).numpy()[1:] 429 | terminal = tf.boolean_mask(data['terminal'][i], data['mask'][i]).numpy()[1:] 430 | 431 | self.latent_buffer.add_samples(obs, 432 | action, 433 | next_obs, 434 | reward, 435 | terminal, 436 | sample_type = 'real') 437 | 438 | def _generate_latent_data(self, data): 439 | embed = self._encode(data) 440 | post, prior = self._dynamics.observe(embed, data['action']) 441 | if self._c.pcont: # Last step could be terminal. 442 | post = {k: v[:, :-1] for k, v in post.items()} 443 | for key in post.keys(): 444 | post[key] = tf.boolean_mask(post[key], data['mask']) 445 | start = post 446 | 447 | policy = lambda state: tf.stop_gradient( 448 | self._exploration(self._actor(self._dynamics.get_feat(state)), True)) 449 | 450 | obs = [[] for _ in tf.nest.flatten(start)] 451 | next_obs = [[] for _ in tf.nest.flatten(start)] 452 | actions = [] 453 | full_posts = [[[] for _ in tf.nest.flatten(start)] for _ in range(self._c.num_models)] 454 | prev = start 455 | 456 | for index in range(self._c.horizon): 457 | [o.append(l) for o, l in zip(obs, tf.nest.flatten(prev))] 458 | a = policy(prev) 459 | actions.append(a) 460 | for i in range(self._c.num_models): 461 | p = self._dynamics.img_step(prev, a, k=i) 462 | [o.append(l) for o, l in zip(full_posts[i], tf.nest.flatten(p))] 463 | prev = self._dynamics.img_step(prev, a, k=np.random.choice(self._c.num_models, 1)[0]) 464 | [o.append(l) for o, l in zip(next_obs, tf.nest.flatten(prev))] 465 | 466 | obs = self._dynamics.get_feat(tf.nest.pack_sequence_as(start, [tf.stack(x, 0) for x in obs])) 467 | stoch = tf.nest.pack_sequence_as(start, [tf.stack(x, 0) for x in next_obs])['stoch'] 468 | next_obs = self._dynamics.get_feat(tf.nest.pack_sequence_as(start, [tf.stack(x, 0) for x in next_obs])) 469 | actions = tf.stack(actions, 0) 470 | rewards = self._reward(next_obs).mode() 471 | if self._c.pcont: 472 | dones = 1.0 * (self._pcont(next_obs).mean().numpy() > self._c.done_treshold) 473 | else: 474 | dones = tf.zeros_like(rewards) 475 | 476 | dists = [self._dynamics.get_dist( 477 | tf.nest.pack_sequence_as(start, [tf.stack(x, 0) for x in full_posts[i]])) 478 | for i in range(self._c.num_models)] 479 | 480 | #Compute penalty based on specification 481 | if self._c.penalty_type == 'log_prob': 482 | log_prob_vars = tf.math.reduce_std( 483 | tf.stack([d.log_prob(stoch) for d in dists], 0), 484 | axis = 0) 485 | modified_rewards = rewards - self.lmbd * log_prob_vars 486 | elif self._c.penalty_type == 'max_var': 487 | max_std = tf.reduce_max( 488 | tf.stack([tf.norm(d.stddev(), 2, -1) for d in dists], 0), 489 | axis = 0) 490 | modified_rewards = rewards - self.lmbd * max_std 491 | elif self._c.penalty_type == 'mean': 492 | mean_prediction = tf.reduce_mean(tf.stack([d.mean() for d in dists], 0), axis = 0) 493 | mean_disagreement = tf.reduce_mean( 494 | tf.stack([tf.norm(d.mean() - mean_prediction, 2, -1) for d in dists], 0), 495 | axis = 0) 496 | modified_rewards = rewards - self.lmbd * mean_disagreement 497 | elif self._c.penalty_type == None: 498 | modified_rewards = rewards 499 | 500 | self.latent_buffer.add_samples(flatten(obs).numpy(), 501 | flatten(actions).numpy(), 502 | flatten(next_obs).numpy(), 503 | flatten(modified_rewards).numpy(), 504 | flatten(dones), 505 | sample_type = 'latent') 506 | 507 | obs = [[] for _ in tf.nest.flatten(start)] 508 | next_obs = [[] for _ in tf.nest.flatten(start)] 509 | actions = [] 510 | full_posts = [[[] for _ in tf.nest.flatten(start)] for _ in range(self._c.num_models)] 511 | 512 | for key in prev.keys(): 513 | prev[key] = tf.boolean_mask(prev[key], flatten(1.0 - dones)) 514 | 515 | def _add_data(self, num_episodes = 1): 516 | self._process_data_to_latent(num_episodes = num_episodes) 517 | self._generate_latent_data(next(self._dataset)) 518 | 519 | def _process_data_to_latent(self, num_episodes=None): 520 | if num_episodes is None: 521 | num_episodes, _ = tools.count_episodes(self._c.datadir) 522 | 523 | for _ in range(num_episodes): 524 | filename = next(self._episode_itterator) 525 | try: 526 | with filename.open('rb') as f: 527 | episode = np.load(f) 528 | episode = {k: episode[k] for k in episode.keys()} 529 | except Exception as e: 530 | print(f'Could not load episode: {e}') 531 | continue 532 | 533 | obs = preprocess_raw(episode, self._c) 534 | if not self._c.pcont: 535 | obs['terminal'] = tf.zeros_like(obs['reward']) 536 | with tf.GradientTape(watch_accessed_variables=False) as _: 537 | embed = self._encode(obs) 538 | post, prior = self._dynamics.observe(tf.expand_dims(embed, 0), 539 | tf.expand_dims(obs['action'], 0)) 540 | feat = flatten(self._dynamics.get_feat(post)) 541 | self.latent_buffer.add_samples(feat.numpy()[:-1], 542 | obs['action'].numpy()[1:], 543 | feat.numpy()[1:], 544 | obs['reward'].numpy()[1:], 545 | obs['terminal'].numpy()[1:], 546 | sample_type = 'real') 547 | 548 | def _image_summaries(self, data, embed, image_pred, step=None, prefix = 'train'): 549 | truth = data['image'][:6] + 0.5 550 | recon = image_pred.mode()[:6] 551 | init, _ = self._dynamics.observe(embed[:6, :5], data['action'][:6, :5]) 552 | init = {k: v[:, -1] for k, v in init.items()} 553 | prior = self._dynamics.imagine(data['action'][:6, 5:], init) 554 | openl = self._decode(self._dynamics.get_feat(prior)).mode() 555 | model = tf.concat([recon[:, :5] + 0.5, openl + 0.5], 1) 556 | error_prior = (model - truth + 1) / 2 557 | error_posterior = (recon + 0.5 - truth + 1) / 2 558 | openl = tf.concat([truth, recon + 0.5, model, error_prior, error_posterior], 2) 559 | with self._writer.as_default(): 560 | tools.video_summary('agent/' + prefix, openl.numpy(), step=step) 561 | 562 | def _write_summaries(self, metrics, step = None): 563 | step = int(step) 564 | metrics = [(k, float(v)) for k, v in metrics.items()] 565 | with self._writer.as_default(): 566 | tf.summary.experimental.set_step(step) 567 | [tf.summary.scalar(k, m, step = step) for k, m in metrics] 568 | print(f'[{step}]', ' / '.join(f'{k} {v:.1f}' for k, v in metrics)) 569 | self._writer.flush() 570 | 571 | def preprocess_raw(obs, config): 572 | dtype = prec.global_policy().compute_dtype 573 | obs = obs.copy() 574 | 575 | with tf.device('cpu:0'): 576 | obs['image'] = tf.cast(obs['image'], dtype) / 255.0 - 0.5 577 | if 'image_128' in obs.keys(): 578 | obs['image_128'] = tf.cast(obs['image_128'], dtype) / 255.0 - 0.5 579 | clip_rewards = dict(none=lambda x: x, tanh=tf.tanh)[config.clip_rewards] 580 | obs['reward'] = clip_rewards(obs['reward']) 581 | for k in obs.keys(): 582 | obs[k] = tf.cast(obs[k], dtype) 583 | return obs 584 | 585 | def flatten(x): 586 | return tf.reshape(x, [-1] + list(x.shape[2:])) 587 | 588 | def preprocess_latent(batch): 589 | dtype = prec.global_policy().compute_dtype 590 | batch = batch.copy() 591 | with tf.device('cpu:0'): 592 | for key in batch.keys(): 593 | batch[key] = tf.cast(batch[key], dtype) 594 | return batch 595 | 596 | def count_steps(datadir, config): 597 | return tools.count_episodes(datadir)[1] * config.action_repeat 598 | 599 | 600 | def load_dataset(directory, config): 601 | episode = next(tools.load_episodes(directory, 1000, load_episodes = 1)) 602 | types = {k: v.dtype for k, v in episode.items()} 603 | shapes = {k: (None,) + v.shape[1:] for k, v in episode.items()} 604 | generator = lambda: tools.load_episodes(directory, config.train_steps, 605 | config.model_batch_length, config.dataset_balance) 606 | dataset = tf.data.Dataset.from_generator(generator, types, shapes) 607 | dataset = dataset.batch(config.model_batch_size, drop_remainder=True) 608 | dataset = dataset.map(functools.partial(preprocess_raw, config=config)) 609 | dataset = dataset.prefetch(10) 610 | return dataset 611 | 612 | def episode_itterator(datadir, config): 613 | while True: 614 | filenames = list(datadir.glob('*.npz')) 615 | for filename in list(filenames): 616 | yield filename 617 | 618 | def summarize_episode(episode, config, datadir, writer, prefix): 619 | length = (len(episode['reward']) - 1) * config.action_repeat 620 | ret = episode['reward'].sum() 621 | print(f'{prefix.title()} episode of length {length} with return {ret:.1f}.') 622 | metrics = [ 623 | (f'{prefix}/return', float(episode['reward'].sum())), 624 | (f'{prefix}/length', len(episode['reward']) - 1)] 625 | with writer.as_default(): # Env might run in a different thread. 626 | [tf.summary.scalar('sim/' + k, v) for k, v in metrics] 627 | if prefix == 'test': 628 | tools.video_summary(f'sim/{prefix}/video', episode['image'][None]) 629 | 630 | def make_env(config, writer, prefix, datadir, store): 631 | suite, task = config.task.split('_', 1) 632 | if suite == 'dmc': 633 | env = wrappers.DeepMindControl(task) 634 | env = wrappers.ActionRepeat(env, config.action_repeat) 635 | env = wrappers.NormalizeActions(env) 636 | elif suite == 'gym': 637 | env = wrappers.Gym(task, config, size=(128, 128)) 638 | env = wrappers.ActionRepeat(env, config.action_repeat) 639 | env = wrappers.NormalizeActions(env) 640 | elif task == 'door': 641 | env = wrappers.DoorOpen(config, size=(128, 128)) 642 | env = wrappers.ActionRepeat(env, config.action_repeat) 643 | env = wrappers.NormalizeActions(env) 644 | elif task == 'drawer': 645 | env = wrappers.DrawerOpen(config, size=(128, 128)) 646 | env = wrappers.ActionRepeat(env, config.action_repeat) 647 | env = wrappers.NormalizeActions(env) 648 | else: 649 | raise NotImplementedError(suite) 650 | env = wrappers.TimeLimit(env, config.time_limit / config.action_repeat) 651 | callbacks = [] 652 | if store: 653 | callbacks.append(lambda ep: tools.save_episodes(datadir, [ep])) 654 | if prefix == 'test': 655 | callbacks.append( 656 | lambda ep: summarize_episode(ep, config, datadir, writer, prefix)) 657 | env = wrappers.Collect(env, callbacks, config.precision) 658 | env = wrappers.RewardObs(env) 659 | return env 660 | 661 | 662 | def main(config): 663 | print(config) 664 | 665 | #Set random seeds 666 | os.environ['PYTHONHASHSEED']=str(config.seed) 667 | os.environ['TF_CUDNN_DETERMINISTIC'] = '1' 668 | random.seed(config.seed) 669 | np.random.seed(config.seed) 670 | tf.random.set_seed(config.seed) 671 | 672 | if config.gpu_growth: 673 | for gpu in tf.config.experimental.list_physical_devices('GPU'): 674 | tf.config.experimental.set_memory_growth(gpu, True) 675 | 676 | config.logdir = config.logdir / config.task 677 | config.logdir = config.logdir / 'seed_{}'.format(config.seed) 678 | config.logdir.mkdir(parents=True, exist_ok=True) 679 | datadir = config.datadir 680 | tf_dir = config.logdir / 'tensorboard' 681 | writer = tf.summary.create_file_writer(str(tf_dir), max_queue=1000, flush_millis=20000) 682 | writer.set_as_default() 683 | 684 | 685 | # Create environments. 686 | train_envs = [wrappers.Async(lambda: make_env( 687 | config, writer, 'train', '.', store=False), config.parallel) 688 | for _ in range(config.envs)] 689 | test_envs = [wrappers.Async(lambda: make_env( 690 | config, writer, 'test', '.', store=False), config.parallel) 691 | for _ in range(config.envs)] 692 | actspace = train_envs[0].action_space 693 | 694 | # Train and regularly evaluate the agent. 695 | agent = Lompo(config, datadir, actspace, writer) 696 | 697 | if agent._c.load_model: 698 | agent.load_model(config.logdir / 'final_model') 699 | print('Load pretarined model') 700 | if agent._c.load_buffer: 701 | agent.latent_buffer.load(agent._c.logdir / 'buffer.h5py') 702 | else: 703 | agent._process_data_to_latent() 704 | agent.latent_buffer.save(agent._c.logdir / 'buffer.h5py') 705 | else: 706 | agent.fit_model(agent._c.model_train_steps) 707 | #agent.save_model(config.logdir) 708 | #agent._generate_real_data(steps = 5000) 709 | agent._process_data_to_latent() 710 | agent.latent_buffer.save(agent._c.logdir / 'buffer.h5py') 711 | 712 | if agent._c.load_agent: 713 | agent.load_agent(config.logdir) 714 | print('Load pretarined actor') 715 | 716 | while agent.latent_buffer._latent_stored_steps < agent._c.start_training: 717 | agent._generate_latent_data(next(agent._dataset)) 718 | 719 | while agent._agent_step < int(config.agent_train_steps): 720 | print('Start evaluation.') 721 | tools.simulate( 722 | functools.partial(agent, training=False), test_envs, episodes=1) 723 | #agent._latent_evaluate(train_envs[0]) 724 | writer.flush() 725 | print('Start collection.') 726 | agent.train_agent(agent._c.agent_itters_per_step) 727 | #agent._generate_real_data(steps = 5) 728 | 729 | if config.sample: 730 | agent._add_data(num_episodes = 1) 731 | else: 732 | agent._process_data_to_latent(num_episodes = 1) 733 | 734 | 735 | for env in train_envs + test_envs: 736 | env.close() 737 | 738 | if __name__ == '__main__': 739 | try: 740 | import colored_traceback 741 | colored_traceback.add_hook() 742 | except ImportError: 743 | pass 744 | parser = argparse.ArgumentParser() 745 | for key, value in define_config().items(): 746 | parser.add_argument(f'--{key}', type=tools.args_type(value), default=value) 747 | main(parser.parse_args()) 748 | 749 | 750 | 751 | -------------------------------------------------------------------------------- /lompo.yaml: -------------------------------------------------------------------------------- 1 | name: lompo 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - backcall=0.2.0=py_0 7 | - blas=1.0=mkl 8 | - ca-certificates=2021.1.19=h06a4308_0 9 | - certifi=2020.12.5=py37h06a4308_0 10 | - dbus=1.13.18=hb2f20db_0 11 | - decorator=4.4.2=py_0 12 | - entrypoints=0.3=py37_0 13 | - expat=2.2.10=he6710b0_2 14 | - fontconfig=2.13.1=h6c09931_0 15 | - freetype=2.10.4=h5ab3b9f_0 16 | - glib=2.67.4=h36276a3_1 17 | - gst-plugins-base=1.14.0=h8213a91_2 18 | - gstreamer=1.14.0=h28cd5cc_2 19 | - icu=58.2=he6710b0_3 20 | - intel-openmp=2020.2=254 21 | - ipykernel=5.3.0=py37h5ca1d4c_0 22 | - ipython=7.16.1=py37h5ca1d4c_0 23 | - ipython_genutils=0.2.0=py37_0 24 | - jedi=0.17.1=py37_0 25 | - jpeg=9b=h024ee3a_2 26 | - jupyter_client=6.1.3=py_0 27 | - jupyter_core=4.6.3=py37_0 28 | - lcms2=2.11=h396b838_0 29 | - ld_impl_linux-64=2.33.1=h53a641e_7 30 | - libedit=3.1.20191231=h7b6447c_0 31 | - libffi=3.3=he6710b0_1 32 | - libgcc-ng=9.1.0=hdf63c60_0 33 | - libgfortran-ng=7.3.0=hdf63c60_0 34 | - libpng=1.6.37=hbc83047_0 35 | - libsodium=1.0.18=h7b6447c_0 36 | - libstdcxx-ng=9.1.0=hdf63c60_0 37 | - libtiff=4.1.0=h2733197_1 38 | - libuuid=1.0.3=h1bed415_2 39 | - libxcb=1.14=h7b6447c_0 40 | - libxml2=2.9.10=hb55368b_3 41 | - lz4-c=1.9.3=h2531618_0 42 | - matplotlib-base=3.3.4=py37h62a2d02_0 43 | - mkl=2020.2=256 44 | - mkl-service=2.3.0=py37he8ac12f_0 45 | - mkl_fft=1.2.1=py37h54f3939_0 46 | - mkl_random=1.1.1=py37h0573a6f_0 47 | - ncurses=6.2=he6710b0_1 48 | - numpy-base=1.19.2=py37hfa32c7d_0 49 | - olefile=0.46=py37_0 50 | - openssl=1.1.1j=h27cfd23_0 51 | - parso=0.7.0=py_0 52 | - pcre=8.44=he6710b0_0 53 | - pexpect=4.8.0=py37_0 54 | - pickleshare=0.7.5=py37_0 55 | - pip=20.1.1=py37_1 56 | - prompt-toolkit=3.0.5=py_0 57 | - ptyprocess=0.6.0=py37_0 58 | - pygments=2.6.1=py_0 59 | - pyparsing=2.4.7=pyhd3eb1b0_0 60 | - pyqt=5.9.2=py37h05f1152_2 61 | - python=3.7.7=hcff3b4d_5 62 | - python-dateutil=2.8.1=py_0 63 | - pyzmq=19.0.1=py37he6710b0_1 64 | - qt=5.9.7=h5867ecd_1 65 | - readline=8.0=h7b6447c_0 66 | - seaborn=0.11.1=pyhd3eb1b0_0 67 | - setuptools=47.3.1=py37_0 68 | - sip=4.19.8=py37hf484d3e_0 69 | - six=1.15.0=py_0 70 | - sqlite=3.32.3=h62c20be_0 71 | - tk=8.6.10=hbc83047_0 72 | - tornado=6.0.4=py37h7b6447c_1 73 | - traitlets=4.3.3=py37_0 74 | - wcwidth=0.2.5=py_0 75 | - wheel=0.34.2=py37_0 76 | - wurlitzer=2.0.0=py37_0 77 | - xz=5.2.5=h7b6447c_0 78 | - zeromq=4.3.2=he6710b0_2 79 | - zlib=1.2.11=h7b6447c_3 80 | - zstd=1.4.5=h9ceee32_0 81 | - pip: 82 | - absl-py==0.9.0 83 | - argon2-cffi==20.1.0 84 | - argparse==1.4.0 85 | - astunparse==1.6.3 86 | - atari-py==0.2.6 87 | - attrs==19.3.0 88 | - bleach==3.1.5 89 | - box2d-py==2.3.8 90 | - cachetools==4.1.1 91 | - cffi==1.14.0 92 | - chardet==3.0.4 93 | - click==7.1.2 94 | - cloudpickle==1.2.2 95 | - cycler==0.10.0 96 | - cython==0.29.20 97 | - defusedxml==0.6.0 98 | - dm-control==0.0.322773188 99 | - dm-tree==0.1.5 100 | - dotmap==1.3.17 101 | - easyprocess==0.3 102 | - entrypoint2==0.2.1 103 | - fasteners==0.15 104 | - future==0.18.2 105 | - gast==0.3.3 106 | - gif2numpy==1.3 107 | - glfw==1.11.2 108 | - google-auth==1.18.0 109 | - google-auth-oauthlib==0.4.1 110 | - google-pasta==0.2.0 111 | - googledrivedownloader==0.4 112 | - grpcio==1.30.0 113 | - gym==0.17.3 114 | - h5py==2.10.0 115 | - idna==2.10 116 | - imageio==2.8.0 117 | - importlib-metadata==1.7.0 118 | - ipywidgets==7.5.1 119 | - jeepney==0.4.3 120 | - jinja2==2.11.2 121 | - joblib==0.16.0 122 | - jsonschema==3.2.0 123 | - jupyter==1.0.0 124 | - jupyter-console==6.1.0 125 | - kaitaistruct==0.9 126 | - keras-preprocessing==1.1.2 127 | - kiwisolver==1.2.0 128 | - labmaze==1.0.2 129 | - lockfile==0.12.2 130 | - lxml==4.5.2 131 | - markdown==3.2.2 132 | - markupsafe==1.1.1 133 | - matplotlib==3.2.2 134 | - mistune==0.8.4 135 | - mjrl==1.0.0 136 | - monotonic==1.5 137 | - mpi4py==3.0.3 138 | - mss==6.0.0 139 | - mujoco-py==2.0.2.11 140 | - nbconvert==5.6.1 141 | - nbformat==5.0.7 142 | - networkx==2.4 143 | - notebook==6.1.1 144 | - numpy==1.19.0 145 | - oauthlib==3.1.0 146 | - opencv-python==4.2.0.34 147 | - opt-einsum==3.2.1 148 | - packaging==20.4 149 | - pandas==1.0.5 150 | - pandocfilters==1.4.2 151 | - pillow==7.2.0 152 | - prometheus-client==0.8.0 153 | - protobuf==3.12.2 154 | - psutil==5.7.2 155 | - pyasn1==0.4.8 156 | - pyasn1-modules==0.2.8 157 | - pybullet==3.0.8 158 | - pycparser==2.20 159 | - pygame==1.9.6 160 | - pyglet==1.5.0 161 | - pympler==0.8 162 | - pyrsistent==0.16.0 163 | - pyscreenshot==2.2 164 | - pytz==2020.1 165 | - pyvirtualdisplay==1.3.2 166 | - qtconsole==4.7.5 167 | - qtpy==1.9.0 168 | - requests==2.24.0 169 | - requests-oauthlib==1.3.0 170 | - roboschool==1.0.49 171 | - rsa==4.6 172 | - scipy==1.4.1 173 | - send2trash==1.5.0 174 | - spyder-kernels==0.5.2 175 | - stable-baselines==2.10.1 176 | - stable-baselines3==0.8.0 177 | - tensorboard==2.2.2 178 | - tensorboard-plugin-wit==1.7.0 179 | - tensorflow-estimator==2.2.0 180 | - tensorflow-gpu==2.2.0 181 | - tensorflow-probability==0.10.0 182 | - termcolor==1.1.0 183 | - terminado==0.8.3 184 | - testpath==0.4.4 185 | - torch==1.6.0 186 | - tqdm==4.48.0 187 | - transforms3d==0.3.1 188 | - urllib3==1.25.9 189 | - webencodings==0.5.1 190 | - werkzeug==1.0.1 191 | - widgetsnbextension==3.5.1 192 | - wrapt==1.12.1 193 | - zipp==3.1.0 194 | 195 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.keras import layers as tfkl 4 | from tensorflow_probability import distributions as tfd 5 | from tensorflow.keras.mixed_precision import experimental as prec 6 | 7 | import tools 8 | 9 | 10 | class RSSME(tools.Module): 11 | def __init__(self, stoch=30, deter=200, hidden=200, num_models=7, act=tf.nn.elu): 12 | super().__init__() 13 | self._activation = act 14 | self._stoch_size = stoch 15 | self._deter_size = deter 16 | self._hidden_size = hidden 17 | self._cell = tfkl.GRUCell(self._deter_size) 18 | self._k = num_models 19 | 20 | def initial(self, batch_size): 21 | dtype = prec.global_policy().compute_dtype 22 | return dict( 23 | mean=tf.zeros([batch_size, self._stoch_size], dtype), 24 | std=tf.zeros([batch_size, self._stoch_size], dtype), 25 | stoch=tf.zeros([batch_size, self._stoch_size], dtype), 26 | deter=self._cell.get_initial_state(None, batch_size, dtype)) 27 | 28 | def observe(self, embed, action, state=None): 29 | if state is None: 30 | state = self.initial(tf.shape(action)[0]) 31 | embed = tf.transpose(embed, [1, 0, 2]) 32 | action = tf.transpose(action, [1, 0, 2]) 33 | post, prior = tools.static_scan(lambda prev, inputs: self.obs_step(prev[0], *inputs), 34 | (action, embed), (state, state)) 35 | post = {k: tf.transpose(v, [1, 0, 2]) for k, v in post.items()} 36 | prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()} 37 | return post, prior 38 | 39 | def imagine(self, action, state=None): 40 | if state is None: 41 | state = self.initial(tf.shape(action)[0]) 42 | assert isinstance(state, dict), state 43 | action = tf.transpose(action, [1, 0, 2]) 44 | prior = tools.static_scan(self.img_step, action, state) 45 | prior = {k: tf.transpose(v, [1, 0, 2]) for k, v in prior.items()} 46 | return prior 47 | 48 | def get_feat(self, state): 49 | return tf.concat([state['stoch'], state['deter']], -1) 50 | 51 | def get_feat_size(self, state): 52 | return self._stoch_size + self._deter_size 53 | 54 | def get_dist(self, state): 55 | return tfd.MultivariateNormalDiag(state['mean'], state['std']) 56 | 57 | def obs_step(self, prev_state, prev_action, embed): 58 | prior = self.img_step(prev_state, prev_action) 59 | x = tf.concat([prior['deter'], embed], -1) 60 | x = self.get('obs1', tfkl.Dense, self._hidden_size, self._activation)(x) 61 | x = self.get('obs2', tfkl.Dense, 2 * self._stoch_size, None)(x) 62 | mean, std = tf.split(x, 2, -1) 63 | std = tf.nn.softplus(std) + 0.1 64 | stoch = self.get_dist({'mean': mean, 'std': std}).sample() 65 | post = {'mean': mean, 'std': std, 'stoch': stoch, 'deter': prior['deter']} 66 | return post, prior 67 | 68 | def img_step(self, prev_state, prev_action, k=None): 69 | if k is None: 70 | k = np.random.choice(self._k) 71 | x = tf.concat([prev_state['stoch'], prev_action], -1) 72 | x = self.get('img1', tfkl.Dense, self._hidden_size, self._activation)(x) 73 | x, deter = self._cell(x, [prev_state['deter']]) 74 | deter = deter[0] # Keras wraps the state in a list. 75 | x = self.get('img2_{}'.format(k), tfkl.Dense, self._hidden_size, self._activation)(x) 76 | x = self.get('img3_{}'.format(k), tfkl.Dense, 2 * self._stoch_size, None)(x) 77 | mean, std = tf.split(x, 2, -1) 78 | std = tf.nn.softplus(std) + 0.1 79 | stoch = self.get_dist({'mean': mean, 'std': std}).sample() 80 | prior = {'mean': mean, 'std': std, 'stoch': stoch, 'deter': deter} 81 | return prior 82 | 83 | 84 | class MultivariateNormalDiag(tools.Module): 85 | def __init__(self, hidden_size, latent_size, scale=None): 86 | super().__init__() 87 | self.latent_size = latent_size 88 | self.scale = scale 89 | self.dense1 = tf.keras.layers.Dense(hidden_size, activation=tf.nn.leaky_relu) 90 | self.dense2 = tf.keras.layers.Dense(hidden_size, activation=tf.nn.leaky_relu) 91 | self.output_layer = tf.keras.layers.Dense(2 * latent_size if self.scale 92 | is None else latent_size) 93 | 94 | def __call__(self, *inputs): 95 | if len(inputs) > 1: 96 | inputs = tf.concat(inputs, axis=-1) 97 | else: 98 | inputs, = inputs 99 | out = self.dense1(inputs) 100 | out = self.dense2(out) 101 | out = self.output_layer(out) 102 | loc = out[..., :self.latent_size] 103 | if self.scale is None: 104 | assert out.shape[-1] == 2 * self.latent_size 105 | scale_diag = tf.nn.softplus(out[..., self.latent_size:]) + 1e-5 106 | else: 107 | assert out.shape[-1].value == self.latent_size 108 | scale_diag = tf.ones_like(loc) * self.scale 109 | return loc, scale_diag 110 | 111 | 112 | class ConstantMultivariateNormalDiag(tools.Module): 113 | def __init__(self, latent_size, scale=None): 114 | super().__init__() 115 | self.latent_size = latent_size 116 | self.scale = scale 117 | 118 | def __call__(self, *inputs): 119 | # first input should not have any dimensions after the batch_shape, step_type 120 | batch_shape = tf.shape(inputs[0]) # input is only used to infer batch_shape 121 | shape = tf.concat([batch_shape, [self.latent_size]], axis=0) 122 | loc = tf.zeros(shape) 123 | if self.scale is None: 124 | scale_diag = tf.ones(shape) 125 | else: 126 | scale_diag = tf.ones(shape) * self.scale 127 | return loc, scale_diag 128 | 129 | 130 | class ConvEncoderLarge(tools.Module): 131 | def __init__(self, depth=32, act=tf.nn.relu): 132 | self._act = act 133 | self._depth = depth 134 | 135 | def __call__(self, obs): 136 | kwargs = dict(strides=2, activation=self._act) 137 | x = tf.reshape(obs['image'], (-1,) + tuple(obs['image'].shape[-3:])) 138 | x = self.get('h1', tfkl.Conv2D, 1 * self._depth, 4, **kwargs)(x) 139 | x = self.get('h2', tfkl.Conv2D, 2 * self._depth, 4, **kwargs)(x) 140 | x = self.get('h3', tfkl.Conv2D, 4 * self._depth, 4, **kwargs)(x) 141 | x = self.get('h4', tfkl.Conv2D, 8 * self._depth, 4, **kwargs)(x) 142 | x = self.get('h5', tfkl.Conv2D, 8 * self._depth, 4, **kwargs)(x) 143 | shape = tf.concat([tf.shape(obs['image'])[:-3], [32 * self._depth]], 0) 144 | return tf.reshape(x, shape) 145 | 146 | 147 | class ConvDecoderLarge(tools.Module): 148 | def __init__(self, depth=32, act=tf.nn.relu, shape=(128, 128, 3)): 149 | self._act = act 150 | self._depth = depth 151 | self._shape = shape 152 | 153 | def __call__(self, features): 154 | kwargs = dict(strides=2, activation=self._act) 155 | x = self.get('h1', tfkl.Dense, 32 * self._depth, None)(features) 156 | x = tf.reshape(x, [-1, 1, 1, 32 * self._depth]) 157 | x = self.get('h2', tfkl.Conv2DTranspose, 4 * self._depth, 5, **kwargs)(x) 158 | x = self.get('h3', tfkl.Conv2DTranspose, 2 * self._depth, 5, **kwargs)(x) 159 | x = self.get('h4', tfkl.Conv2DTranspose, 1 * self._depth, 5, **kwargs)(x) 160 | x = self.get('h5', tfkl.Conv2DTranspose, 1 * self._depth, 6, **kwargs)(x) 161 | x = self.get('h6', tfkl.Conv2DTranspose, self._shape[-1], 6, strides=2)(x) 162 | mean = tf.reshape(x, tf.concat([tf.shape(features)[:-1], self._shape], 0)) 163 | return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)) 164 | 165 | 166 | class ConvEncoder(tools.Module): 167 | def __init__(self, depth=32, act=tf.nn.relu): 168 | self._act = act 169 | self._depth = depth 170 | 171 | def __call__(self, obs): 172 | kwargs = dict(strides=2, activation=self._act) 173 | x = tf.reshape(obs['image'], (-1,) + tuple(obs['image'].shape[-3:])) 174 | x = self.get('h1', tfkl.Conv2D, 1 * self._depth, 4, **kwargs)(x) 175 | x = self.get('h2', tfkl.Conv2D, 2 * self._depth, 4, **kwargs)(x) 176 | x = self.get('h3', tfkl.Conv2D, 4 * self._depth, 4, **kwargs)(x) 177 | x = self.get('h4', tfkl.Conv2D, 8 * self._depth, 4, **kwargs)(x) 178 | shape = tf.concat([tf.shape(obs['image'])[:-3], [32 * self._depth]], 0) 179 | return tf.reshape(x, shape) 180 | 181 | 182 | class ConvDecoder(tools.Module): 183 | def __init__(self, depth=32, act=tf.nn.relu, shape=(64, 64, 3)): 184 | self._act = act 185 | self._depth = depth 186 | self._shape = shape 187 | 188 | def __call__(self, features): 189 | kwargs = dict(strides=2, activation=self._act) 190 | x = self.get('h1', tfkl.Dense, 32 * self._depth, None)(features) 191 | x = tf.reshape(x, [-1, 1, 1, 32 * self._depth]) 192 | x = self.get('h2', tfkl.Conv2DTranspose, 4 * self._depth, 5, **kwargs)(x) 193 | x = self.get('h3', tfkl.Conv2DTranspose, 2 * self._depth, 5, **kwargs)(x) 194 | x = self.get('h4', tfkl.Conv2DTranspose, 1 * self._depth, 6, **kwargs)(x) 195 | x = self.get('h5', tfkl.Conv2DTranspose, self._shape[-1], 6, strides=2)(x) 196 | mean = tf.reshape(x, tf.concat([tf.shape(features)[:-1], self._shape], 0)) 197 | return tfd.Independent(tfd.Normal(mean, 1), len(self._shape)) 198 | 199 | 200 | class DenseDecoder(tools.Module): 201 | def __init__(self, shape, layers, units, dist='normal', act=tf.nn.elu): 202 | self._shape = shape 203 | self._layers = layers 204 | self._units = units 205 | self._dist = dist 206 | self._act = act 207 | 208 | def __call__(self, features): 209 | x = features 210 | for index in range(self._layers): 211 | x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x) 212 | x = self.get(f'hout', tfkl.Dense, np.prod(self._shape))(x) 213 | x = tf.reshape(x, tf.concat([tf.shape(features)[:-1], self._shape], 0)) 214 | if self._dist == 'normal': 215 | return tfd.Independent(tfd.Normal(x, 1), len(self._shape)) 216 | if self._dist == 'binary': 217 | return tfd.Independent(tfd.Bernoulli(x), len(self._shape)) 218 | raise NotImplementedError(self._dist) 219 | 220 | 221 | class DenseNetwork(tools.Module): 222 | def __init__(self, shape, layers, units, act=tf.nn.elu): 223 | self._shape = shape 224 | self._layers = layers 225 | self._units = units 226 | self._act = act 227 | 228 | def __call__(self, features): 229 | x = features 230 | for index in range(self._layers): 231 | x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x) 232 | x = self.get(f'hout', tfkl.Dense, self._shape)(x) 233 | return x 234 | 235 | 236 | class ActorNetwork(tools.Module): 237 | def __init__(self, shape, layers, units, act=tf.nn.elu, mean_scale=1.0): 238 | self._shape = shape 239 | self._layers = layers 240 | self._units = units 241 | self._act = act 242 | self._mean_scale = mean_scale 243 | 244 | def __call__(self, features): 245 | x = features 246 | for index in range(self._layers): 247 | x = self.get(f'h{index}', tfkl.Dense, self._units, self._act)(x) 248 | x = self.get(f'hout', tfkl.Dense, self._shape)(x) 249 | x = self._mean_scale * tf.tanh(x) 250 | return x 251 | 252 | 253 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import datetime 3 | import io 4 | import pathlib 5 | import pickle 6 | import uuid 7 | 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | import tensorflow.compat.v1 as tf1 12 | from tensorflow_probability import distributions as tfd 13 | 14 | 15 | class AttrDict(dict): 16 | __setattr__ = dict.__setitem__ 17 | __getattr__ = dict.__getitem__ 18 | 19 | 20 | class Module(tf.Module): 21 | def save(self, filename): 22 | values = tf.nest.map_structure(lambda x: x.numpy(), self.variables) 23 | with pathlib.Path(filename).open('wb') as f: 24 | pickle.dump(values, f) 25 | 26 | def load(self, filename): 27 | with pathlib.Path(filename).open('rb') as f: 28 | values = pickle.load(f) 29 | tf.nest.map_structure(lambda x, y: x.assign(y), self.variables, values) 30 | 31 | def get(self, name, ctor, *args, **kwargs): 32 | # Create or get layer by name to avoid mentioning it in the constructor. 33 | if not hasattr(self, '_modules'): 34 | self._modules = {} 35 | if name not in self._modules: 36 | self._modules[name] = ctor(*args, **kwargs) 37 | return self._modules[name] 38 | 39 | 40 | def video_summary(name, video, step=None, fps=20): 41 | name = name if isinstance(name, str) else name.decode('utf-8') 42 | if np.issubdtype(video.dtype, np.floating): 43 | video = np.clip(255 * video, 0, 255).astype(np.uint8) 44 | B, T, H, W, C = video.shape 45 | try: 46 | frames = video.transpose((1, 2, 0, 3, 4)).reshape((T, H, B * W, C)) 47 | summary = tf1.Summary() 48 | image = tf1.Summary.Image(height=B * H, width=T * W, colorspace=C) 49 | image.encoded_image_string = encode_gif(frames, fps) 50 | summary.value.add(tag=name + '/gif', image=image) 51 | tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step) 52 | except (IOError, OSError) as e: 53 | print('GIF summaries require ffmpeg in $PATH.', e) 54 | frames = video.transpose((0, 2, 1, 3, 4)).reshape((1, B * H, T * W, C)) 55 | tf.summary.image(name + '/grid', frames, step) 56 | 57 | 58 | def encode_gif(frames, fps): 59 | from subprocess import Popen, PIPE 60 | h, w, c = frames[0].shape 61 | pxfmt = {1: 'gray', 3: 'rgb24'}[c] 62 | cmd = ' '.join([ 63 | f'ffmpeg -y -f rawvideo -vcodec rawvideo', 64 | f'-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex', 65 | f'[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse', 66 | f'-r {fps:.02f} -f gif -']) 67 | proc = Popen(cmd.split(' '), stdin=PIPE, stdout=PIPE, stderr=PIPE) 68 | for image in frames: 69 | proc.stdin.write(image.tostring()) 70 | out, err = proc.communicate() 71 | if proc.returncode: 72 | raise IOError('\n'.join([' '.join(cmd), err.decode('utf8')])) 73 | del proc 74 | return out 75 | 76 | 77 | def simulate(agent, envs, steps=0, episodes=0, state=None): 78 | # Initialize or unpack simulation state. 79 | if state is None: 80 | step, episode = 0, 0 81 | done = np.ones(len(envs), np.bool) 82 | length = np.zeros(len(envs), np.int32) 83 | obs = [None] * len(envs) 84 | agent_state = None 85 | else: 86 | step, episode, done, length, obs, agent_state = state 87 | while (steps and step < steps) or (episodes and episode < episodes): 88 | # Reset envs if necessary. 89 | if done.any(): 90 | indices = [index for index, d in enumerate(done) if d] 91 | promises = [envs[i].reset(blocking=False) for i in indices] 92 | for index, promise in zip(indices, promises): 93 | obs[index] = promise() 94 | # Step agents. 95 | obs = {k: np.stack([o[k] for o in obs]) for k in obs[0]} 96 | action, agent_state = agent(obs, done, agent_state) 97 | action = np.array(action) 98 | assert len(action) == len(envs) 99 | # Step envs. 100 | promises = [e.step(a, blocking=False) for e, a in zip(envs, action)] 101 | obs, _, done = zip(*[p()[:3] for p in promises]) 102 | obs = list(obs) 103 | done = np.stack(done) 104 | episode += int(done.sum()) 105 | length += 1 106 | step += (done * length).sum() 107 | length *= (1 - done) 108 | # Return new state to allow resuming the simulation. 109 | return (step - steps, episode - episodes, done, length, obs, agent_state) 110 | 111 | def count_episodes(directory): 112 | filenames = directory.glob('*.npz') 113 | lengths = [int(n.stem.rsplit('-', 1)[-1]) - 1 for n in filenames] 114 | episodes, steps = len(lengths), sum(lengths) 115 | return episodes, steps 116 | 117 | 118 | def save_episodes(directory, episodes): 119 | directory = pathlib.Path(directory).expanduser() 120 | directory.mkdir(parents=True, exist_ok=True) 121 | timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') 122 | for episode in episodes: 123 | identifier = str(uuid.uuid4().hex) 124 | length = len(episode['reward']) 125 | filename = directory / f'{timestamp}-{identifier}-{length}.npz' 126 | with io.BytesIO() as f1: 127 | np.savez_compressed(f1, **episode) 128 | f1.seek(0) 129 | with filename.open('wb') as f2: 130 | f2.write(f1.read()) 131 | 132 | 133 | def load_episodes(directory, rescan, length=None, balance=False, seed=0, load_episodes = 1000): 134 | directory = pathlib.Path(directory).expanduser() 135 | random = np.random.RandomState(seed) 136 | filenames = list(directory.glob('*.npz')) 137 | load_episodes = min(len(filenames), load_episodes) 138 | if load_episodes is None: 139 | load_episodes = int(count_episodes(directory)[0] / 20) 140 | 141 | while True: 142 | cache = {} 143 | for filename in random.choice(list(directory.glob('*.npz')), 144 | load_episodes, 145 | replace = False): 146 | try: 147 | with filename.open('rb') as f: 148 | episode = np.load(f) 149 | episode = {k: episode[k] for k in episode.keys() if k not in ['image_128']} 150 | #episode['reward'] = copy.deepcopy(episode['success']) 151 | except Exception as e: 152 | print(f'Could not load episode: {e}') 153 | continue 154 | cache[filename] = episode 155 | 156 | keys = list(cache.keys()) 157 | for index in random.choice(len(keys), rescan): 158 | episode = copy.deepcopy(cache[keys[index]]) 159 | if length: 160 | total = len(next(iter(episode.values()))) 161 | available = total - length 162 | if available < 0: 163 | for key in episode.keys(): 164 | shape = episode[key].shape 165 | episode[key] = np.concatenate([episode[key], 166 | np.zeros([abs(available)] + list(shape[1:]))], 167 | axis = 0) 168 | episode['mask'] = np.ones(length) 169 | episode['mask'][available:] = 0.0 170 | elif available > 0: 171 | if balance: 172 | index = min(random.randint(0, total), available) 173 | else: 174 | index = int(random.randint(0, available)) 175 | episode = {k: v[index: index + length] for k, v in episode.items()} 176 | episode['mask'] = np.ones(length) 177 | else: 178 | episode['mask'] = np.ones_like(episode['reward']) 179 | else: 180 | episode['mask'] = np.ones_like(episode['reward']) 181 | yield episode 182 | 183 | 184 | class Adam(tf.Module): 185 | def __init__(self, name, modules, lr, clip=None, wd=None, wdpattern=r'.*'): 186 | self._name = name 187 | self._modules = modules 188 | self._clip = clip 189 | self._wd = wd 190 | self._wdpattern = wdpattern 191 | self._opt = tf.optimizers.Adam(lr) 192 | 193 | @property 194 | def variables(self): 195 | return self._opt.variables() 196 | 197 | def __call__(self, tape, loss): 198 | variables = [module.variables for module in self._modules] 199 | self._variables = tf.nest.flatten(variables) 200 | assert len(loss.shape) == 0, loss.shape 201 | grads = tape.gradient(loss, self._variables) 202 | norm = tf.linalg.global_norm(grads) 203 | if self._clip: 204 | grads, _ = tf.clip_by_global_norm(grads, self._clip, norm) 205 | self._opt.apply_gradients(zip(grads, self._variables)) 206 | return norm 207 | 208 | 209 | def args_type(default): 210 | if isinstance(default, bool): 211 | return lambda x: bool(['False', 'True'].index(x)) 212 | if isinstance(default, int): 213 | return lambda x: float(x) if ('e' in x or '.' in x) else int(x) 214 | if isinstance(default, pathlib.Path): 215 | return lambda x: pathlib.Path(x).expanduser() 216 | return type(default) 217 | 218 | 219 | def static_scan(fn, inputs, start, reverse=False): 220 | last = start 221 | outputs = [[] for _ in tf.nest.flatten(start)] 222 | indices = range(len(tf.nest.flatten(inputs)[0])) 223 | if reverse: 224 | indices = reversed(indices) 225 | for index in indices: 226 | inp = tf.nest.map_structure(lambda x: x[index], inputs) 227 | last = fn(last, inp) 228 | [o.append(l) for o, l in zip(outputs, tf.nest.flatten(last))] 229 | if reverse: 230 | outputs = [list(reversed(x)) for x in outputs] 231 | outputs = [tf.stack(x, 0) for x in outputs] 232 | return tf.nest.pack_sequence_as(start, outputs) 233 | 234 | 235 | def _mnd_sample(self, sample_shape=(), seed=None, name='sample'): 236 | return tf.random.normal( 237 | tuple(sample_shape) + tuple(self.event_shape), 238 | self.mean(), self.stddev(), self.dtype, seed, name) 239 | 240 | 241 | tfd.MultivariateNormalDiag.sample = _mnd_sample 242 | 243 | 244 | -------------------------------------------------------------------------------- /wrappers.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import functools 3 | import sys 4 | import threading 5 | import traceback 6 | 7 | import gym 8 | import mujoco_py 9 | import d4rl 10 | import robel 11 | import metaworld 12 | import numpy as np 13 | from PIL import Image 14 | 15 | class DrawerOpen: 16 | def __init__(self, config, size=(128, 128)): 17 | self._env = metaworld.envs.mujoco.sawyer_xyz.v2.sawyer_drawer_open_v2.SawyerDrawerOpenEnvV2() 18 | self._env._last_rand_vec = np.array([-0.1, 0.9, 0.0]) 19 | self._env._set_task_called = True 20 | self.size = size 21 | 22 | #Setup camera in environment 23 | self.viewer = mujoco_py.MjRenderContextOffscreen(self._env.sim, -1) 24 | self.viewer.cam.elevation = -22.5 25 | self.viewer.cam.azimuth = 15 26 | self.viewer.cam.distance = 0.75 27 | self.viewer.cam.lookat[0] = -0.15 28 | self.viewer.cam.lookat[1] = 0.7 29 | self.viewer.cam.lookat[2] = 0.10 30 | 31 | def __getattr__(self, attr): 32 | if attr == '_wrapped_env': 33 | raise AttributeError() 34 | return getattr(self._env, attr) 35 | 36 | def step(self, action): 37 | state, reward, done, info = self._env.step(action) 38 | img = self.render(mode='rgb_array', width = self.size[0], height = self.size[1]) 39 | obs = {'state':state, 'image':img} 40 | reward = 1.0 * info['success'] 41 | return obs, reward, done, info 42 | 43 | def reset(self): 44 | state = self._env.reset() 45 | state = self._env.reset() 46 | img = self.render(mode='rgb_array', width = self.size[0], height = self.size[1]) 47 | if self.use_transform: 48 | img = img[self.pad:-self.pad, self.pad:-self.pad, :] 49 | obs = {'state':state, 'image':img} 50 | return obs 51 | 52 | def render(self, mode, width = 128, height = 128): 53 | self.viewer.render(width=width, height=width) 54 | img = self.viewer.read_pixels(self.size[0], self.size[1], depth=False) 55 | img = img[::-1] 56 | return img 57 | 58 | 59 | class Hammer: 60 | def __init__(self, config, size=(128, 128)): 61 | self._env = metaworld.envs.mujoco.sawyer_xyz.v2.sawyer_hammer_v2.SawyerHammerEnvV2() 62 | self._env._last_rand_vec = np.array([-0.06, 0.4, 0.02]) 63 | self._env._set_task_called = True 64 | self.size = size 65 | 66 | #Setup camera in environment 67 | self.viewer = mujoco_py.MjRenderContextOffscreen(self._env.sim, -1) 68 | self.viewer.cam.elevation = -15 69 | self.viewer.cam.azimuth = 137.5 70 | self.viewer.cam.distance = 0.9 71 | self.viewer.cam.lookat[0] = -0. 72 | self.viewer.cam.lookat[1] = 0.6 73 | self.viewer.cam.lookat[2] = 0.175 74 | 75 | def __getattr__(self, attr): 76 | if attr == '_wrapped_env': 77 | raise AttributeError() 78 | return getattr(self._env, attr) 79 | 80 | 81 | def step(self, action): 82 | state, reward, done, info = self._env.step(action) 83 | img = self.render(mode='rgb_array', width = self.size[0], height = self.size[1]) 84 | obs = {'state':state, 'image':img} 85 | return obs, reward, done, info 86 | 87 | def reset(self): 88 | state = self._env.reset() 89 | img = self.render(mode='rgb_array', width = self.size[0], height = self.size[1]) 90 | obs = {'state':state, 'image':img} 91 | return obs 92 | 93 | def render(self, mode, width = 128, height = 128): 94 | self.viewer.render(width=width, height=width) 95 | img = self.viewer.read_pixels(self.size[0], self.size[1], depth=False) 96 | img = img[::-1] 97 | return img 98 | 99 | 100 | class DoorOpen: 101 | def __init__(self, config, size=(128, 128)): 102 | self._env = metaworld.envs.mujoco.sawyer_xyz.v2.sawyer_door_v2.SawyerDoorEnvV2() 103 | self._env._last_rand_vec = np.array([0.0, 1.0, .1525]) 104 | self._env._set_task_called = True 105 | self.size = size 106 | 107 | #Setup camera in environment 108 | self.viewer = mujoco_py.MjRenderContextOffscreen(self._env.sim, -1) 109 | self.viewer.cam.elevation = -12.5 110 | self.viewer.cam.azimuth = 115 111 | self.viewer.cam.distance = 1.05 112 | self.viewer.cam.lookat[0] = 0.075 113 | self.viewer.cam.lookat[1] = 0.75 114 | self.viewer.cam.lookat[2] = 0.15 115 | 116 | def __getattr__(self, attr): 117 | if attr == '_wrapped_env': 118 | raise AttributeError() 119 | return getattr(self._env, attr) 120 | 121 | def step(self, action): 122 | state, reward, done, info = self._env.step(action) 123 | img = self.render(mode='rgb_array', width = self.size[0], height = self.size[1]) 124 | obs = {'state':state, 'image':img} 125 | reward = 1.0 * info['success'] 126 | return obs, reward, done, info 127 | 128 | def reset(self): 129 | state = self._env.reset() 130 | img = self.render(mode='rgb_array', width = self.size[0], height = self.size[1]) 131 | obs = {'state':state, 'image':img} 132 | return obs 133 | 134 | def render(self, mode, width = 128, height = 128): 135 | self.viewer.render(width=width, height=width) 136 | img = self.viewer.read_pixels(self.size[0], self.size[1], depth=False) 137 | img = img[::-1] 138 | return img 139 | 140 | 141 | class Gym: 142 | def __init__(self, name, config, size=(64, 64)): 143 | self._env = gym.make(name) 144 | self.size = size 145 | self.use_transform = config.use_transform 146 | self.pad = int(config.pad/2) 147 | 148 | def __getattr__(self, attr): 149 | if attr == '_wrapped_env': 150 | raise AttributeError() 151 | return getattr(self._env, attr) 152 | 153 | def step(self, action): 154 | state, reward, done, info = self.env.step(action) 155 | img = self._env.render(mode='rgb_array', width = self.size[0], height = self.size[1]) 156 | if self.use_transform: 157 | img = img[self.pad:-self.pad, self.pad:-self.pad, :] 158 | obs = {'state':state, 'image':img} 159 | return obs, reward, done, info 160 | 161 | def reset(self): 162 | state = self._env.reset() 163 | img = self._env.render(mode='rgb_array', width = self.size[0], height = self.size[1]) 164 | if self.use_transform: 165 | img = img[self.pad:-self.pad, self.pad:-self.pad, :] 166 | obs = {'state':state, 'image':img} 167 | return obs 168 | 169 | def render(self, *args, **kwargs): 170 | if kwargs.get('mode', 'rgb_array') != 'rgb_array': 171 | raise ValueError("Only render mode 'rgb_array' is supported.") 172 | return self._env.render(mode='rgb_array', width = self.size[0], height = self.size[1]) 173 | 174 | 175 | class DeepMindControl: 176 | def __init__(self, name, size=(64, 64), camera=None): 177 | domain, task = name.split('_', 1) 178 | if domain == 'cup': # Only domain with multiple words. 179 | domain = 'ball_in_cup' 180 | if isinstance(domain, str): 181 | from dm_control import suite 182 | self._env = suite.load(domain, task) 183 | else: 184 | assert task is None 185 | self._env = domain() 186 | self._size = size 187 | if camera is None: 188 | camera = dict(quadruped=2).get(domain, 0) 189 | self._camera = camera 190 | 191 | @property 192 | def observation_space(self): 193 | spaces = {} 194 | for key, value in self._env.observation_spec().items(): 195 | spaces[key] = gym.spaces.Box(-np.inf, np.inf, value.shape, dtype=np.float32) 196 | spaces['image'] = gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8) 197 | return gym.spaces.Dict(spaces) 198 | 199 | @property 200 | def action_space(self): 201 | spec = self._env.action_spec() 202 | return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32) 203 | 204 | def step(self, action): 205 | time_step = self._env.step(action) 206 | obs = dict(time_step.observation) 207 | obs['image'] = self.render() 208 | reward = time_step.reward or 0 209 | done = time_step.last() 210 | info = {'discount': np.array(time_step.discount, np.float32)} 211 | return obs, reward, done, info 212 | 213 | def reset(self): 214 | time_step = self._env.reset() 215 | obs = dict(time_step.observation) 216 | obs['image'] = self.render() 217 | return obs 218 | 219 | def render(self, *args, **kwargs): 220 | if kwargs.get('mode', 'rgb_array') != 'rgb_array': 221 | raise ValueError("Only render mode 'rgb_array' is supported.") 222 | return self._env.physics.render(*self._size, camera_id=self._camera) 223 | 224 | 225 | class Collect: 226 | def __init__(self, env, callbacks=None, precision=32): 227 | self._env = env 228 | self._callbacks = callbacks or () 229 | self._precision = precision 230 | self._episode = None 231 | 232 | def __getattr__(self, name): 233 | return getattr(self._env, name) 234 | 235 | def step(self, action): 236 | obs, reward, done, info = self._env.step(action) 237 | obs = {k: self._convert(v) for k, v in obs.items()} 238 | transition = obs.copy() 239 | transition['action'] = action 240 | transition['reward'] = reward 241 | transition['discount'] = info.get('discount', np.array(1 - float(done))) 242 | self._episode.append(transition) 243 | if done: 244 | episode = {k: [t[k] for t in self._episode] for k in self._episode[0]} 245 | episode = {k: self._convert(v) for k, v in episode.items()} 246 | info['episode'] = episode 247 | for callback in self._callbacks: 248 | callback(episode) 249 | return obs, reward, done, info 250 | 251 | def reset(self): 252 | obs = self._env.reset() 253 | transition = obs.copy() 254 | transition['action'] = np.zeros(self._env.action_space.shape) 255 | transition['reward'] = 0.0 256 | transition['discount'] = 1.0 257 | self._episode = [transition] 258 | return obs 259 | 260 | def _convert(self, value): 261 | value = np.array(value) 262 | if np.issubdtype(value.dtype, np.floating): 263 | dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self._precision] 264 | elif np.issubdtype(value.dtype, np.signedinteger): 265 | dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision] 266 | elif np.issubdtype(value.dtype, np.uint8): 267 | dtype = np.uint8 268 | else: 269 | raise NotImplementedError(value.dtype) 270 | return value.astype(dtype) 271 | 272 | 273 | class TimeLimit: 274 | def __init__(self, env, duration): 275 | self._env = env 276 | self._duration = duration 277 | self._step = None 278 | 279 | def __getattr__(self, name): 280 | return getattr(self._env, name) 281 | 282 | def step(self, action): 283 | assert self._step is not None, 'Must reset environment.' 284 | obs, reward, done, info = self._env.step(action) 285 | self._step += 1 286 | if self._step >= self._duration: 287 | done = True 288 | if 'discount' not in info: 289 | info['discount'] = np.array(1.0).astype(np.float32) 290 | self._step = None 291 | return obs, reward, done, info 292 | 293 | def reset(self): 294 | self._step = 0 295 | return self._env.reset() 296 | 297 | 298 | class ActionRepeat: 299 | def __init__(self, env, amount): 300 | self._env = env 301 | self._amount = amount 302 | 303 | def __getattr__(self, name): 304 | return getattr(self._env, name) 305 | 306 | def step(self, action): 307 | done = False 308 | total_reward = 0 309 | current_step = 0 310 | while current_step < self._amount and not done: 311 | obs, reward, done, info = self._env.step(action) 312 | total_reward += reward 313 | current_step += 1 314 | return obs, total_reward, done, info 315 | 316 | 317 | class NormalizeActions: 318 | def __init__(self, env): 319 | self._env = env 320 | self._mask = np.logical_and(np.isfinite(env.action_space.low), 321 | np.isfinite(env.action_space.high)) 322 | self._low = np.where(self._mask, env.action_space.low, -1) 323 | self._high = np.where(self._mask, env.action_space.high, 1) 324 | 325 | def __getattr__(self, name): 326 | return getattr(self._env, name) 327 | 328 | @property 329 | def action_space(self): 330 | low = np.where(self._mask, -np.ones_like(self._low), self._low) 331 | high = np.where(self._mask, np.ones_like(self._low), self._high) 332 | return gym.spaces.Box(low, high, dtype=np.float32) 333 | 334 | def step(self, action): 335 | original = (action + 1) / 2 * (self._high - self._low) + self._low 336 | original = np.where(self._mask, original, action) 337 | return self._env.step(original) 338 | 339 | 340 | class ObsDict: 341 | def __init__(self, env, key='obs'): 342 | self._env = env 343 | self._key = key 344 | 345 | def __getattr__(self, name): 346 | return getattr(self._env, name) 347 | 348 | @property 349 | def observation_space(self): 350 | spaces = {self._key: self._env.observation_space} 351 | return gym.spaces.Dict(spaces) 352 | 353 | @property 354 | def action_space(self): 355 | return self._env.action_space 356 | 357 | def step(self, action): 358 | obs, reward, done, info = self._env.step(action) 359 | obs = {self._key: np.array(obs)} 360 | return obs, reward, done, info 361 | 362 | def reset(self): 363 | obs = self._env.reset() 364 | obs = {self._key: np.array(obs)} 365 | return obs 366 | 367 | 368 | class RewardObs: 369 | def __init__(self, env): 370 | self._env = env 371 | 372 | def __getattr__(self, name): 373 | return getattr(self._env, name) 374 | 375 | @property 376 | def observation_space(self): 377 | spaces = self._env.observation_space.spaces 378 | assert 'reward' not in spaces 379 | spaces['reward'] = gym.spaces.Box(-np.inf, np.inf, dtype=np.float32) 380 | return gym.spaces.Dict(spaces) 381 | 382 | def step(self, action): 383 | obs, reward, done, info = self._env.step(action) 384 | obs['reward'] = reward 385 | return obs, reward, done, info 386 | 387 | def reset(self): 388 | obs = self._env.reset() 389 | obs['reward'] = 0.0 390 | return obs 391 | 392 | class Async: 393 | _ACCESS = 1 394 | _CALL = 2 395 | _RESULT = 3 396 | _EXCEPTION = 4 397 | _CLOSE = 5 398 | 399 | def __init__(self, ctor, strategy='process'): 400 | self._strategy = strategy 401 | if strategy == 'none': 402 | self._env = ctor() 403 | elif strategy == 'thread': 404 | import multiprocessing.dummy as mp 405 | elif strategy == 'process': 406 | import multiprocessing as mp 407 | else: 408 | raise NotImplementedError(strategy) 409 | if strategy != 'none': 410 | self._conn, conn = mp.Pipe() 411 | self._process = mp.Process(target=self._worker, args=(ctor, conn)) 412 | atexit.register(self.close) 413 | self._process.start() 414 | self._obs_space = None 415 | self._action_space = None 416 | 417 | @property 418 | def observation_space(self): 419 | if not self._obs_space: 420 | self._obs_space = self.__getattr__('observation_space') 421 | return self._obs_space 422 | 423 | @property 424 | def action_space(self): 425 | if not self._action_space: 426 | self._action_space = self.__getattr__('action_space') 427 | return self._action_space 428 | 429 | def __getattr__(self, name): 430 | if self._strategy == 'none': 431 | return getattr(self._env, name) 432 | self._conn.send((self._ACCESS, name)) 433 | return self._receive() 434 | 435 | def call(self, name, *args, **kwargs): 436 | blocking = kwargs.pop('blocking', True) 437 | if self._strategy == 'none': 438 | return functools.partial(getattr(self._env, name), *args, **kwargs) 439 | payload = name, args, kwargs 440 | self._conn.send((self._CALL, payload)) 441 | promise = self._receive 442 | return promise() if blocking else promise 443 | 444 | def close(self): 445 | if self._strategy == 'none': 446 | try: 447 | self._env.close() 448 | except AttributeError: 449 | pass 450 | return 451 | try: 452 | self._conn.send((self._CLOSE, None)) 453 | self._conn.close() 454 | except IOError: 455 | # The connection was already closed. 456 | pass 457 | self._process.join() 458 | 459 | def step(self, action, blocking=True): 460 | return self.call('step', action, blocking=blocking) 461 | 462 | 463 | def reset(self, blocking=True): 464 | return self.call('reset', blocking=blocking) 465 | 466 | def _receive(self): 467 | try: 468 | message, payload = self._conn.recv() 469 | except ConnectionResetError: 470 | raise RuntimeError('Environment worker crashed.') 471 | # Re-raise exceptions in the main process. 472 | if message == self._EXCEPTION: 473 | stacktrace = payload 474 | raise Exception(stacktrace) 475 | if message == self._RESULT: 476 | return payload 477 | raise KeyError(f'Received message of unexpected type {message}') 478 | 479 | def _worker(self, ctor, conn): 480 | try: 481 | env = ctor() 482 | while True: 483 | try: 484 | # Only block for short times to have keyboard exceptions be raised. 485 | if not conn.poll(0.1): 486 | continue 487 | message, payload = conn.recv() 488 | except (EOFError, KeyboardInterrupt): 489 | break 490 | if message == self._ACCESS: 491 | name = payload 492 | result = getattr(env, name) 493 | conn.send((self._RESULT, result)) 494 | continue 495 | if message == self._CALL: 496 | name, args, kwargs = payload 497 | result = getattr(env, name)(*args, **kwargs) 498 | conn.send((self._RESULT, result)) 499 | continue 500 | if message == self._CLOSE: 501 | assert payload is None 502 | break 503 | raise KeyError(f'Received message of unknown type {message}') 504 | except Exception: 505 | stacktrace = ''.join(traceback.format_exception(*sys.exc_info())) 506 | print(f'Error in environment process: {stacktrace}') 507 | conn.send((self._EXCEPTION, stacktrace)) 508 | conn.close() 509 | --------------------------------------------------------------------------------