├── .gitignore ├── .idea ├── P-MCTS.iml ├── encodings.xml ├── misc.xml ├── modules.xml ├── vcs.xml └── workspace.xml ├── CODE_OF_CONDUCT.md ├── Env ├── AtariEnv │ ├── AtariEnvWrapper.py │ ├── VideoRecord.py │ ├── atari_wrappers.py │ ├── subproc_vec_env.py │ ├── vec_env.py │ └── vec_frame_stack.py ├── EnvWrapper.py └── __init__.py ├── Figures ├── Figure_atari_results.png ├── Figure_puct_conceptual_idea.png ├── Figure_puct_pipeline.png ├── Figure_tap_results.png └── Figure_time_consumption.png ├── LICENSE ├── Logs └── .gitkeep ├── Mem ├── CheckpointManager.py └── __init__.py ├── Node ├── UCTnode.py ├── WU_UCTnode.py └── __init__.py ├── OutLogs ├── .gitkeep └── WU-UCT_PongNoFrameskip-v0_123_2.mat ├── ParallelPool ├── PoolManager.py ├── Worker.py └── __init__.py ├── Policy ├── PPO │ ├── PPOPolicy.py │ └── PolicyFiles │ │ └── PPO_AlienNoFrameskip-v0.pt └── PolicyWrapper.py ├── README.md ├── Records └── .gitkeep ├── Results └── .gitkeep ├── Tree ├── UCT.py ├── WU_UCT.py └── __init__.py ├── Utils ├── Atari_PPO_training │ ├── LICENSE │ ├── README.md │ ├── atari_wrappers.py │ ├── envs.py │ ├── main.py │ ├── models.py │ ├── ppo.py │ ├── save │ │ └── .gitkeep │ ├── subproc_vec_env.py │ ├── test.py │ ├── utils.py │ ├── vec_env.py │ └── vec_frame_stack.py ├── MovingAvegCalculator.py ├── NetworkDistillation │ ├── Distillation.py │ ├── ReplayBuffer.py │ └── __init__.py └── __init__.py └── main.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | # *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # celery beat schedule file 98 | celerybeat-schedule 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | # env/ 107 | venv/ 108 | # ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .dmypy.json 124 | dmypy.json 125 | 126 | # Pyre type checker 127 | .pyre/ 128 | -------------------------------------------------------------------------------- /.idea/P-MCTS.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | -------------------------------------------------------------------------------- /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 13 | 14 | 19 | 20 | 21 | 23 | 24 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 52 | 53 | 54 | 55 | 56 | 75 | 76 | 77 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 1565057551875 109 | 113 | 114 | 1569541841304 115 | 120 | 121 | 1569541933765 122 | 127 | 130 | 131 | 140 | 141 | 142 | 143 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | dir(self.manager_proxies[0]) 156 | Python 157 | EXPRESSION 158 | 159 | 160 | node_manager_proxy.get_update_count() 161 | Python 162 | EXPRESSION 163 | 164 | 165 | node_manager_proxy.get_ 166 | Python 167 | EXPRESSION 168 | 169 | 170 | self.manager_proxies[0].get_update_count() 171 | Python 172 | EXPRESSION 173 | 174 | 175 | self.manager_proxies[tree_idx].get_update_count() 176 | Python 177 | EXPRESSION 178 | 179 | 180 | manager_proxy.get_update_count() 181 | Python 182 | EXPRESSION 183 | 184 | 185 | manager_proxy. 186 | Python 187 | EXPRESSION 188 | 189 | 190 | manager_listener.listen() 191 | Python 192 | EXPRESSION 193 | 194 | 195 | np.array(r)[np.where(np.array(r) != 0.0)] 196 | Python 197 | EXPRESSION 198 | 199 | 200 | np.where(np.array(r) != 0.0) 201 | Python 202 | EXPRESSION 203 | 204 | 205 | 206 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at anjiliu219@gmail.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /Env/AtariEnv/AtariEnvWrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .atari_wrappers import make_atari, wrap_deepmind, FrameStack 3 | from .vec_env import VecEnv 4 | from .VideoRecord import VideoRecorder 5 | from multiprocessing import Process, Pipe 6 | 7 | 8 | # cf https://github.com/openai/baselines 9 | 10 | def make_atari_env(env_name, rank, seed, enable_record = False, record_path = "./1.mp4"): 11 | env = make_atari(env_name) 12 | env.seed(seed + rank) 13 | env = wrap_deepmind(env, episode_life = False, clip_rewards = False) 14 | env = FrameStack(env, 4) 15 | recorder = VideoRecorder(env, path = record_path, enabled = enable_record) 16 | return env, recorder 17 | 18 | 19 | def worker(remote, parent_remote, env_fn_wrapper): 20 | parent_remote.close() 21 | env = env_fn_wrapper.x() 22 | while True: 23 | cmd, data = remote.recv() 24 | if cmd == 'step': 25 | ob, reward, done, info = env.step(data) 26 | if done: 27 | ob = env.reset() 28 | remote.send((ob, reward, done, info)) 29 | elif cmd == 'reset': 30 | ob = env.reset() 31 | remote.send(ob) 32 | elif cmd == 'reset_task': 33 | ob = env.reset_task() 34 | remote.send(ob) 35 | elif cmd == 'close': 36 | remote.close() 37 | break 38 | elif cmd == 'get_spaces': 39 | remote.send((env.action_space, env.observation_space)) 40 | elif cmd == 'render': 41 | env.render() 42 | else: 43 | raise NotImplementedError 44 | 45 | 46 | class CloudpickleWrapper(object): 47 | """ 48 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 49 | """ 50 | def __init__(self, x): 51 | self.x = x 52 | def __getstate__(self): 53 | import cloudpickle 54 | return cloudpickle.dumps(self.x) 55 | def __setstate__(self, ob): 56 | import pickle 57 | self.x = pickle.loads(ob) 58 | 59 | 60 | class RenderSubprocVecEnv(VecEnv): 61 | def __init__(self, env_fns, render_interval): 62 | """ Minor addition to SubprocVecEnv, automatically renders environments 63 | 64 | envs: list of gym environments to run in subprocesses 65 | """ 66 | self.closed = False 67 | nenvs = len(env_fns) 68 | self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) 69 | self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) 70 | for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] 71 | for p in self.ps: 72 | p.daemon = True # if the main process crashes, we should not cause things to hang 73 | p.start() 74 | for remote in self.work_remotes: 75 | remote.close() 76 | 77 | self.remotes[0].send(('get_spaces', None)) 78 | self.action_space, self.observation_space = self.remotes[0].recv() 79 | 80 | self.render_interval = render_interval 81 | self.render_timer = 0 82 | 83 | def step(self, actions): 84 | for remote, action in zip(self.remotes, actions): 85 | remote.send(('step', action)) 86 | results = [remote.recv() for remote in self.remotes] 87 | obs, rews, dones, infos = zip(*results) 88 | 89 | self.render_timer += 1 90 | if self.render_timer == self.render_interval: 91 | for remote in self.remotes: 92 | remote.send(('render', None)) 93 | self.render_timer = 0 94 | 95 | return np.stack(obs), np.stack(rews), np.stack(dones), infos 96 | 97 | def reset(self): 98 | for remote in self.remotes: 99 | remote.send(('reset', None)) 100 | return np.stack([remote.recv() for remote in self.remotes]) 101 | 102 | def reset_task(self): 103 | for remote in self.remotes: 104 | remote.send(('reset_task', None)) 105 | return np.stack([remote.recv() for remote in self.remotes]) 106 | 107 | def close(self): 108 | if self.closed: 109 | return 110 | 111 | for remote in self.remotes: 112 | remote.send(('close', None)) 113 | for p in self.ps: 114 | p.join() 115 | self.closed = True 116 | 117 | @property 118 | def num_envs(self): 119 | return len(self.remotes) 120 | -------------------------------------------------------------------------------- /Env/AtariEnv/VideoRecord.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import subprocess 4 | import tempfile 5 | import os.path 6 | import distutils.spawn, distutils.version 7 | import numpy as np 8 | from six import StringIO 9 | import six 10 | from gym import error, logger 11 | 12 | def touch(path): 13 | open(path, 'a').close() 14 | 15 | class VideoRecorder(): 16 | """VideoRecorder renders a nice movie of a rollout, frame by frame. It 17 | comes with an `enabled` option so you can still use the same code 18 | on episodes where you don't want to record video. 19 | 20 | Note: 21 | You are responsible for calling `close` on a created 22 | VideoRecorder, or else you may leak an encoder process. 23 | 24 | Args: 25 | env (Env): Environment to take video of. 26 | path (Optional[str]): Path to the video file; will be randomly chosen if omitted. 27 | base_path (Optional[str]): Alternatively, path to the video file without extension, which will be added. 28 | metadata (Optional[dict]): Contents to save to the metadata file. 29 | enabled (bool): Whether to actually record video, or just no-op (for convenience) 30 | """ 31 | 32 | def __init__(self, env, path=None, metadata=None, enabled=True, base_path=None): 33 | modes = env.metadata.get('render.modes', []) 34 | self._async = env.metadata.get('semantics.async') 35 | self.enabled = enabled 36 | 37 | # Don't bother setting anything else if not enabled 38 | if not self.enabled: 39 | return 40 | 41 | self.ansi_mode = False 42 | if 'rgb_array' not in modes: 43 | if 'ansi' in modes: 44 | self.ansi_mode = True 45 | else: 46 | logger.info('Disabling video recorder because {} neither supports video mode "rgb_array" nor "ansi".'.format(env)) 47 | # Whoops, turns out we shouldn't be enabled after all 48 | self.enabled = False 49 | return 50 | 51 | if path is not None and base_path is not None: 52 | raise error.Error("You can pass at most one of `path` or `base_path`.") 53 | 54 | self.last_frame = None 55 | self.env = env 56 | 57 | required_ext = '.json' if self.ansi_mode else '.mp4' 58 | if path is None: 59 | if base_path is not None: 60 | # Base path given, append ext 61 | path = base_path + required_ext 62 | else: 63 | # Otherwise, just generate a unique filename 64 | with tempfile.NamedTemporaryFile(suffix=required_ext, delete=False) as f: 65 | path = f.name 66 | self.path = path 67 | 68 | path_base, actual_ext = os.path.splitext(self.path) 69 | 70 | if actual_ext != required_ext: 71 | hint = " HINT: The environment is text-only, therefore we're recording its text output in a structured JSON format." if self.ansi_mode else '' 72 | raise error.Error("Invalid path given: {} -- must have file extension {}.{}".format(self.path, required_ext, hint)) 73 | # Touch the file in any case, so we know it's present. (This 74 | # corrects for platform platform differences. Using ffmpeg on 75 | # OS X, the file is precreated, but not on Linux. 76 | touch(path) 77 | 78 | self.frames_per_sec = env.metadata.get('video.frames_per_second', 30) 79 | self.encoder = None # lazily start the process 80 | self.broken = False 81 | 82 | # Dump metadata 83 | self.metadata = metadata or {} 84 | self.metadata['content_type'] = 'video/vnd.openai.ansivid' if self.ansi_mode else 'video/mp4' 85 | self.metadata_path = '{}.meta.json'.format(path_base) 86 | self.write_metadata() 87 | 88 | logger.info('Starting new video recorder writing to %s', self.path) 89 | self.empty = True 90 | 91 | @property 92 | def functional(self): 93 | return self.enabled and not self.broken 94 | 95 | def capture_frame(self): 96 | """Render the given `env` and add the resulting frame to the video.""" 97 | if not self.functional: return 98 | logger.debug('Capturing video frame: path=%s', self.path) 99 | 100 | render_mode = 'ansi' if self.ansi_mode else 'rgb_array' 101 | frame = self.env.render(mode=render_mode) 102 | 103 | if frame is None: 104 | if self._async: 105 | return 106 | else: 107 | # Indicates a bug in the environment: don't want to raise 108 | # an error here. 109 | logger.warn('Env returned None on render(). Disabling further rendering for video recorder by marking as disabled: path=%s metadata_path=%s', self.path, self.metadata_path) 110 | self.broken = True 111 | else: 112 | self.last_frame = frame 113 | if self.ansi_mode: 114 | self._encode_ansi_frame(frame) 115 | else: 116 | self._encode_image_frame(frame) 117 | 118 | def close(self): 119 | """Make sure to manually close, or else you'll leak the encoder process""" 120 | if not self.enabled: 121 | return 122 | 123 | if self.encoder: 124 | logger.debug('Closing video encoder: path=%s', self.path) 125 | self.encoder.close() 126 | self.encoder = None 127 | else: 128 | # No frames captured. Set metadata, and remove the empty output file. 129 | os.remove(self.path) 130 | 131 | if self.metadata is None: 132 | self.metadata = {} 133 | self.metadata['empty'] = True 134 | 135 | # If broken, get rid of the output file, otherwise we'd leak it. 136 | if self.broken: 137 | logger.info('Cleaning up paths for broken video recorder: path=%s metadata_path=%s', self.path, self.metadata_path) 138 | 139 | # Might have crashed before even starting the output file, don't try to remove in that case. 140 | if os.path.exists(self.path): 141 | os.remove(self.path) 142 | 143 | if self.metadata is None: 144 | self.metadata = {} 145 | self.metadata['broken'] = True 146 | 147 | self.write_metadata() 148 | 149 | def write_metadata(self): 150 | with open(self.metadata_path, 'w') as f: 151 | json.dump(self.metadata, f) 152 | 153 | def _encode_ansi_frame(self, frame): 154 | if not self.encoder: 155 | self.encoder = TextEncoder(self.path, self.frames_per_sec) 156 | self.metadata['encoder_version'] = self.encoder.version_info 157 | self.encoder.capture_frame(frame) 158 | self.empty = False 159 | 160 | def _encode_image_frame(self, frame): 161 | if not self.encoder: 162 | self.encoder = ImageEncoder(self.path, frame.shape, self.frames_per_sec) 163 | self.metadata['encoder_version'] = self.encoder.version_info 164 | 165 | try: 166 | self.encoder.capture_frame(frame) 167 | except error.InvalidFrame as e: 168 | logger.warn('Tried to pass invalid video frame, marking as broken: %s', e) 169 | self.broken = True 170 | else: 171 | self.empty = False 172 | 173 | 174 | class TextEncoder(object): 175 | """Store a moving picture made out of ANSI frames. Format adapted from 176 | https://github.com/asciinema/asciinema/blob/master/doc/asciicast-v1.md""" 177 | 178 | def __init__(self, output_path, frames_per_sec): 179 | self.output_path = output_path 180 | self.frames_per_sec = frames_per_sec 181 | self.frames = [] 182 | 183 | def capture_frame(self, frame): 184 | string = None 185 | if isinstance(frame, str): 186 | string = frame 187 | elif isinstance(frame, StringIO): 188 | string = frame.getvalue() 189 | else: 190 | raise error.InvalidFrame('Wrong type {} for {}: text frame must be a string or StringIO'.format(type(frame), frame)) 191 | 192 | frame_bytes = string.encode('utf-8') 193 | 194 | if frame_bytes[-1:] != six.b('\n'): 195 | raise error.InvalidFrame('Frame must end with a newline: """{}"""'.format(string)) 196 | 197 | if six.b('\r') in frame_bytes: 198 | raise error.InvalidFrame('Frame contains carriage returns (only newlines are allowed: """{}"""'.format(string)) 199 | 200 | self.frames.append(frame_bytes) 201 | 202 | def close(self): 203 | #frame_duration = float(1) / self.frames_per_sec 204 | frame_duration = .5 205 | 206 | # Turn frames into events: clear screen beforehand 207 | # https://rosettacode.org/wiki/Terminal_control/Clear_the_screen#Python 208 | # https://rosettacode.org/wiki/Terminal_control/Cursor_positioning#Python 209 | clear_code = six.b("%c[2J\033[1;1H" % (27)) 210 | # Decode the bytes as UTF-8 since JSON may only contain UTF-8 211 | events = [ (frame_duration, (clear_code+frame.replace(six.b('\n'),six.b('\r\n'))).decode('utf-8')) for frame in self.frames ] 212 | 213 | # Calculate frame size from the largest frames. 214 | # Add some padding since we'll get cut off otherwise. 215 | height = max([frame.count(six.b('\n')) for frame in self.frames]) + 1 216 | width = max([max([len(line) for line in frame.split(six.b('\n'))]) for frame in self.frames]) + 2 217 | 218 | data = { 219 | "version": 1, 220 | "width": width, 221 | "height": height, 222 | "duration": len(self.frames)*frame_duration, 223 | "command": "-", 224 | "title": "gym VideoRecorder episode", 225 | "env": {}, # could add some env metadata here 226 | "stdout": events, 227 | } 228 | 229 | with open(self.output_path, 'w') as f: 230 | json.dump(data, f) 231 | 232 | @property 233 | def version_info(self): 234 | return {'backend':'TextEncoder','version':1} 235 | 236 | class ImageEncoder(object): 237 | def __init__(self, output_path, frame_shape, frames_per_sec): 238 | self.proc = None 239 | self.output_path = output_path 240 | # Frame shape should be lines-first, so w and h are swapped 241 | h, w, pixfmt = frame_shape 242 | if pixfmt != 3 and pixfmt != 4: 243 | raise error.InvalidFrame("Your frame has shape {}, but we require (w,h,3) or (w,h,4), i.e. RGB values for a w-by-h image, with an optional alpha channl.".format(frame_shape)) 244 | self.wh = (w,h) 245 | self.includes_alpha = (pixfmt == 4) 246 | self.frame_shape = frame_shape 247 | self.frames_per_sec = frames_per_sec 248 | 249 | if distutils.spawn.find_executable('avconv') is not None: 250 | self.backend = 'avconv' 251 | elif distutils.spawn.find_executable('ffmpeg') is not None: 252 | self.backend = 'ffmpeg' 253 | else: 254 | raise error.DependencyNotInstalled("""Found neither the ffmpeg nor avconv executables. On OS X, you can install ffmpeg via `brew install ffmpeg`. On most Ubuntu variants, `sudo apt-get install ffmpeg` should do it. On Ubuntu 14.04, however, you'll need to install avconv with `sudo apt-get install libav-tools`.""") 255 | 256 | self.start() 257 | 258 | @property 259 | def version_info(self): 260 | return { 261 | 'backend':self.backend, 262 | 'version':str(subprocess.check_output([self.backend, '-version'], 263 | stderr=subprocess.STDOUT)), 264 | 'cmdline':self.cmdline 265 | } 266 | 267 | def start(self): 268 | self.cmdline = (self.backend, 269 | '-nostats', 270 | '-loglevel', 'error', # suppress warnings 271 | '-y', 272 | '-r', '%d' % self.frames_per_sec, 273 | 274 | # input 275 | '-f', 'rawvideo', 276 | '-s:v', '{}x{}'.format(*self.wh), 277 | '-pix_fmt',('rgb32' if self.includes_alpha else 'rgb24'), 278 | '-i', '-', # this used to be /dev/stdin, which is not Windows-friendly 279 | 280 | # output 281 | '-vcodec', 'libx264', 282 | '-pix_fmt', 'yuv420p', 283 | self.output_path 284 | ) 285 | 286 | logger.debug('Starting ffmpeg with "%s"', ' '.join(self.cmdline)) 287 | if hasattr(os,'setsid'): #setsid not present on Windows 288 | self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE, preexec_fn=os.setsid) 289 | else: 290 | self.proc = subprocess.Popen(self.cmdline, stdin=subprocess.PIPE) 291 | 292 | def capture_frame(self, frame): 293 | if not isinstance(frame, (np.ndarray, np.generic)): 294 | raise error.InvalidFrame('Wrong type {} for {} (must be np.ndarray or np.generic)'.format(type(frame), frame)) 295 | if frame.shape != self.frame_shape: 296 | raise error.InvalidFrame("Your frame has shape {}, but the VideoRecorder is configured for shape {}.".format(frame.shape, self.frame_shape)) 297 | if frame.dtype != np.uint8: 298 | raise error.InvalidFrame("Your frame has data type {}, but we require uint8 (i.e. RGB values from 0-255).".format(frame.dtype)) 299 | 300 | if distutils.version.LooseVersion(np.__version__) >= distutils.version.LooseVersion('1.9.0'): 301 | self.proc.stdin.write(frame.tobytes()) 302 | else: 303 | self.proc.stdin.write(frame.tostring()) 304 | 305 | def close(self): 306 | self.proc.stdin.close() 307 | ret = self.proc.wait() 308 | if ret != 0: 309 | logger.error("VideoRecorder encoder exited with status {}".format(ret)) 310 | -------------------------------------------------------------------------------- /Env/AtariEnv/atari_wrappers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | os.environ.setdefault('PATH', '') 4 | from collections import deque 5 | import gym 6 | from gym import spaces 7 | import cv2 8 | cv2.ocl.setUseOpenCL(False) 9 | from copy import deepcopy 10 | 11 | class TimeLimit(gym.Wrapper): 12 | def __init__(self, env, max_episode_steps=None): 13 | super(TimeLimit, self).__init__(env) 14 | self._max_episode_steps = max_episode_steps 15 | self._elapsed_steps = 0 16 | 17 | def step(self, ac): 18 | observation, reward, done, info = self.env.step(ac) 19 | self._elapsed_steps += 1 20 | if self._elapsed_steps >= self._max_episode_steps: 21 | done = True 22 | info['TimeLimit.truncated'] = True 23 | return observation, reward, done, info 24 | 25 | def reset(self, **kwargs): 26 | self._elapsed_steps = 0 27 | return self.env.reset(**kwargs) 28 | 29 | class NoopResetEnv(gym.Wrapper): 30 | def __init__(self, env, noop_max=30): 31 | """Sample initial states by taking random number of no-ops on reset. 32 | No-op is assumed to be action 0. 33 | """ 34 | gym.Wrapper.__init__(self, env) 35 | self.noop_max = noop_max 36 | self.override_num_noops = None 37 | self.noop_action = 0 38 | assert env.unwrapped.get_action_meanings()[0] == 'NOOP' 39 | 40 | def reset(self, **kwargs): 41 | """ Do no-op action for a number of steps in [1, noop_max].""" 42 | self.env.reset(**kwargs) 43 | if self.override_num_noops is not None: 44 | noops = self.override_num_noops 45 | else: 46 | noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101 47 | assert noops > 0 48 | obs = None 49 | for _ in range(noops): 50 | obs, _, done, _ = self.env.step(self.noop_action) 51 | if done: 52 | obs = self.env.reset(**kwargs) 53 | return obs 54 | 55 | def step(self, ac): 56 | return self.env.step(ac) 57 | 58 | class FireResetEnv(gym.Wrapper): 59 | def __init__(self, env): 60 | """Take action on reset for environments that are fixed until firing.""" 61 | gym.Wrapper.__init__(self, env) 62 | assert env.unwrapped.get_action_meanings()[1] == 'FIRE' 63 | assert len(env.unwrapped.get_action_meanings()) >= 3 64 | 65 | def reset(self, **kwargs): 66 | self.env.reset(**kwargs) 67 | obs, _, done, _ = self.env.step(1) 68 | if done: 69 | self.env.reset(**kwargs) 70 | obs, _, done, _ = self.env.step(2) 71 | if done: 72 | self.env.reset(**kwargs) 73 | return obs 74 | 75 | def step(self, ac): 76 | return self.env.step(ac) 77 | 78 | class EpisodicLifeEnv(gym.Wrapper): 79 | def __init__(self, env): 80 | """Make end-of-life == end-of-episode, but only reset on true game over. 81 | Done by DeepMind for the DQN and co. since it helps value estimation. 82 | """ 83 | gym.Wrapper.__init__(self, env) 84 | self.lives = 0 85 | self.was_real_done = True 86 | 87 | def step(self, action): 88 | obs, reward, done, info = self.env.step(action) 89 | self.was_real_done = done 90 | # check current lives, make loss of life terminal, 91 | # then update lives to handle bonus lives 92 | lives = self.env.unwrapped.ale.lives() 93 | if lives < self.lives and lives > 0: 94 | # for Qbert sometimes we stay in lives == 0 condition for a few frames 95 | # so it's important to keep lives > 0, so that we only reset once 96 | # the environment advertises done. 97 | done = True 98 | self.lives = lives 99 | return obs, reward, done, info 100 | 101 | def reset(self, **kwargs): 102 | """Reset only when lives are exhausted. 103 | This way all states are still reachable even though lives are episodic, 104 | and the learner need not know about any of this behind-the-scenes. 105 | """ 106 | if self.was_real_done: 107 | obs = self.env.reset(**kwargs) 108 | else: 109 | # no-op step to advance from terminal/lost life state 110 | obs, _, _, _ = self.env.step(0) 111 | self.lives = self.env.unwrapped.ale.lives() 112 | return obs 113 | 114 | class MaxAndSkipEnv(gym.Wrapper): 115 | def __init__(self, env, skip=4): 116 | """Return only every `skip`-th frame""" 117 | gym.Wrapper.__init__(self, env) 118 | # most recent raw observations (for max pooling across time steps) 119 | self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8) 120 | self._skip = skip 121 | 122 | def step(self, action): 123 | """Repeat action, sum reward, and max over last observations.""" 124 | total_reward = 0.0 125 | done = None 126 | for i in range(self._skip): 127 | obs, reward, done, info = self.env.step(action) 128 | if i == self._skip - 2: self._obs_buffer[0] = obs 129 | if i == self._skip - 1: self._obs_buffer[1] = obs 130 | total_reward += reward 131 | if done: 132 | break 133 | # Note that the observation on the done=True frame 134 | # doesn't matter 135 | max_frame = self._obs_buffer.max(axis=0) 136 | 137 | return max_frame, total_reward, done, info 138 | 139 | def reset(self, **kwargs): 140 | return self.env.reset(**kwargs) 141 | 142 | class ClipRewardEnv(gym.RewardWrapper): 143 | def __init__(self, env): 144 | gym.RewardWrapper.__init__(self, env) 145 | 146 | def reward(self, reward): 147 | """Bin reward to {+1, 0, -1} by its sign.""" 148 | return np.sign(reward) 149 | 150 | 151 | class WarpFrame(gym.ObservationWrapper): 152 | def __init__(self, env, width=84, height=84, grayscale=True, dict_space_key=None): 153 | """ 154 | Warp frames to 84x84 as done in the Nature paper and later work. 155 | 156 | If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which 157 | observation should be warped. 158 | """ 159 | super().__init__(env) 160 | self._width = width 161 | self._height = height 162 | self._grayscale = grayscale 163 | self._key = dict_space_key 164 | if self._grayscale: 165 | num_colors = 1 166 | else: 167 | num_colors = 3 168 | 169 | new_space = gym.spaces.Box( 170 | low=0, 171 | high=255, 172 | shape=(self._height, self._width, num_colors), 173 | dtype=np.uint8, 174 | ) 175 | if self._key is None: 176 | original_space = self.observation_space 177 | self.observation_space = new_space 178 | else: 179 | original_space = self.observation_space.spaces[self._key] 180 | self.observation_space.spaces[self._key] = new_space 181 | assert original_space.dtype == np.uint8 and len(original_space.shape) == 3 182 | 183 | def observation(self, obs): 184 | if self._key is None: 185 | frame = obs 186 | else: 187 | frame = obs[self._key] 188 | 189 | if self._grayscale: 190 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) 191 | frame = cv2.resize( 192 | frame, (self._width, self._height), interpolation=cv2.INTER_AREA 193 | ) 194 | if self._grayscale: 195 | frame = np.expand_dims(frame, -1) 196 | 197 | if self._key is None: 198 | obs = frame 199 | else: 200 | obs = obs.copy() 201 | obs[self._key] = frame 202 | return obs 203 | 204 | 205 | class FrameStack(gym.Wrapper): 206 | def __init__(self, env, k): 207 | """Stack k last frames. 208 | 209 | Returns lazy array, which is much more memory efficient. 210 | 211 | See Also 212 | -------- 213 | baselines.common.atari_wrappers.LazyFrames 214 | """ 215 | gym.Wrapper.__init__(self, env) 216 | self.k = k 217 | self.frames = deque([], maxlen=k) 218 | shp = env.observation_space.shape 219 | self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=env.observation_space.dtype) 220 | 221 | def reset(self): 222 | ob = self.env.reset() 223 | for _ in range(self.k): 224 | self.frames.append(ob / 255.0) 225 | return self._get_ob() 226 | 227 | def step(self, action): 228 | ob, reward, done, info = self.env.step(action) 229 | self.frames.append(ob / 255.0) 230 | return self._get_ob(), reward, done, info 231 | 232 | def _get_ob(self): 233 | assert len(self.frames) == self.k 234 | return LazyFrames(list(self.frames)) 235 | 236 | def clone_full_state(self): 237 | state_data = self.unwrapped.clone_full_state() 238 | frame_data = self.frames.copy() 239 | 240 | full_state_data = (state_data, frame_data) 241 | 242 | return full_state_data 243 | 244 | def restore_full_state(self, full_state_data): 245 | state_data, frame_data = full_state_data 246 | 247 | self.unwrapped.restore_full_state(state_data) 248 | self.frames = frame_data.copy() 249 | 250 | def get_state(self): 251 | return self._get_ob() 252 | 253 | 254 | class ScaledFloatFrame(gym.ObservationWrapper): 255 | def __init__(self, env): 256 | gym.ObservationWrapper.__init__(self, env) 257 | self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32) 258 | 259 | def observation(self, observation): 260 | # careful! This undoes the memory optimization, use 261 | # with smaller replay buffers only. 262 | return np.array(observation).astype(np.float32) / 255.0 263 | 264 | class LazyFrames(object): 265 | def __init__(self, frames): 266 | """This object ensures that common frames between the observations are only stored once. 267 | It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay 268 | buffers. 269 | 270 | This object should only be converted to numpy array before being passed to the model. 271 | 272 | You'd not believe how complex the previous solution was.""" 273 | self._frames = frames 274 | self._out = None 275 | 276 | def _force(self): 277 | if self._out is None: 278 | self._out = np.stack(self._frames, axis=0).reshape((-1, 84, 84)) 279 | self._frames = None 280 | return self._out 281 | 282 | def __array__(self, dtype=None): 283 | out = self._force() 284 | if dtype is not None: 285 | out = out.astype(dtype) 286 | return out 287 | 288 | def __len__(self): 289 | return len(self._force()) 290 | 291 | def __getitem__(self, i): 292 | return self._force()[i] 293 | 294 | def count(self): 295 | frames = self._force() 296 | return frames.shape[frames.ndim - 1] 297 | 298 | def frame(self, i): 299 | return self._force()[..., i] 300 | 301 | def make_atari(env_id, max_episode_steps=None): 302 | env = gym.make(env_id) 303 | assert 'NoFrameskip' in env.spec.id 304 | env = NoopResetEnv(env, noop_max=30) 305 | env = MaxAndSkipEnv(env, skip=4) 306 | if max_episode_steps is not None: 307 | env = TimeLimit(env, max_episode_steps=max_episode_steps) 308 | return env 309 | 310 | def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False): 311 | """Configure environment for DeepMind-style Atari. 312 | """ 313 | if episode_life: 314 | env = EpisodicLifeEnv(env) 315 | if 'FIRE' in env.unwrapped.get_action_meanings(): 316 | env = FireResetEnv(env) 317 | env = WarpFrame(env) 318 | if scale: 319 | env = ScaledFloatFrame(env) 320 | if clip_rewards: 321 | env = ClipRewardEnv(env) 322 | if frame_stack: 323 | env = FrameStack(env, 4) 324 | return env 325 | 326 | -------------------------------------------------------------------------------- /Env/AtariEnv/subproc_vec_env.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | 3 | import numpy as np 4 | from vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars 5 | 6 | 7 | def worker(remote, parent_remote, env_fn_wrapper): 8 | parent_remote.close() 9 | env = env_fn_wrapper.x() 10 | try: 11 | while True: 12 | cmd, data = remote.recv() 13 | if cmd == 'step': 14 | ob, reward, done, info = env.step(data) 15 | if done: 16 | ob = env.reset() 17 | remote.send((ob, reward, done, info)) 18 | elif cmd == 'reset': 19 | ob = env.reset() 20 | remote.send(ob) 21 | elif cmd == 'render': 22 | remote.send(env.render(mode='rgb_array')) 23 | elif cmd == 'close': 24 | remote.close() 25 | break 26 | elif cmd == 'get_spaces_spec': 27 | remote.send((env.observation_space, env.action_space, env.spec)) 28 | else: 29 | raise NotImplementedError 30 | except KeyboardInterrupt: 31 | print('SubprocVecEnv worker: got KeyboardInterrupt') 32 | finally: 33 | env.close() 34 | 35 | 36 | class SubprocVecEnv(VecEnv): 37 | """ 38 | VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes. 39 | Recommended to use when num_envs > 1 and step() can be a bottleneck. 40 | """ 41 | def __init__(self, env_fns, spaces=None, context='spawn'): 42 | """ 43 | Arguments: 44 | 45 | env_fns: iterable of callables - functions that create environments to run in subprocesses. Need to be cloud-pickleable 46 | """ 47 | self.waiting = False 48 | self.closed = False 49 | nenvs = len(env_fns) 50 | ctx = mp.get_context(context) 51 | self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(nenvs)]) 52 | self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) 53 | for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] 54 | for p in self.ps: 55 | p.daemon = True # if the main process crashes, we should not cause things to hang 56 | with clear_mpi_env_vars(): 57 | p.start() 58 | for remote in self.work_remotes: 59 | remote.close() 60 | 61 | self.remotes[0].send(('get_spaces_spec', None)) 62 | observation_space, action_space, self.spec = self.remotes[0].recv() 63 | self.viewer = None 64 | VecEnv.__init__(self, len(env_fns), observation_space, action_space) 65 | 66 | def step_async(self, actions): 67 | self._assert_not_closed() 68 | for remote, action in zip(self.remotes, actions): 69 | remote.send(('step', action)) 70 | self.waiting = True 71 | 72 | def step_wait(self): 73 | self._assert_not_closed() 74 | results = [remote.recv() for remote in self.remotes] 75 | self.waiting = False 76 | obs, rews, dones, infos = zip(*results) 77 | return _flatten_obs(obs), np.stack(rews), np.stack(dones), infos 78 | 79 | def reset(self): 80 | self._assert_not_closed() 81 | for remote in self.remotes: 82 | remote.send(('reset', None)) 83 | return _flatten_obs([remote.recv() for remote in self.remotes]) 84 | 85 | def close_extras(self): 86 | self.closed = True 87 | if self.waiting: 88 | for remote in self.remotes: 89 | remote.recv() 90 | for remote in self.remotes: 91 | remote.send(('close', None)) 92 | for p in self.ps: 93 | p.join() 94 | 95 | def get_images(self): 96 | self._assert_not_closed() 97 | for pipe in self.remotes: 98 | pipe.send(('render', None)) 99 | imgs = [pipe.recv() for pipe in self.remotes] 100 | return imgs 101 | 102 | def _assert_not_closed(self): 103 | assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()" 104 | 105 | def __del__(self): 106 | if not self.closed: 107 | self.close() 108 | 109 | def _flatten_obs(obs): 110 | assert isinstance(obs, (list, tuple)) 111 | assert len(obs) > 0 112 | 113 | if isinstance(obs[0], dict): 114 | keys = obs[0].keys() 115 | return {k: np.stack([o[k] for o in obs]) for k in keys} 116 | else: 117 | return np.stack(obs) 118 | -------------------------------------------------------------------------------- /Env/AtariEnv/vec_env.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import os 3 | from abc import ABC, abstractmethod 4 | 5 | def tile_images(img_nhwc): 6 | """ 7 | Tile N images into one big PxQ image 8 | (P,Q) are chosen to be as close as possible, and if N 9 | is square, then P=Q. 10 | 11 | input: img_nhwc, list or array of images, ndim=4 once turned into array 12 | n = batch index, h = height, w = width, c = channel 13 | returns: 14 | bigim_HWc, ndarray with ndim=3 15 | """ 16 | img_nhwc = np.asarray(img_nhwc) 17 | N, h, w, c = img_nhwc.shape 18 | H = int(np.ceil(np.sqrt(N))) 19 | W = int(np.ceil(float(N)/H)) 20 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)]) 21 | img_HWhwc = img_nhwc.reshape(H, W, h, w, c) 22 | img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4) 23 | img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c) 24 | return img_Hh_Ww_c 25 | 26 | class AlreadySteppingError(Exception): 27 | """ 28 | Raised when an asynchronous step is running while 29 | step_async() is called again. 30 | """ 31 | 32 | def __init__(self): 33 | msg = 'already running an async step' 34 | Exception.__init__(self, msg) 35 | 36 | 37 | class NotSteppingError(Exception): 38 | """ 39 | Raised when an asynchronous step is not running but 40 | step_wait() is called. 41 | """ 42 | 43 | def __init__(self): 44 | msg = 'not running an async step' 45 | Exception.__init__(self, msg) 46 | 47 | 48 | class VecEnv(ABC): 49 | """ 50 | An abstract asynchronous, vectorized environment. 51 | Used to batch data from multiple copies of an environment, so that 52 | each observation becomes an batch of observations, and expected action is a batch of actions to 53 | be applied per-environment. 54 | """ 55 | closed = False 56 | viewer = None 57 | 58 | metadata = { 59 | 'render.modes': ['human', 'rgb_array'] 60 | } 61 | 62 | def __init__(self, num_envs, observation_space, action_space): 63 | self.num_envs = num_envs 64 | self.observation_space = observation_space 65 | self.action_space = action_space 66 | 67 | @abstractmethod 68 | def reset(self): 69 | """ 70 | Reset all the environments and return an array of 71 | observations, or a dict of observation arrays. 72 | 73 | If step_async is still doing work, that work will 74 | be cancelled and step_wait() should not be called 75 | until step_async() is invoked again. 76 | """ 77 | pass 78 | 79 | @abstractmethod 80 | def step_async(self, actions): 81 | """ 82 | Tell all the environments to start taking a step 83 | with the given actions. 84 | Call step_wait() to get the results of the step. 85 | 86 | You should not call this if a step_async run is 87 | already pending. 88 | """ 89 | pass 90 | 91 | @abstractmethod 92 | def step_wait(self): 93 | """ 94 | Wait for the step taken with step_async(). 95 | 96 | Returns (obs, rews, dones, infos): 97 | - obs: an array of observations, or a dict of 98 | arrays of observations. 99 | - rews: an array of rewards 100 | - dones: an array of "episode done" booleans 101 | - infos: a sequence of info objects 102 | """ 103 | pass 104 | 105 | def close_extras(self): 106 | """ 107 | Clean up the extra resources, beyond what's in this base class. 108 | Only runs when not self.closed. 109 | """ 110 | pass 111 | 112 | def close(self): 113 | if self.closed: 114 | return 115 | if self.viewer is not None: 116 | self.viewer.close() 117 | self.close_extras() 118 | self.closed = True 119 | 120 | def step(self, actions): 121 | """ 122 | Step the environments synchronously. 123 | 124 | This is available for backwards compatibility. 125 | """ 126 | self.step_async(actions) 127 | return self.step_wait() 128 | 129 | def render(self, mode='human'): 130 | imgs = self.get_images() 131 | bigimg = tile_images(imgs) 132 | if mode == 'human': 133 | self.get_viewer().imshow(bigimg) 134 | return self.get_viewer().isopen 135 | elif mode == 'rgb_array': 136 | return bigimg 137 | else: 138 | raise NotImplementedError 139 | 140 | def get_images(self): 141 | """ 142 | Return RGB images from each environment 143 | """ 144 | raise NotImplementedError 145 | 146 | @property 147 | def unwrapped(self): 148 | if isinstance(self, VecEnvWrapper): 149 | return self.venv.unwrapped 150 | else: 151 | return self 152 | 153 | def get_viewer(self): 154 | if self.viewer is None: 155 | from gym.envs.classic_control import rendering 156 | self.viewer = rendering.SimpleImageViewer() 157 | return self.viewer 158 | 159 | class VecEnvWrapper(VecEnv): 160 | """ 161 | An environment wrapper that applies to an entire batch 162 | of environments at once. 163 | """ 164 | 165 | def __init__(self, venv, observation_space=None, action_space=None): 166 | self.venv = venv 167 | super().__init__(num_envs=venv.num_envs, 168 | observation_space=observation_space or venv.observation_space, 169 | action_space=action_space or venv.action_space) 170 | 171 | def step_async(self, actions): 172 | self.venv.step_async(actions) 173 | 174 | @abstractmethod 175 | def reset(self): 176 | pass 177 | 178 | @abstractmethod 179 | def step_wait(self): 180 | pass 181 | 182 | def close(self): 183 | return self.venv.close() 184 | 185 | def render(self, mode='human'): 186 | return self.venv.render(mode=mode) 187 | 188 | def get_images(self): 189 | return self.venv.get_images() 190 | 191 | def __getattr__(self, name): 192 | if name.startswith('_'): 193 | raise AttributeError("attempted to get missing private attribute '{}'".format(name)) 194 | return getattr(self.venv, name) 195 | 196 | class VecEnvObservationWrapper(VecEnvWrapper): 197 | @abstractmethod 198 | def process(self, obs): 199 | pass 200 | 201 | def reset(self): 202 | obs = self.venv.reset() 203 | return self.process(obs) 204 | 205 | def step_wait(self): 206 | obs, rews, dones, infos = self.venv.step_wait() 207 | return self.process(obs), rews, dones, infos 208 | 209 | class CloudpickleWrapper(object): 210 | """ 211 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 212 | """ 213 | 214 | def __init__(self, x): 215 | self.x = x 216 | 217 | def __getstate__(self): 218 | import cloudpickle 219 | return cloudpickle.dumps(self.x) 220 | 221 | def __setstate__(self, ob): 222 | import pickle 223 | self.x = pickle.loads(ob) 224 | 225 | 226 | @contextlib.contextmanager 227 | def clear_mpi_env_vars(): 228 | """ 229 | from mpi4py import MPI will call MPI_Init by default. If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang. 230 | This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing 231 | Processes. 232 | """ 233 | removed_environment = {} 234 | for k, v in list(os.environ.items()): 235 | for prefix in ['OMPI_', 'PMI_']: 236 | if k.startswith(prefix): 237 | removed_environment[k] = v 238 | del os.environ[k] 239 | try: 240 | yield 241 | finally: 242 | os.environ.update(removed_environment) 243 | -------------------------------------------------------------------------------- /Env/AtariEnv/vec_frame_stack.py: -------------------------------------------------------------------------------- 1 | from vec_env import VecEnvWrapper 2 | import numpy as np 3 | from gym import spaces 4 | 5 | 6 | class VecFrameStack(VecEnvWrapper): 7 | def __init__(self, venv, nstack): 8 | self.venv = venv 9 | self.nstack = nstack 10 | wos = venv.observation_space # wrapped ob space 11 | low = np.repeat(wos.low, self.nstack, axis=-1) 12 | high = np.repeat(wos.high, self.nstack, axis=-1) 13 | self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype) 14 | observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype) 15 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space) 16 | 17 | def step_wait(self): 18 | obs, rews, news, infos = self.venv.step_wait() 19 | self.stackedobs = np.roll(self.stackedobs, shift=-1, axis=-1) 20 | for (i, new) in enumerate(news): 21 | if new: 22 | self.stackedobs[i] = 0 23 | self.stackedobs[..., -obs.shape[-1]:] = obs 24 | return self.stackedobs, rews, news, infos 25 | 26 | def reset(self): 27 | obs = self.venv.reset() 28 | self.stackedobs[...] = 0 29 | self.stackedobs[..., -obs.shape[-1]:] = obs 30 | return self.stackedobs 31 | -------------------------------------------------------------------------------- /Env/EnvWrapper.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from copy import deepcopy 3 | 4 | from Env.AtariEnv.AtariEnvWrapper import make_atari_env 5 | 6 | 7 | # To allow easily extending to other tasks, we built a wrapper on top of the 'real' environment. 8 | class EnvWrapper(): 9 | def __init__(self, env_name, max_episode_length = 0, enable_record = False, record_path = "1.mp4"): 10 | self.env_name = env_name 11 | 12 | self.env_type = None 13 | 14 | try: 15 | self.env, self.recorder = make_atari_env(env_name, 0, 0, enable_record = enable_record, 16 | record_path = record_path) 17 | 18 | # Call reset to avoid gym bugs. 19 | self.env.reset() 20 | 21 | self.env_type = "Atari" 22 | except gym.error.Error: 23 | exit(1) 24 | 25 | assert isinstance(self.env.action_space, gym.spaces.Discrete), "Should be discrete action space." 26 | self.action_n = self.env.action_space.n 27 | 28 | self.max_episode_length = self.env._max_episode_steps if max_episode_length == 0 else max_episode_length 29 | 30 | self.current_step_count = 0 31 | 32 | self.since_last_reset = 0 33 | 34 | def reset(self): 35 | state = self.env.reset() 36 | 37 | self.current_step_count = 0 38 | self.since_last_reset = 0 39 | 40 | return state 41 | 42 | def step(self, action): 43 | next_state, reward, done, _ = self.env.step(action) 44 | 45 | self.current_step_count += 1 46 | if self.current_step_count >= self.max_episode_length: 47 | done = True 48 | 49 | self.since_last_reset += 1 50 | 51 | return next_state, reward, done 52 | 53 | def checkpoint(self): 54 | return deepcopy(self.env.clone_full_state()), self.current_step_count 55 | 56 | def restore(self, checkpoint): 57 | if self.since_last_reset > 20000: 58 | self.reset() 59 | self.since_last_reset = 0 60 | 61 | self.env.restore_full_state(checkpoint[0]) 62 | 63 | self.current_step_count = checkpoint[1] 64 | 65 | return self.env.get_state() 66 | 67 | def render(self): 68 | self.env.render() 69 | 70 | def capture_frame(self): 71 | self.recorder.capture_frame() 72 | 73 | def store_video_files(self): 74 | self.recorder.write_metadata() 75 | 76 | def close(self): 77 | self.env.close() 78 | 79 | def seed(self, seed): 80 | self.env.seed(seed) 81 | 82 | def get_action_n(self): 83 | return self.action_n 84 | 85 | def get_max_episode_length(self): 86 | return self.max_episode_length 87 | -------------------------------------------------------------------------------- /Env/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Env/__init__.py -------------------------------------------------------------------------------- /Figures/Figure_atari_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Figures/Figure_atari_results.png -------------------------------------------------------------------------------- /Figures/Figure_puct_conceptual_idea.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Figures/Figure_puct_conceptual_idea.png -------------------------------------------------------------------------------- /Figures/Figure_puct_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Figures/Figure_puct_pipeline.png -------------------------------------------------------------------------------- /Figures/Figure_tap_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Figures/Figure_tap_results.png -------------------------------------------------------------------------------- /Figures/Figure_time_consumption.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Figures/Figure_time_consumption.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Anji Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Logs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Logs/.gitkeep -------------------------------------------------------------------------------- /Mem/CheckpointManager.py: -------------------------------------------------------------------------------- 1 | # This is the centralized game-state storage. 2 | class CheckpointManager(): 3 | def __init__(self): 4 | self.buffer = dict() 5 | 6 | self.envs = dict() 7 | 8 | def hock_env(self, name, env): 9 | self.envs[name] = env 10 | 11 | def checkpoint_env(self, name, idx): 12 | self.store(idx, self.envs[name].checkpoint()) 13 | 14 | def load_checkpoint_env(self, name, idx): 15 | self.envs[name].restore(self.retrieve(idx)) 16 | 17 | def store(self, idx, checkpoint_data): 18 | assert idx not in self.buffer 19 | 20 | self.buffer[idx] = checkpoint_data 21 | 22 | def retrieve(self, idx): 23 | assert idx in self.buffer 24 | 25 | return self.buffer[idx] 26 | 27 | def length(self): 28 | return len(self.buffer) 29 | 30 | def clear(self): 31 | self.buffer.clear() 32 | -------------------------------------------------------------------------------- /Mem/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Mem/__init__.py -------------------------------------------------------------------------------- /Node/UCTnode.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | import math 4 | import random 5 | 6 | from Utils.MovingAvegCalculator import MovingAvegCalculator 7 | 8 | 9 | class UCTnode(): 10 | def __init__(self, action_n, state, checkpoint_idx, parent, tree, 11 | prior_prob = None, is_head = False, allowed_actions = None): 12 | self.action_n = action_n 13 | self.state = state 14 | self.checkpoint_idx = checkpoint_idx 15 | self.parent = parent 16 | self.tree = tree 17 | self.is_head = is_head 18 | self.allowed_actions = allowed_actions 19 | 20 | if tree is not None: 21 | self.max_width = tree.max_width 22 | else: 23 | self.max_width = 0 24 | 25 | self.children = [None for _ in range(self.action_n)] 26 | self.rewards = [0.0 for _ in range(self.action_n)] 27 | self.dones = [False for _ in range(self.action_n)] 28 | self.children_visit_count = [0 for _ in range(self.action_n)] 29 | self.Q_values = [0 for _ in range(self.action_n)] 30 | self.visit_count = 0 31 | 32 | if prior_prob is not None: 33 | self.prior_prob = prior_prob 34 | else: 35 | self.prior_prob = np.ones([self.action_n], dtype = np.float32) / self.action_n 36 | 37 | # Record traverse history 38 | self.traverse_history = list() 39 | 40 | # Updated node count 41 | self.updated_node_count = 0 42 | 43 | # Moving average calculator 44 | self.moving_aveg_calculator = MovingAvegCalculator(window_length = 500) 45 | 46 | def no_child_available(self): 47 | # All child nodes have not been expanded. 48 | return self.updated_node_count == 0 49 | 50 | def all_child_visited(self): 51 | # All child nodes have been visited and updated. 52 | if self.is_head: 53 | if self.allowed_actions is None: 54 | return self.updated_node_count == self.action_n 55 | else: 56 | return self.updated_node_count == len(self.allowed_actions) 57 | else: 58 | return self.updated_node_count == self.max_width 59 | 60 | def select_action(self): 61 | best_score = -10000.0 62 | best_action = 0 63 | 64 | for action in range(self.action_n): 65 | if self.children[action] is None: 66 | continue 67 | 68 | if self.allowed_actions is not None and action not in self.allowed_actions: 69 | continue 70 | 71 | exploit_score = self.Q_values[action] / self.children_visit_count[action] 72 | explore_score = math.sqrt(1.0 * math.log(self.visit_count) / self.children_visit_count[action]) 73 | score_std = self.moving_aveg_calculator.get_standard_deviation() 74 | score = exploit_score + score_std * explore_score 75 | 76 | if score > best_score: 77 | best_score = score 78 | best_action = action 79 | 80 | return best_action 81 | 82 | def max_utility_action(self): 83 | best_score = -10000.0 84 | best_action = 0 85 | 86 | for action in range(self.action_n): 87 | if self.children[action] is None: 88 | continue 89 | 90 | score = self.Q_values[action] / self.children_visit_count[action] 91 | 92 | if score > best_score: 93 | best_score = score 94 | best_action = action 95 | 96 | return best_action 97 | 98 | def select_expand_action(self): 99 | count = 0 100 | 101 | while True: 102 | if self.allowed_actions is None: 103 | if count < 20: 104 | action = self.categorical(self.prior_prob) 105 | else: 106 | action = np.random.randint(0, self.action_n) 107 | else: 108 | action = random.choice(self.allowed_actions) 109 | 110 | if count > 100: 111 | return action 112 | 113 | if self.children_visit_count[action] > 0 and count < 10: 114 | count += 1 115 | continue 116 | 117 | if self.children[action] is None: 118 | return action 119 | 120 | count += 1 121 | 122 | def update_history(self, action_taken, reward): 123 | self.traverse_history = (action_taken, reward) 124 | 125 | def update(self, accu_reward): 126 | action_taken = self.traverse_history[0] 127 | reward = self.traverse_history[1] 128 | 129 | accu_reward = reward + self.tree.gamma * accu_reward 130 | 131 | if self.children_visit_count[action_taken] == 0: 132 | self.updated_node_count += 1 133 | 134 | self.children_visit_count[action_taken] += 1 135 | self.Q_values[action_taken] += accu_reward 136 | 137 | self.visit_count += 1 138 | 139 | self.moving_aveg_calculator.add_number(accu_reward) 140 | 141 | return accu_reward 142 | 143 | def add_child(self, action, child_state, checkpoint_idx, prior_prob): 144 | if self.children[action] is not None: 145 | node = self.children[action] 146 | else: 147 | node = UCTnode( 148 | action_n = self.action_n, 149 | state = child_state, 150 | checkpoint_idx = checkpoint_idx, 151 | parent = self, 152 | tree = self.tree, 153 | prior_prob = prior_prob 154 | ) 155 | 156 | self.children[action] = node 157 | 158 | return node 159 | 160 | @staticmethod 161 | def categorical(pvals): 162 | num = np.random.random() 163 | for i in range(pvals.size): 164 | if num < pvals[i]: 165 | return i 166 | else: 167 | num -= pvals[i] 168 | 169 | return pvals.size - 1 170 | -------------------------------------------------------------------------------- /Node/WU_UCTnode.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | import math 4 | 5 | from Utils.MovingAvegCalculator import MovingAvegCalculator 6 | 7 | 8 | class WU_UCTnode(): 9 | def __init__(self, action_n, state, checkpoint_idx, parent, tree, 10 | prior_prob = None, is_head = False): 11 | self.action_n = action_n 12 | self.state = state 13 | self.checkpoint_idx = checkpoint_idx 14 | self.parent = parent 15 | self.tree = tree 16 | self.is_head = is_head 17 | 18 | if tree is not None: 19 | self.max_width = tree.max_width 20 | else: 21 | self.max_width = 0 22 | 23 | self.children = [None for _ in range(self.action_n)] 24 | self.rewards = [0.0 for _ in range(self.action_n)] 25 | self.dones = [False for _ in range(self.action_n)] 26 | self.children_visit_count = [0 for _ in range(self.action_n)] 27 | self.children_completed_visit_count = [0 for _ in range(self.action_n)] 28 | self.Q_values = [0 for _ in range(self.action_n)] 29 | self.visit_count = 0 30 | 31 | if prior_prob is not None: 32 | self.prior_prob = prior_prob 33 | else: 34 | self.prior_prob = np.ones([self.action_n], dtype=np.float32) / self.action_n 35 | 36 | # Record traverse history 37 | self.traverse_history = dict() 38 | 39 | # Visited node count 40 | self.visited_node_count = 0 41 | 42 | # Updated node count 43 | self.updated_node_count = 0 44 | 45 | # Moving average calculator 46 | self.moving_aveg_calculator = MovingAvegCalculator(window_length = 500) 47 | 48 | def no_child_available(self): 49 | # All child nodes have not been expanded. 50 | return self.updated_node_count == 0 51 | 52 | def all_child_visited(self): 53 | # All child nodes have been visited (not necessarily updated). 54 | if self.is_head: 55 | return self.visited_node_count == self.action_n 56 | else: 57 | return self.visited_node_count == self.max_width 58 | 59 | def all_child_updated(self): 60 | # All child nodes have been updated. 61 | if self.is_head: 62 | return self.updated_node_count == self.action_n 63 | else: 64 | return self.updated_node_count == self.max_width 65 | 66 | # Shallowly clone itself, contains necessary data only. 67 | def shallow_clone(self): 68 | node = WU_UCTnode( 69 | action_n = self.action_n, 70 | state = deepcopy(self.state), 71 | checkpoint_idx = self.checkpoint_idx, 72 | parent = None, 73 | tree = None, 74 | prior_prob = None, 75 | is_head = False 76 | ) 77 | 78 | for action in range(self.action_n): 79 | if self.children[action] is not None: 80 | node.children[action] = 1 81 | 82 | node.children_visit_count = deepcopy(self.children_visit_count) 83 | node.children_completed_visit_count = deepcopy(self.children_completed_visit_count) 84 | 85 | node.visited_node_count = self.visited_node_count 86 | node.updated_node_count = self.updated_node_count 87 | 88 | node.action_n = self.action_n 89 | node.max_width = self.max_width 90 | 91 | node.prior_prob = self.prior_prob.copy() 92 | 93 | return node 94 | 95 | # Select action according to the P-UCT tree policy 96 | def select_action(self): 97 | best_score = -10000.0 98 | best_action = 0 99 | 100 | for action in range(self.action_n): 101 | if self.children[action] is None: 102 | continue 103 | 104 | exploit_score = self.Q_values[action] / self.children_completed_visit_count[action] 105 | explore_score = math.sqrt(2.0 * math.log(self.visit_count) / self.children_visit_count[action]) 106 | score_std = self.moving_aveg_calculator.get_standard_deviation() 107 | score = exploit_score + score_std * 2.0 * explore_score 108 | 109 | if score > best_score: 110 | best_score = score 111 | best_action = action 112 | 113 | return best_action 114 | 115 | # Return the action with maximum utility. 116 | def max_utility_action(self): 117 | best_score = -10000.0 118 | best_action = 0 119 | 120 | for action in range(self.action_n): 121 | if self.children[action] is None: 122 | continue 123 | 124 | score = self.Q_values[action] / self.children_completed_visit_count[action] 125 | 126 | if score > best_score: 127 | best_score = score 128 | best_action = action 129 | 130 | return best_action 131 | 132 | # Choose an action to expand 133 | def select_expand_action(self): 134 | count = 0 135 | 136 | while True: 137 | if count < 20: 138 | action = self.categorical(self.prior_prob) 139 | else: 140 | action = np.random.randint(0, self.action_n) 141 | 142 | if count > 100: 143 | return action 144 | 145 | if self.children_visit_count[action] > 0 and count < 10: 146 | count += 1 147 | continue 148 | 149 | if self.children[action] is None: 150 | return action 151 | 152 | count += 1 153 | 154 | # Update traverse history, used to perform update 155 | def update_history(self, idx, action_taken, reward): 156 | if idx in self.traverse_history: 157 | return False 158 | else: 159 | self.traverse_history[idx] = (action_taken, reward) 160 | return True 161 | 162 | # Incomplete update, called by WU_UCT.py 163 | def update_incomplete(self, idx): 164 | action_taken = self.traverse_history[idx][0] 165 | 166 | if self.children_visit_count[action_taken] == 0: 167 | self.visited_node_count += 1 168 | 169 | self.children_visit_count[action_taken] += 1 170 | self.visit_count += 1 171 | 172 | # Complete update, called by WU_UCT.py 173 | def update_complete(self, idx, accu_reward): 174 | if idx not in self.traverse_history: 175 | raise RuntimeError("idx {} should be in traverse_history".format(idx)) 176 | else: 177 | item = self.traverse_history.pop(idx) 178 | action_taken = item[0] 179 | reward = item[1] 180 | 181 | accu_reward = reward + self.tree.gamma * accu_reward 182 | 183 | if self.children_completed_visit_count[action_taken] == 0: 184 | self.updated_node_count += 1 185 | 186 | self.children_completed_visit_count[action_taken] += 1 187 | self.Q_values[action_taken] += accu_reward 188 | 189 | self.moving_aveg_calculator.add_number(accu_reward) 190 | 191 | return accu_reward 192 | 193 | # Add a child to current node. 194 | def add_child(self, action, child_state, checkpoint_idx, prior_prob = None): 195 | if self.children[action] is not None: 196 | node = self.children[action] 197 | else: 198 | node = WU_UCTnode( 199 | action_n = self.action_n, 200 | state = child_state, 201 | checkpoint_idx = checkpoint_idx, 202 | parent = self, 203 | tree = self.tree, 204 | prior_prob = prior_prob 205 | ) 206 | 207 | self.children[action] = node 208 | 209 | return node 210 | 211 | # Draw a sample from the categorical distribution parametrized by 'pvals'. 212 | @staticmethod 213 | def categorical(pvals): 214 | num = np.random.random() 215 | for i in range(pvals.size): 216 | if num < pvals[i]: 217 | return i 218 | else: 219 | num -= pvals[i] 220 | 221 | return pvals.size - 1 222 | -------------------------------------------------------------------------------- /Node/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Node/__init__.py -------------------------------------------------------------------------------- /OutLogs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/OutLogs/.gitkeep -------------------------------------------------------------------------------- /OutLogs/WU-UCT_PongNoFrameskip-v0_123_2.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/OutLogs/WU-UCT_PongNoFrameskip-v0_123_2.mat -------------------------------------------------------------------------------- /ParallelPool/PoolManager.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pipe 2 | from copy import deepcopy 3 | import os 4 | import torch 5 | 6 | from ParallelPool.Worker import Worker 7 | 8 | 9 | # Works in the main process and manages sub-workers 10 | class PoolManager(): 11 | def __init__(self, worker_num, env_params, policy = "Random", 12 | gamma = 1.0, seed = 123, device = "cpu", need_policy = True): 13 | self.worker_num = worker_num 14 | self.env_params = env_params 15 | self.policy = policy 16 | self.gamma = gamma 17 | self.seed = seed 18 | self.need_policy = need_policy 19 | 20 | # Buffer for workers and pipes 21 | self.workers = [] 22 | self.pipes = [] 23 | 24 | # CUDA device parallelization 25 | # if multiple cuda devices exist, use them all 26 | if torch.cuda.is_available(): 27 | torch_device_num = torch.cuda.device_count() 28 | else: 29 | torch_device_num = 0 30 | 31 | # Initialize workers 32 | for worker_idx in range(worker_num): 33 | parent_pipe, child_pipe = Pipe() 34 | self.pipes.append(parent_pipe) 35 | 36 | worker = Worker( 37 | pipe = child_pipe, 38 | env_params = deepcopy(env_params), 39 | policy = policy, 40 | gamma = gamma, 41 | seed = seed + worker_idx, 42 | device = device + ":" + str(int(torch_device_num * worker_idx / worker_num)) 43 | if device == "cuda" else device, 44 | need_policy = need_policy 45 | ) 46 | self.workers.append(worker) 47 | 48 | # Start workers 49 | for worker in self.workers: 50 | worker.start() 51 | 52 | # Worker status: 0 for idle, 1 for busy 53 | self.worker_status = [0 for _ in range(worker_num)] 54 | 55 | def has_idle_server(self): 56 | for status in self.worker_status: 57 | if status == 0: 58 | return True 59 | 60 | return False 61 | 62 | def server_occupied_rate(self): 63 | occupied_count = 0.0 64 | 65 | for status in self.worker_status: 66 | occupied_count += status 67 | 68 | return occupied_count / self.worker_num 69 | 70 | def find_idle_worker(self): 71 | for idx, status in enumerate(self.worker_status): 72 | if status == 0: 73 | self.worker_status[idx] = 1 74 | return idx 75 | 76 | return None 77 | 78 | def assign_expansion_task(self, checkpoint_data, curr_node, 79 | saving_idx, task_simulation_idx): 80 | worker_idx = self.find_idle_worker() 81 | 82 | self.send_safe_protocol(worker_idx, "Expansion", ( 83 | checkpoint_data, 84 | curr_node, 85 | saving_idx, 86 | task_simulation_idx 87 | )) 88 | 89 | self.worker_status[worker_idx] = 1 90 | 91 | def assign_simulation_task(self, task_idx, checkpoint_data, first_action = None): 92 | worker_idx = self.find_idle_worker() 93 | 94 | self.send_safe_protocol(worker_idx, "Simulation", ( 95 | task_idx, 96 | checkpoint_data, 97 | first_action 98 | )) 99 | 100 | self.worker_status[worker_idx] = 1 101 | 102 | def get_complete_expansion_task(self): 103 | flag = False 104 | selected_worker_idx = -1 105 | 106 | while not flag: 107 | for worker_idx in range(self.worker_num): 108 | item = self.receive_safe_protocol_tapcheck(worker_idx) 109 | 110 | if item is not None: 111 | flag = True 112 | selected_worker_idx = worker_idx 113 | break 114 | 115 | command, args = item 116 | assert command == "ReturnExpansion" 117 | 118 | # Set to idle 119 | self.worker_status[selected_worker_idx] = 0 120 | 121 | return args 122 | 123 | def get_complete_simulation_task(self): 124 | flag = False 125 | selected_worker_idx = -1 126 | 127 | while not flag: 128 | for worker_idx in range(self.worker_num): 129 | item = self.receive_safe_protocol_tapcheck(worker_idx) 130 | 131 | if item is not None: 132 | flag = True 133 | selected_worker_idx = worker_idx 134 | break 135 | 136 | command, args = item 137 | assert command == "ReturnSimulation" 138 | 139 | # Set to idle 140 | self.worker_status[selected_worker_idx] = 0 141 | 142 | return args 143 | 144 | def send_safe_protocol(self, worker_idx, command, args): 145 | success = False 146 | 147 | while not success: 148 | self.pipes[worker_idx].send((command, args)) 149 | 150 | ret = self.pipes[worker_idx].recv() 151 | if ret == command: 152 | success = True 153 | 154 | def wait_until_all_envs_idle(self): 155 | for worker_idx in range(self.worker_num): 156 | if self.worker_status[worker_idx] == 0: 157 | continue 158 | 159 | self.receive_safe_protocol(worker_idx) 160 | 161 | self.worker_status[worker_idx] = 0 162 | 163 | def receive_safe_protocol(self, worker_idx): 164 | self.pipes[worker_idx].poll(None) 165 | 166 | command, args = self.pipes[worker_idx].recv() 167 | 168 | self.pipes[worker_idx].send(command) 169 | 170 | return deepcopy(command), deepcopy(args) 171 | 172 | def receive_safe_protocol_tapcheck(self, worker_idx): 173 | flag = self.pipes[worker_idx].poll() 174 | if not flag: 175 | return None 176 | 177 | command, args = self.pipes[worker_idx].recv() 178 | 179 | self.pipes[worker_idx].send(command) 180 | 181 | return deepcopy(command), deepcopy(args) 182 | 183 | def close_pool(self): 184 | for worker_idx in range(self.worker_num): 185 | self.send_safe_protocol(worker_idx, "KillProc", None) 186 | 187 | for worker in self.workers: 188 | worker.join() 189 | -------------------------------------------------------------------------------- /ParallelPool/Worker.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Process 2 | from copy import deepcopy 3 | import random 4 | import numpy as np 5 | 6 | from Env.EnvWrapper import EnvWrapper 7 | 8 | from Policy.PPO.PPOPolicy import PPOAtariCNN, PPOSmallAtariCNN 9 | 10 | from Policy.PolicyWrapper import PolicyWrapper 11 | 12 | 13 | # Slave workers 14 | class Worker(Process): 15 | def __init__(self, pipe, env_params, policy = "Random", gamma = 1.0, seed = 123, 16 | device = "cpu", need_policy = True): 17 | super(Worker, self).__init__() 18 | 19 | self.pipe = pipe 20 | self.env_params = deepcopy(env_params) 21 | self.gamma = gamma 22 | self.seed = seed 23 | self.policy = deepcopy(policy) 24 | self.device = deepcopy(device) 25 | self.need_policy = need_policy 26 | 27 | self.wrapped_env = None 28 | self.action_n = None 29 | self.max_episode_length = None 30 | 31 | self.policy_wrapper = None 32 | 33 | # Initialize the environment 34 | def init_process(self): 35 | self.wrapped_env = EnvWrapper(**self.env_params) 36 | 37 | self.wrapped_env.seed(self.seed) 38 | 39 | self.action_n = self.wrapped_env.get_action_n() 40 | self.max_episode_length = self.wrapped_env.get_max_episode_length() 41 | 42 | # Initialize the default policy 43 | def init_policy(self): 44 | self.policy_wrapper = PolicyWrapper( 45 | self.policy, 46 | self.env_params["env_name"], 47 | self.action_n, 48 | self.device 49 | ) 50 | 51 | def run(self): 52 | self.init_process() 53 | self.init_policy() 54 | 55 | print("> Worker ready.") 56 | 57 | while True: 58 | # Wait for tasks 59 | command, args = self.receive_safe_protocol() 60 | 61 | if command == "KillProc": 62 | return 63 | elif command == "Expansion": 64 | checkpoint_data, curr_node, saving_idx, task_idx = args 65 | 66 | # Select expand action, and do expansion 67 | expand_action, next_state, reward, done, \ 68 | checkpoint_data = self.expand_node(checkpoint_data, curr_node) 69 | 70 | item = (expand_action, next_state, reward, done, checkpoint_data, 71 | saving_idx, task_idx) 72 | 73 | self.send_safe_protocol("ReturnExpansion", item) 74 | elif command == "Simulation": 75 | if args is None: 76 | raise RuntimeError 77 | else: 78 | task_idx, checkpoint_data, first_action = args 79 | 80 | state = self.wrapped_env.restore(checkpoint_data) 81 | 82 | # Prior probability is calculated for the new node 83 | prior_prob = self.get_prior_prob(state) 84 | 85 | # When simulation invoked because of reaching maximum search depth, 86 | # an action was actually selected. Therefore, we need to execute it 87 | # first anyway. 88 | if first_action is not None: 89 | state, reward, done = self.wrapped_env.step(first_action) 90 | 91 | if first_action is not None and done: 92 | accu_reward = reward 93 | else: 94 | # Simulate until termination condition satisfied 95 | accu_reward = self.simulate(state) 96 | 97 | if first_action is not None: 98 | self.send_safe_protocol("ReturnSimulation", (task_idx, accu_reward, reward, done)) 99 | else: 100 | self.send_safe_protocol("ReturnSimulation", (task_idx, accu_reward, prior_prob)) 101 | 102 | def expand_node(self, checkpoint_data, curr_node): 103 | self.wrapped_env.restore(checkpoint_data) 104 | 105 | # Choose action to expand, according to the shallow copy node 106 | expand_action = curr_node.select_expand_action() 107 | 108 | # Execute the action, and observe new state, etc. 109 | next_state, reward, done = self.wrapped_env.step(expand_action) 110 | 111 | if not done: 112 | checkpoint_data = self.wrapped_env.checkpoint() 113 | else: 114 | checkpoint_data = None 115 | 116 | return expand_action, next_state, reward, done, checkpoint_data 117 | 118 | def simulate(self, state, max_simulation_step = 100, lamda = 0.5): 119 | step_count = 0 120 | accu_reward = 0.0 121 | accu_gamma = 1.0 122 | 123 | start_state_value = self.get_value(state) 124 | 125 | done = False 126 | # A strict upper bound for simulation count 127 | while not done and step_count < max_simulation_step: 128 | action = self.get_action(state) 129 | 130 | next_state, reward, done = self.wrapped_env.step(action) 131 | 132 | accu_reward += reward * accu_gamma 133 | accu_gamma *= self.gamma 134 | 135 | state = deepcopy(next_state) 136 | 137 | step_count += 1 138 | 139 | if not done: 140 | accu_reward += self.get_value(state) * accu_gamma 141 | 142 | # Use V(s) to stabilize simulation return 143 | accu_reward = accu_reward * lamda + start_state_value * (1.0 - lamda) 144 | 145 | return accu_reward 146 | 147 | def get_action(self, state): 148 | return self.policy_wrapper.get_action(state) 149 | 150 | def get_value(self, state): 151 | return self.policy_wrapper.get_value(state) 152 | 153 | def get_prior_prob(self, state): 154 | return self.policy_wrapper.get_prior_prob(state) 155 | 156 | # Send message through pipe 157 | def send_safe_protocol(self, command, args): 158 | success = False 159 | 160 | count = 0 161 | while not success: 162 | self.pipe.send((command, args)) 163 | 164 | ret = self.pipe.recv() 165 | if ret == command or count >= 10: 166 | success = True 167 | 168 | count += 1 169 | 170 | # Receive message from pipe 171 | def receive_safe_protocol(self): 172 | self.pipe.poll(None) 173 | 174 | command, args = self.pipe.recv() 175 | 176 | self.pipe.send(command) 177 | 178 | return deepcopy(command), deepcopy(args) 179 | -------------------------------------------------------------------------------- /ParallelPool/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/ParallelPool/__init__.py -------------------------------------------------------------------------------- /Policy/PPO/PPOPolicy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | import time 8 | import os 9 | 10 | from Env.AtariEnv.atari_wrappers import LazyFrames 11 | 12 | 13 | def ortho_weights(shape, scale=1.): 14 | """ PyTorch port of ortho_init from baselines.a2c.utils """ 15 | shape = tuple(shape) 16 | 17 | if len(shape) == 2: 18 | flat_shape = shape[1], shape[0] 19 | elif len(shape) == 4: 20 | flat_shape = (np.prod(shape[1:]), shape[0]) 21 | else: 22 | raise NotImplementedError 23 | 24 | a = np.random.normal(0., 1., flat_shape) 25 | u, _, v = np.linalg.svd(a, full_matrices=False) 26 | q = u if u.shape == flat_shape else v 27 | q = q.transpose().copy().reshape(shape) 28 | 29 | if len(shape) == 2: 30 | return torch.from_numpy((scale * q).astype(np.float32)) 31 | if len(shape) == 4: 32 | return torch.from_numpy((scale * q[:, :shape[1], :shape[2]]).astype(np.float32)) 33 | 34 | 35 | def atari_initializer(module): 36 | """ Parameter initializer for Atari models 37 | 38 | Initializes Linear, Conv2d, and LSTM weights. 39 | """ 40 | classname = module.__class__.__name__ 41 | 42 | if classname == 'Linear': 43 | module.weight.data = ortho_weights(module.weight.data.size(), scale=np.sqrt(2.)) 44 | module.bias.data.zero_() 45 | 46 | elif classname == 'Conv2d': 47 | module.weight.data = ortho_weights(module.weight.data.size(), scale=np.sqrt(2.)) 48 | module.bias.data.zero_() 49 | 50 | elif classname == 'LSTM': 51 | for name, param in module.named_parameters(): 52 | if 'weight_ih' in name: 53 | param.data = ortho_weights(param.data.size(), scale=1.) 54 | if 'weight_hh' in name: 55 | param.data = ortho_weights(param.data.size(), scale=1.) 56 | if 'bias' in name: 57 | param.data.zero_() 58 | 59 | 60 | class PPOAtariCNN(): 61 | def __init__(self, num_actions, device = "cpu", checkpoint_dir = ""): 62 | self.num_actions = num_actions 63 | self.device = torch.device(device) 64 | self.checkpoint_dir = checkpoint_dir 65 | 66 | self.model = AtariCNN(num_actions, self.device) 67 | 68 | if checkpoint_dir != "" and os.path.exists(checkpoint_dir): 69 | checkpoint = torch.load(checkpoint_dir, map_location = "cpu") 70 | self.model.load_state_dict(checkpoint["policy"]) 71 | 72 | self.model.to(device) 73 | 74 | def get_action(self, state, logit = False): 75 | return self.model.get_action(state, logit = logit) 76 | 77 | def get_value(self, state): 78 | return self.model.get_value(state) 79 | 80 | 81 | class PPOSmallAtariCNN(): 82 | def __init__(self, num_actions, device = "cpu", checkpoint_dir = ""): 83 | self.num_actions = num_actions 84 | self.device = torch.device(device) 85 | self.checkpoint_dir = checkpoint_dir 86 | 87 | self.model = SmallPolicyAtariCNN(num_actions, self.device) 88 | 89 | if checkpoint_dir != "" and os.path.exists(checkpoint_dir): 90 | checkpoint = torch.load(checkpoint_dir, map_location = "cpu") 91 | # self.model.load_state_dict(checkpoint["policy"]) 92 | 93 | self.model.to(device) 94 | 95 | self.optimizer = optim.Adam(self.model.parameters(), lr = 1e-3) 96 | 97 | self.mseLoss = nn.MSELoss() 98 | 99 | def get_action(self, state): 100 | return self.model.get_action(state) 101 | 102 | def get_value(self, state): 103 | # raise RuntimeError("Small policy net does not support value evaluation.") 104 | return self.model.get_value(state) 105 | 106 | def train_step(self, state_batch, policy_batch, value_batch, temperature = 2.5): 107 | self.optimizer.zero_grad() 108 | 109 | # policy_batch = F.softmax(policy_batch / temperature, dim = 1) 110 | 111 | out_policy, out_value = self.model(state_batch) 112 | # out_policy = F.softmax(out_policy / temperature, dim = 1) 113 | 114 | # loss = -(policy_batch * torch.log(out_policy + 1e-8)).sum(dim = 1).mean() 115 | loss = self.mseLoss(policy_batch, out_policy) + self.mseLoss(value_batch, out_value) 116 | loss.backward() 117 | 118 | self.optimizer.step() 119 | 120 | return loss.detach().cpu().numpy() 121 | 122 | def save(self, path): 123 | torch.save({"policy": self.model.state_dict()}, path) 124 | 125 | 126 | class AtariCNN(nn.Module): 127 | def __init__(self, num_actions, device): 128 | """ Basic convolutional actor-critic network for Atari 2600 games 129 | 130 | Equivalent to the network in the original DQN paper. 131 | 132 | Args: 133 | num_actions (int): the number of available discrete actions 134 | """ 135 | super().__init__() 136 | 137 | self.conv = nn.Sequential(nn.Conv2d(4, 32, 8, stride=4), 138 | nn.ReLU(inplace=True), 139 | nn.Conv2d(32, 64, 4, stride=2), 140 | nn.ReLU(inplace=True), 141 | nn.Conv2d(64, 64, 3, stride=1), 142 | nn.ReLU(inplace=True)) 143 | 144 | self.fc = nn.Sequential(nn.Linear(64 * 7 * 7, 512), 145 | nn.ReLU(inplace=True)) 146 | 147 | self.pi = nn.Linear(512, num_actions) 148 | self.v = nn.Linear(512, 1) 149 | 150 | self.num_actions = num_actions 151 | 152 | self.device = device 153 | 154 | def forward(self, conv_in): 155 | """ Module forward pass 156 | 157 | Args: 158 | conv_in (Variable): convolutional input, shaped [N x 4 x 84 x 84] 159 | 160 | Returns: 161 | pi (Variable): action probability logits, shaped [N x self.num_actions] 162 | v (Variable): value predictions, shaped [N x 1] 163 | """ 164 | N = conv_in.size()[0] 165 | 166 | conv_out = self.conv(conv_in).view(N, 64 * 7 * 7) 167 | 168 | fc_out = self.fc(conv_out) 169 | 170 | pi_out = self.pi(fc_out) 171 | v_out = self.v(fc_out) 172 | 173 | return pi_out, v_out 174 | 175 | def get_action(self, conv_in, logit = False): 176 | if isinstance(conv_in, LazyFrames): 177 | conv_in = torch.from_numpy(np.array(conv_in)).type(torch.float32).to(self.device).unsqueeze(0) 178 | elif isinstance(conv_in, np.ndarray): 179 | conv_in = torch.from_numpy(conv_in).to(self.device) 180 | 181 | N = conv_in.size(0) 182 | s = time.time() 183 | conv_out = self.conv(conv_in).view(N, 64 * 7 * 7) 184 | aa = time.time() 185 | fc_out = self.fc(conv_out) 186 | 187 | if logit: 188 | pi_out = self.pi(fc_out) 189 | else: 190 | pi_out = F.softmax(self.pi(fc_out), dim = 1) 191 | 192 | if N == 1: 193 | pi_out = pi_out.view(-1) 194 | e = time.time() 195 | # print("large", e - s, aa - s) 196 | return pi_out.detach().cpu().numpy() 197 | 198 | def get_value(self, conv_in): 199 | if isinstance(conv_in, LazyFrames): 200 | conv_in = torch.from_numpy(np.array(conv_in)).type(torch.float32).to(self.device).unsqueeze(0) 201 | else: 202 | raise NotImplementedError() 203 | 204 | N = conv_in.size(0) 205 | 206 | conv_out = self.conv(conv_in).view(N, 64 * 7 * 7) 207 | 208 | fc_out = self.fc(conv_out) 209 | 210 | v_out = self.v(fc_out) 211 | 212 | if N == 1: 213 | v_out = v_out.sum() 214 | 215 | return v_out.detach().cpu().numpy() 216 | 217 | 218 | class SmallPolicyAtariCNN(nn.Module): 219 | def __init__(self, num_actions, device): 220 | """ Basic convolutional actor-critic network for Atari 2600 games 221 | 222 | Equivalent to the network in the original DQN paper. 223 | 224 | Args: 225 | num_actions (int): the number of available discrete actions 226 | """ 227 | super().__init__() 228 | 229 | self.conv = nn.Sequential(nn.Conv2d(4, 16, 3, stride = 4), 230 | nn.ReLU(inplace = True), 231 | nn.Conv2d(16, 16, 3, stride = 4), 232 | nn.ReLU(inplace = True)) 233 | 234 | self.fc = nn.Sequential(nn.Linear(16 * 5 * 5, 64), 235 | nn.ReLU(inplace = True)) 236 | 237 | self.pi = nn.Linear(64, num_actions) 238 | self.v = nn.Linear(64, 1) 239 | 240 | self.num_actions = num_actions 241 | 242 | self.device = device 243 | 244 | def forward(self, conv_in): 245 | """ Module forward pass 246 | 247 | Args: 248 | conv_in (Variable): convolutional input, shaped [N x 4 x 84 x 84] 249 | 250 | Returns: 251 | pi (Variable): action probability logits, shaped [N x self.num_actions] 252 | v (Variable): value predictions, shaped [N x 1] 253 | """ 254 | N = conv_in.size()[0] 255 | 256 | conv_out = self.conv(conv_in).view(N, 16 * 5 * 5) 257 | 258 | fc_out = self.fc(conv_out) 259 | 260 | pi_out = self.pi(fc_out) 261 | v_out = self.v(fc_out) 262 | 263 | return pi_out, v_out 264 | 265 | def get_action(self, conv_in, logit = False, get_tensor = False): 266 | if isinstance(conv_in, LazyFrames): 267 | conv_in = torch.from_numpy(np.array(conv_in)).type(torch.float32).to(self.device).unsqueeze(0) 268 | elif isinstance(conv_in, np.ndarray): 269 | conv_in = torch.from_numpy(conv_in).to(self.device) 270 | 271 | N = conv_in.size(0) 272 | s = time.time() 273 | # with torch.no_grad(): 274 | conv_out = self.conv(conv_in).view(N, 16 * 5 * 5) 275 | aa = time.time() 276 | fc_out = self.fc(conv_out) 277 | bb = time.time() 278 | if logit: 279 | pi_out = self.pi(fc_out) 280 | else: 281 | pi_out = F.softmax(self.pi(fc_out), dim = 1) 282 | cc = time.time() 283 | if N == 1: 284 | pi_out = pi_out.view(-1) 285 | e = time.time() 286 | print("small", e - s, aa - s, bb - aa, cc - bb) 287 | if get_tensor: 288 | return pi_out 289 | else: 290 | return pi_out.detach().cpu().numpy() 291 | 292 | def get_value(self, conv_in): 293 | if isinstance(conv_in, LazyFrames): 294 | conv_in = torch.from_numpy(np.array(conv_in)).type(torch.float32).to(self.device).unsqueeze(0) 295 | else: 296 | raise NotImplementedError() 297 | 298 | N = conv_in.size(0) 299 | 300 | conv_out = self.conv(conv_in).view(N, 16 * 5 * 5) 301 | 302 | fc_out = self.fc(conv_out) 303 | 304 | v_out = self.v(fc_out) 305 | # print("a") 306 | if N == 1: 307 | v_out = v_out.sum() 308 | 309 | return v_out.detach().cpu().numpy() 310 | -------------------------------------------------------------------------------- /Policy/PPO/PolicyFiles/PPO_AlienNoFrameskip-v0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Policy/PPO/PolicyFiles/PPO_AlienNoFrameskip-v0.pt -------------------------------------------------------------------------------- /Policy/PolicyWrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import os 4 | 5 | from Policy.PPO.PPOPolicy import PPOAtariCNN, PPOSmallAtariCNN 6 | 7 | 8 | class PolicyWrapper(): 9 | def __init__(self, policy_name, env_name, action_n, device): 10 | self.policy_name = policy_name 11 | self.env_name = env_name 12 | self.action_n = action_n 13 | self.device = device 14 | 15 | self.policy_func = None 16 | 17 | self.init_policy() 18 | 19 | def init_policy(self): 20 | if self.policy_name == "Random": 21 | self.policy_func = None 22 | 23 | elif self.policy_name == "PPO": 24 | assert os.path.exists("./Policy/PPO/PolicyFiles/PPO_" + self.env_name + ".pt"), "Policy file not found" 25 | 26 | self.policy_func = PPOAtariCNN( 27 | self.action_n, 28 | device = self.device, 29 | checkpoint_dir = "./Policy/PPO/PolicyFiles/PPO_" + self.env_name + ".pt" 30 | ) 31 | 32 | elif self.policy_name == "DistillPPO": 33 | assert os.path.exists("./Policy/PPO/PolicyFiles/PPO_" + self.env_name + ".pt"), "Policy file not found" 34 | assert os.path.exists("./Policy/PPO/PolicyFiles/SmallPPO_" + self.env_name + ".pt"), "Policy file not found" 35 | 36 | full_policy = PPOAtariCNN( 37 | self.action_n, 38 | device = "cpu", # To save memory 39 | checkpoint_dir = "./Policy/PPO/PolicyFiles/PPO_" + self.env_name + ".pt" 40 | ) 41 | 42 | small_policy = PPOSmallAtariCNN( 43 | self.action_n, 44 | device = self.device, 45 | checkpoint_dir = "./Policy/PPO/PolicyFiles/SmallPPO_" + self.env_name + ".pt" 46 | ) 47 | 48 | self.policy_func = [full_policy, small_policy] 49 | else: 50 | raise NotImplementedError() 51 | 52 | def get_action(self, state): 53 | if self.policy_name == "Random": 54 | return random.randint(0, self.action_n - 1) 55 | elif self.policy_name == "PPO": 56 | return self.categorical(self.policy_func.get_action(state)) 57 | elif self.policy_name == "DistillPPO": 58 | return self.categorical(self.policy_func[1].get_action(state)) 59 | else: 60 | raise NotImplementedError() 61 | 62 | def get_value(self, state): 63 | if self.policy_name == "Random": 64 | return 0.0 65 | elif self.policy_name == "PPO": 66 | return self.policy_func.get_value(state) 67 | elif self.policy_name == "DistillPPO": 68 | return self.policy_func[0].get_value(state) 69 | else: 70 | raise NotImplementedError() 71 | 72 | def get_prior_prob(self, state): 73 | if self.policy_name == "Random": 74 | return np.ones([self.action_n], dtype = np.float32) / self.action_n 75 | elif self.policy_name == "PPO": 76 | return self.policy_func.get_action(state) 77 | elif self.policy_name == "DistillPPO": 78 | return self.policy_func[0].get_action(state) 79 | else: 80 | raise NotImplementedError() 81 | 82 | @staticmethod 83 | def categorical(probs): 84 | val = random.random() 85 | chosen_idx = 0 86 | 87 | for prob in probs: 88 | val -= prob 89 | 90 | if val < 0.0: 91 | break 92 | 93 | chosen_idx += 1 94 | 95 | if chosen_idx >= len(probs): 96 | chosen_idx = len(probs) - 1 97 | 98 | return chosen_idx 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WU-UCT (Watch the Unobserved in UCT) 2 | A novel parallel UCT algorithm with linear speedup and negligible performance loss. This package provides a demo on Atari games (see [Running](#Running)). To allow easy extension to other environments, we wrote the code in an extendable way, and modification on only two files are needed for other environments (see [Run on your own environments](#Run-on-your-own-environments)). 3 | 4 | This work has been accepted by **ICLR 2020** for **oral full presentation (48/2594)**. 5 | 6 | # A quick demo! 7 | 8 | This [Google Colaboratory link](https://colab.research.google.com/drive/140Ea6pd8abvg_HdZDVsCVfdeLICSa-f0) contains a demo on the PongNoFrameskip-v0 environment. We thank [@lvisdd](https://github.com/lvisdd) for building this useful demo! 9 | 10 | # Introduction 11 | Note: For full details of WU-UCT, please refer to our [Arxiv](https://arxiv.org/abs/1810.11755) or [OpenReview](https://openreview.net/forum?id=BJlQtJSKDB) paper. 12 | 13 | ## Conceptual idea 14 |

