├── .gitignore
├── README.md
├── atari_wrappers.py
├── cmd_util.py
├── console_util.py
├── load_log.py
├── monitor.py
├── mpi_util.py
├── policies
├── __init__.py
├── cnn_gru_policy_dynamics.py
└── cnn_policy_param_matched.py
├── ppo_agent.py
├── recorder.py
├── replayer.py
├── run_atari.py
├── stochastic_policy.py
├── tf_util.py
├── utils.py
└── vec_env.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | **Status:** Archive (code is provided as-is, no updates expected)
2 |
3 | ## [Exploration by Random Network Distillation](https://arxiv.org/abs/1810.12894) ##
4 |
5 |
6 | Yuri Burda*, Harri Edwards*, Amos Storkey, Oleg Klimov
7 | *equal contribution
8 |
9 | OpenAI
10 | University of Edinburgh
11 |
12 |
13 | ### Installation and Usage
14 | The following command should train an RND agent on Montezuma's Revenge
15 | ```bash
16 | python run_atari.py --gamma_ext 0.999
17 | ```
18 | To use more than one gpu/machine, use MPI (e.g. `mpiexec -n 8 python run_atari.py --num_env 128 --gamma_ext 0.999` should use 1024 parallel environments to collect experience on an 8 gpu machine).
19 |
20 | ### [Blog post and videos](https://blog.openai.com/reinforcement-learning-with-prediction-based-rewards/)
21 |
--------------------------------------------------------------------------------
/atari_wrappers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from collections import deque
3 | import gym
4 | from gym import spaces
5 | import cv2
6 | from copy import copy
7 |
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | def unwrap(env):
11 | if hasattr(env, "unwrapped"):
12 | return env.unwrapped
13 | elif hasattr(env, "env"):
14 | return unwrap(env.env)
15 | elif hasattr(env, "leg_env"):
16 | return unwrap(env.leg_env)
17 | else:
18 | return env
19 |
20 | class MaxAndSkipEnv(gym.Wrapper):
21 | def __init__(self, env, skip=4):
22 | """Return only every `skip`-th frame"""
23 | gym.Wrapper.__init__(self, env)
24 | # most recent raw observations (for max pooling across time steps)
25 | self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
26 | self._skip = skip
27 |
28 | def step(self, action):
29 | """Repeat action, sum reward, and max over last observations."""
30 | total_reward = 0.0
31 | done = None
32 | for i in range(self._skip):
33 | obs, reward, done, info = self.env.step(action)
34 | if i == self._skip - 2: self._obs_buffer[0] = obs
35 | if i == self._skip - 1: self._obs_buffer[1] = obs
36 | total_reward += reward
37 | if done:
38 | break
39 | # Note that the observation on the done=True frame
40 | # doesn't matter
41 | max_frame = self._obs_buffer.max(axis=0)
42 |
43 | return max_frame, total_reward, done, info
44 |
45 | def reset(self, **kwargs):
46 | return self.env.reset(**kwargs)
47 |
48 | class ClipRewardEnv(gym.RewardWrapper):
49 | def __init__(self, env):
50 | gym.RewardWrapper.__init__(self, env)
51 |
52 | def reward(self, reward):
53 | """Bin reward to {+1, 0, -1} by its sign."""
54 | return float(np.sign(reward))
55 |
56 | class WarpFrame(gym.ObservationWrapper):
57 | def __init__(self, env):
58 | """Warp frames to 84x84 as done in the Nature paper and later work."""
59 | gym.ObservationWrapper.__init__(self, env)
60 | self.width = 84
61 | self.height = 84
62 | self.observation_space = spaces.Box(low=0, high=255,
63 | shape=(self.height, self.width, 1), dtype=np.uint8)
64 |
65 | def observation(self, frame):
66 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
67 | frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
68 | return frame[:, :, None]
69 |
70 | class FrameStack(gym.Wrapper):
71 | def __init__(self, env, k):
72 | """Stack k last frames.
73 |
74 | Returns lazy array, which is much more memory efficient.
75 |
76 | See Also
77 | --------
78 | rl_common.atari_wrappers.LazyFrames
79 | """
80 | gym.Wrapper.__init__(self, env)
81 | self.k = k
82 | self.frames = deque([], maxlen=k)
83 | shp = env.observation_space.shape
84 | self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=np.uint8)
85 |
86 | def reset(self):
87 | ob = self.env.reset()
88 | for _ in range(self.k):
89 | self.frames.append(ob)
90 | return self._get_ob()
91 |
92 | def step(self, action):
93 | ob, reward, done, info = self.env.step(action)
94 | self.frames.append(ob)
95 | return self._get_ob(), reward, done, info
96 |
97 | def _get_ob(self):
98 | assert len(self.frames) == self.k
99 | return LazyFrames(list(self.frames))
100 |
101 | class ScaledFloatFrame(gym.ObservationWrapper):
102 | def __init__(self, env):
103 | gym.ObservationWrapper.__init__(self, env)
104 |
105 | def observation(self, observation):
106 | # careful! This undoes the memory optimization, use
107 | # with smaller replay buffers only.
108 | return np.array(observation).astype(np.float32) / 255.0
109 |
110 | class LazyFrames(object):
111 | def __init__(self, frames):
112 | """This object ensures that common frames between the observations are only stored once.
113 | It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
114 | buffers.
115 |
116 | This object should only be converted to numpy array before being passed to the model.
117 |
118 | You'd not believe how complex the previous solution was."""
119 | self._frames = frames
120 | self._out = None
121 |
122 | def _force(self):
123 | if self._out is None:
124 | self._out = np.concatenate(self._frames, axis=2)
125 | self._frames = None
126 | return self._out
127 |
128 | def __array__(self, dtype=None):
129 | out = self._force()
130 | if dtype is not None:
131 | out = out.astype(dtype)
132 | return out
133 |
134 | def __len__(self):
135 | return len(self._force())
136 |
137 | def __getitem__(self, i):
138 | return self._force()[i]
139 |
140 | class MontezumaInfoWrapper(gym.Wrapper):
141 | def __init__(self, env, room_address):
142 | super(MontezumaInfoWrapper, self).__init__(env)
143 | self.room_address = room_address
144 | self.visited_rooms = set()
145 |
146 | def get_current_room(self):
147 | ram = unwrap(self.env).ale.getRAM()
148 | assert len(ram) == 128
149 | return int(ram[self.room_address])
150 |
151 | def step(self, action):
152 | obs, rew, done, info = self.env.step(action)
153 | self.visited_rooms.add(self.get_current_room())
154 | if done:
155 | if 'episode' not in info:
156 | info['episode'] = {}
157 | info['episode'].update(visited_rooms=copy(self.visited_rooms))
158 | self.visited_rooms.clear()
159 | return obs, rew, done, info
160 |
161 | def reset(self):
162 | return self.env.reset()
163 |
164 | class DummyMontezumaInfoWrapper(gym.Wrapper):
165 |
166 | def __init__(self, env):
167 | super(DummyMontezumaInfoWrapper, self).__init__(env)
168 |
169 | def step(self, action):
170 | obs, rew, done, info = self.env.step(action)
171 | if done:
172 | if 'episode' not in info:
173 | info['episode'] = {}
174 | info['episode'].update(pos_count=0,
175 | visited_rooms=set([0]))
176 | return obs, rew, done, info
177 |
178 | def reset(self):
179 | return self.env.reset()
180 |
181 | class AddRandomStateToInfo(gym.Wrapper):
182 | def __init__(self, env):
183 | """Adds the random state to the info field on the first step after reset
184 | """
185 | gym.Wrapper.__init__(self, env)
186 |
187 | def step(self, action):
188 | ob, r, d, info = self.env.step(action)
189 | if d:
190 | if 'episode' not in info:
191 | info['episode'] = {}
192 | info['episode']['rng_at_episode_start'] = self.rng_at_episode_start
193 | return ob, r, d, info
194 |
195 | def reset(self, **kwargs):
196 | self.rng_at_episode_start = copy(self.unwrapped.np_random)
197 | return self.env.reset(**kwargs)
198 |
199 |
200 | def make_atari(env_id, max_episode_steps=4500):
201 | env = gym.make(env_id)
202 | env._max_episode_steps = max_episode_steps*4
203 | assert 'NoFrameskip' in env.spec.id
204 | env = StickyActionEnv(env)
205 | env = MaxAndSkipEnv(env, skip=4)
206 | if "Montezuma" in env_id or "Pitfall" in env_id:
207 | env = MontezumaInfoWrapper(env, room_address=3 if "Montezuma" in env_id else 1)
208 | else:
209 | env = DummyMontezumaInfoWrapper(env)
210 | env = AddRandomStateToInfo(env)
211 | return env
212 |
213 | def wrap_deepmind(env, clip_rewards=True, frame_stack=False, scale=False):
214 | """Configure environment for DeepMind-style Atari.
215 | """
216 | env = WarpFrame(env)
217 | if scale:
218 | env = ScaledFloatFrame(env)
219 | if clip_rewards:
220 | env = ClipRewardEnv(env)
221 | if frame_stack:
222 | env = FrameStack(env, 4)
223 | # env = NormalizeObservation(env)
224 | return env
225 |
226 |
227 | class StickyActionEnv(gym.Wrapper):
228 | def __init__(self, env, p=0.25):
229 | super(StickyActionEnv, self).__init__(env)
230 | self.p = p
231 | self.last_action = 0
232 |
233 | def reset(self):
234 | self.last_action = 0
235 | return self.env.reset()
236 |
237 | def step(self, action):
238 | if self.unwrapped.np_random.uniform() < self.p:
239 | action = self.last_action
240 | self.last_action = action
241 | obs, reward, done, info = self.env.step(action)
242 | return obs, reward, done, info
243 |
--------------------------------------------------------------------------------
/cmd_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for scripts like run_atari.py.
3 | """
4 |
5 | import os
6 |
7 | import gym
8 | from gym.wrappers import FlattenDictWrapper
9 | from mpi4py import MPI
10 | from baselines import logger
11 | from monitor import Monitor
12 | from atari_wrappers import make_atari, wrap_deepmind
13 | from vec_env import SubprocVecEnv
14 |
15 |
16 | def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0, max_episode_steps=4500):
17 | """
18 | Create a wrapped, monitored SubprocVecEnv for Atari.
19 | """
20 | if wrapper_kwargs is None: wrapper_kwargs = {}
21 | def make_env(rank): # pylint: disable=C0111
22 | def _thunk():
23 | env = make_atari(env_id, max_episode_steps=max_episode_steps)
24 | env.seed(seed + rank)
25 | env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)), allow_early_resets=True)
26 | return wrap_deepmind(env, **wrapper_kwargs)
27 | return _thunk
28 | # set_global_seeds(seed)
29 | return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
30 |
31 | def arg_parser():
32 | """
33 | Create an empty argparse.ArgumentParser.
34 | """
35 | import argparse
36 | return argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
37 |
38 | def atari_arg_parser():
39 | """
40 | Create an argparse.ArgumentParser for run_atari.py.
41 | """
42 | parser = arg_parser()
43 | parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
44 | parser.add_argument('--seed', help='RNG seed', type=int, default=0)
45 | parser.add_argument('--num-timesteps', type=int, default=int(10e6))
46 | return parser
47 |
--------------------------------------------------------------------------------
/console_util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from contextlib import contextmanager
3 | import numpy as np
4 | import time
5 |
6 | # ================================================================
7 | # Misc
8 | # ================================================================
9 |
10 | def fmt_row(width, row, header=False):
11 | out = " | ".join(fmt_item(x, width) for x in row)
12 | if header: out = out + "\n" + "-"*len(out)
13 | return out
14 |
15 | def fmt_item(x, l):
16 | if isinstance(x, np.ndarray):
17 | assert x.ndim==0
18 | x = x.item()
19 | if isinstance(x, (float, np.float32, np.float64)):
20 | v = abs(x)
21 | if (v < 1e-4 or v > 1e+4) and v > 0:
22 | rep = "%7.2e" % x
23 | else:
24 | rep = "%7.5f" % x
25 | else: rep = str(x)
26 | return " "*(l - len(rep)) + rep
27 |
28 | color2num = dict(
29 | gray=30,
30 | red=31,
31 | green=32,
32 | yellow=33,
33 | blue=34,
34 | magenta=35,
35 | cyan=36,
36 | white=37,
37 | crimson=38
38 | )
39 |
40 | def colorize(string, color, bold=False, highlight=False):
41 | attr = []
42 | num = color2num[color]
43 | if highlight: num += 10
44 | attr.append(str(num))
45 | if bold: attr.append('1')
46 | return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string)
47 |
48 |
49 | MESSAGE_DEPTH = 0
50 |
51 | @contextmanager
52 | def timed(msg):
53 | global MESSAGE_DEPTH #pylint: disable=W0603
54 | print(colorize('\t'*MESSAGE_DEPTH + '=: ' + msg, color='magenta'))
55 | tstart = time.time()
56 | MESSAGE_DEPTH += 1
57 | yield
58 | MESSAGE_DEPTH -= 1
59 | print(colorize('\t'*MESSAGE_DEPTH + "done in %.3f seconds"%(time.time() - tstart), color='magenta'))
60 |
--------------------------------------------------------------------------------
/load_log.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import os
4 | import pickle
5 | import exptag
6 |
7 | separator = '------------'
8 |
9 | def parse(key, value):
10 | value = value.strip()
11 | try:
12 | if value.startswith('['):
13 | value = value[1:-1].split(';')
14 | if value != ['']:
15 | try:
16 | value = [int(v) for v in value]
17 | except:
18 | value = [str(v) for v in value]
19 | else:
20 | value = []
21 | elif ';' in value:
22 | value = 0.
23 | elif value in ['nan', '', '-inf']:
24 | value = np.nan
25 | else:
26 | value = eval(value)
27 | except:
28 | import ipdb; ipdb.set_trace()
29 | print(f"failed to parse value {key}:{value.__repr__()}")
30 | value = 0.
31 |
32 | return value
33 |
34 | def get_hash(filename):
35 | import hashlib
36 | hash_md5 = hashlib.md5()
37 | with open(filename, "rb") as f:
38 | for chunk in iter(lambda: f.read(4096), b""):
39 | hash_md5.update(chunk)
40 | return hash_md5.hexdigest()
41 |
42 | def pickle_cache_result(f):
43 | def cached_f(filename):
44 | current_hash = get_hash(filename)
45 | cache_filename = filename + '_cache'
46 | if os.path.exists(cache_filename):
47 | with open(cache_filename, 'rb') as fl:
48 | try:
49 | stored_hash, stored_result = pickle.load(fl)
50 | if stored_hash == current_hash:
51 | # pass
52 | return stored_result
53 | except:
54 | pass
55 | result = f(filename)
56 | with open(cache_filename, 'wb') as fl:
57 | pickle.dump((current_hash, result), fl)
58 | return result
59 | return cached_f
60 |
61 | @pickle_cache_result
62 | def parse_csv(filename):
63 | import csv
64 |
65 | timeseries = {}
66 | keys = []
67 | with open(filename, 'r') as f:
68 | reader = csv.reader(f)
69 | for i, row in enumerate(reader):
70 | if i == 0:
71 | keys = row
72 | else:
73 | for column, key in enumerate(keys):
74 | if key not in timeseries:
75 | timeseries[key] = []
76 | timeseries[key].append(parse(key, row[column]))
77 | print(f'parsing {filename}')
78 | if 'opt_featvar' in timeseries:
79 | timeseries['opt_feat_var'] = timeseries['opt_featvar']
80 | return timeseries
81 |
82 |
83 | def get_filename_from_tag(tag):
84 | folder = exptag.get_last_experiment_folder_by_tag(tag)
85 | return os.path.join(folder, "progress.csv")
86 |
87 | def get_filenames_from_tags(tags):
88 | return [get_filename_from_tag(tag) for tag in tags]
89 |
90 | def get_timeseries_from_filenames(filenames):
91 | return [parse_csv(f) for f in filenames]
92 |
93 |
94 | def get_timeseries_from_tags(tags):
95 | return get_timeseries_from_filenames(get_filenames_from_tags(tags))
96 |
97 |
98 | if __name__ == '__main__':
99 | parser = argparse.ArgumentParser()
100 | parser.add_argument('--tags', type=lambda x: x.split(','), nargs='+', default=None)
101 | parser.add_argument('--x_axis', type=str, choices=['tcount', 'n_updates'], default='tcount')
102 | args = parser.parse_args()
103 |
104 | x_axis = args.x_axis
105 |
106 | timeseries_groups = []
107 | for tag_group in args.tags:
108 | timeseries_groups.append(get_timeseries_from_tags(tags=tag_group))
109 | for tag_group, timeseries in zip(args.tags, timeseries_groups):
110 | rooms = []
111 | for tag, t in zip(tag_group, timeseries):
112 | if 'rooms' in t:
113 | rooms.append((tag, t['rooms'][-1], t['best_ret'][-1], t[x_axis][-1]))
114 | else:
115 | print(f"tag {tag} has no rooms")
116 | rooms = sorted(rooms, key=lambda x: len(x[1]))
117 | all_rooms = set.union(*[set(r[1]) for r in rooms])
118 | import pprint
119 | for tag, r, best_ret, max_x in rooms:
120 | print(f'{tag}:{best_ret}:{r}(@{max_x})')
121 | pprint.pprint(all_rooms)
122 | keys = set.intersection(*[set(t.keys()) for t in sum(timeseries_groups, [])])
123 |
124 | keys = sorted(list(keys))
125 | import matplotlib.pyplot as plt
126 |
127 | n_rows = int(np.ceil(np.sqrt(len(keys))))
128 | n_cols = len(keys) // n_rows + 1
129 | fig, axes = plt.subplots(n_rows, n_cols, sharex=True)
130 | for i in range(n_rows):
131 | for j in range(n_cols):
132 | ind = i * n_cols + j
133 | if ind < len(keys):
134 | key = keys[ind]
135 | for color, timeseries in zip('rgbykcm', timeseries_groups):
136 | if key in ['all_places', 'recent_places', 'global_all_places', 'global_recent_places', 'rooms']:
137 | if any(isinstance(t[key][-1], list) for t in timeseries):
138 | for t in timeseries:
139 | t[key] = list(map(len, t[key]))
140 | max_timesteps = min((len(_[x_axis]) for _ in timeseries))
141 | try:
142 | data = np.asarray([t[key][:max_timesteps] for t in timeseries], dtype=np.float32)
143 | except:
144 | import ipdb; ipdb.set_trace()
145 | lines = [np.nan_to_num(d[key]) for d in timeseries]
146 | lines_x = [np.asarray(d[x_axis]) for d in timeseries]
147 | alphas = [0.2/np.sqrt(len(lines)) for l in lines]
148 | lines += [np.nan_to_num(np.nanmean(data, 0))]
149 | alphas += [1.]
150 | lines_x += [np.asarray(timeseries[0][x_axis][:max_timesteps])]
151 | for alpha, y, x in zip(alphas, lines, lines_x):
152 | axes[i, j].plot(x, y, color=color, alpha=alpha)
153 | axes[i, j].set_title(key)
154 |
155 | plt.show()
156 | plt.close()
--------------------------------------------------------------------------------
/monitor.py:
--------------------------------------------------------------------------------
1 | __all__ = ['Monitor', 'get_monitor_files', 'load_results']
2 |
3 | import gym
4 | from gym.core import Wrapper
5 | import time
6 | from glob import glob
7 | import csv
8 | import os.path as osp
9 | import json
10 | import numpy as np
11 |
12 | class Monitor(Wrapper):
13 | EXT = "monitor.csv"
14 | f = None
15 |
16 | def __init__(self, env, filename, allow_early_resets=False, reset_keywords=(), info_keywords=()):
17 | Wrapper.__init__(self, env=env)
18 | self.tstart = time.time()
19 | if filename is None:
20 | self.f = None
21 | self.logger = None
22 | else:
23 | if not filename.endswith(Monitor.EXT):
24 | if osp.isdir(filename):
25 | filename = osp.join(filename, Monitor.EXT)
26 | else:
27 | filename = filename + "." + Monitor.EXT
28 | self.f = open(filename, "wt")
29 | self.f.write('#%s\n'%json.dumps({"t_start": self.tstart, 'env_id' : env.spec and env.spec.id}))
30 | self.logger = csv.DictWriter(self.f, fieldnames=('r', 'l', 't')+reset_keywords+info_keywords)
31 | self.logger.writeheader()
32 | self.f.flush()
33 |
34 | self.reset_keywords = reset_keywords
35 | self.info_keywords = info_keywords
36 | self.allow_early_resets = allow_early_resets
37 | self.rewards = None
38 | self.needs_reset = True
39 | self.episode_rewards = []
40 | self.episode_lengths = []
41 | self.episode_times = []
42 | self.total_steps = 0
43 | self.current_reset_info = {} # extra info about the current episode, that was passed in during reset()
44 |
45 | def reset(self, **kwargs):
46 | if not self.allow_early_resets and not self.needs_reset:
47 | raise RuntimeError("Tried to reset an environment before done. If you want to allow early resets, wrap your env with Monitor(env, path, allow_early_resets=True)")
48 | self.rewards = []
49 | self.needs_reset = False
50 | for k in self.reset_keywords:
51 | v = kwargs.get(k)
52 | if v is None:
53 | raise ValueError('Expected you to pass kwarg %s into reset'%k)
54 | self.current_reset_info[k] = v
55 | return self.env.reset(**kwargs)
56 |
57 | def step(self, action):
58 | if self.needs_reset:
59 | raise RuntimeError("Tried to step environment that needs reset")
60 | ob, rew, done, info = self.env.step(action)
61 | self.rewards.append(rew)
62 | if done:
63 | self.needs_reset = True
64 | eprew = sum(self.rewards)
65 | eplen = len(self.rewards)
66 | epinfo = {"r": round(eprew, 6), "l": eplen, "t": round(time.time() - self.tstart, 6)}
67 | for k in self.info_keywords:
68 | epinfo[k] = info[k]
69 | self.episode_rewards.append(eprew)
70 | self.episode_lengths.append(eplen)
71 | self.episode_times.append(time.time() - self.tstart)
72 | epinfo.update(self.current_reset_info)
73 | if self.logger:
74 | self.logger.writerow(epinfo)
75 | self.f.flush()
76 | if "episode" not in info:
77 | info["episoide"] = {}
78 | info['episode'].update(epinfo)
79 | self.total_steps += 1
80 | return (ob, rew, done, info)
81 |
82 | def close(self):
83 | if self.f is not None:
84 | self.f.close()
85 |
86 | def get_total_steps(self):
87 | return self.total_steps
88 |
89 | def get_episode_rewards(self):
90 | return self.episode_rewards
91 |
92 | def get_episode_lengths(self):
93 | return self.episode_lengths
94 |
95 | def get_episode_times(self):
96 | return self.episode_times
97 |
98 | class LoadMonitorResultsError(Exception):
99 | pass
100 |
101 | def get_monitor_files(dir):
102 | return glob(osp.join(dir, "*" + Monitor.EXT))
103 |
104 | def load_results(dir):
105 | import pandas
106 | monitor_files = (
107 | glob(osp.join(dir, "*monitor.json")) +
108 | glob(osp.join(dir, "*monitor.csv"))) # get both csv and (old) json files
109 | if not monitor_files:
110 | raise LoadMonitorResultsError("no monitor files of the form *%s found in %s" % (Monitor.EXT, dir))
111 | dfs = []
112 | headers = []
113 | for fname in monitor_files:
114 | with open(fname, 'rt') as fh:
115 | if fname.endswith('csv'):
116 | firstline = fh.readline()
117 | assert firstline[0] == '#'
118 | header = json.loads(firstline[1:])
119 | df = pandas.read_csv(fh, index_col=None)
120 | headers.append(header)
121 | elif fname.endswith('json'): # Deprecated json format
122 | episodes = []
123 | lines = fh.readlines()
124 | header = json.loads(lines[0])
125 | headers.append(header)
126 | for line in lines[1:]:
127 | episode = json.loads(line)
128 | episodes.append(episode)
129 | df = pandas.DataFrame(episodes)
130 | else:
131 | assert 0, 'unreachable'
132 | df['t'] += header['t_start']
133 | dfs.append(df)
134 | df = pandas.concat(dfs)
135 | df.sort_values('t', inplace=True)
136 | df.reset_index(inplace=True)
137 | df['t'] -= min(header['t_start'] for header in headers)
138 | df.headers = headers # HACK to preserve backwards compatibility
139 | return df
140 |
141 | def test_monitor():
142 | env = gym.make("CartPole-v1")
143 | env.seed(0)
144 | mon_file = "/tmp/baselines-test-%s.monitor.csv" % uuid.uuid4()
145 | menv = Monitor(env, mon_file)
146 | menv.reset()
147 | for _ in range(1000):
148 | _, _, done, _ = menv.step(0)
149 | if done:
150 | menv.reset()
151 |
152 | f = open(mon_file, 'rt')
153 |
154 | firstline = f.readline()
155 | assert firstline.startswith('#')
156 | metadata = json.loads(firstline[1:])
157 | assert metadata['env_id'] == "CartPole-v1"
158 | assert set(metadata.keys()) == {'env_id', 'gym_version', 't_start'}, "Incorrect keys in monitor metadata"
159 |
160 | last_logline = pandas.read_csv(f, index_col=None)
161 | assert set(last_logline.keys()) == {'l', 't', 'r'}, "Incorrect keys in monitor logline"
162 | f.close()
163 | os.remove(mon_file)
--------------------------------------------------------------------------------
/mpi_util.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from mpi4py import MPI
3 | import os, numpy as np
4 | import platform
5 | import tensorflow as tf
6 |
7 | def sync_from_root(sess, variables, comm=None):
8 | """
9 | Send the root node's parameters to every worker.
10 | Arguments:
11 | sess: the TensorFlow session.
12 | variables: all parameter variables including optimizer's
13 | """
14 | if comm is None: comm = MPI.COMM_WORLD
15 | rank = comm.Get_rank()
16 | for var in variables:
17 | if rank == 0:
18 | comm.bcast(sess.run(var))
19 | else:
20 | import tensorflow as tf
21 | sess.run(tf.assign(var, comm.bcast(None)))
22 |
23 | # def gpu_count():
24 | # """
25 | # Count the GPUs on this machine.
26 | # """
27 | # if shutil.which('nvidia-smi') is None:
28 | # return 0
29 | # output = subprocess.check_output(['nvidia-smi', '--query-gpu=gpu_name', '--format=csv'])
30 | # return max(0, len(output.split(b'\n')) - 2)
31 | #
32 | # def setup_mpi_gpus():
33 | # """
34 | # Set CUDA_VISIBLE_DEVICES using MPI.
35 | # """
36 | # num_gpus = gpu_count()
37 | # if num_gpus == 0:
38 | # return
39 | # local_rank, _ = get_local_rank_size(MPI.COMM_WORLD)
40 | # os.environ['CUDA_VISIBLE_DEVICES'] = str(local_rank % num_gpus)
41 |
42 | def guess_available_gpus(n_gpus=None):
43 | if n_gpus is not None:
44 | return list(range(n_gpus))
45 | if 'CUDA_VISIBLE_DEVICES' in os.environ:
46 | cuda_visible_divices = os.environ['CUDA_VISIBLE_DEVICES']
47 | cuda_visible_divices = cuda_visible_divices.split(',')
48 | return [int(n) for n in cuda_visible_divices]
49 | if 'RCALL_NUM_GPU' not in os.environ:
50 | n_gpus = int(os.environ['RCALL_NUM_GPU'])
51 | return list(range(n_gpus))
52 | nvidia_dir = '/proc/driver/nvidia/gpus/'
53 | if os.path.exists(nvidia_dir):
54 | n_gpus = len(os.listdir(nvidia_dir))
55 | return list(range(n_gpus))
56 | raise Exception("Couldn't guess the available gpus on this machine")
57 |
58 |
59 | def setup_mpi_gpus():
60 | """
61 | Set CUDA_VISIBLE_DEVICES using MPI.
62 | """
63 | available_gpus = guess_available_gpus()
64 |
65 | node_id = platform.node()
66 | nodes_ordered_by_rank = MPI.COMM_WORLD.allgather(node_id)
67 | processes_outranked_on_this_node = [n for n in nodes_ordered_by_rank[:MPI.COMM_WORLD.Get_rank()] if n == node_id]
68 | local_rank = len(processes_outranked_on_this_node)
69 |
70 | os.environ['CUDA_VISIBLE_DEVICES'] = str(available_gpus[local_rank])
71 |
72 |
73 | def get_local_rank_size(comm):
74 | """
75 | Returns the rank of each process on its machine
76 | The processes on a given machine will be assigned ranks
77 | 0, 1, 2, ..., N-1,
78 | where N is the number of processes on this machine.
79 |
80 | Useful if you want to assign one gpu per machine
81 | """
82 | this_node = platform.node()
83 | ranks_nodes = comm.allgather((comm.Get_rank(), this_node))
84 | node2rankssofar = defaultdict(int)
85 | local_rank = None
86 | for (rank, node) in ranks_nodes:
87 | if rank == comm.Get_rank():
88 | local_rank = node2rankssofar[node]
89 | node2rankssofar[node] += 1
90 | assert local_rank is not None
91 | return local_rank, node2rankssofar[this_node]
92 |
93 | def share_file(comm, path):
94 | """
95 | Copies the file from rank 0 to all other ranks
96 | Puts it in the same place on all machines
97 | """
98 | localrank, _ = get_local_rank_size(comm)
99 | if comm.Get_rank() == 0:
100 | with open(path, 'rb') as fh:
101 | data = fh.read()
102 | comm.bcast(data)
103 | else:
104 | data = comm.bcast(None)
105 | if localrank == 0:
106 | os.makedirs(os.path.dirname(path), exist_ok=True)
107 | with open(path, 'wb') as fh:
108 | fh.write(data)
109 | comm.Barrier()
110 |
111 | def dict_gather_mean(comm, d):
112 | alldicts = comm.allgather(d)
113 | size = comm.Get_size()
114 | k2li = defaultdict(list)
115 | for d in alldicts:
116 | for (k,v) in d.items():
117 | k2li[k].append(v)
118 | k2mean = {}
119 | for (k,li) in k2li.items():
120 | k2mean[k] = np.mean(li, axis=0) if len(li) == size else np.nan
121 | return k2mean
122 |
123 | class MpiAdamOptimizer(tf.train.AdamOptimizer):
124 | """Adam optimizer that averages gradients across mpi processes."""
125 | def __init__(self, comm, **kwargs):
126 | self.comm = comm
127 | tf.train.AdamOptimizer.__init__(self, **kwargs)
128 | def compute_gradients(self, loss, var_list, **kwargs):
129 | grads_and_vars = tf.train.AdamOptimizer.compute_gradients(self, loss, var_list, **kwargs)
130 | grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None]
131 | flat_grad = tf.concat([tf.reshape(g, (-1,)) for g, v in grads_and_vars], axis=0)
132 | shapes = [v.shape.as_list() for g, v in grads_and_vars]
133 | sizes = [int(np.prod(s)) for s in shapes]
134 |
135 | num_tasks = self.comm.Get_size()
136 | buf = np.zeros(sum(sizes), np.float32)
137 |
138 | def _collect_grads(flat_grad):
139 | self.comm.Allreduce(flat_grad, buf, op=MPI.SUM)
140 | np.divide(buf, float(num_tasks), out=buf)
141 | return buf
142 |
143 | avg_flat_grad = tf.py_func(_collect_grads, [flat_grad], tf.float32)
144 | avg_flat_grad.set_shape(flat_grad.shape)
145 | avg_grads = tf.split(avg_flat_grad, sizes, axis=0)
146 | avg_grads_and_vars = [(tf.reshape(g, v.shape), v)
147 | for g, (_, v) in zip(avg_grads, grads_and_vars)]
148 |
149 | return avg_grads_and_vars
150 |
151 |
152 | def mpi_mean(x, axis=0, comm=None, keepdims=False):
153 | x = np.asarray(x)
154 | assert x.ndim > 0
155 | if comm is None: comm = MPI.COMM_WORLD
156 | xsum = x.sum(axis=axis, keepdims=keepdims)
157 | n = xsum.size
158 | localsum = np.zeros(n+1, x.dtype)
159 | localsum[:n] = xsum.ravel()
160 | localsum[n] = x.shape[axis]
161 | globalsum = np.zeros_like(localsum)
162 | comm.Allreduce(localsum, globalsum, op=MPI.SUM)
163 | return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n]
164 |
165 | def mpi_moments(x, axis=0, comm=None, keepdims=False):
166 | x = np.asarray(x)
167 | assert x.ndim > 0
168 | mean, count = mpi_mean(x, axis=axis, comm=comm, keepdims=True)
169 | sqdiffs = np.square(x - mean)
170 | meansqdiff, count1 = mpi_mean(sqdiffs, axis=axis, comm=comm, keepdims=True)
171 | assert count1 == count
172 | std = np.sqrt(meansqdiff)
173 | if not keepdims:
174 | newshape = mean.shape[:axis] + mean.shape[axis+1:]
175 | mean = mean.reshape(newshape)
176 | std = std.reshape(newshape)
177 | return mean, std, count
178 |
179 | class RunningMeanStd(object):
180 | # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
181 | def __init__(self, epsilon=1e-4, shape=(), comm=None, use_mpi=True):
182 | self.mean = np.zeros(shape, 'float64')
183 | self.use_mpi = use_mpi
184 | self.var = np.ones(shape, 'float64')
185 | self.count = epsilon
186 | if comm is None:
187 | from mpi4py import MPI
188 | comm = MPI.COMM_WORLD
189 | self.comm = comm
190 |
191 |
192 | def update(self, x):
193 | if self.use_mpi:
194 | batch_mean, batch_std, batch_count = mpi_moments(x, axis=0, comm=self.comm)
195 | else:
196 | batch_mean, batch_std, batch_count = np.mean(x, axis=0), np.std(x, axis=0), x.shape[0]
197 | batch_var = np.square(batch_std)
198 | self.update_from_moments(batch_mean, batch_var, batch_count)
199 |
200 | def update_from_moments(self, batch_mean, batch_var, batch_count):
201 | delta = batch_mean - self.mean
202 | tot_count = self.count + batch_count
203 |
204 | new_mean = self.mean + delta * batch_count / tot_count
205 | m_a = self.var * (self.count)
206 | m_b = batch_var * (batch_count)
207 | M2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count)
208 | new_var = M2 / (self.count + batch_count)
209 |
210 | new_count = batch_count + self.count
211 |
212 | self.mean = new_mean
213 | self.var = new_var
214 | self.count = new_count
215 |
--------------------------------------------------------------------------------
/policies/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/random-network-distillation/f75c0f1efa473d5109d487062fd8ed49ddce6634/policies/__init__.py
--------------------------------------------------------------------------------
/policies/cnn_gru_policy_dynamics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from baselines import logger
4 | from utils import fc, conv
5 | from stochastic_policy import StochasticPolicy
6 | from tf_util import get_available_gpus
7 | from mpi_util import RunningMeanStd
8 |
9 |
10 | def to2d(x):
11 | size = 1
12 | for shapel in x.get_shape()[1:]: size *= shapel.value
13 | return tf.reshape(x, (-1, size))
14 |
15 |
16 |
17 | class GRUCell(tf.nn.rnn_cell.RNNCell):
18 | """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078)."""
19 | def __init__(self, num_units, rec_gate_init=-1.0):
20 | tf.nn.rnn_cell.RNNCell.__init__(self)
21 | self._num_units = num_units
22 | self.rec_gate_init = rec_gate_init
23 | @property
24 | def state_size(self):
25 | return self._num_units
26 | @property
27 | def output_size(self):
28 | return self._num_units
29 | def call(self, inputs, state):
30 | """Gated recurrent unit (GRU) with nunits cells."""
31 | x, new = inputs
32 | h = state
33 | h *= (1.0 - new)
34 | hx = tf.concat([h, x], axis=1)
35 | mr = tf.sigmoid(fc(hx, nh=self._num_units * 2, scope='mr', init_bias=self.rec_gate_init))
36 | # r: read strength. m: 'member strength
37 | m, r = tf.split(mr, 2, axis=1)
38 | rh_x = tf.concat([r * h, x], axis=1)
39 | htil = tf.tanh(fc(rh_x, nh=self._num_units, scope='htil'))
40 | h = m * h + (1.0 - m) * htil
41 | return h, h
42 |
43 | class CnnGruPolicy(StochasticPolicy):
44 | def __init__(self, scope, ob_space, ac_space,
45 | policy_size='normal', maxpool=False, extrahid=True, hidsize=128, memsize=128, rec_gate_init=0.0,
46 | update_ob_stats_independently_per_gpu=True,
47 | proportion_of_exp_used_for_predictor_update=1.,
48 | dynamics_bonus = False,
49 | ):
50 | StochasticPolicy.__init__(self, scope, ob_space, ac_space)
51 | self.proportion_of_exp_used_for_predictor_update = proportion_of_exp_used_for_predictor_update
52 | enlargement = {
53 | 'small': 1,
54 | 'normal': 2,
55 | 'large': 4
56 | }[policy_size]
57 | rep_size = 512
58 | self.ph_mean = tf.placeholder(dtype=tf.float32, shape=list(ob_space.shape[:2])+[1], name="obmean")
59 | self.ph_std = tf.placeholder(dtype=tf.float32, shape=list(ob_space.shape[:2])+[1], name="obstd")
60 | memsize *= enlargement
61 | hidsize *= enlargement
62 | convfeat = 16*enlargement
63 | self.ob_rms = RunningMeanStd(shape=list(ob_space.shape[:2])+[1], use_mpi=not update_ob_stats_independently_per_gpu)
64 | ph_istate = tf.placeholder(dtype=tf.float32,shape=(None,memsize), name='state')
65 | pdparamsize = self.pdtype.param_shape()[0]
66 | self.memsize = memsize
67 |
68 | self.pdparam_opt, self.vpred_int_opt, self.vpred_ext_opt, self.snext_opt = \
69 | self.apply_policy(self.ph_ob[None][:,:-1],
70 | ph_new=self.ph_new,
71 | ph_istate=ph_istate,
72 | reuse=False,
73 | scope=scope,
74 | hidsize=hidsize,
75 | memsize=memsize,
76 | extrahid=extrahid,
77 | sy_nenvs=self.sy_nenvs,
78 | sy_nsteps=self.sy_nsteps - 1,
79 | pdparamsize=pdparamsize,
80 | rec_gate_init=rec_gate_init
81 | )
82 | self.pdparam_rollout, self.vpred_int_rollout, self.vpred_ext_rollout, self.snext_rollout = \
83 | self.apply_policy(self.ph_ob[None],
84 | ph_new=self.ph_new,
85 | ph_istate=ph_istate,
86 | reuse=True,
87 | scope=scope,
88 | hidsize=hidsize,
89 | memsize=memsize,
90 | extrahid=extrahid,
91 | sy_nenvs=self.sy_nenvs,
92 | sy_nsteps=self.sy_nsteps,
93 | pdparamsize=pdparamsize,
94 | rec_gate_init=rec_gate_init
95 | )
96 | if dynamics_bonus:
97 | self.define_dynamics_prediction_rew(convfeat=convfeat, rep_size=rep_size, enlargement=enlargement)
98 | else:
99 | self.define_self_prediction_rew(convfeat=convfeat, rep_size=rep_size, enlargement=enlargement)
100 |
101 |
102 |
103 | pd = self.pdtype.pdfromflat(self.pdparam_rollout)
104 | self.a_samp = pd.sample()
105 | self.nlp_samp = pd.neglogp(self.a_samp)
106 | self.entropy_rollout = pd.entropy()
107 | self.pd_rollout = pd
108 |
109 | self.pd_opt = self.pdtype.pdfromflat(self.pdparam_opt)
110 |
111 | self.ph_istate = ph_istate
112 |
113 | @staticmethod
114 | def apply_policy(ph_ob, ph_new, ph_istate, reuse, scope, hidsize, memsize, extrahid, sy_nenvs, sy_nsteps, pdparamsize, rec_gate_init):
115 | data_format = 'NHWC'
116 | ph = ph_ob
117 | assert len(ph.shape.as_list()) == 5 # B,T,H,W,C
118 | logger.info("CnnGruPolicy: using '%s' shape %s as image input" % (ph.name, str(ph.shape)))
119 | X = tf.cast(ph, tf.float32) / 255.
120 | X = tf.reshape(X, (-1, *ph.shape.as_list()[-3:]))
121 |
122 | activ = tf.nn.relu
123 | yes_gpu = any(get_available_gpus())
124 |
125 | with tf.variable_scope(scope, reuse=reuse), tf.device('/gpu:0' if yes_gpu else '/cpu:0'):
126 | X = activ(conv(X, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2), data_format=data_format))
127 | X = activ(conv(X, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), data_format=data_format))
128 | X = activ(conv(X, 'c3', nf=64, rf=4, stride=1, init_scale=np.sqrt(2), data_format=data_format))
129 | X = to2d(X)
130 | X = activ(fc(X, 'fc1', nh=hidsize, init_scale=np.sqrt(2)))
131 | X = tf.reshape(X, [sy_nenvs, sy_nsteps, hidsize])
132 | X, snext = tf.nn.dynamic_rnn(
133 | GRUCell(memsize, rec_gate_init=rec_gate_init), (X, ph_new[:,:,None]),
134 | dtype=tf.float32, time_major=False, initial_state=ph_istate)
135 | X = tf.reshape(X, (-1, memsize))
136 | Xtout = X
137 | if extrahid:
138 | Xtout = X + activ(fc(Xtout, 'fc2val', nh=memsize, init_scale=0.1))
139 | X = X + activ(fc(X, 'fc2act', nh=memsize, init_scale=0.1))
140 | pdparam = fc(X, 'pd', nh=pdparamsize, init_scale=0.01)
141 | vpred_int = fc(Xtout, 'vf_int', nh=1, init_scale=0.01)
142 | vpred_ext = fc(Xtout, 'vf_ext', nh=1, init_scale=0.01)
143 |
144 | pdparam = tf.reshape(pdparam, (sy_nenvs, sy_nsteps, pdparamsize))
145 | vpred_int = tf.reshape(vpred_int, (sy_nenvs, sy_nsteps))
146 | vpred_ext = tf.reshape(vpred_ext, (sy_nenvs, sy_nsteps))
147 | return pdparam, vpred_int, vpred_ext, snext
148 |
149 | def define_self_prediction_rew(self, convfeat, rep_size, enlargement):
150 | #RND.
151 | # Random target network.
152 | for ph in self.ph_ob.values():
153 | if len(ph.shape.as_list()) == 5: # B,T,H,W,C
154 | logger.info("CnnTarget: using '%s' shape %s as image input" % (ph.name, str(ph.shape)))
155 | xr = ph[:,1:]
156 | xr = tf.cast(xr, tf.float32)
157 | xr = tf.reshape(xr, (-1, *ph.shape.as_list()[-3:]))[:, :, :, -1:]
158 | xr = tf.clip_by_value((xr - self.ph_mean) / self.ph_std, -5.0, 5.0)
159 |
160 | xr = tf.nn.leaky_relu(conv(xr, 'c1r', nf=convfeat * 1, rf=8, stride=4, init_scale=np.sqrt(2)))
161 | xr = tf.nn.leaky_relu(conv(xr, 'c2r', nf=convfeat * 2 * 1, rf=4, stride=2, init_scale=np.sqrt(2)))
162 | xr = tf.nn.leaky_relu(conv(xr, 'c3r', nf=convfeat * 2 * 1, rf=3, stride=1, init_scale=np.sqrt(2)))
163 | rgbr = [to2d(xr)]
164 | X_r = fc(rgbr[0], 'fc1r', nh=rep_size, init_scale=np.sqrt(2))
165 |
166 | # Predictor network.
167 | for ph in self.ph_ob.values():
168 | if len(ph.shape.as_list()) == 5: # B,T,H,W,C
169 | logger.info("CnnTarget: using '%s' shape %s as image input" % (ph.name, str(ph.shape)))
170 | xrp = ph[:,1:]
171 | xrp = tf.cast(xrp, tf.float32)
172 | xrp = tf.reshape(xrp, (-1, *ph.shape.as_list()[-3:]))[:, :, :, -1:]
173 | xrp = tf.clip_by_value((xrp - self.ph_mean) / self.ph_std, -5.0, 5.0)
174 |
175 | xrp = tf.nn.leaky_relu(conv(xrp, 'c1rp_pred', nf=convfeat, rf=8, stride=4, init_scale=np.sqrt(2)))
176 | xrp = tf.nn.leaky_relu(conv(xrp, 'c2rp_pred', nf=convfeat * 2, rf=4, stride=2, init_scale=np.sqrt(2)))
177 | xrp = tf.nn.leaky_relu(conv(xrp, 'c3rp_pred', nf=convfeat * 2, rf=3, stride=1, init_scale=np.sqrt(2)))
178 | rgbrp = to2d(xrp)
179 | X_r_hat = tf.nn.relu(fc(rgbrp, 'fc1r_hat1_pred', nh=256 * enlargement, init_scale=np.sqrt(2)))
180 | X_r_hat = tf.nn.relu(fc(X_r_hat, 'fc1r_hat2_pred', nh=256 * enlargement, init_scale=np.sqrt(2)))
181 | X_r_hat = fc(X_r_hat, 'fc1r_hat3_pred', nh=rep_size, init_scale=np.sqrt(2))
182 |
183 | self.feat_var = tf.reduce_mean(tf.nn.moments(X_r, axes=[0])[1])
184 | self.max_feat = tf.reduce_max(tf.abs(X_r))
185 | self.int_rew = tf.reduce_mean(tf.square(tf.stop_gradient(X_r) - X_r_hat), axis=-1, keep_dims=True)
186 | self.int_rew = tf.reshape(self.int_rew, (self.sy_nenvs, self.sy_nsteps - 1))
187 |
188 | noisy_targets = tf.stop_gradient(X_r)
189 | self.aux_loss = tf.reduce_mean(tf.square(noisy_targets - X_r_hat), -1)
190 | mask = tf.random_uniform(shape=tf.shape(self.aux_loss), minval=0., maxval=1., dtype=tf.float32)
191 | mask = tf.cast(mask < self.proportion_of_exp_used_for_predictor_update, tf.float32)
192 | self.aux_loss = tf.reduce_sum(mask * self.aux_loss) / tf.maximum(tf.reduce_sum(mask), 1.)
193 |
194 | def define_dynamics_prediction_rew(self, convfeat, rep_size, enlargement):
195 | #Dynamics based bonus.
196 |
197 | # Random target network.
198 | for ph in self.ph_ob.values():
199 | if len(ph.shape.as_list()) == 5: # B,T,H,W,C
200 | logger.info("CnnTarget: using '%s' shape %s as image input" % (ph.name, str(ph.shape)))
201 | xr = ph[:,1:]
202 | xr = tf.cast(xr, tf.float32)
203 | xr = tf.reshape(xr, (-1, *ph.shape.as_list()[-3:]))[:, :, :, -1:]
204 | xr = tf.clip_by_value((xr - self.ph_mean) / self.ph_std, -5.0, 5.0)
205 |
206 | xr = tf.nn.leaky_relu(conv(xr, 'c1r', nf=convfeat * 1, rf=8, stride=4, init_scale=np.sqrt(2)))
207 | xr = tf.nn.leaky_relu(conv(xr, 'c2r', nf=convfeat * 2 * 1, rf=4, stride=2, init_scale=np.sqrt(2)))
208 | xr = tf.nn.leaky_relu(conv(xr, 'c3r', nf=convfeat * 2 * 1, rf=3, stride=1, init_scale=np.sqrt(2)))
209 | rgbr = [to2d(xr)]
210 | X_r = fc(rgbr[0], 'fc1r', nh=rep_size, init_scale=np.sqrt(2))
211 |
212 | # Predictor network.
213 | ac_one_hot = tf.one_hot(self.ph_ac, self.ac_space.n, axis=2)
214 | assert ac_one_hot.get_shape().ndims == 3
215 | assert ac_one_hot.get_shape().as_list() == [None, None, self.ac_space.n], ac_one_hot.get_shape().as_list()
216 | ac_one_hot = tf.reshape(ac_one_hot, (-1, self.ac_space.n))
217 | def cond(x):
218 | return tf.concat([x, ac_one_hot], 1)
219 |
220 | for ph in self.ph_ob.values():
221 | if len(ph.shape.as_list()) == 5: # B,T,H,W,C
222 | logger.info("CnnTarget: using '%s' shape %s as image input" % (ph.name, str(ph.shape)))
223 | xrp = ph[:,:-1]
224 | xrp = tf.cast(xrp, tf.float32)
225 | xrp = tf.reshape(xrp, (-1, *ph.shape.as_list()[-3:]))
226 | # ph_mean, ph_std are 84x84x1, so we subtract the average of the last channel from all channels. Is this ok?
227 | xrp = tf.clip_by_value((xrp - self.ph_mean) / self.ph_std, -5.0, 5.0)
228 |
229 | xrp = tf.nn.leaky_relu(conv(xrp, 'c1rp_pred', nf=convfeat, rf=8, stride=4, init_scale=np.sqrt(2)))
230 | xrp = tf.nn.leaky_relu(conv(xrp, 'c2rp_pred', nf=convfeat * 2, rf=4, stride=2, init_scale=np.sqrt(2)))
231 | xrp = tf.nn.leaky_relu(conv(xrp, 'c3rp_pred', nf=convfeat * 2, rf=3, stride=1, init_scale=np.sqrt(2)))
232 | rgbrp = to2d(xrp)
233 |
234 | # X_r_hat = tf.nn.relu(fc(rgb[0], 'fc1r_hat1', nh=256 * enlargement, init_scale=np.sqrt(2)))
235 | X_r_hat = tf.nn.relu(fc(cond(rgbrp), 'fc1r_hat1_pred', nh=256 * enlargement, init_scale=np.sqrt(2)))
236 | X_r_hat = tf.nn.relu(fc(cond(X_r_hat), 'fc1r_hat2_pred', nh=256 * enlargement, init_scale=np.sqrt(2)))
237 | X_r_hat = fc(cond(X_r_hat), 'fc1r_hat3_pred', nh=rep_size, init_scale=np.sqrt(2))
238 |
239 | self.feat_var = tf.reduce_mean(tf.nn.moments(X_r, axes=[0])[1])
240 | self.max_feat = tf.reduce_max(tf.abs(X_r))
241 | self.int_rew = tf.reduce_mean(tf.square(tf.stop_gradient(X_r) - X_r_hat), axis=-1, keep_dims=True)
242 | self.int_rew = tf.reshape(self.int_rew, (self.sy_nenvs, self.sy_nsteps - 1))
243 |
244 | noisy_targets = tf.stop_gradient(X_r)
245 | self.aux_loss = tf.reduce_mean(tf.square(noisy_targets - X_r_hat), -1)
246 | mask = tf.random_uniform(shape=tf.shape(self.aux_loss), minval=0., maxval=1., dtype=tf.float32)
247 | mask = tf.cast(mask < self.proportion_of_exp_used_for_predictor_update, tf.float32)
248 | self.aux_loss = tf.reduce_sum(mask * self.aux_loss) / tf.maximum(tf.reduce_sum(mask), 1.)
249 |
250 | def initial_state(self, n):
251 | return np.zeros((n, self.memsize), np.float32)
252 |
253 | def call(self, dict_obs, new, istate, update_obs_stats=False):
254 | for ob in dict_obs.values():
255 | if ob is not None:
256 | if update_obs_stats:
257 | raise NotImplementedError
258 | ob = ob.astype(np.float32)
259 | ob = ob.reshape(-1, *self.ob_space.shape)
260 | self.ob_rms.update(ob)
261 | # Note: if it fails here with ph vs observations inconsistency, check if you're loading agent from disk.
262 | # It will use whatever observation spaces saved to disk along with other ctor params.
263 | feed1 = { self.ph_ob[k]: dict_obs[k][:,None] for k in self.ph_ob_keys }
264 | feed2 = { self.ph_istate: istate, self.ph_new: new[:,None].astype(np.float32) }
265 | feed1.update({self.ph_mean: self.ob_rms.mean, self.ph_std: self.ob_rms.var ** 0.5})
266 | # for f in feed1:
267 | # print(f)
268 | a, vpred_int,vpred_ext, nlp, newstate, ent = tf.get_default_session().run(
269 | [self.a_samp, self.vpred_int_rollout,self.vpred_ext_rollout, self.nlp_samp, self.snext_rollout, self.entropy_rollout],
270 | feed_dict={**feed1, **feed2})
271 | return a[:,0], vpred_int[:,0],vpred_ext[:,0], nlp[:,0], newstate, ent[:,0]
272 |
--------------------------------------------------------------------------------
/policies/cnn_policy_param_matched.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from baselines import logger
4 | from utils import fc, conv, ortho_init
5 | from stochastic_policy import StochasticPolicy
6 | from tf_util import get_available_gpus
7 | from mpi_util import RunningMeanStd
8 |
9 |
10 | def to2d(x):
11 | size = 1
12 | for shapel in x.get_shape()[1:]: size *= shapel.value
13 | return tf.reshape(x, (-1, size))
14 |
15 | def _fcnobias(x, scope, nh, *, init_scale=1.0):
16 | with tf.variable_scope(scope):
17 | nin = x.get_shape()[1].value
18 | w = tf.get_variable("w", [nin, nh], initializer=ortho_init(init_scale))
19 | return tf.matmul(x, w)
20 | def _normalize(x):
21 | eps = 1e-5
22 | mean, var = tf.nn.moments(x, axes=(-1,), keepdims=True)
23 | return (x - mean) / tf.sqrt(var + eps)
24 |
25 |
26 | class CnnPolicy(StochasticPolicy):
27 | def __init__(self, scope, ob_space, ac_space,
28 | policy_size='normal', maxpool=False, extrahid=True, hidsize=128, memsize=128, rec_gate_init=0.0,
29 | update_ob_stats_independently_per_gpu=True,
30 | proportion_of_exp_used_for_predictor_update=1.,
31 | dynamics_bonus = False,
32 | ):
33 | StochasticPolicy.__init__(self, scope, ob_space, ac_space)
34 | self.proportion_of_exp_used_for_predictor_update = proportion_of_exp_used_for_predictor_update
35 | enlargement = {
36 | 'small': 1,
37 | 'normal': 2,
38 | 'large': 4
39 | }[policy_size]
40 | rep_size = 512
41 | self.ph_mean = tf.placeholder(dtype=tf.float32, shape=list(ob_space.shape[:2])+[1], name="obmean")
42 | self.ph_std = tf.placeholder(dtype=tf.float32, shape=list(ob_space.shape[:2])+[1], name="obstd")
43 | memsize *= enlargement
44 | hidsize *= enlargement
45 | convfeat = 16*enlargement
46 | self.ob_rms = RunningMeanStd(shape=list(ob_space.shape[:2])+[1], use_mpi=not update_ob_stats_independently_per_gpu)
47 | ph_istate = tf.placeholder(dtype=tf.float32,shape=(None,memsize), name='state')
48 | pdparamsize = self.pdtype.param_shape()[0]
49 | self.memsize = memsize
50 |
51 | #Inputs to policy and value function will have different shapes depending on whether it is rollout
52 | #or optimization time, so we treat separately.
53 | self.pdparam_opt, self.vpred_int_opt, self.vpred_ext_opt, self.snext_opt = \
54 | self.apply_policy(self.ph_ob[None][:,:-1],
55 | reuse=False,
56 | scope=scope,
57 | hidsize=hidsize,
58 | memsize=memsize,
59 | extrahid=extrahid,
60 | sy_nenvs=self.sy_nenvs,
61 | sy_nsteps=self.sy_nsteps - 1,
62 | pdparamsize=pdparamsize
63 | )
64 | self.pdparam_rollout, self.vpred_int_rollout, self.vpred_ext_rollout, self.snext_rollout = \
65 | self.apply_policy(self.ph_ob[None],
66 | reuse=True,
67 | scope=scope,
68 | hidsize=hidsize,
69 | memsize=memsize,
70 | extrahid=extrahid,
71 | sy_nenvs=self.sy_nenvs,
72 | sy_nsteps=self.sy_nsteps,
73 | pdparamsize=pdparamsize
74 | )
75 | if dynamics_bonus:
76 | self.define_dynamics_prediction_rew(convfeat=convfeat, rep_size=rep_size, enlargement=enlargement)
77 | else:
78 | self.define_self_prediction_rew(convfeat=convfeat, rep_size=rep_size, enlargement=enlargement)
79 |
80 | pd = self.pdtype.pdfromflat(self.pdparam_rollout)
81 | self.a_samp = pd.sample()
82 | self.nlp_samp = pd.neglogp(self.a_samp)
83 | self.entropy_rollout = pd.entropy()
84 | self.pd_rollout = pd
85 |
86 | self.pd_opt = self.pdtype.pdfromflat(self.pdparam_opt)
87 |
88 | self.ph_istate = ph_istate
89 |
90 | @staticmethod
91 | def apply_policy(ph_ob, reuse, scope, hidsize, memsize, extrahid, sy_nenvs, sy_nsteps, pdparamsize):
92 | data_format = 'NHWC'
93 | ph = ph_ob
94 | assert len(ph.shape.as_list()) == 5 # B,T,H,W,C
95 | logger.info("CnnPolicy: using '%s' shape %s as image input" % (ph.name, str(ph.shape)))
96 | X = tf.cast(ph, tf.float32) / 255.
97 | X = tf.reshape(X, (-1, *ph.shape.as_list()[-3:]))
98 |
99 | activ = tf.nn.relu
100 | yes_gpu = any(get_available_gpus())
101 | with tf.variable_scope(scope, reuse=reuse), tf.device('/gpu:0' if yes_gpu else '/cpu:0'):
102 | X = activ(conv(X, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2), data_format=data_format))
103 | X = activ(conv(X, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), data_format=data_format))
104 | X = activ(conv(X, 'c3', nf=64, rf=4, stride=1, init_scale=np.sqrt(2), data_format=data_format))
105 | X = to2d(X)
106 | mix_other_observations = [X]
107 | X = tf.concat(mix_other_observations, axis=1)
108 | X = activ(fc(X, 'fc1', nh=hidsize, init_scale=np.sqrt(2)))
109 | additional_size = 448
110 | X = activ(fc(X, 'fc_additional', nh=additional_size, init_scale=np.sqrt(2)))
111 | snext = tf.zeros((sy_nenvs, memsize))
112 | mix_timeout = [X]
113 |
114 | Xtout = tf.concat(mix_timeout, axis=1)
115 | if extrahid:
116 | Xtout = X + activ(fc(Xtout, 'fc2val', nh=additional_size, init_scale=0.1))
117 | X = X + activ(fc(X, 'fc2act', nh=additional_size, init_scale=0.1))
118 | pdparam = fc(X, 'pd', nh=pdparamsize, init_scale=0.01)
119 | vpred_int = fc(Xtout, 'vf_int', nh=1, init_scale=0.01)
120 | vpred_ext = fc(Xtout, 'vf_ext', nh=1, init_scale=0.01)
121 |
122 | pdparam = tf.reshape(pdparam, (sy_nenvs, sy_nsteps, pdparamsize))
123 | vpred_int = tf.reshape(vpred_int, (sy_nenvs, sy_nsteps))
124 | vpred_ext = tf.reshape(vpred_ext, (sy_nenvs, sy_nsteps))
125 | return pdparam, vpred_int, vpred_ext, snext
126 |
127 | def define_self_prediction_rew(self, convfeat, rep_size, enlargement):
128 | logger.info("Using RND BONUS ****************************************************")
129 |
130 | #RND bonus.
131 |
132 | # Random target network.
133 | for ph in self.ph_ob.values():
134 | if len(ph.shape.as_list()) == 5: # B,T,H,W,C
135 | logger.info("CnnTarget: using '%s' shape %s as image input" % (ph.name, str(ph.shape)))
136 | xr = ph[:,1:]
137 | xr = tf.cast(xr, tf.float32)
138 | xr = tf.reshape(xr, (-1, *ph.shape.as_list()[-3:]))[:, :, :, -1:]
139 | xr = tf.clip_by_value((xr - self.ph_mean) / self.ph_std, -5.0, 5.0)
140 |
141 | xr = tf.nn.leaky_relu(conv(xr, 'c1r', nf=convfeat * 1, rf=8, stride=4, init_scale=np.sqrt(2)))
142 | xr = tf.nn.leaky_relu(conv(xr, 'c2r', nf=convfeat * 2 * 1, rf=4, stride=2, init_scale=np.sqrt(2)))
143 | xr = tf.nn.leaky_relu(conv(xr, 'c3r', nf=convfeat * 2 * 1, rf=3, stride=1, init_scale=np.sqrt(2)))
144 | rgbr = [to2d(xr)]
145 | X_r = fc(rgbr[0], 'fc1r', nh=rep_size, init_scale=np.sqrt(2))
146 |
147 | # Predictor network.
148 | for ph in self.ph_ob.values():
149 | if len(ph.shape.as_list()) == 5: # B,T,H,W,C
150 | logger.info("CnnTarget: using '%s' shape %s as image input" % (ph.name, str(ph.shape)))
151 | xrp = ph[:,1:]
152 | xrp = tf.cast(xrp, tf.float32)
153 | xrp = tf.reshape(xrp, (-1, *ph.shape.as_list()[-3:]))[:, :, :, -1:]
154 | xrp = tf.clip_by_value((xrp - self.ph_mean) / self.ph_std, -5.0, 5.0)
155 |
156 | xrp = tf.nn.leaky_relu(conv(xrp, 'c1rp_pred', nf=convfeat, rf=8, stride=4, init_scale=np.sqrt(2)))
157 | xrp = tf.nn.leaky_relu(conv(xrp, 'c2rp_pred', nf=convfeat * 2, rf=4, stride=2, init_scale=np.sqrt(2)))
158 | xrp = tf.nn.leaky_relu(conv(xrp, 'c3rp_pred', nf=convfeat * 2, rf=3, stride=1, init_scale=np.sqrt(2)))
159 | rgbrp = to2d(xrp)
160 | # X_r_hat = tf.nn.relu(fc(rgb[0], 'fc1r_hat1', nh=256 * enlargement, init_scale=np.sqrt(2)))
161 | X_r_hat = tf.nn.relu(fc(rgbrp, 'fc1r_hat1_pred', nh=256 * enlargement, init_scale=np.sqrt(2)))
162 | X_r_hat = tf.nn.relu(fc(X_r_hat, 'fc1r_hat2_pred', nh=256 * enlargement, init_scale=np.sqrt(2)))
163 | X_r_hat = fc(X_r_hat, 'fc1r_hat3_pred', nh=rep_size, init_scale=np.sqrt(2))
164 |
165 | self.feat_var = tf.reduce_mean(tf.nn.moments(X_r, axes=[0])[1])
166 | self.max_feat = tf.reduce_max(tf.abs(X_r))
167 | self.int_rew = tf.reduce_mean(tf.square(tf.stop_gradient(X_r) - X_r_hat), axis=-1, keep_dims=True)
168 | self.int_rew = tf.reshape(self.int_rew, (self.sy_nenvs, self.sy_nsteps - 1))
169 |
170 | targets = tf.stop_gradient(X_r)
171 | # self.aux_loss = tf.reduce_mean(tf.square(noisy_targets-X_r_hat))
172 | self.aux_loss = tf.reduce_mean(tf.square(targets - X_r_hat), -1)
173 | mask = tf.random_uniform(shape=tf.shape(self.aux_loss), minval=0., maxval=1., dtype=tf.float32)
174 | mask = tf.cast(mask < self.proportion_of_exp_used_for_predictor_update, tf.float32)
175 | self.aux_loss = tf.reduce_sum(mask * self.aux_loss) / tf.maximum(tf.reduce_sum(mask), 1.)
176 |
177 | def define_dynamics_prediction_rew(self, convfeat, rep_size, enlargement):
178 | #Dynamics loss with random features.
179 |
180 | # Random target network.
181 | for ph in self.ph_ob.values():
182 | if len(ph.shape.as_list()) == 5: # B,T,H,W,C
183 | logger.info("CnnTarget: using '%s' shape %s as image input" % (ph.name, str(ph.shape)))
184 | xr = ph[:,1:]
185 | xr = tf.cast(xr, tf.float32)
186 | xr = tf.reshape(xr, (-1, *ph.shape.as_list()[-3:]))[:, :, :, -1:]
187 | xr = tf.clip_by_value((xr - self.ph_mean) / self.ph_std, -5.0, 5.0)
188 |
189 | xr = tf.nn.leaky_relu(conv(xr, 'c1r', nf=convfeat * 1, rf=8, stride=4, init_scale=np.sqrt(2)))
190 | xr = tf.nn.leaky_relu(conv(xr, 'c2r', nf=convfeat * 2 * 1, rf=4, stride=2, init_scale=np.sqrt(2)))
191 | xr = tf.nn.leaky_relu(conv(xr, 'c3r', nf=convfeat * 2 * 1, rf=3, stride=1, init_scale=np.sqrt(2)))
192 | rgbr = [to2d(xr)]
193 | X_r = fc(rgbr[0], 'fc1r', nh=rep_size, init_scale=np.sqrt(2))
194 |
195 | # Predictor network.
196 | ac_one_hot = tf.one_hot(self.ph_ac, self.ac_space.n, axis=2)
197 | assert ac_one_hot.get_shape().ndims == 3
198 | assert ac_one_hot.get_shape().as_list() == [None, None, self.ac_space.n], ac_one_hot.get_shape().as_list()
199 | ac_one_hot = tf.reshape(ac_one_hot, (-1, self.ac_space.n))
200 | def cond(x):
201 | return tf.concat([x, ac_one_hot], 1)
202 |
203 | for ph in self.ph_ob.values():
204 | if len(ph.shape.as_list()) == 5: # B,T,H,W,C
205 | logger.info("CnnTarget: using '%s' shape %s as image input" % (ph.name, str(ph.shape)))
206 | xrp = ph[:,:-1]
207 | xrp = tf.cast(xrp, tf.float32)
208 | xrp = tf.reshape(xrp, (-1, *ph.shape.as_list()[-3:]))
209 | # ph_mean, ph_std are 84x84x1, so we subtract the average of the last channel from all channels. Is this ok?
210 | xrp = tf.clip_by_value((xrp - self.ph_mean) / self.ph_std, -5.0, 5.0)
211 |
212 | xrp = tf.nn.leaky_relu(conv(xrp, 'c1rp_pred', nf=convfeat, rf=8, stride=4, init_scale=np.sqrt(2)))
213 | xrp = tf.nn.leaky_relu(conv(xrp, 'c2rp_pred', nf=convfeat * 2, rf=4, stride=2, init_scale=np.sqrt(2)))
214 | xrp = tf.nn.leaky_relu(conv(xrp, 'c3rp_pred', nf=convfeat * 2, rf=3, stride=1, init_scale=np.sqrt(2)))
215 | rgbrp = to2d(xrp)
216 |
217 | # X_r_hat = tf.nn.relu(fc(rgb[0], 'fc1r_hat1', nh=256 * enlargement, init_scale=np.sqrt(2)))
218 | X_r_hat = tf.nn.relu(fc(cond(rgbrp), 'fc1r_hat1_pred', nh=256 * enlargement, init_scale=np.sqrt(2)))
219 | X_r_hat = tf.nn.relu(fc(cond(X_r_hat), 'fc1r_hat2_pred', nh=256 * enlargement, init_scale=np.sqrt(2)))
220 | X_r_hat = fc(cond(X_r_hat), 'fc1r_hat3_pred', nh=rep_size, init_scale=np.sqrt(2))
221 |
222 | self.feat_var = tf.reduce_mean(tf.nn.moments(X_r, axes=[0])[1])
223 | self.max_feat = tf.reduce_max(tf.abs(X_r))
224 | self.int_rew = tf.reduce_mean(tf.square(tf.stop_gradient(X_r) - X_r_hat), axis=-1, keep_dims=True)
225 | self.int_rew = tf.reshape(self.int_rew, (self.sy_nenvs, self.sy_nsteps - 1))
226 |
227 | noisy_targets = tf.stop_gradient(X_r)
228 | # self.aux_loss = tf.reduce_mean(tf.square(noisy_targets-X_r_hat))
229 | self.aux_loss = tf.reduce_mean(tf.square(noisy_targets - X_r_hat), -1)
230 | mask = tf.random_uniform(shape=tf.shape(self.aux_loss), minval=0., maxval=1., dtype=tf.float32)
231 | mask = tf.cast(mask < self.proportion_of_exp_used_for_predictor_update, tf.float32)
232 | self.aux_loss = tf.reduce_sum(mask * self.aux_loss) / tf.maximum(tf.reduce_sum(mask), 1.)
233 |
234 | def initial_state(self, n):
235 | return np.zeros((n, self.memsize), np.float32)
236 |
237 | def call(self, dict_obs, new, istate, update_obs_stats=False):
238 | for ob in dict_obs.values():
239 | if ob is not None:
240 | if update_obs_stats:
241 | raise NotImplementedError
242 | ob = ob.astype(np.float32)
243 | ob = ob.reshape(-1, *self.ob_space.shape)
244 | self.ob_rms.update(ob)
245 | # Note: if it fails here with ph vs observations inconsistency, check if you're loading agent from disk.
246 | # It will use whatever observation spaces saved to disk along with other ctor params.
247 | feed1 = { self.ph_ob[k]: dict_obs[k][:,None] for k in self.ph_ob_keys }
248 | feed2 = { self.ph_istate: istate, self.ph_new: new[:,None].astype(np.float32) }
249 | feed1.update({self.ph_mean: self.ob_rms.mean, self.ph_std: self.ob_rms.var ** 0.5})
250 | # for f in feed1:
251 | # print(f)
252 | a, vpred_int,vpred_ext, nlp, newstate, ent = tf.get_default_session().run(
253 | [self.a_samp, self.vpred_int_rollout,self.vpred_ext_rollout, self.nlp_samp, self.snext_rollout, self.entropy_rollout],
254 | feed_dict={**feed1, **feed2})
255 | return a[:,0], vpred_int[:,0],vpred_ext[:,0], nlp[:,0], newstate, ent[:,0]
256 |
--------------------------------------------------------------------------------
/ppo_agent.py:
--------------------------------------------------------------------------------
1 | import time
2 | from collections import deque, defaultdict
3 | from copy import copy
4 |
5 | import numpy as np
6 | import psutil
7 | import tensorflow as tf
8 | from mpi4py import MPI
9 | from baselines import logger
10 | import tf_util
11 | from recorder import Recorder
12 | from utils import explained_variance
13 | from console_util import fmt_row
14 | from mpi_util import MpiAdamOptimizer, RunningMeanStd, sync_from_root
15 |
16 | NO_STATES = ['NO_STATES']
17 |
18 | class SemicolonList(list):
19 | def __str__(self):
20 | return '['+';'.join([str(x) for x in self])+']'
21 |
22 | class InteractionState(object):
23 | """
24 | Parts of the PPOAgent's state that are based on interaction with a single batch of envs
25 | """
26 | def __init__(self, ob_space, ac_space, nsteps, gamma, venvs, stochpol, comm):
27 | self.lump_stride = venvs[0].num_envs
28 | self.venvs = venvs
29 | assert all(venv.num_envs == self.lump_stride for venv in self.venvs[1:]), 'All venvs should have the same num_envs'
30 | self.nlump = len(venvs)
31 | nenvs = self.nenvs = self.nlump * self.lump_stride
32 | self.reset_counter = 0
33 | self.env_results = [None] * self.nlump
34 | self.buf_vpreds_int = np.zeros((nenvs, nsteps), np.float32)
35 | self.buf_vpreds_ext = np.zeros((nenvs, nsteps), np.float32)
36 | self.buf_nlps = np.zeros((nenvs, nsteps), np.float32)
37 | self.buf_advs = np.zeros((nenvs, nsteps), np.float32)
38 | self.buf_advs_int = np.zeros((nenvs, nsteps), np.float32)
39 | self.buf_advs_ext = np.zeros((nenvs, nsteps), np.float32)
40 | self.buf_rews_int = np.zeros((nenvs, nsteps), np.float32)
41 | self.buf_rews_ext = np.zeros((nenvs, nsteps), np.float32)
42 | self.buf_acs = np.zeros((nenvs, nsteps, *ac_space.shape), ac_space.dtype)
43 | self.buf_obs = { k: np.zeros(
44 | [nenvs, nsteps] + stochpol.ph_ob[k].shape.as_list()[2:],
45 | dtype=stochpol.ph_ob_dtypes[k])
46 | for k in stochpol.ph_ob_keys }
47 | self.buf_ob_last = { k: self.buf_obs[k][:, 0, ...].copy() for k in stochpol.ph_ob_keys }
48 | self.buf_epinfos = [{} for _ in range(self.nenvs)]
49 | self.buf_news = np.zeros((nenvs, nsteps), np.float32)
50 | self.buf_ent = np.zeros((nenvs, nsteps), np.float32)
51 | self.mem_state = stochpol.initial_state(nenvs)
52 | self.seg_init_mem_state = copy(self.mem_state) # Memory state at beginning of segment of timesteps
53 | self.rff_int = RewardForwardFilter(gamma)
54 | self.rff_rms_int = RunningMeanStd(comm=comm, use_mpi=True)
55 | self.buf_new_last = self.buf_news[:, 0, ...].copy()
56 | self.buf_vpred_int_last = self.buf_vpreds_int[:, 0, ...].copy()
57 | self.buf_vpred_ext_last = self.buf_vpreds_ext[:, 0, ...].copy()
58 | self.step_count = 0 # counts number of timesteps that you've interacted with this set of environments
59 | self.t_last_update = time.time()
60 | self.statlists = defaultdict(lambda : deque([], maxlen=100)) # Count other stats, e.g. optimizer outputs
61 | self.stats = defaultdict(float) # Count episodes and timesteps
62 | self.stats['epcount'] = 0
63 | self.stats['n_updates'] = 0
64 | self.stats['tcount'] = 0
65 |
66 | def close(self):
67 | for venv in self.venvs:
68 | venv.close()
69 |
70 | def dict_gather(comm, d, op='mean'):
71 | if comm is None: return d
72 | alldicts = comm.allgather(d)
73 | size = comm.Get_size()
74 | k2li = defaultdict(list)
75 | for d in alldicts:
76 | for (k,v) in d.items():
77 | k2li[k].append(v)
78 | result = {}
79 | for (k,li) in k2li.items():
80 | if op=='mean':
81 | result[k] = np.mean(li, axis=0)
82 | elif op=='sum':
83 | result[k] = np.sum(li, axis=0)
84 | elif op=="max":
85 | result[k] = np.max(li, axis=0)
86 | else:
87 | assert 0, op
88 | return result
89 |
90 |
91 | class PpoAgent(object):
92 | envs = None
93 | def __init__(self, *, scope,
94 | ob_space, ac_space,
95 | stochpol_fn,
96 | nsteps, nepochs=4, nminibatches=1,
97 | gamma=0.99,
98 | gamma_ext=0.99,
99 | lam=0.95,
100 | ent_coef=0,
101 | cliprange=0.2,
102 | max_grad_norm=1.0,
103 | vf_coef=1.0,
104 | lr=30e-5,
105 | adam_hps=None,
106 | testing=False,
107 | comm=None, comm_train=None, use_news=False,
108 | update_ob_stats_every_step=True,
109 | int_coeff=None,
110 | ext_coeff=None,
111 | ):
112 | self.lr = lr
113 | self.ext_coeff = ext_coeff
114 | self.int_coeff = int_coeff
115 | self.use_news = use_news
116 | self.update_ob_stats_every_step = update_ob_stats_every_step
117 | self.abs_scope = (tf.get_variable_scope().name + '/' + scope).lstrip('/')
118 | self.testing = testing
119 | self.comm_log = MPI.COMM_SELF
120 | if comm is not None and comm.Get_size() > 1:
121 | self.comm_log = comm
122 | assert not testing or comm.Get_rank() != 0, "Worker number zero can't be testing"
123 | if comm_train is not None:
124 | self.comm_train, self.comm_train_size = comm_train, comm_train.Get_size()
125 | else:
126 | self.comm_train, self.comm_train_size = self.comm_log, self.comm_log.Get_size()
127 | self.is_log_leader = self.comm_log.Get_rank()==0
128 | self.is_train_leader = self.comm_train.Get_rank()==0
129 | with tf.variable_scope(scope):
130 | self.best_ret = -np.inf
131 | self.local_best_ret = - np.inf
132 | self.rooms = []
133 | self.local_rooms = []
134 | self.scores = []
135 | self.ob_space = ob_space
136 | self.ac_space = ac_space
137 | self.stochpol = stochpol_fn()
138 | self.nepochs = nepochs
139 | self.cliprange = cliprange
140 | self.nsteps = nsteps
141 | self.nminibatches = nminibatches
142 | self.gamma = gamma
143 | self.gamma_ext = gamma_ext
144 | self.lam = lam
145 | self.adam_hps = adam_hps or dict()
146 | self.ph_adv = tf.placeholder(tf.float32, [None, None])
147 | self.ph_ret_int = tf.placeholder(tf.float32, [None, None])
148 | self.ph_ret_ext = tf.placeholder(tf.float32, [None, None])
149 | self.ph_oldnlp = tf.placeholder(tf.float32, [None, None])
150 | self.ph_oldvpred = tf.placeholder(tf.float32, [None, None])
151 | self.ph_lr = tf.placeholder(tf.float32, [])
152 | self.ph_lr_pred = tf.placeholder(tf.float32, [])
153 | self.ph_cliprange = tf.placeholder(tf.float32, [])
154 |
155 | #Define loss.
156 | neglogpac = self.stochpol.pd_opt.neglogp(self.stochpol.ph_ac)
157 | entropy = tf.reduce_mean(self.stochpol.pd_opt.entropy())
158 | vf_loss_int = (0.5 * vf_coef) * tf.reduce_mean(tf.square(self.stochpol.vpred_int_opt - self.ph_ret_int))
159 | vf_loss_ext = (0.5 * vf_coef) * tf.reduce_mean(tf.square(self.stochpol.vpred_ext_opt - self.ph_ret_ext))
160 | vf_loss = vf_loss_int + vf_loss_ext
161 | ratio = tf.exp(self.ph_oldnlp - neglogpac) # p_new / p_old
162 | negadv = - self.ph_adv
163 | pg_losses1 = negadv * ratio
164 | pg_losses2 = negadv * tf.clip_by_value(ratio, 1.0 - self.ph_cliprange, 1.0 + self.ph_cliprange)
165 | pg_loss = tf.reduce_mean(tf.maximum(pg_losses1, pg_losses2))
166 | ent_loss = (- ent_coef) * entropy
167 | approxkl = .5 * tf.reduce_mean(tf.square(neglogpac - self.ph_oldnlp))
168 | maxkl = .5 * tf.reduce_max(tf.square(neglogpac - self.ph_oldnlp))
169 | clipfrac = tf.reduce_mean(tf.to_float(tf.greater(tf.abs(ratio - 1.0), self.ph_cliprange)))
170 | loss = pg_loss + ent_loss + vf_loss + self.stochpol.aux_loss
171 |
172 | #Create optimizer.
173 | params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.abs_scope)
174 | logger.info("PPO: using MpiAdamOptimizer connected to %i peers" % self.comm_train_size)
175 | trainer = MpiAdamOptimizer(self.comm_train, learning_rate=self.ph_lr, **self.adam_hps)
176 | grads_and_vars = trainer.compute_gradients(loss, params)
177 | grads, vars = zip(*grads_and_vars)
178 | if max_grad_norm:
179 | _, _grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
180 | global_grad_norm = tf.global_norm(grads)
181 | grads_and_vars = list(zip(grads, vars))
182 | self._train = trainer.apply_gradients(grads_and_vars)
183 |
184 | #Quantities for reporting.
185 | self._losses = [loss, pg_loss, vf_loss, entropy, clipfrac, approxkl, maxkl, self.stochpol.aux_loss,
186 | self.stochpol.feat_var, self.stochpol.max_feat, global_grad_norm]
187 | self.loss_names = ['tot', 'pg', 'vf', 'ent', 'clipfrac', 'approxkl', 'maxkl', "auxloss", "featvar",
188 | "maxfeat", "gradnorm"]
189 | self.I = None
190 | self.disable_policy_update = None
191 | allvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.abs_scope)
192 | if self.is_log_leader:
193 | tf_util.display_var_info(allvars)
194 | tf.get_default_session().run(tf.variables_initializer(allvars))
195 | sync_from_root(tf.get_default_session(), allvars) #Syncs initialization across mpi workers.
196 | self.t0 = time.time()
197 | self.global_tcount = 0
198 |
199 | def start_interaction(self, venvs, disable_policy_update=False):
200 | self.I = InteractionState(ob_space=self.ob_space, ac_space=self.ac_space,
201 | nsteps=self.nsteps, gamma=self.gamma,
202 | venvs=venvs, stochpol=self.stochpol, comm=self.comm_train)
203 | self.disable_policy_update = disable_policy_update
204 | self.recorder = Recorder(nenvs=self.I.nenvs, score_multiple=venvs[0].score_multiple)
205 |
206 | def collect_random_statistics(self, num_timesteps):
207 | #Initializes observation normalization with data from random agent.
208 | all_ob = []
209 | for lump in range(self.I.nlump):
210 | all_ob.append(self.I.venvs[lump].reset())
211 | for step in range(num_timesteps):
212 | for lump in range(self.I.nlump):
213 | acs = np.random.randint(low=0, high=self.ac_space.n, size=(self.I.lump_stride,))
214 | self.I.venvs[lump].step_async(acs)
215 | ob, _, _, _ = self.I.venvs[lump].step_wait()
216 | all_ob.append(ob)
217 | if len(all_ob) % (128 * self.I.nlump) == 0:
218 | ob_ = np.asarray(all_ob).astype(np.float32).reshape((-1, *self.ob_space.shape))
219 | self.stochpol.ob_rms.update(ob_[:,:,:,-1:])
220 | all_ob.clear()
221 |
222 | def stop_interaction(self):
223 | self.I.close()
224 | self.I = None
225 |
226 | @logger.profile("update")
227 | def update(self):
228 |
229 | #Some logic gathering best ret, rooms etc using MPI.
230 | temp = sum(MPI.COMM_WORLD.allgather(self.local_rooms), [])
231 | temp = sorted(list(set(temp)))
232 | self.rooms = temp
233 |
234 | temp = sum(MPI.COMM_WORLD.allgather(self.scores), [])
235 | temp = sorted(list(set(temp)))
236 | self.scores = temp
237 |
238 | temp = sum(MPI.COMM_WORLD.allgather([self.local_best_ret]), [])
239 | self.best_ret = max(temp)
240 |
241 | eprews = MPI.COMM_WORLD.allgather(np.mean(list(self.I.statlists["eprew"])))
242 | local_best_rets = MPI.COMM_WORLD.allgather(self.local_best_ret)
243 | n_rooms = sum(MPI.COMM_WORLD.allgather([len(self.local_rooms)]), [])
244 |
245 | if MPI.COMM_WORLD.Get_rank() == 0:
246 | logger.info(f"Rooms visited {self.rooms}")
247 | logger.info(f"Best return {self.best_ret}")
248 | logger.info(f"Best local return {sorted(local_best_rets)}")
249 | logger.info(f"eprews {sorted(eprews)}")
250 | logger.info(f"n_rooms {sorted(n_rooms)}")
251 | logger.info(f"Extrinsic coefficient {self.ext_coeff}")
252 | logger.info(f"Gamma {self.gamma}")
253 | logger.info(f"Gamma ext {self.gamma_ext}")
254 | logger.info(f"All scores {sorted(self.scores)}")
255 |
256 |
257 | #Normalize intrinsic rewards.
258 | rffs_int = np.array([self.I.rff_int.update(rew) for rew in self.I.buf_rews_int.T])
259 | self.I.rff_rms_int.update(rffs_int.ravel())
260 | rews_int = self.I.buf_rews_int / np.sqrt(self.I.rff_rms_int.var)
261 | self.mean_int_rew = np.mean(rews_int)
262 | self.max_int_rew = np.max(rews_int)
263 |
264 | #Don't normalize extrinsic rewards.
265 | rews_ext = self.I.buf_rews_ext
266 |
267 | rewmean, rewstd, rewmax = self.I.buf_rews_int.mean(), self.I.buf_rews_int.std(), np.max(self.I.buf_rews_int)
268 |
269 | #Calculate intrinsic returns and advantages.
270 | lastgaelam = 0
271 | for t in range(self.nsteps-1, -1, -1): # nsteps-2 ... 0
272 | if self.use_news:
273 | nextnew = self.I.buf_news[:, t + 1] if t + 1 < self.nsteps else self.I.buf_new_last
274 | else:
275 | nextnew = 0.0 #No dones for intrinsic reward.
276 | nextvals = self.I.buf_vpreds_int[:, t + 1] if t + 1 < self.nsteps else self.I.buf_vpred_int_last
277 | nextnotnew = 1 - nextnew
278 | delta = rews_int[:, t] + self.gamma * nextvals * nextnotnew - self.I.buf_vpreds_int[:, t]
279 | self.I.buf_advs_int[:, t] = lastgaelam = delta + self.gamma * self.lam * nextnotnew * lastgaelam
280 | rets_int = self.I.buf_advs_int + self.I.buf_vpreds_int
281 |
282 | #Calculate extrinsic returns and advantages.
283 | lastgaelam = 0
284 | for t in range(self.nsteps-1, -1, -1): # nsteps-2 ... 0
285 | nextnew = self.I.buf_news[:, t + 1] if t + 1 < self.nsteps else self.I.buf_new_last
286 | #Use dones for extrinsic reward.
287 | nextvals = self.I.buf_vpreds_ext[:, t + 1] if t + 1 < self.nsteps else self.I.buf_vpred_ext_last
288 | nextnotnew = 1 - nextnew
289 | delta = rews_ext[:, t] + self.gamma_ext * nextvals * nextnotnew - self.I.buf_vpreds_ext[:, t]
290 | self.I.buf_advs_ext[:, t] = lastgaelam = delta + self.gamma_ext * self.lam * nextnotnew * lastgaelam
291 | rets_ext = self.I.buf_advs_ext + self.I.buf_vpreds_ext
292 |
293 | #Combine the extrinsic and intrinsic advantages.
294 | self.I.buf_advs = self.int_coeff*self.I.buf_advs_int + self.ext_coeff*self.I.buf_advs_ext
295 |
296 | #Collects info for reporting.
297 | info = dict(
298 | advmean = self.I.buf_advs.mean(),
299 | advstd = self.I.buf_advs.std(),
300 | retintmean = rets_int.mean(), # previously retmean
301 | retintstd = rets_int.std(), # previously retstd
302 | retextmean = rets_ext.mean(), # previously not there
303 | retextstd = rets_ext.std(), # previously not there
304 | rewintmean_unnorm = rewmean, # previously rewmean
305 | rewintmax_unnorm = rewmax, # previously not there
306 | rewintmean_norm = self.mean_int_rew, # previously rewintmean
307 | rewintmax_norm = self.max_int_rew, # previously rewintmax
308 | rewintstd_unnorm = rewstd, # previously rewstd
309 | vpredintmean = self.I.buf_vpreds_int.mean(), # previously vpredmean
310 | vpredintstd = self.I.buf_vpreds_int.std(), # previously vrpedstd
311 | vpredextmean = self.I.buf_vpreds_ext.mean(), # previously not there
312 | vpredextstd = self.I.buf_vpreds_ext.std(), # previously not there
313 | ev_int = np.clip(explained_variance(self.I.buf_vpreds_int.ravel(), rets_int.ravel()), -1, None),
314 | ev_ext = np.clip(explained_variance(self.I.buf_vpreds_ext.ravel(), rets_ext.ravel()), -1, None),
315 | rooms = SemicolonList(self.rooms),
316 | n_rooms = len(self.rooms),
317 | best_ret = self.best_ret,
318 | reset_counter = self.I.reset_counter
319 | )
320 |
321 | info[f'mem_available'] = psutil.virtual_memory().available
322 |
323 | to_record = {'acs': self.I.buf_acs,
324 | 'rews_int': self.I.buf_rews_int,
325 | 'rews_int_norm': rews_int,
326 | 'rews_ext': self.I.buf_rews_ext,
327 | 'vpred_int': self.I.buf_vpreds_int,
328 | 'vpred_ext': self.I.buf_vpreds_ext,
329 | 'adv_int': self.I.buf_advs_int,
330 | 'adv_ext': self.I.buf_advs_ext,
331 | 'ent': self.I.buf_ent,
332 | 'ret_int': rets_int,
333 | 'ret_ext': rets_ext,
334 | }
335 | if self.I.venvs[0].record_obs:
336 | to_record['obs'] = self.I.buf_obs[None]
337 | self.recorder.record(bufs=to_record,
338 | infos=self.I.buf_epinfos)
339 |
340 |
341 | #Create feeddict for optimization.
342 | envsperbatch = self.I.nenvs // self.nminibatches
343 | ph_buf = [
344 | (self.stochpol.ph_ac, self.I.buf_acs),
345 | (self.ph_ret_int, rets_int),
346 | (self.ph_ret_ext, rets_ext),
347 | (self.ph_oldnlp, self.I.buf_nlps),
348 | (self.ph_adv, self.I.buf_advs),
349 | ]
350 | if self.I.mem_state is not NO_STATES:
351 | ph_buf.extend([
352 | (self.stochpol.ph_istate, self.I.seg_init_mem_state),
353 | (self.stochpol.ph_new, self.I.buf_news),
354 | ])
355 |
356 | verbose = True
357 | if verbose and self.is_log_leader:
358 | samples = np.prod(self.I.buf_advs.shape)
359 | logger.info("buffer shape %s, samples_per_mpi=%i, mini_per_mpi=%i, samples=%i, mini=%i " % (
360 | str(self.I.buf_advs.shape),
361 | samples, samples//self.nminibatches,
362 | samples*self.comm_train_size, samples*self.comm_train_size//self.nminibatches))
363 | logger.info(" "*6 + fmt_row(13, self.loss_names))
364 |
365 |
366 | epoch = 0
367 | start = 0
368 | #Optimizes on current data for several epochs.
369 | while epoch < self.nepochs:
370 | end = start + envsperbatch
371 | mbenvinds = slice(start, end, None)
372 |
373 | fd = {ph : buf[mbenvinds] for (ph, buf) in ph_buf}
374 | fd.update({self.ph_lr : self.lr, self.ph_cliprange : self.cliprange})
375 | fd[self.stochpol.ph_ob[None]] = np.concatenate([self.I.buf_obs[None][mbenvinds], self.I.buf_ob_last[None][mbenvinds, None]], 1)
376 | assert list(fd[self.stochpol.ph_ob[None]].shape) == [self.I.nenvs//self.nminibatches, self.nsteps + 1] + list(self.ob_space.shape), \
377 | [fd[self.stochpol.ph_ob[None]].shape, [self.I.nenvs//self.nminibatches, self.nsteps + 1] + list(self.ob_space.shape)]
378 | fd.update({self.stochpol.ph_mean:self.stochpol.ob_rms.mean, self.stochpol.ph_std:self.stochpol.ob_rms.var**0.5})
379 |
380 | ret = tf.get_default_session().run(self._losses+[self._train], feed_dict=fd)[:-1]
381 | if not self.testing:
382 | lossdict = dict(zip([n for n in self.loss_names], ret), axis=0)
383 | else:
384 | lossdict = {}
385 | #Synchronize the lossdict across mpi processes, otherwise weights may be rolled back on one process but not another.
386 | _maxkl = lossdict.pop('maxkl')
387 | lossdict = dict_gather(self.comm_train, lossdict, op='mean')
388 | maxmaxkl = dict_gather(self.comm_train, {"maxkl":_maxkl}, op='max')
389 | lossdict["maxkl"] = maxmaxkl["maxkl"]
390 | if verbose and self.is_log_leader:
391 | logger.info("%i:%03i %s" % (epoch, start, fmt_row(13, [lossdict[n] for n in self.loss_names])))
392 | start += envsperbatch
393 | if start == self.I.nenvs:
394 | epoch += 1
395 | start = 0
396 |
397 | if self.is_train_leader:
398 | self.I.stats["n_updates"] += 1
399 | info.update([('opt_'+n, lossdict[n]) for n in self.loss_names])
400 | tnow = time.time()
401 | info['tps'] = self.nsteps * self.I.nenvs / (tnow - self.I.t_last_update)
402 | info['time_elapsed'] = time.time() - self.t0
403 | self.I.t_last_update = tnow
404 | self.stochpol.update_normalization( # Necessary for continuous control tasks with odd obs ranges, only implemented in mlp policy,
405 | ob=self.I.buf_obs # NOTE: not shared via MPI
406 | )
407 | return info
408 |
409 | def env_step(self, l, acs):
410 | self.I.venvs[l].step_async(acs)
411 | self.I.env_results[l] = None
412 |
413 | def env_get(self, l):
414 | """
415 | Get most recent (obs, rews, dones, infos) from vectorized environment
416 | Using step_wait if necessary
417 | """
418 | if self.I.step_count == 0: # On the zeroth step with a new venv, we need to call reset on the environment
419 | ob = self.I.venvs[l].reset()
420 | out = self.I.env_results[l] = (ob, None, np.ones(self.I.lump_stride, bool), {})
421 | else:
422 | if self.I.env_results[l] is None:
423 | out = self.I.env_results[l] = self.I.venvs[l].step_wait()
424 | else:
425 | out = self.I.env_results[l]
426 | return out
427 |
428 | @logger.profile("step")
429 | def step(self):
430 | #Does a rollout.
431 | t = self.I.step_count % self.nsteps
432 | epinfos = []
433 | for l in range(self.I.nlump):
434 | obs, prevrews, news, infos = self.env_get(l)
435 | for env_pos_in_lump, info in enumerate(infos):
436 | if 'episode' in info:
437 | #Information like rooms visited is added to info on end of episode.
438 | epinfos.append(info['episode'])
439 | info_with_places = info['episode']
440 | try:
441 | info_with_places['places'] = info['episode']['visited_rooms']
442 | except:
443 | import ipdb; ipdb.set_trace()
444 | self.I.buf_epinfos[env_pos_in_lump+l*self.I.lump_stride][t] = info_with_places
445 |
446 | sli = slice(l * self.I.lump_stride, (l + 1) * self.I.lump_stride)
447 | memsli = slice(None) if self.I.mem_state is NO_STATES else sli
448 | dict_obs = self.stochpol.ensure_observation_is_dict(obs)
449 | with logger.ProfileKV("policy_inference"):
450 | #Calls the policy and value function on current observation.
451 | acs, vpreds_int, vpreds_ext, nlps, self.I.mem_state[memsli], ent = self.stochpol.call(dict_obs, news, self.I.mem_state[memsli],
452 | update_obs_stats=self.update_ob_stats_every_step)
453 | self.env_step(l, acs)
454 |
455 | #Update buffer with transition.
456 | for k in self.stochpol.ph_ob_keys:
457 | self.I.buf_obs[k][sli, t] = dict_obs[k]
458 | self.I.buf_news[sli, t] = news
459 | self.I.buf_vpreds_int[sli, t] = vpreds_int
460 | self.I.buf_vpreds_ext[sli, t] = vpreds_ext
461 | self.I.buf_nlps[sli, t] = nlps
462 | self.I.buf_acs[sli, t] = acs
463 | self.I.buf_ent[sli, t] = ent
464 |
465 | if t > 0:
466 | self.I.buf_rews_ext[sli, t-1] = prevrews
467 |
468 | self.I.step_count += 1
469 | if t == self.nsteps - 1 and not self.disable_policy_update:
470 | #We need to take one extra step so every transition has a reward.
471 | for l in range(self.I.nlump):
472 | sli = slice(l * self.I.lump_stride, (l + 1) * self.I.lump_stride)
473 | memsli = slice(None) if self.I.mem_state is NO_STATES else sli
474 | nextobs, rews, nextnews, _ = self.env_get(l)
475 | dict_nextobs = self.stochpol.ensure_observation_is_dict(nextobs)
476 | for k in self.stochpol.ph_ob_keys:
477 | self.I.buf_ob_last[k][sli] = dict_nextobs[k]
478 | self.I.buf_new_last[sli] = nextnews
479 | with logger.ProfileKV("policy_inference"):
480 | _, self.I.buf_vpred_int_last[sli], self.I.buf_vpred_ext_last[sli], _, _, _ = self.stochpol.call(dict_nextobs, nextnews, self.I.mem_state[memsli], update_obs_stats=False)
481 | self.I.buf_rews_ext[sli, t] = rews
482 |
483 | #Calcuate the intrinsic rewards for the rollout.
484 | fd = {}
485 | fd[self.stochpol.ph_ob[None]] = np.concatenate([self.I.buf_obs[None], self.I.buf_ob_last[None][:,None]], 1)
486 | fd.update({self.stochpol.ph_mean: self.stochpol.ob_rms.mean,
487 | self.stochpol.ph_std: self.stochpol.ob_rms.var ** 0.5})
488 | fd[self.stochpol.ph_ac] = self.I.buf_acs
489 | self.I.buf_rews_int[:] = tf.get_default_session().run(self.stochpol.int_rew, fd)
490 |
491 | if not self.update_ob_stats_every_step:
492 | #Update observation normalization parameters after the rollout is completed.
493 | obs_ = self.I.buf_obs[None].astype(np.float32)
494 | self.stochpol.ob_rms.update(obs_.reshape((-1, *obs_.shape[2:]))[:,:,:,-1:])
495 | if not self.testing:
496 | update_info = self.update()
497 | else:
498 | update_info = {}
499 | self.I.seg_init_mem_state = copy(self.I.mem_state)
500 | global_i_stats = dict_gather(self.comm_log, self.I.stats, op='sum')
501 | global_deque_mean = dict_gather(self.comm_log, { n : np.mean(dvs) for n,dvs in self.I.statlists.items() }, op='mean')
502 | update_info.update(global_i_stats)
503 | update_info.update(global_deque_mean)
504 | self.global_tcount = global_i_stats['tcount']
505 | for infos_ in self.I.buf_epinfos:
506 | infos_.clear()
507 | else:
508 | update_info = {}
509 |
510 | #Some reporting logic.
511 | for epinfo in epinfos:
512 | if self.testing:
513 | self.I.statlists['eprew_test'].append(epinfo['r'])
514 | self.I.statlists['eplen_test'].append(epinfo['l'])
515 | else:
516 | if "visited_rooms" in epinfo:
517 | self.local_rooms += list(epinfo["visited_rooms"])
518 | self.local_rooms = sorted(list(set(self.local_rooms)))
519 | score_multiple = self.I.venvs[0].score_multiple
520 | if score_multiple is None:
521 | score_multiple = 1000
522 | rounded_score = int(epinfo["r"] / score_multiple) * score_multiple
523 | self.scores.append(rounded_score)
524 | self.scores = sorted(list(set(self.scores)))
525 | self.I.statlists['eprooms'].append(len(epinfo["visited_rooms"]))
526 |
527 | self.I.statlists['eprew'].append(epinfo['r'])
528 | if self.local_best_ret is None:
529 | self.local_best_ret = epinfo["r"]
530 | elif epinfo["r"] > self.local_best_ret:
531 | self.local_best_ret = epinfo["r"]
532 |
533 | self.I.statlists['eplen'].append(epinfo['l'])
534 | self.I.stats['epcount'] += 1
535 | self.I.stats['tcount'] += epinfo['l']
536 | self.I.stats['rewtotal'] += epinfo['r']
537 | # self.I.stats["best_ext_ret"] = self.best_ret
538 |
539 |
540 | return {'update' : update_info}
541 |
542 |
543 | class RewardForwardFilter(object):
544 | def __init__(self, gamma):
545 | self.rewems = None
546 | self.gamma = gamma
547 | def update(self, rews):
548 | if self.rewems is None:
549 | self.rewems = rews
550 | else:
551 | self.rewems = self.rewems * self.gamma + rews
552 | return self.rewems
553 |
554 | def flatten_lists(listoflists):
555 | return [el for list_ in listoflists for el in list_]
556 |
--------------------------------------------------------------------------------
/recorder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 |
7 | from baselines import logger
8 | from mpi4py import MPI
9 |
10 | def is_square(n):
11 | return n == (int(np.sqrt(n))) ** 2
12 |
13 | class Recorder(object):
14 | def __init__(self, nenvs, score_multiple=1):
15 | self.episodes = [defaultdict(list) for _ in range(nenvs)]
16 | self.total_episodes = 0
17 | self.filename = self.get_filename()
18 | self.score_multiple = score_multiple
19 |
20 | self.all_scores = {}
21 | self.all_places = {}
22 |
23 | def record(self, bufs, infos):
24 | for env_id, ep_infos in enumerate(infos):
25 | left_step = 0
26 | done_steps = sorted(ep_infos.keys())
27 | for right_step in done_steps:
28 | for key in bufs:
29 | self.episodes[env_id][key].append(bufs[key][env_id, left_step:right_step].copy())
30 | self.record_episode(env_id, ep_infos[right_step])
31 | left_step = right_step
32 | for key in bufs:
33 | self.episodes[env_id][key].clear()
34 | for key in bufs:
35 | self.episodes[env_id][key].append(bufs[key][env_id, left_step:].copy())
36 |
37 |
38 | def record_episode(self, env_id, info):
39 | self.total_episodes += 1
40 | if self.episode_worth_saving(env_id, info):
41 | episode = {}
42 | for key in self.episodes[env_id]:
43 | episode[key] = np.concatenate(self.episodes[env_id][key])
44 | info['env_id'] = env_id
45 | episode['info'] = info
46 | with open(self.filename, 'ab') as f:
47 | pickle.dump(episode, f, protocol=-1)
48 |
49 | def get_score(self, info):
50 | return int(info['r']/self.score_multiple) * self.score_multiple
51 |
52 | def episode_worth_saving(self, env_id, info):
53 | if self.score_multiple is None:
54 | return False
55 | r = self.get_score(info)
56 | if r not in self.all_scores:
57 | self.all_scores[r] = 0
58 | else:
59 | self.all_scores[r] += 1
60 | hashable_places = tuple(sorted(info['places']))
61 | if hashable_places not in self.all_places:
62 | self.all_places[hashable_places] = 0
63 | else:
64 | self.all_places[hashable_places] += 1
65 | if is_square(self.all_scores[r]) or is_square(self.all_places[hashable_places]):
66 | return True
67 | if 15 in info['places']:
68 | return True
69 | return False
70 |
71 | def get_filename(self):
72 | filename = os.path.join(logger.get_dir(), 'videos_{}.pk'.format(MPI.COMM_WORLD.Get_rank()))
73 | return filename
74 |
75 |
--------------------------------------------------------------------------------
/replayer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import glob
4 | import os
5 | import pickle
6 | import sys
7 |
8 | import exptag
9 | import ipdb
10 | import numpy as np
11 | from atari_wrappers import make_atari, wrap_deepmind
12 | from run_atari import add_env_params
13 |
14 | seen_scores = set()
15 |
16 |
17 | class EpisodeIterator(object):
18 | def __init__(self, filenames):
19 | if args['filter'] == 'none':
20 | cond = lambda info: True
21 | elif args['filter'] == 'rew':
22 | cond = lambda info: info['r'] < args['rew_max'] and (info['r'] > args['rew_min'])
23 | elif args['filter'] == 'room':
24 | def cond(info):
25 | return any(int(room) in args['room_number'] for room in info['places'])
26 |
27 | self.filenames = filenames
28 | self.condition = cond
29 | self.episode_number = 0
30 |
31 | def iterate(self):
32 | for filename in self.filenames:
33 | print("Opening file", filename)
34 | with open(filename, 'rb') as f:
35 | yield from self.iterate_over_episodes_in_file(f, condition=self.condition)
36 | raise StopIteration
37 |
38 | def iterate_over_episodes_in_file(self, file, condition):
39 | while True:
40 | try:
41 | episode = pickle.load(file)
42 | except:
43 | raise StopIteration
44 |
45 | info = episode['info']
46 | if condition(info):
47 | print(self.episode_number)
48 | self.episode_number += 1
49 | if self.episode_number >= args['skip']:
50 | if 'obs' in episode:
51 | # import ipdb; ipdb.set_trace()
52 | yield episode
53 | else:
54 | unwrapped_env = env.unwrapped
55 | if 'rng_at_episode_start' in info:
56 | random_state = info['rng_at_episode_start']
57 | unwrapped_env.np_random.set_state(random_state.get_state())
58 | if hasattr(unwrapped_env, "scene"):
59 | unwrapped_env.scene.np_random.set_state(random_state.get_state())
60 | ob = env.reset()
61 | ret = 0
62 | frames = []
63 | infos = []
64 | for i, a in enumerate(episode['acs']):
65 | ob, r, d, info = env.step(a)
66 | if args['display'] == 'game':
67 | rend = unwrapped_env.render(mode="rgb_array")
68 | else:
69 | rend = np.asarray(ob)[:, :, :1]
70 | frames.append(rend)
71 | ret += r
72 | infos.append(info)
73 | assert not d or i == len(episode['acs']) - 1, ipdb.set_trace()
74 | assert d, ipdb.set_trace()
75 | assert ret == episode['info']['r'], (ret, episode['info']['r'])
76 | episode['obs'] = frames
77 | episode['infos'] = infos
78 | print(episode.keys())
79 | yield episode
80 |
81 |
82 | class Animation(object):
83 | def __init__(self, episodes):
84 | self.episodes = episodes
85 |
86 | self.pause = False
87 | self.delta = 1
88 | self.j = 0
89 |
90 | self.fig = self.create_empty_figure()
91 |
92 | self.fig.canvas.mpl_connect('key_press_event', self.onKeyPress)
93 |
94 | self.axes = {}
95 | self.lines = {}
96 | self.dots = {}
97 | # self.ax1 = self.fig.add_subplot(1, 2, 1)
98 | # self.ax2 = self.fig.add_subplot(1, 2, 2)
99 |
100 | def create_empty_figure(self):
101 | fig = plt.figure()
102 | for evt, callback in fig.canvas.callbacks.callbacks.items():
103 | items = list(callback.items())
104 | for cid, _ in items:
105 | fig.canvas.mpl_disconnect(cid)
106 | return fig
107 |
108 | def onKeyPress(self, event):
109 | if event.key == 'left':
110 | self.pause = True
111 | if self.j > 0:
112 | self.j -= 1
113 | elif event.key == 'right':
114 | self.pause = True
115 | if self.j < len(self.episode['obs']) - 1:
116 | self.j += 1
117 | elif event.key == 'n':
118 | self.pause = False
119 | self.j = len(self.episode['obs']) - 1
120 | elif event.key == ' ':
121 | self.pause = not self.pause
122 | elif event.key == 'q':
123 | sys.exit()
124 | elif event.key == 'f':
125 | self.delta = 1 if self.delta > 1 else 8
126 | elif event.key == 'b':
127 | self.j = max(self.j-100, 0)
128 |
129 | def create_axes(self, episode):
130 | assert self.axes == {}
131 | keys = [key for key in episode.keys() if key not in ['acs', 'infos', 'obs', 'info']]
132 | keys.insert(0, 'obs')
133 | n_rows = int(np.floor(np.sqrt(len(keys))))
134 | n_cols = int(np.ceil(len(keys) / n_rows))
135 | for i, key in enumerate(keys, start=1):
136 | self.axes[key] = self.fig.add_subplot(n_rows, n_cols, i)
137 |
138 | def process_frame(self, frame):
139 | if frame.shape[-1] == 3:
140 | return frame
141 | else:
142 | return frame[:, :, -1]
143 |
144 | def run(self):
145 | self.episode = next(self.episodes)
146 |
147 | if self.axes == {}:
148 | self.create_axes(self.episode)
149 |
150 | self.im = self.axes['obs'].imshow(self.process_frame(self.episode['obs'][0]), cmap='gray')
151 | for key in self.axes:
152 | if key != 'obs':
153 | line, = self.axes[key].plot(self.episode[key], alpha=0.5)
154 | dot = matplotlib.patches.Ellipse(xy=(0, 0), width=1, height=0.0001, color='r')
155 | self.axes[key].add_artist(dot)
156 | self.axes[key].set_title(key)
157 | self.lines[key] = line
158 | self.dots[key] = dot
159 |
160 |
161 | def draw_frame_i(i):
162 | # update the data
163 | if self.j == 0:
164 | for key in self.axes:
165 | if key != 'obs':
166 | data = self.episode[key]
167 | n_timesteps = len(data)
168 | self.lines[key].set_data(range(n_timesteps), data)
169 | self.axes[key].set_xlim(0, n_timesteps)
170 | min_y, max_y = np.min(data), np.max(data)
171 | self.axes[key].set_ylim(min_y, max_y)
172 |
173 | self.dots[key].height = (max_y - min_y) / 30.
174 | self.dots[key].width = n_timesteps / 30.
175 | self.im.set_data(self.process_frame(self.episode['obs'][self.j]))
176 | for key in self.axes:
177 | if key != 'obs':
178 | self.dots[key].center = (self.j, self.episode[key][self.j])
179 | if not self.pause:
180 | self.j += self.delta
181 | if self.j > len(self.episode['obs']) - 1:
182 | self.episode = next(episodes)
183 | self.j = 0
184 | return [self.im] + list(self.lines.values()) + list(self.dots.values())
185 |
186 | ani = animation.FuncAnimation(self.fig, draw_frame_i, blit=False, interval=1,
187 | repeat=False)
188 | plt.show()
189 | plt.close()
190 |
191 |
192 | if __name__ == '__main__':
193 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
194 | add_env_params(parser)
195 | parser.add_argument('--filter', type=str, default='none')
196 | parser.add_argument('--rew_min', type=int, default=0)
197 | parser.add_argument('--rew_max', type=int, default=np.inf)
198 | parser.add_argument('--tag', type=str, default=None)
199 | parser.add_argument('--kind', type=str, default='plot')
200 | parser.add_argument('--display', type=str, default='game', choices=['game', 'agent'])
201 | parser.add_argument('--skip', type=int, default=0)
202 | parser.add_argument('--room_number', type=lambda x: [int(_) for _ in x.split(',')], default=[15])
203 |
204 |
205 |
206 | args = parser.parse_args().__dict__
207 | folder = exptag.get_last_experiment_folder_by_tag(args['tag'])
208 |
209 | def date_from_folder(folder):
210 | assert folder.startswith('openai-')
211 | date_started = folder[len('openai-'):]
212 | return datetime.datetime.strptime(date_started, "%Y-%m-%d-%H-%M-%S-%f")
213 |
214 | date_started = date_from_folder(os.path.basename(folder))
215 | machine_dir = os.path.dirname(folder)
216 | if machine_dir[-4:-1]=='-00':
217 | all_machine_dirs = glob.glob(machine_dir[:-1]+'*')
218 | else:
219 | all_machine_dirs = [machine_dir]
220 | other_folders = []
221 | for machine_dir in all_machine_dirs:
222 | this_machine_other_folders = os.listdir(machine_dir)
223 | this_machine_other_folders = [f_ for f_ in this_machine_other_folders
224 | if f_.startswith("openai-") and abs((date_from_folder(f_) - date_started).total_seconds()) < 3]
225 | this_machine_other_folders = [os.path.join(machine_dir, f_) for f_ in this_machine_other_folders]
226 | other_folders.extend(this_machine_other_folders)
227 |
228 | filenames = [glob.glob(os.path.join(f_, "videos_*.pk")) for f_ in other_folders]
229 | assert all(len(files_) == 1 for files_ in filenames), filenames
230 | filenames = [files_[0] for files_ in filenames]
231 |
232 | env = make_atari(args['env'], max_episode_steps=args['max_episode_steps'])
233 | if args['display'] == 'agent':
234 | env = wrap_deepmind(env, frame_stack=4, clip_rewards=False)
235 | env.reset()
236 | un_env = env.unwrapped
237 | rend_shape = un_env.render(mode='rgb_array').shape
238 | episodes = EpisodeIterator(filenames).iterate()
239 | if args['kind'] == 'movie':
240 | import imageio
241 | import time
242 | for i, episode in enumerate(episodes):
243 | filename = os.path.expanduser('~/tmp/movie_{}.mp4'.format(time.time()))
244 | imageio.mimwrite(filename, episode["obs"], fps=30)
245 | print(filename)
246 |
247 | else:
248 | import matplotlib.patches
249 | import matplotlib.pyplot as plt
250 | import matplotlib.animation as animation
251 | print('left/right, space, n, q, f keys are special')
252 | Animation(episodes).run()
253 |
--------------------------------------------------------------------------------
/run_atari.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import functools
3 | import os
4 |
5 | from baselines import logger
6 | from mpi4py import MPI
7 | import mpi_util
8 | import tf_util
9 | from cmd_util import make_atari_env, arg_parser
10 | from policies.cnn_gru_policy_dynamics import CnnGruPolicy
11 | from policies.cnn_policy_param_matched import CnnPolicy
12 | from ppo_agent import PpoAgent
13 | from utils import set_global_seeds
14 | from vec_env import VecFrameStack
15 |
16 |
17 | def train(*, env_id, num_env, hps, num_timesteps, seed):
18 | venv = VecFrameStack(
19 | make_atari_env(env_id, num_env, seed, wrapper_kwargs=dict(),
20 | start_index=num_env * MPI.COMM_WORLD.Get_rank(),
21 | max_episode_steps=hps.pop('max_episode_steps')),
22 | hps.pop('frame_stack'))
23 | # venv.score_multiple = {'Mario': 500,
24 | # 'MontezumaRevengeNoFrameskip-v4': 100,
25 | # 'GravitarNoFrameskip-v4': 250,
26 | # 'PrivateEyeNoFrameskip-v4': 500,
27 | # 'SolarisNoFrameskip-v4': None,
28 | # 'VentureNoFrameskip-v4': 200,
29 | # 'PitfallNoFrameskip-v4': 100,
30 | # }[env_id]
31 | venv.score_multiple = 1
32 | venv.record_obs = True if env_id == 'SolarisNoFrameskip-v4' else False
33 | ob_space = venv.observation_space
34 | ac_space = venv.action_space
35 | gamma = hps.pop('gamma')
36 | policy = {'rnn': CnnGruPolicy,
37 | 'cnn': CnnPolicy}[hps.pop('policy')]
38 | agent = PpoAgent(
39 | scope='ppo',
40 | ob_space=ob_space,
41 | ac_space=ac_space,
42 | stochpol_fn=functools.partial(
43 | policy,
44 | scope='pol',
45 | ob_space=ob_space,
46 | ac_space=ac_space,
47 | update_ob_stats_independently_per_gpu=hps.pop('update_ob_stats_independently_per_gpu'),
48 | proportion_of_exp_used_for_predictor_update=hps.pop('proportion_of_exp_used_for_predictor_update'),
49 | dynamics_bonus = hps.pop("dynamics_bonus")
50 | ),
51 | gamma=gamma,
52 | gamma_ext=hps.pop('gamma_ext'),
53 | lam=hps.pop('lam'),
54 | nepochs=hps.pop('nepochs'),
55 | nminibatches=hps.pop('nminibatches'),
56 | lr=hps.pop('lr'),
57 | cliprange=0.1,
58 | nsteps=128,
59 | ent_coef=0.001,
60 | max_grad_norm=hps.pop('max_grad_norm'),
61 | use_news=hps.pop("use_news"),
62 | comm=MPI.COMM_WORLD if MPI.COMM_WORLD.Get_size() > 1 else None,
63 | update_ob_stats_every_step=hps.pop('update_ob_stats_every_step'),
64 | int_coeff=hps.pop('int_coeff'),
65 | ext_coeff=hps.pop('ext_coeff'),
66 | )
67 | agent.start_interaction([venv])
68 | if hps.pop('update_ob_stats_from_random_agent'):
69 | agent.collect_random_statistics(num_timesteps=128*50)
70 | assert len(hps) == 0, "Unused hyperparameters: %s" % list(hps.keys())
71 |
72 | counter = 0
73 | while True:
74 | info = agent.step()
75 | if info['update']:
76 | logger.logkvs(info['update'])
77 | logger.dumpkvs()
78 | counter += 1
79 | if agent.I.stats['tcount'] > num_timesteps:
80 | break
81 |
82 | agent.stop_interaction()
83 |
84 |
85 | def add_env_params(parser):
86 | parser.add_argument('--env', help='environment ID', default='MontezumaRevengeNoFrameskip-v4')
87 | parser.add_argument('--seed', help='RNG seed', type=int, default=0)
88 | parser.add_argument('--max_episode_steps', type=int, default=4500)
89 |
90 |
91 | def main():
92 | parser = arg_parser()
93 | add_env_params(parser)
94 | parser.add_argument('--num-timesteps', type=int, default=int(1e12))
95 | parser.add_argument('--num_env', type=int, default=32)
96 | parser.add_argument('--use_news', type=int, default=0)
97 | parser.add_argument('--gamma', type=float, default=0.99)
98 | parser.add_argument('--gamma_ext', type=float, default=0.99)
99 | parser.add_argument('--lam', type=float, default=0.95)
100 | parser.add_argument('--update_ob_stats_every_step', type=int, default=0)
101 | parser.add_argument('--update_ob_stats_independently_per_gpu', type=int, default=0)
102 | parser.add_argument('--update_ob_stats_from_random_agent', type=int, default=1)
103 | parser.add_argument('--proportion_of_exp_used_for_predictor_update', type=float, default=1.)
104 | parser.add_argument('--tag', type=str, default='')
105 | parser.add_argument('--policy', type=str, default='rnn', choices=['cnn', 'rnn'])
106 | parser.add_argument('--int_coeff', type=float, default=1.)
107 | parser.add_argument('--ext_coeff', type=float, default=2.)
108 | parser.add_argument('--dynamics_bonus', type=int, default=0)
109 |
110 |
111 | args = parser.parse_args()
112 | logger.configure(dir=logger.get_dir(), format_strs=['stdout', 'log', 'csv'] if MPI.COMM_WORLD.Get_rank() == 0 else [])
113 | if MPI.COMM_WORLD.Get_rank() == 0:
114 | with open(os.path.join(logger.get_dir(), 'experiment_tag.txt'), 'w') as f:
115 | f.write(args.tag)
116 | # shutil.copytree(os.path.dirname(os.path.abspath(__file__)), os.path.join(logger.get_dir(), 'code'))
117 |
118 | mpi_util.setup_mpi_gpus()
119 |
120 | seed = 10000 * args.seed + MPI.COMM_WORLD.Get_rank()
121 | set_global_seeds(seed)
122 |
123 | hps = dict(
124 | frame_stack=4,
125 | nminibatches=4,
126 | nepochs=4,
127 | lr=0.0001,
128 | max_grad_norm=0.0,
129 | use_news=args.use_news,
130 | gamma=args.gamma,
131 | gamma_ext=args.gamma_ext,
132 | max_episode_steps=args.max_episode_steps,
133 | lam=args.lam,
134 | update_ob_stats_every_step=args.update_ob_stats_every_step,
135 | update_ob_stats_independently_per_gpu=args.update_ob_stats_independently_per_gpu,
136 | update_ob_stats_from_random_agent=args.update_ob_stats_from_random_agent,
137 | proportion_of_exp_used_for_predictor_update=args.proportion_of_exp_used_for_predictor_update,
138 | policy=args.policy,
139 | int_coeff=args.int_coeff,
140 | ext_coeff=args.ext_coeff,
141 | dynamics_bonus = args.dynamics_bonus
142 | )
143 |
144 | tf_util.make_session(make_default=True)
145 | train(env_id=args.env, num_env=args.num_env, seed=seed,
146 | num_timesteps=args.num_timesteps, hps=hps)
147 |
148 |
149 | if __name__ == '__main__':
150 | main()
151 |
--------------------------------------------------------------------------------
/stochastic_policy.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from baselines.common.distributions import make_pdtype
3 | from collections import OrderedDict
4 | from gym import spaces
5 |
6 | def canonical_dtype(orig_dt):
7 | if orig_dt.kind == 'f':
8 | return tf.float32
9 | elif orig_dt.kind in 'iu':
10 | return tf.int32
11 | else:
12 | raise NotImplementedError
13 |
14 | class StochasticPolicy(object):
15 | def __init__(self, scope, ob_space, ac_space):
16 | self.abs_scope = (tf.get_variable_scope().name + '/' + scope).lstrip('/')
17 | self.ob_space = ob_space
18 | self.ac_space = ac_space
19 | self.pdtype = make_pdtype(ac_space)
20 | self.ph_new = tf.placeholder(dtype=tf.float32, shape=(None, None), name='new')
21 | self.ph_ob_keys = []
22 | self.ph_ob_dtypes = {}
23 | shapes = {}
24 | if isinstance(ob_space, spaces.Dict):
25 | assert isinstance(ob_space.spaces, OrderedDict)
26 | for key, box in ob_space.spaces.items():
27 | assert isinstance(box, spaces.Box)
28 | self.ph_ob_keys.append(key)
29 | # Keys must be ordered, because tf.concat(ph) depends on order. Here we don't keep OrderedDict
30 | # order and sort keys instead. Rationale is to give freedom to modify environment.
31 | self.ph_ob_keys.sort()
32 | for k in self.ph_ob_keys:
33 | self.ph_ob_dtypes[k] = ob_space.spaces[k].dtype
34 | shapes[k] = ob_space.spaces[k].shape
35 | else:
36 | print(ob_space)
37 | box = ob_space
38 | assert isinstance(box, spaces.Box)
39 | self.ph_ob_keys = [None]
40 | self.ph_ob_dtypes = { None: box.dtype }
41 | shapes = { None: box.shape }
42 | self.ph_ob = OrderedDict([(k, tf.placeholder(
43 | canonical_dtype(self.ph_ob_dtypes[k]),
44 | (None, None,) + tuple(shapes[k]),
45 | name=(('obs/%s'%k) if k is not None else 'obs')
46 | )) for k in self.ph_ob_keys ])
47 | assert list(self.ph_ob.keys())==self.ph_ob_keys, "\n%s\n%s\n" % (list(self.ph_ob.keys()), self.ph_ob_keys)
48 | ob_shape = tf.shape(next(iter(self.ph_ob.values())))
49 | self.sy_nenvs = ob_shape[0]
50 | self.sy_nsteps = ob_shape[1]
51 | self.ph_ac = self.pdtype.sample_placeholder([None, None], name='ac')
52 | self.pd = self.vpred = self.ph_istate = None
53 |
54 | def finalize(self, pd, vpred, ph_istate=None): #pylint: disable=W0221
55 | self.pd = pd
56 | self.vpred = vpred
57 | self.ph_istate = ph_istate
58 |
59 | def ensure_observation_is_dict(self, ob):
60 | if self.ph_ob_keys==[None]:
61 | return { None: ob }
62 | else:
63 | return ob
64 |
65 | def call(self, ob, new, istate):
66 | """
67 | Return acs, vpred, neglogprob, nextstate
68 | """
69 | raise NotImplementedError
70 |
71 | def initial_state(self, n):
72 | raise NotImplementedError
73 |
74 | def update_normalization(self, ob):
75 | pass
76 |
--------------------------------------------------------------------------------
/tf_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf # pylint: ignore-module
3 | import copy
4 | import os
5 | import functools
6 | import collections
7 | import multiprocessing
8 |
9 | def switch(condition, then_expression, else_expression):
10 | """Switches between two operations depending on a scalar value (int or bool).
11 | Note that both `then_expression` and `else_expression`
12 | should be symbolic tensors of the *same shape*.
13 |
14 | # Arguments
15 | condition: scalar tensor.
16 | then_expression: TensorFlow operation.
17 | else_expression: TensorFlow operation.
18 | """
19 | x_shape = copy.copy(then_expression.get_shape())
20 | x = tf.cond(tf.cast(condition, 'bool'),
21 | lambda: then_expression,
22 | lambda: else_expression)
23 | x.set_shape(x_shape)
24 | return x
25 |
26 | # ================================================================
27 | # Extras
28 | # ================================================================
29 |
30 | def lrelu(x, leak=0.2):
31 | f1 = 0.5 * (1 + leak)
32 | f2 = 0.5 * (1 - leak)
33 | return f1 * x + f2 * abs(x)
34 |
35 | # ================================================================
36 | # Mathematical utils
37 | # ================================================================
38 |
39 | def huber_loss(x, delta=1.0):
40 | """Reference: https://en.wikipedia.org/wiki/Huber_loss"""
41 | return tf.where(
42 | tf.abs(x) < delta,
43 | tf.square(x) * 0.5,
44 | delta * (tf.abs(x) - 0.5 * delta)
45 | )
46 |
47 | # ================================================================
48 | # Global session
49 | # ================================================================
50 |
51 | def make_session(num_cpu=None, make_default=False, graph=None):
52 | """Returns a session that will use CPU's only"""
53 | if num_cpu is None:
54 | num_cpu = int(os.getenv('RCALL_NUM_CPU', multiprocessing.cpu_count()))
55 | tf_config = tf.ConfigProto(
56 | inter_op_parallelism_threads=num_cpu,
57 | intra_op_parallelism_threads=num_cpu)
58 | if make_default:
59 | return tf.InteractiveSession(config=tf_config, graph=graph)
60 | else:
61 | return tf.Session(config=tf_config, graph=graph)
62 |
63 | def single_threaded_session():
64 | """Returns a session which will only use a single CPU"""
65 | return make_session(num_cpu=1)
66 |
67 | def in_session(f):
68 | @functools.wraps(f)
69 | def newfunc(*args, **kwargs):
70 | with tf.Session():
71 | f(*args, **kwargs)
72 | return newfunc
73 |
74 | ALREADY_INITIALIZED = set()
75 |
76 | def initialize():
77 | """Initialize all the uninitialized variables in the global scope."""
78 | new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED
79 | tf.get_default_session().run(tf.variables_initializer(new_variables))
80 | ALREADY_INITIALIZED.update(new_variables)
81 |
82 | # ================================================================
83 | # Model components
84 | # ================================================================
85 |
86 | def normc_initializer(std=1.0, axis=0):
87 | def _initializer(shape, dtype=None, partition_info=None): # pylint: disable=W0613
88 | out = np.random.randn(*shape).astype(np.float32)
89 | out *= std / np.sqrt(np.square(out).sum(axis=axis, keepdims=True))
90 | return tf.constant(out)
91 | return _initializer
92 |
93 | def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME", dtype=tf.float32, collections=None,
94 | summary_tag=None):
95 | with tf.variable_scope(name):
96 | stride_shape = [1, stride[0], stride[1], 1]
97 | filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]), num_filters]
98 |
99 | # there are "num input feature maps * filter height * filter width"
100 | # inputs to each hidden unit
101 | fan_in = intprod(filter_shape[:3])
102 | # each unit in the lower layer receives a gradient from:
103 | # "num output feature maps * filter height * filter width" /
104 | # pooling size
105 | fan_out = intprod(filter_shape[:2]) * num_filters
106 | # initialize weights with random weights
107 | w_bound = np.sqrt(6. / (fan_in + fan_out))
108 |
109 | w = tf.get_variable("W", filter_shape, dtype, tf.random_uniform_initializer(-w_bound, w_bound),
110 | collections=collections)
111 | b = tf.get_variable("b", [1, 1, 1, num_filters], initializer=tf.zeros_initializer(),
112 | collections=collections)
113 |
114 | if summary_tag is not None:
115 | tf.summary.image(summary_tag,
116 | tf.transpose(tf.reshape(w, [filter_size[0], filter_size[1], -1, 1]),
117 | [2, 0, 1, 3]),
118 | max_images=10)
119 |
120 | return tf.nn.conv2d(x, w, stride_shape, pad) + b
121 |
122 | # ================================================================
123 | # Theano-like Function
124 | # ================================================================
125 |
126 | def function(inputs, outputs, updates=None, givens=None):
127 | """Just like Theano function. Take a bunch of tensorflow placeholders and expressions
128 | computed based on those placeholders and produces f(inputs) -> outputs. Function f takes
129 | values to be fed to the input's placeholders and produces the values of the expressions
130 | in outputs.
131 |
132 | Input values can be passed in the same order as inputs or can be provided as kwargs based
133 | on placeholder name (passed to constructor or accessible via placeholder.op.name).
134 |
135 | Example:
136 | x = tf.placeholder(tf.int32, (), name="x")
137 | y = tf.placeholder(tf.int32, (), name="y")
138 | z = 3 * x + 2 * y
139 | lin = function([x, y], z, givens={y: 0})
140 |
141 | with single_threaded_session():
142 | initialize()
143 |
144 | assert lin(2) == 6
145 | assert lin(x=3) == 9
146 | assert lin(2, 2) == 10
147 | assert lin(x=2, y=3) == 12
148 |
149 | Parameters
150 | ----------
151 | inputs: [tf.placeholder, tf.constant, or object with make_feed_dict method]
152 | list of input arguments
153 | outputs: [tf.Variable] or tf.Variable
154 | list of outputs or a single output to be returned from function. Returned
155 | value will also have the same shape.
156 | """
157 | if isinstance(outputs, list):
158 | return _Function(inputs, outputs, updates, givens=givens)
159 | elif isinstance(outputs, (dict, collections.OrderedDict)):
160 | f = _Function(inputs, outputs.values(), updates, givens=givens)
161 | return lambda *args, **kwargs: type(outputs)(zip(outputs.keys(), f(*args, **kwargs)))
162 | else:
163 | f = _Function(inputs, [outputs], updates, givens=givens)
164 | return lambda *args, **kwargs: f(*args, **kwargs)[0]
165 |
166 |
167 | class _Function(object):
168 | def __init__(self, inputs, outputs, updates, givens):
169 | for inpt in inputs:
170 | if not hasattr(inpt, 'make_feed_dict') and not (type(inpt) is tf.Tensor and len(inpt.op.inputs) == 0):
171 | assert False, "inputs should all be placeholders, constants, or have a make_feed_dict method"
172 | self.inputs = inputs
173 | updates = updates or []
174 | self.update_group = tf.group(*updates)
175 | self.outputs_update = list(outputs) + [self.update_group]
176 | self.givens = {} if givens is None else givens
177 |
178 | def _feed_input(self, feed_dict, inpt, value):
179 | if hasattr(inpt, 'make_feed_dict'):
180 | feed_dict.update(inpt.make_feed_dict(value))
181 | else:
182 | feed_dict[inpt] = value
183 |
184 | def __call__(self, *args):
185 | assert len(args) <= len(self.inputs), "Too many arguments provided"
186 | feed_dict = {}
187 | # Update the args
188 | for inpt, value in zip(self.inputs, args):
189 | self._feed_input(feed_dict, inpt, value)
190 | # Update feed dict with givens.
191 | for inpt in self.givens:
192 | feed_dict[inpt] = feed_dict.get(inpt, self.givens[inpt])
193 | results = tf.get_default_session().run(self.outputs_update, feed_dict=feed_dict)[:-1]
194 | return results
195 |
196 | # ================================================================
197 | # Flat vectors
198 | # ================================================================
199 |
200 | def var_shape(x):
201 | out = x.get_shape().as_list()
202 | assert all(isinstance(a, int) for a in out), \
203 | "shape function assumes that shape is fully known"
204 | return out
205 |
206 | def numel(x):
207 | return intprod(var_shape(x))
208 |
209 | def intprod(x):
210 | return int(np.prod(x))
211 |
212 | def flatgrad(loss, var_list, clip_norm=None):
213 | grads = tf.gradients(loss, var_list)
214 | if clip_norm is not None:
215 | grads = [tf.clip_by_norm(grad, clip_norm=clip_norm) for grad in grads]
216 | return tf.concat(axis=0, values=[
217 | tf.reshape(grad if grad is not None else tf.zeros_like(v), [numel(v)])
218 | for (v, grad) in zip(var_list, grads)
219 | ])
220 |
221 | class SetFromFlat(object):
222 | def __init__(self, var_list, dtype=tf.float32):
223 | assigns = []
224 | shapes = list(map(var_shape, var_list))
225 | total_size = np.sum([intprod(shape) for shape in shapes])
226 |
227 | self.theta = theta = tf.placeholder(dtype, [total_size])
228 | start = 0
229 | assigns = []
230 | for (shape, v) in zip(shapes, var_list):
231 | size = intprod(shape)
232 | assigns.append(tf.assign(v, tf.reshape(theta[start:start + size], shape)))
233 | start += size
234 | self.op = tf.group(*assigns)
235 |
236 | def __call__(self, theta):
237 | tf.get_default_session().run(self.op, feed_dict={self.theta: theta})
238 |
239 | class GetFlat(object):
240 | def __init__(self, var_list):
241 | self.op = tf.concat(axis=0, values=[tf.reshape(v, [numel(v)]) for v in var_list])
242 |
243 | def __call__(self):
244 | return tf.get_default_session().run(self.op)
245 |
246 | _PLACEHOLDER_CACHE = {} # name -> (placeholder, dtype, shape)
247 |
248 | def get_placeholder(name, dtype, shape):
249 | if name in _PLACEHOLDER_CACHE:
250 | out, dtype1, shape1 = _PLACEHOLDER_CACHE[name]
251 | assert dtype1 == dtype and shape1 == shape
252 | return out
253 | else:
254 | out = tf.placeholder(dtype=dtype, shape=shape, name=name)
255 | _PLACEHOLDER_CACHE[name] = (out, dtype, shape)
256 | return out
257 |
258 | def get_placeholder_cached(name):
259 | return _PLACEHOLDER_CACHE[name][0]
260 |
261 | def flattenallbut0(x):
262 | return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])])
263 |
264 |
265 | # ================================================================
266 | # Diagnostics
267 | # ================================================================
268 |
269 | def display_var_info(vars):
270 | from baselines import logger
271 | count_params = 0
272 | for v in vars:
273 | name = v.name
274 | if "/Adam" in name or "beta1_power" in name or "beta2_power" in name: continue
275 | v_params = np.prod(v.shape.as_list())
276 | count_params += v_params
277 | if "/b:" in name or "/biases" in name: continue # Wx+b, bias is not interesting to look at => count params, but not print
278 | logger.info(" %s%s %i params %s" % (name, " "*(55-len(name)), v_params, str(v.shape)))
279 |
280 | logger.info("Total model parameters: %0.2f million" % (count_params*1e-6))
281 |
282 |
283 | def get_available_gpus():
284 | # recipe from here:
285 | # https://stackoverflow.com/questions/38559755/how-to-get-current-available-gpus-in-tensorflow?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa
286 |
287 | from tensorflow.python.client import device_lib
288 | local_device_protos = device_lib.list_local_devices()
289 | return [x.name for x in local_device_protos if x.device_type == 'GPU']
290 |
291 | # ================================================================
292 | # Saving variables
293 | # ================================================================
294 |
295 | def load_state(fname):
296 | saver = tf.train.Saver()
297 | saver.restore(tf.get_default_session(), fname)
298 |
299 | def save_state(fname):
300 | os.makedirs(os.path.dirname(fname), exist_ok=True)
301 | saver = tf.train.Saver()
302 | saver.save(tf.get_default_session(), fname)
303 |
304 |
305 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import random
4 | from mpi_util import mpi_moments
5 |
6 |
7 | def fc(x, scope, nh, *, init_scale=1.0, init_bias=0.0):
8 | with tf.variable_scope(scope):
9 | nin = x.get_shape()[1].value
10 | w = tf.get_variable("w", [nin, nh], initializer=ortho_init(init_scale))
11 | b = tf.get_variable("b", [nh], initializer=tf.constant_initializer(init_bias))
12 | return tf.matmul(x, w)+b
13 |
14 | def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0, data_format='NHWC', one_dim_bias=False):
15 | if data_format == 'NHWC':
16 | channel_ax = 3
17 | strides = [1, stride, stride, 1]
18 | bshape = [1, 1, 1, nf]
19 | elif data_format == 'NCHW':
20 | channel_ax = 1
21 | strides = [1, 1, stride, stride]
22 | bshape = [1, nf, 1, 1]
23 | else:
24 | raise NotImplementedError
25 | bias_var_shape = [nf] if one_dim_bias else [1, nf, 1, 1]
26 | nin = x.get_shape()[channel_ax].value
27 | wshape = [rf, rf, nin, nf]
28 | with tf.variable_scope(scope):
29 | w = tf.get_variable("w", wshape, initializer=ortho_init(init_scale))
30 | b = tf.get_variable("b", bias_var_shape, initializer=tf.constant_initializer(0.0))
31 | if not one_dim_bias and data_format == 'NHWC':
32 | b = tf.reshape(b, bshape)
33 | return b + tf.nn.conv2d(x, w, strides=strides, padding=pad, data_format=data_format)
34 |
35 | def ortho_init(scale=1.0):
36 | def _ortho_init(shape, dtype, partition_info=None):
37 | #lasagne ortho init for tf
38 | shape = tuple(shape)
39 | if len(shape) == 2:
40 | flat_shape = shape
41 | elif len(shape) == 4: # assumes NHWC
42 | flat_shape = (np.prod(shape[:-1]), shape[-1])
43 | else:
44 | raise NotImplementedError
45 | a = np.random.normal(0.0, 1.0, flat_shape)
46 | u, _, v = np.linalg.svd(a, full_matrices=False)
47 | q = u if u.shape == flat_shape else v # pick the one with the correct shape
48 | q = q.reshape(shape)
49 | return (scale * q[:shape[0], :shape[1]]).astype(np.float32)
50 | return _ortho_init
51 |
52 | def tile_images(array, n_cols=None, max_images=None, div=1):
53 | if max_images is not None:
54 | array = array[:max_images]
55 | if len(array.shape) == 4 and array.shape[3] == 1:
56 | array = array[:, :, :, 0]
57 | assert len(array.shape) in [3, 4], "wrong number of dimensions - shape {}".format(array.shape)
58 | if len(array.shape) == 4:
59 | assert array.shape[3] == 3, "wrong number of channels- shape {}".format(array.shape)
60 | if n_cols is None:
61 | n_cols = max(int(np.sqrt(array.shape[0])) // div * div, div)
62 | n_rows = int(np.ceil(float(array.shape[0]) / n_cols))
63 |
64 | def cell(i, j):
65 | ind = i * n_cols + j
66 | return array[ind] if ind < array.shape[0] else np.zeros(array[0].shape)
67 |
68 | def row(i):
69 | return np.concatenate([cell(i, j) for j in range(n_cols)], axis=1)
70 |
71 | return np.concatenate([row(i) for i in range(n_rows)], axis=0)
72 |
73 |
74 | def set_global_seeds(i):
75 | try:
76 | import tensorflow as tf
77 | except ImportError:
78 | pass
79 | else:
80 | from mpi4py import MPI
81 | tf.set_random_seed(i)
82 | np.random.seed(i)
83 | random.seed(i)
84 |
85 |
86 | def explained_variance_non_mpi(ypred,y):
87 | """
88 | Computes fraction of variance that ypred explains about y.
89 | Returns 1 - Var[y-ypred] / Var[y]
90 |
91 | interpretation:
92 | ev=0 => might as well have predicted zero
93 | ev=1 => perfect prediction
94 | ev<0 => worse than just predicting zero
95 |
96 | """
97 | assert y.ndim == 1 and ypred.ndim == 1
98 | vary = np.var(y)
99 | return np.nan if vary==0 else 1 - np.var(y-ypred)/vary
100 |
101 | def mpi_var(x):
102 | return mpi_moments(x)[1]**2
103 |
104 | def explained_variance(ypred,y):
105 | """
106 | Computes fraction of variance that ypred explains about y.
107 | Returns 1 - Var[y-ypred] / Var[y]
108 |
109 | interpretation:
110 | ev=0 => might as well have predicted zero
111 | ev=1 => perfect prediction
112 | ev<0 => worse than just predicting zero
113 |
114 | """
115 | assert y.ndim == 1 and ypred.ndim == 1
116 | vary = mpi_var(y)
117 | return np.nan if vary==0 else 1 - mpi_var(y-ypred)/vary
118 |
119 |
--------------------------------------------------------------------------------
/vec_env.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from multiprocessing import Process, Pipe
3 | from baselines import logger
4 | from utils import tile_images
5 |
6 | class AlreadySteppingError(Exception):
7 | """
8 | Raised when an asynchronous step is running while
9 | step_async() is called again.
10 | """
11 | def __init__(self):
12 | msg = 'already running an async step'
13 | Exception.__init__(self, msg)
14 |
15 | class NotSteppingError(Exception):
16 | """
17 | Raised when an asynchronous step is not running but
18 | step_wait() is called.
19 | """
20 | def __init__(self):
21 | msg = 'not running an async step'
22 | Exception.__init__(self, msg)
23 |
24 | class VecEnv(ABC):
25 | """
26 | An abstract asynchronous, vectorized environment.
27 | """
28 | def __init__(self, num_envs, observation_space, action_space):
29 | self.num_envs = num_envs
30 | self.observation_space = observation_space
31 | self.action_space = action_space
32 |
33 | @abstractmethod
34 | def reset(self):
35 | """
36 | Reset all the environments and return an array of
37 | observations, or a tuple of observation arrays.
38 |
39 | If step_async is still doing work, that work will
40 | be cancelled and step_wait() should not be called
41 | until step_async() is invoked again.
42 | """
43 | pass
44 |
45 | @abstractmethod
46 | def step_async(self, actions):
47 | """
48 | Tell all the environments to start taking a step
49 | with the given actions.
50 | Call step_wait() to get the results of the step.
51 |
52 | You should not call this if a step_async run is
53 | already pending.
54 | """
55 | pass
56 |
57 | @abstractmethod
58 | def step_wait(self):
59 | """
60 | Wait for the step taken with step_async().
61 |
62 | Returns (obs, rews, dones, infos):
63 | - obs: an array of observations, or a tuple of
64 | arrays of observations.
65 | - rews: an array of rewards
66 | - dones: an array of "episode done" booleans
67 | - infos: a sequence of info objects
68 | """
69 | pass
70 |
71 | @abstractmethod
72 | def close(self):
73 | """
74 | Clean up the environments' resources.
75 | """
76 | pass
77 |
78 | def step(self, actions):
79 | self.step_async(actions)
80 | return self.step_wait()
81 |
82 | def render(self, mode='human'):
83 | logger.warn('Render not defined for %s'%self)
84 |
85 | @property
86 | def unwrapped(self):
87 | if isinstance(self, VecEnvWrapper):
88 | return self.venv.unwrapped
89 | else:
90 | return self
91 |
92 | class VecEnvWrapper(VecEnv):
93 | def __init__(self, venv, observation_space=None, action_space=None):
94 | self.venv = venv
95 | VecEnv.__init__(self,
96 | num_envs=venv.num_envs,
97 | observation_space=observation_space or venv.observation_space,
98 | action_space=action_space or venv.action_space)
99 |
100 | def step_async(self, actions):
101 | self.venv.step_async(actions)
102 |
103 | @abstractmethod
104 | def reset(self):
105 | pass
106 |
107 | @abstractmethod
108 | def step_wait(self):
109 | pass
110 |
111 | def close(self):
112 | return self.venv.close()
113 |
114 | def render(self):
115 | self.venv.render()
116 |
117 | class CloudpickleWrapper(object):
118 | """
119 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
120 | """
121 | def __init__(self, x):
122 | self.x = x
123 | def __getstate__(self):
124 | import cloudpickle
125 | return cloudpickle.dumps(self.x)
126 | def __setstate__(self, ob):
127 | import pickle
128 | self.x = pickle.loads(ob)
129 |
130 | import numpy as np
131 | from gym import spaces
132 |
133 | class VecFrameStack(VecEnvWrapper):
134 | """
135 | Vectorized environment base class
136 | """
137 | def __init__(self, venv, nstack):
138 | self.venv = venv
139 | self.nstack = nstack
140 | wos = venv.observation_space # wrapped ob space
141 | low = np.repeat(wos.low, self.nstack, axis=-1)
142 | high = np.repeat(wos.high, self.nstack, axis=-1)
143 | self.stackedobs = np.zeros((venv.num_envs,)+low.shape, low.dtype)
144 | observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)
145 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
146 |
147 | def step_wait(self):
148 | obs, rews, news, infos = self.venv.step_wait()
149 | self.stackedobs = np.roll(self.stackedobs, shift=-1, axis=-1)
150 | for (i, new) in enumerate(news):
151 | if new:
152 | self.stackedobs[i] = 0
153 | self.stackedobs[..., -obs.shape[-1]:] = obs
154 | return self.stackedobs, rews, news, infos
155 |
156 | def reset(self):
157 | """
158 | Reset all environments
159 | """
160 | obs = self.venv.reset()
161 | self.stackedobs[...] = 0
162 | self.stackedobs[..., -obs.shape[-1]:] = obs
163 | return self.stackedobs
164 |
165 | def close(self):
166 | self.venv.close()
167 |
168 |
169 | class VecFrameStack(VecEnvWrapper):
170 | """
171 | Vectorized environment base class
172 | """
173 | def __init__(self, venv, nstack):
174 | self.venv = venv
175 | self.nstack = nstack
176 | wos = venv.observation_space # wrapped ob space
177 | low = np.repeat(wos.low, self.nstack, axis=-1)
178 | high = np.repeat(wos.high, self.nstack, axis=-1)
179 | self.stackedobs = np.zeros((venv.num_envs,)+low.shape, low.dtype)
180 | observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)
181 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
182 |
183 | def step_wait(self):
184 | obs, rews, news, infos = self.venv.step_wait()
185 | self.stackedobs = np.roll(self.stackedobs, shift=-1, axis=-1)
186 | for (i, new) in enumerate(news):
187 | if new:
188 | self.stackedobs[i] = 0
189 | self.stackedobs[..., -obs.shape[-1]:] = obs
190 | return self.stackedobs, rews, news, infos
191 |
192 | def reset(self):
193 | """
194 | Reset all environments
195 | """
196 | obs = self.venv.reset()
197 | self.stackedobs[...] = 0
198 | self.stackedobs[..., -obs.shape[-1]:] = obs
199 | return self.stackedobs
200 |
201 | def close(self):
202 | self.venv.close()
203 |
204 |
205 |
206 |
207 |
208 | def worker(remote, parent_remote, env_fn_wrapper):
209 | parent_remote.close()
210 | env = env_fn_wrapper.x()
211 | while True:
212 | cmd, data = remote.recv()
213 | if cmd == 'step':
214 | ob, reward, done, info = env.step(data)
215 | if done:
216 | ob = env.reset()
217 | remote.send((ob, reward, done, info))
218 | elif cmd == 'reset':
219 | ob = env.reset()
220 | remote.send(ob)
221 | elif cmd == 'render':
222 | remote.send(env.render(mode='rgb_array'))
223 | elif cmd == 'close':
224 | remote.close()
225 | break
226 | elif cmd == 'get_spaces':
227 | remote.send((env.observation_space, env.action_space))
228 | else:
229 | raise NotImplementedError
230 |
231 |
232 | class SubprocVecEnv(VecEnv):
233 | def __init__(self, env_fns, spaces=None):
234 | """
235 | envs: list of gym environments to run in subprocesses
236 | """
237 | self.waiting = False
238 | self.closed = False
239 | nenvs = len(env_fns)
240 | self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
241 | self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
242 | for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
243 | for p in self.ps:
244 | p.daemon = True # if the main process crashes, we should not cause things to hang
245 | p.start()
246 | for remote in self.work_remotes:
247 | remote.close()
248 |
249 | self.remotes[0].send(('get_spaces', None))
250 | observation_space, action_space = self.remotes[0].recv()
251 | VecEnv.__init__(self, len(env_fns), observation_space, action_space)
252 |
253 | def step_async(self, actions):
254 | for remote, action in zip(self.remotes, actions):
255 | remote.send(('step', action))
256 | self.waiting = True
257 |
258 | def step_wait(self):
259 | results = [remote.recv() for remote in self.remotes]
260 | self.waiting = False
261 | obs, rews, dones, infos = zip(*results)
262 | return np.stack(obs), np.stack(rews), np.stack(dones), infos
263 |
264 | def reset(self):
265 | for remote in self.remotes:
266 | remote.send(('reset', None))
267 | return np.stack([remote.recv() for remote in self.remotes])
268 |
269 | def reset_task(self):
270 | for remote in self.remotes:
271 | remote.send(('reset_task', None))
272 | return np.stack([remote.recv() for remote in self.remotes])
273 |
274 | def close(self):
275 | if self.closed:
276 | return
277 | if self.waiting:
278 | for remote in self.remotes:
279 | remote.recv()
280 | for remote in self.remotes:
281 | remote.send(('close', None))
282 | for p in self.ps:
283 | p.join()
284 | self.closed = True
285 |
286 | def render(self, mode='human'):
287 | for pipe in self.remotes:
288 | pipe.send(('render', None))
289 | imgs = [pipe.recv() for pipe in self.remotes]
290 | bigimg = tile_images(imgs)
291 | if mode == 'human':
292 | import cv2
293 | cv2.imshow('vecenv', bigimg[:,:,::-1])
294 | cv2.waitKey(1)
295 | elif mode == 'rgb_array':
296 | return bigimg
297 | else:
298 | raise NotImplementedError
299 |
--------------------------------------------------------------------------------