├── .gitignore ├── README.md ├── atari_wrappers.py ├── cmd_util.py ├── console_util.py ├── load_log.py ├── monitor.py ├── mpi_util.py ├── policies ├── __init__.py ├── cnn_gru_policy_dynamics.py └── cnn_policy_param_matched.py ├── ppo_agent.py ├── recorder.py ├── replayer.py ├── run_atari.py ├── stochastic_policy.py ├── tf_util.py ├── utils.py └── vec_env.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Status:** Archive (code is provided as-is, no updates expected) 2 | 3 | ## [Exploration by Random Network Distillation](https://arxiv.org/abs/1810.12894) ## 4 | 5 | 6 | Yuri Burda*, Harri Edwards*, Amos Storkey, Oleg Klimov
7 | *equal contribution 8 | 9 | OpenAI
10 | University of Edinburgh 11 | 12 | 13 | ### Installation and Usage 14 | The following command should train an RND agent on Montezuma's Revenge 15 | ```bash 16 | python run_atari.py --gamma_ext 0.999 17 | ``` 18 | To use more than one gpu/machine, use MPI (e.g. `mpiexec -n 8 python run_atari.py --num_env 128 --gamma_ext 0.999` should use 1024 parallel environments to collect experience on an 8 gpu machine). 19 | 20 | ### [Blog post and videos](https://blog.openai.com/reinforcement-learning-with-prediction-based-rewards/) 21 | -------------------------------------------------------------------------------- /atari_wrappers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import deque 3 | import gym 4 | from gym import spaces 5 | import cv2 6 | from copy import copy 7 | 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | def unwrap(env): 11 | if hasattr(env, "unwrapped"): 12 | return env.unwrapped 13 | elif hasattr(env, "env"): 14 | return unwrap(env.env) 15 | elif hasattr(env, "leg_env"): 16 | return unwrap(env.leg_env) 17 | else: 18 | return env 19 | 20 | class MaxAndSkipEnv(gym.Wrapper): 21 | def __init__(self, env, skip=4): 22 | """Return only every `skip`-th frame""" 23 | gym.Wrapper.__init__(self, env) 24 | # most recent raw observations (for max pooling across time steps) 25 | self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8) 26 | self._skip = skip 27 | 28 | def step(self, action): 29 | """Repeat action, sum reward, and max over last observations.""" 30 | total_reward = 0.0 31 | done = None 32 | for i in range(self._skip): 33 | obs, reward, done, info = self.env.step(action) 34 | if i == self._skip - 2: self._obs_buffer[0] = obs 35 | if i == self._skip - 1: self._obs_buffer[1] = obs 36 | total_reward += reward 37 | if done: 38 | break 39 | # Note that the observation on the done=True frame 40 | # doesn't matter 41 | max_frame = self._obs_buffer.max(axis=0) 42 | 43 | return max_frame, total_reward, done, info 44 | 45 | def reset(self, **kwargs): 46 | return self.env.reset(**kwargs) 47 | 48 | class ClipRewardEnv(gym.RewardWrapper): 49 | def __init__(self, env): 50 | gym.RewardWrapper.__init__(self, env) 51 | 52 | def reward(self, reward): 53 | """Bin reward to {+1, 0, -1} by its sign.""" 54 | return float(np.sign(reward)) 55 | 56 | class WarpFrame(gym.ObservationWrapper): 57 | def __init__(self, env): 58 | """Warp frames to 84x84 as done in the Nature paper and later work.""" 59 | gym.ObservationWrapper.__init__(self, env) 60 | self.width = 84 61 | self.height = 84 62 | self.observation_space = spaces.Box(low=0, high=255, 63 | shape=(self.height, self.width, 1), dtype=np.uint8) 64 | 65 | def observation(self, frame): 66 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) 67 | frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) 68 | return frame[:, :, None] 69 | 70 | class FrameStack(gym.Wrapper): 71 | def __init__(self, env, k): 72 | """Stack k last frames. 73 | 74 | Returns lazy array, which is much more memory efficient. 75 | 76 | See Also 77 | -------- 78 | rl_common.atari_wrappers.LazyFrames 79 | """ 80 | gym.Wrapper.__init__(self, env) 81 | self.k = k 82 | self.frames = deque([], maxlen=k) 83 | shp = env.observation_space.shape 84 | self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=np.uint8) 85 | 86 | def reset(self): 87 | ob = self.env.reset() 88 | for _ in range(self.k): 89 | self.frames.append(ob) 90 | return self._get_ob() 91 | 92 | def step(self, action): 93 | ob, reward, done, info = self.env.step(action) 94 | self.frames.append(ob) 95 | return self._get_ob(), reward, done, info 96 | 97 | def _get_ob(self): 98 | assert len(self.frames) == self.k 99 | return LazyFrames(list(self.frames)) 100 | 101 | class ScaledFloatFrame(gym.ObservationWrapper): 102 | def __init__(self, env): 103 | gym.ObservationWrapper.__init__(self, env) 104 | 105 | def observation(self, observation): 106 | # careful! This undoes the memory optimization, use 107 | # with smaller replay buffers only. 108 | return np.array(observation).astype(np.float32) / 255.0 109 | 110 | class LazyFrames(object): 111 | def __init__(self, frames): 112 | """This object ensures that common frames between the observations are only stored once. 113 | It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay 114 | buffers. 115 | 116 | This object should only be converted to numpy array before being passed to the model. 117 | 118 | You'd not believe how complex the previous solution was.""" 119 | self._frames = frames 120 | self._out = None 121 | 122 | def _force(self): 123 | if self._out is None: 124 | self._out = np.concatenate(self._frames, axis=2) 125 | self._frames = None 126 | return self._out 127 | 128 | def __array__(self, dtype=None): 129 | out = self._force() 130 | if dtype is not None: 131 | out = out.astype(dtype) 132 | return out 133 | 134 | def __len__(self): 135 | return len(self._force()) 136 | 137 | def __getitem__(self, i): 138 | return self._force()[i] 139 | 140 | class MontezumaInfoWrapper(gym.Wrapper): 141 | def __init__(self, env, room_address): 142 | super(MontezumaInfoWrapper, self).__init__(env) 143 | self.room_address = room_address 144 | self.visited_rooms = set() 145 | 146 | def get_current_room(self): 147 | ram = unwrap(self.env).ale.getRAM() 148 | assert len(ram) == 128 149 | return int(ram[self.room_address]) 150 | 151 | def step(self, action): 152 | obs, rew, done, info = self.env.step(action) 153 | self.visited_rooms.add(self.get_current_room()) 154 | if done: 155 | if 'episode' not in info: 156 | info['episode'] = {} 157 | info['episode'].update(visited_rooms=copy(self.visited_rooms)) 158 | self.visited_rooms.clear() 159 | return obs, rew, done, info 160 | 161 | def reset(self): 162 | return self.env.reset() 163 | 164 | class DummyMontezumaInfoWrapper(gym.Wrapper): 165 | 166 | def __init__(self, env): 167 | super(DummyMontezumaInfoWrapper, self).__init__(env) 168 | 169 | def step(self, action): 170 | obs, rew, done, info = self.env.step(action) 171 | if done: 172 | if 'episode' not in info: 173 | info['episode'] = {} 174 | info['episode'].update(pos_count=0, 175 | visited_rooms=set([0])) 176 | return obs, rew, done, info 177 | 178 | def reset(self): 179 | return self.env.reset() 180 | 181 | class AddRandomStateToInfo(gym.Wrapper): 182 | def __init__(self, env): 183 | """Adds the random state to the info field on the first step after reset 184 | """ 185 | gym.Wrapper.__init__(self, env) 186 | 187 | def step(self, action): 188 | ob, r, d, info = self.env.step(action) 189 | if d: 190 | if 'episode' not in info: 191 | info['episode'] = {} 192 | info['episode']['rng_at_episode_start'] = self.rng_at_episode_start 193 | return ob, r, d, info 194 | 195 | def reset(self, **kwargs): 196 | self.rng_at_episode_start = copy(self.unwrapped.np_random) 197 | return self.env.reset(**kwargs) 198 | 199 | 200 | def make_atari(env_id, max_episode_steps=4500): 201 | env = gym.make(env_id) 202 | env._max_episode_steps = max_episode_steps*4 203 | assert 'NoFrameskip' in env.spec.id 204 | env = StickyActionEnv(env) 205 | env = MaxAndSkipEnv(env, skip=4) 206 | if "Montezuma" in env_id or "Pitfall" in env_id: 207 | env = MontezumaInfoWrapper(env, room_address=3 if "Montezuma" in env_id else 1) 208 | else: 209 | env = DummyMontezumaInfoWrapper(env) 210 | env = AddRandomStateToInfo(env) 211 | return env 212 | 213 | def wrap_deepmind(env, clip_rewards=True, frame_stack=False, scale=False): 214 | """Configure environment for DeepMind-style Atari. 215 | """ 216 | env = WarpFrame(env) 217 | if scale: 218 | env = ScaledFloatFrame(env) 219 | if clip_rewards: 220 | env = ClipRewardEnv(env) 221 | if frame_stack: 222 | env = FrameStack(env, 4) 223 | # env = NormalizeObservation(env) 224 | return env 225 | 226 | 227 | class StickyActionEnv(gym.Wrapper): 228 | def __init__(self, env, p=0.25): 229 | super(StickyActionEnv, self).__init__(env) 230 | self.p = p 231 | self.last_action = 0 232 | 233 | def reset(self): 234 | self.last_action = 0 235 | return self.env.reset() 236 | 237 | def step(self, action): 238 | if self.unwrapped.np_random.uniform() < self.p: 239 | action = self.last_action 240 | self.last_action = action 241 | obs, reward, done, info = self.env.step(action) 242 | return obs, reward, done, info 243 | -------------------------------------------------------------------------------- /cmd_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for scripts like run_atari.py. 3 | """ 4 | 5 | import os 6 | 7 | import gym 8 | from gym.wrappers import FlattenDictWrapper 9 | from mpi4py import MPI 10 | from baselines import logger 11 | from monitor import Monitor 12 | from atari_wrappers import make_atari, wrap_deepmind 13 | from vec_env import SubprocVecEnv 14 | 15 | 16 | def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0, max_episode_steps=4500): 17 | """ 18 | Create a wrapped, monitored SubprocVecEnv for Atari. 19 | """ 20 | if wrapper_kwargs is None: wrapper_kwargs = {} 21 | def make_env(rank): # pylint: disable=C0111 22 | def _thunk(): 23 | env = make_atari(env_id, max_episode_steps=max_episode_steps) 24 | env.seed(seed + rank) 25 | env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)), allow_early_resets=True) 26 | return wrap_deepmind(env, **wrapper_kwargs) 27 | return _thunk 28 | # set_global_seeds(seed) 29 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) 30 | 31 | def arg_parser(): 32 | """ 33 | Create an empty argparse.ArgumentParser. 34 | """ 35 | import argparse 36 | return argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 37 | 38 | def atari_arg_parser(): 39 | """ 40 | Create an argparse.ArgumentParser for run_atari.py. 41 | """ 42 | parser = arg_parser() 43 | parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4') 44 | parser.add_argument('--seed', help='RNG seed', type=int, default=0) 45 | parser.add_argument('--num-timesteps', type=int, default=int(10e6)) 46 | return parser 47 | -------------------------------------------------------------------------------- /console_util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from contextlib import contextmanager 3 | import numpy as np 4 | import time 5 | 6 | # ================================================================ 7 | # Misc 8 | # ================================================================ 9 | 10 | def fmt_row(width, row, header=False): 11 | out = " | ".join(fmt_item(x, width) for x in row) 12 | if header: out = out + "\n" + "-"*len(out) 13 | return out 14 | 15 | def fmt_item(x, l): 16 | if isinstance(x, np.ndarray): 17 | assert x.ndim==0 18 | x = x.item() 19 | if isinstance(x, (float, np.float32, np.float64)): 20 | v = abs(x) 21 | if (v < 1e-4 or v > 1e+4) and v > 0: 22 | rep = "%7.2e" % x 23 | else: 24 | rep = "%7.5f" % x 25 | else: rep = str(x) 26 | return " "*(l - len(rep)) + rep 27 | 28 | color2num = dict( 29 | gray=30, 30 | red=31, 31 | green=32, 32 | yellow=33, 33 | blue=34, 34 | magenta=35, 35 | cyan=36, 36 | white=37, 37 | crimson=38 38 | ) 39 | 40 | def colorize(string, color, bold=False, highlight=False): 41 | attr = [] 42 | num = color2num[color] 43 | if highlight: num += 10 44 | attr.append(str(num)) 45 | if bold: attr.append('1') 46 | return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string) 47 | 48 | 49 | MESSAGE_DEPTH = 0 50 | 51 | @contextmanager 52 | def timed(msg): 53 | global MESSAGE_DEPTH #pylint: disable=W0603 54 | print(colorize('\t'*MESSAGE_DEPTH + '=: ' + msg, color='magenta')) 55 | tstart = time.time() 56 | MESSAGE_DEPTH += 1 57 | yield 58 | MESSAGE_DEPTH -= 1 59 | print(colorize('\t'*MESSAGE_DEPTH + "done in %.3f seconds"%(time.time() - tstart), color='magenta')) 60 | -------------------------------------------------------------------------------- /load_log.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | import pickle 5 | import exptag 6 | 7 | separator = '------------' 8 | 9 | def parse(key, value): 10 | value = value.strip() 11 | try: 12 | if value.startswith('['): 13 | value = value[1:-1].split(';') 14 | if value != ['']: 15 | try: 16 | value = [int(v) for v in value] 17 | except: 18 | value = [str(v) for v in value] 19 | else: 20 | value = [] 21 | elif ';' in value: 22 | value = 0. 23 | elif value in ['nan', '', '-inf']: 24 | value = np.nan 25 | else: 26 | value = eval(value) 27 | except: 28 | import ipdb; ipdb.set_trace() 29 | print(f"failed to parse value {key}:{value.__repr__()}") 30 | value = 0. 31 | 32 | return value 33 | 34 | def get_hash(filename): 35 | import hashlib 36 | hash_md5 = hashlib.md5() 37 | with open(filename, "rb") as f: 38 | for chunk in iter(lambda: f.read(4096), b""): 39 | hash_md5.update(chunk) 40 | return hash_md5.hexdigest() 41 | 42 | def pickle_cache_result(f): 43 | def cached_f(filename): 44 | current_hash = get_hash(filename) 45 | cache_filename = filename + '_cache' 46 | if os.path.exists(cache_filename): 47 | with open(cache_filename, 'rb') as fl: 48 | try: 49 | stored_hash, stored_result = pickle.load(fl) 50 | if stored_hash == current_hash: 51 | # pass 52 | return stored_result 53 | except: 54 | pass 55 | result = f(filename) 56 | with open(cache_filename, 'wb') as fl: 57 | pickle.dump((current_hash, result), fl) 58 | return result 59 | return cached_f 60 | 61 | @pickle_cache_result 62 | def parse_csv(filename): 63 | import csv 64 | 65 | timeseries = {} 66 | keys = [] 67 | with open(filename, 'r') as f: 68 | reader = csv.reader(f) 69 | for i, row in enumerate(reader): 70 | if i == 0: 71 | keys = row 72 | else: 73 | for column, key in enumerate(keys): 74 | if key not in timeseries: 75 | timeseries[key] = [] 76 | timeseries[key].append(parse(key, row[column])) 77 | print(f'parsing {filename}') 78 | if 'opt_featvar' in timeseries: 79 | timeseries['opt_feat_var'] = timeseries['opt_featvar'] 80 | return timeseries 81 | 82 | 83 | def get_filename_from_tag(tag): 84 | folder = exptag.get_last_experiment_folder_by_tag(tag) 85 | return os.path.join(folder, "progress.csv") 86 | 87 | def get_filenames_from_tags(tags): 88 | return [get_filename_from_tag(tag) for tag in tags] 89 | 90 | def get_timeseries_from_filenames(filenames): 91 | return [parse_csv(f) for f in filenames] 92 | 93 | 94 | def get_timeseries_from_tags(tags): 95 | return get_timeseries_from_filenames(get_filenames_from_tags(tags)) 96 | 97 | 98 | if __name__ == '__main__': 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument('--tags', type=lambda x: x.split(','), nargs='+', default=None) 101 | parser.add_argument('--x_axis', type=str, choices=['tcount', 'n_updates'], default='tcount') 102 | args = parser.parse_args() 103 | 104 | x_axis = args.x_axis 105 | 106 | timeseries_groups = [] 107 | for tag_group in args.tags: 108 | timeseries_groups.append(get_timeseries_from_tags(tags=tag_group)) 109 | for tag_group, timeseries in zip(args.tags, timeseries_groups): 110 | rooms = [] 111 | for tag, t in zip(tag_group, timeseries): 112 | if 'rooms' in t: 113 | rooms.append((tag, t['rooms'][-1], t['best_ret'][-1], t[x_axis][-1])) 114 | else: 115 | print(f"tag {tag} has no rooms") 116 | rooms = sorted(rooms, key=lambda x: len(x[1])) 117 | all_rooms = set.union(*[set(r[1]) for r in rooms]) 118 | import pprint 119 | for tag, r, best_ret, max_x in rooms: 120 | print(f'{tag}:{best_ret}:{r}(@{max_x})') 121 | pprint.pprint(all_rooms) 122 | keys = set.intersection(*[set(t.keys()) for t in sum(timeseries_groups, [])]) 123 | 124 | keys = sorted(list(keys)) 125 | import matplotlib.pyplot as plt 126 | 127 | n_rows = int(np.ceil(np.sqrt(len(keys)))) 128 | n_cols = len(keys) // n_rows + 1 129 | fig, axes = plt.subplots(n_rows, n_cols, sharex=True) 130 | for i in range(n_rows): 131 | for j in range(n_cols): 132 | ind = i * n_cols + j 133 | if ind < len(keys): 134 | key = keys[ind] 135 | for color, timeseries in zip('rgbykcm', timeseries_groups): 136 | if key in ['all_places', 'recent_places', 'global_all_places', 'global_recent_places', 'rooms']: 137 | if any(isinstance(t[key][-1], list) for t in timeseries): 138 | for t in timeseries: 139 | t[key] = list(map(len, t[key])) 140 | max_timesteps = min((len(_[x_axis]) for _ in timeseries)) 141 | try: 142 | data = np.asarray([t[key][:max_timesteps] for t in timeseries], dtype=np.float32) 143 | except: 144 | import ipdb; ipdb.set_trace() 145 | lines = [np.nan_to_num(d[key]) for d in timeseries] 146 | lines_x = [np.asarray(d[x_axis]) for d in timeseries] 147 | alphas = [0.2/np.sqrt(len(lines)) for l in lines] 148 | lines += [np.nan_to_num(np.nanmean(data, 0))] 149 | alphas += [1.] 150 | lines_x += [np.asarray(timeseries[0][x_axis][:max_timesteps])] 151 | for alpha, y, x in zip(alphas, lines, lines_x): 152 | axes[i, j].plot(x, y, color=color, alpha=alpha) 153 | axes[i, j].set_title(key) 154 | 155 | plt.show() 156 | plt.close() -------------------------------------------------------------------------------- /monitor.py: -------------------------------------------------------------------------------- 1 | __all__ = ['Monitor', 'get_monitor_files', 'load_results'] 2 | 3 | import gym 4 | from gym.core import Wrapper 5 | import time 6 | from glob import glob 7 | import csv 8 | import os.path as osp 9 | import json 10 | import numpy as np 11 | 12 | class Monitor(Wrapper): 13 | EXT = "monitor.csv" 14 | f = None 15 | 16 | def __init__(self, env, filename, allow_early_resets=False, reset_keywords=(), info_keywords=()): 17 | Wrapper.__init__(self, env=env) 18 | self.tstart = time.time() 19 | if filename is None: 20 | self.f = None 21 | self.logger = None 22 | else: 23 | if not filename.endswith(Monitor.EXT): 24 | if osp.isdir(filename): 25 | filename = osp.join(filename, Monitor.EXT) 26 | else: 27 | filename = filename + "." + Monitor.EXT 28 | self.f = open(filename, "wt") 29 | self.f.write('#%s\n'%json.dumps({"t_start": self.tstart, 'env_id' : env.spec and env.spec.id})) 30 | self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+reset_keywords+info_keywords) 31 | self.logger.writeheader() 32 | self.f.flush() 33 | 34 | self.reset_keywords = reset_keywords 35 | self.info_keywords = info_keywords 36 | self.allow_early_resets = allow_early_resets 37 | self.rewards = None 38 | self.needs_reset = True 39 | self.episode_rewards = [] 40 | self.episode_lengths = [] 41 | self.episode_times = [] 42 | self.total_steps = 0 43 | self.current_reset_info = {} # extra info about the current episode, that was passed in during reset() 44 | 45 | def reset(self, **kwargs): 46 | if not self.allow_early_resets and not self.needs_reset: 47 | raise RuntimeError("Tried to reset an environment before done. If you want to allow early resets, wrap your env with Monitor(env, path, allow_early_resets=True)") 48 | self.rewards = [] 49 | self.needs_reset = False 50 | for k in self.reset_keywords: 51 | v = kwargs.get(k) 52 | if v is None: 53 | raise ValueError('Expected you to pass kwarg %s into reset'%k) 54 | self.current_reset_info[k] = v 55 | return self.env.reset(**kwargs) 56 | 57 | def step(self, action): 58 | if self.needs_reset: 59 | raise RuntimeError("Tried to step environment that needs reset") 60 | ob, rew, done, info = self.env.step(action) 61 | self.rewards.append(rew) 62 | if done: 63 | self.needs_reset = True 64 | eprew = sum(self.rewards) 65 | eplen = len(self.rewards) 66 | epinfo = {"r": round(eprew, 6), "l": eplen, "t": round(time.time() - self.tstart, 6)} 67 | for k in self.info_keywords: 68 | epinfo[k] = info[k] 69 | self.episode_rewards.append(eprew) 70 | self.episode_lengths.append(eplen) 71 | self.episode_times.append(time.time() - self.tstart) 72 | epinfo.update(self.current_reset_info) 73 | if self.logger: 74 | self.logger.writerow(epinfo) 75 | self.f.flush() 76 | if "episode" not in info: 77 | info["episoide"] = {} 78 | info['episode'].update(epinfo) 79 | self.total_steps += 1 80 | return (ob, rew, done, info) 81 | 82 | def close(self): 83 | if self.f is not None: 84 | self.f.close() 85 | 86 | def get_total_steps(self): 87 | return self.total_steps 88 | 89 | def get_episode_rewards(self): 90 | return self.episode_rewards 91 | 92 | def get_episode_lengths(self): 93 | return self.episode_lengths 94 | 95 | def get_episode_times(self): 96 | return self.episode_times 97 | 98 | class LoadMonitorResultsError(Exception): 99 | pass 100 | 101 | def get_monitor_files(dir): 102 | return glob(osp.join(dir, "*" + Monitor.EXT)) 103 | 104 | def load_results(dir): 105 | import pandas 106 | monitor_files = ( 107 | glob(osp.join(dir, "*monitor.json")) + 108 | glob(osp.join(dir, "*monitor.csv"))) # get both csv and (old) json files 109 | if not monitor_files: 110 | raise LoadMonitorResultsError("no monitor files of the form *%s found in %s" % (Monitor.EXT, dir)) 111 | dfs = [] 112 | headers = [] 113 | for fname in monitor_files: 114 | with open(fname, 'rt') as fh: 115 | if fname.endswith('csv'): 116 | firstline = fh.readline() 117 | assert firstline[0] == '#' 118 | header = json.loads(firstline[1:]) 119 | df = pandas.read_csv(fh, index_col=None) 120 | headers.append(header) 121 | elif fname.endswith('json'): # Deprecated json format 122 | episodes = [] 123 | lines = fh.readlines() 124 | header = json.loads(lines[0]) 125 | headers.append(header) 126 | for line in lines[1:]: 127 | episode = json.loads(line) 128 | episodes.append(episode) 129 | df = pandas.DataFrame(episodes) 130 | else: 131 | assert 0, 'unreachable' 132 | df['t'] += header['t_start'] 133 | dfs.append(df) 134 | df = pandas.concat(dfs) 135 | df.sort_values('t', inplace=True) 136 | df.reset_index(inplace=True) 137 | df['t'] -= min(header['t_start'] for header in headers) 138 | df.headers = headers # HACK to preserve backwards compatibility 139 | return df 140 | 141 | def test_monitor(): 142 | env = gym.make("CartPole-v1") 143 | env.seed(0) 144 | mon_file = "/tmp/baselines-test-%s.monitor.csv" % uuid.uuid4() 145 | menv = Monitor(env, mon_file) 146 | menv.reset() 147 | for _ in range(1000): 148 | _, _, done, _ = menv.step(0) 149 | if done: 150 | menv.reset() 151 | 152 | f = open(mon_file, 'rt') 153 | 154 | firstline = f.readline() 155 | assert firstline.startswith('#') 156 | metadata = json.loads(firstline[1:]) 157 | assert metadata['env_id'] == "CartPole-v1" 158 | assert set(metadata.keys()) == {'env_id', 'gym_version', 't_start'}, "Incorrect keys in monitor metadata" 159 | 160 | last_logline = pandas.read_csv(f, index_col=None) 161 | assert set(last_logline.keys()) == {'l', 't', 'r'}, "Incorrect keys in monitor logline" 162 | f.close() 163 | os.remove(mon_file) -------------------------------------------------------------------------------- /mpi_util.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from mpi4py import MPI 3 | import os, numpy as np 4 | import platform 5 | import tensorflow as tf 6 | 7 | def sync_from_root(sess, variables, comm=None): 8 | """ 9 | Send the root node's parameters to every worker. 10 | Arguments: 11 | sess: the TensorFlow session. 12 | variables: all parameter variables including optimizer's 13 | """ 14 | if comm is None: comm = MPI.COMM_WORLD 15 | rank = comm.Get_rank() 16 | for var in variables: 17 | if rank == 0: 18 | comm.bcast(sess.run(var)) 19 | else: 20 | import tensorflow as tf 21 | sess.run(tf.assign(var, comm.bcast(None))) 22 | 23 | # def gpu_count(): 24 | # """ 25 | # Count the GPUs on this machine. 26 | # """ 27 | # if shutil.which('nvidia-smi') is None: 28 | # return 0 29 | # output = subprocess.check_output(['nvidia-smi', '--query-gpu=gpu_name', '--format=csv']) 30 | # return max(0, len(output.split(b'\n')) - 2) 31 | # 32 | # def setup_mpi_gpus(): 33 | # """ 34 | # Set CUDA_VISIBLE_DEVICES using MPI. 35 | # """ 36 | # num_gpus = gpu_count() 37 | # if num_gpus == 0: 38 | # return 39 | # local_rank, _ = get_local_rank_size(MPI.COMM_WORLD) 40 | # os.environ['CUDA_VISIBLE_DEVICES'] = str(local_rank % num_gpus) 41 | 42 | def guess_available_gpus(n_gpus=None): 43 | if n_gpus is not None: 44 | return list(range(n_gpus)) 45 | if 'CUDA_VISIBLE_DEVICES' in os.environ: 46 | cuda_visible_divices = os.environ['CUDA_VISIBLE_DEVICES'] 47 | cuda_visible_divices = cuda_visible_divices.split(',') 48 | return [int(n) for n in cuda_visible_divices] 49 | if 'RCALL_NUM_GPU' not in os.environ: 50 | n_gpus = int(os.environ['RCALL_NUM_GPU']) 51 | return list(range(n_gpus)) 52 | nvidia_dir = '/proc/driver/nvidia/gpus/' 53 | if os.path.exists(nvidia_dir): 54 | n_gpus = len(os.listdir(nvidia_dir)) 55 | return list(range(n_gpus)) 56 | raise Exception("Couldn't guess the available gpus on this machine") 57 | 58 | 59 | def setup_mpi_gpus(): 60 | """ 61 | Set CUDA_VISIBLE_DEVICES using MPI. 62 | """ 63 | available_gpus = guess_available_gpus() 64 | 65 | node_id = platform.node() 66 | nodes_ordered_by_rank = MPI.COMM_WORLD.allgather(node_id) 67 | processes_outranked_on_this_node = [n for n in nodes_ordered_by_rank[:MPI.COMM_WORLD.Get_rank()] if n == node_id] 68 | local_rank = len(processes_outranked_on_this_node) 69 | 70 | os.environ['CUDA_VISIBLE_DEVICES'] = str(available_gpus[local_rank]) 71 | 72 | 73 | def get_local_rank_size(comm): 74 | """ 75 | Returns the rank of each process on its machine 76 | The processes on a given machine will be assigned ranks 77 | 0, 1, 2, ..., N-1, 78 | where N is the number of processes on this machine. 79 | 80 | Useful if you want to assign one gpu per machine 81 | """ 82 | this_node = platform.node() 83 | ranks_nodes = comm.allgather((comm.Get_rank(), this_node)) 84 | node2rankssofar = defaultdict(int) 85 | local_rank = None 86 | for (rank, node) in ranks_nodes: 87 | if rank == comm.Get_rank(): 88 | local_rank = node2rankssofar[node] 89 | node2rankssofar[node] += 1 90 | assert local_rank is not None 91 | return local_rank, node2rankssofar[this_node] 92 | 93 | def share_file(comm, path): 94 | """ 95 | Copies the file from rank 0 to all other ranks 96 | Puts it in the same place on all machines 97 | """ 98 | localrank, _ = get_local_rank_size(comm) 99 | if comm.Get_rank() == 0: 100 | with open(path, 'rb') as fh: 101 | data = fh.read() 102 | comm.bcast(data) 103 | else: 104 | data = comm.bcast(None) 105 | if localrank == 0: 106 | os.makedirs(os.path.dirname(path), exist_ok=True) 107 | with open(path, 'wb') as fh: 108 | fh.write(data) 109 | comm.Barrier() 110 | 111 | def dict_gather_mean(comm, d): 112 | alldicts = comm.allgather(d) 113 | size = comm.Get_size() 114 | k2li = defaultdict(list) 115 | for d in alldicts: 116 | for (k,v) in d.items(): 117 | k2li[k].append(v) 118 | k2mean = {} 119 | for (k,li) in k2li.items(): 120 | k2mean[k] = np.mean(li, axis=0) if len(li) == size else np.nan 121 | return k2mean 122 | 123 | class MpiAdamOptimizer(tf.train.AdamOptimizer): 124 | """Adam optimizer that averages gradients across mpi processes.""" 125 | def __init__(self, comm, **kwargs): 126 | self.comm = comm 127 | tf.train.AdamOptimizer.__init__(self, **kwargs) 128 | def compute_gradients(self, loss, var_list, **kwargs): 129 | grads_and_vars = tf.train.AdamOptimizer.compute_gradients(self, loss, var_list, **kwargs) 130 | grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None] 131 | flat_grad = tf.concat([tf.reshape(g, (-1,)) for g, v in grads_and_vars], axis=0) 132 | shapes = [v.shape.as_list() for g, v in grads_and_vars] 133 | sizes = [int(np.prod(s)) for s in shapes] 134 | 135 | num_tasks = self.comm.Get_size() 136 | buf = np.zeros(sum(sizes), np.float32) 137 | 138 | def _collect_grads(flat_grad): 139 | self.comm.Allreduce(flat_grad, buf, op=MPI.SUM) 140 | np.divide(buf, float(num_tasks), out=buf) 141 | return buf 142 | 143 | avg_flat_grad = tf.py_func(_collect_grads, [flat_grad], tf.float32) 144 | avg_flat_grad.set_shape(flat_grad.shape) 145 | avg_grads = tf.split(avg_flat_grad, sizes, axis=0) 146 | avg_grads_and_vars = [(tf.reshape(g, v.shape), v) 147 | for g, (_, v) in zip(avg_grads, grads_and_vars)] 148 | 149 | return avg_grads_and_vars 150 | 151 | 152 | def mpi_mean(x, axis=0, comm=None, keepdims=False): 153 | x = np.asarray(x) 154 | assert x.ndim > 0 155 | if comm is None: comm = MPI.COMM_WORLD 156 | xsum = x.sum(axis=axis, keepdims=keepdims) 157 | n = xsum.size 158 | localsum = np.zeros(n+1, x.dtype) 159 | localsum[:n] = xsum.ravel() 160 | localsum[n] = x.shape[axis] 161 | globalsum = np.zeros_like(localsum) 162 | comm.Allreduce(localsum, globalsum, op=MPI.SUM) 163 | return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n] 164 | 165 | def mpi_moments(x, axis=0, comm=None, keepdims=False): 166 | x = np.asarray(x) 167 | assert x.ndim > 0 168 | mean, count = mpi_mean(x, axis=axis, comm=comm, keepdims=True) 169 | sqdiffs = np.square(x - mean) 170 | meansqdiff, count1 = mpi_mean(sqdiffs, axis=axis, comm=comm, keepdims=True) 171 | assert count1 == count 172 | std = np.sqrt(meansqdiff) 173 | if not keepdims: 174 | newshape = mean.shape[:axis] + mean.shape[axis+1:] 175 | mean = mean.reshape(newshape) 176 | std = std.reshape(newshape) 177 | return mean, std, count 178 | 179 | class RunningMeanStd(object): 180 | # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 181 | def __init__(self, epsilon=1e-4, shape=(), comm=None, use_mpi=True): 182 | self.mean = np.zeros(shape, 'float64') 183 | self.use_mpi = use_mpi 184 | self.var = np.ones(shape, 'float64') 185 | self.count = epsilon 186 | if comm is None: 187 | from mpi4py import MPI 188 | comm = MPI.COMM_WORLD 189 | self.comm = comm 190 | 191 | 192 | def update(self, x): 193 | if self.use_mpi: 194 | batch_mean, batch_std, batch_count = mpi_moments(x, axis=0, comm=self.comm) 195 | else: 196 | batch_mean, batch_std, batch_count = np.mean(x, axis=0), np.std(x, axis=0), x.shape[0] 197 | batch_var = np.square(batch_std) 198 | self.update_from_moments(batch_mean, batch_var, batch_count) 199 | 200 | def update_from_moments(self, batch_mean, batch_var, batch_count): 201 | delta = batch_mean - self.mean 202 | tot_count = self.count + batch_count 203 | 204 | new_mean = self.mean + delta * batch_count / tot_count 205 | m_a = self.var * (self.count) 206 | m_b = batch_var * (batch_count) 207 | M2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count) 208 | new_var = M2 / (self.count + batch_count) 209 | 210 | new_count = batch_count + self.count 211 | 212 | self.mean = new_mean 213 | self.var = new_var 214 | self.count = new_count 215 | -------------------------------------------------------------------------------- /policies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/random-network-distillation/f75c0f1efa473d5109d487062fd8ed49ddce6634/policies/__init__.py -------------------------------------------------------------------------------- /policies/cnn_gru_policy_dynamics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from baselines import logger 4 | from utils import fc, conv 5 | from stochastic_policy import StochasticPolicy 6 | from tf_util import get_available_gpus 7 | from mpi_util import RunningMeanStd 8 | 9 | 10 | def to2d(x): 11 | size = 1 12 | for shapel in x.get_shape()[1:]: size *= shapel.value 13 | return tf.reshape(x, (-1, size)) 14 | 15 | 16 | 17 | class GRUCell(tf.nn.rnn_cell.RNNCell): 18 | """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).""" 19 | def __init__(self, num_units, rec_gate_init=-1.0): 20 | tf.nn.rnn_cell.RNNCell.__init__(self) 21 | self._num_units = num_units 22 | self.rec_gate_init = rec_gate_init 23 | @property 24 | def state_size(self): 25 | return self._num_units 26 | @property 27 | def output_size(self): 28 | return self._num_units 29 | def call(self, inputs, state): 30 | """Gated recurrent unit (GRU) with nunits cells.""" 31 | x, new = inputs 32 | h = state 33 | h *= (1.0 - new) 34 | hx = tf.concat([h, x], axis=1) 35 | mr = tf.sigmoid(fc(hx, nh=self._num_units * 2, scope='mr', init_bias=self.rec_gate_init)) 36 | # r: read strength. m: 'member strength 37 | m, r = tf.split(mr, 2, axis=1) 38 | rh_x = tf.concat([r * h, x], axis=1) 39 | htil = tf.tanh(fc(rh_x, nh=self._num_units, scope='htil')) 40 | h = m * h + (1.0 - m) * htil 41 | return h, h 42 | 43 | class CnnGruPolicy(StochasticPolicy): 44 | def __init__(self, scope, ob_space, ac_space, 45 | policy_size='normal', maxpool=False, extrahid=True, hidsize=128, memsize=128, rec_gate_init=0.0, 46 | update_ob_stats_independently_per_gpu=True, 47 | proportion_of_exp_used_for_predictor_update=1., 48 | dynamics_bonus = False, 49 | ): 50 | StochasticPolicy.__init__(self, scope, ob_space, ac_space) 51 | self.proportion_of_exp_used_for_predictor_update = proportion_of_exp_used_for_predictor_update 52 | enlargement = { 53 | 'small': 1, 54 | 'normal': 2, 55 | 'large': 4 56 | }[policy_size] 57 | rep_size = 512 58 | self.ph_mean = tf.placeholder(dtype=tf.float32, shape=list(ob_space.shape[:2])+[1], name="obmean") 59 | self.ph_std = tf.placeholder(dtype=tf.float32, shape=list(ob_space.shape[:2])+[1], name="obstd") 60 | memsize *= enlargement 61 | hidsize *= enlargement 62 | convfeat = 16*enlargement 63 | self.ob_rms = RunningMeanStd(shape=list(ob_space.shape[:2])+[1], use_mpi=not update_ob_stats_independently_per_gpu) 64 | ph_istate = tf.placeholder(dtype=tf.float32,shape=(None,memsize), name='state') 65 | pdparamsize = self.pdtype.param_shape()[0] 66 | self.memsize = memsize 67 | 68 | self.pdparam_opt, self.vpred_int_opt, self.vpred_ext_opt, self.snext_opt = \ 69 | self.apply_policy(self.ph_ob[None][:,:-1], 70 | ph_new=self.ph_new, 71 | ph_istate=ph_istate, 72 | reuse=False, 73 | scope=scope, 74 | hidsize=hidsize, 75 | memsize=memsize, 76 | extrahid=extrahid, 77 | sy_nenvs=self.sy_nenvs, 78 | sy_nsteps=self.sy_nsteps - 1, 79 | pdparamsize=pdparamsize, 80 | rec_gate_init=rec_gate_init 81 | ) 82 | self.pdparam_rollout, self.vpred_int_rollout, self.vpred_ext_rollout, self.snext_rollout = \ 83 | self.apply_policy(self.ph_ob[None], 84 | ph_new=self.ph_new, 85 | ph_istate=ph_istate, 86 | reuse=True, 87 | scope=scope, 88 | hidsize=hidsize, 89 | memsize=memsize, 90 | extrahid=extrahid, 91 | sy_nenvs=self.sy_nenvs, 92 | sy_nsteps=self.sy_nsteps, 93 | pdparamsize=pdparamsize, 94 | rec_gate_init=rec_gate_init 95 | ) 96 | if dynamics_bonus: 97 | self.define_dynamics_prediction_rew(convfeat=convfeat, rep_size=rep_size, enlargement=enlargement) 98 | else: 99 | self.define_self_prediction_rew(convfeat=convfeat, rep_size=rep_size, enlargement=enlargement) 100 | 101 | 102 | 103 | pd = self.pdtype.pdfromflat(self.pdparam_rollout) 104 | self.a_samp = pd.sample() 105 | self.nlp_samp = pd.neglogp(self.a_samp) 106 | self.entropy_rollout = pd.entropy() 107 | self.pd_rollout = pd 108 | 109 | self.pd_opt = self.pdtype.pdfromflat(self.pdparam_opt) 110 | 111 | self.ph_istate = ph_istate 112 | 113 | @staticmethod 114 | def apply_policy(ph_ob, ph_new, ph_istate, reuse, scope, hidsize, memsize, extrahid, sy_nenvs, sy_nsteps, pdparamsize, rec_gate_init): 115 | data_format = 'NHWC' 116 | ph = ph_ob 117 | assert len(ph.shape.as_list()) == 5 # B,T,H,W,C 118 | logger.info("CnnGruPolicy: using '%s' shape %s as image input" % (ph.name, str(ph.shape))) 119 | X = tf.cast(ph, tf.float32) / 255. 120 | X = tf.reshape(X, (-1, *ph.shape.as_list()[-3:])) 121 | 122 | activ = tf.nn.relu 123 | yes_gpu = any(get_available_gpus()) 124 | 125 | with tf.variable_scope(scope, reuse=reuse), tf.device('/gpu:0' if yes_gpu else '/cpu:0'): 126 | X = activ(conv(X, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2), data_format=data_format)) 127 | X = activ(conv(X, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), data_format=data_format)) 128 | X = activ(conv(X, 'c3', nf=64, rf=4, stride=1, init_scale=np.sqrt(2), data_format=data_format)) 129 | X = to2d(X) 130 | X = activ(fc(X, 'fc1', nh=hidsize, init_scale=np.sqrt(2))) 131 | X = tf.reshape(X, [sy_nenvs, sy_nsteps, hidsize]) 132 | X, snext = tf.nn.dynamic_rnn( 133 | GRUCell(memsize, rec_gate_init=rec_gate_init), (X, ph_new[:,:,None]), 134 | dtype=tf.float32, time_major=False, initial_state=ph_istate) 135 | X = tf.reshape(X, (-1, memsize)) 136 | Xtout = X 137 | if extrahid: 138 | Xtout = X + activ(fc(Xtout, 'fc2val', nh=memsize, init_scale=0.1)) 139 | X = X + activ(fc(X, 'fc2act', nh=memsize, init_scale=0.1)) 140 | pdparam = fc(X, 'pd', nh=pdparamsize, init_scale=0.01) 141 | vpred_int = fc(Xtout, 'vf_int', nh=1, init_scale=0.01) 142 | vpred_ext = fc(Xtout, 'vf_ext', nh=1, init_scale=0.01) 143 | 144 | pdparam = tf.reshape(pdparam, (sy_nenvs, sy_nsteps, pdparamsize)) 145 | vpred_int = tf.reshape(vpred_int, (sy_nenvs, sy_nsteps)) 146 | vpred_ext = tf.reshape(vpred_ext, (sy_nenvs, sy_nsteps)) 147 | return pdparam, vpred_int, vpred_ext, snext 148 | 149 | def define_self_prediction_rew(self, convfeat, rep_size, enlargement): 150 | #RND. 151 | # Random target network. 152 | for ph in self.ph_ob.values(): 153 | if len(ph.shape.as_list()) == 5: # B,T,H,W,C 154 | logger.info("CnnTarget: using '%s' shape %s as image input" % (ph.name, str(ph.shape))) 155 | xr = ph[:,1:] 156 | xr = tf.cast(xr, tf.float32) 157 | xr = tf.reshape(xr, (-1, *ph.shape.as_list()[-3:]))[:, :, :, -1:] 158 | xr = tf.clip_by_value((xr - self.ph_mean) / self.ph_std, -5.0, 5.0) 159 | 160 | xr = tf.nn.leaky_relu(conv(xr, 'c1r', nf=convfeat * 1, rf=8, stride=4, init_scale=np.sqrt(2))) 161 | xr = tf.nn.leaky_relu(conv(xr, 'c2r', nf=convfeat * 2 * 1, rf=4, stride=2, init_scale=np.sqrt(2))) 162 | xr = tf.nn.leaky_relu(conv(xr, 'c3r', nf=convfeat * 2 * 1, rf=3, stride=1, init_scale=np.sqrt(2))) 163 | rgbr = [to2d(xr)] 164 | X_r = fc(rgbr[0], 'fc1r', nh=rep_size, init_scale=np.sqrt(2)) 165 | 166 | # Predictor network. 167 | for ph in self.ph_ob.values(): 168 | if len(ph.shape.as_list()) == 5: # B,T,H,W,C 169 | logger.info("CnnTarget: using '%s' shape %s as image input" % (ph.name, str(ph.shape))) 170 | xrp = ph[:,1:] 171 | xrp = tf.cast(xrp, tf.float32) 172 | xrp = tf.reshape(xrp, (-1, *ph.shape.as_list()[-3:]))[:, :, :, -1:] 173 | xrp = tf.clip_by_value((xrp - self.ph_mean) / self.ph_std, -5.0, 5.0) 174 | 175 | xrp = tf.nn.leaky_relu(conv(xrp, 'c1rp_pred', nf=convfeat, rf=8, stride=4, init_scale=np.sqrt(2))) 176 | xrp = tf.nn.leaky_relu(conv(xrp, 'c2rp_pred', nf=convfeat * 2, rf=4, stride=2, init_scale=np.sqrt(2))) 177 | xrp = tf.nn.leaky_relu(conv(xrp, 'c3rp_pred', nf=convfeat * 2, rf=3, stride=1, init_scale=np.sqrt(2))) 178 | rgbrp = to2d(xrp) 179 | X_r_hat = tf.nn.relu(fc(rgbrp, 'fc1r_hat1_pred', nh=256 * enlargement, init_scale=np.sqrt(2))) 180 | X_r_hat = tf.nn.relu(fc(X_r_hat, 'fc1r_hat2_pred', nh=256 * enlargement, init_scale=np.sqrt(2))) 181 | X_r_hat = fc(X_r_hat, 'fc1r_hat3_pred', nh=rep_size, init_scale=np.sqrt(2)) 182 | 183 | self.feat_var = tf.reduce_mean(tf.nn.moments(X_r, axes=[0])[1]) 184 | self.max_feat = tf.reduce_max(tf.abs(X_r)) 185 | self.int_rew = tf.reduce_mean(tf.square(tf.stop_gradient(X_r) - X_r_hat), axis=-1, keep_dims=True) 186 | self.int_rew = tf.reshape(self.int_rew, (self.sy_nenvs, self.sy_nsteps - 1)) 187 | 188 | noisy_targets = tf.stop_gradient(X_r) 189 | self.aux_loss = tf.reduce_mean(tf.square(noisy_targets - X_r_hat), -1) 190 | mask = tf.random_uniform(shape=tf.shape(self.aux_loss), minval=0., maxval=1., dtype=tf.float32) 191 | mask = tf.cast(mask < self.proportion_of_exp_used_for_predictor_update, tf.float32) 192 | self.aux_loss = tf.reduce_sum(mask * self.aux_loss) / tf.maximum(tf.reduce_sum(mask), 1.) 193 | 194 | def define_dynamics_prediction_rew(self, convfeat, rep_size, enlargement): 195 | #Dynamics based bonus. 196 | 197 | # Random target network. 198 | for ph in self.ph_ob.values(): 199 | if len(ph.shape.as_list()) == 5: # B,T,H,W,C 200 | logger.info("CnnTarget: using '%s' shape %s as image input" % (ph.name, str(ph.shape))) 201 | xr = ph[:,1:] 202 | xr = tf.cast(xr, tf.float32) 203 | xr = tf.reshape(xr, (-1, *ph.shape.as_list()[-3:]))[:, :, :, -1:] 204 | xr = tf.clip_by_value((xr - self.ph_mean) / self.ph_std, -5.0, 5.0) 205 | 206 | xr = tf.nn.leaky_relu(conv(xr, 'c1r', nf=convfeat * 1, rf=8, stride=4, init_scale=np.sqrt(2))) 207 | xr = tf.nn.leaky_relu(conv(xr, 'c2r', nf=convfeat * 2 * 1, rf=4, stride=2, init_scale=np.sqrt(2))) 208 | xr = tf.nn.leaky_relu(conv(xr, 'c3r', nf=convfeat * 2 * 1, rf=3, stride=1, init_scale=np.sqrt(2))) 209 | rgbr = [to2d(xr)] 210 | X_r = fc(rgbr[0], 'fc1r', nh=rep_size, init_scale=np.sqrt(2)) 211 | 212 | # Predictor network. 213 | ac_one_hot = tf.one_hot(self.ph_ac, self.ac_space.n, axis=2) 214 | assert ac_one_hot.get_shape().ndims == 3 215 | assert ac_one_hot.get_shape().as_list() == [None, None, self.ac_space.n], ac_one_hot.get_shape().as_list() 216 | ac_one_hot = tf.reshape(ac_one_hot, (-1, self.ac_space.n)) 217 | def cond(x): 218 | return tf.concat([x, ac_one_hot], 1) 219 | 220 | for ph in self.ph_ob.values(): 221 | if len(ph.shape.as_list()) == 5: # B,T,H,W,C 222 | logger.info("CnnTarget: using '%s' shape %s as image input" % (ph.name, str(ph.shape))) 223 | xrp = ph[:,:-1] 224 | xrp = tf.cast(xrp, tf.float32) 225 | xrp = tf.reshape(xrp, (-1, *ph.shape.as_list()[-3:])) 226 | # ph_mean, ph_std are 84x84x1, so we subtract the average of the last channel from all channels. Is this ok? 227 | xrp = tf.clip_by_value((xrp - self.ph_mean) / self.ph_std, -5.0, 5.0) 228 | 229 | xrp = tf.nn.leaky_relu(conv(xrp, 'c1rp_pred', nf=convfeat, rf=8, stride=4, init_scale=np.sqrt(2))) 230 | xrp = tf.nn.leaky_relu(conv(xrp, 'c2rp_pred', nf=convfeat * 2, rf=4, stride=2, init_scale=np.sqrt(2))) 231 | xrp = tf.nn.leaky_relu(conv(xrp, 'c3rp_pred', nf=convfeat * 2, rf=3, stride=1, init_scale=np.sqrt(2))) 232 | rgbrp = to2d(xrp) 233 | 234 | # X_r_hat = tf.nn.relu(fc(rgb[0], 'fc1r_hat1', nh=256 * enlargement, init_scale=np.sqrt(2))) 235 | X_r_hat = tf.nn.relu(fc(cond(rgbrp), 'fc1r_hat1_pred', nh=256 * enlargement, init_scale=np.sqrt(2))) 236 | X_r_hat = tf.nn.relu(fc(cond(X_r_hat), 'fc1r_hat2_pred', nh=256 * enlargement, init_scale=np.sqrt(2))) 237 | X_r_hat = fc(cond(X_r_hat), 'fc1r_hat3_pred', nh=rep_size, init_scale=np.sqrt(2)) 238 | 239 | self.feat_var = tf.reduce_mean(tf.nn.moments(X_r, axes=[0])[1]) 240 | self.max_feat = tf.reduce_max(tf.abs(X_r)) 241 | self.int_rew = tf.reduce_mean(tf.square(tf.stop_gradient(X_r) - X_r_hat), axis=-1, keep_dims=True) 242 | self.int_rew = tf.reshape(self.int_rew, (self.sy_nenvs, self.sy_nsteps - 1)) 243 | 244 | noisy_targets = tf.stop_gradient(X_r) 245 | self.aux_loss = tf.reduce_mean(tf.square(noisy_targets - X_r_hat), -1) 246 | mask = tf.random_uniform(shape=tf.shape(self.aux_loss), minval=0., maxval=1., dtype=tf.float32) 247 | mask = tf.cast(mask < self.proportion_of_exp_used_for_predictor_update, tf.float32) 248 | self.aux_loss = tf.reduce_sum(mask * self.aux_loss) / tf.maximum(tf.reduce_sum(mask), 1.) 249 | 250 | def initial_state(self, n): 251 | return np.zeros((n, self.memsize), np.float32) 252 | 253 | def call(self, dict_obs, new, istate, update_obs_stats=False): 254 | for ob in dict_obs.values(): 255 | if ob is not None: 256 | if update_obs_stats: 257 | raise NotImplementedError 258 | ob = ob.astype(np.float32) 259 | ob = ob.reshape(-1, *self.ob_space.shape) 260 | self.ob_rms.update(ob) 261 | # Note: if it fails here with ph vs observations inconsistency, check if you're loading agent from disk. 262 | # It will use whatever observation spaces saved to disk along with other ctor params. 263 | feed1 = { self.ph_ob[k]: dict_obs[k][:,None] for k in self.ph_ob_keys } 264 | feed2 = { self.ph_istate: istate, self.ph_new: new[:,None].astype(np.float32) } 265 | feed1.update({self.ph_mean: self.ob_rms.mean, self.ph_std: self.ob_rms.var ** 0.5}) 266 | # for f in feed1: 267 | # print(f) 268 | a, vpred_int,vpred_ext, nlp, newstate, ent = tf.get_default_session().run( 269 | [self.a_samp, self.vpred_int_rollout,self.vpred_ext_rollout, self.nlp_samp, self.snext_rollout, self.entropy_rollout], 270 | feed_dict={**feed1, **feed2}) 271 | return a[:,0], vpred_int[:,0],vpred_ext[:,0], nlp[:,0], newstate, ent[:,0] 272 | -------------------------------------------------------------------------------- /policies/cnn_policy_param_matched.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from baselines import logger 4 | from utils import fc, conv, ortho_init 5 | from stochastic_policy import StochasticPolicy 6 | from tf_util import get_available_gpus 7 | from mpi_util import RunningMeanStd 8 | 9 | 10 | def to2d(x): 11 | size = 1 12 | for shapel in x.get_shape()[1:]: size *= shapel.value 13 | return tf.reshape(x, (-1, size)) 14 | 15 | def _fcnobias(x, scope, nh, *, init_scale=1.0): 16 | with tf.variable_scope(scope): 17 | nin = x.get_shape()[1].value 18 | w = tf.get_variable("w", [nin, nh], initializer=ortho_init(init_scale)) 19 | return tf.matmul(x, w) 20 | def _normalize(x): 21 | eps = 1e-5 22 | mean, var = tf.nn.moments(x, axes=(-1,), keepdims=True) 23 | return (x - mean) / tf.sqrt(var + eps) 24 | 25 | 26 | class CnnPolicy(StochasticPolicy): 27 | def __init__(self, scope, ob_space, ac_space, 28 | policy_size='normal', maxpool=False, extrahid=True, hidsize=128, memsize=128, rec_gate_init=0.0, 29 | update_ob_stats_independently_per_gpu=True, 30 | proportion_of_exp_used_for_predictor_update=1., 31 | dynamics_bonus = False, 32 | ): 33 | StochasticPolicy.__init__(self, scope, ob_space, ac_space) 34 | self.proportion_of_exp_used_for_predictor_update = proportion_of_exp_used_for_predictor_update 35 | enlargement = { 36 | 'small': 1, 37 | 'normal': 2, 38 | 'large': 4 39 | }[policy_size] 40 | rep_size = 512 41 | self.ph_mean = tf.placeholder(dtype=tf.float32, shape=list(ob_space.shape[:2])+[1], name="obmean") 42 | self.ph_std = tf.placeholder(dtype=tf.float32, shape=list(ob_space.shape[:2])+[1], name="obstd") 43 | memsize *= enlargement 44 | hidsize *= enlargement 45 | convfeat = 16*enlargement 46 | self.ob_rms = RunningMeanStd(shape=list(ob_space.shape[:2])+[1], use_mpi=not update_ob_stats_independently_per_gpu) 47 | ph_istate = tf.placeholder(dtype=tf.float32,shape=(None,memsize), name='state') 48 | pdparamsize = self.pdtype.param_shape()[0] 49 | self.memsize = memsize 50 | 51 | #Inputs to policy and value function will have different shapes depending on whether it is rollout 52 | #or optimization time, so we treat separately. 53 | self.pdparam_opt, self.vpred_int_opt, self.vpred_ext_opt, self.snext_opt = \ 54 | self.apply_policy(self.ph_ob[None][:,:-1], 55 | reuse=False, 56 | scope=scope, 57 | hidsize=hidsize, 58 | memsize=memsize, 59 | extrahid=extrahid, 60 | sy_nenvs=self.sy_nenvs, 61 | sy_nsteps=self.sy_nsteps - 1, 62 | pdparamsize=pdparamsize 63 | ) 64 | self.pdparam_rollout, self.vpred_int_rollout, self.vpred_ext_rollout, self.snext_rollout = \ 65 | self.apply_policy(self.ph_ob[None], 66 | reuse=True, 67 | scope=scope, 68 | hidsize=hidsize, 69 | memsize=memsize, 70 | extrahid=extrahid, 71 | sy_nenvs=self.sy_nenvs, 72 | sy_nsteps=self.sy_nsteps, 73 | pdparamsize=pdparamsize 74 | ) 75 | if dynamics_bonus: 76 | self.define_dynamics_prediction_rew(convfeat=convfeat, rep_size=rep_size, enlargement=enlargement) 77 | else: 78 | self.define_self_prediction_rew(convfeat=convfeat, rep_size=rep_size, enlargement=enlargement) 79 | 80 | pd = self.pdtype.pdfromflat(self.pdparam_rollout) 81 | self.a_samp = pd.sample() 82 | self.nlp_samp = pd.neglogp(self.a_samp) 83 | self.entropy_rollout = pd.entropy() 84 | self.pd_rollout = pd 85 | 86 | self.pd_opt = self.pdtype.pdfromflat(self.pdparam_opt) 87 | 88 | self.ph_istate = ph_istate 89 | 90 | @staticmethod 91 | def apply_policy(ph_ob, reuse, scope, hidsize, memsize, extrahid, sy_nenvs, sy_nsteps, pdparamsize): 92 | data_format = 'NHWC' 93 | ph = ph_ob 94 | assert len(ph.shape.as_list()) == 5 # B,T,H,W,C 95 | logger.info("CnnPolicy: using '%s' shape %s as image input" % (ph.name, str(ph.shape))) 96 | X = tf.cast(ph, tf.float32) / 255. 97 | X = tf.reshape(X, (-1, *ph.shape.as_list()[-3:])) 98 | 99 | activ = tf.nn.relu 100 | yes_gpu = any(get_available_gpus()) 101 | with tf.variable_scope(scope, reuse=reuse), tf.device('/gpu:0' if yes_gpu else '/cpu:0'): 102 | X = activ(conv(X, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2), data_format=data_format)) 103 | X = activ(conv(X, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), data_format=data_format)) 104 | X = activ(conv(X, 'c3', nf=64, rf=4, stride=1, init_scale=np.sqrt(2), data_format=data_format)) 105 | X = to2d(X) 106 | mix_other_observations = [X] 107 | X = tf.concat(mix_other_observations, axis=1) 108 | X = activ(fc(X, 'fc1', nh=hidsize, init_scale=np.sqrt(2))) 109 | additional_size = 448 110 | X = activ(fc(X, 'fc_additional', nh=additional_size, init_scale=np.sqrt(2))) 111 | snext = tf.zeros((sy_nenvs, memsize)) 112 | mix_timeout = [X] 113 | 114 | Xtout = tf.concat(mix_timeout, axis=1) 115 | if extrahid: 116 | Xtout = X + activ(fc(Xtout, 'fc2val', nh=additional_size, init_scale=0.1)) 117 | X = X + activ(fc(X, 'fc2act', nh=additional_size, init_scale=0.1)) 118 | pdparam = fc(X, 'pd', nh=pdparamsize, init_scale=0.01) 119 | vpred_int = fc(Xtout, 'vf_int', nh=1, init_scale=0.01) 120 | vpred_ext = fc(Xtout, 'vf_ext', nh=1, init_scale=0.01) 121 | 122 | pdparam = tf.reshape(pdparam, (sy_nenvs, sy_nsteps, pdparamsize)) 123 | vpred_int = tf.reshape(vpred_int, (sy_nenvs, sy_nsteps)) 124 | vpred_ext = tf.reshape(vpred_ext, (sy_nenvs, sy_nsteps)) 125 | return pdparam, vpred_int, vpred_ext, snext 126 | 127 | def define_self_prediction_rew(self, convfeat, rep_size, enlargement): 128 | logger.info("Using RND BONUS ****************************************************") 129 | 130 | #RND bonus. 131 | 132 | # Random target network. 133 | for ph in self.ph_ob.values(): 134 | if len(ph.shape.as_list()) == 5: # B,T,H,W,C 135 | logger.info("CnnTarget: using '%s' shape %s as image input" % (ph.name, str(ph.shape))) 136 | xr = ph[:,1:] 137 | xr = tf.cast(xr, tf.float32) 138 | xr = tf.reshape(xr, (-1, *ph.shape.as_list()[-3:]))[:, :, :, -1:] 139 | xr = tf.clip_by_value((xr - self.ph_mean) / self.ph_std, -5.0, 5.0) 140 | 141 | xr = tf.nn.leaky_relu(conv(xr, 'c1r', nf=convfeat * 1, rf=8, stride=4, init_scale=np.sqrt(2))) 142 | xr = tf.nn.leaky_relu(conv(xr, 'c2r', nf=convfeat * 2 * 1, rf=4, stride=2, init_scale=np.sqrt(2))) 143 | xr = tf.nn.leaky_relu(conv(xr, 'c3r', nf=convfeat * 2 * 1, rf=3, stride=1, init_scale=np.sqrt(2))) 144 | rgbr = [to2d(xr)] 145 | X_r = fc(rgbr[0], 'fc1r', nh=rep_size, init_scale=np.sqrt(2)) 146 | 147 | # Predictor network. 148 | for ph in self.ph_ob.values(): 149 | if len(ph.shape.as_list()) == 5: # B,T,H,W,C 150 | logger.info("CnnTarget: using '%s' shape %s as image input" % (ph.name, str(ph.shape))) 151 | xrp = ph[:,1:] 152 | xrp = tf.cast(xrp, tf.float32) 153 | xrp = tf.reshape(xrp, (-1, *ph.shape.as_list()[-3:]))[:, :, :, -1:] 154 | xrp = tf.clip_by_value((xrp - self.ph_mean) / self.ph_std, -5.0, 5.0) 155 | 156 | xrp = tf.nn.leaky_relu(conv(xrp, 'c1rp_pred', nf=convfeat, rf=8, stride=4, init_scale=np.sqrt(2))) 157 | xrp = tf.nn.leaky_relu(conv(xrp, 'c2rp_pred', nf=convfeat * 2, rf=4, stride=2, init_scale=np.sqrt(2))) 158 | xrp = tf.nn.leaky_relu(conv(xrp, 'c3rp_pred', nf=convfeat * 2, rf=3, stride=1, init_scale=np.sqrt(2))) 159 | rgbrp = to2d(xrp) 160 | # X_r_hat = tf.nn.relu(fc(rgb[0], 'fc1r_hat1', nh=256 * enlargement, init_scale=np.sqrt(2))) 161 | X_r_hat = tf.nn.relu(fc(rgbrp, 'fc1r_hat1_pred', nh=256 * enlargement, init_scale=np.sqrt(2))) 162 | X_r_hat = tf.nn.relu(fc(X_r_hat, 'fc1r_hat2_pred', nh=256 * enlargement, init_scale=np.sqrt(2))) 163 | X_r_hat = fc(X_r_hat, 'fc1r_hat3_pred', nh=rep_size, init_scale=np.sqrt(2)) 164 | 165 | self.feat_var = tf.reduce_mean(tf.nn.moments(X_r, axes=[0])[1]) 166 | self.max_feat = tf.reduce_max(tf.abs(X_r)) 167 | self.int_rew = tf.reduce_mean(tf.square(tf.stop_gradient(X_r) - X_r_hat), axis=-1, keep_dims=True) 168 | self.int_rew = tf.reshape(self.int_rew, (self.sy_nenvs, self.sy_nsteps - 1)) 169 | 170 | targets = tf.stop_gradient(X_r) 171 | # self.aux_loss = tf.reduce_mean(tf.square(noisy_targets-X_r_hat)) 172 | self.aux_loss = tf.reduce_mean(tf.square(targets - X_r_hat), -1) 173 | mask = tf.random_uniform(shape=tf.shape(self.aux_loss), minval=0., maxval=1., dtype=tf.float32) 174 | mask = tf.cast(mask < self.proportion_of_exp_used_for_predictor_update, tf.float32) 175 | self.aux_loss = tf.reduce_sum(mask * self.aux_loss) / tf.maximum(tf.reduce_sum(mask), 1.) 176 | 177 | def define_dynamics_prediction_rew(self, convfeat, rep_size, enlargement): 178 | #Dynamics loss with random features. 179 | 180 | # Random target network. 181 | for ph in self.ph_ob.values(): 182 | if len(ph.shape.as_list()) == 5: # B,T,H,W,C 183 | logger.info("CnnTarget: using '%s' shape %s as image input" % (ph.name, str(ph.shape))) 184 | xr = ph[:,1:] 185 | xr = tf.cast(xr, tf.float32) 186 | xr = tf.reshape(xr, (-1, *ph.shape.as_list()[-3:]))[:, :, :, -1:] 187 | xr = tf.clip_by_value((xr - self.ph_mean) / self.ph_std, -5.0, 5.0) 188 | 189 | xr = tf.nn.leaky_relu(conv(xr, 'c1r', nf=convfeat * 1, rf=8, stride=4, init_scale=np.sqrt(2))) 190 | xr = tf.nn.leaky_relu(conv(xr, 'c2r', nf=convfeat * 2 * 1, rf=4, stride=2, init_scale=np.sqrt(2))) 191 | xr = tf.nn.leaky_relu(conv(xr, 'c3r', nf=convfeat * 2 * 1, rf=3, stride=1, init_scale=np.sqrt(2))) 192 | rgbr = [to2d(xr)] 193 | X_r = fc(rgbr[0], 'fc1r', nh=rep_size, init_scale=np.sqrt(2)) 194 | 195 | # Predictor network. 196 | ac_one_hot = tf.one_hot(self.ph_ac, self.ac_space.n, axis=2) 197 | assert ac_one_hot.get_shape().ndims == 3 198 | assert ac_one_hot.get_shape().as_list() == [None, None, self.ac_space.n], ac_one_hot.get_shape().as_list() 199 | ac_one_hot = tf.reshape(ac_one_hot, (-1, self.ac_space.n)) 200 | def cond(x): 201 | return tf.concat([x, ac_one_hot], 1) 202 | 203 | for ph in self.ph_ob.values(): 204 | if len(ph.shape.as_list()) == 5: # B,T,H,W,C 205 | logger.info("CnnTarget: using '%s' shape %s as image input" % (ph.name, str(ph.shape))) 206 | xrp = ph[:,:-1] 207 | xrp = tf.cast(xrp, tf.float32) 208 | xrp = tf.reshape(xrp, (-1, *ph.shape.as_list()[-3:])) 209 | # ph_mean, ph_std are 84x84x1, so we subtract the average of the last channel from all channels. Is this ok? 210 | xrp = tf.clip_by_value((xrp - self.ph_mean) / self.ph_std, -5.0, 5.0) 211 | 212 | xrp = tf.nn.leaky_relu(conv(xrp, 'c1rp_pred', nf=convfeat, rf=8, stride=4, init_scale=np.sqrt(2))) 213 | xrp = tf.nn.leaky_relu(conv(xrp, 'c2rp_pred', nf=convfeat * 2, rf=4, stride=2, init_scale=np.sqrt(2))) 214 | xrp = tf.nn.leaky_relu(conv(xrp, 'c3rp_pred', nf=convfeat * 2, rf=3, stride=1, init_scale=np.sqrt(2))) 215 | rgbrp = to2d(xrp) 216 | 217 | # X_r_hat = tf.nn.relu(fc(rgb[0], 'fc1r_hat1', nh=256 * enlargement, init_scale=np.sqrt(2))) 218 | X_r_hat = tf.nn.relu(fc(cond(rgbrp), 'fc1r_hat1_pred', nh=256 * enlargement, init_scale=np.sqrt(2))) 219 | X_r_hat = tf.nn.relu(fc(cond(X_r_hat), 'fc1r_hat2_pred', nh=256 * enlargement, init_scale=np.sqrt(2))) 220 | X_r_hat = fc(cond(X_r_hat), 'fc1r_hat3_pred', nh=rep_size, init_scale=np.sqrt(2)) 221 | 222 | self.feat_var = tf.reduce_mean(tf.nn.moments(X_r, axes=[0])[1]) 223 | self.max_feat = tf.reduce_max(tf.abs(X_r)) 224 | self.int_rew = tf.reduce_mean(tf.square(tf.stop_gradient(X_r) - X_r_hat), axis=-1, keep_dims=True) 225 | self.int_rew = tf.reshape(self.int_rew, (self.sy_nenvs, self.sy_nsteps - 1)) 226 | 227 | noisy_targets = tf.stop_gradient(X_r) 228 | # self.aux_loss = tf.reduce_mean(tf.square(noisy_targets-X_r_hat)) 229 | self.aux_loss = tf.reduce_mean(tf.square(noisy_targets - X_r_hat), -1) 230 | mask = tf.random_uniform(shape=tf.shape(self.aux_loss), minval=0., maxval=1., dtype=tf.float32) 231 | mask = tf.cast(mask < self.proportion_of_exp_used_for_predictor_update, tf.float32) 232 | self.aux_loss = tf.reduce_sum(mask * self.aux_loss) / tf.maximum(tf.reduce_sum(mask), 1.) 233 | 234 | def initial_state(self, n): 235 | return np.zeros((n, self.memsize), np.float32) 236 | 237 | def call(self, dict_obs, new, istate, update_obs_stats=False): 238 | for ob in dict_obs.values(): 239 | if ob is not None: 240 | if update_obs_stats: 241 | raise NotImplementedError 242 | ob = ob.astype(np.float32) 243 | ob = ob.reshape(-1, *self.ob_space.shape) 244 | self.ob_rms.update(ob) 245 | # Note: if it fails here with ph vs observations inconsistency, check if you're loading agent from disk. 246 | # It will use whatever observation spaces saved to disk along with other ctor params. 247 | feed1 = { self.ph_ob[k]: dict_obs[k][:,None] for k in self.ph_ob_keys } 248 | feed2 = { self.ph_istate: istate, self.ph_new: new[:,None].astype(np.float32) } 249 | feed1.update({self.ph_mean: self.ob_rms.mean, self.ph_std: self.ob_rms.var ** 0.5}) 250 | # for f in feed1: 251 | # print(f) 252 | a, vpred_int,vpred_ext, nlp, newstate, ent = tf.get_default_session().run( 253 | [self.a_samp, self.vpred_int_rollout,self.vpred_ext_rollout, self.nlp_samp, self.snext_rollout, self.entropy_rollout], 254 | feed_dict={**feed1, **feed2}) 255 | return a[:,0], vpred_int[:,0],vpred_ext[:,0], nlp[:,0], newstate, ent[:,0] 256 | -------------------------------------------------------------------------------- /ppo_agent.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import deque, defaultdict 3 | from copy import copy 4 | 5 | import numpy as np 6 | import psutil 7 | import tensorflow as tf 8 | from mpi4py import MPI 9 | from baselines import logger 10 | import tf_util 11 | from recorder import Recorder 12 | from utils import explained_variance 13 | from console_util import fmt_row 14 | from mpi_util import MpiAdamOptimizer, RunningMeanStd, sync_from_root 15 | 16 | NO_STATES = ['NO_STATES'] 17 | 18 | class SemicolonList(list): 19 | def __str__(self): 20 | return '['+';'.join([str(x) for x in self])+']' 21 | 22 | class InteractionState(object): 23 | """ 24 | Parts of the PPOAgent's state that are based on interaction with a single batch of envs 25 | """ 26 | def __init__(self, ob_space, ac_space, nsteps, gamma, venvs, stochpol, comm): 27 | self.lump_stride = venvs[0].num_envs 28 | self.venvs = venvs 29 | assert all(venv.num_envs == self.lump_stride for venv in self.venvs[1:]), 'All venvs should have the same num_envs' 30 | self.nlump = len(venvs) 31 | nenvs = self.nenvs = self.nlump * self.lump_stride 32 | self.reset_counter = 0 33 | self.env_results = [None] * self.nlump 34 | self.buf_vpreds_int = np.zeros((nenvs, nsteps), np.float32) 35 | self.buf_vpreds_ext = np.zeros((nenvs, nsteps), np.float32) 36 | self.buf_nlps = np.zeros((nenvs, nsteps), np.float32) 37 | self.buf_advs = np.zeros((nenvs, nsteps), np.float32) 38 | self.buf_advs_int = np.zeros((nenvs, nsteps), np.float32) 39 | self.buf_advs_ext = np.zeros((nenvs, nsteps), np.float32) 40 | self.buf_rews_int = np.zeros((nenvs, nsteps), np.float32) 41 | self.buf_rews_ext = np.zeros((nenvs, nsteps), np.float32) 42 | self.buf_acs = np.zeros((nenvs, nsteps, *ac_space.shape), ac_space.dtype) 43 | self.buf_obs = { k: np.zeros( 44 | [nenvs, nsteps] + stochpol.ph_ob[k].shape.as_list()[2:], 45 | dtype=stochpol.ph_ob_dtypes[k]) 46 | for k in stochpol.ph_ob_keys } 47 | self.buf_ob_last = { k: self.buf_obs[k][:, 0, ...].copy() for k in stochpol.ph_ob_keys } 48 | self.buf_epinfos = [{} for _ in range(self.nenvs)] 49 | self.buf_news = np.zeros((nenvs, nsteps), np.float32) 50 | self.buf_ent = np.zeros((nenvs, nsteps), np.float32) 51 | self.mem_state = stochpol.initial_state(nenvs) 52 | self.seg_init_mem_state = copy(self.mem_state) # Memory state at beginning of segment of timesteps 53 | self.rff_int = RewardForwardFilter(gamma) 54 | self.rff_rms_int = RunningMeanStd(comm=comm, use_mpi=True) 55 | self.buf_new_last = self.buf_news[:, 0, ...].copy() 56 | self.buf_vpred_int_last = self.buf_vpreds_int[:, 0, ...].copy() 57 | self.buf_vpred_ext_last = self.buf_vpreds_ext[:, 0, ...].copy() 58 | self.step_count = 0 # counts number of timesteps that you've interacted with this set of environments 59 | self.t_last_update = time.time() 60 | self.statlists = defaultdict(lambda : deque([], maxlen=100)) # Count other stats, e.g. optimizer outputs 61 | self.stats = defaultdict(float) # Count episodes and timesteps 62 | self.stats['epcount'] = 0 63 | self.stats['n_updates'] = 0 64 | self.stats['tcount'] = 0 65 | 66 | def close(self): 67 | for venv in self.venvs: 68 | venv.close() 69 | 70 | def dict_gather(comm, d, op='mean'): 71 | if comm is None: return d 72 | alldicts = comm.allgather(d) 73 | size = comm.Get_size() 74 | k2li = defaultdict(list) 75 | for d in alldicts: 76 | for (k,v) in d.items(): 77 | k2li[k].append(v) 78 | result = {} 79 | for (k,li) in k2li.items(): 80 | if op=='mean': 81 | result[k] = np.mean(li, axis=0) 82 | elif op=='sum': 83 | result[k] = np.sum(li, axis=0) 84 | elif op=="max": 85 | result[k] = np.max(li, axis=0) 86 | else: 87 | assert 0, op 88 | return result 89 | 90 | 91 | class PpoAgent(object): 92 | envs = None 93 | def __init__(self, *, scope, 94 | ob_space, ac_space, 95 | stochpol_fn, 96 | nsteps, nepochs=4, nminibatches=1, 97 | gamma=0.99, 98 | gamma_ext=0.99, 99 | lam=0.95, 100 | ent_coef=0, 101 | cliprange=0.2, 102 | max_grad_norm=1.0, 103 | vf_coef=1.0, 104 | lr=30e-5, 105 | adam_hps=None, 106 | testing=False, 107 | comm=None, comm_train=None, use_news=False, 108 | update_ob_stats_every_step=True, 109 | int_coeff=None, 110 | ext_coeff=None, 111 | ): 112 | self.lr = lr 113 | self.ext_coeff = ext_coeff 114 | self.int_coeff = int_coeff 115 | self.use_news = use_news 116 | self.update_ob_stats_every_step = update_ob_stats_every_step 117 | self.abs_scope = (tf.get_variable_scope().name + '/' + scope).lstrip('/') 118 | self.testing = testing 119 | self.comm_log = MPI.COMM_SELF 120 | if comm is not None and comm.Get_size() > 1: 121 | self.comm_log = comm 122 | assert not testing or comm.Get_rank() != 0, "Worker number zero can't be testing" 123 | if comm_train is not None: 124 | self.comm_train, self.comm_train_size = comm_train, comm_train.Get_size() 125 | else: 126 | self.comm_train, self.comm_train_size = self.comm_log, self.comm_log.Get_size() 127 | self.is_log_leader = self.comm_log.Get_rank()==0 128 | self.is_train_leader = self.comm_train.Get_rank()==0 129 | with tf.variable_scope(scope): 130 | self.best_ret = -np.inf 131 | self.local_best_ret = - np.inf 132 | self.rooms = [] 133 | self.local_rooms = [] 134 | self.scores = [] 135 | self.ob_space = ob_space 136 | self.ac_space = ac_space 137 | self.stochpol = stochpol_fn() 138 | self.nepochs = nepochs 139 | self.cliprange = cliprange 140 | self.nsteps = nsteps 141 | self.nminibatches = nminibatches 142 | self.gamma = gamma 143 | self.gamma_ext = gamma_ext 144 | self.lam = lam 145 | self.adam_hps = adam_hps or dict() 146 | self.ph_adv = tf.placeholder(tf.float32, [None, None]) 147 | self.ph_ret_int = tf.placeholder(tf.float32, [None, None]) 148 | self.ph_ret_ext = tf.placeholder(tf.float32, [None, None]) 149 | self.ph_oldnlp = tf.placeholder(tf.float32, [None, None]) 150 | self.ph_oldvpred = tf.placeholder(tf.float32, [None, None]) 151 | self.ph_lr = tf.placeholder(tf.float32, []) 152 | self.ph_lr_pred = tf.placeholder(tf.float32, []) 153 | self.ph_cliprange = tf.placeholder(tf.float32, []) 154 | 155 | #Define loss. 156 | neglogpac = self.stochpol.pd_opt.neglogp(self.stochpol.ph_ac) 157 | entropy = tf.reduce_mean(self.stochpol.pd_opt.entropy()) 158 | vf_loss_int = (0.5 * vf_coef) * tf.reduce_mean(tf.square(self.stochpol.vpred_int_opt - self.ph_ret_int)) 159 | vf_loss_ext = (0.5 * vf_coef) * tf.reduce_mean(tf.square(self.stochpol.vpred_ext_opt - self.ph_ret_ext)) 160 | vf_loss = vf_loss_int + vf_loss_ext 161 | ratio = tf.exp(self.ph_oldnlp - neglogpac) # p_new / p_old 162 | negadv = - self.ph_adv 163 | pg_losses1 = negadv * ratio 164 | pg_losses2 = negadv * tf.clip_by_value(ratio, 1.0 - self.ph_cliprange, 1.0 + self.ph_cliprange) 165 | pg_loss = tf.reduce_mean(tf.maximum(pg_losses1, pg_losses2)) 166 | ent_loss = (- ent_coef) * entropy 167 | approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - self.ph_oldnlp)) 168 | maxkl = .5 * tf.reduce_max(tf.square(neglogpac - self.ph_oldnlp)) 169 | clipfrac = tf.reduce_mean(tf.to_float(tf.greater(tf.abs(ratio - 1.0), self.ph_cliprange))) 170 | loss = pg_loss + ent_loss + vf_loss + self.stochpol.aux_loss 171 | 172 | #Create optimizer. 173 | params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.abs_scope) 174 | logger.info("PPO: using MpiAdamOptimizer connected to %i peers" % self.comm_train_size) 175 | trainer = MpiAdamOptimizer(self.comm_train, learning_rate=self.ph_lr, **self.adam_hps) 176 | grads_and_vars = trainer.compute_gradients(loss, params) 177 | grads, vars = zip(*grads_and_vars) 178 | if max_grad_norm: 179 | _, _grad_norm = tf.clip_by_global_norm(grads, max_grad_norm) 180 | global_grad_norm = tf.global_norm(grads) 181 | grads_and_vars = list(zip(grads, vars)) 182 | self._train = trainer.apply_gradients(grads_and_vars) 183 | 184 | #Quantities for reporting. 185 | self._losses = [loss, pg_loss, vf_loss, entropy, clipfrac, approxkl, maxkl, self.stochpol.aux_loss, 186 | self.stochpol.feat_var, self.stochpol.max_feat, global_grad_norm] 187 | self.loss_names = ['tot', 'pg', 'vf', 'ent', 'clipfrac', 'approxkl', 'maxkl', "auxloss", "featvar", 188 | "maxfeat", "gradnorm"] 189 | self.I = None 190 | self.disable_policy_update = None 191 | allvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.abs_scope) 192 | if self.is_log_leader: 193 | tf_util.display_var_info(allvars) 194 | tf.get_default_session().run(tf.variables_initializer(allvars)) 195 | sync_from_root(tf.get_default_session(), allvars) #Syncs initialization across mpi workers. 196 | self.t0 = time.time() 197 | self.global_tcount = 0 198 | 199 | def start_interaction(self, venvs, disable_policy_update=False): 200 | self.I = InteractionState(ob_space=self.ob_space, ac_space=self.ac_space, 201 | nsteps=self.nsteps, gamma=self.gamma, 202 | venvs=venvs, stochpol=self.stochpol, comm=self.comm_train) 203 | self.disable_policy_update = disable_policy_update 204 | self.recorder = Recorder(nenvs=self.I.nenvs, score_multiple=venvs[0].score_multiple) 205 | 206 | def collect_random_statistics(self, num_timesteps): 207 | #Initializes observation normalization with data from random agent. 208 | all_ob = [] 209 | for lump in range(self.I.nlump): 210 | all_ob.append(self.I.venvs[lump].reset()) 211 | for step in range(num_timesteps): 212 | for lump in range(self.I.nlump): 213 | acs = np.random.randint(low=0, high=self.ac_space.n, size=(self.I.lump_stride,)) 214 | self.I.venvs[lump].step_async(acs) 215 | ob, _, _, _ = self.I.venvs[lump].step_wait() 216 | all_ob.append(ob) 217 | if len(all_ob) % (128 * self.I.nlump) == 0: 218 | ob_ = np.asarray(all_ob).astype(np.float32).reshape((-1, *self.ob_space.shape)) 219 | self.stochpol.ob_rms.update(ob_[:,:,:,-1:]) 220 | all_ob.clear() 221 | 222 | def stop_interaction(self): 223 | self.I.close() 224 | self.I = None 225 | 226 | @logger.profile("update") 227 | def update(self): 228 | 229 | #Some logic gathering best ret, rooms etc using MPI. 230 | temp = sum(MPI.COMM_WORLD.allgather(self.local_rooms), []) 231 | temp = sorted(list(set(temp))) 232 | self.rooms = temp 233 | 234 | temp = sum(MPI.COMM_WORLD.allgather(self.scores), []) 235 | temp = sorted(list(set(temp))) 236 | self.scores = temp 237 | 238 | temp = sum(MPI.COMM_WORLD.allgather([self.local_best_ret]), []) 239 | self.best_ret = max(temp) 240 | 241 | eprews = MPI.COMM_WORLD.allgather(np.mean(list(self.I.statlists["eprew"]))) 242 | local_best_rets = MPI.COMM_WORLD.allgather(self.local_best_ret) 243 | n_rooms = sum(MPI.COMM_WORLD.allgather([len(self.local_rooms)]), []) 244 | 245 | if MPI.COMM_WORLD.Get_rank() == 0: 246 | logger.info(f"Rooms visited {self.rooms}") 247 | logger.info(f"Best return {self.best_ret}") 248 | logger.info(f"Best local return {sorted(local_best_rets)}") 249 | logger.info(f"eprews {sorted(eprews)}") 250 | logger.info(f"n_rooms {sorted(n_rooms)}") 251 | logger.info(f"Extrinsic coefficient {self.ext_coeff}") 252 | logger.info(f"Gamma {self.gamma}") 253 | logger.info(f"Gamma ext {self.gamma_ext}") 254 | logger.info(f"All scores {sorted(self.scores)}") 255 | 256 | 257 | #Normalize intrinsic rewards. 258 | rffs_int = np.array([self.I.rff_int.update(rew) for rew in self.I.buf_rews_int.T]) 259 | self.I.rff_rms_int.update(rffs_int.ravel()) 260 | rews_int = self.I.buf_rews_int / np.sqrt(self.I.rff_rms_int.var) 261 | self.mean_int_rew = np.mean(rews_int) 262 | self.max_int_rew = np.max(rews_int) 263 | 264 | #Don't normalize extrinsic rewards. 265 | rews_ext = self.I.buf_rews_ext 266 | 267 | rewmean, rewstd, rewmax = self.I.buf_rews_int.mean(), self.I.buf_rews_int.std(), np.max(self.I.buf_rews_int) 268 | 269 | #Calculate intrinsic returns and advantages. 270 | lastgaelam = 0 271 | for t in range(self.nsteps-1, -1, -1): # nsteps-2 ... 0 272 | if self.use_news: 273 | nextnew = self.I.buf_news[:, t + 1] if t + 1 < self.nsteps else self.I.buf_new_last 274 | else: 275 | nextnew = 0.0 #No dones for intrinsic reward. 276 | nextvals = self.I.buf_vpreds_int[:, t + 1] if t + 1 < self.nsteps else self.I.buf_vpred_int_last 277 | nextnotnew = 1 - nextnew 278 | delta = rews_int[:, t] + self.gamma * nextvals * nextnotnew - self.I.buf_vpreds_int[:, t] 279 | self.I.buf_advs_int[:, t] = lastgaelam = delta + self.gamma * self.lam * nextnotnew * lastgaelam 280 | rets_int = self.I.buf_advs_int + self.I.buf_vpreds_int 281 | 282 | #Calculate extrinsic returns and advantages. 283 | lastgaelam = 0 284 | for t in range(self.nsteps-1, -1, -1): # nsteps-2 ... 0 285 | nextnew = self.I.buf_news[:, t + 1] if t + 1 < self.nsteps else self.I.buf_new_last 286 | #Use dones for extrinsic reward. 287 | nextvals = self.I.buf_vpreds_ext[:, t + 1] if t + 1 < self.nsteps else self.I.buf_vpred_ext_last 288 | nextnotnew = 1 - nextnew 289 | delta = rews_ext[:, t] + self.gamma_ext * nextvals * nextnotnew - self.I.buf_vpreds_ext[:, t] 290 | self.I.buf_advs_ext[:, t] = lastgaelam = delta + self.gamma_ext * self.lam * nextnotnew * lastgaelam 291 | rets_ext = self.I.buf_advs_ext + self.I.buf_vpreds_ext 292 | 293 | #Combine the extrinsic and intrinsic advantages. 294 | self.I.buf_advs = self.int_coeff*self.I.buf_advs_int + self.ext_coeff*self.I.buf_advs_ext 295 | 296 | #Collects info for reporting. 297 | info = dict( 298 | advmean = self.I.buf_advs.mean(), 299 | advstd = self.I.buf_advs.std(), 300 | retintmean = rets_int.mean(), # previously retmean 301 | retintstd = rets_int.std(), # previously retstd 302 | retextmean = rets_ext.mean(), # previously not there 303 | retextstd = rets_ext.std(), # previously not there 304 | rewintmean_unnorm = rewmean, # previously rewmean 305 | rewintmax_unnorm = rewmax, # previously not there 306 | rewintmean_norm = self.mean_int_rew, # previously rewintmean 307 | rewintmax_norm = self.max_int_rew, # previously rewintmax 308 | rewintstd_unnorm = rewstd, # previously rewstd 309 | vpredintmean = self.I.buf_vpreds_int.mean(), # previously vpredmean 310 | vpredintstd = self.I.buf_vpreds_int.std(), # previously vrpedstd 311 | vpredextmean = self.I.buf_vpreds_ext.mean(), # previously not there 312 | vpredextstd = self.I.buf_vpreds_ext.std(), # previously not there 313 | ev_int = np.clip(explained_variance(self.I.buf_vpreds_int.ravel(), rets_int.ravel()), -1, None), 314 | ev_ext = np.clip(explained_variance(self.I.buf_vpreds_ext.ravel(), rets_ext.ravel()), -1, None), 315 | rooms = SemicolonList(self.rooms), 316 | n_rooms = len(self.rooms), 317 | best_ret = self.best_ret, 318 | reset_counter = self.I.reset_counter 319 | ) 320 | 321 | info[f'mem_available'] = psutil.virtual_memory().available 322 | 323 | to_record = {'acs': self.I.buf_acs, 324 | 'rews_int': self.I.buf_rews_int, 325 | 'rews_int_norm': rews_int, 326 | 'rews_ext': self.I.buf_rews_ext, 327 | 'vpred_int': self.I.buf_vpreds_int, 328 | 'vpred_ext': self.I.buf_vpreds_ext, 329 | 'adv_int': self.I.buf_advs_int, 330 | 'adv_ext': self.I.buf_advs_ext, 331 | 'ent': self.I.buf_ent, 332 | 'ret_int': rets_int, 333 | 'ret_ext': rets_ext, 334 | } 335 | if self.I.venvs[0].record_obs: 336 | to_record['obs'] = self.I.buf_obs[None] 337 | self.recorder.record(bufs=to_record, 338 | infos=self.I.buf_epinfos) 339 | 340 | 341 | #Create feeddict for optimization. 342 | envsperbatch = self.I.nenvs // self.nminibatches 343 | ph_buf = [ 344 | (self.stochpol.ph_ac, self.I.buf_acs), 345 | (self.ph_ret_int, rets_int), 346 | (self.ph_ret_ext, rets_ext), 347 | (self.ph_oldnlp, self.I.buf_nlps), 348 | (self.ph_adv, self.I.buf_advs), 349 | ] 350 | if self.I.mem_state is not NO_STATES: 351 | ph_buf.extend([ 352 | (self.stochpol.ph_istate, self.I.seg_init_mem_state), 353 | (self.stochpol.ph_new, self.I.buf_news), 354 | ]) 355 | 356 | verbose = True 357 | if verbose and self.is_log_leader: 358 | samples = np.prod(self.I.buf_advs.shape) 359 | logger.info("buffer shape %s, samples_per_mpi=%i, mini_per_mpi=%i, samples=%i, mini=%i " % ( 360 | str(self.I.buf_advs.shape), 361 | samples, samples//self.nminibatches, 362 | samples*self.comm_train_size, samples*self.comm_train_size//self.nminibatches)) 363 | logger.info(" "*6 + fmt_row(13, self.loss_names)) 364 | 365 | 366 | epoch = 0 367 | start = 0 368 | #Optimizes on current data for several epochs. 369 | while epoch < self.nepochs: 370 | end = start + envsperbatch 371 | mbenvinds = slice(start, end, None) 372 | 373 | fd = {ph : buf[mbenvinds] for (ph, buf) in ph_buf} 374 | fd.update({self.ph_lr : self.lr, self.ph_cliprange : self.cliprange}) 375 | fd[self.stochpol.ph_ob[None]] = np.concatenate([self.I.buf_obs[None][mbenvinds], self.I.buf_ob_last[None][mbenvinds, None]], 1) 376 | assert list(fd[self.stochpol.ph_ob[None]].shape) == [self.I.nenvs//self.nminibatches, self.nsteps + 1] + list(self.ob_space.shape), \ 377 | [fd[self.stochpol.ph_ob[None]].shape, [self.I.nenvs//self.nminibatches, self.nsteps + 1] + list(self.ob_space.shape)] 378 | fd.update({self.stochpol.ph_mean:self.stochpol.ob_rms.mean, self.stochpol.ph_std:self.stochpol.ob_rms.var**0.5}) 379 | 380 | ret = tf.get_default_session().run(self._losses+[self._train], feed_dict=fd)[:-1] 381 | if not self.testing: 382 | lossdict = dict(zip([n for n in self.loss_names], ret), axis=0) 383 | else: 384 | lossdict = {} 385 | #Synchronize the lossdict across mpi processes, otherwise weights may be rolled back on one process but not another. 386 | _maxkl = lossdict.pop('maxkl') 387 | lossdict = dict_gather(self.comm_train, lossdict, op='mean') 388 | maxmaxkl = dict_gather(self.comm_train, {"maxkl":_maxkl}, op='max') 389 | lossdict["maxkl"] = maxmaxkl["maxkl"] 390 | if verbose and self.is_log_leader: 391 | logger.info("%i:%03i %s" % (epoch, start, fmt_row(13, [lossdict[n] for n in self.loss_names]))) 392 | start += envsperbatch 393 | if start == self.I.nenvs: 394 | epoch += 1 395 | start = 0 396 | 397 | if self.is_train_leader: 398 | self.I.stats["n_updates"] += 1 399 | info.update([('opt_'+n, lossdict[n]) for n in self.loss_names]) 400 | tnow = time.time() 401 | info['tps'] = self.nsteps * self.I.nenvs / (tnow - self.I.t_last_update) 402 | info['time_elapsed'] = time.time() - self.t0 403 | self.I.t_last_update = tnow 404 | self.stochpol.update_normalization( # Necessary for continuous control tasks with odd obs ranges, only implemented in mlp policy, 405 | ob=self.I.buf_obs # NOTE: not shared via MPI 406 | ) 407 | return info 408 | 409 | def env_step(self, l, acs): 410 | self.I.venvs[l].step_async(acs) 411 | self.I.env_results[l] = None 412 | 413 | def env_get(self, l): 414 | """ 415 | Get most recent (obs, rews, dones, infos) from vectorized environment 416 | Using step_wait if necessary 417 | """ 418 | if self.I.step_count == 0: # On the zeroth step with a new venv, we need to call reset on the environment 419 | ob = self.I.venvs[l].reset() 420 | out = self.I.env_results[l] = (ob, None, np.ones(self.I.lump_stride, bool), {}) 421 | else: 422 | if self.I.env_results[l] is None: 423 | out = self.I.env_results[l] = self.I.venvs[l].step_wait() 424 | else: 425 | out = self.I.env_results[l] 426 | return out 427 | 428 | @logger.profile("step") 429 | def step(self): 430 | #Does a rollout. 431 | t = self.I.step_count % self.nsteps 432 | epinfos = [] 433 | for l in range(self.I.nlump): 434 | obs, prevrews, news, infos = self.env_get(l) 435 | for env_pos_in_lump, info in enumerate(infos): 436 | if 'episode' in info: 437 | #Information like rooms visited is added to info on end of episode. 438 | epinfos.append(info['episode']) 439 | info_with_places = info['episode'] 440 | try: 441 | info_with_places['places'] = info['episode']['visited_rooms'] 442 | except: 443 | import ipdb; ipdb.set_trace() 444 | self.I.buf_epinfos[env_pos_in_lump+l*self.I.lump_stride][t] = info_with_places 445 | 446 | sli = slice(l * self.I.lump_stride, (l + 1) * self.I.lump_stride) 447 | memsli = slice(None) if self.I.mem_state is NO_STATES else sli 448 | dict_obs = self.stochpol.ensure_observation_is_dict(obs) 449 | with logger.ProfileKV("policy_inference"): 450 | #Calls the policy and value function on current observation. 451 | acs, vpreds_int, vpreds_ext, nlps, self.I.mem_state[memsli], ent = self.stochpol.call(dict_obs, news, self.I.mem_state[memsli], 452 | update_obs_stats=self.update_ob_stats_every_step) 453 | self.env_step(l, acs) 454 | 455 | #Update buffer with transition. 456 | for k in self.stochpol.ph_ob_keys: 457 | self.I.buf_obs[k][sli, t] = dict_obs[k] 458 | self.I.buf_news[sli, t] = news 459 | self.I.buf_vpreds_int[sli, t] = vpreds_int 460 | self.I.buf_vpreds_ext[sli, t] = vpreds_ext 461 | self.I.buf_nlps[sli, t] = nlps 462 | self.I.buf_acs[sli, t] = acs 463 | self.I.buf_ent[sli, t] = ent 464 | 465 | if t > 0: 466 | self.I.buf_rews_ext[sli, t-1] = prevrews 467 | 468 | self.I.step_count += 1 469 | if t == self.nsteps - 1 and not self.disable_policy_update: 470 | #We need to take one extra step so every transition has a reward. 471 | for l in range(self.I.nlump): 472 | sli = slice(l * self.I.lump_stride, (l + 1) * self.I.lump_stride) 473 | memsli = slice(None) if self.I.mem_state is NO_STATES else sli 474 | nextobs, rews, nextnews, _ = self.env_get(l) 475 | dict_nextobs = self.stochpol.ensure_observation_is_dict(nextobs) 476 | for k in self.stochpol.ph_ob_keys: 477 | self.I.buf_ob_last[k][sli] = dict_nextobs[k] 478 | self.I.buf_new_last[sli] = nextnews 479 | with logger.ProfileKV("policy_inference"): 480 | _, self.I.buf_vpred_int_last[sli], self.I.buf_vpred_ext_last[sli], _, _, _ = self.stochpol.call(dict_nextobs, nextnews, self.I.mem_state[memsli], update_obs_stats=False) 481 | self.I.buf_rews_ext[sli, t] = rews 482 | 483 | #Calcuate the intrinsic rewards for the rollout. 484 | fd = {} 485 | fd[self.stochpol.ph_ob[None]] = np.concatenate([self.I.buf_obs[None], self.I.buf_ob_last[None][:,None]], 1) 486 | fd.update({self.stochpol.ph_mean: self.stochpol.ob_rms.mean, 487 | self.stochpol.ph_std: self.stochpol.ob_rms.var ** 0.5}) 488 | fd[self.stochpol.ph_ac] = self.I.buf_acs 489 | self.I.buf_rews_int[:] = tf.get_default_session().run(self.stochpol.int_rew, fd) 490 | 491 | if not self.update_ob_stats_every_step: 492 | #Update observation normalization parameters after the rollout is completed. 493 | obs_ = self.I.buf_obs[None].astype(np.float32) 494 | self.stochpol.ob_rms.update(obs_.reshape((-1, *obs_.shape[2:]))[:,:,:,-1:]) 495 | if not self.testing: 496 | update_info = self.update() 497 | else: 498 | update_info = {} 499 | self.I.seg_init_mem_state = copy(self.I.mem_state) 500 | global_i_stats = dict_gather(self.comm_log, self.I.stats, op='sum') 501 | global_deque_mean = dict_gather(self.comm_log, { n : np.mean(dvs) for n,dvs in self.I.statlists.items() }, op='mean') 502 | update_info.update(global_i_stats) 503 | update_info.update(global_deque_mean) 504 | self.global_tcount = global_i_stats['tcount'] 505 | for infos_ in self.I.buf_epinfos: 506 | infos_.clear() 507 | else: 508 | update_info = {} 509 | 510 | #Some reporting logic. 511 | for epinfo in epinfos: 512 | if self.testing: 513 | self.I.statlists['eprew_test'].append(epinfo['r']) 514 | self.I.statlists['eplen_test'].append(epinfo['l']) 515 | else: 516 | if "visited_rooms" in epinfo: 517 | self.local_rooms += list(epinfo["visited_rooms"]) 518 | self.local_rooms = sorted(list(set(self.local_rooms))) 519 | score_multiple = self.I.venvs[0].score_multiple 520 | if score_multiple is None: 521 | score_multiple = 1000 522 | rounded_score = int(epinfo["r"] / score_multiple) * score_multiple 523 | self.scores.append(rounded_score) 524 | self.scores = sorted(list(set(self.scores))) 525 | self.I.statlists['eprooms'].append(len(epinfo["visited_rooms"])) 526 | 527 | self.I.statlists['eprew'].append(epinfo['r']) 528 | if self.local_best_ret is None: 529 | self.local_best_ret = epinfo["r"] 530 | elif epinfo["r"] > self.local_best_ret: 531 | self.local_best_ret = epinfo["r"] 532 | 533 | self.I.statlists['eplen'].append(epinfo['l']) 534 | self.I.stats['epcount'] += 1 535 | self.I.stats['tcount'] += epinfo['l'] 536 | self.I.stats['rewtotal'] += epinfo['r'] 537 | # self.I.stats["best_ext_ret"] = self.best_ret 538 | 539 | 540 | return {'update' : update_info} 541 | 542 | 543 | class RewardForwardFilter(object): 544 | def __init__(self, gamma): 545 | self.rewems = None 546 | self.gamma = gamma 547 | def update(self, rews): 548 | if self.rewems is None: 549 | self.rewems = rews 550 | else: 551 | self.rewems = self.rewems * self.gamma + rews 552 | return self.rewems 553 | 554 | def flatten_lists(listoflists): 555 | return [el for list_ in listoflists for el in list_] 556 | -------------------------------------------------------------------------------- /recorder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | 7 | from baselines import logger 8 | from mpi4py import MPI 9 | 10 | def is_square(n): 11 | return n == (int(np.sqrt(n))) ** 2 12 | 13 | class Recorder(object): 14 | def __init__(self, nenvs, score_multiple=1): 15 | self.episodes = [defaultdict(list) for _ in range(nenvs)] 16 | self.total_episodes = 0 17 | self.filename = self.get_filename() 18 | self.score_multiple = score_multiple 19 | 20 | self.all_scores = {} 21 | self.all_places = {} 22 | 23 | def record(self, bufs, infos): 24 | for env_id, ep_infos in enumerate(infos): 25 | left_step = 0 26 | done_steps = sorted(ep_infos.keys()) 27 | for right_step in done_steps: 28 | for key in bufs: 29 | self.episodes[env_id][key].append(bufs[key][env_id, left_step:right_step].copy()) 30 | self.record_episode(env_id, ep_infos[right_step]) 31 | left_step = right_step 32 | for key in bufs: 33 | self.episodes[env_id][key].clear() 34 | for key in bufs: 35 | self.episodes[env_id][key].append(bufs[key][env_id, left_step:].copy()) 36 | 37 | 38 | def record_episode(self, env_id, info): 39 | self.total_episodes += 1 40 | if self.episode_worth_saving(env_id, info): 41 | episode = {} 42 | for key in self.episodes[env_id]: 43 | episode[key] = np.concatenate(self.episodes[env_id][key]) 44 | info['env_id'] = env_id 45 | episode['info'] = info 46 | with open(self.filename, 'ab') as f: 47 | pickle.dump(episode, f, protocol=-1) 48 | 49 | def get_score(self, info): 50 | return int(info['r']/self.score_multiple) * self.score_multiple 51 | 52 | def episode_worth_saving(self, env_id, info): 53 | if self.score_multiple is None: 54 | return False 55 | r = self.get_score(info) 56 | if r not in self.all_scores: 57 | self.all_scores[r] = 0 58 | else: 59 | self.all_scores[r] += 1 60 | hashable_places = tuple(sorted(info['places'])) 61 | if hashable_places not in self.all_places: 62 | self.all_places[hashable_places] = 0 63 | else: 64 | self.all_places[hashable_places] += 1 65 | if is_square(self.all_scores[r]) or is_square(self.all_places[hashable_places]): 66 | return True 67 | if 15 in info['places']: 68 | return True 69 | return False 70 | 71 | def get_filename(self): 72 | filename = os.path.join(logger.get_dir(), 'videos_{}.pk'.format(MPI.COMM_WORLD.Get_rank())) 73 | return filename 74 | 75 | -------------------------------------------------------------------------------- /replayer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import glob 4 | import os 5 | import pickle 6 | import sys 7 | 8 | import exptag 9 | import ipdb 10 | import numpy as np 11 | from atari_wrappers import make_atari, wrap_deepmind 12 | from run_atari import add_env_params 13 | 14 | seen_scores = set() 15 | 16 | 17 | class EpisodeIterator(object): 18 | def __init__(self, filenames): 19 | if args['filter'] == 'none': 20 | cond = lambda info: True 21 | elif args['filter'] == 'rew': 22 | cond = lambda info: info['r'] < args['rew_max'] and (info['r'] > args['rew_min']) 23 | elif args['filter'] == 'room': 24 | def cond(info): 25 | return any(int(room) in args['room_number'] for room in info['places']) 26 | 27 | self.filenames = filenames 28 | self.condition = cond 29 | self.episode_number = 0 30 | 31 | def iterate(self): 32 | for filename in self.filenames: 33 | print("Opening file", filename) 34 | with open(filename, 'rb') as f: 35 | yield from self.iterate_over_episodes_in_file(f, condition=self.condition) 36 | raise StopIteration 37 | 38 | def iterate_over_episodes_in_file(self, file, condition): 39 | while True: 40 | try: 41 | episode = pickle.load(file) 42 | except: 43 | raise StopIteration 44 | 45 | info = episode['info'] 46 | if condition(info): 47 | print(self.episode_number) 48 | self.episode_number += 1 49 | if self.episode_number >= args['skip']: 50 | if 'obs' in episode: 51 | # import ipdb; ipdb.set_trace() 52 | yield episode 53 | else: 54 | unwrapped_env = env.unwrapped 55 | if 'rng_at_episode_start' in info: 56 | random_state = info['rng_at_episode_start'] 57 | unwrapped_env.np_random.set_state(random_state.get_state()) 58 | if hasattr(unwrapped_env, "scene"): 59 | unwrapped_env.scene.np_random.set_state(random_state.get_state()) 60 | ob = env.reset() 61 | ret = 0 62 | frames = [] 63 | infos = [] 64 | for i, a in enumerate(episode['acs']): 65 | ob, r, d, info = env.step(a) 66 | if args['display'] == 'game': 67 | rend = unwrapped_env.render(mode="rgb_array") 68 | else: 69 | rend = np.asarray(ob)[:, :, :1] 70 | frames.append(rend) 71 | ret += r 72 | infos.append(info) 73 | assert not d or i == len(episode['acs']) - 1, ipdb.set_trace() 74 | assert d, ipdb.set_trace() 75 | assert ret == episode['info']['r'], (ret, episode['info']['r']) 76 | episode['obs'] = frames 77 | episode['infos'] = infos 78 | print(episode.keys()) 79 | yield episode 80 | 81 | 82 | class Animation(object): 83 | def __init__(self, episodes): 84 | self.episodes = episodes 85 | 86 | self.pause = False 87 | self.delta = 1 88 | self.j = 0 89 | 90 | self.fig = self.create_empty_figure() 91 | 92 | self.fig.canvas.mpl_connect('key_press_event', self.onKeyPress) 93 | 94 | self.axes = {} 95 | self.lines = {} 96 | self.dots = {} 97 | # self.ax1 = self.fig.add_subplot(1, 2, 1) 98 | # self.ax2 = self.fig.add_subplot(1, 2, 2) 99 | 100 | def create_empty_figure(self): 101 | fig = plt.figure() 102 | for evt, callback in fig.canvas.callbacks.callbacks.items(): 103 | items = list(callback.items()) 104 | for cid, _ in items: 105 | fig.canvas.mpl_disconnect(cid) 106 | return fig 107 | 108 | def onKeyPress(self, event): 109 | if event.key == 'left': 110 | self.pause = True 111 | if self.j > 0: 112 | self.j -= 1 113 | elif event.key == 'right': 114 | self.pause = True 115 | if self.j < len(self.episode['obs']) - 1: 116 | self.j += 1 117 | elif event.key == 'n': 118 | self.pause = False 119 | self.j = len(self.episode['obs']) - 1 120 | elif event.key == ' ': 121 | self.pause = not self.pause 122 | elif event.key == 'q': 123 | sys.exit() 124 | elif event.key == 'f': 125 | self.delta = 1 if self.delta > 1 else 8 126 | elif event.key == 'b': 127 | self.j = max(self.j-100, 0) 128 | 129 | def create_axes(self, episode): 130 | assert self.axes == {} 131 | keys = [key for key in episode.keys() if key not in ['acs', 'infos', 'obs', 'info']] 132 | keys.insert(0, 'obs') 133 | n_rows = int(np.floor(np.sqrt(len(keys)))) 134 | n_cols = int(np.ceil(len(keys) / n_rows)) 135 | for i, key in enumerate(keys, start=1): 136 | self.axes[key] = self.fig.add_subplot(n_rows, n_cols, i) 137 | 138 | def process_frame(self, frame): 139 | if frame.shape[-1] == 3: 140 | return frame 141 | else: 142 | return frame[:, :, -1] 143 | 144 | def run(self): 145 | self.episode = next(self.episodes) 146 | 147 | if self.axes == {}: 148 | self.create_axes(self.episode) 149 | 150 | self.im = self.axes['obs'].imshow(self.process_frame(self.episode['obs'][0]), cmap='gray') 151 | for key in self.axes: 152 | if key != 'obs': 153 | line, = self.axes[key].plot(self.episode[key], alpha=0.5) 154 | dot = matplotlib.patches.Ellipse(xy=(0, 0), width=1, height=0.0001, color='r') 155 | self.axes[key].add_artist(dot) 156 | self.axes[key].set_title(key) 157 | self.lines[key] = line 158 | self.dots[key] = dot 159 | 160 | 161 | def draw_frame_i(i): 162 | # update the data 163 | if self.j == 0: 164 | for key in self.axes: 165 | if key != 'obs': 166 | data = self.episode[key] 167 | n_timesteps = len(data) 168 | self.lines[key].set_data(range(n_timesteps), data) 169 | self.axes[key].set_xlim(0, n_timesteps) 170 | min_y, max_y = np.min(data), np.max(data) 171 | self.axes[key].set_ylim(min_y, max_y) 172 | 173 | self.dots[key].height = (max_y - min_y) / 30. 174 | self.dots[key].width = n_timesteps / 30. 175 | self.im.set_data(self.process_frame(self.episode['obs'][self.j])) 176 | for key in self.axes: 177 | if key != 'obs': 178 | self.dots[key].center = (self.j, self.episode[key][self.j]) 179 | if not self.pause: 180 | self.j += self.delta 181 | if self.j > len(self.episode['obs']) - 1: 182 | self.episode = next(episodes) 183 | self.j = 0 184 | return [self.im] + list(self.lines.values()) + list(self.dots.values()) 185 | 186 | ani = animation.FuncAnimation(self.fig, draw_frame_i, blit=False, interval=1, 187 | repeat=False) 188 | plt.show() 189 | plt.close() 190 | 191 | 192 | if __name__ == '__main__': 193 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 194 | add_env_params(parser) 195 | parser.add_argument('--filter', type=str, default='none') 196 | parser.add_argument('--rew_min', type=int, default=0) 197 | parser.add_argument('--rew_max', type=int, default=np.inf) 198 | parser.add_argument('--tag', type=str, default=None) 199 | parser.add_argument('--kind', type=str, default='plot') 200 | parser.add_argument('--display', type=str, default='game', choices=['game', 'agent']) 201 | parser.add_argument('--skip', type=int, default=0) 202 | parser.add_argument('--room_number', type=lambda x: [int(_) for _ in x.split(',')], default=[15]) 203 | 204 | 205 | 206 | args = parser.parse_args().__dict__ 207 | folder = exptag.get_last_experiment_folder_by_tag(args['tag']) 208 | 209 | def date_from_folder(folder): 210 | assert folder.startswith('openai-') 211 | date_started = folder[len('openai-'):] 212 | return datetime.datetime.strptime(date_started, "%Y-%m-%d-%H-%M-%S-%f") 213 | 214 | date_started = date_from_folder(os.path.basename(folder)) 215 | machine_dir = os.path.dirname(folder) 216 | if machine_dir[-4:-1]=='-00': 217 | all_machine_dirs = glob.glob(machine_dir[:-1]+'*') 218 | else: 219 | all_machine_dirs = [machine_dir] 220 | other_folders = [] 221 | for machine_dir in all_machine_dirs: 222 | this_machine_other_folders = os.listdir(machine_dir) 223 | this_machine_other_folders = [f_ for f_ in this_machine_other_folders 224 | if f_.startswith("openai-") and abs((date_from_folder(f_) - date_started).total_seconds()) < 3] 225 | this_machine_other_folders = [os.path.join(machine_dir, f_) for f_ in this_machine_other_folders] 226 | other_folders.extend(this_machine_other_folders) 227 | 228 | filenames = [glob.glob(os.path.join(f_, "videos_*.pk")) for f_ in other_folders] 229 | assert all(len(files_) == 1 for files_ in filenames), filenames 230 | filenames = [files_[0] for files_ in filenames] 231 | 232 | env = make_atari(args['env'], max_episode_steps=args['max_episode_steps']) 233 | if args['display'] == 'agent': 234 | env = wrap_deepmind(env, frame_stack=4, clip_rewards=False) 235 | env.reset() 236 | un_env = env.unwrapped 237 | rend_shape = un_env.render(mode='rgb_array').shape 238 | episodes = EpisodeIterator(filenames).iterate() 239 | if args['kind'] == 'movie': 240 | import imageio 241 | import time 242 | for i, episode in enumerate(episodes): 243 | filename = os.path.expanduser('~/tmp/movie_{}.mp4'.format(time.time())) 244 | imageio.mimwrite(filename, episode["obs"], fps=30) 245 | print(filename) 246 | 247 | else: 248 | import matplotlib.patches 249 | import matplotlib.pyplot as plt 250 | import matplotlib.animation as animation 251 | print('left/right, space, n, q, f keys are special') 252 | Animation(episodes).run() 253 | -------------------------------------------------------------------------------- /run_atari.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import functools 3 | import os 4 | 5 | from baselines import logger 6 | from mpi4py import MPI 7 | import mpi_util 8 | import tf_util 9 | from cmd_util import make_atari_env, arg_parser 10 | from policies.cnn_gru_policy_dynamics import CnnGruPolicy 11 | from policies.cnn_policy_param_matched import CnnPolicy 12 | from ppo_agent import PpoAgent 13 | from utils import set_global_seeds 14 | from vec_env import VecFrameStack 15 | 16 | 17 | def train(*, env_id, num_env, hps, num_timesteps, seed): 18 | venv = VecFrameStack( 19 | make_atari_env(env_id, num_env, seed, wrapper_kwargs=dict(), 20 | start_index=num_env * MPI.COMM_WORLD.Get_rank(), 21 | max_episode_steps=hps.pop('max_episode_steps')), 22 | hps.pop('frame_stack')) 23 | # venv.score_multiple = {'Mario': 500, 24 | # 'MontezumaRevengeNoFrameskip-v4': 100, 25 | # 'GravitarNoFrameskip-v4': 250, 26 | # 'PrivateEyeNoFrameskip-v4': 500, 27 | # 'SolarisNoFrameskip-v4': None, 28 | # 'VentureNoFrameskip-v4': 200, 29 | # 'PitfallNoFrameskip-v4': 100, 30 | # }[env_id] 31 | venv.score_multiple = 1 32 | venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False 33 | ob_space = venv.observation_space 34 | ac_space = venv.action_space 35 | gamma = hps.pop('gamma') 36 | policy = {'rnn': CnnGruPolicy, 37 | 'cnn': CnnPolicy}[hps.pop('policy')] 38 | agent = PpoAgent( 39 | scope='ppo', 40 | ob_space=ob_space, 41 | ac_space=ac_space, 42 | stochpol_fn=functools.partial( 43 | policy, 44 | scope='pol', 45 | ob_space=ob_space, 46 | ac_space=ac_space, 47 | update_ob_stats_independently_per_gpu=hps.pop('update_ob_stats_independently_per_gpu'), 48 | proportion_of_exp_used_for_predictor_update=hps.pop('proportion_of_exp_used_for_predictor_update'), 49 | dynamics_bonus = hps.pop("dynamics_bonus") 50 | ), 51 | gamma=gamma, 52 | gamma_ext=hps.pop('gamma_ext'), 53 | lam=hps.pop('lam'), 54 | nepochs=hps.pop('nepochs'), 55 | nminibatches=hps.pop('nminibatches'), 56 | lr=hps.pop('lr'), 57 | cliprange=0.1, 58 | nsteps=128, 59 | ent_coef=0.001, 60 | max_grad_norm=hps.pop('max_grad_norm'), 61 | use_news=hps.pop("use_news"), 62 | comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None, 63 | update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'), 64 | int_coeff=hps.pop('int_coeff'), 65 | ext_coeff=hps.pop('ext_coeff'), 66 | ) 67 | agent.start_interaction([venv]) 68 | if hps.pop('update_ob_stats_from_random_agent'): 69 | agent.collect_random_statistics(num_timesteps=128*50) 70 | assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys()) 71 | 72 | counter = 0 73 | while True: 74 | info = agent.step() 75 | if info['update']: 76 | logger.logkvs(info['update']) 77 | logger.dumpkvs() 78 | counter += 1 79 | if agent.I.stats['tcount'] > num_timesteps: 80 | break 81 | 82 | agent.stop_interaction() 83 | 84 | 85 | def add_env_params(parser): 86 | parser.add_argument('--env', help='environment ID', default='MontezumaRevengeNoFrameskip-v4') 87 | parser.add_argument('--seed', help='RNG seed', type=int, default=0) 88 | parser.add_argument('--max_episode_steps', type=int, default=4500) 89 | 90 | 91 | def main(): 92 | parser = arg_parser() 93 | add_env_params(parser) 94 | parser.add_argument('--num-timesteps', type=int, default=int(1e12)) 95 | parser.add_argument('--num_env', type=int, default=32) 96 | parser.add_argument('--use_news', type=int, default=0) 97 | parser.add_argument('--gamma', type=float, default=0.99) 98 | parser.add_argument('--gamma_ext', type=float, default=0.99) 99 | parser.add_argument('--lam', type=float, default=0.95) 100 | parser.add_argument('--update_ob_stats_every_step', type=int, default=0) 101 | parser.add_argument('--update_ob_stats_independently_per_gpu', type=int, default=0) 102 | parser.add_argument('--update_ob_stats_from_random_agent', type=int, default=1) 103 | parser.add_argument('--proportion_of_exp_used_for_predictor_update', type=float, default=1.) 104 | parser.add_argument('--tag', type=str, default='') 105 | parser.add_argument('--policy', type=str, default='rnn', choices=['cnn', 'rnn']) 106 | parser.add_argument('--int_coeff', type=float, default=1.) 107 | parser.add_argument('--ext_coeff', type=float, default=2.) 108 | parser.add_argument('--dynamics_bonus', type=int, default=0) 109 | 110 | 111 | args = parser.parse_args() 112 | logger.configure(dir=logger.get_dir(), format_strs=['stdout', 'log', 'csv'] if MPI.COMM_WORLD.Get_rank() == 0 else []) 113 | if MPI.COMM_WORLD.Get_rank() == 0: 114 | with open(os.path.join(logger.get_dir(), 'experiment_tag.txt'), 'w') as f: 115 | f.write(args.tag) 116 | # shutil.copytree(os.path.dirname(os.path.abspath(__file__)), os.path.join(logger.get_dir(), 'code')) 117 | 118 | mpi_util.setup_mpi_gpus() 119 | 120 | seed = 10000 * args.seed + MPI.COMM_WORLD.Get_rank() 121 | set_global_seeds(seed) 122 | 123 | hps = dict( 124 | frame_stack=4, 125 | nminibatches=4, 126 | nepochs=4, 127 | lr=0.0001, 128 | max_grad_norm=0.0, 129 | use_news=args.use_news, 130 | gamma=args.gamma, 131 | gamma_ext=args.gamma_ext, 132 | max_episode_steps=args.max_episode_steps, 133 | lam=args.lam, 134 | update_ob_stats_every_step=args.update_ob_stats_every_step, 135 | update_ob_stats_independently_per_gpu=args.update_ob_stats_independently_per_gpu, 136 | update_ob_stats_from_random_agent=args.update_ob_stats_from_random_agent, 137 | proportion_of_exp_used_for_predictor_update=args.proportion_of_exp_used_for_predictor_update, 138 | policy=args.policy, 139 | int_coeff=args.int_coeff, 140 | ext_coeff=args.ext_coeff, 141 | dynamics_bonus = args.dynamics_bonus 142 | ) 143 | 144 | tf_util.make_session(make_default=True) 145 | train(env_id=args.env, num_env=args.num_env, seed=seed, 146 | num_timesteps=args.num_timesteps, hps=hps) 147 | 148 | 149 | if __name__ == '__main__': 150 | main() 151 | -------------------------------------------------------------------------------- /stochastic_policy.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from baselines.common.distributions import make_pdtype 3 | from collections import OrderedDict 4 | from gym import spaces 5 | 6 | def canonical_dtype(orig_dt): 7 | if orig_dt.kind == 'f': 8 | return tf.float32 9 | elif orig_dt.kind in 'iu': 10 | return tf.int32 11 | else: 12 | raise NotImplementedError 13 | 14 | class StochasticPolicy(object): 15 | def __init__(self, scope, ob_space, ac_space): 16 | self.abs_scope = (tf.get_variable_scope().name + '/' + scope).lstrip('/') 17 | self.ob_space = ob_space 18 | self.ac_space = ac_space 19 | self.pdtype = make_pdtype(ac_space) 20 | self.ph_new = tf.placeholder(dtype=tf.float32, shape=(None, None), name='new') 21 | self.ph_ob_keys = [] 22 | self.ph_ob_dtypes = {} 23 | shapes = {} 24 | if isinstance(ob_space, spaces.Dict): 25 | assert isinstance(ob_space.spaces, OrderedDict) 26 | for key, box in ob_space.spaces.items(): 27 | assert isinstance(box, spaces.Box) 28 | self.ph_ob_keys.append(key) 29 | # Keys must be ordered, because tf.concat(ph) depends on order. Here we don't keep OrderedDict 30 | # order and sort keys instead. Rationale is to give freedom to modify environment. 31 | self.ph_ob_keys.sort() 32 | for k in self.ph_ob_keys: 33 | self.ph_ob_dtypes[k] = ob_space.spaces[k].dtype 34 | shapes[k] = ob_space.spaces[k].shape 35 | else: 36 | print(ob_space) 37 | box = ob_space 38 | assert isinstance(box, spaces.Box) 39 | self.ph_ob_keys = [None] 40 | self.ph_ob_dtypes = { None: box.dtype } 41 | shapes = { None: box.shape } 42 | self.ph_ob = OrderedDict([(k, tf.placeholder( 43 | canonical_dtype(self.ph_ob_dtypes[k]), 44 | (None, None,) + tuple(shapes[k]), 45 | name=(('obs/%s'%k) if k is not None else 'obs') 46 | )) for k in self.ph_ob_keys ]) 47 | assert list(self.ph_ob.keys())==self.ph_ob_keys, "\n%s\n%s\n" % (list(self.ph_ob.keys()), self.ph_ob_keys) 48 | ob_shape = tf.shape(next(iter(self.ph_ob.values()))) 49 | self.sy_nenvs = ob_shape[0] 50 | self.sy_nsteps = ob_shape[1] 51 | self.ph_ac = self.pdtype.sample_placeholder([None, None], name='ac') 52 | self.pd = self.vpred = self.ph_istate = None 53 | 54 | def finalize(self, pd, vpred, ph_istate=None): #pylint: disable=W0221 55 | self.pd = pd 56 | self.vpred = vpred 57 | self.ph_istate = ph_istate 58 | 59 | def ensure_observation_is_dict(self, ob): 60 | if self.ph_ob_keys==[None]: 61 | return { None: ob } 62 | else: 63 | return ob 64 | 65 | def call(self, ob, new, istate): 66 | """ 67 | Return acs, vpred, neglogprob, nextstate 68 | """ 69 | raise NotImplementedError 70 | 71 | def initial_state(self, n): 72 | raise NotImplementedError 73 | 74 | def update_normalization(self, ob): 75 | pass 76 | -------------------------------------------------------------------------------- /tf_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf # pylint: ignore-module 3 | import copy 4 | import os 5 | import functools 6 | import collections 7 | import multiprocessing 8 | 9 | def switch(condition, then_expression, else_expression): 10 | """Switches between two operations depending on a scalar value (int or bool). 11 | Note that both `then_expression` and `else_expression` 12 | should be symbolic tensors of the *same shape*. 13 | 14 | # Arguments 15 | condition: scalar tensor. 16 | then_expression: TensorFlow operation. 17 | else_expression: TensorFlow operation. 18 | """ 19 | x_shape = copy.copy(then_expression.get_shape()) 20 | x = tf.cond(tf.cast(condition, 'bool'), 21 | lambda: then_expression, 22 | lambda: else_expression) 23 | x.set_shape(x_shape) 24 | return x 25 | 26 | # ================================================================ 27 | # Extras 28 | # ================================================================ 29 | 30 | def lrelu(x, leak=0.2): 31 | f1 = 0.5 * (1 + leak) 32 | f2 = 0.5 * (1 - leak) 33 | return f1 * x + f2 * abs(x) 34 | 35 | # ================================================================ 36 | # Mathematical utils 37 | # ================================================================ 38 | 39 | def huber_loss(x, delta=1.0): 40 | """Reference: https://en.wikipedia.org/wiki/Huber_loss""" 41 | return tf.where( 42 | tf.abs(x) < delta, 43 | tf.square(x) * 0.5, 44 | delta * (tf.abs(x) - 0.5 * delta) 45 | ) 46 | 47 | # ================================================================ 48 | # Global session 49 | # ================================================================ 50 | 51 | def make_session(num_cpu=None, make_default=False, graph=None): 52 | """Returns a session that will use CPU's only""" 53 | if num_cpu is None: 54 | num_cpu = int(os.getenv('RCALL_NUM_CPU', multiprocessing.cpu_count())) 55 | tf_config = tf.ConfigProto( 56 | inter_op_parallelism_threads=num_cpu, 57 | intra_op_parallelism_threads=num_cpu) 58 | if make_default: 59 | return tf.InteractiveSession(config=tf_config, graph=graph) 60 | else: 61 | return tf.Session(config=tf_config, graph=graph) 62 | 63 | def single_threaded_session(): 64 | """Returns a session which will only use a single CPU""" 65 | return make_session(num_cpu=1) 66 | 67 | def in_session(f): 68 | @functools.wraps(f) 69 | def newfunc(*args, **kwargs): 70 | with tf.Session(): 71 | f(*args, **kwargs) 72 | return newfunc 73 | 74 | ALREADY_INITIALIZED = set() 75 | 76 | def initialize(): 77 | """Initialize all the uninitialized variables in the global scope.""" 78 | new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED 79 | tf.get_default_session().run(tf.variables_initializer(new_variables)) 80 | ALREADY_INITIALIZED.update(new_variables) 81 | 82 | # ================================================================ 83 | # Model components 84 | # ================================================================ 85 | 86 | def normc_initializer(std=1.0, axis=0): 87 | def _initializer(shape, dtype=None, partition_info=None): # pylint: disable=W0613 88 | out = np.random.randn(*shape).astype(np.float32) 89 | out *= std / np.sqrt(np.square(out).sum(axis=axis, keepdims=True)) 90 | return tf.constant(out) 91 | return _initializer 92 | 93 | def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME", dtype=tf.float32, collections=None, 94 | summary_tag=None): 95 | with tf.variable_scope(name): 96 | stride_shape = [1, stride[0], stride[1], 1] 97 | filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]), num_filters] 98 | 99 | # there are "num input feature maps * filter height * filter width" 100 | # inputs to each hidden unit 101 | fan_in = intprod(filter_shape[:3]) 102 | # each unit in the lower layer receives a gradient from: 103 | # "num output feature maps * filter height * filter width" / 104 | # pooling size 105 | fan_out = intprod(filter_shape[:2]) * num_filters 106 | # initialize weights with random weights 107 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 108 | 109 | w = tf.get_variable("W", filter_shape, dtype, tf.random_uniform_initializer(-w_bound, w_bound), 110 | collections=collections) 111 | b = tf.get_variable("b", [1, 1, 1, num_filters], initializer=tf.zeros_initializer(), 112 | collections=collections) 113 | 114 | if summary_tag is not None: 115 | tf.summary.image(summary_tag, 116 | tf.transpose(tf.reshape(w, [filter_size[0], filter_size[1], -1, 1]), 117 | [2, 0, 1, 3]), 118 | max_images=10) 119 | 120 | return tf.nn.conv2d(x, w, stride_shape, pad) + b 121 | 122 | # ================================================================ 123 | # Theano-like Function 124 | # ================================================================ 125 | 126 | def function(inputs, outputs, updates=None, givens=None): 127 | """Just like Theano function. Take a bunch of tensorflow placeholders and expressions 128 | computed based on those placeholders and produces f(inputs) -> outputs. Function f takes 129 | values to be fed to the input's placeholders and produces the values of the expressions 130 | in outputs. 131 | 132 | Input values can be passed in the same order as inputs or can be provided as kwargs based 133 | on placeholder name (passed to constructor or accessible via placeholder.op.name). 134 | 135 | Example: 136 | x = tf.placeholder(tf.int32, (), name="x") 137 | y = tf.placeholder(tf.int32, (), name="y") 138 | z = 3 * x + 2 * y 139 | lin = function([x, y], z, givens={y: 0}) 140 | 141 | with single_threaded_session(): 142 | initialize() 143 | 144 | assert lin(2) == 6 145 | assert lin(x=3) == 9 146 | assert lin(2, 2) == 10 147 | assert lin(x=2, y=3) == 12 148 | 149 | Parameters 150 | ---------- 151 | inputs: [tf.placeholder, tf.constant, or object with make_feed_dict method] 152 | list of input arguments 153 | outputs: [tf.Variable] or tf.Variable 154 | list of outputs or a single output to be returned from function. Returned 155 | value will also have the same shape. 156 | """ 157 | if isinstance(outputs, list): 158 | return _Function(inputs, outputs, updates, givens=givens) 159 | elif isinstance(outputs, (dict, collections.OrderedDict)): 160 | f = _Function(inputs, outputs.values(), updates, givens=givens) 161 | return lambda *args, **kwargs: type(outputs)(zip(outputs.keys(), f(*args, **kwargs))) 162 | else: 163 | f = _Function(inputs, [outputs], updates, givens=givens) 164 | return lambda *args, **kwargs: f(*args, **kwargs)[0] 165 | 166 | 167 | class _Function(object): 168 | def __init__(self, inputs, outputs, updates, givens): 169 | for inpt in inputs: 170 | if not hasattr(inpt, 'make_feed_dict') and not (type(inpt) is tf.Tensor and len(inpt.op.inputs) == 0): 171 | assert False, "inputs should all be placeholders, constants, or have a make_feed_dict method" 172 | self.inputs = inputs 173 | updates = updates or [] 174 | self.update_group = tf.group(*updates) 175 | self.outputs_update = list(outputs) + [self.update_group] 176 | self.givens = {} if givens is None else givens 177 | 178 | def _feed_input(self, feed_dict, inpt, value): 179 | if hasattr(inpt, 'make_feed_dict'): 180 | feed_dict.update(inpt.make_feed_dict(value)) 181 | else: 182 | feed_dict[inpt] = value 183 | 184 | def __call__(self, *args): 185 | assert len(args) <= len(self.inputs), "Too many arguments provided" 186 | feed_dict = {} 187 | # Update the args 188 | for inpt, value in zip(self.inputs, args): 189 | self._feed_input(feed_dict, inpt, value) 190 | # Update feed dict with givens. 191 | for inpt in self.givens: 192 | feed_dict[inpt] = feed_dict.get(inpt, self.givens[inpt]) 193 | results = tf.get_default_session().run(self.outputs_update, feed_dict=feed_dict)[:-1] 194 | return results 195 | 196 | # ================================================================ 197 | # Flat vectors 198 | # ================================================================ 199 | 200 | def var_shape(x): 201 | out = x.get_shape().as_list() 202 | assert all(isinstance(a, int) for a in out), \ 203 | "shape function assumes that shape is fully known" 204 | return out 205 | 206 | def numel(x): 207 | return intprod(var_shape(x)) 208 | 209 | def intprod(x): 210 | return int(np.prod(x)) 211 | 212 | def flatgrad(loss, var_list, clip_norm=None): 213 | grads = tf.gradients(loss, var_list) 214 | if clip_norm is not None: 215 | grads = [tf.clip_by_norm(grad, clip_norm=clip_norm) for grad in grads] 216 | return tf.concat(axis=0, values=[ 217 | tf.reshape(grad if grad is not None else tf.zeros_like(v), [numel(v)]) 218 | for (v, grad) in zip(var_list, grads) 219 | ]) 220 | 221 | class SetFromFlat(object): 222 | def __init__(self, var_list, dtype=tf.float32): 223 | assigns = [] 224 | shapes = list(map(var_shape, var_list)) 225 | total_size = np.sum([intprod(shape) for shape in shapes]) 226 | 227 | self.theta = theta = tf.placeholder(dtype, [total_size]) 228 | start = 0 229 | assigns = [] 230 | for (shape, v) in zip(shapes, var_list): 231 | size = intprod(shape) 232 | assigns.append(tf.assign(v, tf.reshape(theta[start:start + size], shape))) 233 | start += size 234 | self.op = tf.group(*assigns) 235 | 236 | def __call__(self, theta): 237 | tf.get_default_session().run(self.op, feed_dict={self.theta: theta}) 238 | 239 | class GetFlat(object): 240 | def __init__(self, var_list): 241 | self.op = tf.concat(axis=0, values=[tf.reshape(v, [numel(v)]) for v in var_list]) 242 | 243 | def __call__(self): 244 | return tf.get_default_session().run(self.op) 245 | 246 | _PLACEHOLDER_CACHE = {} # name -> (placeholder, dtype, shape) 247 | 248 | def get_placeholder(name, dtype, shape): 249 | if name in _PLACEHOLDER_CACHE: 250 | out, dtype1, shape1 = _PLACEHOLDER_CACHE[name] 251 | assert dtype1 == dtype and shape1 == shape 252 | return out 253 | else: 254 | out = tf.placeholder(dtype=dtype, shape=shape, name=name) 255 | _PLACEHOLDER_CACHE[name] = (out, dtype, shape) 256 | return out 257 | 258 | def get_placeholder_cached(name): 259 | return _PLACEHOLDER_CACHE[name][0] 260 | 261 | def flattenallbut0(x): 262 | return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])]) 263 | 264 | 265 | # ================================================================ 266 | # Diagnostics 267 | # ================================================================ 268 | 269 | def display_var_info(vars): 270 | from baselines import logger 271 | count_params = 0 272 | for v in vars: 273 | name = v.name 274 | if "/Adam" in name or "beta1_power" in name or "beta2_power" in name: continue 275 | v_params = np.prod(v.shape.as_list()) 276 | count_params += v_params 277 | if "/b:" in name or "/biases" in name: continue # Wx+b, bias is not interesting to look at => count params, but not print 278 | logger.info(" %s%s %i params %s" % (name, " "*(55-len(name)), v_params, str(v.shape))) 279 | 280 | logger.info("Total model parameters: %0.2f million" % (count_params*1e-6)) 281 | 282 | 283 | def get_available_gpus(): 284 | # recipe from here: 285 | # https://stackoverflow.com/questions/38559755/how-to-get-current-available-gpus-in-tensorflow?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa 286 | 287 | from tensorflow.python.client import device_lib 288 | local_device_protos = device_lib.list_local_devices() 289 | return [x.name for x in local_device_protos if x.device_type == 'GPU'] 290 | 291 | # ================================================================ 292 | # Saving variables 293 | # ================================================================ 294 | 295 | def load_state(fname): 296 | saver = tf.train.Saver() 297 | saver.restore(tf.get_default_session(), fname) 298 | 299 | def save_state(fname): 300 | os.makedirs(os.path.dirname(fname), exist_ok=True) 301 | saver = tf.train.Saver() 302 | saver.save(tf.get_default_session(), fname) 303 | 304 | 305 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import random 4 | from mpi_util import mpi_moments 5 | 6 | 7 | def fc(x, scope, nh, *, init_scale=1.0, init_bias=0.0): 8 | with tf.variable_scope(scope): 9 | nin = x.get_shape()[1].value 10 | w = tf.get_variable("w", [nin, nh], initializer=ortho_init(init_scale)) 11 | b = tf.get_variable("b", [nh], initializer=tf.constant_initializer(init_bias)) 12 | return tf.matmul(x, w)+b 13 | 14 | def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0, data_format='NHWC', one_dim_bias=False): 15 | if data_format == 'NHWC': 16 | channel_ax = 3 17 | strides = [1, stride, stride, 1] 18 | bshape = [1, 1, 1, nf] 19 | elif data_format == 'NCHW': 20 | channel_ax = 1 21 | strides = [1, 1, stride, stride] 22 | bshape = [1, nf, 1, 1] 23 | else: 24 | raise NotImplementedError 25 | bias_var_shape = [nf] if one_dim_bias else [1, nf, 1, 1] 26 | nin = x.get_shape()[channel_ax].value 27 | wshape = [rf, rf, nin, nf] 28 | with tf.variable_scope(scope): 29 | w = tf.get_variable("w", wshape, initializer=ortho_init(init_scale)) 30 | b = tf.get_variable("b", bias_var_shape, initializer=tf.constant_initializer(0.0)) 31 | if not one_dim_bias and data_format == 'NHWC': 32 | b = tf.reshape(b, bshape) 33 | return b + tf.nn.conv2d(x, w, strides=strides, padding=pad, data_format=data_format) 34 | 35 | def ortho_init(scale=1.0): 36 | def _ortho_init(shape, dtype, partition_info=None): 37 | #lasagne ortho init for tf 38 | shape = tuple(shape) 39 | if len(shape) == 2: 40 | flat_shape = shape 41 | elif len(shape) == 4: # assumes NHWC 42 | flat_shape = (np.prod(shape[:-1]), shape[-1]) 43 | else: 44 | raise NotImplementedError 45 | a = np.random.normal(0.0, 1.0, flat_shape) 46 | u, _, v = np.linalg.svd(a, full_matrices=False) 47 | q = u if u.shape == flat_shape else v # pick the one with the correct shape 48 | q = q.reshape(shape) 49 | return (scale * q[:shape[0], :shape[1]]).astype(np.float32) 50 | return _ortho_init 51 | 52 | def tile_images(array, n_cols=None, max_images=None, div=1): 53 | if max_images is not None: 54 | array = array[:max_images] 55 | if len(array.shape) == 4 and array.shape[3] == 1: 56 | array = array[:, :, :, 0] 57 | assert len(array.shape) in [3, 4], "wrong number of dimensions - shape {}".format(array.shape) 58 | if len(array.shape) == 4: 59 | assert array.shape[3] == 3, "wrong number of channels- shape {}".format(array.shape) 60 | if n_cols is None: 61 | n_cols = max(int(np.sqrt(array.shape[0])) // div * div, div) 62 | n_rows = int(np.ceil(float(array.shape[0]) / n_cols)) 63 | 64 | def cell(i, j): 65 | ind = i * n_cols + j 66 | return array[ind] if ind < array.shape[0] else np.zeros(array[0].shape) 67 | 68 | def row(i): 69 | return np.concatenate([cell(i, j) for j in range(n_cols)], axis=1) 70 | 71 | return np.concatenate([row(i) for i in range(n_rows)], axis=0) 72 | 73 | 74 | def set_global_seeds(i): 75 | try: 76 | import tensorflow as tf 77 | except ImportError: 78 | pass 79 | else: 80 | from mpi4py import MPI 81 | tf.set_random_seed(i) 82 | np.random.seed(i) 83 | random.seed(i) 84 | 85 | 86 | def explained_variance_non_mpi(ypred,y): 87 | """ 88 | Computes fraction of variance that ypred explains about y. 89 | Returns 1 - Var[y-ypred] / Var[y] 90 | 91 | interpretation: 92 | ev=0 => might as well have predicted zero 93 | ev=1 => perfect prediction 94 | ev<0 => worse than just predicting zero 95 | 96 | """ 97 | assert y.ndim == 1 and ypred.ndim == 1 98 | vary = np.var(y) 99 | return np.nan if vary==0 else 1 - np.var(y-ypred)/vary 100 | 101 | def mpi_var(x): 102 | return mpi_moments(x)[1]**2 103 | 104 | def explained_variance(ypred,y): 105 | """ 106 | Computes fraction of variance that ypred explains about y. 107 | Returns 1 - Var[y-ypred] / Var[y] 108 | 109 | interpretation: 110 | ev=0 => might as well have predicted zero 111 | ev=1 => perfect prediction 112 | ev<0 => worse than just predicting zero 113 | 114 | """ 115 | assert y.ndim == 1 and ypred.ndim == 1 116 | vary = mpi_var(y) 117 | return np.nan if vary==0 else 1 - mpi_var(y-ypred)/vary 118 | 119 | -------------------------------------------------------------------------------- /vec_env.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from multiprocessing import Process, Pipe 3 | from baselines import logger 4 | from utils import tile_images 5 | 6 | class AlreadySteppingError(Exception): 7 | """ 8 | Raised when an asynchronous step is running while 9 | step_async() is called again. 10 | """ 11 | def __init__(self): 12 | msg = 'already running an async step' 13 | Exception.__init__(self, msg) 14 | 15 | class NotSteppingError(Exception): 16 | """ 17 | Raised when an asynchronous step is not running but 18 | step_wait() is called. 19 | """ 20 | def __init__(self): 21 | msg = 'not running an async step' 22 | Exception.__init__(self, msg) 23 | 24 | class VecEnv(ABC): 25 | """ 26 | An abstract asynchronous, vectorized environment. 27 | """ 28 | def __init__(self, num_envs, observation_space, action_space): 29 | self.num_envs = num_envs 30 | self.observation_space = observation_space 31 | self.action_space = action_space 32 | 33 | @abstractmethod 34 | def reset(self): 35 | """ 36 | Reset all the environments and return an array of 37 | observations, or a tuple of observation arrays. 38 | 39 | If step_async is still doing work, that work will 40 | be cancelled and step_wait() should not be called 41 | until step_async() is invoked again. 42 | """ 43 | pass 44 | 45 | @abstractmethod 46 | def step_async(self, actions): 47 | """ 48 | Tell all the environments to start taking a step 49 | with the given actions. 50 | Call step_wait() to get the results of the step. 51 | 52 | You should not call this if a step_async run is 53 | already pending. 54 | """ 55 | pass 56 | 57 | @abstractmethod 58 | def step_wait(self): 59 | """ 60 | Wait for the step taken with step_async(). 61 | 62 | Returns (obs, rews, dones, infos): 63 | - obs: an array of observations, or a tuple of 64 | arrays of observations. 65 | - rews: an array of rewards 66 | - dones: an array of "episode done" booleans 67 | - infos: a sequence of info objects 68 | """ 69 | pass 70 | 71 | @abstractmethod 72 | def close(self): 73 | """ 74 | Clean up the environments' resources. 75 | """ 76 | pass 77 | 78 | def step(self, actions): 79 | self.step_async(actions) 80 | return self.step_wait() 81 | 82 | def render(self, mode='human'): 83 | logger.warn('Render not defined for %s'%self) 84 | 85 | @property 86 | def unwrapped(self): 87 | if isinstance(self, VecEnvWrapper): 88 | return self.venv.unwrapped 89 | else: 90 | return self 91 | 92 | class VecEnvWrapper(VecEnv): 93 | def __init__(self, venv, observation_space=None, action_space=None): 94 | self.venv = venv 95 | VecEnv.__init__(self, 96 | num_envs=venv.num_envs, 97 | observation_space=observation_space or venv.observation_space, 98 | action_space=action_space or venv.action_space) 99 | 100 | def step_async(self, actions): 101 | self.venv.step_async(actions) 102 | 103 | @abstractmethod 104 | def reset(self): 105 | pass 106 | 107 | @abstractmethod 108 | def step_wait(self): 109 | pass 110 | 111 | def close(self): 112 | return self.venv.close() 113 | 114 | def render(self): 115 | self.venv.render() 116 | 117 | class CloudpickleWrapper(object): 118 | """ 119 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 120 | """ 121 | def __init__(self, x): 122 | self.x = x 123 | def __getstate__(self): 124 | import cloudpickle 125 | return cloudpickle.dumps(self.x) 126 | def __setstate__(self, ob): 127 | import pickle 128 | self.x = pickle.loads(ob) 129 | 130 | import numpy as np 131 | from gym import spaces 132 | 133 | class VecFrameStack(VecEnvWrapper): 134 | """ 135 | Vectorized environment base class 136 | """ 137 | def __init__(self, venv, nstack): 138 | self.venv = venv 139 | self.nstack = nstack 140 | wos = venv.observation_space # wrapped ob space 141 | low = np.repeat(wos.low, self.nstack, axis=-1) 142 | high = np.repeat(wos.high, self.nstack, axis=-1) 143 | self.stackedobs = np.zeros((venv.num_envs,)+low.shape, low.dtype) 144 | observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype) 145 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space) 146 | 147 | def step_wait(self): 148 | obs, rews, news, infos = self.venv.step_wait() 149 | self.stackedobs = np.roll(self.stackedobs, shift=-1, axis=-1) 150 | for (i, new) in enumerate(news): 151 | if new: 152 | self.stackedobs[i] = 0 153 | self.stackedobs[..., -obs.shape[-1]:] = obs 154 | return self.stackedobs, rews, news, infos 155 | 156 | def reset(self): 157 | """ 158 | Reset all environments 159 | """ 160 | obs = self.venv.reset() 161 | self.stackedobs[...] = 0 162 | self.stackedobs[..., -obs.shape[-1]:] = obs 163 | return self.stackedobs 164 | 165 | def close(self): 166 | self.venv.close() 167 | 168 | 169 | class VecFrameStack(VecEnvWrapper): 170 | """ 171 | Vectorized environment base class 172 | """ 173 | def __init__(self, venv, nstack): 174 | self.venv = venv 175 | self.nstack = nstack 176 | wos = venv.observation_space # wrapped ob space 177 | low = np.repeat(wos.low, self.nstack, axis=-1) 178 | high = np.repeat(wos.high, self.nstack, axis=-1) 179 | self.stackedobs = np.zeros((venv.num_envs,)+low.shape, low.dtype) 180 | observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype) 181 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space) 182 | 183 | def step_wait(self): 184 | obs, rews, news, infos = self.venv.step_wait() 185 | self.stackedobs = np.roll(self.stackedobs, shift=-1, axis=-1) 186 | for (i, new) in enumerate(news): 187 | if new: 188 | self.stackedobs[i] = 0 189 | self.stackedobs[..., -obs.shape[-1]:] = obs 190 | return self.stackedobs, rews, news, infos 191 | 192 | def reset(self): 193 | """ 194 | Reset all environments 195 | """ 196 | obs = self.venv.reset() 197 | self.stackedobs[...] = 0 198 | self.stackedobs[..., -obs.shape[-1]:] = obs 199 | return self.stackedobs 200 | 201 | def close(self): 202 | self.venv.close() 203 | 204 | 205 | 206 | 207 | 208 | def worker(remote, parent_remote, env_fn_wrapper): 209 | parent_remote.close() 210 | env = env_fn_wrapper.x() 211 | while True: 212 | cmd, data = remote.recv() 213 | if cmd == 'step': 214 | ob, reward, done, info = env.step(data) 215 | if done: 216 | ob = env.reset() 217 | remote.send((ob, reward, done, info)) 218 | elif cmd == 'reset': 219 | ob = env.reset() 220 | remote.send(ob) 221 | elif cmd == 'render': 222 | remote.send(env.render(mode='rgb_array')) 223 | elif cmd == 'close': 224 | remote.close() 225 | break 226 | elif cmd == 'get_spaces': 227 | remote.send((env.observation_space, env.action_space)) 228 | else: 229 | raise NotImplementedError 230 | 231 | 232 | class SubprocVecEnv(VecEnv): 233 | def __init__(self, env_fns, spaces=None): 234 | """ 235 | envs: list of gym environments to run in subprocesses 236 | """ 237 | self.waiting = False 238 | self.closed = False 239 | nenvs = len(env_fns) 240 | self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) 241 | self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) 242 | for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] 243 | for p in self.ps: 244 | p.daemon = True # if the main process crashes, we should not cause things to hang 245 | p.start() 246 | for remote in self.work_remotes: 247 | remote.close() 248 | 249 | self.remotes[0].send(('get_spaces', None)) 250 | observation_space, action_space = self.remotes[0].recv() 251 | VecEnv.__init__(self, len(env_fns), observation_space, action_space) 252 | 253 | def step_async(self, actions): 254 | for remote, action in zip(self.remotes, actions): 255 | remote.send(('step', action)) 256 | self.waiting = True 257 | 258 | def step_wait(self): 259 | results = [remote.recv() for remote in self.remotes] 260 | self.waiting = False 261 | obs, rews, dones, infos = zip(*results) 262 | return np.stack(obs), np.stack(rews), np.stack(dones), infos 263 | 264 | def reset(self): 265 | for remote in self.remotes: 266 | remote.send(('reset', None)) 267 | return np.stack([remote.recv() for remote in self.remotes]) 268 | 269 | def reset_task(self): 270 | for remote in self.remotes: 271 | remote.send(('reset_task', None)) 272 | return np.stack([remote.recv() for remote in self.remotes]) 273 | 274 | def close(self): 275 | if self.closed: 276 | return 277 | if self.waiting: 278 | for remote in self.remotes: 279 | remote.recv() 280 | for remote in self.remotes: 281 | remote.send(('close', None)) 282 | for p in self.ps: 283 | p.join() 284 | self.closed = True 285 | 286 | def render(self, mode='human'): 287 | for pipe in self.remotes: 288 | pipe.send(('render', None)) 289 | imgs = [pipe.recv() for pipe in self.remotes] 290 | bigimg = tile_images(imgs) 291 | if mode == 'human': 292 | import cv2 293 | cv2.imshow('vecenv', bigimg[:,:,::-1]) 294 | cv2.waitKey(1) 295 | elif mode == 'rgb_array': 296 | return bigimg 297 | else: 298 | raise NotImplementedError 299 | --------------------------------------------------------------------------------