15 | 16 |

17 | 18 | We use the above figure to demonstrate the main problem caused by parallelizing UCT. (a) illustrates the four main steps of UCT: selection, expansion, simulation, and backpropagation. (b) a demonstration of the ideal (but unrealistic) parallel algorithm, i.e., the return V (cumulative reward) is available as soon as simulations start (in real-world cases they are observable only after simulations complete). (c) if we parallelize UCT naively, problems such as *collapes of exploration* or *exploitation failure* will happen. Specifically, since less statistics are available at the selection step, the algorithm cannot choose the "best" node to query. (d) we propose to keep track of the on-going but non-terminated simulations (called unobserved samples) to correct and compensate the outdated statistics. This allows performing principled selection step on parallel settings, allowing WU-UCT to achieve linear speedup as well as negligible performance loss. 19 | 20 |

21 | 22 |

23 | 24 | WU-UCT achieves ideal speedup under up to 16 workers, also without performance degradation. 25 | 26 |

27 | 28 |

29 | 30 | Clear advantage compared to baseline parallel approaches, in terms of both speed and accuracy. 31 | 32 | ## System implementation 33 |

34 | 35 |

36 | 37 | Our implementation of the system consists of a master process and two sets of slave workers, i.e., expansion workers and simulation workers. With a clear division of labor, we parallel the most time-consuming expansion and simulation step, while maintain the sequential structure in the selection and backpropagation step. 38 | 39 |

