144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 | dir(self.manager_proxies[0])
156 | Python
157 | EXPRESSION
158 |
159 |
160 | node_manager_proxy.get_update_count()
161 | Python
162 | EXPRESSION
163 |
164 |
165 | node_manager_proxy.get_
166 | Python
167 | EXPRESSION
168 |
169 |
170 | self.manager_proxies[0].get_update_count()
171 | Python
172 | EXPRESSION
173 |
174 |
175 | self.manager_proxies[tree_idx].get_update_count()
176 | Python
177 | EXPRESSION
178 |
179 |
180 | manager_proxy.get_update_count()
181 | Python
182 | EXPRESSION
183 |
184 |
185 | manager_proxy.
186 | Python
187 | EXPRESSION
188 |
189 |
190 | manager_listener.listen()
191 | Python
192 | EXPRESSION
193 |
194 |
195 | np.array(r)[np.where(np.array(r) != 0.0)]
196 | Python
197 | EXPRESSION
198 |
199 |
200 | np.where(np.array(r) != 0.0)
201 | Python
202 | EXPRESSION
203 |
204 |
205 |
206 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies both within project spaces and in public spaces
49 | when an individual is representing the project or its community. Examples of
50 | representing a project or community include using an official project e-mail
51 | address, posting via an official social media account, or acting as an appointed
52 | representative at an online or offline event. Representation of a project may be
53 | further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the project team at anjiliu219@gmail.com. All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72 |
73 | [homepage]: https://www.contributor-covenant.org
74 |
75 | For answers to common questions about this code of conduct, see
76 | https://www.contributor-covenant.org/faq
77 |
--------------------------------------------------------------------------------
/Env/AtariEnv/AtariEnvWrapper.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .atari_wrappers import make_atari, wrap_deepmind, FrameStack
3 | from .vec_env import VecEnv
4 | from .VideoRecord import VideoRecorder
5 | from multiprocessing import Process, Pipe
6 |
7 |
8 | # cf https://github.com/openai/baselines
9 |
10 | def make_atari_env(env_name, rank, seed, enable_record = False, record_path = "./1.mp4"):
11 | env = make_atari(env_name)
12 | env.seed(seed + rank)
13 | env = wrap_deepmind(env, episode_life = False, clip_rewards = False)
14 | env = FrameStack(env, 4)
15 | recorder = VideoRecorder(env, path = record_path, enabled = enable_record)
16 | return env, recorder
17 |
18 |
19 | def worker(remote, parent_remote, env_fn_wrapper):
20 | parent_remote.close()
21 | env = env_fn_wrapper.x()
22 | while True:
23 | cmd, data = remote.recv()
24 | if cmd == 'step':
25 | ob, reward, done, info = env.step(data)
26 | if done:
27 | ob = env.reset()
28 | remote.send((ob, reward, done, info))
29 | elif cmd == 'reset':
30 | ob = env.reset()
31 | remote.send(ob)
32 | elif cmd == 'reset_task':
33 | ob = env.reset_task()
34 | remote.send(ob)
35 | elif cmd == 'close':
36 | remote.close()
37 | break
38 | elif cmd == 'get_spaces':
39 | remote.send((env.action_space, env.observation_space))
40 | elif cmd == 'render':
41 | env.render()
42 | else:
43 | raise NotImplementedError
44 |
45 |
46 | class CloudpickleWrapper(object):
47 | """
48 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
49 | """
50 | def __init__(self, x):
51 | self.x = x
52 | def __getstate__(self):
53 | import cloudpickle
54 | return cloudpickle.dumps(self.x)
55 | def __setstate__(self, ob):
56 | import pickle
57 | self.x = pickle.loads(ob)
58 |
59 |
60 | class RenderSubprocVecEnv(VecEnv):
61 | def __init__(self, env_fns, render_interval):
62 | """ Minor addition to SubprocVecEnv, automatically renders environments
63 |
64 | envs: list of gym environments to run in subprocesses
65 | """
66 | self.closed = False
67 | nenvs = len(env_fns)
68 | self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
69 | self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
70 | for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
71 | for p in self.ps:
72 | p.daemon = True # if the main process crashes, we should not cause things to hang
73 | p.start()
74 | for remote in self.work_remotes:
75 | remote.close()
76 |
77 | self.remotes[0].send(('get_spaces', None))
78 | self.action_space, self.observation_space = self.remotes[0].recv()
79 |
80 | self.render_interval = render_interval
81 | self.render_timer = 0
82 |
83 | def step(self, actions):
84 | for remote, action in zip(self.remotes, actions):
85 | remote.send(('step', action))
86 | results = [remote.recv() for remote in self.remotes]
87 | obs, rews, dones, infos = zip(*results)
88 |
89 | self.render_timer += 1
90 | if self.render_timer == self.render_interval:
91 | for remote in self.remotes:
92 | remote.send(('render', None))
93 | self.render_timer = 0
94 |
95 | return np.stack(obs), np.stack(rews), np.stack(dones), infos
96 |
97 | def reset(self):
98 | for remote in self.remotes:
99 | remote.send(('reset', None))
100 | return np.stack([remote.recv() for remote in self.remotes])
101 |
102 | def reset_task(self):
103 | for remote in self.remotes:
104 | remote.send(('reset_task', None))
105 | return np.stack([remote.recv() for remote in self.remotes])
106 |
107 | def close(self):
108 | if self.closed:
109 | return
110 |
111 | for remote in self.remotes:
112 | remote.send(('close', None))
113 | for p in self.ps:
114 | p.join()
115 | self.closed = True
116 |
117 | @property
118 | def num_envs(self):
119 | return len(self.remotes)
120 |
--------------------------------------------------------------------------------
/Env/AtariEnv/VideoRecord.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import subprocess
4 | import tempfile
5 | import os.path
6 | import distutils.spawn, distutils.version
7 | import numpy as np
8 | from six import StringIO
9 | import six
10 | from gym import error, logger
11 |
12 | def touch(path):
13 | open(path, 'a').close()
14 |
15 | class VideoRecorder():
16 | """VideoRecorder renders a nice movie of a rollout, frame by frame. It
17 | comes with an `enabled` option so you can still use the same code
18 | on episodes where you don't want to record video.
19 |
20 | Note:
21 | You are responsible for calling `close` on a created
22 | VideoRecorder, or else you may leak an encoder process.
23 |
24 | Args:
25 | env (Env): Environment to take video of.
26 | path (Optional[str]): Path to the video file; will be randomly chosen if omitted.
27 | base_path (Optional[str]): Alternatively, path to the video file without extension, which will be added.
28 | metadata (Optional[dict]): Contents to save to the metadata file.
29 | enabled (bool): Whether to actually record video, or just no-op (for convenience)
30 | """
31 |
32 | def __init__(self, env, path=None, metadata=None, enabled=True, base_path=None):
33 | modes = env.metadata.get('render.modes', [])
34 | self._async = env.metadata.get('semantics.async')
35 | self.enabled = enabled
36 |
37 | # Don't bother setting anything else if not enabled
38 | if not self.enabled:
39 | return
40 |
41 | self.ansi_mode = False
42 | if 'rgb_array' not in modes:
43 | if 'ansi' in modes:
44 | self.ansi_mode = True
45 | else:
46 | logger.info('Disabling video recorder because {} neither supports video mode "rgb_array" nor "ansi".'.format(env))
47 | # Whoops, turns out we shouldn't be enabled after all
48 | self.enabled = False
49 | return
50 |
51 | if path is not None and base_path is not None:
52 | raise error.Error("You can pass at most one of `path` or `base_path`.")
53 |
54 | self.last_frame = None
55 | self.env = env
56 |
57 | required_ext = '.json' if self.ansi_mode else '.mp4'
58 | if path is None:
59 | if base_path is not None:
60 | # Base path given, append ext
61 | path = base_path + required_ext
62 | else:
63 | # Otherwise, just generate a unique filename
64 | with tempfile.NamedTemporaryFile(suffix=required_ext, delete=False) as f:
65 | path = f.name
66 | self.path = path
67 |
68 | path_base, actual_ext = os.path.splitext(self.path)
69 |
70 | if actual_ext != required_ext:
71 | hint = " HINT: The environment is text-only, therefore we're recording its text output in a structured JSON format." if self.ansi_mode else ''
72 | raise error.Error("Invalid path given: {} -- must have file extension {}.{}".format(self.path, required_ext, hint))
73 | # Touch the file in any case, so we know it's present. (This
74 | # corrects for platform platform differences. Using ffmpeg on
75 | # OS X, the file is precreated, but not on Linux.
76 | touch(path)
77 |
78 | self.frames_per_sec = env.metadata.get('video.frames_per_second', 30)
79 | self.encoder = None # lazily start the process
80 | self.broken = False
81 |
82 | # Dump metadata
83 | self.metadata = metadata or {}
84 | self.metadata['content_type'] = 'video/vnd.openai.ansivid' if self.ansi_mode else 'video/mp4'
85 | self.metadata_path = '{}.meta.json'.format(path_base)
86 | self.write_metadata()
87 |
88 | logger.info('Starting new video recorder writing to %s', self.path)
89 | self.empty = True
90 |
91 | @property
92 | def functional(self):
93 | return self.enabled and not self.broken
94 |
95 | def capture_frame(self):
96 | """Render the given `env` and add the resulting frame to the video."""
97 | if not self.functional: return
98 | logger.debug('Capturing video frame: path=%s', self.path)
99 |
100 | render_mode = 'ansi' if self.ansi_mode else 'rgb_array'
101 | frame = self.env.render(mode=render_mode)
102 |
103 | if frame is None:
104 | if self._async:
105 | return
106 | else:
107 | # Indicates a bug in the environment: don't want to raise
108 | # an error here.
109 | logger.warn('Env returned None on render(). Disabling further rendering for video recorder by marking as disabled: path=%s metadata_path=%s', self.path, self.metadata_path)
110 | self.broken = True
111 | else:
112 | self.last_frame = frame
113 | if self.ansi_mode:
114 | self._encode_ansi_frame(frame)
115 | else:
116 | self._encode_image_frame(frame)
117 |
118 | def close(self):
119 | """Make sure to manually close, or else you'll leak the encoder process"""
120 | if not self.enabled:
121 | return
122 |
123 | if self.encoder:
124 | logger.debug('Closing video encoder: path=%s', self.path)
125 | self.encoder.close()
126 | self.encoder = None
127 | else:
128 | # No frames captured. Set metadata, and remove the empty output file.
129 | os.remove(self.path)
130 |
131 | if self.metadata is None:
132 | self.metadata = {}
133 | self.metadata['empty'] = True
134 |
135 | # If broken, get rid of the output file, otherwise we'd leak it.
136 | if self.broken:
137 | logger.info('Cleaning up paths for broken video recorder: path=%s metadata_path=%s', self.path, self.metadata_path)
138 |
139 | # Might have crashed before even starting the output file, don't try to remove in that case.
140 | if os.path.exists(self.path):
141 | os.remove(self.path)
142 |
143 | if self.metadata is None:
144 | self.metadata = {}
145 | self.metadata['broken'] = True
146 |
147 | self.write_metadata()
148 |
149 | def write_metadata(self):
150 | with open(self.metadata_path, 'w') as f:
151 | json.dump(self.metadata, f)
152 |
153 | def _encode_ansi_frame(self, frame):
154 | if not self.encoder:
155 | self.encoder = TextEncoder(self.path, self.frames_per_sec)
156 | self.metadata['encoder_version'] = self.encoder.version_info
157 | self.encoder.capture_frame(frame)
158 | self.empty = False
159 |
160 | def _encode_image_frame(self, frame):
161 | if not self.encoder:
162 | self.encoder = ImageEncoder(self.path, frame.shape, self.frames_per_sec)
163 | self.metadata['encoder_version'] = self.encoder.version_info
164 |
165 | try:
166 | self.encoder.capture_frame(frame)
167 | except error.InvalidFrame as e:
168 | logger.warn('Tried to pass invalid video frame, marking as broken: %s', e)
169 | self.broken = True
170 | else:
171 | self.empty = False
172 |
173 |
174 | class TextEncoder(object):
175 | """Store a moving picture made out of ANSI frames. Format adapted from
176 | https://github.com/asciinema/asciinema/blob/master/doc/asciicast-v1.md"""
177 |
178 | def __init__(self, output_path, frames_per_sec):
179 | self.output_path = output_path
180 | self.frames_per_sec = frames_per_sec
181 | self.frames = []
182 |
183 | def capture_frame(self, frame):
184 | string = None
185 | if isinstance(frame, str):
186 | string = frame
187 | elif isinstance(frame, StringIO):
188 | string = frame.getvalue()
189 | else:
190 | raise error.InvalidFrame('Wrong type {} for {}: text frame must be a string or StringIO'.format(type(frame), frame))
191 |
192 | frame_bytes = string.encode('utf-8')
193 |
194 | if frame_bytes[-1:] != six.b('\n'):
195 | raise error.InvalidFrame('Frame must end with a newline: """{}"""'.format(string))
196 |
197 | if six.b('\r') in frame_bytes:
198 | raise error.InvalidFrame('Frame contains carriage returns (only newlines are allowed: """{}"""'.format(string))
199 |
200 | self.frames.append(frame_bytes)
201 |
202 | def close(self):
203 | #frame_duration = float(1) / self.frames_per_sec
204 | frame_duration = .5
205 |
206 | # Turn frames into events: clear screen beforehand
207 | # https://rosettacode.org/wiki/Terminal_control/Clear_the_screen#Python
208 | # https://rosettacode.org/wiki/Terminal_control/Cursor_positioning#Python
209 | clear_code = six.b("%c[2J\033[1;1H" % (27))
210 | # Decode the bytes as UTF-8 since JSON may only contain UTF-8
211 | events = [ (frame_duration, (clear_code+frame.replace(six.b('\n'),six.b('\r\n'))).decode('utf-8')) for frame in self.frames ]
212 |
213 | # Calculate frame size from the largest frames.
214 | # Add some padding since we'll get cut off otherwise.
215 | height = max([frame.count(six.b('\n')) for frame in self.frames]) + 1
216 | width = max([max([len(line) for line in frame.split(six.b('\n'))]) for frame in self.frames]) + 2
217 |
218 | data = {
219 | "version": 1,
220 | "width": width,
221 | "height": height,
222 | "duration": len(self.frames)*frame_duration,
223 | "command": "-",
224 | "title": "gym VideoRecorder episode",
225 | "env": {}, # could add some env metadata here
226 | "stdout": events,
227 | }
228 |
229 | with open(self.output_path, 'w') as f:
230 | json.dump(data, f)
231 |
232 | @property
233 | def version_info(self):
234 | return {'backend':'TextEncoder','version':1}
235 |
236 | class ImageEncoder(object):
237 | def __init__(self, output_path, frame_shape, frames_per_sec):
238 | self.proc = None
239 | self.output_path = output_path
240 | # Frame shape should be lines-first, so w and h are swapped
241 | h, w, pixfmt = frame_shape
242 | if pixfmt != 3 and pixfmt != 4:
243 | raise error.InvalidFrame("Your frame has shape {}, but we require (w,h,3) or (w,h,4), i.e. RGB values for a w-by-h image, with an optional alpha channl.".format(frame_shape))
244 | self.wh = (w,h)
245 | self.includes_alpha = (pixfmt == 4)
246 | self.frame_shape = frame_shape
247 | self.frames_per_sec = frames_per_sec
248 |
249 | if distutils.spawn.find_executable('avconv') is not None:
250 | self.backend = 'avconv'
251 | elif distutils.spawn.find_executable('ffmpeg') is not None:
252 | self.backend = 'ffmpeg'
253 | else:
254 | raise error.DependencyNotInstalled("""Found neither the ffmpeg nor avconv executables. On OS X, you can install ffmpeg via `brew install ffmpeg`. On most Ubuntu variants, `sudo apt-get install ffmpeg` should do it. On Ubuntu 14.04, however, you'll need to install avconv with `sudo apt-get install libav-tools`.""")
255 |
256 | self.start()
257 |
258 | @property
259 | def version_info(self):
260 | return {
261 | 'backend':self.backend,
262 | 'version':str(subprocess.check_output([self.backend, '-version'],
263 | stderr=subprocess.STDOUT)),
264 | 'cmdline':self.cmdline
265 | }
266 |
267 | def start(self):
268 | self.cmdline = (self.backend,
269 | '-nostats',
270 | '-loglevel', 'error', # suppress warnings
271 | '-y',
272 | '-r', '%d' % self.frames_per_sec,
273 |
274 | # input
275 | '-f', 'rawvideo',
276 | '-s:v', '{}x{}'.format(*self.wh),
277 | '-pix_fmt',('rgb32' if self.includes_alpha else 'rgb24'),
278 | '-i', '-', # this used to be /dev/stdin, which is not Windows-friendly
279 |
280 | # output
281 | '-vcodec', 'libx264',
282 | '-pix_fmt', 'yuv420p',
283 | self.output_path
284 | )
285 |
286 | logger.debug('Starting ffmpeg with "%s"', ' '.join(self.cmdline))
287 | if hasattr(os,'setsid'): #setsid not present on Windows
288 | self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE, preexec_fn=os.setsid)
289 | else:
290 | self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE)
291 |
292 | def capture_frame(self, frame):
293 | if not isinstance(frame, (np.ndarray, np.generic)):
294 | raise error.InvalidFrame('Wrong type {} for {} (must be np.ndarray or np.generic)'.format(type(frame), frame))
295 | if frame.shape != self.frame_shape:
296 | raise error.InvalidFrame("Your frame has shape {}, but the VideoRecorder is configured for shape {}.".format(frame.shape, self.frame_shape))
297 | if frame.dtype != np.uint8:
298 | raise error.InvalidFrame("Your frame has data type {}, but we require uint8 (i.e. RGB values from 0-255).".format(frame.dtype))
299 |
300 | if distutils.version.LooseVersion(np.__version__) >= distutils.version.LooseVersion('1.9.0'):
301 | self.proc.stdin.write(frame.tobytes())
302 | else:
303 | self.proc.stdin.write(frame.tostring())
304 |
305 | def close(self):
306 | self.proc.stdin.close()
307 | ret = self.proc.wait()
308 | if ret != 0:
309 | logger.error("VideoRecorder encoder exited with status {}".format(ret))
310 |
--------------------------------------------------------------------------------
/Env/AtariEnv/atari_wrappers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | os.environ.setdefault('PATH', '')
4 | from collections import deque
5 | import gym
6 | from gym import spaces
7 | import cv2
8 | cv2.ocl.setUseOpenCL(False)
9 | from copy import deepcopy
10 |
11 | class TimeLimit(gym.Wrapper):
12 | def __init__(self, env, max_episode_steps=None):
13 | super(TimeLimit, self).__init__(env)
14 | self._max_episode_steps = max_episode_steps
15 | self._elapsed_steps = 0
16 |
17 | def step(self, ac):
18 | observation, reward, done, info = self.env.step(ac)
19 | self._elapsed_steps += 1
20 | if self._elapsed_steps >= self._max_episode_steps:
21 | done = True
22 | info['TimeLimit.truncated'] = True
23 | return observation, reward, done, info
24 |
25 | def reset(self, **kwargs):
26 | self._elapsed_steps = 0
27 | return self.env.reset(**kwargs)
28 |
29 | class NoopResetEnv(gym.Wrapper):
30 | def __init__(self, env, noop_max=30):
31 | """Sample initial states by taking random number of no-ops on reset.
32 | No-op is assumed to be action 0.
33 | """
34 | gym.Wrapper.__init__(self, env)
35 | self.noop_max = noop_max
36 | self.override_num_noops = None
37 | self.noop_action = 0
38 | assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
39 |
40 | def reset(self, **kwargs):
41 | """ Do no-op action for a number of steps in [1, noop_max]."""
42 | self.env.reset(**kwargs)
43 | if self.override_num_noops is not None:
44 | noops = self.override_num_noops
45 | else:
46 | noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101
47 | assert noops > 0
48 | obs = None
49 | for _ in range(noops):
50 | obs, _, done, _ = self.env.step(self.noop_action)
51 | if done:
52 | obs = self.env.reset(**kwargs)
53 | return obs
54 |
55 | def step(self, ac):
56 | return self.env.step(ac)
57 |
58 | class FireResetEnv(gym.Wrapper):
59 | def __init__(self, env):
60 | """Take action on reset for environments that are fixed until firing."""
61 | gym.Wrapper.__init__(self, env)
62 | assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
63 | assert len(env.unwrapped.get_action_meanings()) >= 3
64 |
65 | def reset(self, **kwargs):
66 | self.env.reset(**kwargs)
67 | obs, _, done, _ = self.env.step(1)
68 | if done:
69 | self.env.reset(**kwargs)
70 | obs, _, done, _ = self.env.step(2)
71 | if done:
72 | self.env.reset(**kwargs)
73 | return obs
74 |
75 | def step(self, ac):
76 | return self.env.step(ac)
77 |
78 | class EpisodicLifeEnv(gym.Wrapper):
79 | def __init__(self, env):
80 | """Make end-of-life == end-of-episode, but only reset on true game over.
81 | Done by DeepMind for the DQN and co. since it helps value estimation.
82 | """
83 | gym.Wrapper.__init__(self, env)
84 | self.lives = 0
85 | self.was_real_done = True
86 |
87 | def step(self, action):
88 | obs, reward, done, info = self.env.step(action)
89 | self.was_real_done = done
90 | # check current lives, make loss of life terminal,
91 | # then update lives to handle bonus lives
92 | lives = self.env.unwrapped.ale.lives()
93 | if lives < self.lives and lives > 0:
94 | # for Qbert sometimes we stay in lives == 0 condition for a few frames
95 | # so it's important to keep lives > 0, so that we only reset once
96 | # the environment advertises done.
97 | done = True
98 | self.lives = lives
99 | return obs, reward, done, info
100 |
101 | def reset(self, **kwargs):
102 | """Reset only when lives are exhausted.
103 | This way all states are still reachable even though lives are episodic,
104 | and the learner need not know about any of this behind-the-scenes.
105 | """
106 | if self.was_real_done:
107 | obs = self.env.reset(**kwargs)
108 | else:
109 | # no-op step to advance from terminal/lost life state
110 | obs, _, _, _ = self.env.step(0)
111 | self.lives = self.env.unwrapped.ale.lives()
112 | return obs
113 |
114 | class MaxAndSkipEnv(gym.Wrapper):
115 | def __init__(self, env, skip=4):
116 | """Return only every `skip`-th frame"""
117 | gym.Wrapper.__init__(self, env)
118 | # most recent raw observations (for max pooling across time steps)
119 | self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
120 | self._skip = skip
121 |
122 | def step(self, action):
123 | """Repeat action, sum reward, and max over last observations."""
124 | total_reward = 0.0
125 | done = None
126 | for i in range(self._skip):
127 | obs, reward, done, info = self.env.step(action)
128 | if i == self._skip - 2: self._obs_buffer[0] = obs
129 | if i == self._skip - 1: self._obs_buffer[1] = obs
130 | total_reward += reward
131 | if done:
132 | break
133 | # Note that the observation on the done=True frame
134 | # doesn't matter
135 | max_frame = self._obs_buffer.max(axis=0)
136 |
137 | return max_frame, total_reward, done, info
138 |
139 | def reset(self, **kwargs):
140 | return self.env.reset(**kwargs)
141 |
142 | class ClipRewardEnv(gym.RewardWrapper):
143 | def __init__(self, env):
144 | gym.RewardWrapper.__init__(self, env)
145 |
146 | def reward(self, reward):
147 | """Bin reward to {+1, 0, -1} by its sign."""
148 | return np.sign(reward)
149 |
150 |
151 | class WarpFrame(gym.ObservationWrapper):
152 | def __init__(self, env, width=84, height=84, grayscale=True, dict_space_key=None):
153 | """
154 | Warp frames to 84x84 as done in the Nature paper and later work.
155 |
156 | If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which
157 | observation should be warped.
158 | """
159 | super().__init__(env)
160 | self._width = width
161 | self._height = height
162 | self._grayscale = grayscale
163 | self._key = dict_space_key
164 | if self._grayscale:
165 | num_colors = 1
166 | else:
167 | num_colors = 3
168 |
169 | new_space = gym.spaces.Box(
170 | low=0,
171 | high=255,
172 | shape=(self._height, self._width, num_colors),
173 | dtype=np.uint8,
174 | )
175 | if self._key is None:
176 | original_space = self.observation_space
177 | self.observation_space = new_space
178 | else:
179 | original_space = self.observation_space.spaces[self._key]
180 | self.observation_space.spaces[self._key] = new_space
181 | assert original_space.dtype == np.uint8 and len(original_space.shape) == 3
182 |
183 | def observation(self, obs):
184 | if self._key is None:
185 | frame = obs
186 | else:
187 | frame = obs[self._key]
188 |
189 | if self._grayscale:
190 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
191 | frame = cv2.resize(
192 | frame, (self._width, self._height), interpolation=cv2.INTER_AREA
193 | )
194 | if self._grayscale:
195 | frame = np.expand_dims(frame, -1)
196 |
197 | if self._key is None:
198 | obs = frame
199 | else:
200 | obs = obs.copy()
201 | obs[self._key] = frame
202 | return obs
203 |
204 |
205 | class FrameStack(gym.Wrapper):
206 | def __init__(self, env, k):
207 | """Stack k last frames.
208 |
209 | Returns lazy array, which is much more memory efficient.
210 |
211 | See Also
212 | --------
213 | baselines.common.atari_wrappers.LazyFrames
214 | """
215 | gym.Wrapper.__init__(self, env)
216 | self.k = k
217 | self.frames = deque([], maxlen=k)
218 | shp = env.observation_space.shape
219 | self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=env.observation_space.dtype)
220 |
221 | def reset(self):
222 | ob = self.env.reset()
223 | for _ in range(self.k):
224 | self.frames.append(ob / 255.0)
225 | return self._get_ob()
226 |
227 | def step(self, action):
228 | ob, reward, done, info = self.env.step(action)
229 | self.frames.append(ob / 255.0)
230 | return self._get_ob(), reward, done, info
231 |
232 | def _get_ob(self):
233 | assert len(self.frames) == self.k
234 | return LazyFrames(list(self.frames))
235 |
236 | def clone_full_state(self):
237 | state_data = self.unwrapped.clone_full_state()
238 | frame_data = self.frames.copy()
239 |
240 | full_state_data = (state_data, frame_data)
241 |
242 | return full_state_data
243 |
244 | def restore_full_state(self, full_state_data):
245 | state_data, frame_data = full_state_data
246 |
247 | self.unwrapped.restore_full_state(state_data)
248 | self.frames = frame_data.copy()
249 |
250 | def get_state(self):
251 | return self._get_ob()
252 |
253 |
254 | class ScaledFloatFrame(gym.ObservationWrapper):
255 | def __init__(self, env):
256 | gym.ObservationWrapper.__init__(self, env)
257 | self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32)
258 |
259 | def observation(self, observation):
260 | # careful! This undoes the memory optimization, use
261 | # with smaller replay buffers only.
262 | return np.array(observation).astype(np.float32) / 255.0
263 |
264 | class LazyFrames(object):
265 | def __init__(self, frames):
266 | """This object ensures that common frames between the observations are only stored once.
267 | It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
268 | buffers.
269 |
270 | This object should only be converted to numpy array before being passed to the model.
271 |
272 | You'd not believe how complex the previous solution was."""
273 | self._frames = frames
274 | self._out = None
275 |
276 | def _force(self):
277 | if self._out is None:
278 | self._out = np.stack(self._frames, axis=0).reshape((-1, 84, 84))
279 | self._frames = None
280 | return self._out
281 |
282 | def __array__(self, dtype=None):
283 | out = self._force()
284 | if dtype is not None:
285 | out = out.astype(dtype)
286 | return out
287 |
288 | def __len__(self):
289 | return len(self._force())
290 |
291 | def __getitem__(self, i):
292 | return self._force()[i]
293 |
294 | def count(self):
295 | frames = self._force()
296 | return frames.shape[frames.ndim - 1]
297 |
298 | def frame(self, i):
299 | return self._force()[..., i]
300 |
301 | def make_atari(env_id, max_episode_steps=None):
302 | env = gym.make(env_id)
303 | assert 'NoFrameskip' in env.spec.id
304 | env = NoopResetEnv(env, noop_max=30)
305 | env = MaxAndSkipEnv(env, skip=4)
306 | if max_episode_steps is not None:
307 | env = TimeLimit(env, max_episode_steps=max_episode_steps)
308 | return env
309 |
310 | def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False):
311 | """Configure environment for DeepMind-style Atari.
312 | """
313 | if episode_life:
314 | env = EpisodicLifeEnv(env)
315 | if 'FIRE' in env.unwrapped.get_action_meanings():
316 | env = FireResetEnv(env)
317 | env = WarpFrame(env)
318 | if scale:
319 | env = ScaledFloatFrame(env)
320 | if clip_rewards:
321 | env = ClipRewardEnv(env)
322 | if frame_stack:
323 | env = FrameStack(env, 4)
324 | return env
325 |
326 |
--------------------------------------------------------------------------------
/Env/AtariEnv/subproc_vec_env.py:
--------------------------------------------------------------------------------
1 | import multiprocessing as mp
2 |
3 | import numpy as np
4 | from vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars
5 |
6 |
7 | def worker(remote, parent_remote, env_fn_wrapper):
8 | parent_remote.close()
9 | env = env_fn_wrapper.x()
10 | try:
11 | while True:
12 | cmd, data = remote.recv()
13 | if cmd == 'step':
14 | ob, reward, done, info = env.step(data)
15 | if done:
16 | ob = env.reset()
17 | remote.send((ob, reward, done, info))
18 | elif cmd == 'reset':
19 | ob = env.reset()
20 | remote.send(ob)
21 | elif cmd == 'render':
22 | remote.send(env.render(mode='rgb_array'))
23 | elif cmd == 'close':
24 | remote.close()
25 | break
26 | elif cmd == 'get_spaces_spec':
27 | remote.send((env.observation_space, env.action_space, env.spec))
28 | else:
29 | raise NotImplementedError
30 | except KeyboardInterrupt:
31 | print('SubprocVecEnv worker: got KeyboardInterrupt')
32 | finally:
33 | env.close()
34 |
35 |
36 | class SubprocVecEnv(VecEnv):
37 | """
38 | VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes.
39 | Recommended to use when num_envs > 1 and step() can be a bottleneck.
40 | """
41 | def __init__(self, env_fns, spaces=None, context='spawn'):
42 | """
43 | Arguments:
44 |
45 | env_fns: iterable of callables - functions that create environments to run in subprocesses. Need to be cloud-pickleable
46 | """
47 | self.waiting = False
48 | self.closed = False
49 | nenvs = len(env_fns)
50 | ctx = mp.get_context(context)
51 | self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(nenvs)])
52 | self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
53 | for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
54 | for p in self.ps:
55 | p.daemon = True # if the main process crashes, we should not cause things to hang
56 | with clear_mpi_env_vars():
57 | p.start()
58 | for remote in self.work_remotes:
59 | remote.close()
60 |
61 | self.remotes[0].send(('get_spaces_spec', None))
62 | observation_space, action_space, self.spec = self.remotes[0].recv()
63 | self.viewer = None
64 | VecEnv.__init__(self, len(env_fns), observation_space, action_space)
65 |
66 | def step_async(self, actions):
67 | self._assert_not_closed()
68 | for remote, action in zip(self.remotes, actions):
69 | remote.send(('step', action))
70 | self.waiting = True
71 |
72 | def step_wait(self):
73 | self._assert_not_closed()
74 | results = [remote.recv() for remote in self.remotes]
75 | self.waiting = False
76 | obs, rews, dones, infos = zip(*results)
77 | return _flatten_obs(obs), np.stack(rews), np.stack(dones), infos
78 |
79 | def reset(self):
80 | self._assert_not_closed()
81 | for remote in self.remotes:
82 | remote.send(('reset', None))
83 | return _flatten_obs([remote.recv() for remote in self.remotes])
84 |
85 | def close_extras(self):
86 | self.closed = True
87 | if self.waiting:
88 | for remote in self.remotes:
89 | remote.recv()
90 | for remote in self.remotes:
91 | remote.send(('close', None))
92 | for p in self.ps:
93 | p.join()
94 |
95 | def get_images(self):
96 | self._assert_not_closed()
97 | for pipe in self.remotes:
98 | pipe.send(('render', None))
99 | imgs = [pipe.recv() for pipe in self.remotes]
100 | return imgs
101 |
102 | def _assert_not_closed(self):
103 | assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()"
104 |
105 | def __del__(self):
106 | if not self.closed:
107 | self.close()
108 |
109 | def _flatten_obs(obs):
110 | assert isinstance(obs, (list, tuple))
111 | assert len(obs) > 0
112 |
113 | if isinstance(obs[0], dict):
114 | keys = obs[0].keys()
115 | return {k: np.stack([o[k] for o in obs]) for k in keys}
116 | else:
117 | return np.stack(obs)
118 |
--------------------------------------------------------------------------------
/Env/AtariEnv/vec_env.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import os
3 | from abc import ABC, abstractmethod
4 |
5 | def tile_images(img_nhwc):
6 | """
7 | Tile N images into one big PxQ image
8 | (P,Q) are chosen to be as close as possible, and if N
9 | is square, then P=Q.
10 |
11 | input: img_nhwc, list or array of images, ndim=4 once turned into array
12 | n = batch index, h = height, w = width, c = channel
13 | returns:
14 | bigim_HWc, ndarray with ndim=3
15 | """
16 | img_nhwc = np.asarray(img_nhwc)
17 | N, h, w, c = img_nhwc.shape
18 | H = int(np.ceil(np.sqrt(N)))
19 | W = int(np.ceil(float(N)/H))
20 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)])
21 | img_HWhwc = img_nhwc.reshape(H, W, h, w, c)
22 | img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)
23 | img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c)
24 | return img_Hh_Ww_c
25 |
26 | class AlreadySteppingError(Exception):
27 | """
28 | Raised when an asynchronous step is running while
29 | step_async() is called again.
30 | """
31 |
32 | def __init__(self):
33 | msg = 'already running an async step'
34 | Exception.__init__(self, msg)
35 |
36 |
37 | class NotSteppingError(Exception):
38 | """
39 | Raised when an asynchronous step is not running but
40 | step_wait() is called.
41 | """
42 |
43 | def __init__(self):
44 | msg = 'not running an async step'
45 | Exception.__init__(self, msg)
46 |
47 |
48 | class VecEnv(ABC):
49 | """
50 | An abstract asynchronous, vectorized environment.
51 | Used to batch data from multiple copies of an environment, so that
52 | each observation becomes an batch of observations, and expected action is a batch of actions to
53 | be applied per-environment.
54 | """
55 | closed = False
56 | viewer = None
57 |
58 | metadata = {
59 | 'render.modes': ['human', 'rgb_array']
60 | }
61 |
62 | def __init__(self, num_envs, observation_space, action_space):
63 | self.num_envs = num_envs
64 | self.observation_space = observation_space
65 | self.action_space = action_space
66 |
67 | @abstractmethod
68 | def reset(self):
69 | """
70 | Reset all the environments and return an array of
71 | observations, or a dict of observation arrays.
72 |
73 | If step_async is still doing work, that work will
74 | be cancelled and step_wait() should not be called
75 | until step_async() is invoked again.
76 | """
77 | pass
78 |
79 | @abstractmethod
80 | def step_async(self, actions):
81 | """
82 | Tell all the environments to start taking a step
83 | with the given actions.
84 | Call step_wait() to get the results of the step.
85 |
86 | You should not call this if a step_async run is
87 | already pending.
88 | """
89 | pass
90 |
91 | @abstractmethod
92 | def step_wait(self):
93 | """
94 | Wait for the step taken with step_async().
95 |
96 | Returns (obs, rews, dones, infos):
97 | - obs: an array of observations, or a dict of
98 | arrays of observations.
99 | - rews: an array of rewards
100 | - dones: an array of "episode done" booleans
101 | - infos: a sequence of info objects
102 | """
103 | pass
104 |
105 | def close_extras(self):
106 | """
107 | Clean up the extra resources, beyond what's in this base class.
108 | Only runs when not self.closed.
109 | """
110 | pass
111 |
112 | def close(self):
113 | if self.closed:
114 | return
115 | if self.viewer is not None:
116 | self.viewer.close()
117 | self.close_extras()
118 | self.closed = True
119 |
120 | def step(self, actions):
121 | """
122 | Step the environments synchronously.
123 |
124 | This is available for backwards compatibility.
125 | """
126 | self.step_async(actions)
127 | return self.step_wait()
128 |
129 | def render(self, mode='human'):
130 | imgs = self.get_images()
131 | bigimg = tile_images(imgs)
132 | if mode == 'human':
133 | self.get_viewer().imshow(bigimg)
134 | return self.get_viewer().isopen
135 | elif mode == 'rgb_array':
136 | return bigimg
137 | else:
138 | raise NotImplementedError
139 |
140 | def get_images(self):
141 | """
142 | Return RGB images from each environment
143 | """
144 | raise NotImplementedError
145 |
146 | @property
147 | def unwrapped(self):
148 | if isinstance(self, VecEnvWrapper):
149 | return self.venv.unwrapped
150 | else:
151 | return self
152 |
153 | def get_viewer(self):
154 | if self.viewer is None:
155 | from gym.envs.classic_control import rendering
156 | self.viewer = rendering.SimpleImageViewer()
157 | return self.viewer
158 |
159 | class VecEnvWrapper(VecEnv):
160 | """
161 | An environment wrapper that applies to an entire batch
162 | of environments at once.
163 | """
164 |
165 | def __init__(self, venv, observation_space=None, action_space=None):
166 | self.venv = venv
167 | super().__init__(num_envs=venv.num_envs,
168 | observation_space=observation_space or venv.observation_space,
169 | action_space=action_space or venv.action_space)
170 |
171 | def step_async(self, actions):
172 | self.venv.step_async(actions)
173 |
174 | @abstractmethod
175 | def reset(self):
176 | pass
177 |
178 | @abstractmethod
179 | def step_wait(self):
180 | pass
181 |
182 | def close(self):
183 | return self.venv.close()
184 |
185 | def render(self, mode='human'):
186 | return self.venv.render(mode=mode)
187 |
188 | def get_images(self):
189 | return self.venv.get_images()
190 |
191 | def __getattr__(self, name):
192 | if name.startswith('_'):
193 | raise AttributeError("attempted to get missing private attribute '{}'".format(name))
194 | return getattr(self.venv, name)
195 |
196 | class VecEnvObservationWrapper(VecEnvWrapper):
197 | @abstractmethod
198 | def process(self, obs):
199 | pass
200 |
201 | def reset(self):
202 | obs = self.venv.reset()
203 | return self.process(obs)
204 |
205 | def step_wait(self):
206 | obs, rews, dones, infos = self.venv.step_wait()
207 | return self.process(obs), rews, dones, infos
208 |
209 | class CloudpickleWrapper(object):
210 | """
211 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
212 | """
213 |
214 | def __init__(self, x):
215 | self.x = x
216 |
217 | def __getstate__(self):
218 | import cloudpickle
219 | return cloudpickle.dumps(self.x)
220 |
221 | def __setstate__(self, ob):
222 | import pickle
223 | self.x = pickle.loads(ob)
224 |
225 |
226 | @contextlib.contextmanager
227 | def clear_mpi_env_vars():
228 | """
229 | from mpi4py import MPI will call MPI_Init by default. If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang.
230 | This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing
231 | Processes.
232 | """
233 | removed_environment = {}
234 | for k, v in list(os.environ.items()):
235 | for prefix in ['OMPI_', 'PMI_']:
236 | if k.startswith(prefix):
237 | removed_environment[k] = v
238 | del os.environ[k]
239 | try:
240 | yield
241 | finally:
242 | os.environ.update(removed_environment)
243 |
--------------------------------------------------------------------------------
/Env/AtariEnv/vec_frame_stack.py:
--------------------------------------------------------------------------------
1 | from vec_env import VecEnvWrapper
2 | import numpy as np
3 | from gym import spaces
4 |
5 |
6 | class VecFrameStack(VecEnvWrapper):
7 | def __init__(self, venv, nstack):
8 | self.venv = venv
9 | self.nstack = nstack
10 | wos = venv.observation_space # wrapped ob space
11 | low = np.repeat(wos.low, self.nstack, axis=-1)
12 | high = np.repeat(wos.high, self.nstack, axis=-1)
13 | self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype)
14 | observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)
15 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
16 |
17 | def step_wait(self):
18 | obs, rews, news, infos = self.venv.step_wait()
19 | self.stackedobs = np.roll(self.stackedobs, shift=-1, axis=-1)
20 | for (i, new) in enumerate(news):
21 | if new:
22 | self.stackedobs[i] = 0
23 | self.stackedobs[..., -obs.shape[-1]:] = obs
24 | return self.stackedobs, rews, news, infos
25 |
26 | def reset(self):
27 | obs = self.venv.reset()
28 | self.stackedobs[...] = 0
29 | self.stackedobs[..., -obs.shape[-1]:] = obs
30 | return self.stackedobs
31 |
--------------------------------------------------------------------------------
/Env/EnvWrapper.py:
--------------------------------------------------------------------------------
1 | import gym
2 | from copy import deepcopy
3 |
4 | from Env.AtariEnv.AtariEnvWrapper import make_atari_env
5 |
6 |
7 | # To allow easily extending to other tasks, we built a wrapper on top of the 'real' environment.
8 | class EnvWrapper():
9 | def __init__(self, env_name, max_episode_length = 0, enable_record = False, record_path = "1.mp4"):
10 | self.env_name = env_name
11 |
12 | self.env_type = None
13 |
14 | try:
15 | self.env, self.recorder = make_atari_env(env_name, 0, 0, enable_record = enable_record,
16 | record_path = record_path)
17 |
18 | # Call reset to avoid gym bugs.
19 | self.env.reset()
20 |
21 | self.env_type = "Atari"
22 | except gym.error.Error:
23 | exit(1)
24 |
25 | assert isinstance(self.env.action_space, gym.spaces.Discrete), "Should be discrete action space."
26 | self.action_n = self.env.action_space.n
27 |
28 | self.max_episode_length = self.env._max_episode_steps if max_episode_length == 0 else max_episode_length
29 |
30 | self.current_step_count = 0
31 |
32 | self.since_last_reset = 0
33 |
34 | def reset(self):
35 | state = self.env.reset()
36 |
37 | self.current_step_count = 0
38 | self.since_last_reset = 0
39 |
40 | return state
41 |
42 | def step(self, action):
43 | next_state, reward, done, _ = self.env.step(action)
44 |
45 | self.current_step_count += 1
46 | if self.current_step_count >= self.max_episode_length:
47 | done = True
48 |
49 | self.since_last_reset += 1
50 |
51 | return next_state, reward, done
52 |
53 | def checkpoint(self):
54 | return deepcopy(self.env.clone_full_state()), self.current_step_count
55 |
56 | def restore(self, checkpoint):
57 | if self.since_last_reset > 20000:
58 | self.reset()
59 | self.since_last_reset = 0
60 |
61 | self.env.restore_full_state(checkpoint[0])
62 |
63 | self.current_step_count = checkpoint[1]
64 |
65 | return self.env.get_state()
66 |
67 | def render(self):
68 | self.env.render()
69 |
70 | def capture_frame(self):
71 | self.recorder.capture_frame()
72 |
73 | def store_video_files(self):
74 | self.recorder.write_metadata()
75 |
76 | def close(self):
77 | self.env.close()
78 |
79 | def seed(self, seed):
80 | self.env.seed(seed)
81 |
82 | def get_action_n(self):
83 | return self.action_n
84 |
85 | def get_max_episode_length(self):
86 | return self.max_episode_length
87 |
--------------------------------------------------------------------------------
/Env/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Env/__init__.py
--------------------------------------------------------------------------------
/Figures/Figure_atari_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Figures/Figure_atari_results.png
--------------------------------------------------------------------------------
/Figures/Figure_puct_conceptual_idea.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Figures/Figure_puct_conceptual_idea.png
--------------------------------------------------------------------------------
/Figures/Figure_puct_pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Figures/Figure_puct_pipeline.png
--------------------------------------------------------------------------------
/Figures/Figure_tap_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Figures/Figure_tap_results.png
--------------------------------------------------------------------------------
/Figures/Figure_time_consumption.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Figures/Figure_time_consumption.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Anji Liu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Logs/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Logs/.gitkeep
--------------------------------------------------------------------------------
/Mem/CheckpointManager.py:
--------------------------------------------------------------------------------
1 | # This is the centralized game-state storage.
2 | class CheckpointManager():
3 | def __init__(self):
4 | self.buffer = dict()
5 |
6 | self.envs = dict()
7 |
8 | def hock_env(self, name, env):
9 | self.envs[name] = env
10 |
11 | def checkpoint_env(self, name, idx):
12 | self.store(idx, self.envs[name].checkpoint())
13 |
14 | def load_checkpoint_env(self, name, idx):
15 | self.envs[name].restore(self.retrieve(idx))
16 |
17 | def store(self, idx, checkpoint_data):
18 | assert idx not in self.buffer
19 |
20 | self.buffer[idx] = checkpoint_data
21 |
22 | def retrieve(self, idx):
23 | assert idx in self.buffer
24 |
25 | return self.buffer[idx]
26 |
27 | def length(self):
28 | return len(self.buffer)
29 |
30 | def clear(self):
31 | self.buffer.clear()
32 |
--------------------------------------------------------------------------------
/Mem/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Mem/__init__.py
--------------------------------------------------------------------------------
/Node/UCTnode.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from copy import deepcopy
3 | import math
4 | import random
5 |
6 | from Utils.MovingAvegCalculator import MovingAvegCalculator
7 |
8 |
9 | class UCTnode():
10 | def __init__(self, action_n, state, checkpoint_idx, parent, tree,
11 | prior_prob = None, is_head = False, allowed_actions = None):
12 | self.action_n = action_n
13 | self.state = state
14 | self.checkpoint_idx = checkpoint_idx
15 | self.parent = parent
16 | self.tree = tree
17 | self.is_head = is_head
18 | self.allowed_actions = allowed_actions
19 |
20 | if tree is not None:
21 | self.max_width = tree.max_width
22 | else:
23 | self.max_width = 0
24 |
25 | self.children = [None for _ in range(self.action_n)]
26 | self.rewards = [0.0 for _ in range(self.action_n)]
27 | self.dones = [False for _ in range(self.action_n)]
28 | self.children_visit_count = [0 for _ in range(self.action_n)]
29 | self.Q_values = [0 for _ in range(self.action_n)]
30 | self.visit_count = 0
31 |
32 | if prior_prob is not None:
33 | self.prior_prob = prior_prob
34 | else:
35 | self.prior_prob = np.ones([self.action_n], dtype = np.float32) / self.action_n
36 |
37 | # Record traverse history
38 | self.traverse_history = list()
39 |
40 | # Updated node count
41 | self.updated_node_count = 0
42 |
43 | # Moving average calculator
44 | self.moving_aveg_calculator = MovingAvegCalculator(window_length = 500)
45 |
46 | def no_child_available(self):
47 | # All child nodes have not been expanded.
48 | return self.updated_node_count == 0
49 |
50 | def all_child_visited(self):
51 | # All child nodes have been visited and updated.
52 | if self.is_head:
53 | if self.allowed_actions is None:
54 | return self.updated_node_count == self.action_n
55 | else:
56 | return self.updated_node_count == len(self.allowed_actions)
57 | else:
58 | return self.updated_node_count == self.max_width
59 |
60 | def select_action(self):
61 | best_score = -10000.0
62 | best_action = 0
63 |
64 | for action in range(self.action_n):
65 | if self.children[action] is None:
66 | continue
67 |
68 | if self.allowed_actions is not None and action not in self.allowed_actions:
69 | continue
70 |
71 | exploit_score = self.Q_values[action] / self.children_visit_count[action]
72 | explore_score = math.sqrt(1.0 * math.log(self.visit_count) / self.children_visit_count[action])
73 | score_std = self.moving_aveg_calculator.get_standard_deviation()
74 | score = exploit_score + score_std * explore_score
75 |
76 | if score > best_score:
77 | best_score = score
78 | best_action = action
79 |
80 | return best_action
81 |
82 | def max_utility_action(self):
83 | best_score = -10000.0
84 | best_action = 0
85 |
86 | for action in range(self.action_n):
87 | if self.children[action] is None:
88 | continue
89 |
90 | score = self.Q_values[action] / self.children_visit_count[action]
91 |
92 | if score > best_score:
93 | best_score = score
94 | best_action = action
95 |
96 | return best_action
97 |
98 | def select_expand_action(self):
99 | count = 0
100 |
101 | while True:
102 | if self.allowed_actions is None:
103 | if count < 20:
104 | action = self.categorical(self.prior_prob)
105 | else:
106 | action = np.random.randint(0, self.action_n)
107 | else:
108 | action = random.choice(self.allowed_actions)
109 |
110 | if count > 100:
111 | return action
112 |
113 | if self.children_visit_count[action] > 0 and count < 10:
114 | count += 1
115 | continue
116 |
117 | if self.children[action] is None:
118 | return action
119 |
120 | count += 1
121 |
122 | def update_history(self, action_taken, reward):
123 | self.traverse_history = (action_taken, reward)
124 |
125 | def update(self, accu_reward):
126 | action_taken = self.traverse_history[0]
127 | reward = self.traverse_history[1]
128 |
129 | accu_reward = reward + self.tree.gamma * accu_reward
130 |
131 | if self.children_visit_count[action_taken] == 0:
132 | self.updated_node_count += 1
133 |
134 | self.children_visit_count[action_taken] += 1
135 | self.Q_values[action_taken] += accu_reward
136 |
137 | self.visit_count += 1
138 |
139 | self.moving_aveg_calculator.add_number(accu_reward)
140 |
141 | return accu_reward
142 |
143 | def add_child(self, action, child_state, checkpoint_idx, prior_prob):
144 | if self.children[action] is not None:
145 | node = self.children[action]
146 | else:
147 | node = UCTnode(
148 | action_n = self.action_n,
149 | state = child_state,
150 | checkpoint_idx = checkpoint_idx,
151 | parent = self,
152 | tree = self.tree,
153 | prior_prob = prior_prob
154 | )
155 |
156 | self.children[action] = node
157 |
158 | return node
159 |
160 | @staticmethod
161 | def categorical(pvals):
162 | num = np.random.random()
163 | for i in range(pvals.size):
164 | if num < pvals[i]:
165 | return i
166 | else:
167 | num -= pvals[i]
168 |
169 | return pvals.size - 1
170 |
--------------------------------------------------------------------------------
/Node/WU_UCTnode.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from copy import deepcopy
3 | import math
4 |
5 | from Utils.MovingAvegCalculator import MovingAvegCalculator
6 |
7 |
8 | class WU_UCTnode():
9 | def __init__(self, action_n, state, checkpoint_idx, parent, tree,
10 | prior_prob = None, is_head = False):
11 | self.action_n = action_n
12 | self.state = state
13 | self.checkpoint_idx = checkpoint_idx
14 | self.parent = parent
15 | self.tree = tree
16 | self.is_head = is_head
17 |
18 | if tree is not None:
19 | self.max_width = tree.max_width
20 | else:
21 | self.max_width = 0
22 |
23 | self.children = [None for _ in range(self.action_n)]
24 | self.rewards = [0.0 for _ in range(self.action_n)]
25 | self.dones = [False for _ in range(self.action_n)]
26 | self.children_visit_count = [0 for _ in range(self.action_n)]
27 | self.children_completed_visit_count = [0 for _ in range(self.action_n)]
28 | self.Q_values = [0 for _ in range(self.action_n)]
29 | self.visit_count = 0
30 |
31 | if prior_prob is not None:
32 | self.prior_prob = prior_prob
33 | else:
34 | self.prior_prob = np.ones([self.action_n], dtype=np.float32) / self.action_n
35 |
36 | # Record traverse history
37 | self.traverse_history = dict()
38 |
39 | # Visited node count
40 | self.visited_node_count = 0
41 |
42 | # Updated node count
43 | self.updated_node_count = 0
44 |
45 | # Moving average calculator
46 | self.moving_aveg_calculator = MovingAvegCalculator(window_length = 500)
47 |
48 | def no_child_available(self):
49 | # All child nodes have not been expanded.
50 | return self.updated_node_count == 0
51 |
52 | def all_child_visited(self):
53 | # All child nodes have been visited (not necessarily updated).
54 | if self.is_head:
55 | return self.visited_node_count == self.action_n
56 | else:
57 | return self.visited_node_count == self.max_width
58 |
59 | def all_child_updated(self):
60 | # All child nodes have been updated.
61 | if self.is_head:
62 | return self.updated_node_count == self.action_n
63 | else:
64 | return self.updated_node_count == self.max_width
65 |
66 | # Shallowly clone itself, contains necessary data only.
67 | def shallow_clone(self):
68 | node = WU_UCTnode(
69 | action_n = self.action_n,
70 | state = deepcopy(self.state),
71 | checkpoint_idx = self.checkpoint_idx,
72 | parent = None,
73 | tree = None,
74 | prior_prob = None,
75 | is_head = False
76 | )
77 |
78 | for action in range(self.action_n):
79 | if self.children[action] is not None:
80 | node.children[action] = 1
81 |
82 | node.children_visit_count = deepcopy(self.children_visit_count)
83 | node.children_completed_visit_count = deepcopy(self.children_completed_visit_count)
84 |
85 | node.visited_node_count = self.visited_node_count
86 | node.updated_node_count = self.updated_node_count
87 |
88 | node.action_n = self.action_n
89 | node.max_width = self.max_width
90 |
91 | node.prior_prob = self.prior_prob.copy()
92 |
93 | return node
94 |
95 | # Select action according to the P-UCT tree policy
96 | def select_action(self):
97 | best_score = -10000.0
98 | best_action = 0
99 |
100 | for action in range(self.action_n):
101 | if self.children[action] is None:
102 | continue
103 |
104 | exploit_score = self.Q_values[action] / self.children_completed_visit_count[action]
105 | explore_score = math.sqrt(2.0 * math.log(self.visit_count) / self.children_visit_count[action])
106 | score_std = self.moving_aveg_calculator.get_standard_deviation()
107 | score = exploit_score + score_std * 2.0 * explore_score
108 |
109 | if score > best_score:
110 | best_score = score
111 | best_action = action
112 |
113 | return best_action
114 |
115 | # Return the action with maximum utility.
116 | def max_utility_action(self):
117 | best_score = -10000.0
118 | best_action = 0
119 |
120 | for action in range(self.action_n):
121 | if self.children[action] is None:
122 | continue
123 |
124 | score = self.Q_values[action] / self.children_completed_visit_count[action]
125 |
126 | if score > best_score:
127 | best_score = score
128 | best_action = action
129 |
130 | return best_action
131 |
132 | # Choose an action to expand
133 | def select_expand_action(self):
134 | count = 0
135 |
136 | while True:
137 | if count < 20:
138 | action = self.categorical(self.prior_prob)
139 | else:
140 | action = np.random.randint(0, self.action_n)
141 |
142 | if count > 100:
143 | return action
144 |
145 | if self.children_visit_count[action] > 0 and count < 10:
146 | count += 1
147 | continue
148 |
149 | if self.children[action] is None:
150 | return action
151 |
152 | count += 1
153 |
154 | # Update traverse history, used to perform update
155 | def update_history(self, idx, action_taken, reward):
156 | if idx in self.traverse_history:
157 | return False
158 | else:
159 | self.traverse_history[idx] = (action_taken, reward)
160 | return True
161 |
162 | # Incomplete update, called by WU_UCT.py
163 | def update_incomplete(self, idx):
164 | action_taken = self.traverse_history[idx][0]
165 |
166 | if self.children_visit_count[action_taken] == 0:
167 | self.visited_node_count += 1
168 |
169 | self.children_visit_count[action_taken] += 1
170 | self.visit_count += 1
171 |
172 | # Complete update, called by WU_UCT.py
173 | def update_complete(self, idx, accu_reward):
174 | if idx not in self.traverse_history:
175 | raise RuntimeError("idx {} should be in traverse_history".format(idx))
176 | else:
177 | item = self.traverse_history.pop(idx)
178 | action_taken = item[0]
179 | reward = item[1]
180 |
181 | accu_reward = reward + self.tree.gamma * accu_reward
182 |
183 | if self.children_completed_visit_count[action_taken] == 0:
184 | self.updated_node_count += 1
185 |
186 | self.children_completed_visit_count[action_taken] += 1
187 | self.Q_values[action_taken] += accu_reward
188 |
189 | self.moving_aveg_calculator.add_number(accu_reward)
190 |
191 | return accu_reward
192 |
193 | # Add a child to current node.
194 | def add_child(self, action, child_state, checkpoint_idx, prior_prob = None):
195 | if self.children[action] is not None:
196 | node = self.children[action]
197 | else:
198 | node = WU_UCTnode(
199 | action_n = self.action_n,
200 | state = child_state,
201 | checkpoint_idx = checkpoint_idx,
202 | parent = self,
203 | tree = self.tree,
204 | prior_prob = prior_prob
205 | )
206 |
207 | self.children[action] = node
208 |
209 | return node
210 |
211 | # Draw a sample from the categorical distribution parametrized by 'pvals'.
212 | @staticmethod
213 | def categorical(pvals):
214 | num = np.random.random()
215 | for i in range(pvals.size):
216 | if num < pvals[i]:
217 | return i
218 | else:
219 | num -= pvals[i]
220 |
221 | return pvals.size - 1
222 |
--------------------------------------------------------------------------------
/Node/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Node/__init__.py
--------------------------------------------------------------------------------
/OutLogs/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/OutLogs/.gitkeep
--------------------------------------------------------------------------------
/OutLogs/WU-UCT_PongNoFrameskip-v0_123_2.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/OutLogs/WU-UCT_PongNoFrameskip-v0_123_2.mat
--------------------------------------------------------------------------------
/ParallelPool/PoolManager.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Pipe
2 | from copy import deepcopy
3 | import os
4 | import torch
5 |
6 | from ParallelPool.Worker import Worker
7 |
8 |
9 | # Works in the main process and manages sub-workers
10 | class PoolManager():
11 | def __init__(self, worker_num, env_params, policy = "Random",
12 | gamma = 1.0, seed = 123, device = "cpu", need_policy = True):
13 | self.worker_num = worker_num
14 | self.env_params = env_params
15 | self.policy = policy
16 | self.gamma = gamma
17 | self.seed = seed
18 | self.need_policy = need_policy
19 |
20 | # Buffer for workers and pipes
21 | self.workers = []
22 | self.pipes = []
23 |
24 | # CUDA device parallelization
25 | # if multiple cuda devices exist, use them all
26 | if torch.cuda.is_available():
27 | torch_device_num = torch.cuda.device_count()
28 | else:
29 | torch_device_num = 0
30 |
31 | # Initialize workers
32 | for worker_idx in range(worker_num):
33 | parent_pipe, child_pipe = Pipe()
34 | self.pipes.append(parent_pipe)
35 |
36 | worker = Worker(
37 | pipe = child_pipe,
38 | env_params = deepcopy(env_params),
39 | policy = policy,
40 | gamma = gamma,
41 | seed = seed + worker_idx,
42 | device = device + ":" + str(int(torch_device_num * worker_idx / worker_num))
43 | if device == "cuda" else device,
44 | need_policy = need_policy
45 | )
46 | self.workers.append(worker)
47 |
48 | # Start workers
49 | for worker in self.workers:
50 | worker.start()
51 |
52 | # Worker status: 0 for idle, 1 for busy
53 | self.worker_status = [0 for _ in range(worker_num)]
54 |
55 | def has_idle_server(self):
56 | for status in self.worker_status:
57 | if status == 0:
58 | return True
59 |
60 | return False
61 |
62 | def server_occupied_rate(self):
63 | occupied_count = 0.0
64 |
65 | for status in self.worker_status:
66 | occupied_count += status
67 |
68 | return occupied_count / self.worker_num
69 |
70 | def find_idle_worker(self):
71 | for idx, status in enumerate(self.worker_status):
72 | if status == 0:
73 | self.worker_status[idx] = 1
74 | return idx
75 |
76 | return None
77 |
78 | def assign_expansion_task(self, checkpoint_data, curr_node,
79 | saving_idx, task_simulation_idx):
80 | worker_idx = self.find_idle_worker()
81 |
82 | self.send_safe_protocol(worker_idx, "Expansion", (
83 | checkpoint_data,
84 | curr_node,
85 | saving_idx,
86 | task_simulation_idx
87 | ))
88 |
89 | self.worker_status[worker_idx] = 1
90 |
91 | def assign_simulation_task(self, task_idx, checkpoint_data, first_action = None):
92 | worker_idx = self.find_idle_worker()
93 |
94 | self.send_safe_protocol(worker_idx, "Simulation", (
95 | task_idx,
96 | checkpoint_data,
97 | first_action
98 | ))
99 |
100 | self.worker_status[worker_idx] = 1
101 |
102 | def get_complete_expansion_task(self):
103 | flag = False
104 | selected_worker_idx = -1
105 |
106 | while not flag:
107 | for worker_idx in range(self.worker_num):
108 | item = self.receive_safe_protocol_tapcheck(worker_idx)
109 |
110 | if item is not None:
111 | flag = True
112 | selected_worker_idx = worker_idx
113 | break
114 |
115 | command, args = item
116 | assert command == "ReturnExpansion"
117 |
118 | # Set to idle
119 | self.worker_status[selected_worker_idx] = 0
120 |
121 | return args
122 |
123 | def get_complete_simulation_task(self):
124 | flag = False
125 | selected_worker_idx = -1
126 |
127 | while not flag:
128 | for worker_idx in range(self.worker_num):
129 | item = self.receive_safe_protocol_tapcheck(worker_idx)
130 |
131 | if item is not None:
132 | flag = True
133 | selected_worker_idx = worker_idx
134 | break
135 |
136 | command, args = item
137 | assert command == "ReturnSimulation"
138 |
139 | # Set to idle
140 | self.worker_status[selected_worker_idx] = 0
141 |
142 | return args
143 |
144 | def send_safe_protocol(self, worker_idx, command, args):
145 | success = False
146 |
147 | while not success:
148 | self.pipes[worker_idx].send((command, args))
149 |
150 | ret = self.pipes[worker_idx].recv()
151 | if ret == command:
152 | success = True
153 |
154 | def wait_until_all_envs_idle(self):
155 | for worker_idx in range(self.worker_num):
156 | if self.worker_status[worker_idx] == 0:
157 | continue
158 |
159 | self.receive_safe_protocol(worker_idx)
160 |
161 | self.worker_status[worker_idx] = 0
162 |
163 | def receive_safe_protocol(self, worker_idx):
164 | self.pipes[worker_idx].poll(None)
165 |
166 | command, args = self.pipes[worker_idx].recv()
167 |
168 | self.pipes[worker_idx].send(command)
169 |
170 | return deepcopy(command), deepcopy(args)
171 |
172 | def receive_safe_protocol_tapcheck(self, worker_idx):
173 | flag = self.pipes[worker_idx].poll()
174 | if not flag:
175 | return None
176 |
177 | command, args = self.pipes[worker_idx].recv()
178 |
179 | self.pipes[worker_idx].send(command)
180 |
181 | return deepcopy(command), deepcopy(args)
182 |
183 | def close_pool(self):
184 | for worker_idx in range(self.worker_num):
185 | self.send_safe_protocol(worker_idx, "KillProc", None)
186 |
187 | for worker in self.workers:
188 | worker.join()
189 |
--------------------------------------------------------------------------------
/ParallelPool/Worker.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Process
2 | from copy import deepcopy
3 | import random
4 | import numpy as np
5 |
6 | from Env.EnvWrapper import EnvWrapper
7 |
8 | from Policy.PPO.PPOPolicy import PPOAtariCNN, PPOSmallAtariCNN
9 |
10 | from Policy.PolicyWrapper import PolicyWrapper
11 |
12 |
13 | # Slave workers
14 | class Worker(Process):
15 | def __init__(self, pipe, env_params, policy = "Random", gamma = 1.0, seed = 123,
16 | device = "cpu", need_policy = True):
17 | super(Worker, self).__init__()
18 |
19 | self.pipe = pipe
20 | self.env_params = deepcopy(env_params)
21 | self.gamma = gamma
22 | self.seed = seed
23 | self.policy = deepcopy(policy)
24 | self.device = deepcopy(device)
25 | self.need_policy = need_policy
26 |
27 | self.wrapped_env = None
28 | self.action_n = None
29 | self.max_episode_length = None
30 |
31 | self.policy_wrapper = None
32 |
33 | # Initialize the environment
34 | def init_process(self):
35 | self.wrapped_env = EnvWrapper(**self.env_params)
36 |
37 | self.wrapped_env.seed(self.seed)
38 |
39 | self.action_n = self.wrapped_env.get_action_n()
40 | self.max_episode_length = self.wrapped_env.get_max_episode_length()
41 |
42 | # Initialize the default policy
43 | def init_policy(self):
44 | self.policy_wrapper = PolicyWrapper(
45 | self.policy,
46 | self.env_params["env_name"],
47 | self.action_n,
48 | self.device
49 | )
50 |
51 | def run(self):
52 | self.init_process()
53 | self.init_policy()
54 |
55 | print("> Worker ready.")
56 |
57 | while True:
58 | # Wait for tasks
59 | command, args = self.receive_safe_protocol()
60 |
61 | if command == "KillProc":
62 | return
63 | elif command == "Expansion":
64 | checkpoint_data, curr_node, saving_idx, task_idx = args
65 |
66 | # Select expand action, and do expansion
67 | expand_action, next_state, reward, done, \
68 | checkpoint_data = self.expand_node(checkpoint_data, curr_node)
69 |
70 | item = (expand_action, next_state, reward, done, checkpoint_data,
71 | saving_idx, task_idx)
72 |
73 | self.send_safe_protocol("ReturnExpansion", item)
74 | elif command == "Simulation":
75 | if args is None:
76 | raise RuntimeError
77 | else:
78 | task_idx, checkpoint_data, first_action = args
79 |
80 | state = self.wrapped_env.restore(checkpoint_data)
81 |
82 | # Prior probability is calculated for the new node
83 | prior_prob = self.get_prior_prob(state)
84 |
85 | # When simulation invoked because of reaching maximum search depth,
86 | # an action was actually selected. Therefore, we need to execute it
87 | # first anyway.
88 | if first_action is not None:
89 | state, reward, done = self.wrapped_env.step(first_action)
90 |
91 | if first_action is not None and done:
92 | accu_reward = reward
93 | else:
94 | # Simulate until termination condition satisfied
95 | accu_reward = self.simulate(state)
96 |
97 | if first_action is not None:
98 | self.send_safe_protocol("ReturnSimulation", (task_idx, accu_reward, reward, done))
99 | else:
100 | self.send_safe_protocol("ReturnSimulation", (task_idx, accu_reward, prior_prob))
101 |
102 | def expand_node(self, checkpoint_data, curr_node):
103 | self.wrapped_env.restore(checkpoint_data)
104 |
105 | # Choose action to expand, according to the shallow copy node
106 | expand_action = curr_node.select_expand_action()
107 |
108 | # Execute the action, and observe new state, etc.
109 | next_state, reward, done = self.wrapped_env.step(expand_action)
110 |
111 | if not done:
112 | checkpoint_data = self.wrapped_env.checkpoint()
113 | else:
114 | checkpoint_data = None
115 |
116 | return expand_action, next_state, reward, done, checkpoint_data
117 |
118 | def simulate(self, state, max_simulation_step = 100, lamda = 0.5):
119 | step_count = 0
120 | accu_reward = 0.0
121 | accu_gamma = 1.0
122 |
123 | start_state_value = self.get_value(state)
124 |
125 | done = False
126 | # A strict upper bound for simulation count
127 | while not done and step_count < max_simulation_step:
128 | action = self.get_action(state)
129 |
130 | next_state, reward, done = self.wrapped_env.step(action)
131 |
132 | accu_reward += reward * accu_gamma
133 | accu_gamma *= self.gamma
134 |
135 | state = deepcopy(next_state)
136 |
137 | step_count += 1
138 |
139 | if not done:
140 | accu_reward += self.get_value(state) * accu_gamma
141 |
142 | # Use V(s) to stabilize simulation return
143 | accu_reward = accu_reward * lamda + start_state_value * (1.0 - lamda)
144 |
145 | return accu_reward
146 |
147 | def get_action(self, state):
148 | return self.policy_wrapper.get_action(state)
149 |
150 | def get_value(self, state):
151 | return self.policy_wrapper.get_value(state)
152 |
153 | def get_prior_prob(self, state):
154 | return self.policy_wrapper.get_prior_prob(state)
155 |
156 | # Send message through pipe
157 | def send_safe_protocol(self, command, args):
158 | success = False
159 |
160 | count = 0
161 | while not success:
162 | self.pipe.send((command, args))
163 |
164 | ret = self.pipe.recv()
165 | if ret == command or count >= 10:
166 | success = True
167 |
168 | count += 1
169 |
170 | # Receive message from pipe
171 | def receive_safe_protocol(self):
172 | self.pipe.poll(None)
173 |
174 | command, args = self.pipe.recv()
175 |
176 | self.pipe.send(command)
177 |
178 | return deepcopy(command), deepcopy(args)
179 |
--------------------------------------------------------------------------------
/ParallelPool/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/ParallelPool/__init__.py
--------------------------------------------------------------------------------
/Policy/PPO/PPOPolicy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 | import time
8 | import os
9 |
10 | from Env.AtariEnv.atari_wrappers import LazyFrames
11 |
12 |
13 | def ortho_weights(shape, scale=1.):
14 | """ PyTorch port of ortho_init from baselines.a2c.utils """
15 | shape = tuple(shape)
16 |
17 | if len(shape) == 2:
18 | flat_shape = shape[1], shape[0]
19 | elif len(shape) == 4:
20 | flat_shape = (np.prod(shape[1:]), shape[0])
21 | else:
22 | raise NotImplementedError
23 |
24 | a = np.random.normal(0., 1., flat_shape)
25 | u, _, v = np.linalg.svd(a, full_matrices=False)
26 | q = u if u.shape == flat_shape else v
27 | q = q.transpose().copy().reshape(shape)
28 |
29 | if len(shape) == 2:
30 | return torch.from_numpy((scale * q).astype(np.float32))
31 | if len(shape) == 4:
32 | return torch.from_numpy((scale * q[:, :shape[1], :shape[2]]).astype(np.float32))
33 |
34 |
35 | def atari_initializer(module):
36 | """ Parameter initializer for Atari models
37 |
38 | Initializes Linear, Conv2d, and LSTM weights.
39 | """
40 | classname = module.__class__.__name__
41 |
42 | if classname == 'Linear':
43 | module.weight.data = ortho_weights(module.weight.data.size(), scale=np.sqrt(2.))
44 | module.bias.data.zero_()
45 |
46 | elif classname == 'Conv2d':
47 | module.weight.data = ortho_weights(module.weight.data.size(), scale=np.sqrt(2.))
48 | module.bias.data.zero_()
49 |
50 | elif classname == 'LSTM':
51 | for name, param in module.named_parameters():
52 | if 'weight_ih' in name:
53 | param.data = ortho_weights(param.data.size(), scale=1.)
54 | if 'weight_hh' in name:
55 | param.data = ortho_weights(param.data.size(), scale=1.)
56 | if 'bias' in name:
57 | param.data.zero_()
58 |
59 |
60 | class PPOAtariCNN():
61 | def __init__(self, num_actions, device = "cpu", checkpoint_dir = ""):
62 | self.num_actions = num_actions
63 | self.device = torch.device(device)
64 | self.checkpoint_dir = checkpoint_dir
65 |
66 | self.model = AtariCNN(num_actions, self.device)
67 |
68 | if checkpoint_dir != "" and os.path.exists(checkpoint_dir):
69 | checkpoint = torch.load(checkpoint_dir, map_location = "cpu")
70 | self.model.load_state_dict(checkpoint["policy"])
71 |
72 | self.model.to(device)
73 |
74 | def get_action(self, state, logit = False):
75 | return self.model.get_action(state, logit = logit)
76 |
77 | def get_value(self, state):
78 | return self.model.get_value(state)
79 |
80 |
81 | class PPOSmallAtariCNN():
82 | def __init__(self, num_actions, device = "cpu", checkpoint_dir = ""):
83 | self.num_actions = num_actions
84 | self.device = torch.device(device)
85 | self.checkpoint_dir = checkpoint_dir
86 |
87 | self.model = SmallPolicyAtariCNN(num_actions, self.device)
88 |
89 | if checkpoint_dir != "" and os.path.exists(checkpoint_dir):
90 | checkpoint = torch.load(checkpoint_dir, map_location = "cpu")
91 | # self.model.load_state_dict(checkpoint["policy"])
92 |
93 | self.model.to(device)
94 |
95 | self.optimizer = optim.Adam(self.model.parameters(), lr = 1e-3)
96 |
97 | self.mseLoss = nn.MSELoss()
98 |
99 | def get_action(self, state):
100 | return self.model.get_action(state)
101 |
102 | def get_value(self, state):
103 | # raise RuntimeError("Small policy net does not support value evaluation.")
104 | return self.model.get_value(state)
105 |
106 | def train_step(self, state_batch, policy_batch, value_batch, temperature = 2.5):
107 | self.optimizer.zero_grad()
108 |
109 | # policy_batch = F.softmax(policy_batch / temperature, dim = 1)
110 |
111 | out_policy, out_value = self.model(state_batch)
112 | # out_policy = F.softmax(out_policy / temperature, dim = 1)
113 |
114 | # loss = -(policy_batch * torch.log(out_policy + 1e-8)).sum(dim = 1).mean()
115 | loss = self.mseLoss(policy_batch, out_policy) + self.mseLoss(value_batch, out_value)
116 | loss.backward()
117 |
118 | self.optimizer.step()
119 |
120 | return loss.detach().cpu().numpy()
121 |
122 | def save(self, path):
123 | torch.save({"policy": self.model.state_dict()}, path)
124 |
125 |
126 | class AtariCNN(nn.Module):
127 | def __init__(self, num_actions, device):
128 | """ Basic convolutional actor-critic network for Atari 2600 games
129 |
130 | Equivalent to the network in the original DQN paper.
131 |
132 | Args:
133 | num_actions (int): the number of available discrete actions
134 | """
135 | super().__init__()
136 |
137 | self.conv = nn.Sequential(nn.Conv2d(4, 32, 8, stride=4),
138 | nn.ReLU(inplace=True),
139 | nn.Conv2d(32, 64, 4, stride=2),
140 | nn.ReLU(inplace=True),
141 | nn.Conv2d(64, 64, 3, stride=1),
142 | nn.ReLU(inplace=True))
143 |
144 | self.fc = nn.Sequential(nn.Linear(64 * 7 * 7, 512),
145 | nn.ReLU(inplace=True))
146 |
147 | self.pi = nn.Linear(512, num_actions)
148 | self.v = nn.Linear(512, 1)
149 |
150 | self.num_actions = num_actions
151 |
152 | self.device = device
153 |
154 | def forward(self, conv_in):
155 | """ Module forward pass
156 |
157 | Args:
158 | conv_in (Variable): convolutional input, shaped [N x 4 x 84 x 84]
159 |
160 | Returns:
161 | pi (Variable): action probability logits, shaped [N x self.num_actions]
162 | v (Variable): value predictions, shaped [N x 1]
163 | """
164 | N = conv_in.size()[0]
165 |
166 | conv_out = self.conv(conv_in).view(N, 64 * 7 * 7)
167 |
168 | fc_out = self.fc(conv_out)
169 |
170 | pi_out = self.pi(fc_out)
171 | v_out = self.v(fc_out)
172 |
173 | return pi_out, v_out
174 |
175 | def get_action(self, conv_in, logit = False):
176 | if isinstance(conv_in, LazyFrames):
177 | conv_in = torch.from_numpy(np.array(conv_in)).type(torch.float32).to(self.device).unsqueeze(0)
178 | elif isinstance(conv_in, np.ndarray):
179 | conv_in = torch.from_numpy(conv_in).to(self.device)
180 |
181 | N = conv_in.size(0)
182 | s = time.time()
183 | conv_out = self.conv(conv_in).view(N, 64 * 7 * 7)
184 | aa = time.time()
185 | fc_out = self.fc(conv_out)
186 |
187 | if logit:
188 | pi_out = self.pi(fc_out)
189 | else:
190 | pi_out = F.softmax(self.pi(fc_out), dim = 1)
191 |
192 | if N == 1:
193 | pi_out = pi_out.view(-1)
194 | e = time.time()
195 | # print("large", e - s, aa - s)
196 | return pi_out.detach().cpu().numpy()
197 |
198 | def get_value(self, conv_in):
199 | if isinstance(conv_in, LazyFrames):
200 | conv_in = torch.from_numpy(np.array(conv_in)).type(torch.float32).to(self.device).unsqueeze(0)
201 | else:
202 | raise NotImplementedError()
203 |
204 | N = conv_in.size(0)
205 |
206 | conv_out = self.conv(conv_in).view(N, 64 * 7 * 7)
207 |
208 | fc_out = self.fc(conv_out)
209 |
210 | v_out = self.v(fc_out)
211 |
212 | if N == 1:
213 | v_out = v_out.sum()
214 |
215 | return v_out.detach().cpu().numpy()
216 |
217 |
218 | class SmallPolicyAtariCNN(nn.Module):
219 | def __init__(self, num_actions, device):
220 | """ Basic convolutional actor-critic network for Atari 2600 games
221 |
222 | Equivalent to the network in the original DQN paper.
223 |
224 | Args:
225 | num_actions (int): the number of available discrete actions
226 | """
227 | super().__init__()
228 |
229 | self.conv = nn.Sequential(nn.Conv2d(4, 16, 3, stride = 4),
230 | nn.ReLU(inplace = True),
231 | nn.Conv2d(16, 16, 3, stride = 4),
232 | nn.ReLU(inplace = True))
233 |
234 | self.fc = nn.Sequential(nn.Linear(16 * 5 * 5, 64),
235 | nn.ReLU(inplace = True))
236 |
237 | self.pi = nn.Linear(64, num_actions)
238 | self.v = nn.Linear(64, 1)
239 |
240 | self.num_actions = num_actions
241 |
242 | self.device = device
243 |
244 | def forward(self, conv_in):
245 | """ Module forward pass
246 |
247 | Args:
248 | conv_in (Variable): convolutional input, shaped [N x 4 x 84 x 84]
249 |
250 | Returns:
251 | pi (Variable): action probability logits, shaped [N x self.num_actions]
252 | v (Variable): value predictions, shaped [N x 1]
253 | """
254 | N = conv_in.size()[0]
255 |
256 | conv_out = self.conv(conv_in).view(N, 16 * 5 * 5)
257 |
258 | fc_out = self.fc(conv_out)
259 |
260 | pi_out = self.pi(fc_out)
261 | v_out = self.v(fc_out)
262 |
263 | return pi_out, v_out
264 |
265 | def get_action(self, conv_in, logit = False, get_tensor = False):
266 | if isinstance(conv_in, LazyFrames):
267 | conv_in = torch.from_numpy(np.array(conv_in)).type(torch.float32).to(self.device).unsqueeze(0)
268 | elif isinstance(conv_in, np.ndarray):
269 | conv_in = torch.from_numpy(conv_in).to(self.device)
270 |
271 | N = conv_in.size(0)
272 | s = time.time()
273 | # with torch.no_grad():
274 | conv_out = self.conv(conv_in).view(N, 16 * 5 * 5)
275 | aa = time.time()
276 | fc_out = self.fc(conv_out)
277 | bb = time.time()
278 | if logit:
279 | pi_out = self.pi(fc_out)
280 | else:
281 | pi_out = F.softmax(self.pi(fc_out), dim = 1)
282 | cc = time.time()
283 | if N == 1:
284 | pi_out = pi_out.view(-1)
285 | e = time.time()
286 | print("small", e - s, aa - s, bb - aa, cc - bb)
287 | if get_tensor:
288 | return pi_out
289 | else:
290 | return pi_out.detach().cpu().numpy()
291 |
292 | def get_value(self, conv_in):
293 | if isinstance(conv_in, LazyFrames):
294 | conv_in = torch.from_numpy(np.array(conv_in)).type(torch.float32).to(self.device).unsqueeze(0)
295 | else:
296 | raise NotImplementedError()
297 |
298 | N = conv_in.size(0)
299 |
300 | conv_out = self.conv(conv_in).view(N, 16 * 5 * 5)
301 |
302 | fc_out = self.fc(conv_out)
303 |
304 | v_out = self.v(fc_out)
305 | # print("a")
306 | if N == 1:
307 | v_out = v_out.sum()
308 |
309 | return v_out.detach().cpu().numpy()
310 |
--------------------------------------------------------------------------------
/Policy/PPO/PolicyFiles/PPO_AlienNoFrameskip-v0.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Policy/PPO/PolicyFiles/PPO_AlienNoFrameskip-v0.pt
--------------------------------------------------------------------------------
/Policy/PolicyWrapper.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import os
4 |
5 | from Policy.PPO.PPOPolicy import PPOAtariCNN, PPOSmallAtariCNN
6 |
7 |
8 | class PolicyWrapper():
9 | def __init__(self, policy_name, env_name, action_n, device):
10 | self.policy_name = policy_name
11 | self.env_name = env_name
12 | self.action_n = action_n
13 | self.device = device
14 |
15 | self.policy_func = None
16 |
17 | self.init_policy()
18 |
19 | def init_policy(self):
20 | if self.policy_name == "Random":
21 | self.policy_func = None
22 |
23 | elif self.policy_name == "PPO":
24 | assert os.path.exists("./Policy/PPO/PolicyFiles/PPO_" + self.env_name + ".pt"), "Policy file not found"
25 |
26 | self.policy_func = PPOAtariCNN(
27 | self.action_n,
28 | device = self.device,
29 | checkpoint_dir = "./Policy/PPO/PolicyFiles/PPO_" + self.env_name + ".pt"
30 | )
31 |
32 | elif self.policy_name == "DistillPPO":
33 | assert os.path.exists("./Policy/PPO/PolicyFiles/PPO_" + self.env_name + ".pt"), "Policy file not found"
34 | assert os.path.exists("./Policy/PPO/PolicyFiles/SmallPPO_" + self.env_name + ".pt"), "Policy file not found"
35 |
36 | full_policy = PPOAtariCNN(
37 | self.action_n,
38 | device = "cpu", # To save memory
39 | checkpoint_dir = "./Policy/PPO/PolicyFiles/PPO_" + self.env_name + ".pt"
40 | )
41 |
42 | small_policy = PPOSmallAtariCNN(
43 | self.action_n,
44 | device = self.device,
45 | checkpoint_dir = "./Policy/PPO/PolicyFiles/SmallPPO_" + self.env_name + ".pt"
46 | )
47 |
48 | self.policy_func = [full_policy, small_policy]
49 | else:
50 | raise NotImplementedError()
51 |
52 | def get_action(self, state):
53 | if self.policy_name == "Random":
54 | return random.randint(0, self.action_n - 1)
55 | elif self.policy_name == "PPO":
56 | return self.categorical(self.policy_func.get_action(state))
57 | elif self.policy_name == "DistillPPO":
58 | return self.categorical(self.policy_func[1].get_action(state))
59 | else:
60 | raise NotImplementedError()
61 |
62 | def get_value(self, state):
63 | if self.policy_name == "Random":
64 | return 0.0
65 | elif self.policy_name == "PPO":
66 | return self.policy_func.get_value(state)
67 | elif self.policy_name == "DistillPPO":
68 | return self.policy_func[0].get_value(state)
69 | else:
70 | raise NotImplementedError()
71 |
72 | def get_prior_prob(self, state):
73 | if self.policy_name == "Random":
74 | return np.ones([self.action_n], dtype = np.float32) / self.action_n
75 | elif self.policy_name == "PPO":
76 | return self.policy_func.get_action(state)
77 | elif self.policy_name == "DistillPPO":
78 | return self.policy_func[0].get_action(state)
79 | else:
80 | raise NotImplementedError()
81 |
82 | @staticmethod
83 | def categorical(probs):
84 | val = random.random()
85 | chosen_idx = 0
86 |
87 | for prob in probs:
88 | val -= prob
89 |
90 | if val < 0.0:
91 | break
92 |
93 | chosen_idx += 1
94 |
95 | if chosen_idx >= len(probs):
96 | chosen_idx = len(probs) - 1
97 |
98 | return chosen_idx
99 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # WU-UCT (Watch the Unobserved in UCT)
2 | A novel parallel UCT algorithm with linear speedup and negligible performance loss. This package provides a demo on Atari games (see [Running](#Running)). To allow easy extension to other environments, we wrote the code in an extendable way, and modification on only two files are needed for other environments (see [Run on your own environments](#Run-on-your-own-environments)).
3 |
4 | This work has been accepted by **ICLR 2020** for **oral full presentation (48/2594)**.
5 |
6 | # A quick demo!
7 |
8 | This [Google Colaboratory link](https://colab.research.google.com/drive/140Ea6pd8abvg_HdZDVsCVfdeLICSa-f0) contains a demo on the PongNoFrameskip-v0 environment. We thank [@lvisdd](https://github.com/lvisdd) for building this useful demo!
9 |
10 | # Introduction
11 | Note: For full details of WU-UCT, please refer to our [Arxiv](https://arxiv.org/abs/1810.11755) or [OpenReview](https://openreview.net/forum?id=BJlQtJSKDB) paper.
12 |
13 | ## Conceptual idea
14 |
15 |
16 |
17 |
18 | We use the above figure to demonstrate the main problem caused by parallelizing UCT. (a) illustrates the four main steps of UCT: selection, expansion, simulation, and backpropagation. (b) a demonstration of the ideal (but unrealistic) parallel algorithm, i.e., the return V (cumulative reward) is available as soon as simulations start (in real-world cases they are observable only after simulations complete). (c) if we parallelize UCT naively, problems such as *collapes of exploration* or *exploitation failure* will happen. Specifically, since less statistics are available at the selection step, the algorithm cannot choose the "best" node to query. (d) we propose to keep track of the on-going but non-terminated simulations (called unobserved samples) to correct and compensate the outdated statistics. This allows performing principled selection step on parallel settings, allowing WU-UCT to achieve linear speedup as well as negligible performance loss.
19 |
20 |
21 |
22 |
23 |
24 | WU-UCT achieves ideal speedup under up to 16 workers, also without performance degradation.
25 |
26 |
27 |
28 |
29 |
30 | Clear advantage compared to baseline parallel approaches, in terms of both speed and accuracy.
31 |
32 | ## System implementation
33 |
34 |
35 |
36 |
37 | Our implementation of the system consists of a master process and two sets of slave workers, i.e., expansion workers and simulation workers. With a clear division of labor, we parallel the most time-consuming expansion and simulation step, while maintain the sequential structure in the selection and backpropagation step.
38 |
39 |
40 |
41 |
42 |
43 | The breakdown of time consumption (tested with 16 expansion and simulation workers) indicates we successfully parallelize the most time-consuming expansion and simulation process and maintains time-consumption of other steps relatively small.
44 |
45 | # Usage
46 | ## Prerequisites
47 | - Python 3.x
48 | - PyTorch 1.0
49 | - Gym (with atari) 0.14.0
50 | - Numpy 1.17.2
51 | - Scipy 1.3.1
52 | - OpenCV-Python 4.1.1.26
53 |
54 | ## Running
55 | 1. Download or clone the repository.
56 | 2. Run with the default settings:
57 | ```
58 | python3 main.py --model WU-UCT
59 | ```
60 | 3. For additional hyperparameters please have a look at [main.py](https://github.com/liuanji/WU-UCT/tree/master/main.py) (they are also listed below), where descriptions are also included. For example, if you want to run the game PongNoFrameskip-v0 with 200 MCTS rollouts, simply run:
61 | ```
62 | python3 main.py --model WU-UCT --env-name PongNoFrameskip-v0 --MCTS-max-steps 200
63 | ```
64 | or if you want to record the video of gameplay, run:
65 | ```
66 | python3 main.py --model WU-UCT --env-name PongNoFrameskip-v0 --record-video
67 | ```
68 |
69 | * A full list of parameters
70 | * --model: MCTS model to use (currently support WU-UCT and UCT).
71 | * --env-name: name of the environment.
72 | * --MCTS-max-steps: number of simulation steps in the planning phase.
73 | * --MCTS-max-depth: maximum planning depth.
74 | * --MCTS-max-width: maximum width for each node.
75 | * --gamma: environment discount factor.
76 | * --expansion-worker-num: number of expansion workers.
77 | * --simulation-worker-num: number of simulation workers.
78 | * --seed: random seed for the environment.
79 | * --max-episode-length: a strict upper bound of environment's episode length.
80 | * --policy: default policy (see above).
81 | * --device: support "cpu", "cuda:x", and "cuda". If entered "cuda", it will use all available cuda devices. Usually used to load the policy.
82 | * --record-video: see above.
83 | * --mode: MCTS or Distill, see [Planning with prior policy](#Planning-with-prior-policy).
84 |
85 | ### Planning with prior policy
86 | The code currently support three default policies (policy used to perform simulation): *Random*, *PPO*, *DistillPPO* (to use them, change the “--policy” parameter). To use the *PPO* and *DistillPPO* policy, corresponding policy files need to be put in [./Policy/PPO/PolicyFiles](https://github.com/liuanji/WU-UCT/tree/master/Policy/PPO/PolicyFiles). PPO policy files can be generated by [Atari_PPO_training](https://github.com/liuanji/WU-UCT/tree/master/Utils/Atari_PPO_training) (or see [Github](https://github.com/lnpalmer/PPO)). For example, by running
87 | ```
88 | cd Utils/Atari_PPO_training
89 | python3 main.py PongNoFrameskip-v0
90 | ```
91 | a policy file will be generated in [./Utils/Atari_PPO_training/save](https://github.com/liuanji/WU-UCT/tree/master/Utils/Atari_PPO_training/save). To run DistillPPO, we have to run the distill training process by
92 | ```
93 | python3 main.py --mode Distill --env-name PongNoFrameskip-v0
94 | ```
95 |
96 | ## Run on your own environments
97 | We kindly provide an [environment wrapper](https://github.com/liuanji/WU-UCT/tree/master/Env/EnvWrapper.py) and a [policy wrapper](https://github.com/liuanji/WU-UCT/tree/master/Policy/PolicyWrapper.py) to make easy extensions to other environments. All you need is to modify [./Env/EnvWrapper.py](https://github.com/liuanji/WU-UCT/tree/master/Env/EnvWrapper.py) and [./Policy/PolicyWrapper.py](https://github.com/liuanji/WU-UCT/tree/master/Policy/PolicyWrapper.py), and fit in your own environment. Please follow the below instructions.
98 |
99 | 1. Edit the class EnvWrapper in [./Env/EnvWrapper.py](https://github.com/liuanji/WU-UCT/tree/master/Env/EnvWrapper.py).
100 |
101 | Nest your environment into the wrapper by providing specific functionality in each of the member function of EnvWrapper. There are currently four input arguments to EnvWrapper: *env_name*, *max_episode_length*, *enable_record*, and *record_path*. If additional information needs to be imported, you may first consider adding them in *env_name*.
102 |
103 | 2. Edit the class PolicyWrapper in [./Policy/PolicyWrapper.py](https://github.com/liuanji/WU-UCT/tree/master/Policy/PolicyWrapper.py).
104 |
105 | Similarly, nest your default policy in PolicyWrapper, and pass the corresponding method using --policy. You will need to rewrite *get_action*, *get_value*, and *get_prior_prob* three member functions.
106 |
107 | # Updates and to-dos
108 | ## Past updates
109 | 1. Refactor prior policy module to support easy reuse (Sep. 26, 2019).
110 |
111 | ## To-do list
112 | (empty)
113 |
114 | # Reference
115 | Please cite the paper in the following format if you used this code during your research :)
116 | ```
117 | @inproceedings{liu2020watch,
118 | title = {Watch the Unobserved: A Simple Approach to Parallelizing Monte Carlo Tree Search},
119 | author = {Anji Liu and Jianshu Chen and Mingze Yu and Yu Zhai and Xuewen Zhou and Ji Liu},
120 | booktitle = {International Conference on Learning Representations},
121 | month = apr,
122 | year = {2020},
123 | url = "https://openreview.net/forum?id=BJlQtJSKDB"
124 | }
125 | ```
126 |
--------------------------------------------------------------------------------
/Records/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Records/.gitkeep
--------------------------------------------------------------------------------
/Results/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Results/.gitkeep
--------------------------------------------------------------------------------
/Tree/UCT.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from copy import deepcopy
3 | from multiprocessing import Process
4 | import gc
5 | import time
6 | import random
7 | import os
8 | import torch
9 |
10 | from Node.UCTnode import UCTnode
11 |
12 | from Env.EnvWrapper import EnvWrapper
13 |
14 | from Mem.CheckpointManager import CheckpointManager
15 |
16 | from Policy.PolicyWrapper import PolicyWrapper
17 |
18 |
19 | class UCT():
20 | def __init__(self, env_params, max_steps = 1000, max_depth = 20, max_width = 5,
21 | gamma = 1.0, policy = "Random", seed = 123, device = torch.device("cpu")):
22 | self.env_params = env_params
23 | self.max_steps = max_steps
24 | self.max_depth = max_depth
25 | self.max_width = max_width
26 | self.gamma = gamma
27 | self.policy = policy
28 | self.seed = seed
29 | self.device = device
30 |
31 | self.policy_wrapper = None
32 |
33 | # Environment
34 | self.wrapped_env = EnvWrapper(**env_params)
35 |
36 | # Environment properties
37 | self.action_n = self.wrapped_env.get_action_n()
38 | self.max_width = min(self.action_n, self.max_width)
39 |
40 | assert self.max_depth > 0 and 0 < self.max_width <= self.action_n
41 |
42 | # Checkpoint data manager
43 | self.checkpoint_data_manager = CheckpointManager()
44 | self.checkpoint_data_manager.hock_env("main", self.wrapped_env)
45 |
46 | # For MCTS tree
47 | self.root_node = None
48 | self.global_saving_idx = 0
49 |
50 | self.init_policy()
51 |
52 | def init_policy(self):
53 | self.policy_wrapper = PolicyWrapper(
54 | self.policy,
55 | self.env_params["env_name"],
56 | self.action_n,
57 | self.device
58 | )
59 |
60 | # Entrance of the P-UCT algorithm
61 | def simulate_trajectory(self, max_episode_length = -1):
62 | state = self.wrapped_env.reset()
63 | accu_reward = 0.0
64 | done = False
65 | step_count = 0
66 | rewards = []
67 | times = []
68 |
69 | game_start_time = time.time()
70 |
71 | while not done and (max_episode_length == -1 or step_count < max_episode_length):
72 | simulation_start_time = time.time()
73 | action = self.simulate_single_move(state)
74 | simulation_end_time = time.time()
75 |
76 | next_state, reward, done = self.wrapped_env.step(action)
77 | rewards.append(reward)
78 | times.append(simulation_end_time - simulation_start_time)
79 |
80 | print("> Time step {}, take action {}, instance reward {}, cumulative reward {}, used {} seconds".format(
81 | step_count, action, reward, accu_reward + reward, simulation_end_time - simulation_start_time))
82 |
83 | accu_reward += reward
84 | state = next_state
85 | step_count += 1
86 |
87 | game_end_time = time.time()
88 | print("> game ended. total reward: {}, used time {} s".format(accu_reward, game_end_time - game_start_time))
89 |
90 | return accu_reward, np.array(rewards, dtype = np.float32), np.array(times, dtype = np.float32)
91 |
92 | def simulate_single_move(self, state):
93 | # Clear cache
94 | self.root_node = None
95 | self.global_saving_idx = 0
96 | self.checkpoint_data_manager.clear()
97 |
98 | gc.collect()
99 |
100 | # Construct root node
101 | self.checkpoint_data_manager.checkpoint_env("main", self.global_saving_idx)
102 |
103 | self.root_node = UCTnode(
104 | action_n = self.action_n,
105 | state = state,
106 | checkpoint_idx = self.global_saving_idx,
107 | parent = None,
108 | tree = self,
109 | is_head = True
110 | )
111 |
112 | self.global_saving_idx += 1
113 |
114 | for _ in range(self.max_steps):
115 | self.simulate_single_step()
116 |
117 | best_action = self.root_node.max_utility_action()
118 |
119 | self.checkpoint_data_manager.load_checkpoint_env("main", self.root_node.checkpoint_idx)
120 |
121 | return best_action
122 |
123 | def simulate_single_step(self):
124 | # Go into root node
125 | curr_node = self.root_node
126 |
127 | # Selection
128 | curr_depth = 1
129 | while True:
130 | if curr_node.no_child_available() or (not curr_node.all_child_visited() and
131 | curr_node != self.root_node and np.random.random() < 0.5) or \
132 | (not curr_node.all_child_visited() and curr_node == self.root_node):
133 | # If no child node has been updated, we have to perform expansion anyway.
134 | # Or if root node is not fully visited.
135 | # Or if non-root node is not fully visited and {with prob 1/2}.
136 |
137 | need_expansion = True
138 | break
139 |
140 | else:
141 | action = curr_node.select_action()
142 |
143 | curr_node.update_history(action, curr_node.rewards[action])
144 |
145 | if curr_node.dones[action] or curr_depth >= self.max_depth:
146 | need_expansion = False
147 | break
148 |
149 | next_node = curr_node.children[action]
150 |
151 | curr_depth += 1
152 | curr_node = next_node
153 |
154 | # Expansion
155 | if need_expansion:
156 | expand_action = curr_node.select_expand_action()
157 |
158 | self.checkpoint_data_manager.load_checkpoint_env("main", curr_node.checkpoint_idx)
159 | next_state, reward, done = self.wrapped_env.step(expand_action)
160 | self.checkpoint_data_manager.checkpoint_env("main", self.global_saving_idx)
161 |
162 | curr_node.rewards[expand_action] = reward
163 | curr_node.dones[expand_action] = done
164 |
165 | curr_node.update_history(
166 | action_taken = expand_action,
167 | reward = reward
168 | )
169 |
170 | curr_node.add_child(
171 | expand_action,
172 | next_state,
173 | self.global_saving_idx,
174 | prior_prob = self.get_prior_prob(next_state)
175 | )
176 | self.global_saving_idx += 1
177 | else:
178 | self.checkpoint_data_manager.load_checkpoint_env("main", curr_node.checkpoint_idx)
179 | next_state, reward, done = self.wrapped_env.step(action)
180 |
181 | curr_node.rewards[action] = reward
182 | curr_node.dones[action] = done
183 |
184 | # Simulation
185 | done = False
186 | accu_reward = 0.0
187 | accu_gamma = 1.0
188 |
189 | while not done:
190 | action = self.get_action(next_state)
191 |
192 | next_state, reward, done = self.wrapped_env.step(action)
193 |
194 | accu_reward += reward * accu_gamma
195 |
196 | accu_gamma *= self.gamma
197 |
198 | # Complete Update
199 | self.complete_update(curr_node, self.root_node, accu_reward)
200 |
201 | def get_action(self, state):
202 | return self.policy_wrapper.get_action(state)
203 |
204 | def get_prior_prob(self, state):
205 | return self.policy_wrapper.get_prior_prob(state)
206 |
207 | def close(self):
208 | pass
209 |
210 | @staticmethod
211 | def complete_update(curr_node, curr_node_head, accu_reward):
212 | while curr_node != curr_node_head:
213 | accu_reward = curr_node.update(accu_reward)
214 | curr_node = curr_node.parent
215 |
216 | curr_node_head.update(accu_reward)
217 |
--------------------------------------------------------------------------------
/Tree/WU_UCT.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from copy import deepcopy
3 | import gc
4 | import time
5 | import logging
6 |
7 | from Node.WU_UCTnode import WU_UCTnode
8 |
9 | from Env.EnvWrapper import EnvWrapper
10 |
11 | from ParallelPool.PoolManager import PoolManager
12 |
13 | from Mem.CheckpointManager import CheckpointManager
14 |
15 |
16 | class WU_UCT():
17 | def __init__(self, env_params, max_steps = 1000, max_depth = 20, max_width = 5,
18 | gamma = 1.0, expansion_worker_num = 16, simulation_worker_num = 16,
19 | policy = "Random", seed = 123, device = "cpu", record_video = False):
20 | self.env_params = env_params
21 | self.max_steps = max_steps
22 | self.max_depth = max_depth
23 | self.max_width = max_width
24 | self.gamma = gamma
25 | self.expansion_worker_num = expansion_worker_num
26 | self.simulation_worker_num = simulation_worker_num
27 | self.policy = policy
28 | self.device = device
29 | self.record_video = record_video
30 |
31 | # Environment
32 | record_path = "Records/P-UCT_" + env_params["env_name"] + ".mp4"
33 | self.wrapped_env = EnvWrapper(**env_params, enable_record = record_video,
34 | record_path = record_path)
35 |
36 | # Environment properties
37 | self.action_n = self.wrapped_env.get_action_n()
38 | self.max_width = min(self.action_n, self.max_width)
39 |
40 | assert self.max_depth > 0 and 0 < self.max_width <= self.action_n
41 |
42 | # Expansion worker pool
43 | self.expansion_worker_pool = PoolManager(
44 | worker_num = expansion_worker_num,
45 | env_params = env_params,
46 | policy = policy,
47 | gamma = gamma,
48 | seed = seed,
49 | device = device,
50 | need_policy = False
51 | )
52 |
53 | # Simulation worker pool
54 | self.simulation_worker_pool = PoolManager(
55 | worker_num = simulation_worker_num,
56 | env_params = env_params,
57 | policy = policy,
58 | gamma = gamma,
59 | seed = seed,
60 | device = device,
61 | need_policy = True
62 | )
63 |
64 | # Checkpoint data manager
65 | self.checkpoint_data_manager = CheckpointManager()
66 | self.checkpoint_data_manager.hock_env("main", self.wrapped_env)
67 |
68 | # For MCTS tree
69 | self.root_node = None
70 | self.global_saving_idx = 0
71 |
72 | # Task recorder
73 | self.expansion_task_recorder = dict()
74 | self.unscheduled_expansion_tasks = list()
75 | self.simulation_task_recorder = dict()
76 | self.unscheduled_simulation_tasks = list()
77 |
78 | # Simulation count
79 | self.simulation_count = 0
80 |
81 | # Logging
82 | logging.basicConfig(filename = "Logs/P-UCT_" + self.env_params["env_name"] + "_" +
83 | str(self.simulation_worker_num) + ".log", level = logging.INFO)
84 |
85 | # Entrance of the P-UCT algorithm
86 | # This is the outer loop of P-UCT simulation, where the P-UCT agent consecutively plan a best action and
87 | # interact with the environment.
88 | def simulate_trajectory(self, max_episode_length = -1):
89 | state = self.wrapped_env.reset()
90 | accu_reward = 0.0
91 | done = False
92 | step_count = 0
93 | rewards = []
94 | times = []
95 |
96 | game_start_time = time.clock()
97 |
98 | logging.info("Start simulation")
99 |
100 | while not done and (max_episode_length == -1 or step_count < max_episode_length):
101 | # Plan a best action under the current state
102 | simulation_start_time = time.clock()
103 | action = self.simulate_single_move(state)
104 | simulation_end_time = time.clock()
105 |
106 | # Interact with the environment
107 | next_state, reward, done = self.wrapped_env.step(action)
108 | rewards.append(reward)
109 | times.append(simulation_end_time - simulation_start_time)
110 |
111 | print("> Time step {}, take action {}, instance reward {}, cumulative reward {}, used {} seconds".format(
112 | step_count, action, reward, accu_reward + reward, simulation_end_time - simulation_start_time))
113 | logging.info("> Time step {}, take action {}, instance reward {}, cumulative reward {}, used {} seconds".format(
114 | step_count, action, reward, accu_reward + reward, simulation_end_time - simulation_start_time))
115 |
116 | # Record video
117 | if self.record_video:
118 | self.wrapped_env.capture_frame()
119 | self.wrapped_env.store_video_files()
120 |
121 | # update game status
122 | accu_reward += reward
123 | state = next_state
124 | step_count += 1
125 |
126 | game_end_time = time.clock()
127 | print("> game ended. total reward: {}, used time {} s".format(accu_reward, game_end_time - game_start_time))
128 | logging.info("> game ended. total reward: {}, used time {} s".format(accu_reward,
129 | game_end_time - game_start_time))
130 |
131 | return accu_reward, np.array(rewards, dtype = np.float32), np.array(times, dtype = np.float32)
132 |
133 | # This is the planning process of P-UCT. Starts from a tree with a root node only,
134 | # P-UCT performs selection, expansion, simulation, and backpropagation on it.
135 | def simulate_single_move(self, state):
136 | # Clear cache
137 | self.root_node = None
138 | self.global_saving_idx = 0
139 | self.checkpoint_data_manager.clear()
140 |
141 | # Clear recorders
142 | self.expansion_task_recorder.clear()
143 | self.unscheduled_expansion_tasks.clear()
144 | self.simulation_task_recorder.clear()
145 | self.unscheduled_simulation_tasks.clear()
146 |
147 | gc.collect()
148 |
149 | # Free all workers
150 | self.expansion_worker_pool.wait_until_all_envs_idle()
151 | self.simulation_worker_pool.wait_until_all_envs_idle()
152 |
153 | # Construct root node
154 | self.checkpoint_data_manager.checkpoint_env("main", self.global_saving_idx)
155 | self.root_node = WU_UCTnode(
156 | action_n = self.action_n,
157 | state = state,
158 | checkpoint_idx = self.global_saving_idx,
159 | parent = None,
160 | tree = self,
161 | is_head = True
162 | )
163 |
164 | # An index used to retrieve game-states
165 | self.global_saving_idx += 1
166 |
167 | # t_complete in the origin paper, measures the completed number of simulations
168 | self.simulation_count = 0
169 |
170 | # Repeatedly invoke the master loop (Figure 2 of the paper)
171 | sim_idx = 0
172 | while self.simulation_count < self.max_steps:
173 | self.simulate_single_step(sim_idx)
174 |
175 | sim_idx += 1
176 |
177 | # Select the best root action
178 | best_action = self.root_node.max_utility_action()
179 |
180 | # Retrieve the game-state before simulation begins
181 | self.checkpoint_data_manager.load_checkpoint_env("main", self.root_node.checkpoint_idx)
182 |
183 | return best_action
184 |
185 | def simulate_single_step(self, sim_idx):
186 | # Go into root node
187 | curr_node = self.root_node
188 |
189 | # Selection
190 | curr_depth = 1
191 | while True:
192 | if curr_node.no_child_available() or (not curr_node.all_child_visited() and
193 | curr_node != self.root_node and np.random.random() < 0.5) or \
194 | (not curr_node.all_child_visited() and curr_node == self.root_node):
195 | # If no child node has been updated, we have to perform expansion anyway.
196 | # Or if root node is not fully visited.
197 | # Or if non-root node is not fully visited and {with prob 1/2}.
198 |
199 | cloned_curr_node = curr_node.shallow_clone()
200 | checkpoint_data = self.checkpoint_data_manager.retrieve(curr_node.checkpoint_idx)
201 |
202 | # Record the task
203 | self.expansion_task_recorder[sim_idx] = (checkpoint_data, cloned_curr_node, curr_node)
204 | self.unscheduled_expansion_tasks.append(sim_idx)
205 |
206 | need_expansion = True
207 | break
208 |
209 | else:
210 | action = curr_node.select_action()
211 |
212 | curr_node.update_history(sim_idx, action, curr_node.rewards[action])
213 |
214 | if curr_node.dones[action] or curr_depth >= self.max_depth:
215 | # Exceed maximum depth
216 | need_expansion = False
217 | break
218 |
219 | if curr_node.children[action] is None:
220 | need_expansion = False
221 | break
222 |
223 | next_node = curr_node.children[action]
224 |
225 | curr_depth += 1
226 | curr_node = next_node
227 |
228 | # Expansion
229 | if not need_expansion:
230 | if not curr_node.dones[action]:
231 | # Reach maximum depth but have not terminate.
232 | # Record simulation task.
233 |
234 | self.simulation_task_recorder[sim_idx] = (
235 | action,
236 | curr_node,
237 | curr_node.checkpoint_idx,
238 | None
239 | )
240 | self.unscheduled_simulation_tasks.append(sim_idx)
241 | else:
242 | # Reach terminal node.
243 | # In this case, update directly.
244 |
245 | self.incomplete_update(curr_node, self.root_node, sim_idx)
246 | self.complete_update(curr_node, self.root_node, 0.0, sim_idx)
247 |
248 | self.simulation_count += 1
249 |
250 | else:
251 | # Assign tasks to idle server
252 | while len(self.unscheduled_expansion_tasks) > 0 and self.expansion_worker_pool.has_idle_server():
253 | # Get a task
254 | curr_idx = np.random.randint(0, len(self.unscheduled_expansion_tasks))
255 | task_idx = self.unscheduled_expansion_tasks.pop(curr_idx)
256 |
257 | # Assign the task to server
258 | checkpoint_data, cloned_curr_node, _ = self.expansion_task_recorder[task_idx]
259 | self.expansion_worker_pool.assign_expansion_task(
260 | checkpoint_data,
261 | cloned_curr_node,
262 | self.global_saving_idx,
263 | task_idx
264 | )
265 | self.global_saving_idx += 1
266 |
267 | # Wait for an expansion task to complete
268 | if self.expansion_worker_pool.server_occupied_rate() >= 0.99:
269 | expand_action, next_state, reward, done, checkpoint_data, \
270 | saving_idx, task_idx = self.expansion_worker_pool.get_complete_expansion_task()
271 |
272 | curr_node = self.expansion_task_recorder.pop(task_idx)[2]
273 | curr_node.update_history(task_idx, expand_action, reward)
274 |
275 | # Record info
276 | curr_node.dones[expand_action] = done
277 | curr_node.rewards[expand_action] = reward
278 |
279 | if done:
280 | # If this expansion result in a terminal node, perform update directly.
281 | # (simulation is not needed)
282 |
283 | self.incomplete_update(curr_node, self.root_node, task_idx)
284 | self.complete_update(curr_node, self.root_node, 0.0, task_idx)
285 |
286 | self.simulation_count += 1
287 |
288 | else:
289 | # Schedule the task to the simulation task buffer.
290 |
291 | self.checkpoint_data_manager.store(saving_idx, checkpoint_data)
292 |
293 | self.simulation_task_recorder[task_idx] = (
294 | expand_action,
295 | curr_node,
296 | saving_idx,
297 | deepcopy(next_state)
298 | )
299 | self.unscheduled_simulation_tasks.append(task_idx)
300 |
301 | # Assign simulation tasks to idle environment server
302 | while len(self.unscheduled_simulation_tasks) > 0 and self.simulation_worker_pool.has_idle_server():
303 | # Get a task
304 | idx = np.random.randint(0, len(self.unscheduled_simulation_tasks))
305 | task_idx = self.unscheduled_simulation_tasks.pop(idx)
306 |
307 | checkpoint_data = self.checkpoint_data_manager.retrieve(self.simulation_task_recorder[task_idx][2])
308 |
309 | first_aciton = None if self.simulation_task_recorder[task_idx][3] \
310 | is not None else self.simulation_task_recorder[task_idx][0]
311 |
312 | # Assign the task to server
313 | self.simulation_worker_pool.assign_simulation_task(
314 | task_idx,
315 | checkpoint_data,
316 | first_action = first_aciton
317 | )
318 |
319 | # Perform incomplete update
320 | self.incomplete_update(
321 | self.simulation_task_recorder[task_idx][1], # This is the corresponding node
322 | self.root_node,
323 | task_idx
324 | )
325 |
326 | # Wait for a simulation task to complete
327 | if self.simulation_worker_pool.server_occupied_rate() >= 0.99:
328 | args = self.simulation_worker_pool.get_complete_simulation_task()
329 | if len(args) == 3:
330 | task_idx, accu_reward, prior_prob = args
331 | else:
332 | task_idx, accu_reward, reward, done = args
333 | expand_action, curr_node, saving_idx, next_state = self.simulation_task_recorder.pop(task_idx)
334 |
335 | if len(args) == 4:
336 | curr_node.rewards[expand_action] = reward
337 | curr_node.dones[expand_action] = done
338 |
339 | # Add node
340 | if next_state is not None:
341 | curr_node.add_child(
342 | expand_action,
343 | next_state,
344 | saving_idx,
345 | prior_prob = prior_prob
346 | )
347 |
348 | # Complete Update
349 | self.complete_update(curr_node, self.root_node, accu_reward, task_idx)
350 |
351 | self.simulation_count += 1
352 |
353 | def close(self):
354 | # Free sub-processes
355 | self.expansion_worker_pool.close_pool()
356 | self.simulation_worker_pool.close_pool()
357 |
358 | # Incomplete update allows to track unobserved samples (Algorithm 2 in the paper)
359 | @staticmethod
360 | def incomplete_update(curr_node, curr_node_head, idx):
361 | while curr_node != curr_node_head:
362 | curr_node.update_incomplete(idx)
363 | curr_node = curr_node.parent
364 |
365 | curr_node_head.update_incomplete(idx)
366 |
367 | # Complete update tracks the observed samples (Algorithm 3 in the paper)
368 | @staticmethod
369 | def complete_update(curr_node, curr_node_head, accu_reward, idx):
370 | while curr_node != curr_node_head:
371 | accu_reward = curr_node.update_complete(idx, accu_reward)
372 | curr_node = curr_node.parent
373 |
374 | curr_node_head.update_complete(idx, accu_reward)
375 |
--------------------------------------------------------------------------------
/Tree/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Tree/__init__.py
--------------------------------------------------------------------------------
/Utils/Atari_PPO_training/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Lukas Palmer
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Utils/Atari_PPO_training/README.md:
--------------------------------------------------------------------------------
1 | # PPO
2 | PyTorch implementation of Proximal Policy Optimization
3 |
4 | 
5 |
6 | ## Usage
7 |
8 | Example command line usage:
9 | ````
10 | python main.py BreakoutNoFrameskip-v0 --num-workers 8 --render
11 | ````
12 |
13 | This will run PPO with 8 parallel training environments, which will be rendered on the screen. Run with `-h` for usage information.
14 |
15 | ## Performance
16 |
17 | Results are comparable to those of the original PPO paper. The horizontal axis here is labeled by environment steps, whereas the graphs in the paper label it with frames, with 4 frames per step.
18 |
19 | Training episode reward versus environment steps for `BreakoutNoFrameskip-v3`:
20 |
21 | 
22 |
23 | ## References
24 |
25 | [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347)
26 |
27 | [OpenAI Baselines](https://github.com/openai/baselines)
28 |
29 | This code uses some environment utilities such as `SubprocVecEnv` and `VecFrameStack` from OpenAI's Baselines.
30 |
--------------------------------------------------------------------------------
/Utils/Atari_PPO_training/atari_wrappers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | os.environ.setdefault('PATH', '')
4 | from collections import deque
5 | import gym
6 | from gym import spaces
7 | import cv2
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | class TimeLimit(gym.Wrapper):
11 | def __init__(self, env, max_episode_steps=None):
12 | super(TimeLimit, self).__init__(env)
13 | self._max_episode_steps = max_episode_steps
14 | self._elapsed_steps = 0
15 |
16 | def step(self, ac):
17 | observation, reward, done, info = self.env.step(ac)
18 | self._elapsed_steps += 1
19 | if self._elapsed_steps >= self._max_episode_steps:
20 | done = True
21 | info['TimeLimit.truncated'] = True
22 | return observation, reward, done, info
23 |
24 | def reset(self, **kwargs):
25 | self._elapsed_steps = 0
26 | return self.env.reset(**kwargs)
27 |
28 | class NoopResetEnv(gym.Wrapper):
29 | def __init__(self, env, noop_max=30):
30 | """Sample initial states by taking random number of no-ops on reset.
31 | No-op is assumed to be action 0.
32 | """
33 | gym.Wrapper.__init__(self, env)
34 | self.noop_max = noop_max
35 | self.override_num_noops = None
36 | self.noop_action = 0
37 | assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
38 |
39 | def reset(self, **kwargs):
40 | """ Do no-op action for a number of steps in [1, noop_max]."""
41 | self.env.reset(**kwargs)
42 | if self.override_num_noops is not None:
43 | noops = self.override_num_noops
44 | else:
45 | noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101
46 | assert noops > 0
47 | obs = None
48 | for _ in range(noops):
49 | obs, _, done, _ = self.env.step(self.noop_action)
50 | if done:
51 | obs = self.env.reset(**kwargs)
52 | return obs
53 |
54 | def step(self, ac):
55 | return self.env.step(ac)
56 |
57 | class FireResetEnv(gym.Wrapper):
58 | def __init__(self, env):
59 | """Take action on reset for environments that are fixed until firing."""
60 | gym.Wrapper.__init__(self, env)
61 | assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
62 | assert len(env.unwrapped.get_action_meanings()) >= 3
63 |
64 | def reset(self, **kwargs):
65 | self.env.reset(**kwargs)
66 | obs, _, done, _ = self.env.step(1)
67 | if done:
68 | self.env.reset(**kwargs)
69 | obs, _, done, _ = self.env.step(2)
70 | if done:
71 | self.env.reset(**kwargs)
72 | return obs
73 |
74 | def step(self, ac):
75 | return self.env.step(ac)
76 |
77 | class EpisodicLifeEnv(gym.Wrapper):
78 | def __init__(self, env):
79 | """Make end-of-life == end-of-episode, but only reset on true game over.
80 | Done by DeepMind for the DQN and co. since it helps value estimation.
81 | """
82 | gym.Wrapper.__init__(self, env)
83 | self.lives = 0
84 | self.was_real_done = True
85 |
86 | def step(self, action):
87 | obs, reward, done, info = self.env.step(action)
88 | self.was_real_done = done
89 | # check current lives, make loss of life terminal,
90 | # then update lives to handle bonus lives
91 | lives = self.env.unwrapped.ale.lives()
92 | if lives < self.lives and lives > 0:
93 | # for Qbert sometimes we stay in lives == 0 condition for a few frames
94 | # so it's important to keep lives > 0, so that we only reset once
95 | # the environment advertises done.
96 | done = True
97 | self.lives = lives
98 | return obs, reward, done, info
99 |
100 | def reset(self, **kwargs):
101 | """Reset only when lives are exhausted.
102 | This way all states are still reachable even though lives are episodic,
103 | and the learner need not know about any of this behind-the-scenes.
104 | """
105 | if self.was_real_done:
106 | obs = self.env.reset(**kwargs)
107 | else:
108 | # no-op step to advance from terminal/lost life state
109 | obs, _, _, _ = self.env.step(0)
110 | self.lives = self.env.unwrapped.ale.lives()
111 | return obs
112 |
113 | class MaxAndSkipEnv(gym.Wrapper):
114 | def __init__(self, env, skip=4):
115 | """Return only every `skip`-th frame"""
116 | gym.Wrapper.__init__(self, env)
117 | # most recent raw observations (for max pooling across time steps)
118 | self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
119 | self._skip = skip
120 |
121 | def step(self, action):
122 | """Repeat action, sum reward, and max over last observations."""
123 | total_reward = 0.0
124 | done = None
125 | for i in range(self._skip):
126 | obs, reward, done, info = self.env.step(action)
127 | if i == self._skip - 2: self._obs_buffer[0] = obs
128 | if i == self._skip - 1: self._obs_buffer[1] = obs
129 | total_reward += reward
130 | if done:
131 | break
132 | # Note that the observation on the done=True frame
133 | # doesn't matter
134 | max_frame = self._obs_buffer.max(axis=0)
135 |
136 | return max_frame, total_reward, done, info
137 |
138 | def reset(self, **kwargs):
139 | return self.env.reset(**kwargs)
140 |
141 | class ClipRewardEnv(gym.RewardWrapper):
142 | def __init__(self, env):
143 | gym.RewardWrapper.__init__(self, env)
144 |
145 | def reward(self, reward):
146 | """Bin reward to {+1, 0, -1} by its sign."""
147 | return np.sign(reward)
148 |
149 |
150 | class WarpFrame(gym.ObservationWrapper):
151 | def __init__(self, env, width=84, height=84, grayscale=True, dict_space_key=None):
152 | """
153 | Warp frames to 84x84 as done in the Nature paper and later work.
154 |
155 | If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which
156 | observation should be warped.
157 | """
158 | super().__init__(env)
159 | self._width = width
160 | self._height = height
161 | self._grayscale = grayscale
162 | self._key = dict_space_key
163 | if self._grayscale:
164 | num_colors = 1
165 | else:
166 | num_colors = 3
167 |
168 | new_space = gym.spaces.Box(
169 | low=0,
170 | high=255,
171 | shape=(self._height, self._width, num_colors),
172 | dtype=np.uint8,
173 | )
174 | if self._key is None:
175 | original_space = self.observation_space
176 | self.observation_space = new_space
177 | else:
178 | original_space = self.observation_space.spaces[self._key]
179 | self.observation_space.spaces[self._key] = new_space
180 | assert original_space.dtype == np.uint8 and len(original_space.shape) == 3
181 |
182 | def observation(self, obs):
183 | if self._key is None:
184 | frame = obs
185 | else:
186 | frame = obs[self._key]
187 |
188 | if self._grayscale:
189 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
190 | frame = cv2.resize(
191 | frame, (self._width, self._height), interpolation=cv2.INTER_AREA
192 | )
193 | if self._grayscale:
194 | frame = np.expand_dims(frame, -1)
195 |
196 | if self._key is None:
197 | obs = frame
198 | else:
199 | obs = obs.copy()
200 | obs[self._key] = frame
201 | return obs
202 |
203 |
204 | class FrameStack(gym.Wrapper):
205 | def __init__(self, env, k):
206 | """Stack k last frames.
207 |
208 | Returns lazy array, which is much more memory efficient.
209 |
210 | See Also
211 | --------
212 | baselines.common.atari_wrappers.LazyFrames
213 | """
214 | gym.Wrapper.__init__(self, env)
215 | self.k = k
216 | self.frames = deque([], maxlen=k)
217 | shp = env.observation_space.shape
218 | self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=env.observation_space.dtype)
219 |
220 | def reset(self):
221 | ob = self.env.reset()
222 | for _ in range(self.k):
223 | self.frames.append(ob)
224 | return self._get_ob()
225 |
226 | def step(self, action):
227 | ob, reward, done, info = self.env.step(action)
228 | self.frames.append(ob)
229 | return self._get_ob(), reward, done, info
230 |
231 | def _get_ob(self):
232 | assert len(self.frames) == self.k
233 | return LazyFrames(list(self.frames))
234 |
235 | class ScaledFloatFrame(gym.ObservationWrapper):
236 | def __init__(self, env):
237 | gym.ObservationWrapper.__init__(self, env)
238 | self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32)
239 |
240 | def observation(self, observation):
241 | # careful! This undoes the memory optimization, use
242 | # with smaller replay buffers only.
243 | return np.array(observation).astype(np.float32) / 255.0
244 |
245 | class LazyFrames(object):
246 | def __init__(self, frames):
247 | """This object ensures that common frames between the observations are only stored once.
248 | It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
249 | buffers.
250 |
251 | This object should only be converted to numpy array before being passed to the model.
252 |
253 | You'd not believe how complex the previous solution was."""
254 | self._frames = frames
255 | self._out = None
256 |
257 | def _force(self):
258 | if self._out is None:
259 | self._out = np.concatenate(self._frames, axis=-1)
260 | self._frames = None
261 | return self._out
262 |
263 | def __array__(self, dtype=None):
264 | out = self._force()
265 | if dtype is not None:
266 | out = out.astype(dtype)
267 | return out
268 |
269 | def __len__(self):
270 | return len(self._force())
271 |
272 | def __getitem__(self, i):
273 | return self._force()[i]
274 |
275 | def count(self):
276 | frames = self._force()
277 | return frames.shape[frames.ndim - 1]
278 |
279 | def frame(self, i):
280 | return self._force()[..., i]
281 |
282 | def make_atari(env_id, max_episode_steps=None):
283 | env = gym.make(env_id)
284 | assert 'NoFrameskip' in env.spec.id
285 | env = NoopResetEnv(env, noop_max=30)
286 | env = MaxAndSkipEnv(env, skip=4)
287 | if max_episode_steps is not None:
288 | env = TimeLimit(env, max_episode_steps=max_episode_steps)
289 | return env
290 |
291 | def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False):
292 | """Configure environment for DeepMind-style Atari.
293 | """
294 | if episode_life:
295 | env = EpisodicLifeEnv(env)
296 | if 'FIRE' in env.unwrapped.get_action_meanings():
297 | env = FireResetEnv(env)
298 | env = WarpFrame(env)
299 | if scale:
300 | env = ScaledFloatFrame(env)
301 | if clip_rewards:
302 | env = ClipRewardEnv(env)
303 | if frame_stack:
304 | env = FrameStack(env, 4)
305 | return env
306 |
307 |
--------------------------------------------------------------------------------
/Utils/Atari_PPO_training/envs.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import gym
3 | from gym import spaces
4 | from atari_wrappers import make_atari, wrap_deepmind
5 | from vec_env import VecEnv
6 | from multiprocessing import Process, Pipe
7 |
8 |
9 | # cf https://github.com/openai/baselines
10 |
11 | def make_env(env_name, rank, seed):
12 | env = make_atari(env_name)
13 | env.seed(seed + rank)
14 | env = wrap_deepmind(env, episode_life=False, clip_rewards=False)
15 | return env
16 |
17 |
18 | def worker(remote, parent_remote, env_fn_wrapper):
19 | parent_remote.close()
20 | env = env_fn_wrapper.x()
21 | while True:
22 | cmd, data = remote.recv()
23 | if cmd == 'step':
24 | ob, reward, done, info = env.step(data)
25 | if done:
26 | ob = env.reset()
27 | remote.send((ob, reward, done, info))
28 | elif cmd == 'reset':
29 | ob = env.reset()
30 | remote.send(ob)
31 | elif cmd == 'reset_task':
32 | ob = env.reset_task()
33 | remote.send(ob)
34 | elif cmd == 'close':
35 | remote.close()
36 | break
37 | elif cmd == 'get_spaces':
38 | remote.send((env.action_space, env.observation_space))
39 | elif cmd == 'render':
40 | env.render()
41 | else:
42 | raise NotImplementedError
43 |
44 |
45 | class CloudpickleWrapper(object):
46 | """
47 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
48 | """
49 | def __init__(self, x):
50 | self.x = x
51 | def __getstate__(self):
52 | import cloudpickle
53 | return cloudpickle.dumps(self.x)
54 | def __setstate__(self, ob):
55 | import pickle
56 | self.x = pickle.loads(ob)
57 |
58 |
59 | class RenderSubprocVecEnv(VecEnv):
60 | def __init__(self, env_fns, render_interval):
61 | """ Minor addition to SubprocVecEnv, automatically renders environments
62 |
63 | envs: list of gym environments to run in subprocesses
64 | """
65 | self.closed = False
66 | nenvs = len(env_fns)
67 | self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
68 | self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
69 | for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
70 | for p in self.ps:
71 | p.daemon = True # if the main process crashes, we should not cause things to hang
72 | p.start()
73 | for remote in self.work_remotes:
74 | remote.close()
75 |
76 | self.remotes[0].send(('get_spaces', None))
77 | self.action_space, self.observation_space = self.remotes[0].recv()
78 |
79 | self.render_interval = render_interval
80 | self.render_timer = 0
81 |
82 | def step(self, actions):
83 | for remote, action in zip(self.remotes, actions):
84 | remote.send(('step', action))
85 | results = [remote.recv() for remote in self.remotes]
86 | obs, rews, dones, infos = zip(*results)
87 |
88 | self.render_timer += 1
89 | if self.render_timer == self.render_interval:
90 | for remote in self.remotes:
91 | remote.send(('render', None))
92 | self.render_timer = 0
93 |
94 | return np.stack(obs), np.stack(rews), np.stack(dones), infos
95 |
96 | def reset(self):
97 | for remote in self.remotes:
98 | remote.send(('reset', None))
99 | return np.stack([remote.recv() for remote in self.remotes])
100 |
101 | def reset_task(self):
102 | for remote in self.remotes:
103 | remote.send(('reset_task', None))
104 | return np.stack([remote.recv() for remote in self.remotes])
105 |
106 | def close(self):
107 | if self.closed:
108 | return
109 |
110 | for remote in self.remotes:
111 | remote.send(('close', None))
112 | for p in self.ps:
113 | p.join()
114 | self.closed = True
115 |
116 | @property
117 | def num_envs(self):
118 | return len(self.remotes)
119 |
--------------------------------------------------------------------------------
/Utils/Atari_PPO_training/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.optim as optim
4 | from atari_wrappers import FrameStack
5 | from subproc_vec_env import SubprocVecEnv
6 | from vec_frame_stack import VecFrameStack
7 | import multiprocessing
8 |
9 | from envs import make_env, RenderSubprocVecEnv
10 | from models import AtariCNN
11 | from ppo import PPO
12 | from utils import set_seed, cuda_if
13 |
14 | if __name__ == "__main__":
15 | multiprocessing.set_start_method("forkserver")
16 |
17 | parser = argparse.ArgumentParser(description='PPO', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
18 | parser.add_argument('env_id', type=str, help='Gym environment id')
19 | parser.add_argument('--arch', type=str, default='cnn', help='policy architecture, {lstm, cnn}')
20 | parser.add_argument('--num-workers', type=int, default=8, help='number of parallel actors')
21 | parser.add_argument('--opt-epochs', type=int, default=3, help='optimization epochs between environment interaction')
22 | parser.add_argument('--total-steps', type=int, default=int(10e6), help='total number of environment steps to take')
23 | parser.add_argument('--worker-steps', type=int, default=128, help='steps per worker between optimization rounds')
24 | parser.add_argument('--sequence-steps', type=int, default=32, help='steps per sequence (for backprop through time)')
25 | parser.add_argument('--minibatch-steps', type=int, default=256, help='steps per optimization minibatch')
26 | parser.add_argument('--lr', type=float, default=2.5e-4, help='initial learning rate')
27 | parser.add_argument('--lr-func', type=str, default='linear', help='learning rate schedule function, {linear, constant}')
28 | parser.add_argument('--clip', type=float, default=.1, help='initial probability ratio clipping range')
29 | parser.add_argument('--clip-func', type=str, default='linear', help='clip range schedule function, {linear, constant}')
30 | parser.add_argument('--gamma', type=float, default=.99, help='discount factor')
31 | parser.add_argument('--lambd', type=float, default=.95, help='GAE lambda parameter')
32 | parser.add_argument('--value-coef', type=float, default=1., help='value loss coeffecient')
33 | parser.add_argument('--entropy-coef', type=float, default=.01, help='entropy loss coeffecient')
34 | parser.add_argument('--max-grad-norm', type=float, default=.5, help='grad norm to clip at')
35 | parser.add_argument('--no-cuda', action='store_true', help='disable CUDA acceleration')
36 | parser.add_argument('--render', action='store_true', help='render training environments')
37 | parser.add_argument('--render-interval', type=int, default=4, help='steps between environment renders')
38 | parser.add_argument('--plot-reward', action='store_true', help='plot episode reward')
39 | parser.add_argument('--plot-points', type=int, default=20, help='number of plot points (groups with mean, std)')
40 | parser.add_argument('--plot-path', type=str, default='ep_reward.png', help='path to save reward plot to')
41 | parser.add_argument('--seed', type=int, default=0, help='random seed')
42 | args = parser.parse_args()
43 |
44 | set_seed(args.seed)
45 |
46 | cuda = torch.cuda.is_available() and not args.no_cuda
47 |
48 | env_fns = []
49 | for rank in range(args.num_workers):
50 | env_fns.append(lambda: make_env(args.env_id, rank, args.seed + rank))
51 | if args.render:
52 | venv = RenderSubprocVecEnv(env_fns, args.render_interval)
53 | else:
54 | venv = SubprocVecEnv(env_fns)
55 | venv = VecFrameStack(venv, 4)
56 |
57 | test_env = make_env(args.env_id, 0, args.seed)
58 | test_env = FrameStack(test_env, 4)
59 |
60 | policy = {'cnn': AtariCNN}[args.arch](venv.action_space.n)
61 | policy = cuda_if(policy, cuda)
62 |
63 | optimizer = optim.Adam(policy.parameters())
64 |
65 | if args.lr_func == 'linear':
66 | lr_func = lambda a: args.lr * (1. - a)
67 | elif args.lr_func == 'constant':
68 | lr_func = lambda a: args.lr
69 |
70 | if args.clip_func == 'linear':
71 | clip_func = lambda a: args.clip * (1. - a)
72 | elif args.clip_func == 'constant':
73 | clip_func = lambda a: args.clip
74 |
75 | algorithm = PPO(policy, venv, test_env, optimizer,
76 | lr_func=lr_func, clip_func=clip_func, gamma=args.gamma, lambd=args.lambd,
77 | worker_steps=args.worker_steps, sequence_steps=args.sequence_steps,
78 | minibatch_steps=args.minibatch_steps,
79 | value_coef=args.value_coef, entropy_coef=args.entropy_coef,
80 | max_grad_norm=args.max_grad_norm,
81 | cuda=cuda,
82 | plot_reward=args.plot_reward, plot_points=args.plot_points, plot_path=args.plot_path, env_name = args.env_id)
83 | algorithm.run(args.total_steps)
84 |
--------------------------------------------------------------------------------
/Utils/Atari_PPO_training/models.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 |
6 |
7 | def ortho_weights(shape, scale=1.):
8 | """ PyTorch port of ortho_init from baselines.a2c.utils """
9 | shape = tuple(shape)
10 |
11 | if len(shape) == 2:
12 | flat_shape = shape[1], shape[0]
13 | elif len(shape) == 4:
14 | flat_shape = (np.prod(shape[1:]), shape[0])
15 | else:
16 | raise NotImplementedError
17 |
18 | a = np.random.normal(0., 1., flat_shape)
19 | u, _, v = np.linalg.svd(a, full_matrices=False)
20 | q = u if u.shape == flat_shape else v
21 | q = q.transpose().copy().reshape(shape)
22 |
23 | if len(shape) == 2:
24 | return torch.from_numpy((scale * q).astype(np.float32))
25 | if len(shape) == 4:
26 | return torch.from_numpy((scale * q[:, :shape[1], :shape[2]]).astype(np.float32))
27 |
28 |
29 | def atari_initializer(module):
30 | """ Parameter initializer for Atari models
31 |
32 | Initializes Linear, Conv2d, and LSTM weights.
33 | """
34 | classname = module.__class__.__name__
35 |
36 | if classname == 'Linear':
37 | module.weight.data = ortho_weights(module.weight.data.size(), scale=np.sqrt(2.))
38 | module.bias.data.zero_()
39 |
40 | elif classname == 'Conv2d':
41 | module.weight.data = ortho_weights(module.weight.data.size(), scale=np.sqrt(2.))
42 | module.bias.data.zero_()
43 |
44 | elif classname == 'LSTM':
45 | for name, param in module.named_parameters():
46 | if 'weight_ih' in name:
47 | param.data = ortho_weights(param.data.size(), scale=1.)
48 | if 'weight_hh' in name:
49 | param.data = ortho_weights(param.data.size(), scale=1.)
50 | if 'bias' in name:
51 | param.data.zero_()
52 |
53 |
54 | class AtariCNN(nn.Module):
55 | def __init__(self, num_actions):
56 | """ Basic convolutional actor-critic network for Atari 2600 games
57 |
58 | Equivalent to the network in the original DQN paper.
59 |
60 | Args:
61 | num_actions (int): the number of available discrete actions
62 | """
63 | super().__init__()
64 |
65 | self.conv = nn.Sequential(nn.Conv2d(4, 32, 8, stride=4),
66 | nn.ReLU(inplace=True),
67 | nn.Conv2d(32, 64, 4, stride=2),
68 | nn.ReLU(inplace=True),
69 | nn.Conv2d(64, 64, 3, stride=1),
70 | nn.ReLU(inplace=True))
71 |
72 | self.fc = nn.Sequential(nn.Linear(64 * 7 * 7, 512),
73 | nn.ReLU(inplace=True))
74 |
75 | self.pi = nn.Linear(512, num_actions)
76 | self.v = nn.Linear(512, 1)
77 |
78 | self.num_actions = num_actions
79 |
80 | # parameter initialization
81 | self.apply(atari_initializer)
82 | self.pi.weight.data = ortho_weights(self.pi.weight.size(), scale=.01)
83 | self.v.weight.data = ortho_weights(self.v.weight.size())
84 |
85 | def forward(self, conv_in):
86 | """ Module forward pass
87 |
88 | Args:
89 | conv_in (Variable): convolutional input, shaped [N x 4 x 84 x 84]
90 |
91 | Returns:
92 | pi (Variable): action probability logits, shaped [N x self.num_actions]
93 | v (Variable): value predictions, shaped [N x 1]
94 | """
95 | N = conv_in.size()[0]
96 |
97 | conv_out = self.conv(conv_in).view(N, 64 * 7 * 7)
98 |
99 | fc_out = self.fc(conv_out)
100 |
101 | pi_out = self.pi(fc_out)
102 | v_out = self.v(fc_out)
103 |
104 | return pi_out, v_out
105 |
--------------------------------------------------------------------------------
/Utils/Atari_PPO_training/ppo.py:
--------------------------------------------------------------------------------
1 | import random
2 | import copy
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as Fnn
7 | import matplotlib
8 | matplotlib.use('Agg')
9 | import matplotlib.pyplot as plt
10 | from torch.autograd import Variable
11 |
12 | from utils import gae, cuda_if, mean_std_groups, set_lr
13 |
14 |
15 | class PPO:
16 | def __init__(self, policy, venv, test_env, optimizer,
17 | lr_func=None, clip_func=None, gamma=.99, lambd=.95,
18 | worker_steps=128, sequence_steps=32, minibatch_steps=256,
19 | opt_epochs=3, value_coef=1., entropy_coef=.01, max_grad_norm=.5,
20 | cuda=False, plot_reward=False, plot_points=20, plot_path='ep_reward.png',
21 | test_repeat_max=100, env_name = ""):
22 | """ Proximal Policy Optimization algorithm class
23 |
24 | Evaluates a policy over a vectorized environment and
25 | optimizes over policy, value, entropy objectives.
26 |
27 | Assumes discrete action space.
28 |
29 | Args:
30 | policy (nn.Module): the policy to optimize
31 | venv (vec_env): the vectorized environment to use
32 | test_env (Env): the environment to use for policy testing
33 | optimizer (optim.Optimizer): the optimizer to use
34 | clip (float): probability ratio clipping range
35 | gamma (float): discount factor
36 | lambd (float): GAE lambda parameter
37 | worker_steps (int): steps per worker between optimization rounds
38 | sequence_steps (int): steps per sequence (for backprop through time)
39 | batch_steps (int): steps per sequence (for backprop through time)
40 | """
41 | self.policy = policy
42 | self.policy_old = copy.deepcopy(policy)
43 | self.venv = venv
44 | self.test_env = test_env
45 | self.optimizer = optimizer
46 |
47 | self.env_name = env_name
48 |
49 | self.lr_func = lr_func
50 | self.clip_func = clip_func
51 |
52 | self.num_workers = venv.num_envs
53 | self.worker_steps = worker_steps
54 | self.sequence_steps = sequence_steps
55 | self.minibatch_steps = minibatch_steps
56 |
57 | self.opt_epochs = opt_epochs
58 | self.gamma = gamma
59 | self.lambd = lambd
60 | self.value_coef = value_coef
61 | self.entropy_coef = entropy_coef
62 | self.max_grad_norm = max_grad_norm
63 | self.cuda = cuda
64 |
65 | self.plot_reward = plot_reward
66 | self.plot_points = plot_points
67 | self.plot_path = plot_path
68 | self.ep_reward = np.zeros(self.num_workers)
69 | self.reward_histr = []
70 | self.steps_histr = []
71 |
72 | self.objective = PPOObjective()
73 |
74 | self.last_ob = self.venv.reset()
75 |
76 | self.taken_steps = 0
77 |
78 | self.test_repeat_max = test_repeat_max
79 |
80 | def run(self, total_steps):
81 | """ Runs PPO
82 |
83 | Args:
84 | total_steps (int): total number of environment steps to run for
85 | """
86 | N = self.num_workers
87 | T = self.worker_steps
88 | E = self.opt_epochs
89 | A = self.venv.action_space.n
90 |
91 | while self.taken_steps < total_steps:
92 | progress = self.taken_steps / total_steps
93 |
94 | obs, rewards, masks, actions, steps = self.interact()
95 | ob_shape = obs.size()[2:]
96 |
97 | ep_reward = self.test()
98 | self.reward_histr.append(ep_reward)
99 | self.steps_histr.append(self.taken_steps)
100 |
101 | # statistic logic
102 | group_size = len(self.steps_histr) // self.plot_points
103 | if self.plot_reward and len(self.steps_histr) % (self.plot_points * 10) == 0 and group_size >= 10:
104 | x_means, _, y_means, y_stds = \
105 | mean_std_groups(np.array(self.steps_histr), np.array(self.reward_histr), group_size)
106 | fig = plt.figure()
107 | fig.set_size_inches(8, 6)
108 | plt.ticklabel_format(axis='x', style='sci', scilimits=(-2, 6))
109 | plt.errorbar(x_means, y_means, yerr=y_stds, ecolor='xkcd:blue', fmt='xkcd:black', capsize=5, elinewidth=1.5, mew=1.5, linewidth=1.5)
110 | plt.title('Training progress')
111 | plt.xlabel('Total steps')
112 | plt.ylabel('Episode reward')
113 | plt.savefig(self.plot_path, dpi=200)
114 | plt.clf()
115 | plt.close()
116 | plot_timer = 0
117 |
118 | # TEMP upgrade to support recurrence
119 |
120 | # compute advantages, returns with GAE
121 | obs_ = obs.view(((T + 1) * N,) + ob_shape)
122 | obs_ = Variable(obs_)
123 | _, values = self.policy(obs_)
124 | values = values.view(T + 1, N, 1)
125 | advantages, returns = gae(rewards, masks, values, self.gamma, self.lambd)
126 |
127 | self.policy_old.load_state_dict(self.policy.state_dict())
128 | for e in range(E):
129 | self.policy.zero_grad()
130 |
131 | MB = steps // self.minibatch_steps
132 |
133 | b_obs = Variable(obs[:T].view((steps,) + ob_shape))
134 | b_rewards = Variable(rewards.view(steps, 1))
135 | b_masks = Variable(masks.view(steps, 1))
136 | b_actions = Variable(actions.view(steps, 1))
137 | b_advantages = Variable(advantages.view(steps, 1))
138 | b_returns = Variable(returns.view(steps, 1))
139 |
140 | b_inds = np.arange(steps)
141 | np.random.shuffle(b_inds)
142 |
143 | for start in range(0, steps, self.minibatch_steps):
144 | mb_inds = b_inds[start:start + self.minibatch_steps]
145 | mb_inds = cuda_if(torch.from_numpy(mb_inds).long(), self.cuda)
146 | mb_obs, mb_rewards, mb_masks, mb_actions, mb_advantages, mb_returns = \
147 | [arr[mb_inds] for arr in [b_obs, b_rewards, b_masks, b_actions, b_advantages, b_returns]]
148 |
149 | mb_pis, mb_vs = self.policy(mb_obs)
150 | mb_pi_olds, mb_v_olds = self.policy_old(mb_obs)
151 | mb_pi_olds, mb_v_olds = mb_pi_olds.detach(), mb_v_olds.detach()
152 |
153 | losses = self.objective(self.clip_func(progress),
154 | mb_pis, mb_vs, mb_pi_olds, mb_v_olds,
155 | mb_actions, mb_advantages, mb_returns)
156 | policy_loss, value_loss, entropy_loss = losses
157 | loss = policy_loss + value_loss * self.value_coef + entropy_loss * self.entropy_coef
158 |
159 | set_lr(self.optimizer, self.lr_func(progress))
160 | self.optimizer.zero_grad()
161 | loss.backward()
162 | torch.nn.utils.clip_grad_norm(self.policy.parameters(), self.max_grad_norm)
163 | self.optimizer.step()
164 |
165 | self.taken_steps += steps
166 | print(self.taken_steps)
167 |
168 | torch.save({'policy': self.policy.state_dict()}, "./save/PPO_" + self.env_name + ".pt")
169 |
170 | def interact(self):
171 | """ Interacts with the environment
172 |
173 | Returns:
174 | obs (ArgumentDefaultsHelpFormatternsor): observations shaped [T + 1 x N x ...]
175 | rewards (FloatTensor): rewards shaped [T x N x 1]
176 | masks (FloatTensor): continuation masks shaped [T x N x 1]
177 | zero at done timesteps, one otherwise
178 | actions (LongTensor): discrete actions shaped [T x N x 1]
179 | steps (int): total number of steps taken
180 | """
181 | N = self.num_workers
182 | T = self.worker_steps
183 |
184 | # TEMP needs to be generalized, does conv-specific transpose for PyTorch
185 | obs = torch.zeros(T + 1, N, 4, 84, 84)
186 | obs = cuda_if(obs, self.cuda)
187 | rewards = torch.zeros(T, N, 1)
188 | rewards = cuda_if(rewards, self.cuda)
189 | masks = torch.zeros(T, N, 1)
190 | masks = cuda_if(masks, self.cuda)
191 | actions = torch.zeros(T, N, 1).long()
192 | actions = cuda_if(actions, self.cuda)
193 |
194 | for t in range(T):
195 | # interaction logic
196 | ob = torch.from_numpy(self.last_ob.transpose((0, 3, 1, 2))).float()
197 | ob = Variable(ob / 255.)
198 | ob = cuda_if(ob, self.cuda)
199 | obs[t] = ob.data
200 |
201 | pi, v = self.policy(ob)
202 | u = cuda_if(torch.rand(pi.size()), self.cuda)
203 | _, action = torch.max(pi.data - (-u.log()).log(), 1)
204 | action = action.unsqueeze(1)
205 | actions[t] = action
206 |
207 | self.last_ob, reward, done, _ = self.venv.step(action.cpu().numpy())
208 | reward = torch.from_numpy(reward).unsqueeze(1)
209 | rewards[t] = torch.clamp(reward, min=-1., max=1.)
210 | masks[t] = mask = torch.from_numpy((1. - done)).unsqueeze(1)
211 |
212 | ob = torch.from_numpy(self.last_ob.transpose((0, 3, 1, 2))).float()
213 | ob = Variable(ob / 255.)
214 | ob = cuda_if(ob, self.cuda)
215 | obs[T] = ob.data
216 |
217 | steps = N * T
218 |
219 | return obs, rewards, masks, actions, steps
220 |
221 | def test(self):
222 | ob = self.test_env.reset()
223 | done = False
224 | ep_reward = 0
225 | last_action = np.array([-1])
226 | action_repeat = 0
227 |
228 | while not done:
229 | ob = np.array(ob)
230 | ob = torch.from_numpy(ob.transpose((2, 0, 1))).float().unsqueeze(0)
231 | ob = Variable(ob / 255., volatile=True)
232 | ob = cuda_if(ob, self.cuda)
233 |
234 | pi, v = self.policy(ob)
235 | _, action = torch.max(pi, dim=1)
236 |
237 | # abort after {self.test_repeat_max} discrete action repeats
238 | if action.data[0] == last_action.data[0]:
239 | action_repeat += 1
240 | if action_repeat == self.test_repeat_max:
241 | return ep_reward
242 | else:
243 | action_repeat = 0
244 | last_action = action
245 |
246 | ob, reward, done, _ = self.test_env.step(action.data.cpu().numpy())
247 |
248 | ep_reward += reward
249 |
250 | return ep_reward
251 |
252 |
253 | class PPOObjective(nn.Module):
254 | def forward(self, clip, pi, v, pi_old, v_old, action, advantage, returns):
255 | """ Computes PPO objectives
256 |
257 | Assumes discrete action space.
258 |
259 | Args:
260 | clip (float): probability ratio clipping range
261 | pi (Variable): discrete action logits, shaped [N x num_actions]
262 | v (Variable): value predictions, shaped [N x 1]
263 | pi_old (Variable): old discrete action logits, shaped [N x num_actions]
264 | v_old (Variable): old value predictions, shaped [N x 1]
265 | action (Variable): discrete actions, shaped [N x 1]
266 | advantage (Variable): action advantages, shaped [N x 1]
267 | returns (Variable): discounted returns, shaped [N x 1]
268 |
269 | Returns:
270 | policy_loss (Variable): policy surrogate loss, shaped [1]
271 | value_loss (Variable): value loss, shaped [1]
272 | entropy_loss (Variable): entropy loss, shaped [1]
273 | """
274 | prob = Fnn.softmax(pi)
275 | log_prob = Fnn.log_softmax(pi)
276 | action_prob = prob.gather(1, action)
277 |
278 | prob_old = Fnn.softmax(pi_old)
279 | action_prob_old = prob_old.gather(1, action)
280 |
281 | ratio = action_prob / (action_prob_old + 1e-10)
282 |
283 | advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-5)
284 |
285 | surr1 = ratio * advantage
286 | surr2 = torch.clamp(ratio, min=1. - clip, max=1. + clip) * advantage
287 |
288 | policy_loss = -torch.min(surr1, surr2).mean()
289 | value_loss = (.5 * (v - returns) ** 2.).mean()
290 | entropy_loss = (prob * log_prob).sum(1).mean()
291 |
292 | return policy_loss, value_loss, entropy_loss
293 |
--------------------------------------------------------------------------------
/Utils/Atari_PPO_training/save/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Utils/Atari_PPO_training/save/.gitkeep
--------------------------------------------------------------------------------
/Utils/Atari_PPO_training/subproc_vec_env.py:
--------------------------------------------------------------------------------
1 | import multiprocessing as mp
2 |
3 | import numpy as np
4 | from vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars
5 |
6 |
7 | def worker(remote, parent_remote, env_fn_wrapper):
8 | parent_remote.close()
9 | env = env_fn_wrapper.x()
10 | try:
11 | while True:
12 | cmd, data = remote.recv()
13 | if cmd == 'step':
14 | ob, reward, done, info = env.step(data)
15 | if done:
16 | ob = env.reset()
17 | remote.send((ob, reward, done, info))
18 | elif cmd == 'reset':
19 | ob = env.reset()
20 | remote.send(ob)
21 | elif cmd == 'render':
22 | remote.send(env.render(mode='rgb_array'))
23 | elif cmd == 'close':
24 | remote.close()
25 | break
26 | elif cmd == 'get_spaces_spec':
27 | remote.send((env.observation_space, env.action_space, env.spec))
28 | else:
29 | raise NotImplementedError
30 | except KeyboardInterrupt:
31 | print('SubprocVecEnv worker: got KeyboardInterrupt')
32 | finally:
33 | env.close()
34 |
35 |
36 | class SubprocVecEnv(VecEnv):
37 | """
38 | VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes.
39 | Recommended to use when num_envs > 1 and step() can be a bottleneck.
40 | """
41 | def __init__(self, env_fns, spaces=None, context='spawn'):
42 | """
43 | Arguments:
44 |
45 | env_fns: iterable of callables - functions that create environments to run in subprocesses. Need to be cloud-pickleable
46 | """
47 | self.waiting = False
48 | self.closed = False
49 | nenvs = len(env_fns)
50 | ctx = mp.get_context(context)
51 | self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(nenvs)])
52 | self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
53 | for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
54 | for p in self.ps:
55 | p.daemon = True # if the main process crashes, we should not cause things to hang
56 | with clear_mpi_env_vars():
57 | p.start()
58 | for remote in self.work_remotes:
59 | remote.close()
60 |
61 | self.remotes[0].send(('get_spaces_spec', None))
62 | observation_space, action_space, self.spec = self.remotes[0].recv()
63 | self.viewer = None
64 | VecEnv.__init__(self, len(env_fns), observation_space, action_space)
65 |
66 | def step_async(self, actions):
67 | self._assert_not_closed()
68 | for remote, action in zip(self.remotes, actions):
69 | remote.send(('step', action))
70 | self.waiting = True
71 |
72 | def step_wait(self):
73 | self._assert_not_closed()
74 | results = [remote.recv() for remote in self.remotes]
75 | self.waiting = False
76 | obs, rews, dones, infos = zip(*results)
77 | return _flatten_obs(obs), np.stack(rews), np.stack(dones), infos
78 |
79 | def reset(self):
80 | self._assert_not_closed()
81 | for remote in self.remotes:
82 | remote.send(('reset', None))
83 | return _flatten_obs([remote.recv() for remote in self.remotes])
84 |
85 | def close_extras(self):
86 | self.closed = True
87 | if self.waiting:
88 | for remote in self.remotes:
89 | remote.recv()
90 | for remote in self.remotes:
91 | remote.send(('close', None))
92 | for p in self.ps:
93 | p.join()
94 |
95 | def get_images(self):
96 | self._assert_not_closed()
97 | for pipe in self.remotes:
98 | pipe.send(('render', None))
99 | imgs = [pipe.recv() for pipe in self.remotes]
100 | return imgs
101 |
102 | def _assert_not_closed(self):
103 | assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()"
104 |
105 | def __del__(self):
106 | if not self.closed:
107 | self.close()
108 |
109 | def _flatten_obs(obs):
110 | assert isinstance(obs, (list, tuple))
111 | assert len(obs) > 0
112 |
113 | if isinstance(obs[0], dict):
114 | keys = obs[0].keys()
115 | return {k: np.stack([o[k] for o in obs]) for k in keys}
116 | else:
117 | return np.stack(obs)
118 |
--------------------------------------------------------------------------------
/Utils/Atari_PPO_training/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.optim as optim
4 | from atari_wrappers import FrameStack
5 | from subproc_vec_env import SubprocVecEnv
6 | from vec_frame_stack import VecFrameStack
7 | import multiprocessing
8 |
9 | from envs import make_env, RenderSubprocVecEnv
10 | from models import AtariCNN
11 | from ppo import PPO
12 | from utils import set_seed, cuda_if
13 |
14 | if __name__ == "__main__":
15 | multiprocessing.set_start_method("forkserver")
16 |
17 | parser = argparse.ArgumentParser(description='PPO', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
18 | parser.add_argument('env_id', type=str, help='Gym environment id')
19 | parser.add_argument('--arch', type=str, default='cnn', help='policy architecture, {lstm, cnn}')
20 | parser.add_argument('--num-workers', type=int, default=8, help='number of parallel actors')
21 | parser.add_argument('--opt-epochs', type=int, default=3, help='optimization epochs between environment interaction')
22 | parser.add_argument('--total-steps', type=int, default=int(10e6), help='total number of environment steps to take')
23 | parser.add_argument('--worker-steps', type=int, default=128, help='steps per worker between optimization rounds')
24 | parser.add_argument('--sequence-steps', type=int, default=32, help='steps per sequence (for backprop through time)')
25 | parser.add_argument('--minibatch-steps', type=int, default=256, help='steps per optimization minibatch')
26 | parser.add_argument('--lr', type=float, default=2.5e-4, help='initial learning rate')
27 | parser.add_argument('--lr-func', type=str, default='linear', help='learning rate schedule function, {linear, constant}')
28 | parser.add_argument('--clip', type=float, default=.1, help='initial probability ratio clipping range')
29 | parser.add_argument('--clip-func', type=str, default='linear', help='clip range schedule function, {linear, constant}')
30 | parser.add_argument('--gamma', type=float, default=.99, help='discount factor')
31 | parser.add_argument('--lambd', type=float, default=.95, help='GAE lambda parameter')
32 | parser.add_argument('--value-coef', type=float, default=1., help='value loss coeffecient')
33 | parser.add_argument('--entropy-coef', type=float, default=.01, help='entropy loss coeffecient')
34 | parser.add_argument('--max-grad-norm', type=float, default=.5, help='grad norm to clip at')
35 | parser.add_argument('--no-cuda', action='store_true', help='disable CUDA acceleration')
36 | parser.add_argument('--render', action='store_true', help='render training environments')
37 | parser.add_argument('--render-interval', type=int, default=4, help='steps between environment renders')
38 | parser.add_argument('--plot-reward', action='store_true', help='plot episode reward')
39 | parser.add_argument('--plot-points', type=int, default=20, help='number of plot points (groups with mean, std)')
40 | parser.add_argument('--plot-path', type=str, default='ep_reward.png', help='path to save reward plot to')
41 | parser.add_argument('--seed', type=int, default=0, help='random seed')
42 | args = parser.parse_args()
43 |
44 | set_seed(args.seed)
45 |
46 | cuda = torch.cuda.is_available() and not args.no_cuda
47 |
48 | test_env = make_env(args.env_id, 0, args.seed)
49 | test_env = FrameStack(test_env, 4)
50 |
51 | policy = {'cnn': AtariCNN}[args.arch](venv.action_space.n)
52 | checkpoint = torch.load("./save/PPO_" + self.env_name + ".pt")
53 | policy.load_check_point(checkpoint["policy"])
54 | policy = cuda_if(policy, cuda)
55 |
56 | ob = self.test_env.reset()
57 | done = False
58 | ep_reward = 0
59 | last_action = np.array([-1])
60 | action_repeat = 0
61 |
62 | while not done:
63 | ob = np.array(ob)
64 | ob = torch.from_numpy(ob.transpose((2, 0, 1))).float().unsqueeze(0)
65 | ob = Variable(ob / 255., volatile=True)
66 | ob = cuda_if(ob, self.cuda)
67 |
68 | pi, v = policy(ob)
69 | _, action = torch.max(pi, dim=1)
70 |
71 | # abort after {self.test_repeat_max} discrete action repeats
72 | if action.data[0] == last_action.data[0]:
73 | action_repeat += 1
74 | if action_repeat == self.test_repeat_max:
75 | return ep_reward
76 | else:
77 | action_repeat = 0
78 | last_action = action
79 |
80 | ob, reward, done, _ = self.test_env.step(action.data.cpu().numpy())
81 |
82 | ep_reward += reward
83 |
84 | print(ep_reward)
85 |
--------------------------------------------------------------------------------
/Utils/Atari_PPO_training/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 |
5 | def set_seed(seed):
6 | random.seed(seed)
7 | np.random.seed(seed)
8 | torch.manual_seed(seed)
9 |
10 | def cuda_if(torch_object, cuda):
11 | return torch_object.cuda() if cuda else torch_object
12 |
13 | def gae(rewards, masks, values, gamma, lambd):
14 | """ Generalized Advantage Estimation
15 |
16 | Args:
17 | rewards (FloatTensor): rewards shaped [T x N x 1]
18 | masks (FloatTensor): continuation masks shaped [T x N x 1]
19 | zero at done timesteps, one otherwise
20 | values (Variable): value predictions shaped [(T + 1) x N x 1]
21 | gamma (float): discount factor
22 | lambd (float): GAE lambda parameter
23 |
24 | Returns:
25 | advantages (FloatTensor): advantages shaped [T x N x 1]
26 | returns (FloatTensor): returns shaped [T x N x 1]
27 | """
28 | T, N, _ = rewards.size()
29 |
30 | cuda = rewards.is_cuda
31 |
32 | advantages = torch.zeros(T, N, 1)
33 | advantages = cuda_if(advantages, cuda)
34 | advantage_t = torch.zeros(N, 1)
35 | advantage_t = cuda_if(advantage_t, cuda)
36 |
37 | for t in reversed(range(T)):
38 | delta = rewards[t] + values[t + 1].data * gamma * masks[t] - values[t].data
39 | advantage_t = delta + advantage_t * gamma * lambd * masks[t]
40 | advantages[t] = advantage_t
41 |
42 | returns = values[:T].data + advantages
43 |
44 | return advantages, returns
45 |
46 |
47 | def mean_std_groups(x, y, group_size):
48 | num_groups = int(len(x) / group_size)
49 |
50 | x, x_tail = x[:group_size * num_groups], x[group_size * num_groups:]
51 | x = x.reshape((num_groups, group_size))
52 |
53 | y, y_tail = y[:group_size * num_groups], y[group_size * num_groups:]
54 | y = y.reshape((num_groups, group_size))
55 |
56 | x_means = x.mean(axis=1)
57 | x_stds = x.std(axis=1)
58 |
59 | if len(x_tail) > 0:
60 | x_means = np.concatenate([x_means, x_tail.mean(axis=0, keepdims=True)])
61 | x_stds = np.concatenate([x_stds, x_tail.std(axis=0, keepdims=True)])
62 |
63 | y_means = y.mean(axis=1)
64 | y_stds = y.std(axis=1)
65 |
66 | if len(y_tail) > 0:
67 | y_means = np.concatenate([y_means, y_tail.mean(axis=0, keepdims=True)])
68 | y_stds = np.concatenate([y_stds, y_tail.std(axis=0, keepdims=True)])
69 |
70 | return x_means, x_stds, y_means, y_stds
71 |
72 | def set_lr(optimizer, lr):
73 | for param_group in optimizer.param_groups:
74 | param_group['lr'] = lr
75 |
--------------------------------------------------------------------------------
/Utils/Atari_PPO_training/vec_env.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import os
3 | from abc import ABC, abstractmethod
4 |
5 | def tile_images(img_nhwc):
6 | """
7 | Tile N images into one big PxQ image
8 | (P,Q) are chosen to be as close as possible, and if N
9 | is square, then P=Q.
10 |
11 | input: img_nhwc, list or array of images, ndim=4 once turned into array
12 | n = batch index, h = height, w = width, c = channel
13 | returns:
14 | bigim_HWc, ndarray with ndim=3
15 | """
16 | img_nhwc = np.asarray(img_nhwc)
17 | N, h, w, c = img_nhwc.shape
18 | H = int(np.ceil(np.sqrt(N)))
19 | W = int(np.ceil(float(N)/H))
20 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)])
21 | img_HWhwc = img_nhwc.reshape(H, W, h, w, c)
22 | img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)
23 | img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c)
24 | return img_Hh_Ww_c
25 |
26 | class AlreadySteppingError(Exception):
27 | """
28 | Raised when an asynchronous step is running while
29 | step_async() is called again.
30 | """
31 |
32 | def __init__(self):
33 | msg = 'already running an async step'
34 | Exception.__init__(self, msg)
35 |
36 |
37 | class NotSteppingError(Exception):
38 | """
39 | Raised when an asynchronous step is not running but
40 | step_wait() is called.
41 | """
42 |
43 | def __init__(self):
44 | msg = 'not running an async step'
45 | Exception.__init__(self, msg)
46 |
47 |
48 | class VecEnv(ABC):
49 | """
50 | An abstract asynchronous, vectorized environment.
51 | Used to batch data from multiple copies of an environment, so that
52 | each observation becomes an batch of observations, and expected action is a batch of actions to
53 | be applied per-environment.
54 | """
55 | closed = False
56 | viewer = None
57 |
58 | metadata = {
59 | 'render.modes': ['human', 'rgb_array']
60 | }
61 |
62 | def __init__(self, num_envs, observation_space, action_space):
63 | self.num_envs = num_envs
64 | self.observation_space = observation_space
65 | self.action_space = action_space
66 |
67 | @abstractmethod
68 | def reset(self):
69 | """
70 | Reset all the environments and return an array of
71 | observations, or a dict of observation arrays.
72 |
73 | If step_async is still doing work, that work will
74 | be cancelled and step_wait() should not be called
75 | until step_async() is invoked again.
76 | """
77 | pass
78 |
79 | @abstractmethod
80 | def step_async(self, actions):
81 | """
82 | Tell all the environments to start taking a step
83 | with the given actions.
84 | Call step_wait() to get the results of the step.
85 |
86 | You should not call this if a step_async run is
87 | already pending.
88 | """
89 | pass
90 |
91 | @abstractmethod
92 | def step_wait(self):
93 | """
94 | Wait for the step taken with step_async().
95 |
96 | Returns (obs, rews, dones, infos):
97 | - obs: an array of observations, or a dict of
98 | arrays of observations.
99 | - rews: an array of rewards
100 | - dones: an array of "episode done" booleans
101 | - infos: a sequence of info objects
102 | """
103 | pass
104 |
105 | def close_extras(self):
106 | """
107 | Clean up the extra resources, beyond what's in this base class.
108 | Only runs when not self.closed.
109 | """
110 | pass
111 |
112 | def close(self):
113 | if self.closed:
114 | return
115 | if self.viewer is not None:
116 | self.viewer.close()
117 | self.close_extras()
118 | self.closed = True
119 |
120 | def step(self, actions):
121 | """
122 | Step the environments synchronously.
123 |
124 | This is available for backwards compatibility.
125 | """
126 | self.step_async(actions)
127 | return self.step_wait()
128 |
129 | def render(self, mode='human'):
130 | imgs = self.get_images()
131 | bigimg = tile_images(imgs)
132 | if mode == 'human':
133 | self.get_viewer().imshow(bigimg)
134 | return self.get_viewer().isopen
135 | elif mode == 'rgb_array':
136 | return bigimg
137 | else:
138 | raise NotImplementedError
139 |
140 | def get_images(self):
141 | """
142 | Return RGB images from each environment
143 | """
144 | raise NotImplementedError
145 |
146 | @property
147 | def unwrapped(self):
148 | if isinstance(self, VecEnvWrapper):
149 | return self.venv.unwrapped
150 | else:
151 | return self
152 |
153 | def get_viewer(self):
154 | if self.viewer is None:
155 | from gym.envs.classic_control import rendering
156 | self.viewer = rendering.SimpleImageViewer()
157 | return self.viewer
158 |
159 | class VecEnvWrapper(VecEnv):
160 | """
161 | An environment wrapper that applies to an entire batch
162 | of environments at once.
163 | """
164 |
165 | def __init__(self, venv, observation_space=None, action_space=None):
166 | self.venv = venv
167 | super().__init__(num_envs=venv.num_envs,
168 | observation_space=observation_space or venv.observation_space,
169 | action_space=action_space or venv.action_space)
170 |
171 | def step_async(self, actions):
172 | self.venv.step_async(actions)
173 |
174 | @abstractmethod
175 | def reset(self):
176 | pass
177 |
178 | @abstractmethod
179 | def step_wait(self):
180 | pass
181 |
182 | def close(self):
183 | return self.venv.close()
184 |
185 | def render(self, mode='human'):
186 | return self.venv.render(mode=mode)
187 |
188 | def get_images(self):
189 | return self.venv.get_images()
190 |
191 | def __getattr__(self, name):
192 | if name.startswith('_'):
193 | raise AttributeError("attempted to get missing private attribute '{}'".format(name))
194 | return getattr(self.venv, name)
195 |
196 | class VecEnvObservationWrapper(VecEnvWrapper):
197 | @abstractmethod
198 | def process(self, obs):
199 | pass
200 |
201 | def reset(self):
202 | obs = self.venv.reset()
203 | return self.process(obs)
204 |
205 | def step_wait(self):
206 | obs, rews, dones, infos = self.venv.step_wait()
207 | return self.process(obs), rews, dones, infos
208 |
209 | class CloudpickleWrapper(object):
210 | """
211 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
212 | """
213 |
214 | def __init__(self, x):
215 | self.x = x
216 |
217 | def __getstate__(self):
218 | import cloudpickle
219 | return cloudpickle.dumps(self.x)
220 |
221 | def __setstate__(self, ob):
222 | import pickle
223 | self.x = pickle.loads(ob)
224 |
225 |
226 | @contextlib.contextmanager
227 | def clear_mpi_env_vars():
228 | """
229 | from mpi4py import MPI will call MPI_Init by default. If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang.
230 | This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing
231 | Processes.
232 | """
233 | removed_environment = {}
234 | for k, v in list(os.environ.items()):
235 | for prefix in ['OMPI_', 'PMI_']:
236 | if k.startswith(prefix):
237 | removed_environment[k] = v
238 | del os.environ[k]
239 | try:
240 | yield
241 | finally:
242 | os.environ.update(removed_environment)
243 |
--------------------------------------------------------------------------------
/Utils/Atari_PPO_training/vec_frame_stack.py:
--------------------------------------------------------------------------------
1 | from vec_env import VecEnvWrapper
2 | import numpy as np
3 | from gym import spaces
4 |
5 |
6 | class VecFrameStack(VecEnvWrapper):
7 | def __init__(self, venv, nstack):
8 | self.venv = venv
9 | self.nstack = nstack
10 | wos = venv.observation_space # wrapped ob space
11 | low = np.repeat(wos.low, self.nstack, axis=-1)
12 | high = np.repeat(wos.high, self.nstack, axis=-1)
13 | self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype)
14 | observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)
15 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
16 |
17 | def step_wait(self):
18 | obs, rews, news, infos = self.venv.step_wait()
19 | self.stackedobs = np.roll(self.stackedobs, shift=-1, axis=-1)
20 | for (i, new) in enumerate(news):
21 | if new:
22 | self.stackedobs[i] = 0
23 | self.stackedobs[..., -obs.shape[-1]:] = obs
24 | return self.stackedobs, rews, news, infos
25 |
26 | def reset(self):
27 | obs = self.venv.reset()
28 | self.stackedobs[...] = 0
29 | self.stackedobs[..., -obs.shape[-1]:] = obs
30 | return self.stackedobs
31 |
--------------------------------------------------------------------------------
/Utils/MovingAvegCalculator.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 |
4 | class MovingAvegCalculator():
5 | def __init__(self, window_length):
6 | self.num_added = 0
7 | self.window_length = window_length
8 | self.window = [0.0 for _ in range(window_length)]
9 |
10 | self.aveg = 0.0
11 | self.var = 0.0
12 |
13 | self.last_std = 0.0
14 |
15 | def add_number(self, num):
16 | idx = self.num_added % self.window_length
17 | old_num = self.window[idx]
18 | self.window[idx] = num
19 | self.num_added += 1
20 |
21 | old_aveg = self.aveg
22 | if self.num_added <= self.window_length:
23 | delta = num - old_aveg
24 | self.aveg += delta / self.num_added
25 | self.var += delta * (num - self.aveg)
26 | else:
27 | delta = num - old_num
28 | self.aveg += delta / self.window_length
29 | self.var += delta * ((num - self.aveg) + (old_num - old_aveg))
30 |
31 | if self.num_added <= self.window_length:
32 | if self.num_added == 1:
33 | variance = 0.1
34 | else:
35 | variance = self.var / (self.num_added - 1)
36 | else:
37 | variance = self.var / self.window_length
38 |
39 | try:
40 | std = math.sqrt(variance)
41 | if math.isnan(std):
42 | std = 0.1
43 | except:
44 | std = 0.1
45 |
46 | self.last_std = std
47 |
48 | return self.aveg, std
49 |
50 | def get_standard_deviation(self):
51 | return self.last_std
52 |
--------------------------------------------------------------------------------
/Utils/NetworkDistillation/Distillation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 | import sys
5 | sys.path.append("../Env")
6 |
7 | from Env.EnvWrapper import EnvWrapper
8 |
9 | from Policy.PPO.PPOPolicy import PPOAtariCNN, PPOSmallAtariCNN
10 |
11 | from .ReplayBuffer import ReplayBuffer
12 |
13 |
14 | class Distillation():
15 | def __init__(self, wrapped_env, teacher_network, student_network,
16 | temperature = 2.5, buffer_size = 1e5, batch_size = 32,
17 | device = torch.device("cpu")):
18 | self.wrapped_env = wrapped_env
19 | self.teacher_network = teacher_network
20 | self.student_network = student_network
21 | self.temperature = temperature
22 | self.buffer_size = buffer_size
23 | self.batch_size = batch_size
24 | self.device = device
25 |
26 | # Replay buffer
27 | self.replay_buffer = ReplayBuffer(max_size = buffer_size, device = self.device)
28 |
29 | def train_step(self):
30 | state_batch, policy_batch, value_batch = self.replay_buffer.sample(self.batch_size)
31 |
32 | loss = self.student_network.train_step(state_batch, policy_batch, value_batch, temperature = self.temperature)
33 |
34 | return loss
35 |
36 | def gather_samples(self, max_step_count = 10000):
37 | state = self.wrapped_env.reset()
38 |
39 | step_count = 0
40 |
41 | done = False
42 | while not done:
43 | if np.random.random() < 0.9:
44 | action = self.categorical(self.student_network.get_action(state))
45 | else:
46 | action = np.random.randint(0, self.wrapped_env.action_n)
47 | target_policy = self.teacher_network.get_action(state, logit = True)
48 | target_value = self.teacher_network.get_value(state)
49 |
50 | self.replay_buffer.add((np.array(state), target_policy, target_value))
51 |
52 | state, _, done = self.wrapped_env.step(action)
53 |
54 | step_count += 1
55 |
56 | if step_count > max_step_count:
57 | return
58 |
59 | @staticmethod
60 | def categorical(probs):
61 | val = random.random()
62 | chosen_idx = 0
63 |
64 | for prob in probs:
65 | val -= prob
66 |
67 | if val < 0.0:
68 | break
69 |
70 | chosen_idx += 1
71 |
72 | if chosen_idx >= len(probs):
73 | chosen_idx = len(probs) - 1
74 |
75 | return chosen_idx
76 |
77 |
78 | def train_distillation(env_name, device):
79 | device = torch.device("cuda:0" if device == "cuda" else device)
80 |
81 | wrapped_env = EnvWrapper(env_name = env_name, max_episode_length = 100000)
82 |
83 | teacher_network = PPOAtariCNN(wrapped_env.action_n, device,
84 | checkpoint_dir = "./Policy/PPO/PolicyFiles/PPO_" + env_name + ".pt")
85 | student_network = PPOSmallAtariCNN(wrapped_env.action_n, device,
86 | checkpoint_dir = "")
87 |
88 | distillation = Distillation(wrapped_env, teacher_network, student_network, device = device)
89 |
90 | for _ in range(1000):
91 | for _ in range(10):
92 | distillation.gather_samples()
93 |
94 | for _ in range(1000):
95 | loss = distillation.train_step()
96 | print(loss)
97 |
98 | student_network.save("./Policy/PPO/PolicyFiles/SmallPPO_" + env_name + ".pt")
99 |
--------------------------------------------------------------------------------
/Utils/NetworkDistillation/ReplayBuffer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class ReplayBuffer(object):
6 | def __init__(self, max_size = 1e6, device = None):
7 | self.storage = []
8 | self.max_size = max_size
9 | self.ptr = 0
10 |
11 | self.device = device
12 |
13 | def add(self, data):
14 | if len(self.storage) == self.max_size:
15 | self.storage[int(self.ptr)] = data
16 | self.ptr = (self.ptr + 1) % self.max_size
17 | else:
18 | self.storage.append(data)
19 |
20 | def sample(self, batch_size):
21 | ind = np.random.randint(0, len(self.storage), size = batch_size)
22 | state_batch, policy_batch, value_batch = [], [], []
23 |
24 | for i in ind:
25 | state, policy, value = self.storage[i]
26 | state_batch.append(np.array(state, copy = False))
27 | policy_batch.append(np.array(policy, copy = False))
28 | value_batch.append(value)
29 |
30 | state_batch = torch.from_numpy(np.array(state_batch, dtype = np.float32))
31 | policy_batch = torch.from_numpy(np.array(policy_batch, dtype = np.float32))
32 | value_batch = torch.from_numpy(np.array(value_batch, dtype = np.float32))
33 |
34 | if self.device is not None:
35 | state_batch = state_batch.to(self.device)
36 | policy_batch = policy_batch.to(self.device)
37 | value_batch = value_batch.to(self.device).unsqueeze(-1)
38 |
39 | return state_batch, policy_batch, value_batch
40 |
--------------------------------------------------------------------------------
/Utils/NetworkDistillation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Utils/NetworkDistillation/__init__.py
--------------------------------------------------------------------------------
/Utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Utils/__init__.py
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import multiprocessing
3 | import scipy.io as sio
4 | import os
5 |
6 | from Tree.WU_UCT import WU_UCT
7 | from Tree.UCT import UCT
8 |
9 | from Utils.NetworkDistillation.Distillation import train_distillation
10 |
11 |
12 | def main():
13 | parser = argparse.ArgumentParser(description = "P-MCTS")
14 | parser.add_argument("--model", type = str, default = "WU-UCT",
15 | help = "Base MCTS model WU-UCT/UCT (default: WU-UCT)")
16 |
17 | parser.add_argument("--env-name", type = str, default = "AlienNoFrameskip-v0",
18 | help = "Environment name (default: AlienNoFrameskip-v0)")
19 |
20 | parser.add_argument("--MCTS-max-steps", type = int, default = 128,
21 | help = "Max simulation step of MCTS (default: 500)")
22 | parser.add_argument("--MCTS-max-depth", type = int, default = 100,
23 | help = "Max depth of MCTS simulation (default: 100)")
24 | parser.add_argument("--MCTS-max-width", type = int, default = 20,
25 | help = "Max width of MCTS simulation (default: 20)")
26 |
27 | parser.add_argument("--gamma", type = float, default = 0.99,
28 | help = "Discount factor (default: 1.0)")
29 |
30 | parser.add_argument("--expansion-worker-num", type = int, default = 1,
31 | help = "Number of expansion workers (default: 1)")
32 | parser.add_argument("--simulation-worker-num", type = int, default = 16,
33 | help = "Number of simulation workers (default: 16)")
34 |
35 | parser.add_argument("--seed", type = int, default = 123,
36 | help = "random seed (default: 123)")
37 |
38 | parser.add_argument("--max-episode-length", type = int, default = 100000,
39 | help = "Maximum episode length (default: 100000)")
40 |
41 | parser.add_argument("--policy", type = str, default = "Random",
42 | help = "Prior prob/simulation policy used in MCTS Random/PPO/DistillPPO (default: Random)")
43 |
44 | parser.add_argument("--device", type = str, default = "cpu",
45 | help = "PyTorch device, if entered 'cuda', use cuda device parallelization (default: cpu)")
46 |
47 | parser.add_argument("--record-video", default = False, action = "store_true",
48 | help = "Record video if supported (default: False)")
49 |
50 | parser.add_argument("--mode", type = str, default = "MCTS",
51 | help = "Mode MCTS/Distill (default: MCTS)")
52 |
53 | args = parser.parse_args()
54 |
55 | env_params = {
56 | "env_name": args.env_name,
57 | "max_episode_length": args.max_episode_length
58 | }
59 |
60 | if args.mode == "MCTS":
61 | # Model initialization
62 | if args.model == "WU-UCT":
63 | MCTStree = WU_UCT(env_params, args.MCTS_max_steps, args.MCTS_max_depth,
64 | args.MCTS_max_width, args.gamma, args.expansion_worker_num,
65 | args.simulation_worker_num, policy = args.policy,
66 | seed = args.seed, device = args.device,
67 | record_video = args.record_video)
68 | elif args.model == "UCT":
69 | MCTStree = UCT(env_params, args.MCTS_max_steps, args.MCTS_max_depth,
70 | args.MCTS_max_width, args.gamma, policy = args.policy, seed = args.seed)
71 | else:
72 | raise NotImplementedError()
73 |
74 | accu_reward, rewards, times = MCTStree.simulate_trajectory()
75 | print(accu_reward)
76 |
77 | with open("Results/" + args.model + ".txt", "a+") as f:
78 | f.write("Model: {}, env: {}, result: {}, MCTS max steps: {}, policy: {}, worker num: {}".format(
79 | args.model, args.env_name, accu_reward, args.MCTS_max_steps, args.policy, args.simulation_worker_num
80 | ))
81 |
82 | if not os.path.exists("OutLogs/"):
83 | try:
84 | os.mkdir("OutLogs/")
85 | except:
86 | pass
87 |
88 | sio.savemat("OutLogs/" + args.model + "_" + args.env_name + "_" + str(args.seed) + "_" +
89 | str(args.simulation_worker_num) + ".mat",
90 | {"rewards": rewards, "times": times})
91 |
92 | MCTStree.close()
93 |
94 | elif args.mode == "Distill":
95 | train_distillation(args.env_name, args.device)
96 |
97 |
98 | if __name__ == "__main__":
99 | # Mandatory for Unix/Darwin
100 | multiprocessing.set_start_method("forkserver")
101 |
102 | main()
103 |
--------------------------------------------------------------------------------