40 | 41 |

42 | 43 | The breakdown of time consumption (tested with 16 expansion and simulation workers) indicates we successfully parallelize the most time-consuming expansion and simulation process and maintains time-consumption of other steps relatively small. 44 | 45 | # Usage 46 | ## Prerequisites 47 | - Python 3.x 48 | - PyTorch 1.0 49 | - Gym (with atari) 0.14.0 50 | - Numpy 1.17.2 51 | - Scipy 1.3.1 52 | - OpenCV-Python 4.1.1.26 53 | 54 | ## Running 55 | 1. Download or clone the repository. 56 | 2. Run with the default settings: 57 | ``` 58 | python3 main.py --model WU-UCT 59 | ``` 60 | 3. For additional hyperparameters please have a look at [main.py](https://github.com/liuanji/WU-UCT/tree/master/main.py) (they are also listed below), where descriptions are also included. For example, if you want to run the game PongNoFrameskip-v0 with 200 MCTS rollouts, simply run: 61 | ``` 62 | python3 main.py --model WU-UCT --env-name PongNoFrameskip-v0 --MCTS-max-steps 200 63 | ``` 64 | or if you want to record the video of gameplay, run: 65 | ``` 66 | python3 main.py --model WU-UCT --env-name PongNoFrameskip-v0 --record-video 67 | ``` 68 | 69 | * A full list of parameters 70 | * --model: MCTS model to use (currently support WU-UCT and UCT). 71 | * --env-name: name of the environment. 72 | * --MCTS-max-steps: number of simulation steps in the planning phase. 73 | * --MCTS-max-depth: maximum planning depth. 74 | * --MCTS-max-width: maximum width for each node. 75 | * --gamma: environment discount factor. 76 | * --expansion-worker-num: number of expansion workers. 77 | * --simulation-worker-num: number of simulation workers. 78 | * --seed: random seed for the environment. 79 | * --max-episode-length: a strict upper bound of environment's episode length. 80 | * --policy: default policy (see above). 81 | * --device: support "cpu", "cuda:x", and "cuda". If entered "cuda", it will use all available cuda devices. Usually used to load the policy. 82 | * --record-video: see above. 83 | * --mode: MCTS or Distill, see [Planning with prior policy](#Planning-with-prior-policy). 84 | 85 | ### Planning with prior policy 86 | The code currently support three default policies (policy used to perform simulation): *Random*, *PPO*, *DistillPPO* (to use them, change the “--policy” parameter). To use the *PPO* and *DistillPPO* policy, corresponding policy files need to be put in [./Policy/PPO/PolicyFiles](https://github.com/liuanji/WU-UCT/tree/master/Policy/PPO/PolicyFiles). PPO policy files can be generated by [Atari_PPO_training](https://github.com/liuanji/WU-UCT/tree/master/Utils/Atari_PPO_training) (or see [Github](https://github.com/lnpalmer/PPO)). For example, by running 87 | ``` 88 | cd Utils/Atari_PPO_training 89 | python3 main.py PongNoFrameskip-v0 90 | ``` 91 | a policy file will be generated in [./Utils/Atari_PPO_training/save](https://github.com/liuanji/WU-UCT/tree/master/Utils/Atari_PPO_training/save). To run DistillPPO, we have to run the distill training process by 92 | ``` 93 | python3 main.py --mode Distill --env-name PongNoFrameskip-v0 94 | ``` 95 | 96 | ## Run on your own environments 97 | We kindly provide an [environment wrapper](https://github.com/liuanji/WU-UCT/tree/master/Env/EnvWrapper.py) and a [policy wrapper](https://github.com/liuanji/WU-UCT/tree/master/Policy/PolicyWrapper.py) to make easy extensions to other environments. All you need is to modify [./Env/EnvWrapper.py](https://github.com/liuanji/WU-UCT/tree/master/Env/EnvWrapper.py) and [./Policy/PolicyWrapper.py](https://github.com/liuanji/WU-UCT/tree/master/Policy/PolicyWrapper.py), and fit in your own environment. Please follow the below instructions. 98 | 99 | 1. Edit the class EnvWrapper in [./Env/EnvWrapper.py](https://github.com/liuanji/WU-UCT/tree/master/Env/EnvWrapper.py). 100 | 101 | Nest your environment into the wrapper by providing specific functionality in each of the member function of EnvWrapper. There are currently four input arguments to EnvWrapper: *env_name*, *max_episode_length*, *enable_record*, and *record_path*. If additional information needs to be imported, you may first consider adding them in *env_name*. 102 | 103 | 2. Edit the class PolicyWrapper in [./Policy/PolicyWrapper.py](https://github.com/liuanji/WU-UCT/tree/master/Policy/PolicyWrapper.py). 104 | 105 | Similarly, nest your default policy in PolicyWrapper, and pass the corresponding method using --policy. You will need to rewrite *get_action*, *get_value*, and *get_prior_prob* three member functions. 106 | 107 | # Updates and to-dos 108 | ## Past updates 109 | 1. Refactor prior policy module to support easy reuse (Sep. 26, 2019). 110 | 111 | ## To-do list 112 | (empty) 113 | 114 | # Reference 115 | Please cite the paper in the following format if you used this code during your research :) 116 | ``` 117 | @inproceedings{liu2020watch, 118 | title = {Watch the Unobserved: A Simple Approach to Parallelizing Monte Carlo Tree Search}, 119 | author = {Anji Liu and Jianshu Chen and Mingze Yu and Yu Zhai and Xuewen Zhou and Ji Liu}, 120 | booktitle = {International Conference on Learning Representations}, 121 | month = apr, 122 | year = {2020}, 123 | url = "https://openreview.net/forum?id=BJlQtJSKDB" 124 | } 125 | ``` 126 | -------------------------------------------------------------------------------- /Records/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Records/.gitkeep -------------------------------------------------------------------------------- /Results/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Results/.gitkeep -------------------------------------------------------------------------------- /Tree/UCT.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | from multiprocessing import Process 4 | import gc 5 | import time 6 | import random 7 | import os 8 | import torch 9 | 10 | from Node.UCTnode import UCTnode 11 | 12 | from Env.EnvWrapper import EnvWrapper 13 | 14 | from Mem.CheckpointManager import CheckpointManager 15 | 16 | from Policy.PolicyWrapper import PolicyWrapper 17 | 18 | 19 | class UCT(): 20 | def __init__(self, env_params, max_steps = 1000, max_depth = 20, max_width = 5, 21 | gamma = 1.0, policy = "Random", seed = 123, device = torch.device("cpu")): 22 | self.env_params = env_params 23 | self.max_steps = max_steps 24 | self.max_depth = max_depth 25 | self.max_width = max_width 26 | self.gamma = gamma 27 | self.policy = policy 28 | self.seed = seed 29 | self.device = device 30 | 31 | self.policy_wrapper = None 32 | 33 | # Environment 34 | self.wrapped_env = EnvWrapper(**env_params) 35 | 36 | # Environment properties 37 | self.action_n = self.wrapped_env.get_action_n() 38 | self.max_width = min(self.action_n, self.max_width) 39 | 40 | assert self.max_depth > 0 and 0 < self.max_width <= self.action_n 41 | 42 | # Checkpoint data manager 43 | self.checkpoint_data_manager = CheckpointManager() 44 | self.checkpoint_data_manager.hock_env("main", self.wrapped_env) 45 | 46 | # For MCTS tree 47 | self.root_node = None 48 | self.global_saving_idx = 0 49 | 50 | self.init_policy() 51 | 52 | def init_policy(self): 53 | self.policy_wrapper = PolicyWrapper( 54 | self.policy, 55 | self.env_params["env_name"], 56 | self.action_n, 57 | self.device 58 | ) 59 | 60 | # Entrance of the P-UCT algorithm 61 | def simulate_trajectory(self, max_episode_length = -1): 62 | state = self.wrapped_env.reset() 63 | accu_reward = 0.0 64 | done = False 65 | step_count = 0 66 | rewards = [] 67 | times = [] 68 | 69 | game_start_time = time.time() 70 | 71 | while not done and (max_episode_length == -1 or step_count < max_episode_length): 72 | simulation_start_time = time.time() 73 | action = self.simulate_single_move(state) 74 | simulation_end_time = time.time() 75 | 76 | next_state, reward, done = self.wrapped_env.step(action) 77 | rewards.append(reward) 78 | times.append(simulation_end_time - simulation_start_time) 79 | 80 | print("> Time step {}, take action {}, instance reward {}, cumulative reward {}, used {} seconds".format( 81 | step_count, action, reward, accu_reward + reward, simulation_end_time - simulation_start_time)) 82 | 83 | accu_reward += reward 84 | state = next_state 85 | step_count += 1 86 | 87 | game_end_time = time.time() 88 | print("> game ended. total reward: {}, used time {} s".format(accu_reward, game_end_time - game_start_time)) 89 | 90 | return accu_reward, np.array(rewards, dtype = np.float32), np.array(times, dtype = np.float32) 91 | 92 | def simulate_single_move(self, state): 93 | # Clear cache 94 | self.root_node = None 95 | self.global_saving_idx = 0 96 | self.checkpoint_data_manager.clear() 97 | 98 | gc.collect() 99 | 100 | # Construct root node 101 | self.checkpoint_data_manager.checkpoint_env("main", self.global_saving_idx) 102 | 103 | self.root_node = UCTnode( 104 | action_n = self.action_n, 105 | state = state, 106 | checkpoint_idx = self.global_saving_idx, 107 | parent = None, 108 | tree = self, 109 | is_head = True 110 | ) 111 | 112 | self.global_saving_idx += 1 113 | 114 | for _ in range(self.max_steps): 115 | self.simulate_single_step() 116 | 117 | best_action = self.root_node.max_utility_action() 118 | 119 | self.checkpoint_data_manager.load_checkpoint_env("main", self.root_node.checkpoint_idx) 120 | 121 | return best_action 122 | 123 | def simulate_single_step(self): 124 | # Go into root node 125 | curr_node = self.root_node 126 | 127 | # Selection 128 | curr_depth = 1 129 | while True: 130 | if curr_node.no_child_available() or (not curr_node.all_child_visited() and 131 | curr_node != self.root_node and np.random.random() < 0.5) or \ 132 | (not curr_node.all_child_visited() and curr_node == self.root_node): 133 | # If no child node has been updated, we have to perform expansion anyway. 134 | # Or if root node is not fully visited. 135 | # Or if non-root node is not fully visited and {with prob 1/2}. 136 | 137 | need_expansion = True 138 | break 139 | 140 | else: 141 | action = curr_node.select_action() 142 | 143 | curr_node.update_history(action, curr_node.rewards[action]) 144 | 145 | if curr_node.dones[action] or curr_depth >= self.max_depth: 146 | need_expansion = False 147 | break 148 | 149 | next_node = curr_node.children[action] 150 | 151 | curr_depth += 1 152 | curr_node = next_node 153 | 154 | # Expansion 155 | if need_expansion: 156 | expand_action = curr_node.select_expand_action() 157 | 158 | self.checkpoint_data_manager.load_checkpoint_env("main", curr_node.checkpoint_idx) 159 | next_state, reward, done = self.wrapped_env.step(expand_action) 160 | self.checkpoint_data_manager.checkpoint_env("main", self.global_saving_idx) 161 | 162 | curr_node.rewards[expand_action] = reward 163 | curr_node.dones[expand_action] = done 164 | 165 | curr_node.update_history( 166 | action_taken = expand_action, 167 | reward = reward 168 | ) 169 | 170 | curr_node.add_child( 171 | expand_action, 172 | next_state, 173 | self.global_saving_idx, 174 | prior_prob = self.get_prior_prob(next_state) 175 | ) 176 | self.global_saving_idx += 1 177 | else: 178 | self.checkpoint_data_manager.load_checkpoint_env("main", curr_node.checkpoint_idx) 179 | next_state, reward, done = self.wrapped_env.step(action) 180 | 181 | curr_node.rewards[action] = reward 182 | curr_node.dones[action] = done 183 | 184 | # Simulation 185 | done = False 186 | accu_reward = 0.0 187 | accu_gamma = 1.0 188 | 189 | while not done: 190 | action = self.get_action(next_state) 191 | 192 | next_state, reward, done = self.wrapped_env.step(action) 193 | 194 | accu_reward += reward * accu_gamma 195 | 196 | accu_gamma *= self.gamma 197 | 198 | # Complete Update 199 | self.complete_update(curr_node, self.root_node, accu_reward) 200 | 201 | def get_action(self, state): 202 | return self.policy_wrapper.get_action(state) 203 | 204 | def get_prior_prob(self, state): 205 | return self.policy_wrapper.get_prior_prob(state) 206 | 207 | def close(self): 208 | pass 209 | 210 | @staticmethod 211 | def complete_update(curr_node, curr_node_head, accu_reward): 212 | while curr_node != curr_node_head: 213 | accu_reward = curr_node.update(accu_reward) 214 | curr_node = curr_node.parent 215 | 216 | curr_node_head.update(accu_reward) 217 | -------------------------------------------------------------------------------- /Tree/WU_UCT.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | import gc 4 | import time 5 | import logging 6 | 7 | from Node.WU_UCTnode import WU_UCTnode 8 | 9 | from Env.EnvWrapper import EnvWrapper 10 | 11 | from ParallelPool.PoolManager import PoolManager 12 | 13 | from Mem.CheckpointManager import CheckpointManager 14 | 15 | 16 | class WU_UCT(): 17 | def __init__(self, env_params, max_steps = 1000, max_depth = 20, max_width = 5, 18 | gamma = 1.0, expansion_worker_num = 16, simulation_worker_num = 16, 19 | policy = "Random", seed = 123, device = "cpu", record_video = False): 20 | self.env_params = env_params 21 | self.max_steps = max_steps 22 | self.max_depth = max_depth 23 | self.max_width = max_width 24 | self.gamma = gamma 25 | self.expansion_worker_num = expansion_worker_num 26 | self.simulation_worker_num = simulation_worker_num 27 | self.policy = policy 28 | self.device = device 29 | self.record_video = record_video 30 | 31 | # Environment 32 | record_path = "Records/P-UCT_" + env_params["env_name"] + ".mp4" 33 | self.wrapped_env = EnvWrapper(**env_params, enable_record = record_video, 34 | record_path = record_path) 35 | 36 | # Environment properties 37 | self.action_n = self.wrapped_env.get_action_n() 38 | self.max_width = min(self.action_n, self.max_width) 39 | 40 | assert self.max_depth > 0 and 0 < self.max_width <= self.action_n 41 | 42 | # Expansion worker pool 43 | self.expansion_worker_pool = PoolManager( 44 | worker_num = expansion_worker_num, 45 | env_params = env_params, 46 | policy = policy, 47 | gamma = gamma, 48 | seed = seed, 49 | device = device, 50 | need_policy = False 51 | ) 52 | 53 | # Simulation worker pool 54 | self.simulation_worker_pool = PoolManager( 55 | worker_num = simulation_worker_num, 56 | env_params = env_params, 57 | policy = policy, 58 | gamma = gamma, 59 | seed = seed, 60 | device = device, 61 | need_policy = True 62 | ) 63 | 64 | # Checkpoint data manager 65 | self.checkpoint_data_manager = CheckpointManager() 66 | self.checkpoint_data_manager.hock_env("main", self.wrapped_env) 67 | 68 | # For MCTS tree 69 | self.root_node = None 70 | self.global_saving_idx = 0 71 | 72 | # Task recorder 73 | self.expansion_task_recorder = dict() 74 | self.unscheduled_expansion_tasks = list() 75 | self.simulation_task_recorder = dict() 76 | self.unscheduled_simulation_tasks = list() 77 | 78 | # Simulation count 79 | self.simulation_count = 0 80 | 81 | # Logging 82 | logging.basicConfig(filename = "Logs/P-UCT_" + self.env_params["env_name"] + "_" + 83 | str(self.simulation_worker_num) + ".log", level = logging.INFO) 84 | 85 | # Entrance of the P-UCT algorithm 86 | # This is the outer loop of P-UCT simulation, where the P-UCT agent consecutively plan a best action and 87 | # interact with the environment. 88 | def simulate_trajectory(self, max_episode_length = -1): 89 | state = self.wrapped_env.reset() 90 | accu_reward = 0.0 91 | done = False 92 | step_count = 0 93 | rewards = [] 94 | times = [] 95 | 96 | game_start_time = time.clock() 97 | 98 | logging.info("Start simulation") 99 | 100 | while not done and (max_episode_length == -1 or step_count < max_episode_length): 101 | # Plan a best action under the current state 102 | simulation_start_time = time.clock() 103 | action = self.simulate_single_move(state) 104 | simulation_end_time = time.clock() 105 | 106 | # Interact with the environment 107 | next_state, reward, done = self.wrapped_env.step(action) 108 | rewards.append(reward) 109 | times.append(simulation_end_time - simulation_start_time) 110 | 111 | print("> Time step {}, take action {}, instance reward {}, cumulative reward {}, used {} seconds".format( 112 | step_count, action, reward, accu_reward + reward, simulation_end_time - simulation_start_time)) 113 | logging.info("> Time step {}, take action {}, instance reward {}, cumulative reward {}, used {} seconds".format( 114 | step_count, action, reward, accu_reward + reward, simulation_end_time - simulation_start_time)) 115 | 116 | # Record video 117 | if self.record_video: 118 | self.wrapped_env.capture_frame() 119 | self.wrapped_env.store_video_files() 120 | 121 | # update game status 122 | accu_reward += reward 123 | state = next_state 124 | step_count += 1 125 | 126 | game_end_time = time.clock() 127 | print("> game ended. total reward: {}, used time {} s".format(accu_reward, game_end_time - game_start_time)) 128 | logging.info("> game ended. total reward: {}, used time {} s".format(accu_reward, 129 | game_end_time - game_start_time)) 130 | 131 | return accu_reward, np.array(rewards, dtype = np.float32), np.array(times, dtype = np.float32) 132 | 133 | # This is the planning process of P-UCT. Starts from a tree with a root node only, 134 | # P-UCT performs selection, expansion, simulation, and backpropagation on it. 135 | def simulate_single_move(self, state): 136 | # Clear cache 137 | self.root_node = None 138 | self.global_saving_idx = 0 139 | self.checkpoint_data_manager.clear() 140 | 141 | # Clear recorders 142 | self.expansion_task_recorder.clear() 143 | self.unscheduled_expansion_tasks.clear() 144 | self.simulation_task_recorder.clear() 145 | self.unscheduled_simulation_tasks.clear() 146 | 147 | gc.collect() 148 | 149 | # Free all workers 150 | self.expansion_worker_pool.wait_until_all_envs_idle() 151 | self.simulation_worker_pool.wait_until_all_envs_idle() 152 | 153 | # Construct root node 154 | self.checkpoint_data_manager.checkpoint_env("main", self.global_saving_idx) 155 | self.root_node = WU_UCTnode( 156 | action_n = self.action_n, 157 | state = state, 158 | checkpoint_idx = self.global_saving_idx, 159 | parent = None, 160 | tree = self, 161 | is_head = True 162 | ) 163 | 164 | # An index used to retrieve game-states 165 | self.global_saving_idx += 1 166 | 167 | # t_complete in the origin paper, measures the completed number of simulations 168 | self.simulation_count = 0 169 | 170 | # Repeatedly invoke the master loop (Figure 2 of the paper) 171 | sim_idx = 0 172 | while self.simulation_count < self.max_steps: 173 | self.simulate_single_step(sim_idx) 174 | 175 | sim_idx += 1 176 | 177 | # Select the best root action 178 | best_action = self.root_node.max_utility_action() 179 | 180 | # Retrieve the game-state before simulation begins 181 | self.checkpoint_data_manager.load_checkpoint_env("main", self.root_node.checkpoint_idx) 182 | 183 | return best_action 184 | 185 | def simulate_single_step(self, sim_idx): 186 | # Go into root node 187 | curr_node = self.root_node 188 | 189 | # Selection 190 | curr_depth = 1 191 | while True: 192 | if curr_node.no_child_available() or (not curr_node.all_child_visited() and 193 | curr_node != self.root_node and np.random.random() < 0.5) or \ 194 | (not curr_node.all_child_visited() and curr_node == self.root_node): 195 | # If no child node has been updated, we have to perform expansion anyway. 196 | # Or if root node is not fully visited. 197 | # Or if non-root node is not fully visited and {with prob 1/2}. 198 | 199 | cloned_curr_node = curr_node.shallow_clone() 200 | checkpoint_data = self.checkpoint_data_manager.retrieve(curr_node.checkpoint_idx) 201 | 202 | # Record the task 203 | self.expansion_task_recorder[sim_idx] = (checkpoint_data, cloned_curr_node, curr_node) 204 | self.unscheduled_expansion_tasks.append(sim_idx) 205 | 206 | need_expansion = True 207 | break 208 | 209 | else: 210 | action = curr_node.select_action() 211 | 212 | curr_node.update_history(sim_idx, action, curr_node.rewards[action]) 213 | 214 | if curr_node.dones[action] or curr_depth >= self.max_depth: 215 | # Exceed maximum depth 216 | need_expansion = False 217 | break 218 | 219 | if curr_node.children[action] is None: 220 | need_expansion = False 221 | break 222 | 223 | next_node = curr_node.children[action] 224 | 225 | curr_depth += 1 226 | curr_node = next_node 227 | 228 | # Expansion 229 | if not need_expansion: 230 | if not curr_node.dones[action]: 231 | # Reach maximum depth but have not terminate. 232 | # Record simulation task. 233 | 234 | self.simulation_task_recorder[sim_idx] = ( 235 | action, 236 | curr_node, 237 | curr_node.checkpoint_idx, 238 | None 239 | ) 240 | self.unscheduled_simulation_tasks.append(sim_idx) 241 | else: 242 | # Reach terminal node. 243 | # In this case, update directly. 244 | 245 | self.incomplete_update(curr_node, self.root_node, sim_idx) 246 | self.complete_update(curr_node, self.root_node, 0.0, sim_idx) 247 | 248 | self.simulation_count += 1 249 | 250 | else: 251 | # Assign tasks to idle server 252 | while len(self.unscheduled_expansion_tasks) > 0 and self.expansion_worker_pool.has_idle_server(): 253 | # Get a task 254 | curr_idx = np.random.randint(0, len(self.unscheduled_expansion_tasks)) 255 | task_idx = self.unscheduled_expansion_tasks.pop(curr_idx) 256 | 257 | # Assign the task to server 258 | checkpoint_data, cloned_curr_node, _ = self.expansion_task_recorder[task_idx] 259 | self.expansion_worker_pool.assign_expansion_task( 260 | checkpoint_data, 261 | cloned_curr_node, 262 | self.global_saving_idx, 263 | task_idx 264 | ) 265 | self.global_saving_idx += 1 266 | 267 | # Wait for an expansion task to complete 268 | if self.expansion_worker_pool.server_occupied_rate() >= 0.99: 269 | expand_action, next_state, reward, done, checkpoint_data, \ 270 | saving_idx, task_idx = self.expansion_worker_pool.get_complete_expansion_task() 271 | 272 | curr_node = self.expansion_task_recorder.pop(task_idx)[2] 273 | curr_node.update_history(task_idx, expand_action, reward) 274 | 275 | # Record info 276 | curr_node.dones[expand_action] = done 277 | curr_node.rewards[expand_action] = reward 278 | 279 | if done: 280 | # If this expansion result in a terminal node, perform update directly. 281 | # (simulation is not needed) 282 | 283 | self.incomplete_update(curr_node, self.root_node, task_idx) 284 | self.complete_update(curr_node, self.root_node, 0.0, task_idx) 285 | 286 | self.simulation_count += 1 287 | 288 | else: 289 | # Schedule the task to the simulation task buffer. 290 | 291 | self.checkpoint_data_manager.store(saving_idx, checkpoint_data) 292 | 293 | self.simulation_task_recorder[task_idx] = ( 294 | expand_action, 295 | curr_node, 296 | saving_idx, 297 | deepcopy(next_state) 298 | ) 299 | self.unscheduled_simulation_tasks.append(task_idx) 300 | 301 | # Assign simulation tasks to idle environment server 302 | while len(self.unscheduled_simulation_tasks) > 0 and self.simulation_worker_pool.has_idle_server(): 303 | # Get a task 304 | idx = np.random.randint(0, len(self.unscheduled_simulation_tasks)) 305 | task_idx = self.unscheduled_simulation_tasks.pop(idx) 306 | 307 | checkpoint_data = self.checkpoint_data_manager.retrieve(self.simulation_task_recorder[task_idx][2]) 308 | 309 | first_aciton = None if self.simulation_task_recorder[task_idx][3] \ 310 | is not None else self.simulation_task_recorder[task_idx][0] 311 | 312 | # Assign the task to server 313 | self.simulation_worker_pool.assign_simulation_task( 314 | task_idx, 315 | checkpoint_data, 316 | first_action = first_aciton 317 | ) 318 | 319 | # Perform incomplete update 320 | self.incomplete_update( 321 | self.simulation_task_recorder[task_idx][1], # This is the corresponding node 322 | self.root_node, 323 | task_idx 324 | ) 325 | 326 | # Wait for a simulation task to complete 327 | if self.simulation_worker_pool.server_occupied_rate() >= 0.99: 328 | args = self.simulation_worker_pool.get_complete_simulation_task() 329 | if len(args) == 3: 330 | task_idx, accu_reward, prior_prob = args 331 | else: 332 | task_idx, accu_reward, reward, done = args 333 | expand_action, curr_node, saving_idx, next_state = self.simulation_task_recorder.pop(task_idx) 334 | 335 | if len(args) == 4: 336 | curr_node.rewards[expand_action] = reward 337 | curr_node.dones[expand_action] = done 338 | 339 | # Add node 340 | if next_state is not None: 341 | curr_node.add_child( 342 | expand_action, 343 | next_state, 344 | saving_idx, 345 | prior_prob = prior_prob 346 | ) 347 | 348 | # Complete Update 349 | self.complete_update(curr_node, self.root_node, accu_reward, task_idx) 350 | 351 | self.simulation_count += 1 352 | 353 | def close(self): 354 | # Free sub-processes 355 | self.expansion_worker_pool.close_pool() 356 | self.simulation_worker_pool.close_pool() 357 | 358 | # Incomplete update allows to track unobserved samples (Algorithm 2 in the paper) 359 | @staticmethod 360 | def incomplete_update(curr_node, curr_node_head, idx): 361 | while curr_node != curr_node_head: 362 | curr_node.update_incomplete(idx) 363 | curr_node = curr_node.parent 364 | 365 | curr_node_head.update_incomplete(idx) 366 | 367 | # Complete update tracks the observed samples (Algorithm 3 in the paper) 368 | @staticmethod 369 | def complete_update(curr_node, curr_node_head, accu_reward, idx): 370 | while curr_node != curr_node_head: 371 | accu_reward = curr_node.update_complete(idx, accu_reward) 372 | curr_node = curr_node.parent 373 | 374 | curr_node_head.update_complete(idx, accu_reward) 375 | -------------------------------------------------------------------------------- /Tree/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Tree/__init__.py -------------------------------------------------------------------------------- /Utils/Atari_PPO_training/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Lukas Palmer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Utils/Atari_PPO_training/README.md: -------------------------------------------------------------------------------- 1 | # PPO 2 | PyTorch implementation of Proximal Policy Optimization 3 | 4 | ![live agents](assets/agents.gif) 5 | 6 | ## Usage 7 | 8 | Example command line usage: 9 | ```` 10 | python main.py BreakoutNoFrameskip-v0 --num-workers 8 --render 11 | ```` 12 | 13 | This will run PPO with 8 parallel training environments, which will be rendered on the screen. Run with `-h` for usage information. 14 | 15 | ## Performance 16 | 17 | Results are comparable to those of the original PPO paper. The horizontal axis here is labeled by environment steps, whereas the graphs in the paper label it with frames, with 4 frames per step. 18 | 19 | Training episode reward versus environment steps for `BreakoutNoFrameskip-v3`: 20 | 21 | ![Breakout training curve](assets/breakout_reward.png) 22 | 23 | ## References 24 | 25 | [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347) 26 | 27 | [OpenAI Baselines](https://github.com/openai/baselines) 28 | 29 | This code uses some environment utilities such as `SubprocVecEnv` and `VecFrameStack` from OpenAI's Baselines. 30 | -------------------------------------------------------------------------------- /Utils/Atari_PPO_training/atari_wrappers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | os.environ.setdefault('PATH', '') 4 | from collections import deque 5 | import gym 6 | from gym import spaces 7 | import cv2 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | class TimeLimit(gym.Wrapper): 11 | def __init__(self, env, max_episode_steps=None): 12 | super(TimeLimit, self).__init__(env) 13 | self._max_episode_steps = max_episode_steps 14 | self._elapsed_steps = 0 15 | 16 | def step(self, ac): 17 | observation, reward, done, info = self.env.step(ac) 18 | self._elapsed_steps += 1 19 | if self._elapsed_steps >= self._max_episode_steps: 20 | done = True 21 | info['TimeLimit.truncated'] = True 22 | return observation, reward, done, info 23 | 24 | def reset(self, **kwargs): 25 | self._elapsed_steps = 0 26 | return self.env.reset(**kwargs) 27 | 28 | class NoopResetEnv(gym.Wrapper): 29 | def __init__(self, env, noop_max=30): 30 | """Sample initial states by taking random number of no-ops on reset. 31 | No-op is assumed to be action 0. 32 | """ 33 | gym.Wrapper.__init__(self, env) 34 | self.noop_max = noop_max 35 | self.override_num_noops = None 36 | self.noop_action = 0 37 | assert env.unwrapped.get_action_meanings()[0] == 'NOOP' 38 | 39 | def reset(self, **kwargs): 40 | """ Do no-op action for a number of steps in [1, noop_max].""" 41 | self.env.reset(**kwargs) 42 | if self.override_num_noops is not None: 43 | noops = self.override_num_noops 44 | else: 45 | noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101 46 | assert noops > 0 47 | obs = None 48 | for _ in range(noops): 49 | obs, _, done, _ = self.env.step(self.noop_action) 50 | if done: 51 | obs = self.env.reset(**kwargs) 52 | return obs 53 | 54 | def step(self, ac): 55 | return self.env.step(ac) 56 | 57 | class FireResetEnv(gym.Wrapper): 58 | def __init__(self, env): 59 | """Take action on reset for environments that are fixed until firing.""" 60 | gym.Wrapper.__init__(self, env) 61 | assert env.unwrapped.get_action_meanings()[1] == 'FIRE' 62 | assert len(env.unwrapped.get_action_meanings()) >= 3 63 | 64 | def reset(self, **kwargs): 65 | self.env.reset(**kwargs) 66 | obs, _, done, _ = self.env.step(1) 67 | if done: 68 | self.env.reset(**kwargs) 69 | obs, _, done, _ = self.env.step(2) 70 | if done: 71 | self.env.reset(**kwargs) 72 | return obs 73 | 74 | def step(self, ac): 75 | return self.env.step(ac) 76 | 77 | class EpisodicLifeEnv(gym.Wrapper): 78 | def __init__(self, env): 79 | """Make end-of-life == end-of-episode, but only reset on true game over. 80 | Done by DeepMind for the DQN and co. since it helps value estimation. 81 | """ 82 | gym.Wrapper.__init__(self, env) 83 | self.lives = 0 84 | self.was_real_done = True 85 | 86 | def step(self, action): 87 | obs, reward, done, info = self.env.step(action) 88 | self.was_real_done = done 89 | # check current lives, make loss of life terminal, 90 | # then update lives to handle bonus lives 91 | lives = self.env.unwrapped.ale.lives() 92 | if lives < self.lives and lives > 0: 93 | # for Qbert sometimes we stay in lives == 0 condition for a few frames 94 | # so it's important to keep lives > 0, so that we only reset once 95 | # the environment advertises done. 96 | done = True 97 | self.lives = lives 98 | return obs, reward, done, info 99 | 100 | def reset(self, **kwargs): 101 | """Reset only when lives are exhausted. 102 | This way all states are still reachable even though lives are episodic, 103 | and the learner need not know about any of this behind-the-scenes. 104 | """ 105 | if self.was_real_done: 106 | obs = self.env.reset(**kwargs) 107 | else: 108 | # no-op step to advance from terminal/lost life state 109 | obs, _, _, _ = self.env.step(0) 110 | self.lives = self.env.unwrapped.ale.lives() 111 | return obs 112 | 113 | class MaxAndSkipEnv(gym.Wrapper): 114 | def __init__(self, env, skip=4): 115 | """Return only every `skip`-th frame""" 116 | gym.Wrapper.__init__(self, env) 117 | # most recent raw observations (for max pooling across time steps) 118 | self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8) 119 | self._skip = skip 120 | 121 | def step(self, action): 122 | """Repeat action, sum reward, and max over last observations.""" 123 | total_reward = 0.0 124 | done = None 125 | for i in range(self._skip): 126 | obs, reward, done, info = self.env.step(action) 127 | if i == self._skip - 2: self._obs_buffer[0] = obs 128 | if i == self._skip - 1: self._obs_buffer[1] = obs 129 | total_reward += reward 130 | if done: 131 | break 132 | # Note that the observation on the done=True frame 133 | # doesn't matter 134 | max_frame = self._obs_buffer.max(axis=0) 135 | 136 | return max_frame, total_reward, done, info 137 | 138 | def reset(self, **kwargs): 139 | return self.env.reset(**kwargs) 140 | 141 | class ClipRewardEnv(gym.RewardWrapper): 142 | def __init__(self, env): 143 | gym.RewardWrapper.__init__(self, env) 144 | 145 | def reward(self, reward): 146 | """Bin reward to {+1, 0, -1} by its sign.""" 147 | return np.sign(reward) 148 | 149 | 150 | class WarpFrame(gym.ObservationWrapper): 151 | def __init__(self, env, width=84, height=84, grayscale=True, dict_space_key=None): 152 | """ 153 | Warp frames to 84x84 as done in the Nature paper and later work. 154 | 155 | If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which 156 | observation should be warped. 157 | """ 158 | super().__init__(env) 159 | self._width = width 160 | self._height = height 161 | self._grayscale = grayscale 162 | self._key = dict_space_key 163 | if self._grayscale: 164 | num_colors = 1 165 | else: 166 | num_colors = 3 167 | 168 | new_space = gym.spaces.Box( 169 | low=0, 170 | high=255, 171 | shape=(self._height, self._width, num_colors), 172 | dtype=np.uint8, 173 | ) 174 | if self._key is None: 175 | original_space = self.observation_space 176 | self.observation_space = new_space 177 | else: 178 | original_space = self.observation_space.spaces[self._key] 179 | self.observation_space.spaces[self._key] = new_space 180 | assert original_space.dtype == np.uint8 and len(original_space.shape) == 3 181 | 182 | def observation(self, obs): 183 | if self._key is None: 184 | frame = obs 185 | else: 186 | frame = obs[self._key] 187 | 188 | if self._grayscale: 189 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) 190 | frame = cv2.resize( 191 | frame, (self._width, self._height), interpolation=cv2.INTER_AREA 192 | ) 193 | if self._grayscale: 194 | frame = np.expand_dims(frame, -1) 195 | 196 | if self._key is None: 197 | obs = frame 198 | else: 199 | obs = obs.copy() 200 | obs[self._key] = frame 201 | return obs 202 | 203 | 204 | class FrameStack(gym.Wrapper): 205 | def __init__(self, env, k): 206 | """Stack k last frames. 207 | 208 | Returns lazy array, which is much more memory efficient. 209 | 210 | See Also 211 | -------- 212 | baselines.common.atari_wrappers.LazyFrames 213 | """ 214 | gym.Wrapper.__init__(self, env) 215 | self.k = k 216 | self.frames = deque([], maxlen=k) 217 | shp = env.observation_space.shape 218 | self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=env.observation_space.dtype) 219 | 220 | def reset(self): 221 | ob = self.env.reset() 222 | for _ in range(self.k): 223 | self.frames.append(ob) 224 | return self._get_ob() 225 | 226 | def step(self, action): 227 | ob, reward, done, info = self.env.step(action) 228 | self.frames.append(ob) 229 | return self._get_ob(), reward, done, info 230 | 231 | def _get_ob(self): 232 | assert len(self.frames) == self.k 233 | return LazyFrames(list(self.frames)) 234 | 235 | class ScaledFloatFrame(gym.ObservationWrapper): 236 | def __init__(self, env): 237 | gym.ObservationWrapper.__init__(self, env) 238 | self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32) 239 | 240 | def observation(self, observation): 241 | # careful! This undoes the memory optimization, use 242 | # with smaller replay buffers only. 243 | return np.array(observation).astype(np.float32) / 255.0 244 | 245 | class LazyFrames(object): 246 | def __init__(self, frames): 247 | """This object ensures that common frames between the observations are only stored once. 248 | It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay 249 | buffers. 250 | 251 | This object should only be converted to numpy array before being passed to the model. 252 | 253 | You'd not believe how complex the previous solution was.""" 254 | self._frames = frames 255 | self._out = None 256 | 257 | def _force(self): 258 | if self._out is None: 259 | self._out = np.concatenate(self._frames, axis=-1) 260 | self._frames = None 261 | return self._out 262 | 263 | def __array__(self, dtype=None): 264 | out = self._force() 265 | if dtype is not None: 266 | out = out.astype(dtype) 267 | return out 268 | 269 | def __len__(self): 270 | return len(self._force()) 271 | 272 | def __getitem__(self, i): 273 | return self._force()[i] 274 | 275 | def count(self): 276 | frames = self._force() 277 | return frames.shape[frames.ndim - 1] 278 | 279 | def frame(self, i): 280 | return self._force()[..., i] 281 | 282 | def make_atari(env_id, max_episode_steps=None): 283 | env = gym.make(env_id) 284 | assert 'NoFrameskip' in env.spec.id 285 | env = NoopResetEnv(env, noop_max=30) 286 | env = MaxAndSkipEnv(env, skip=4) 287 | if max_episode_steps is not None: 288 | env = TimeLimit(env, max_episode_steps=max_episode_steps) 289 | return env 290 | 291 | def wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=False): 292 | """Configure environment for DeepMind-style Atari. 293 | """ 294 | if episode_life: 295 | env = EpisodicLifeEnv(env) 296 | if 'FIRE' in env.unwrapped.get_action_meanings(): 297 | env = FireResetEnv(env) 298 | env = WarpFrame(env) 299 | if scale: 300 | env = ScaledFloatFrame(env) 301 | if clip_rewards: 302 | env = ClipRewardEnv(env) 303 | if frame_stack: 304 | env = FrameStack(env, 4) 305 | return env 306 | 307 | -------------------------------------------------------------------------------- /Utils/Atari_PPO_training/envs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | from gym import spaces 4 | from atari_wrappers import make_atari, wrap_deepmind 5 | from vec_env import VecEnv 6 | from multiprocessing import Process, Pipe 7 | 8 | 9 | # cf https://github.com/openai/baselines 10 | 11 | def make_env(env_name, rank, seed): 12 | env = make_atari(env_name) 13 | env.seed(seed + rank) 14 | env = wrap_deepmind(env, episode_life=False, clip_rewards=False) 15 | return env 16 | 17 | 18 | def worker(remote, parent_remote, env_fn_wrapper): 19 | parent_remote.close() 20 | env = env_fn_wrapper.x() 21 | while True: 22 | cmd, data = remote.recv() 23 | if cmd == 'step': 24 | ob, reward, done, info = env.step(data) 25 | if done: 26 | ob = env.reset() 27 | remote.send((ob, reward, done, info)) 28 | elif cmd == 'reset': 29 | ob = env.reset() 30 | remote.send(ob) 31 | elif cmd == 'reset_task': 32 | ob = env.reset_task() 33 | remote.send(ob) 34 | elif cmd == 'close': 35 | remote.close() 36 | break 37 | elif cmd == 'get_spaces': 38 | remote.send((env.action_space, env.observation_space)) 39 | elif cmd == 'render': 40 | env.render() 41 | else: 42 | raise NotImplementedError 43 | 44 | 45 | class CloudpickleWrapper(object): 46 | """ 47 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 48 | """ 49 | def __init__(self, x): 50 | self.x = x 51 | def __getstate__(self): 52 | import cloudpickle 53 | return cloudpickle.dumps(self.x) 54 | def __setstate__(self, ob): 55 | import pickle 56 | self.x = pickle.loads(ob) 57 | 58 | 59 | class RenderSubprocVecEnv(VecEnv): 60 | def __init__(self, env_fns, render_interval): 61 | """ Minor addition to SubprocVecEnv, automatically renders environments 62 | 63 | envs: list of gym environments to run in subprocesses 64 | """ 65 | self.closed = False 66 | nenvs = len(env_fns) 67 | self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) 68 | self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) 69 | for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] 70 | for p in self.ps: 71 | p.daemon = True # if the main process crashes, we should not cause things to hang 72 | p.start() 73 | for remote in self.work_remotes: 74 | remote.close() 75 | 76 | self.remotes[0].send(('get_spaces', None)) 77 | self.action_space, self.observation_space = self.remotes[0].recv() 78 | 79 | self.render_interval = render_interval 80 | self.render_timer = 0 81 | 82 | def step(self, actions): 83 | for remote, action in zip(self.remotes, actions): 84 | remote.send(('step', action)) 85 | results = [remote.recv() for remote in self.remotes] 86 | obs, rews, dones, infos = zip(*results) 87 | 88 | self.render_timer += 1 89 | if self.render_timer == self.render_interval: 90 | for remote in self.remotes: 91 | remote.send(('render', None)) 92 | self.render_timer = 0 93 | 94 | return np.stack(obs), np.stack(rews), np.stack(dones), infos 95 | 96 | def reset(self): 97 | for remote in self.remotes: 98 | remote.send(('reset', None)) 99 | return np.stack([remote.recv() for remote in self.remotes]) 100 | 101 | def reset_task(self): 102 | for remote in self.remotes: 103 | remote.send(('reset_task', None)) 104 | return np.stack([remote.recv() for remote in self.remotes]) 105 | 106 | def close(self): 107 | if self.closed: 108 | return 109 | 110 | for remote in self.remotes: 111 | remote.send(('close', None)) 112 | for p in self.ps: 113 | p.join() 114 | self.closed = True 115 | 116 | @property 117 | def num_envs(self): 118 | return len(self.remotes) 119 | -------------------------------------------------------------------------------- /Utils/Atari_PPO_training/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.optim as optim 4 | from atari_wrappers import FrameStack 5 | from subproc_vec_env import SubprocVecEnv 6 | from vec_frame_stack import VecFrameStack 7 | import multiprocessing 8 | 9 | from envs import make_env, RenderSubprocVecEnv 10 | from models import AtariCNN 11 | from ppo import PPO 12 | from utils import set_seed, cuda_if 13 | 14 | if __name__ == "__main__": 15 | multiprocessing.set_start_method("forkserver") 16 | 17 | parser = argparse.ArgumentParser(description='PPO', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | parser.add_argument('env_id', type=str, help='Gym environment id') 19 | parser.add_argument('--arch', type=str, default='cnn', help='policy architecture, {lstm, cnn}') 20 | parser.add_argument('--num-workers', type=int, default=8, help='number of parallel actors') 21 | parser.add_argument('--opt-epochs', type=int, default=3, help='optimization epochs between environment interaction') 22 | parser.add_argument('--total-steps', type=int, default=int(10e6), help='total number of environment steps to take') 23 | parser.add_argument('--worker-steps', type=int, default=128, help='steps per worker between optimization rounds') 24 | parser.add_argument('--sequence-steps', type=int, default=32, help='steps per sequence (for backprop through time)') 25 | parser.add_argument('--minibatch-steps', type=int, default=256, help='steps per optimization minibatch') 26 | parser.add_argument('--lr', type=float, default=2.5e-4, help='initial learning rate') 27 | parser.add_argument('--lr-func', type=str, default='linear', help='learning rate schedule function, {linear, constant}') 28 | parser.add_argument('--clip', type=float, default=.1, help='initial probability ratio clipping range') 29 | parser.add_argument('--clip-func', type=str, default='linear', help='clip range schedule function, {linear, constant}') 30 | parser.add_argument('--gamma', type=float, default=.99, help='discount factor') 31 | parser.add_argument('--lambd', type=float, default=.95, help='GAE lambda parameter') 32 | parser.add_argument('--value-coef', type=float, default=1., help='value loss coeffecient') 33 | parser.add_argument('--entropy-coef', type=float, default=.01, help='entropy loss coeffecient') 34 | parser.add_argument('--max-grad-norm', type=float, default=.5, help='grad norm to clip at') 35 | parser.add_argument('--no-cuda', action='store_true', help='disable CUDA acceleration') 36 | parser.add_argument('--render', action='store_true', help='render training environments') 37 | parser.add_argument('--render-interval', type=int, default=4, help='steps between environment renders') 38 | parser.add_argument('--plot-reward', action='store_true', help='plot episode reward') 39 | parser.add_argument('--plot-points', type=int, default=20, help='number of plot points (groups with mean, std)') 40 | parser.add_argument('--plot-path', type=str, default='ep_reward.png', help='path to save reward plot to') 41 | parser.add_argument('--seed', type=int, default=0, help='random seed') 42 | args = parser.parse_args() 43 | 44 | set_seed(args.seed) 45 | 46 | cuda = torch.cuda.is_available() and not args.no_cuda 47 | 48 | env_fns = [] 49 | for rank in range(args.num_workers): 50 | env_fns.append(lambda: make_env(args.env_id, rank, args.seed + rank)) 51 | if args.render: 52 | venv = RenderSubprocVecEnv(env_fns, args.render_interval) 53 | else: 54 | venv = SubprocVecEnv(env_fns) 55 | venv = VecFrameStack(venv, 4) 56 | 57 | test_env = make_env(args.env_id, 0, args.seed) 58 | test_env = FrameStack(test_env, 4) 59 | 60 | policy = {'cnn': AtariCNN}[args.arch](venv.action_space.n) 61 | policy = cuda_if(policy, cuda) 62 | 63 | optimizer = optim.Adam(policy.parameters()) 64 | 65 | if args.lr_func == 'linear': 66 | lr_func = lambda a: args.lr * (1. - a) 67 | elif args.lr_func == 'constant': 68 | lr_func = lambda a: args.lr 69 | 70 | if args.clip_func == 'linear': 71 | clip_func = lambda a: args.clip * (1. - a) 72 | elif args.clip_func == 'constant': 73 | clip_func = lambda a: args.clip 74 | 75 | algorithm = PPO(policy, venv, test_env, optimizer, 76 | lr_func=lr_func, clip_func=clip_func, gamma=args.gamma, lambd=args.lambd, 77 | worker_steps=args.worker_steps, sequence_steps=args.sequence_steps, 78 | minibatch_steps=args.minibatch_steps, 79 | value_coef=args.value_coef, entropy_coef=args.entropy_coef, 80 | max_grad_norm=args.max_grad_norm, 81 | cuda=cuda, 82 | plot_reward=args.plot_reward, plot_points=args.plot_points, plot_path=args.plot_path, env_name = args.env_id) 83 | algorithm.run(args.total_steps) 84 | -------------------------------------------------------------------------------- /Utils/Atari_PPO_training/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | 7 | def ortho_weights(shape, scale=1.): 8 | """ PyTorch port of ortho_init from baselines.a2c.utils """ 9 | shape = tuple(shape) 10 | 11 | if len(shape) == 2: 12 | flat_shape = shape[1], shape[0] 13 | elif len(shape) == 4: 14 | flat_shape = (np.prod(shape[1:]), shape[0]) 15 | else: 16 | raise NotImplementedError 17 | 18 | a = np.random.normal(0., 1., flat_shape) 19 | u, _, v = np.linalg.svd(a, full_matrices=False) 20 | q = u if u.shape == flat_shape else v 21 | q = q.transpose().copy().reshape(shape) 22 | 23 | if len(shape) == 2: 24 | return torch.from_numpy((scale * q).astype(np.float32)) 25 | if len(shape) == 4: 26 | return torch.from_numpy((scale * q[:, :shape[1], :shape[2]]).astype(np.float32)) 27 | 28 | 29 | def atari_initializer(module): 30 | """ Parameter initializer for Atari models 31 | 32 | Initializes Linear, Conv2d, and LSTM weights. 33 | """ 34 | classname = module.__class__.__name__ 35 | 36 | if classname == 'Linear': 37 | module.weight.data = ortho_weights(module.weight.data.size(), scale=np.sqrt(2.)) 38 | module.bias.data.zero_() 39 | 40 | elif classname == 'Conv2d': 41 | module.weight.data = ortho_weights(module.weight.data.size(), scale=np.sqrt(2.)) 42 | module.bias.data.zero_() 43 | 44 | elif classname == 'LSTM': 45 | for name, param in module.named_parameters(): 46 | if 'weight_ih' in name: 47 | param.data = ortho_weights(param.data.size(), scale=1.) 48 | if 'weight_hh' in name: 49 | param.data = ortho_weights(param.data.size(), scale=1.) 50 | if 'bias' in name: 51 | param.data.zero_() 52 | 53 | 54 | class AtariCNN(nn.Module): 55 | def __init__(self, num_actions): 56 | """ Basic convolutional actor-critic network for Atari 2600 games 57 | 58 | Equivalent to the network in the original DQN paper. 59 | 60 | Args: 61 | num_actions (int): the number of available discrete actions 62 | """ 63 | super().__init__() 64 | 65 | self.conv = nn.Sequential(nn.Conv2d(4, 32, 8, stride=4), 66 | nn.ReLU(inplace=True), 67 | nn.Conv2d(32, 64, 4, stride=2), 68 | nn.ReLU(inplace=True), 69 | nn.Conv2d(64, 64, 3, stride=1), 70 | nn.ReLU(inplace=True)) 71 | 72 | self.fc = nn.Sequential(nn.Linear(64 * 7 * 7, 512), 73 | nn.ReLU(inplace=True)) 74 | 75 | self.pi = nn.Linear(512, num_actions) 76 | self.v = nn.Linear(512, 1) 77 | 78 | self.num_actions = num_actions 79 | 80 | # parameter initialization 81 | self.apply(atari_initializer) 82 | self.pi.weight.data = ortho_weights(self.pi.weight.size(), scale=.01) 83 | self.v.weight.data = ortho_weights(self.v.weight.size()) 84 | 85 | def forward(self, conv_in): 86 | """ Module forward pass 87 | 88 | Args: 89 | conv_in (Variable): convolutional input, shaped [N x 4 x 84 x 84] 90 | 91 | Returns: 92 | pi (Variable): action probability logits, shaped [N x self.num_actions] 93 | v (Variable): value predictions, shaped [N x 1] 94 | """ 95 | N = conv_in.size()[0] 96 | 97 | conv_out = self.conv(conv_in).view(N, 64 * 7 * 7) 98 | 99 | fc_out = self.fc(conv_out) 100 | 101 | pi_out = self.pi(fc_out) 102 | v_out = self.v(fc_out) 103 | 104 | return pi_out, v_out 105 | -------------------------------------------------------------------------------- /Utils/Atari_PPO_training/ppo.py: -------------------------------------------------------------------------------- 1 | import random 2 | import copy 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as Fnn 7 | import matplotlib 8 | matplotlib.use('Agg') 9 | import matplotlib.pyplot as plt 10 | from torch.autograd import Variable 11 | 12 | from utils import gae, cuda_if, mean_std_groups, set_lr 13 | 14 | 15 | class PPO: 16 | def __init__(self, policy, venv, test_env, optimizer, 17 | lr_func=None, clip_func=None, gamma=.99, lambd=.95, 18 | worker_steps=128, sequence_steps=32, minibatch_steps=256, 19 | opt_epochs=3, value_coef=1., entropy_coef=.01, max_grad_norm=.5, 20 | cuda=False, plot_reward=False, plot_points=20, plot_path='ep_reward.png', 21 | test_repeat_max=100, env_name = ""): 22 | """ Proximal Policy Optimization algorithm class 23 | 24 | Evaluates a policy over a vectorized environment and 25 | optimizes over policy, value, entropy objectives. 26 | 27 | Assumes discrete action space. 28 | 29 | Args: 30 | policy (nn.Module): the policy to optimize 31 | venv (vec_env): the vectorized environment to use 32 | test_env (Env): the environment to use for policy testing 33 | optimizer (optim.Optimizer): the optimizer to use 34 | clip (float): probability ratio clipping range 35 | gamma (float): discount factor 36 | lambd (float): GAE lambda parameter 37 | worker_steps (int): steps per worker between optimization rounds 38 | sequence_steps (int): steps per sequence (for backprop through time) 39 | batch_steps (int): steps per sequence (for backprop through time) 40 | """ 41 | self.policy = policy 42 | self.policy_old = copy.deepcopy(policy) 43 | self.venv = venv 44 | self.test_env = test_env 45 | self.optimizer = optimizer 46 | 47 | self.env_name = env_name 48 | 49 | self.lr_func = lr_func 50 | self.clip_func = clip_func 51 | 52 | self.num_workers = venv.num_envs 53 | self.worker_steps = worker_steps 54 | self.sequence_steps = sequence_steps 55 | self.minibatch_steps = minibatch_steps 56 | 57 | self.opt_epochs = opt_epochs 58 | self.gamma = gamma 59 | self.lambd = lambd 60 | self.value_coef = value_coef 61 | self.entropy_coef = entropy_coef 62 | self.max_grad_norm = max_grad_norm 63 | self.cuda = cuda 64 | 65 | self.plot_reward = plot_reward 66 | self.plot_points = plot_points 67 | self.plot_path = plot_path 68 | self.ep_reward = np.zeros(self.num_workers) 69 | self.reward_histr = [] 70 | self.steps_histr = [] 71 | 72 | self.objective = PPOObjective() 73 | 74 | self.last_ob = self.venv.reset() 75 | 76 | self.taken_steps = 0 77 | 78 | self.test_repeat_max = test_repeat_max 79 | 80 | def run(self, total_steps): 81 | """ Runs PPO 82 | 83 | Args: 84 | total_steps (int): total number of environment steps to run for 85 | """ 86 | N = self.num_workers 87 | T = self.worker_steps 88 | E = self.opt_epochs 89 | A = self.venv.action_space.n 90 | 91 | while self.taken_steps < total_steps: 92 | progress = self.taken_steps / total_steps 93 | 94 | obs, rewards, masks, actions, steps = self.interact() 95 | ob_shape = obs.size()[2:] 96 | 97 | ep_reward = self.test() 98 | self.reward_histr.append(ep_reward) 99 | self.steps_histr.append(self.taken_steps) 100 | 101 | # statistic logic 102 | group_size = len(self.steps_histr) // self.plot_points 103 | if self.plot_reward and len(self.steps_histr) % (self.plot_points * 10) == 0 and group_size >= 10: 104 | x_means, _, y_means, y_stds = \ 105 | mean_std_groups(np.array(self.steps_histr), np.array(self.reward_histr), group_size) 106 | fig = plt.figure() 107 | fig.set_size_inches(8, 6) 108 | plt.ticklabel_format(axis='x', style='sci', scilimits=(-2, 6)) 109 | plt.errorbar(x_means, y_means, yerr=y_stds, ecolor='xkcd:blue', fmt='xkcd:black', capsize=5, elinewidth=1.5, mew=1.5, linewidth=1.5) 110 | plt.title('Training progress') 111 | plt.xlabel('Total steps') 112 | plt.ylabel('Episode reward') 113 | plt.savefig(self.plot_path, dpi=200) 114 | plt.clf() 115 | plt.close() 116 | plot_timer = 0 117 | 118 | # TEMP upgrade to support recurrence 119 | 120 | # compute advantages, returns with GAE 121 | obs_ = obs.view(((T + 1) * N,) + ob_shape) 122 | obs_ = Variable(obs_) 123 | _, values = self.policy(obs_) 124 | values = values.view(T + 1, N, 1) 125 | advantages, returns = gae(rewards, masks, values, self.gamma, self.lambd) 126 | 127 | self.policy_old.load_state_dict(self.policy.state_dict()) 128 | for e in range(E): 129 | self.policy.zero_grad() 130 | 131 | MB = steps // self.minibatch_steps 132 | 133 | b_obs = Variable(obs[:T].view((steps,) + ob_shape)) 134 | b_rewards = Variable(rewards.view(steps, 1)) 135 | b_masks = Variable(masks.view(steps, 1)) 136 | b_actions = Variable(actions.view(steps, 1)) 137 | b_advantages = Variable(advantages.view(steps, 1)) 138 | b_returns = Variable(returns.view(steps, 1)) 139 | 140 | b_inds = np.arange(steps) 141 | np.random.shuffle(b_inds) 142 | 143 | for start in range(0, steps, self.minibatch_steps): 144 | mb_inds = b_inds[start:start + self.minibatch_steps] 145 | mb_inds = cuda_if(torch.from_numpy(mb_inds).long(), self.cuda) 146 | mb_obs, mb_rewards, mb_masks, mb_actions, mb_advantages, mb_returns = \ 147 | [arr[mb_inds] for arr in [b_obs, b_rewards, b_masks, b_actions, b_advantages, b_returns]] 148 | 149 | mb_pis, mb_vs = self.policy(mb_obs) 150 | mb_pi_olds, mb_v_olds = self.policy_old(mb_obs) 151 | mb_pi_olds, mb_v_olds = mb_pi_olds.detach(), mb_v_olds.detach() 152 | 153 | losses = self.objective(self.clip_func(progress), 154 | mb_pis, mb_vs, mb_pi_olds, mb_v_olds, 155 | mb_actions, mb_advantages, mb_returns) 156 | policy_loss, value_loss, entropy_loss = losses 157 | loss = policy_loss + value_loss * self.value_coef + entropy_loss * self.entropy_coef 158 | 159 | set_lr(self.optimizer, self.lr_func(progress)) 160 | self.optimizer.zero_grad() 161 | loss.backward() 162 | torch.nn.utils.clip_grad_norm(self.policy.parameters(), self.max_grad_norm) 163 | self.optimizer.step() 164 | 165 | self.taken_steps += steps 166 | print(self.taken_steps) 167 | 168 | torch.save({'policy': self.policy.state_dict()}, "./save/PPO_" + self.env_name + ".pt") 169 | 170 | def interact(self): 171 | """ Interacts with the environment 172 | 173 | Returns: 174 | obs (ArgumentDefaultsHelpFormatternsor): observations shaped [T + 1 x N x ...] 175 | rewards (FloatTensor): rewards shaped [T x N x 1] 176 | masks (FloatTensor): continuation masks shaped [T x N x 1] 177 | zero at done timesteps, one otherwise 178 | actions (LongTensor): discrete actions shaped [T x N x 1] 179 | steps (int): total number of steps taken 180 | """ 181 | N = self.num_workers 182 | T = self.worker_steps 183 | 184 | # TEMP needs to be generalized, does conv-specific transpose for PyTorch 185 | obs = torch.zeros(T + 1, N, 4, 84, 84) 186 | obs = cuda_if(obs, self.cuda) 187 | rewards = torch.zeros(T, N, 1) 188 | rewards = cuda_if(rewards, self.cuda) 189 | masks = torch.zeros(T, N, 1) 190 | masks = cuda_if(masks, self.cuda) 191 | actions = torch.zeros(T, N, 1).long() 192 | actions = cuda_if(actions, self.cuda) 193 | 194 | for t in range(T): 195 | # interaction logic 196 | ob = torch.from_numpy(self.last_ob.transpose((0, 3, 1, 2))).float() 197 | ob = Variable(ob / 255.) 198 | ob = cuda_if(ob, self.cuda) 199 | obs[t] = ob.data 200 | 201 | pi, v = self.policy(ob) 202 | u = cuda_if(torch.rand(pi.size()), self.cuda) 203 | _, action = torch.max(pi.data - (-u.log()).log(), 1) 204 | action = action.unsqueeze(1) 205 | actions[t] = action 206 | 207 | self.last_ob, reward, done, _ = self.venv.step(action.cpu().numpy()) 208 | reward = torch.from_numpy(reward).unsqueeze(1) 209 | rewards[t] = torch.clamp(reward, min=-1., max=1.) 210 | masks[t] = mask = torch.from_numpy((1. - done)).unsqueeze(1) 211 | 212 | ob = torch.from_numpy(self.last_ob.transpose((0, 3, 1, 2))).float() 213 | ob = Variable(ob / 255.) 214 | ob = cuda_if(ob, self.cuda) 215 | obs[T] = ob.data 216 | 217 | steps = N * T 218 | 219 | return obs, rewards, masks, actions, steps 220 | 221 | def test(self): 222 | ob = self.test_env.reset() 223 | done = False 224 | ep_reward = 0 225 | last_action = np.array([-1]) 226 | action_repeat = 0 227 | 228 | while not done: 229 | ob = np.array(ob) 230 | ob = torch.from_numpy(ob.transpose((2, 0, 1))).float().unsqueeze(0) 231 | ob = Variable(ob / 255., volatile=True) 232 | ob = cuda_if(ob, self.cuda) 233 | 234 | pi, v = self.policy(ob) 235 | _, action = torch.max(pi, dim=1) 236 | 237 | # abort after {self.test_repeat_max} discrete action repeats 238 | if action.data[0] == last_action.data[0]: 239 | action_repeat += 1 240 | if action_repeat == self.test_repeat_max: 241 | return ep_reward 242 | else: 243 | action_repeat = 0 244 | last_action = action 245 | 246 | ob, reward, done, _ = self.test_env.step(action.data.cpu().numpy()) 247 | 248 | ep_reward += reward 249 | 250 | return ep_reward 251 | 252 | 253 | class PPOObjective(nn.Module): 254 | def forward(self, clip, pi, v, pi_old, v_old, action, advantage, returns): 255 | """ Computes PPO objectives 256 | 257 | Assumes discrete action space. 258 | 259 | Args: 260 | clip (float): probability ratio clipping range 261 | pi (Variable): discrete action logits, shaped [N x num_actions] 262 | v (Variable): value predictions, shaped [N x 1] 263 | pi_old (Variable): old discrete action logits, shaped [N x num_actions] 264 | v_old (Variable): old value predictions, shaped [N x 1] 265 | action (Variable): discrete actions, shaped [N x 1] 266 | advantage (Variable): action advantages, shaped [N x 1] 267 | returns (Variable): discounted returns, shaped [N x 1] 268 | 269 | Returns: 270 | policy_loss (Variable): policy surrogate loss, shaped [1] 271 | value_loss (Variable): value loss, shaped [1] 272 | entropy_loss (Variable): entropy loss, shaped [1] 273 | """ 274 | prob = Fnn.softmax(pi) 275 | log_prob = Fnn.log_softmax(pi) 276 | action_prob = prob.gather(1, action) 277 | 278 | prob_old = Fnn.softmax(pi_old) 279 | action_prob_old = prob_old.gather(1, action) 280 | 281 | ratio = action_prob / (action_prob_old + 1e-10) 282 | 283 | advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-5) 284 | 285 | surr1 = ratio * advantage 286 | surr2 = torch.clamp(ratio, min=1. - clip, max=1. + clip) * advantage 287 | 288 | policy_loss = -torch.min(surr1, surr2).mean() 289 | value_loss = (.5 * (v - returns) ** 2.).mean() 290 | entropy_loss = (prob * log_prob).sum(1).mean() 291 | 292 | return policy_loss, value_loss, entropy_loss 293 | -------------------------------------------------------------------------------- /Utils/Atari_PPO_training/save/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Utils/Atari_PPO_training/save/.gitkeep -------------------------------------------------------------------------------- /Utils/Atari_PPO_training/subproc_vec_env.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | 3 | import numpy as np 4 | from vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars 5 | 6 | 7 | def worker(remote, parent_remote, env_fn_wrapper): 8 | parent_remote.close() 9 | env = env_fn_wrapper.x() 10 | try: 11 | while True: 12 | cmd, data = remote.recv() 13 | if cmd == 'step': 14 | ob, reward, done, info = env.step(data) 15 | if done: 16 | ob = env.reset() 17 | remote.send((ob, reward, done, info)) 18 | elif cmd == 'reset': 19 | ob = env.reset() 20 | remote.send(ob) 21 | elif cmd == 'render': 22 | remote.send(env.render(mode='rgb_array')) 23 | elif cmd == 'close': 24 | remote.close() 25 | break 26 | elif cmd == 'get_spaces_spec': 27 | remote.send((env.observation_space, env.action_space, env.spec)) 28 | else: 29 | raise NotImplementedError 30 | except KeyboardInterrupt: 31 | print('SubprocVecEnv worker: got KeyboardInterrupt') 32 | finally: 33 | env.close() 34 | 35 | 36 | class SubprocVecEnv(VecEnv): 37 | """ 38 | VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes. 39 | Recommended to use when num_envs > 1 and step() can be a bottleneck. 40 | """ 41 | def __init__(self, env_fns, spaces=None, context='spawn'): 42 | """ 43 | Arguments: 44 | 45 | env_fns: iterable of callables - functions that create environments to run in subprocesses. Need to be cloud-pickleable 46 | """ 47 | self.waiting = False 48 | self.closed = False 49 | nenvs = len(env_fns) 50 | ctx = mp.get_context(context) 51 | self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(nenvs)]) 52 | self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) 53 | for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] 54 | for p in self.ps: 55 | p.daemon = True # if the main process crashes, we should not cause things to hang 56 | with clear_mpi_env_vars(): 57 | p.start() 58 | for remote in self.work_remotes: 59 | remote.close() 60 | 61 | self.remotes[0].send(('get_spaces_spec', None)) 62 | observation_space, action_space, self.spec = self.remotes[0].recv() 63 | self.viewer = None 64 | VecEnv.__init__(self, len(env_fns), observation_space, action_space) 65 | 66 | def step_async(self, actions): 67 | self._assert_not_closed() 68 | for remote, action in zip(self.remotes, actions): 69 | remote.send(('step', action)) 70 | self.waiting = True 71 | 72 | def step_wait(self): 73 | self._assert_not_closed() 74 | results = [remote.recv() for remote in self.remotes] 75 | self.waiting = False 76 | obs, rews, dones, infos = zip(*results) 77 | return _flatten_obs(obs), np.stack(rews), np.stack(dones), infos 78 | 79 | def reset(self): 80 | self._assert_not_closed() 81 | for remote in self.remotes: 82 | remote.send(('reset', None)) 83 | return _flatten_obs([remote.recv() for remote in self.remotes]) 84 | 85 | def close_extras(self): 86 | self.closed = True 87 | if self.waiting: 88 | for remote in self.remotes: 89 | remote.recv() 90 | for remote in self.remotes: 91 | remote.send(('close', None)) 92 | for p in self.ps: 93 | p.join() 94 | 95 | def get_images(self): 96 | self._assert_not_closed() 97 | for pipe in self.remotes: 98 | pipe.send(('render', None)) 99 | imgs = [pipe.recv() for pipe in self.remotes] 100 | return imgs 101 | 102 | def _assert_not_closed(self): 103 | assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()" 104 | 105 | def __del__(self): 106 | if not self.closed: 107 | self.close() 108 | 109 | def _flatten_obs(obs): 110 | assert isinstance(obs, (list, tuple)) 111 | assert len(obs) > 0 112 | 113 | if isinstance(obs[0], dict): 114 | keys = obs[0].keys() 115 | return {k: np.stack([o[k] for o in obs]) for k in keys} 116 | else: 117 | return np.stack(obs) 118 | -------------------------------------------------------------------------------- /Utils/Atari_PPO_training/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.optim as optim 4 | from atari_wrappers import FrameStack 5 | from subproc_vec_env import SubprocVecEnv 6 | from vec_frame_stack import VecFrameStack 7 | import multiprocessing 8 | 9 | from envs import make_env, RenderSubprocVecEnv 10 | from models import AtariCNN 11 | from ppo import PPO 12 | from utils import set_seed, cuda_if 13 | 14 | if __name__ == "__main__": 15 | multiprocessing.set_start_method("forkserver") 16 | 17 | parser = argparse.ArgumentParser(description='PPO', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | parser.add_argument('env_id', type=str, help='Gym environment id') 19 | parser.add_argument('--arch', type=str, default='cnn', help='policy architecture, {lstm, cnn}') 20 | parser.add_argument('--num-workers', type=int, default=8, help='number of parallel actors') 21 | parser.add_argument('--opt-epochs', type=int, default=3, help='optimization epochs between environment interaction') 22 | parser.add_argument('--total-steps', type=int, default=int(10e6), help='total number of environment steps to take') 23 | parser.add_argument('--worker-steps', type=int, default=128, help='steps per worker between optimization rounds') 24 | parser.add_argument('--sequence-steps', type=int, default=32, help='steps per sequence (for backprop through time)') 25 | parser.add_argument('--minibatch-steps', type=int, default=256, help='steps per optimization minibatch') 26 | parser.add_argument('--lr', type=float, default=2.5e-4, help='initial learning rate') 27 | parser.add_argument('--lr-func', type=str, default='linear', help='learning rate schedule function, {linear, constant}') 28 | parser.add_argument('--clip', type=float, default=.1, help='initial probability ratio clipping range') 29 | parser.add_argument('--clip-func', type=str, default='linear', help='clip range schedule function, {linear, constant}') 30 | parser.add_argument('--gamma', type=float, default=.99, help='discount factor') 31 | parser.add_argument('--lambd', type=float, default=.95, help='GAE lambda parameter') 32 | parser.add_argument('--value-coef', type=float, default=1., help='value loss coeffecient') 33 | parser.add_argument('--entropy-coef', type=float, default=.01, help='entropy loss coeffecient') 34 | parser.add_argument('--max-grad-norm', type=float, default=.5, help='grad norm to clip at') 35 | parser.add_argument('--no-cuda', action='store_true', help='disable CUDA acceleration') 36 | parser.add_argument('--render', action='store_true', help='render training environments') 37 | parser.add_argument('--render-interval', type=int, default=4, help='steps between environment renders') 38 | parser.add_argument('--plot-reward', action='store_true', help='plot episode reward') 39 | parser.add_argument('--plot-points', type=int, default=20, help='number of plot points (groups with mean, std)') 40 | parser.add_argument('--plot-path', type=str, default='ep_reward.png', help='path to save reward plot to') 41 | parser.add_argument('--seed', type=int, default=0, help='random seed') 42 | args = parser.parse_args() 43 | 44 | set_seed(args.seed) 45 | 46 | cuda = torch.cuda.is_available() and not args.no_cuda 47 | 48 | test_env = make_env(args.env_id, 0, args.seed) 49 | test_env = FrameStack(test_env, 4) 50 | 51 | policy = {'cnn': AtariCNN}[args.arch](venv.action_space.n) 52 | checkpoint = torch.load("./save/PPO_" + self.env_name + ".pt") 53 | policy.load_check_point(checkpoint["policy"]) 54 | policy = cuda_if(policy, cuda) 55 | 56 | ob = self.test_env.reset() 57 | done = False 58 | ep_reward = 0 59 | last_action = np.array([-1]) 60 | action_repeat = 0 61 | 62 | while not done: 63 | ob = np.array(ob) 64 | ob = torch.from_numpy(ob.transpose((2, 0, 1))).float().unsqueeze(0) 65 | ob = Variable(ob / 255., volatile=True) 66 | ob = cuda_if(ob, self.cuda) 67 | 68 | pi, v = policy(ob) 69 | _, action = torch.max(pi, dim=1) 70 | 71 | # abort after {self.test_repeat_max} discrete action repeats 72 | if action.data[0] == last_action.data[0]: 73 | action_repeat += 1 74 | if action_repeat == self.test_repeat_max: 75 | return ep_reward 76 | else: 77 | action_repeat = 0 78 | last_action = action 79 | 80 | ob, reward, done, _ = self.test_env.step(action.data.cpu().numpy()) 81 | 82 | ep_reward += reward 83 | 84 | print(ep_reward) 85 | -------------------------------------------------------------------------------- /Utils/Atari_PPO_training/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | def set_seed(seed): 6 | random.seed(seed) 7 | np.random.seed(seed) 8 | torch.manual_seed(seed) 9 | 10 | def cuda_if(torch_object, cuda): 11 | return torch_object.cuda() if cuda else torch_object 12 | 13 | def gae(rewards, masks, values, gamma, lambd): 14 | """ Generalized Advantage Estimation 15 | 16 | Args: 17 | rewards (FloatTensor): rewards shaped [T x N x 1] 18 | masks (FloatTensor): continuation masks shaped [T x N x 1] 19 | zero at done timesteps, one otherwise 20 | values (Variable): value predictions shaped [(T + 1) x N x 1] 21 | gamma (float): discount factor 22 | lambd (float): GAE lambda parameter 23 | 24 | Returns: 25 | advantages (FloatTensor): advantages shaped [T x N x 1] 26 | returns (FloatTensor): returns shaped [T x N x 1] 27 | """ 28 | T, N, _ = rewards.size() 29 | 30 | cuda = rewards.is_cuda 31 | 32 | advantages = torch.zeros(T, N, 1) 33 | advantages = cuda_if(advantages, cuda) 34 | advantage_t = torch.zeros(N, 1) 35 | advantage_t = cuda_if(advantage_t, cuda) 36 | 37 | for t in reversed(range(T)): 38 | delta = rewards[t] + values[t + 1].data * gamma * masks[t] - values[t].data 39 | advantage_t = delta + advantage_t * gamma * lambd * masks[t] 40 | advantages[t] = advantage_t 41 | 42 | returns = values[:T].data + advantages 43 | 44 | return advantages, returns 45 | 46 | 47 | def mean_std_groups(x, y, group_size): 48 | num_groups = int(len(x) / group_size) 49 | 50 | x, x_tail = x[:group_size * num_groups], x[group_size * num_groups:] 51 | x = x.reshape((num_groups, group_size)) 52 | 53 | y, y_tail = y[:group_size * num_groups], y[group_size * num_groups:] 54 | y = y.reshape((num_groups, group_size)) 55 | 56 | x_means = x.mean(axis=1) 57 | x_stds = x.std(axis=1) 58 | 59 | if len(x_tail) > 0: 60 | x_means = np.concatenate([x_means, x_tail.mean(axis=0, keepdims=True)]) 61 | x_stds = np.concatenate([x_stds, x_tail.std(axis=0, keepdims=True)]) 62 | 63 | y_means = y.mean(axis=1) 64 | y_stds = y.std(axis=1) 65 | 66 | if len(y_tail) > 0: 67 | y_means = np.concatenate([y_means, y_tail.mean(axis=0, keepdims=True)]) 68 | y_stds = np.concatenate([y_stds, y_tail.std(axis=0, keepdims=True)]) 69 | 70 | return x_means, x_stds, y_means, y_stds 71 | 72 | def set_lr(optimizer, lr): 73 | for param_group in optimizer.param_groups: 74 | param_group['lr'] = lr 75 | -------------------------------------------------------------------------------- /Utils/Atari_PPO_training/vec_env.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import os 3 | from abc import ABC, abstractmethod 4 | 5 | def tile_images(img_nhwc): 6 | """ 7 | Tile N images into one big PxQ image 8 | (P,Q) are chosen to be as close as possible, and if N 9 | is square, then P=Q. 10 | 11 | input: img_nhwc, list or array of images, ndim=4 once turned into array 12 | n = batch index, h = height, w = width, c = channel 13 | returns: 14 | bigim_HWc, ndarray with ndim=3 15 | """ 16 | img_nhwc = np.asarray(img_nhwc) 17 | N, h, w, c = img_nhwc.shape 18 | H = int(np.ceil(np.sqrt(N))) 19 | W = int(np.ceil(float(N)/H)) 20 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)]) 21 | img_HWhwc = img_nhwc.reshape(H, W, h, w, c) 22 | img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4) 23 | img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c) 24 | return img_Hh_Ww_c 25 | 26 | class AlreadySteppingError(Exception): 27 | """ 28 | Raised when an asynchronous step is running while 29 | step_async() is called again. 30 | """ 31 | 32 | def __init__(self): 33 | msg = 'already running an async step' 34 | Exception.__init__(self, msg) 35 | 36 | 37 | class NotSteppingError(Exception): 38 | """ 39 | Raised when an asynchronous step is not running but 40 | step_wait() is called. 41 | """ 42 | 43 | def __init__(self): 44 | msg = 'not running an async step' 45 | Exception.__init__(self, msg) 46 | 47 | 48 | class VecEnv(ABC): 49 | """ 50 | An abstract asynchronous, vectorized environment. 51 | Used to batch data from multiple copies of an environment, so that 52 | each observation becomes an batch of observations, and expected action is a batch of actions to 53 | be applied per-environment. 54 | """ 55 | closed = False 56 | viewer = None 57 | 58 | metadata = { 59 | 'render.modes': ['human', 'rgb_array'] 60 | } 61 | 62 | def __init__(self, num_envs, observation_space, action_space): 63 | self.num_envs = num_envs 64 | self.observation_space = observation_space 65 | self.action_space = action_space 66 | 67 | @abstractmethod 68 | def reset(self): 69 | """ 70 | Reset all the environments and return an array of 71 | observations, or a dict of observation arrays. 72 | 73 | If step_async is still doing work, that work will 74 | be cancelled and step_wait() should not be called 75 | until step_async() is invoked again. 76 | """ 77 | pass 78 | 79 | @abstractmethod 80 | def step_async(self, actions): 81 | """ 82 | Tell all the environments to start taking a step 83 | with the given actions. 84 | Call step_wait() to get the results of the step. 85 | 86 | You should not call this if a step_async run is 87 | already pending. 88 | """ 89 | pass 90 | 91 | @abstractmethod 92 | def step_wait(self): 93 | """ 94 | Wait for the step taken with step_async(). 95 | 96 | Returns (obs, rews, dones, infos): 97 | - obs: an array of observations, or a dict of 98 | arrays of observations. 99 | - rews: an array of rewards 100 | - dones: an array of "episode done" booleans 101 | - infos: a sequence of info objects 102 | """ 103 | pass 104 | 105 | def close_extras(self): 106 | """ 107 | Clean up the extra resources, beyond what's in this base class. 108 | Only runs when not self.closed. 109 | """ 110 | pass 111 | 112 | def close(self): 113 | if self.closed: 114 | return 115 | if self.viewer is not None: 116 | self.viewer.close() 117 | self.close_extras() 118 | self.closed = True 119 | 120 | def step(self, actions): 121 | """ 122 | Step the environments synchronously. 123 | 124 | This is available for backwards compatibility. 125 | """ 126 | self.step_async(actions) 127 | return self.step_wait() 128 | 129 | def render(self, mode='human'): 130 | imgs = self.get_images() 131 | bigimg = tile_images(imgs) 132 | if mode == 'human': 133 | self.get_viewer().imshow(bigimg) 134 | return self.get_viewer().isopen 135 | elif mode == 'rgb_array': 136 | return bigimg 137 | else: 138 | raise NotImplementedError 139 | 140 | def get_images(self): 141 | """ 142 | Return RGB images from each environment 143 | """ 144 | raise NotImplementedError 145 | 146 | @property 147 | def unwrapped(self): 148 | if isinstance(self, VecEnvWrapper): 149 | return self.venv.unwrapped 150 | else: 151 | return self 152 | 153 | def get_viewer(self): 154 | if self.viewer is None: 155 | from gym.envs.classic_control import rendering 156 | self.viewer = rendering.SimpleImageViewer() 157 | return self.viewer 158 | 159 | class VecEnvWrapper(VecEnv): 160 | """ 161 | An environment wrapper that applies to an entire batch 162 | of environments at once. 163 | """ 164 | 165 | def __init__(self, venv, observation_space=None, action_space=None): 166 | self.venv = venv 167 | super().__init__(num_envs=venv.num_envs, 168 | observation_space=observation_space or venv.observation_space, 169 | action_space=action_space or venv.action_space) 170 | 171 | def step_async(self, actions): 172 | self.venv.step_async(actions) 173 | 174 | @abstractmethod 175 | def reset(self): 176 | pass 177 | 178 | @abstractmethod 179 | def step_wait(self): 180 | pass 181 | 182 | def close(self): 183 | return self.venv.close() 184 | 185 | def render(self, mode='human'): 186 | return self.venv.render(mode=mode) 187 | 188 | def get_images(self): 189 | return self.venv.get_images() 190 | 191 | def __getattr__(self, name): 192 | if name.startswith('_'): 193 | raise AttributeError("attempted to get missing private attribute '{}'".format(name)) 194 | return getattr(self.venv, name) 195 | 196 | class VecEnvObservationWrapper(VecEnvWrapper): 197 | @abstractmethod 198 | def process(self, obs): 199 | pass 200 | 201 | def reset(self): 202 | obs = self.venv.reset() 203 | return self.process(obs) 204 | 205 | def step_wait(self): 206 | obs, rews, dones, infos = self.venv.step_wait() 207 | return self.process(obs), rews, dones, infos 208 | 209 | class CloudpickleWrapper(object): 210 | """ 211 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 212 | """ 213 | 214 | def __init__(self, x): 215 | self.x = x 216 | 217 | def __getstate__(self): 218 | import cloudpickle 219 | return cloudpickle.dumps(self.x) 220 | 221 | def __setstate__(self, ob): 222 | import pickle 223 | self.x = pickle.loads(ob) 224 | 225 | 226 | @contextlib.contextmanager 227 | def clear_mpi_env_vars(): 228 | """ 229 | from mpi4py import MPI will call MPI_Init by default. If the child process has MPI environment variables, MPI will think that the child process is an MPI process just like the parent and do bad things such as hang. 230 | This context manager is a hacky way to clear those environment variables temporarily such as when we are starting multiprocessing 231 | Processes. 232 | """ 233 | removed_environment = {} 234 | for k, v in list(os.environ.items()): 235 | for prefix in ['OMPI_', 'PMI_']: 236 | if k.startswith(prefix): 237 | removed_environment[k] = v 238 | del os.environ[k] 239 | try: 240 | yield 241 | finally: 242 | os.environ.update(removed_environment) 243 | -------------------------------------------------------------------------------- /Utils/Atari_PPO_training/vec_frame_stack.py: -------------------------------------------------------------------------------- 1 | from vec_env import VecEnvWrapper 2 | import numpy as np 3 | from gym import spaces 4 | 5 | 6 | class VecFrameStack(VecEnvWrapper): 7 | def __init__(self, venv, nstack): 8 | self.venv = venv 9 | self.nstack = nstack 10 | wos = venv.observation_space # wrapped ob space 11 | low = np.repeat(wos.low, self.nstack, axis=-1) 12 | high = np.repeat(wos.high, self.nstack, axis=-1) 13 | self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype) 14 | observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype) 15 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space) 16 | 17 | def step_wait(self): 18 | obs, rews, news, infos = self.venv.step_wait() 19 | self.stackedobs = np.roll(self.stackedobs, shift=-1, axis=-1) 20 | for (i, new) in enumerate(news): 21 | if new: 22 | self.stackedobs[i] = 0 23 | self.stackedobs[..., -obs.shape[-1]:] = obs 24 | return self.stackedobs, rews, news, infos 25 | 26 | def reset(self): 27 | obs = self.venv.reset() 28 | self.stackedobs[...] = 0 29 | self.stackedobs[..., -obs.shape[-1]:] = obs 30 | return self.stackedobs 31 | -------------------------------------------------------------------------------- /Utils/MovingAvegCalculator.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | class MovingAvegCalculator(): 5 | def __init__(self, window_length): 6 | self.num_added = 0 7 | self.window_length = window_length 8 | self.window = [0.0 for _ in range(window_length)] 9 | 10 | self.aveg = 0.0 11 | self.var = 0.0 12 | 13 | self.last_std = 0.0 14 | 15 | def add_number(self, num): 16 | idx = self.num_added % self.window_length 17 | old_num = self.window[idx] 18 | self.window[idx] = num 19 | self.num_added += 1 20 | 21 | old_aveg = self.aveg 22 | if self.num_added <= self.window_length: 23 | delta = num - old_aveg 24 | self.aveg += delta / self.num_added 25 | self.var += delta * (num - self.aveg) 26 | else: 27 | delta = num - old_num 28 | self.aveg += delta / self.window_length 29 | self.var += delta * ((num - self.aveg) + (old_num - old_aveg)) 30 | 31 | if self.num_added <= self.window_length: 32 | if self.num_added == 1: 33 | variance = 0.1 34 | else: 35 | variance = self.var / (self.num_added - 1) 36 | else: 37 | variance = self.var / self.window_length 38 | 39 | try: 40 | std = math.sqrt(variance) 41 | if math.isnan(std): 42 | std = 0.1 43 | except: 44 | std = 0.1 45 | 46 | self.last_std = std 47 | 48 | return self.aveg, std 49 | 50 | def get_standard_deviation(self): 51 | return self.last_std 52 | -------------------------------------------------------------------------------- /Utils/NetworkDistillation/Distillation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import sys 5 | sys.path.append("../Env") 6 | 7 | from Env.EnvWrapper import EnvWrapper 8 | 9 | from Policy.PPO.PPOPolicy import PPOAtariCNN, PPOSmallAtariCNN 10 | 11 | from .ReplayBuffer import ReplayBuffer 12 | 13 | 14 | class Distillation(): 15 | def __init__(self, wrapped_env, teacher_network, student_network, 16 | temperature = 2.5, buffer_size = 1e5, batch_size = 32, 17 | device = torch.device("cpu")): 18 | self.wrapped_env = wrapped_env 19 | self.teacher_network = teacher_network 20 | self.student_network = student_network 21 | self.temperature = temperature 22 | self.buffer_size = buffer_size 23 | self.batch_size = batch_size 24 | self.device = device 25 | 26 | # Replay buffer 27 | self.replay_buffer = ReplayBuffer(max_size = buffer_size, device = self.device) 28 | 29 | def train_step(self): 30 | state_batch, policy_batch, value_batch = self.replay_buffer.sample(self.batch_size) 31 | 32 | loss = self.student_network.train_step(state_batch, policy_batch, value_batch, temperature = self.temperature) 33 | 34 | return loss 35 | 36 | def gather_samples(self, max_step_count = 10000): 37 | state = self.wrapped_env.reset() 38 | 39 | step_count = 0 40 | 41 | done = False 42 | while not done: 43 | if np.random.random() < 0.9: 44 | action = self.categorical(self.student_network.get_action(state)) 45 | else: 46 | action = np.random.randint(0, self.wrapped_env.action_n) 47 | target_policy = self.teacher_network.get_action(state, logit = True) 48 | target_value = self.teacher_network.get_value(state) 49 | 50 | self.replay_buffer.add((np.array(state), target_policy, target_value)) 51 | 52 | state, _, done = self.wrapped_env.step(action) 53 | 54 | step_count += 1 55 | 56 | if step_count > max_step_count: 57 | return 58 | 59 | @staticmethod 60 | def categorical(probs): 61 | val = random.random() 62 | chosen_idx = 0 63 | 64 | for prob in probs: 65 | val -= prob 66 | 67 | if val < 0.0: 68 | break 69 | 70 | chosen_idx += 1 71 | 72 | if chosen_idx >= len(probs): 73 | chosen_idx = len(probs) - 1 74 | 75 | return chosen_idx 76 | 77 | 78 | def train_distillation(env_name, device): 79 | device = torch.device("cuda:0" if device == "cuda" else device) 80 | 81 | wrapped_env = EnvWrapper(env_name = env_name, max_episode_length = 100000) 82 | 83 | teacher_network = PPOAtariCNN(wrapped_env.action_n, device, 84 | checkpoint_dir = "./Policy/PPO/PolicyFiles/PPO_" + env_name + ".pt") 85 | student_network = PPOSmallAtariCNN(wrapped_env.action_n, device, 86 | checkpoint_dir = "") 87 | 88 | distillation = Distillation(wrapped_env, teacher_network, student_network, device = device) 89 | 90 | for _ in range(1000): 91 | for _ in range(10): 92 | distillation.gather_samples() 93 | 94 | for _ in range(1000): 95 | loss = distillation.train_step() 96 | print(loss) 97 | 98 | student_network.save("./Policy/PPO/PolicyFiles/SmallPPO_" + env_name + ".pt") 99 | -------------------------------------------------------------------------------- /Utils/NetworkDistillation/ReplayBuffer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class ReplayBuffer(object): 6 | def __init__(self, max_size = 1e6, device = None): 7 | self.storage = [] 8 | self.max_size = max_size 9 | self.ptr = 0 10 | 11 | self.device = device 12 | 13 | def add(self, data): 14 | if len(self.storage) == self.max_size: 15 | self.storage[int(self.ptr)] = data 16 | self.ptr = (self.ptr + 1) % self.max_size 17 | else: 18 | self.storage.append(data) 19 | 20 | def sample(self, batch_size): 21 | ind = np.random.randint(0, len(self.storage), size = batch_size) 22 | state_batch, policy_batch, value_batch = [], [], [] 23 | 24 | for i in ind: 25 | state, policy, value = self.storage[i] 26 | state_batch.append(np.array(state, copy = False)) 27 | policy_batch.append(np.array(policy, copy = False)) 28 | value_batch.append(value) 29 | 30 | state_batch = torch.from_numpy(np.array(state_batch, dtype = np.float32)) 31 | policy_batch = torch.from_numpy(np.array(policy_batch, dtype = np.float32)) 32 | value_batch = torch.from_numpy(np.array(value_batch, dtype = np.float32)) 33 | 34 | if self.device is not None: 35 | state_batch = state_batch.to(self.device) 36 | policy_batch = policy_batch.to(self.device) 37 | value_batch = value_batch.to(self.device).unsqueeze(-1) 38 | 39 | return state_batch, policy_batch, value_batch 40 | -------------------------------------------------------------------------------- /Utils/NetworkDistillation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Utils/NetworkDistillation/__init__.py -------------------------------------------------------------------------------- /Utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuanji/WU-UCT/a185cf11feef548ca3f7f59bec4f9d8a92c8a2a1/Utils/__init__.py -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing 3 | import scipy.io as sio 4 | import os 5 | 6 | from Tree.WU_UCT import WU_UCT 7 | from Tree.UCT import UCT 8 | 9 | from Utils.NetworkDistillation.Distillation import train_distillation 10 | 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser(description = "P-MCTS") 14 | parser.add_argument("--model", type = str, default = "WU-UCT", 15 | help = "Base MCTS model WU-UCT/UCT (default: WU-UCT)") 16 | 17 | parser.add_argument("--env-name", type = str, default = "AlienNoFrameskip-v0", 18 | help = "Environment name (default: AlienNoFrameskip-v0)") 19 | 20 | parser.add_argument("--MCTS-max-steps", type = int, default = 128, 21 | help = "Max simulation step of MCTS (default: 500)") 22 | parser.add_argument("--MCTS-max-depth", type = int, default = 100, 23 | help = "Max depth of MCTS simulation (default: 100)") 24 | parser.add_argument("--MCTS-max-width", type = int, default = 20, 25 | help = "Max width of MCTS simulation (default: 20)") 26 | 27 | parser.add_argument("--gamma", type = float, default = 0.99, 28 | help = "Discount factor (default: 1.0)") 29 | 30 | parser.add_argument("--expansion-worker-num", type = int, default = 1, 31 | help = "Number of expansion workers (default: 1)") 32 | parser.add_argument("--simulation-worker-num", type = int, default = 16, 33 | help = "Number of simulation workers (default: 16)") 34 | 35 | parser.add_argument("--seed", type = int, default = 123, 36 | help = "random seed (default: 123)") 37 | 38 | parser.add_argument("--max-episode-length", type = int, default = 100000, 39 | help = "Maximum episode length (default: 100000)") 40 | 41 | parser.add_argument("--policy", type = str, default = "Random", 42 | help = "Prior prob/simulation policy used in MCTS Random/PPO/DistillPPO (default: Random)") 43 | 44 | parser.add_argument("--device", type = str, default = "cpu", 45 | help = "PyTorch device, if entered 'cuda', use cuda device parallelization (default: cpu)") 46 | 47 | parser.add_argument("--record-video", default = False, action = "store_true", 48 | help = "Record video if supported (default: False)") 49 | 50 | parser.add_argument("--mode", type = str, default = "MCTS", 51 | help = "Mode MCTS/Distill (default: MCTS)") 52 | 53 | args = parser.parse_args() 54 | 55 | env_params = { 56 | "env_name": args.env_name, 57 | "max_episode_length": args.max_episode_length 58 | } 59 | 60 | if args.mode == "MCTS": 61 | # Model initialization 62 | if args.model == "WU-UCT": 63 | MCTStree = WU_UCT(env_params, args.MCTS_max_steps, args.MCTS_max_depth, 64 | args.MCTS_max_width, args.gamma, args.expansion_worker_num, 65 | args.simulation_worker_num, policy = args.policy, 66 | seed = args.seed, device = args.device, 67 | record_video = args.record_video) 68 | elif args.model == "UCT": 69 | MCTStree = UCT(env_params, args.MCTS_max_steps, args.MCTS_max_depth, 70 | args.MCTS_max_width, args.gamma, policy = args.policy, seed = args.seed) 71 | else: 72 | raise NotImplementedError() 73 | 74 | accu_reward, rewards, times = MCTStree.simulate_trajectory() 75 | print(accu_reward) 76 | 77 | with open("Results/" + args.model + ".txt", "a+") as f: 78 | f.write("Model: {}, env: {}, result: {}, MCTS max steps: {}, policy: {}, worker num: {}".format( 79 | args.model, args.env_name, accu_reward, args.MCTS_max_steps, args.policy, args.simulation_worker_num 80 | )) 81 | 82 | if not os.path.exists("OutLogs/"): 83 | try: 84 | os.mkdir("OutLogs/") 85 | except: 86 | pass 87 | 88 | sio.savemat("OutLogs/" + args.model + "_" + args.env_name + "_" + str(args.seed) + "_" + 89 | str(args.simulation_worker_num) + ".mat", 90 | {"rewards": rewards, "times": times}) 91 | 92 | MCTStree.close() 93 | 94 | elif args.mode == "Distill": 95 | train_distillation(args.env_name, args.device) 96 | 97 | 98 | if __name__ == "__main__": 99 | # Mandatory for Unix/Darwin 100 | multiprocessing.set_start_method("forkserver") 101 | 102 | main() 103 | --------------------------------------------------------------------------------