├── .gitignore ├── LICENSE ├── README.md ├── doomFiles ├── README.md ├── __init__.py ├── action_space.py ├── doom_env.py ├── doom_my_way_home_sparse.py ├── doom_my_way_home_verySparse.py └── wads │ ├── my_way_home_dense.wad │ ├── my_way_home_sparse.wad │ └── my_way_home_verySparse.wad ├── images ├── mario1.gif ├── mario2.gif └── vizdoom.gif ├── models └── download_models.sh └── src ├── .gitignore ├── a3c.py ├── constants.py ├── demo.py ├── env_wrapper.py ├── envs.py ├── inference.py ├── mario.py ├── model.py ├── requirements.txt ├── train.py ├── utils.py └── worker.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | **pyc 3 | **npy 4 | tmp/ 5 | curiosity 6 | src/vizdoom.ini 7 | models/*.tar.gz 8 | models/output 9 | models/doom 10 | models/mario 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Deepak Pathak 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | 26 | -------------------------------------------------------------------------------- 27 | Original openai License: 28 | -------------------------------------------------------------------------------- 29 | MIT License 30 | 31 | Copyright (c) 2016 openai 32 | 33 | Permission is hereby granted, free of charge, to any person obtaining a copy 34 | of this software and associated documentation files (the "Software"), to deal 35 | in the Software without restriction, including without limitation the rights 36 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 37 | copies of the Software, and to permit persons to whom the Software is 38 | furnished to do so, subject to the following conditions: 39 | 40 | The above copyright notice and this permission notice shall be included in all 41 | copies or substantial portions of the Software. 42 | 43 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 44 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 45 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 46 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 47 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 48 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 49 | SOFTWARE. 50 | -------------------------------------------------------------------------------- 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Curiosity-driven Exploration by Self-supervised Prediction ## 2 | #### In ICML 2017 [[Project Website]](http://pathak22.github.io/noreward-rl/) [[Demo Video]](http://pathak22.github.io/noreward-rl/index.html#demoVideo) 3 | 4 | [Deepak Pathak](https://people.eecs.berkeley.edu/~pathak/), [Pulkit Agrawal](https://people.eecs.berkeley.edu/~pulkitag/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/), [Trevor Darrell](https://people.eecs.berkeley.edu/~trevor/)
5 | University of California, Berkeley
6 | 7 | 8 | 9 | This is a tensorflow based implementation for our [ICML 2017 paper on curiosity-driven exploration for reinforcement learning](http://pathak22.github.io/noreward-rl/). Idea is to train agent with intrinsic curiosity-based motivation (ICM) when external rewards from environment are sparse. Surprisingly, you can use ICM even when there are no rewards available from the environment, in which case, agent learns to explore only out of curiosity: 'RL without rewards'. If you find this work useful in your research, please cite: 10 | 11 | @inproceedings{pathakICMl17curiosity, 12 | Author = {Pathak, Deepak and Agrawal, Pulkit and 13 | Efros, Alexei A. and Darrell, Trevor}, 14 | Title = {Curiosity-driven Exploration by Self-supervised Prediction}, 15 | Booktitle = {International Conference on Machine Learning ({ICML})}, 16 | Year = {2017} 17 | } 18 | 19 | ### 1) Installation and Usage 20 | 1. This code is based on [TensorFlow](https://www.tensorflow.org/). To install, run these commands: 21 | ```Shell 22 | # you might not need many of these, e.g., fceux is only for mario 23 | sudo apt-get install -y python-numpy python-dev cmake zlib1g-dev libjpeg-dev xvfb \ 24 | libav-tools xorg-dev python-opengl libboost-all-dev libsdl2-dev swig python3-dev \ 25 | python3-venv make golang libjpeg-turbo8-dev gcc wget unzip git fceux virtualenv \ 26 | tmux 27 | 28 | # install the code 29 | git clone -b master --single-branch https://github.com/pathak22/noreward-rl.git 30 | cd noreward-rl/ 31 | virtualenv curiosity 32 | source $PWD/curiosity/bin/activate 33 | pip install numpy 34 | pip install -r src/requirements.txt 35 | python curiosity/src/go-vncdriver/build.py 36 | 37 | # download models 38 | bash models/download_models.sh 39 | 40 | # setup customized doom environment 41 | cd doomFiles/ 42 | # then follow commands in doomFiles/README.md 43 | ``` 44 | 45 | 2. Running demo 46 | ```Shell 47 | cd noreward-rl/src/ 48 | python demo.py --ckpt ../models/doom/doom_ICM 49 | python demo.py --env-id SuperMarioBros-1-1-v0 --ckpt ../models/mario/mario_ICM 50 | ``` 51 | 52 | 3. Training code 53 | ```Shell 54 | cd noreward-rl/src/ 55 | # For Doom: doom or doomSparse or doomVerySparse 56 | python train.py --default --env-id doom 57 | 58 | # For Mario, change src/constants.py as follows: 59 | # PREDICTION_BETA = 0.2 60 | # ENTROPY_BETA = 0.0005 61 | python train.py --default --env-id mario --noReward 62 | 63 | xvfb-run -s "-screen 0 1400x900x24" bash # only for remote desktops 64 | # useful xvfb link: http://stackoverflow.com/a/30336424 65 | python inference.py --default --env-id doom --record 66 | ``` 67 | 68 | ### 2) Other helpful pointers 69 | - [Paper](https://pathak22.github.io/noreward-rl/resources/icml17.pdf) 70 | - [Project Website](http://pathak22.github.io/noreward-rl/) 71 | - [Demo Video](http://pathak22.github.io/noreward-rl/index.html#demoVideo) 72 | - [Reddit Discussion](https://redd.it/6bc8ul) 73 | - [Media Articles (New Scientist, MIT Tech Review and others)](http://pathak22.github.io/noreward-rl/index.html#media) 74 | 75 | ### 3) Acknowledgement 76 | Vanilla A3C code is based on the open source implementation of [universe-starter-agent](https://github.com/openai/universe-starter-agent). 77 | -------------------------------------------------------------------------------- /doomFiles/README.md: -------------------------------------------------------------------------------- 1 | ### VizDoom Scenarios 2 | This directory provides the relevant files to replicate doom scenarios in ICML'17 paper. Run following commands: 3 | 4 | ```Shell 5 | cp wads/*.wad ../curiosity/lib/python2.7/site-packages/doom_py/scenarios/ 6 | cp __init__.py ../curiosity/lib/python2.7/site-packages/ppaquette_gym_doom/ 7 | cp doom*.py ../curiosity/lib/python2.7/site-packages/ppaquette_gym_doom/ 8 | cp action_space.py ../curiosity/lib/python2.7/site-packages/ppaquette_gym_doom/wrappers/ 9 | ``` 10 | -------------------------------------------------------------------------------- /doomFiles/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Place this file in: 3 | /home/pathak/projects/unsup-rl/unsuprl/local/lib/python2.7/site-packages/ppaquette_gym_doom/__init__.py 4 | ''' 5 | 6 | from gym.envs.registration import register 7 | from gym.scoreboard.registration import add_task, add_group 8 | from .package_info import USERNAME 9 | from .doom_env import DoomEnv, MetaDoomEnv 10 | from .doom_basic import DoomBasicEnv 11 | from .doom_corridor import DoomCorridorEnv 12 | from .doom_defend_center import DoomDefendCenterEnv 13 | from .doom_defend_line import DoomDefendLineEnv 14 | from .doom_health_gathering import DoomHealthGatheringEnv 15 | from .doom_my_way_home import DoomMyWayHomeEnv 16 | from .doom_predict_position import DoomPredictPositionEnv 17 | from .doom_take_cover import DoomTakeCoverEnv 18 | from .doom_deathmatch import DoomDeathmatchEnv 19 | from .doom_my_way_home_sparse import DoomMyWayHomeFixedEnv 20 | from .doom_my_way_home_verySparse import DoomMyWayHomeFixed15Env 21 | 22 | # Env registration 23 | # ========================== 24 | 25 | register( 26 | id='{}/meta-Doom-v0'.format(USERNAME), 27 | entry_point='{}_gym_doom:MetaDoomEnv'.format(USERNAME), 28 | timestep_limit=999999, 29 | reward_threshold=9000.0, 30 | kwargs={ 31 | 'average_over': 3, 32 | 'passing_grade': 600, 33 | 'min_tries_for_avg': 3 34 | }, 35 | ) 36 | 37 | register( 38 | id='{}/DoomBasic-v0'.format(USERNAME), 39 | entry_point='{}_gym_doom:DoomBasicEnv'.format(USERNAME), 40 | timestep_limit=10000, 41 | reward_threshold=10.0, 42 | ) 43 | 44 | register( 45 | id='{}/DoomCorridor-v0'.format(USERNAME), 46 | entry_point='{}_gym_doom:DoomCorridorEnv'.format(USERNAME), 47 | timestep_limit=10000, 48 | reward_threshold=1000.0, 49 | ) 50 | 51 | register( 52 | id='{}/DoomDefendCenter-v0'.format(USERNAME), 53 | entry_point='{}_gym_doom:DoomDefendCenterEnv'.format(USERNAME), 54 | timestep_limit=10000, 55 | reward_threshold=10.0, 56 | ) 57 | 58 | register( 59 | id='{}/DoomDefendLine-v0'.format(USERNAME), 60 | entry_point='{}_gym_doom:DoomDefendLineEnv'.format(USERNAME), 61 | timestep_limit=10000, 62 | reward_threshold=15.0, 63 | ) 64 | 65 | register( 66 | id='{}/DoomHealthGathering-v0'.format(USERNAME), 67 | entry_point='{}_gym_doom:DoomHealthGatheringEnv'.format(USERNAME), 68 | timestep_limit=10000, 69 | reward_threshold=1000.0, 70 | ) 71 | 72 | register( 73 | id='{}/DoomMyWayHome-v0'.format(USERNAME), 74 | entry_point='{}_gym_doom:DoomMyWayHomeEnv'.format(USERNAME), 75 | timestep_limit=10000, 76 | reward_threshold=0.5, 77 | ) 78 | 79 | register( 80 | id='{}/DoomPredictPosition-v0'.format(USERNAME), 81 | entry_point='{}_gym_doom:DoomPredictPositionEnv'.format(USERNAME), 82 | timestep_limit=10000, 83 | reward_threshold=0.5, 84 | ) 85 | 86 | register( 87 | id='{}/DoomTakeCover-v0'.format(USERNAME), 88 | entry_point='{}_gym_doom:DoomTakeCoverEnv'.format(USERNAME), 89 | timestep_limit=10000, 90 | reward_threshold=750.0, 91 | ) 92 | 93 | register( 94 | id='{}/DoomDeathmatch-v0'.format(USERNAME), 95 | entry_point='{}_gym_doom:DoomDeathmatchEnv'.format(USERNAME), 96 | timestep_limit=10000, 97 | reward_threshold=20.0, 98 | ) 99 | 100 | register( 101 | id='{}/DoomMyWayHomeFixed-v0'.format(USERNAME), 102 | entry_point='{}_gym_doom:DoomMyWayHomeFixedEnv'.format(USERNAME), 103 | timestep_limit=10000, 104 | reward_threshold=0.5, 105 | ) 106 | 107 | register( 108 | id='{}/DoomMyWayHomeFixed15-v0'.format(USERNAME), 109 | entry_point='{}_gym_doom:DoomMyWayHomeFixed15Env'.format(USERNAME), 110 | timestep_limit=10000, 111 | reward_threshold=0.5, 112 | ) 113 | 114 | # Scoreboard registration 115 | # ========================== 116 | add_group( 117 | id= 'doom', 118 | name= 'Doom', 119 | description= 'Doom environments based on VizDoom.' 120 | ) 121 | 122 | add_task( 123 | id='{}/meta-Doom-v0'.format(USERNAME), 124 | group='doom', 125 | summary='Mission #1 to #9 - Beat all 9 Doom missions.', 126 | description=""" 127 | This is a meta map that combines all 9 Doom levels. 128 | 129 | Levels: 130 | - #0 Doom Basic 131 | - #1 Doom Corridor 132 | - #2 Doom DefendCenter 133 | - #3 Doom DefendLine 134 | - #4 Doom HealthGathering 135 | - #5 Doom MyWayHome 136 | - #6 Doom PredictPosition 137 | - #7 Doom TakeCover 138 | - #8 Doom Deathmatch 139 | - #9 Doom MyWayHomeFixed (customized) 140 | - #10 Doom MyWayHomeFixed15 (customized) 141 | 142 | Goal: 9,000 points 143 | - Pass all levels 144 | 145 | Scoring: 146 | - Each level score has been standardized on a scale of 0 to 1,000 147 | - The passing score for a level is 990 (99th percentile) 148 | - A bonus of 450 (50 * 9 levels) is given if all levels are passed 149 | - The score for a level is the average of the last 3 tries 150 | """ 151 | ) 152 | 153 | add_task( 154 | id='{}/DoomBasic-v0'.format(USERNAME), 155 | group='doom', 156 | summary='Mission #1 - Kill a single monster using your pistol.', 157 | description=""" 158 | This map is rectangular with gray walls, ceiling and floor. 159 | You are spawned in the center of the longer wall, and a red 160 | circular monster is spawned randomly on the opposite wall. 161 | You need to kill the monster (one bullet is enough). 162 | 163 | Goal: 10 points 164 | - Kill the monster in 3 secs with 1 shot 165 | 166 | Rewards: 167 | - Plus 101 pts for killing the monster 168 | - Minus 5 pts for missing a shot 169 | - Minus 1 pts every 0.028 secs 170 | 171 | Ends when: 172 | - Monster is dead 173 | - Player is dead 174 | - Timeout (10 seconds - 350 frames) 175 | 176 | Allowed actions: 177 | - ATTACK 178 | - MOVE_RIGHT 179 | - MOVE_LEFT 180 | """ 181 | ) 182 | 183 | add_task( 184 | id='{}/DoomCorridor-v0'.format(USERNAME), 185 | group='doom', 186 | summary='Mission #2 - Run as fast as possible to grab a vest.', 187 | description=""" 188 | This map is designed to improve your navigation. There is a vest 189 | at the end of the corridor, with 6 enemies (3 groups of 2). Your goal 190 | is to get to the vest as soon as possible, without being killed. 191 | 192 | Goal: 1,000 points 193 | - Reach the vest (or get very close to it) 194 | 195 | Rewards: 196 | - Plus distance for getting closer to the vest 197 | - Minus distance for getting further from the vest 198 | - Minus 100 pts for getting killed 199 | 200 | Ends when: 201 | - Player touches vest 202 | - Player is dead 203 | - Timeout (1 minutes - 2,100 frames) 204 | 205 | Allowed actions: 206 | - ATTACK 207 | - MOVE_RIGHT 208 | - MOVE_LEFT 209 | - MOVE_FORWARD 210 | - TURN_RIGHT 211 | - TURN_LEFT 212 | """ 213 | ) 214 | 215 | add_task( 216 | id='{}/DoomDefendCenter-v0'.format(USERNAME), 217 | group='doom', 218 | summary='Mission #3 - Kill enemies coming at your from all sides.', 219 | description=""" 220 | This map is designed to teach you how to kill and how to stay alive. 221 | You will also need to keep an eye on your ammunition level. You are only 222 | rewarded for kills, so figure out how to stay alive. 223 | 224 | The map is a circle with monsters. You are in the middle. Monsters will 225 | respawn with additional health when killed. Kill as many as you can 226 | before you run out of ammo. 227 | 228 | Goal: 10 points 229 | - Kill 11 monsters (you have 26 ammo) 230 | 231 | Rewards: 232 | - Plus 1 point for killing a monster 233 | - Minus 1 point for getting killed 234 | 235 | Ends when: 236 | - Player is dead 237 | - Timeout (60 seconds - 2100 frames) 238 | 239 | Allowed actions: 240 | - ATTACK 241 | - TURN_RIGHT 242 | - TURN_LEFT 243 | """ 244 | ) 245 | 246 | add_task( 247 | id='{}/DoomDefendLine-v0'.format(USERNAME), 248 | group='doom', 249 | summary='Mission #4 - Kill enemies on the other side of the room.', 250 | description=""" 251 | This map is designed to teach you how to kill and how to stay alive. 252 | Your ammo will automatically replenish. You are only rewarded for kills, 253 | so figure out how to stay alive. 254 | 255 | The map is a rectangle with monsters on the other side. Monsters will 256 | respawn with additional health when killed. Kill as many as you can 257 | before they kill you. This map is harder than the previous. 258 | 259 | Goal: 15 points 260 | - Kill 16 monsters 261 | 262 | Rewards: 263 | - Plus 1 point for killing a monster 264 | - Minus 1 point for getting killed 265 | 266 | Ends when: 267 | - Player is dead 268 | - Timeout (60 seconds - 2100 frames) 269 | 270 | Allowed actions: 271 | - ATTACK 272 | - TURN_RIGHT 273 | - TURN_LEFT 274 | """ 275 | ) 276 | 277 | add_task( 278 | id='{}/DoomHealthGathering-v0'.format(USERNAME), 279 | group='doom', 280 | summary='Mission #5 - Learn to grad medkits to survive as long as possible.', 281 | description=""" 282 | This map is a guide on how to survive by collecting health packs. 283 | It is a rectangle with green, acidic floor which hurts the player 284 | periodically. There are also medkits spread around the map, and 285 | additional kits will spawn at interval. 286 | 287 | Goal: 1000 points 288 | - Stay alive long enough for approx. 30 secs 289 | 290 | Rewards: 291 | - Plus 1 point every 0.028 secs 292 | - Minus 100 pts for dying 293 | 294 | Ends when: 295 | - Player is dead 296 | - Timeout (60 seconds - 2,100 frames) 297 | 298 | Allowed actions: 299 | - MOVE_FORWARD 300 | - TURN_RIGHT 301 | - TURN_LEFT 302 | """ 303 | ) 304 | 305 | add_task( 306 | id='{}/DoomMyWayHome-v0'.format(USERNAME), 307 | group='doom', 308 | summary='Mission #6 - Find the vest in one the 4 rooms.', 309 | description=""" 310 | This map is designed to improve navigational skills. It is a series of 311 | interconnected rooms and 1 corridor with a dead end. Each room 312 | has a separate color. There is a green vest in one of the room. 313 | The vest is always in the same room. Player must find the vest. 314 | 315 | Goal: 0.50 point 316 | - Find the vest 317 | 318 | Rewards: 319 | - Plus 1 point for finding the vest 320 | - Minus 0.0001 point every 0.028 secs 321 | 322 | Ends when: 323 | - Vest is found 324 | - Timeout (1 minutes - 2,100 frames) 325 | 326 | Allowed actions: 327 | - MOVE_FORWARD 328 | - TURN_RIGHT 329 | - TURN_LEFT 330 | """ 331 | ) 332 | 333 | add_task( 334 | id='{}/DoomPredictPosition-v0'.format(USERNAME), 335 | group='doom', 336 | summary='Mission #7 - Learn how to kill an enemy with a rocket launcher.', 337 | description=""" 338 | This map is designed to train you on using a rocket launcher. 339 | It is a rectangular map with a monster on the opposite side. You need 340 | to use your rocket launcher to kill it. The rocket adds a delay between 341 | the moment it is fired and the moment it reaches the other side of the room. 342 | You need to predict the position of the monster to kill it. 343 | 344 | Goal: 0.5 point 345 | - Kill the monster 346 | 347 | Rewards: 348 | - Plus 1 point for killing the monster 349 | - Minus 0.0001 point every 0.028 secs 350 | 351 | Ends when: 352 | - Monster is dead 353 | - Out of missile (you only have one) 354 | - Timeout (20 seconds - 700 frames) 355 | 356 | Hint: Wait 1 sec for the missile launcher to load. 357 | 358 | Allowed actions: 359 | - ATTACK 360 | - TURN_RIGHT 361 | - TURN_LEFT 362 | """ 363 | ) 364 | 365 | add_task( 366 | id='{}/DoomTakeCover-v0'.format(USERNAME), 367 | group='doom', 368 | summary='Mission #8 - Survive as long as possible with enemies shooting at you.', 369 | description=""" 370 | This map is to train you on the damage of incoming missiles. 371 | It is a rectangular map with monsters firing missiles and fireballs 372 | at you. You need to survive as long as possible. 373 | 374 | Goal: 750 points 375 | - Survive for approx. 20 seconds 376 | 377 | Rewards: 378 | - Plus 1 point every 0.028 secs 379 | 380 | Ends when: 381 | - Player is dead (1 or 2 fireballs is enough) 382 | - Timeout (60 seconds - 2,100 frames) 383 | 384 | Allowed actions: 385 | - MOVE_RIGHT 386 | - MOVE_LEFT 387 | """ 388 | ) 389 | 390 | add_task( 391 | id='{}/DoomDeathmatch-v0'.format(USERNAME), 392 | group='doom', 393 | summary='Mission #9 - Kill as many enemies as possible without being killed.', 394 | description=""" 395 | Kill as many monsters as possible without being killed. 396 | 397 | Goal: 20 points 398 | - Kill 20 monsters 399 | 400 | Rewards: 401 | - Plus 1 point for killing a monster 402 | 403 | Ends when: 404 | - Player is dead 405 | - Timeout (3 minutes - 6,300 frames) 406 | 407 | Allowed actions: 408 | - ALL 409 | """ 410 | ) 411 | 412 | add_task( 413 | id='{}/DoomMyWayHomeFixed-v0'.format(USERNAME), 414 | group='doom', 415 | summary='Mission #10 - Find the vest in one the 4 rooms.', 416 | description=""" 417 | This map is designed to improve navigational skills. It is a series of 418 | interconnected rooms and 1 corridor with a dead end. Each room 419 | has a separate color. There is a green vest in one of the room. 420 | The vest is always in the same room. Player must find the vest. 421 | You always start from fixed room (room no. 10 -- farthest). 422 | 423 | Goal: 0.50 point 424 | - Find the vest 425 | 426 | Rewards: 427 | - Plus 1 point for finding the vest 428 | - Minus 0.0001 point every 0.028 secs 429 | 430 | Ends when: 431 | - Vest is found 432 | - Timeout (1 minutes - 2,100 frames) 433 | 434 | Allowed actions: 435 | - MOVE_FORWARD 436 | - TURN_RIGHT 437 | - TURN_LEFT 438 | """ 439 | ) 440 | 441 | add_task( 442 | id='{}/DoomMyWayHomeFixed15-v0'.format(USERNAME), 443 | group='doom', 444 | summary='Mission #11 - Find the vest in one the 4 rooms.', 445 | description=""" 446 | This map is designed to improve navigational skills. It is a series of 447 | interconnected rooms and 1 corridor with a dead end. Each room 448 | has a separate color. There is a green vest in one of the room. 449 | The vest is always in the same room. Player must find the vest. 450 | You always start from fixed room (room no. 10 -- farthest). 451 | 452 | Goal: 0.50 point 453 | - Find the vest 454 | 455 | Rewards: 456 | - Plus 1 point for finding the vest 457 | - Minus 0.0001 point every 0.028 secs 458 | 459 | Ends when: 460 | - Vest is found 461 | - Timeout (1 minutes - 2,100 frames) 462 | 463 | Allowed actions: 464 | - MOVE_FORWARD 465 | - TURN_RIGHT 466 | - TURN_LEFT 467 | """ 468 | ) 469 | -------------------------------------------------------------------------------- /doomFiles/action_space.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Place this file in: 3 | /home/pathak/projects/unsup-rl/unsuprl/local/lib/python2.7/site-packages/ppaquette_gym_doom/wrappers/action_space.py 4 | ''' 5 | 6 | import gym 7 | 8 | # Constants 9 | NUM_ACTIONS = 43 10 | ALLOWED_ACTIONS = [ 11 | [0, 10, 11], # 0 - Basic 12 | [0, 10, 11, 13, 14, 15], # 1 - Corridor 13 | [0, 14, 15], # 2 - DefendCenter 14 | [0, 14, 15], # 3 - DefendLine 15 | [13, 14, 15], # 4 - HealthGathering 16 | [13, 14, 15], # 5 - MyWayHome 17 | [0, 14, 15], # 6 - PredictPosition 18 | [10, 11], # 7 - TakeCover 19 | [x for x in range(NUM_ACTIONS) if x != 33], # 8 - Deathmatch 20 | [13, 14, 15], # 9 - MyWayHomeFixed 21 | [13, 14, 15], # 10 - MyWayHomeFixed15 22 | ] 23 | 24 | __all__ = [ 'ToDiscrete', 'ToBox' ] 25 | 26 | def ToDiscrete(config): 27 | # Config can be 'minimal', 'constant-7', 'constant-17', 'full' 28 | 29 | class ToDiscreteWrapper(gym.Wrapper): 30 | """ 31 | Doom wrapper to convert MultiDiscrete action space to Discrete 32 | 33 | config: 34 | - minimal - Will only use the levels' allowed actions (+ NOOP) 35 | - constant-7 - Will use the 7 minimum actions (+NOOP) to complete all levels 36 | - constant-17 - Will use the 17 most common actions (+NOOP) to complete all levels 37 | - full - Will use all available actions (+ NOOP) 38 | 39 | list of commands: 40 | - minimal: 41 | Basic: NOOP, ATTACK, MOVE_RIGHT, MOVE_LEFT 42 | Corridor: NOOP, ATTACK, MOVE_RIGHT, MOVE_LEFT, MOVE_FORWARD, TURN_RIGHT, TURN_LEFT 43 | DefendCenter NOOP, ATTACK, TURN_RIGHT, TURN_LEFT 44 | DefendLine: NOOP, ATTACK, TURN_RIGHT, TURN_LEFT 45 | HealthGathering: NOOP, MOVE_FORWARD, TURN_RIGHT, TURN_LEFT 46 | MyWayHome: NOOP, MOVE_FORWARD, TURN_RIGHT, TURN_LEFT 47 | PredictPosition: NOOP, ATTACK, TURN_RIGHT, TURN_LEFT 48 | TakeCover: NOOP, MOVE_RIGHT, MOVE_LEFT 49 | Deathmatch: NOOP, ALL COMMANDS (Deltas are limited to [0,1] range and will not work properly) 50 | 51 | - constant-7: NOOP, ATTACK, MOVE_RIGHT, MOVE_LEFT, MOVE_FORWARD, TURN_RIGHT, TURN_LEFT, SELECT_NEXT_WEAPON 52 | 53 | - constant-17: NOOP, ATTACK, JUMP, CROUCH, TURN180, RELOAD, SPEED, STRAFE, MOVE_RIGHT, MOVE_LEFT, MOVE_BACKWARD 54 | MOVE_FORWARD, TURN_RIGHT, TURN_LEFT, LOOK_UP, LOOK_DOWN, SELECT_NEXT_WEAPON, SELECT_PREV_WEAPON 55 | """ 56 | def __init__(self, env): 57 | super(ToDiscreteWrapper, self).__init__(env) 58 | if config == 'minimal': 59 | allowed_actions = ALLOWED_ACTIONS[self.unwrapped.level] 60 | elif config == 'constant-7': 61 | allowed_actions = [0, 10, 11, 13, 14, 15, 31] 62 | elif config == 'constant-17': 63 | allowed_actions = [0, 2, 3, 4, 6, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 31, 32] 64 | elif config == 'full': 65 | allowed_actions = None 66 | else: 67 | raise gym.error.Error('Invalid configuration. Valid options are "minimal", "constant-7", "constant-17", "full"') 68 | self.action_space = gym.spaces.multi_discrete.DiscreteToMultiDiscrete(self.action_space, allowed_actions) 69 | def _step(self, action): 70 | return self.env._step(self.action_space(action)) 71 | 72 | return ToDiscreteWrapper 73 | 74 | def ToBox(config): 75 | # Config can be 'minimal', 'constant-7', 'constant-17', 'full' 76 | 77 | class ToBoxWrapper(gym.Wrapper): 78 | """ 79 | Doom wrapper to convert MultiDiscrete action space to Box 80 | 81 | config: 82 | - minimal - Will only use the levels' allowed actions 83 | - constant-7 - Will use the 7 minimum actions to complete all levels 84 | - constant-17 - Will use the 17 most common actions to complete all levels 85 | - full - Will use all available actions 86 | 87 | list of commands: 88 | - minimal: 89 | Basic: ATTACK, MOVE_RIGHT, MOVE_LEFT 90 | Corridor: ATTACK, MOVE_RIGHT, MOVE_LEFT, MOVE_FORWARD, TURN_RIGHT, TURN_LEFT 91 | DefendCenter ATTACK, TURN_RIGHT, TURN_LEFT 92 | DefendLine: ATTACK, TURN_RIGHT, TURN_LEFT 93 | HealthGathering: MOVE_FORWARD, TURN_RIGHT, TURN_LEFT 94 | MyWayHome: MOVE_FORWARD, TURN_RIGHT, TURN_LEFT 95 | PredictPosition: ATTACK, TURN_RIGHT, TURN_LEFT 96 | TakeCover: MOVE_RIGHT, MOVE_LEFT 97 | Deathmatch: ALL COMMANDS 98 | 99 | - constant-7: ATTACK, MOVE_RIGHT, MOVE_LEFT, MOVE_FORWARD, TURN_RIGHT, TURN_LEFT, SELECT_NEXT_WEAPON 100 | 101 | - constant-17: ATTACK, JUMP, CROUCH, TURN180, RELOAD, SPEED, STRAFE, MOVE_RIGHT, MOVE_LEFT, MOVE_BACKWARD 102 | MOVE_FORWARD, TURN_RIGHT, TURN_LEFT, LOOK_UP, LOOK_DOWN, SELECT_NEXT_WEAPON, SELECT_PREV_WEAPON 103 | """ 104 | def __init__(self, env): 105 | super(ToBoxWrapper, self).__init__(env) 106 | if config == 'minimal': 107 | allowed_actions = ALLOWED_ACTIONS[self.unwrapped.level] 108 | elif config == 'constant-7': 109 | allowed_actions = [0, 10, 11, 13, 14, 15, 31] 110 | elif config == 'constant-17': 111 | allowed_actions = [0, 2, 3, 4, 6, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 31, 32] 112 | elif config == 'full': 113 | allowed_actions = None 114 | else: 115 | raise gym.error.Error('Invalid configuration. Valid options are "minimal", "constant-7", "constant-17", "full"') 116 | self.action_space = gym.spaces.multi_discrete.BoxToMultiDiscrete(self.action_space, allowed_actions) 117 | def _step(self, action): 118 | return self.env._step(self.action_space(action)) 119 | 120 | return ToBoxWrapper 121 | -------------------------------------------------------------------------------- /doomFiles/doom_env.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Place this file in: 3 | /home/pathak/projects/unsup-rl/unsuprl/local/lib/python2.7/site-packages/ppaquette_gym_doom/doom_env.py 4 | ''' 5 | 6 | import logging 7 | import os 8 | from time import sleep 9 | import multiprocessing 10 | 11 | import numpy as np 12 | 13 | import gym 14 | from gym import spaces, error 15 | from gym.utils import seeding 16 | 17 | try: 18 | import doom_py 19 | from doom_py import DoomGame, Mode, Button, GameVariable, ScreenFormat, ScreenResolution, Loader, doom_fixed_to_double 20 | from doom_py.vizdoom import ViZDoomUnexpectedExitException, ViZDoomErrorException 21 | except ImportError as e: 22 | raise gym.error.DependencyNotInstalled("{}. (HINT: you can install Doom dependencies " + 23 | "with 'pip install doom_py.)'".format(e)) 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | # Arguments: 28 | RANDOMIZE_MAPS = 80 # 0 means load default, otherwise randomly load in the id mentioned 29 | NO_MONSTERS = True # remove monster spawning 30 | 31 | # Constants 32 | NUM_ACTIONS = 43 33 | NUM_LEVELS = 9 34 | CONFIG = 0 35 | SCENARIO = 1 36 | MAP = 2 37 | DIFFICULTY = 3 38 | ACTIONS = 4 39 | MIN_SCORE = 5 40 | TARGET_SCORE = 6 41 | 42 | # Format (config, scenario, map, difficulty, actions, min, target) 43 | DOOM_SETTINGS = [ 44 | ['basic.cfg', 'basic.wad', 'map01', 5, [0, 10, 11], -485, 10], # 0 - Basic 45 | ['deadly_corridor.cfg', 'deadly_corridor.wad', '', 1, [0, 10, 11, 13, 14, 15], -120, 1000], # 1 - Corridor 46 | ['defend_the_center.cfg', 'defend_the_center.wad', '', 5, [0, 14, 15], -1, 10], # 2 - DefendCenter 47 | ['defend_the_line.cfg', 'defend_the_line.wad', '', 5, [0, 14, 15], -1, 15], # 3 - DefendLine 48 | ['health_gathering.cfg', 'health_gathering.wad', 'map01', 5, [13, 14, 15], 0, 1000], # 4 - HealthGathering 49 | ['my_way_home.cfg', 'my_way_home_dense.wad', '', 5, [13, 14, 15], -0.22, 0.5], # 5 - MyWayHome 50 | ['predict_position.cfg', 'predict_position.wad', 'map01', 3, [0, 14, 15], -0.075, 0.5], # 6 - PredictPosition 51 | ['take_cover.cfg', 'take_cover.wad', 'map01', 5, [10, 11], 0, 750], # 7 - TakeCover 52 | ['deathmatch.cfg', 'deathmatch.wad', '', 5, [x for x in range(NUM_ACTIONS) if x != 33], 0, 20], # 8 - Deathmatch 53 | ['my_way_home.cfg', 'my_way_home_sparse.wad', '', 5, [13, 14, 15], -0.22, 0.5], # 9 - MyWayHomeFixed 54 | ['my_way_home.cfg', 'my_way_home_verySparse.wad', '', 5, [13, 14, 15], -0.22, 0.5], # 10 - MyWayHomeFixed15 55 | ] 56 | 57 | # Singleton pattern 58 | class DoomLock: 59 | class __DoomLock: 60 | def __init__(self): 61 | self.lock = multiprocessing.Lock() 62 | instance = None 63 | def __init__(self): 64 | if not DoomLock.instance: 65 | DoomLock.instance = DoomLock.__DoomLock() 66 | def get_lock(self): 67 | return DoomLock.instance.lock 68 | 69 | 70 | class DoomEnv(gym.Env): 71 | metadata = {'render.modes': ['human', 'rgb_array'], 'video.frames_per_second': 35} 72 | 73 | def __init__(self, level): 74 | self.previous_level = -1 75 | self.level = level 76 | self.game = DoomGame() 77 | self.loader = Loader() 78 | self.doom_dir = os.path.dirname(os.path.abspath(__file__)) 79 | self._mode = 'algo' # 'algo' or 'human' 80 | self.no_render = False # To disable double rendering in human mode 81 | self.viewer = None 82 | self.is_initialized = False # Indicates that reset() has been called 83 | self.curr_seed = 0 84 | self.lock = (DoomLock()).get_lock() 85 | self.action_space = spaces.MultiDiscrete([[0, 1]] * 38 + [[-10, 10]] * 2 + [[-100, 100]] * 3) 86 | self.allowed_actions = list(range(NUM_ACTIONS)) 87 | self.screen_height = 480 88 | self.screen_width = 640 89 | self.screen_resolution = ScreenResolution.RES_640X480 90 | self.observation_space = spaces.Box(low=0, high=255, shape=(self.screen_height, self.screen_width, 3)) 91 | self._seed() 92 | self._configure() 93 | 94 | def _configure(self, lock=None, **kwargs): 95 | if 'screen_resolution' in kwargs: 96 | logger.warn('Deprecated - Screen resolution must now be set using a wrapper. See documentation for details.') 97 | # Multiprocessing lock 98 | if lock is not None: 99 | self.lock = lock 100 | 101 | def _load_level(self): 102 | # Closing if is_initialized 103 | if self.is_initialized: 104 | self.is_initialized = False 105 | self.game.close() 106 | self.game = DoomGame() 107 | 108 | # Customizing level 109 | if getattr(self, '_customize_game', None) is not None and callable(self._customize_game): 110 | self.level = -1 111 | self._customize_game() 112 | 113 | else: 114 | # Loading Paths 115 | if not self.is_initialized: 116 | self.game.set_vizdoom_path(self.loader.get_vizdoom_path()) 117 | self.game.set_doom_game_path(self.loader.get_freedoom_path()) 118 | 119 | # Common settings 120 | self.game.load_config(os.path.join(self.doom_dir, 'assets/%s' % DOOM_SETTINGS[self.level][CONFIG])) 121 | self.game.set_doom_scenario_path(self.loader.get_scenario_path(DOOM_SETTINGS[self.level][SCENARIO])) 122 | if DOOM_SETTINGS[self.level][MAP] != '': 123 | if RANDOMIZE_MAPS > 0 and 'labyrinth' in DOOM_SETTINGS[self.level][CONFIG].lower(): 124 | if 'fix' in DOOM_SETTINGS[self.level][SCENARIO].lower(): 125 | # mapId = 'map%02d'%np.random.randint(1, 23) 126 | mapId = 'map%02d'%np.random.randint(4, 8) 127 | else: 128 | mapId = 'map%02d'%np.random.randint(1, RANDOMIZE_MAPS+1) 129 | print('\t=> Special Config: Randomly Loading Maps. MapID = ' + mapId) 130 | self.game.set_doom_map(mapId) 131 | else: 132 | print('\t=> Default map loaded. MapID = ' + DOOM_SETTINGS[self.level][MAP]) 133 | self.game.set_doom_map(DOOM_SETTINGS[self.level][MAP]) 134 | self.game.set_doom_skill(DOOM_SETTINGS[self.level][DIFFICULTY]) 135 | self.allowed_actions = DOOM_SETTINGS[self.level][ACTIONS] 136 | self.game.set_screen_resolution(self.screen_resolution) 137 | 138 | self.previous_level = self.level 139 | self._closed = False 140 | 141 | # Algo mode 142 | if 'human' != self._mode: 143 | if NO_MONSTERS: 144 | print('\t=> Special Config: Monsters Removed.') 145 | self.game.add_game_args('-nomonsters 1') 146 | self.game 147 | self.game.set_window_visible(False) 148 | self.game.set_mode(Mode.PLAYER) 149 | self.no_render = False 150 | try: 151 | with self.lock: 152 | self.game.init() 153 | except (ViZDoomUnexpectedExitException, ViZDoomErrorException): 154 | raise error.Error( 155 | 'VizDoom exited unexpectedly. This is likely caused by a missing multiprocessing lock. ' + 156 | 'To run VizDoom across multiple processes, you need to pass a lock when you configure the env ' + 157 | '[e.g. env.configure(lock=my_multiprocessing_lock)], or create and close an env ' + 158 | 'before starting your processes [e.g. env = gym.make("DoomBasic-v0"); env.close()] to cache a ' + 159 | 'singleton lock in memory.') 160 | self._start_episode() 161 | self.is_initialized = True 162 | return self.game.get_state().image_buffer.copy() 163 | 164 | # Human mode 165 | else: 166 | if NO_MONSTERS: 167 | print('\t=> Special Config: Monsters Removed.') 168 | self.game.add_game_args('-nomonsters 1') 169 | self.game.add_game_args('+freelook 1') 170 | self.game.set_window_visible(True) 171 | self.game.set_mode(Mode.SPECTATOR) 172 | self.no_render = True 173 | with self.lock: 174 | self.game.init() 175 | self._start_episode() 176 | self.is_initialized = True 177 | self._play_human_mode() 178 | return np.zeros(shape=self.observation_space.shape, dtype=np.uint8) 179 | 180 | def _start_episode(self): 181 | if self.curr_seed > 0: 182 | self.game.set_seed(self.curr_seed) 183 | self.curr_seed = 0 184 | self.game.new_episode() 185 | return 186 | 187 | def _play_human_mode(self): 188 | while not self.game.is_episode_finished(): 189 | self.game.advance_action() 190 | state = self.game.get_state() 191 | total_reward = self.game.get_total_reward() 192 | info = self._get_game_variables(state.game_variables) 193 | info["TOTAL_REWARD"] = round(total_reward, 4) 194 | print('===============================') 195 | print('State: #' + str(state.number)) 196 | print('Action: \t' + str(self.game.get_last_action()) + '\t (=> only allowed actions)') 197 | print('Reward: \t' + str(self.game.get_last_reward())) 198 | print('Total Reward: \t' + str(total_reward)) 199 | print('Variables: \n' + str(info)) 200 | sleep(0.02857) # 35 fps = 0.02857 sleep between frames 201 | print('===============================') 202 | print('Done') 203 | return 204 | 205 | def _step(self, action): 206 | if NUM_ACTIONS != len(action): 207 | logger.warn('Doom action list must contain %d items. Padding missing items with 0' % NUM_ACTIONS) 208 | old_action = action 209 | action = [0] * NUM_ACTIONS 210 | for i in range(len(old_action)): 211 | action[i] = old_action[i] 212 | # action is a list of numbers but DoomGame.make_action expects a list of ints 213 | if len(self.allowed_actions) > 0: 214 | list_action = [int(action[action_idx]) for action_idx in self.allowed_actions] 215 | else: 216 | list_action = [int(x) for x in action] 217 | try: 218 | reward = self.game.make_action(list_action) 219 | state = self.game.get_state() 220 | info = self._get_game_variables(state.game_variables) 221 | info["TOTAL_REWARD"] = round(self.game.get_total_reward(), 4) 222 | 223 | if self.game.is_episode_finished(): 224 | is_finished = True 225 | return np.zeros(shape=self.observation_space.shape, dtype=np.uint8), reward, is_finished, info 226 | else: 227 | is_finished = False 228 | return state.image_buffer.copy(), reward, is_finished, info 229 | 230 | except doom_py.vizdoom.ViZDoomIsNotRunningException: 231 | return np.zeros(shape=self.observation_space.shape, dtype=np.uint8), 0, True, {} 232 | 233 | def _reset(self): 234 | if self.is_initialized and not self._closed: 235 | self._start_episode() 236 | image_buffer = self.game.get_state().image_buffer 237 | if image_buffer is None: 238 | raise error.Error( 239 | 'VizDoom incorrectly initiated. This is likely caused by a missing multiprocessing lock. ' + 240 | 'To run VizDoom across multiple processes, you need to pass a lock when you configure the env ' + 241 | '[e.g. env.configure(lock=my_multiprocessing_lock)], or create and close an env ' + 242 | 'before starting your processes [e.g. env = gym.make("DoomBasic-v0"); env.close()] to cache a ' + 243 | 'singleton lock in memory.') 244 | return image_buffer.copy() 245 | else: 246 | return self._load_level() 247 | 248 | def _render(self, mode='human', close=False): 249 | if close: 250 | if self.viewer is not None: 251 | self.viewer.close() 252 | self.viewer = None # If we don't None out this reference pyglet becomes unhappy 253 | return 254 | try: 255 | if 'human' == mode and self.no_render: 256 | return 257 | state = self.game.get_state() 258 | img = state.image_buffer 259 | # VizDoom returns None if the episode is finished, let's make it 260 | # an empty image so the recorder doesn't stop 261 | if img is None: 262 | img = np.zeros(shape=self.observation_space.shape, dtype=np.uint8) 263 | if mode == 'rgb_array': 264 | return img 265 | elif mode is 'human': 266 | from gym.envs.classic_control import rendering 267 | if self.viewer is None: 268 | self.viewer = rendering.SimpleImageViewer() 269 | self.viewer.imshow(img) 270 | except doom_py.vizdoom.ViZDoomIsNotRunningException: 271 | pass # Doom has been closed 272 | 273 | def _close(self): 274 | # Lock required for VizDoom to close processes properly 275 | with self.lock: 276 | self.game.close() 277 | 278 | def _seed(self, seed=None): 279 | self.curr_seed = seeding.hash_seed(seed) % 2 ** 32 280 | return [self.curr_seed] 281 | 282 | def _get_game_variables(self, state_variables): 283 | info = { 284 | "LEVEL": self.level 285 | } 286 | if state_variables is None: 287 | return info 288 | info['KILLCOUNT'] = state_variables[0] 289 | info['ITEMCOUNT'] = state_variables[1] 290 | info['SECRETCOUNT'] = state_variables[2] 291 | info['FRAGCOUNT'] = state_variables[3] 292 | info['HEALTH'] = state_variables[4] 293 | info['ARMOR'] = state_variables[5] 294 | info['DEAD'] = state_variables[6] 295 | info['ON_GROUND'] = state_variables[7] 296 | info['ATTACK_READY'] = state_variables[8] 297 | info['ALTATTACK_READY'] = state_variables[9] 298 | info['SELECTED_WEAPON'] = state_variables[10] 299 | info['SELECTED_WEAPON_AMMO'] = state_variables[11] 300 | info['AMMO1'] = state_variables[12] 301 | info['AMMO2'] = state_variables[13] 302 | info['AMMO3'] = state_variables[14] 303 | info['AMMO4'] = state_variables[15] 304 | info['AMMO5'] = state_variables[16] 305 | info['AMMO6'] = state_variables[17] 306 | info['AMMO7'] = state_variables[18] 307 | info['AMMO8'] = state_variables[19] 308 | info['AMMO9'] = state_variables[20] 309 | info['AMMO0'] = state_variables[21] 310 | info['POSITION_X'] = doom_fixed_to_double(self.game.get_game_variable(GameVariable.USER1)) 311 | info['POSITION_Y'] = doom_fixed_to_double(self.game.get_game_variable(GameVariable.USER2)) 312 | return info 313 | 314 | 315 | class MetaDoomEnv(DoomEnv): 316 | 317 | def __init__(self, average_over=10, passing_grade=600, min_tries_for_avg=5): 318 | super(MetaDoomEnv, self).__init__(0) 319 | self.average_over = average_over 320 | self.passing_grade = passing_grade 321 | self.min_tries_for_avg = min_tries_for_avg # Need to use at least this number of tries to calc avg 322 | self.scores = [[]] * NUM_LEVELS 323 | self.locked_levels = [True] * NUM_LEVELS # Locking all levels but the first 324 | self.locked_levels[0] = False 325 | self.total_reward = 0 326 | self.find_new_level = False # Indicates that we need a level change 327 | self._unlock_levels() 328 | 329 | def _play_human_mode(self): 330 | while not self.game.is_episode_finished(): 331 | self.game.advance_action() 332 | state = self.game.get_state() 333 | episode_reward = self.game.get_total_reward() 334 | (reward, self.total_reward) = self._calculate_reward(episode_reward, self.total_reward) 335 | info = self._get_game_variables(state.game_variables) 336 | info["SCORES"] = self.get_scores() 337 | info["TOTAL_REWARD"] = round(self.total_reward, 4) 338 | info["LOCKED_LEVELS"] = self.locked_levels 339 | print('===============================') 340 | print('State: #' + str(state.number)) 341 | print('Action: \t' + str(self.game.get_last_action()) + '\t (=> only allowed actions)') 342 | print('Reward: \t' + str(reward)) 343 | print('Total Reward: \t' + str(self.total_reward)) 344 | print('Variables: \n' + str(info)) 345 | sleep(0.02857) # 35 fps = 0.02857 sleep between frames 346 | print('===============================') 347 | print('Done') 348 | return 349 | 350 | def _get_next_level(self): 351 | # Finds the unlocked level with the lowest average 352 | averages = self.get_scores() 353 | lowest_level = 0 # Defaulting to first level 354 | lowest_score = 1001 355 | for i in range(NUM_LEVELS): 356 | if not self.locked_levels[i]: 357 | if averages[i] < lowest_score: 358 | lowest_level = i 359 | lowest_score = averages[i] 360 | return lowest_level 361 | 362 | def _unlock_levels(self): 363 | averages = self.get_scores() 364 | for i in range(NUM_LEVELS - 2, -1, -1): 365 | if self.locked_levels[i + 1] and averages[i] >= self.passing_grade: 366 | self.locked_levels[i + 1] = False 367 | return 368 | 369 | def _start_episode(self): 370 | if 0 == len(self.scores[self.level]): 371 | self.scores[self.level] = [0] * self.min_tries_for_avg 372 | else: 373 | self.scores[self.level].insert(0, 0) 374 | self.scores[self.level] = self.scores[self.level][:self.min_tries_for_avg] 375 | self.is_new_episode = True 376 | return super(MetaDoomEnv, self)._start_episode() 377 | 378 | def change_level(self, new_level=None): 379 | if new_level is not None and self.locked_levels[new_level] == False: 380 | self.find_new_level = False 381 | self.level = new_level 382 | self.reset() 383 | else: 384 | self.find_new_level = False 385 | self.level = self._get_next_level() 386 | self.reset() 387 | return 388 | 389 | def _get_standard_reward(self, episode_reward): 390 | # Returns a standardized reward for an episode (i.e. between 0 and 1,000) 391 | min_score = float(DOOM_SETTINGS[self.level][MIN_SCORE]) 392 | target_score = float(DOOM_SETTINGS[self.level][TARGET_SCORE]) 393 | max_score = min_score + (target_score - min_score) / 0.99 # Target is 99th percentile (Scale 0-1000) 394 | std_reward = round(1000 * (episode_reward - min_score) / (max_score - min_score), 4) 395 | std_reward = min(1000, std_reward) # Cannot be more than 1,000 396 | std_reward = max(0, std_reward) # Cannot be less than 0 397 | return std_reward 398 | 399 | def get_total_reward(self): 400 | # Returns the sum of the average of all levels 401 | total_score = 0 402 | passed_levels = 0 403 | for i in range(NUM_LEVELS): 404 | if len(self.scores[i]) > 0: 405 | level_total = 0 406 | level_count = min(len(self.scores[i]), self.average_over) 407 | for j in range(level_count): 408 | level_total += self.scores[i][j] 409 | level_average = level_total / level_count 410 | if level_average >= 990: 411 | passed_levels += 1 412 | total_score += level_average 413 | # Bonus for passing all levels (50 * num of levels) 414 | if NUM_LEVELS == passed_levels: 415 | total_score += NUM_LEVELS * 50 416 | return round(total_score, 4) 417 | 418 | def _calculate_reward(self, episode_reward, prev_total_reward): 419 | # Calculates the action reward and the new total reward 420 | std_reward = self._get_standard_reward(episode_reward) 421 | self.scores[self.level][0] = std_reward 422 | total_reward = self.get_total_reward() 423 | reward = total_reward - prev_total_reward 424 | return reward, total_reward 425 | 426 | def get_scores(self): 427 | # Returns a list with the averages per level 428 | averages = [0] * NUM_LEVELS 429 | for i in range(NUM_LEVELS): 430 | if len(self.scores[i]) > 0: 431 | level_total = 0 432 | level_count = min(len(self.scores[i]), self.average_over) 433 | for j in range(level_count): 434 | level_total += self.scores[i][j] 435 | level_average = level_total / level_count 436 | averages[i] = round(level_average, 4) 437 | return averages 438 | 439 | def _reset(self): 440 | # Reset is called on first step() after level is finished 441 | # or when change_level() is called. Returning if neither have been called to 442 | # avoid resetting the level twice 443 | if self.find_new_level: 444 | return 445 | 446 | if self.is_initialized and not self._closed and self.previous_level == self.level: 447 | self._start_episode() 448 | return self.game.get_state().image_buffer.copy() 449 | else: 450 | return self._load_level() 451 | 452 | def _step(self, action): 453 | # Changing level 454 | if self.find_new_level: 455 | self.change_level() 456 | 457 | if 'human' == self._mode: 458 | self._play_human_mode() 459 | obs = np.zeros(shape=self.observation_space.shape, dtype=np.uint8) 460 | reward = 0 461 | is_finished = True 462 | info = self._get_game_variables(None) 463 | else: 464 | obs, step_reward, is_finished, info = super(MetaDoomEnv, self)._step(action) 465 | reward, self.total_reward = self._calculate_reward(self.game.get_total_reward(), self.total_reward) 466 | # First step() after new episode returns the entire total reward 467 | # because stats_recorder resets the episode score to 0 after reset() is called 468 | if self.is_new_episode: 469 | reward = self.total_reward 470 | 471 | self.is_new_episode = False 472 | info["SCORES"] = self.get_scores() 473 | info["TOTAL_REWARD"] = round(self.total_reward, 4) 474 | info["LOCKED_LEVELS"] = self.locked_levels 475 | 476 | # Indicating new level required 477 | if is_finished: 478 | self._unlock_levels() 479 | self.find_new_level = True 480 | 481 | return obs, reward, is_finished, info 482 | -------------------------------------------------------------------------------- /doomFiles/doom_my_way_home_sparse.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from .doom_env import DoomEnv 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | class DoomMyWayHomeFixedEnv(DoomEnv): 8 | """ 9 | ------------ Training Mission 10 - My Way Home Fixed ------------ 10 | Exactly same as Mission#6, but with fixed start from room-10 (farthest). 11 | This map is designed to improve navigational skills. It is a series of 12 | interconnected rooms and 1 corridor with a dead end. Each room 13 | has a separate color. There is a green vest in one of the room. 14 | The vest is always in the same room. Player must find the vest. 15 | 16 | Allowed actions: 17 | [13] - MOVE_FORWARD - Move forward - Values 0 or 1 18 | [14] - TURN_RIGHT - Turn right - Values 0 or 1 19 | [15] - TURN_LEFT - Turn left - Values 0 or 1 20 | Note: see controls.md for details 21 | 22 | Rewards: 23 | + 1 - Finding the vest 24 | -0.0001 - 35 times per second - Find the vest quick! 25 | 26 | Goal: 0.50 point 27 | Find the vest 28 | 29 | Ends when: 30 | - Vest is found 31 | - Timeout (1 minutes - 2,100 frames) 32 | 33 | Actions: 34 | actions = [0] * 43 35 | actions[13] = 0 # MOVE_FORWARD 36 | actions[14] = 1 # TURN_RIGHT 37 | actions[15] = 0 # TURN_LEFT 38 | 39 | Configuration: 40 | After creating the env, you can call env.configure() to configure some parameters. 41 | 42 | - lock [e.g. env.configure(lock=multiprocessing_lock)] 43 | 44 | VizDoom requires a multiprocessing lock when running across multiple processes, otherwise the vizdoom instance 45 | might crash on launch 46 | 47 | You can either: 48 | 49 | 1) [Preferred] Create a multiprocessing.Lock() and pass it as a parameter to the configure() method 50 | [e.g. env.configure(lock=multiprocessing_lock)] 51 | 52 | 2) Create and close a Doom environment before running your multiprocessing routine, this will create 53 | a singleton lock that will be cached in memory, and be used by all Doom environments afterwards 54 | [e.g. env = gym.make('Doom-...'); env.close()] 55 | 56 | 3) Manually wrap calls to reset() and close() in a multiprocessing.Lock() 57 | 58 | Wrappers: 59 | 60 | You can use wrappers to further customize the environment. Wrappers need to be manually copied from the wrappers folder. 61 | 62 | theWrapperOne = WrapperOneName(init_options) 63 | theWrapperTwo = WrapperTwoName(init_options) 64 | env = gym.make('ppaquette/DoomMyWayHome-v0') 65 | env = theWrapperTwo(theWrapperOne((env)) 66 | 67 | - Observation space: 68 | 69 | You can change the resolution by using the SetResolution wrapper. 70 | 71 | wrapper = SetResolution(target_resolution) 72 | env = wrapper(env) 73 | 74 | The following are valid target_resolution that can be used: 75 | 76 | '160x120', '200x125', '200x150', '256x144', '256x160', '256x192', '320x180', '320x200', 77 | '320x240', '320x256', '400x225', '400x250', '400x300', '512x288', '512x320', '512x384', 78 | '640x360', '640x400', '640x480', '800x450', '800x500', '800x600', '1024x576', '1024x640', 79 | '1024x768', '1280x720', '1280x800', '1280x960', '1280x1024', '1400x787', '1400x875', 80 | '1400x1050', '1600x900', '1600x1000', '1600x1200', '1920x1080' 81 | 82 | - Action space: 83 | 84 | You can change the action space by using the ToDiscrete or ToBox wrapper 85 | 86 | wrapper = ToBox(config_options) 87 | env = wrapper(env) 88 | 89 | The following are valid config options (for both ToDiscrete and ToBox) 90 | 91 | - minimal - Only the level's allowed actions (and NOOP for discrete) 92 | - constant-7 - 7 minimum actions required to complete all levels (and NOOP for discrete) 93 | - constant-17 - 17 most common actions required to complete all levels (and NOOP for discrete) 94 | - full - All available actions (and NOOP for discrete) 95 | 96 | Note: Discrete action spaces only allow one action at a time, Box action spaces support simultaneous actions 97 | 98 | - Control: 99 | 100 | You can play the game manually with the SetPlayingMode wrapper. 101 | 102 | wrapper = SetPlayingMode('human') 103 | env = wrapper(env) 104 | 105 | Valid options are 'human' or 'algo' (default) 106 | 107 | ----------------------------------------------------- 108 | """ 109 | def __init__(self): 110 | super(DoomMyWayHomeFixedEnv, self).__init__(9) 111 | -------------------------------------------------------------------------------- /doomFiles/doom_my_way_home_verySparse.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from .doom_env import DoomEnv 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | class DoomMyWayHomeFixed15Env(DoomEnv): 8 | """ 9 | ------------ Training Mission 11 - My Way Home Fixed15 ------------ 10 | Exactly same as Mission#6, but with fixed start from room-10 (farthest). 11 | This map is designed to improve navigational skills. It is a series of 12 | interconnected rooms and 1 corridor with a dead end. Each room 13 | has a separate color. There is a green vest in one of the room. 14 | The vest is always in the same room. Player must find the vest. 15 | 16 | Allowed actions: 17 | [13] - MOVE_FORWARD - Move forward - Values 0 or 1 18 | [14] - TURN_RIGHT - Turn right - Values 0 or 1 19 | [15] - TURN_LEFT - Turn left - Values 0 or 1 20 | Note: see controls.md for details 21 | 22 | Rewards: 23 | + 1 - Finding the vest 24 | -0.0001 - 35 times per second - Find the vest quick! 25 | 26 | Goal: 0.50 point 27 | Find the vest 28 | 29 | Ends when: 30 | - Vest is found 31 | - Timeout (1 minutes - 2,100 frames) 32 | 33 | Actions: 34 | actions = [0] * 43 35 | actions[13] = 0 # MOVE_FORWARD 36 | actions[14] = 1 # TURN_RIGHT 37 | actions[15] = 0 # TURN_LEFT 38 | 39 | Configuration: 40 | After creating the env, you can call env.configure() to configure some parameters. 41 | 42 | - lock [e.g. env.configure(lock=multiprocessing_lock)] 43 | 44 | VizDoom requires a multiprocessing lock when running across multiple processes, otherwise the vizdoom instance 45 | might crash on launch 46 | 47 | You can either: 48 | 49 | 1) [Preferred] Create a multiprocessing.Lock() and pass it as a parameter to the configure() method 50 | [e.g. env.configure(lock=multiprocessing_lock)] 51 | 52 | 2) Create and close a Doom environment before running your multiprocessing routine, this will create 53 | a singleton lock that will be cached in memory, and be used by all Doom environments afterwards 54 | [e.g. env = gym.make('Doom-...'); env.close()] 55 | 56 | 3) Manually wrap calls to reset() and close() in a multiprocessing.Lock() 57 | 58 | Wrappers: 59 | 60 | You can use wrappers to further customize the environment. Wrappers need to be manually copied from the wrappers folder. 61 | 62 | theWrapperOne = WrapperOneName(init_options) 63 | theWrapperTwo = WrapperTwoName(init_options) 64 | env = gym.make('ppaquette/DoomMyWayHome-v0') 65 | env = theWrapperTwo(theWrapperOne((env)) 66 | 67 | - Observation space: 68 | 69 | You can change the resolution by using the SetResolution wrapper. 70 | 71 | wrapper = SetResolution(target_resolution) 72 | env = wrapper(env) 73 | 74 | The following are valid target_resolution that can be used: 75 | 76 | '160x120', '200x125', '200x150', '256x144', '256x160', '256x192', '320x180', '320x200', 77 | '320x240', '320x256', '400x225', '400x250', '400x300', '512x288', '512x320', '512x384', 78 | '640x360', '640x400', '640x480', '800x450', '800x500', '800x600', '1024x576', '1024x640', 79 | '1024x768', '1280x720', '1280x800', '1280x960', '1280x1024', '1400x787', '1400x875', 80 | '1400x1050', '1600x900', '1600x1000', '1600x1200', '1920x1080' 81 | 82 | - Action space: 83 | 84 | You can change the action space by using the ToDiscrete or ToBox wrapper 85 | 86 | wrapper = ToBox(config_options) 87 | env = wrapper(env) 88 | 89 | The following are valid config options (for both ToDiscrete and ToBox) 90 | 91 | - minimal - Only the level's allowed actions (and NOOP for discrete) 92 | - constant-7 - 7 minimum actions required to complete all levels (and NOOP for discrete) 93 | - constant-17 - 17 most common actions required to complete all levels (and NOOP for discrete) 94 | - full - All available actions (and NOOP for discrete) 95 | 96 | Note: Discrete action spaces only allow one action at a time, Box action spaces support simultaneous actions 97 | 98 | - Control: 99 | 100 | You can play the game manually with the SetPlayingMode wrapper. 101 | 102 | wrapper = SetPlayingMode('human') 103 | env = wrapper(env) 104 | 105 | Valid options are 'human' or 'algo' (default) 106 | 107 | ----------------------------------------------------- 108 | """ 109 | def __init__(self): 110 | super(DoomMyWayHomeFixed15Env, self).__init__(10) 111 | -------------------------------------------------------------------------------- /doomFiles/wads/my_way_home_dense.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pathak22/noreward-rl/3e220c2177fc253916f12d980957fc40579d577a/doomFiles/wads/my_way_home_dense.wad -------------------------------------------------------------------------------- /doomFiles/wads/my_way_home_sparse.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pathak22/noreward-rl/3e220c2177fc253916f12d980957fc40579d577a/doomFiles/wads/my_way_home_sparse.wad -------------------------------------------------------------------------------- /doomFiles/wads/my_way_home_verySparse.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pathak22/noreward-rl/3e220c2177fc253916f12d980957fc40579d577a/doomFiles/wads/my_way_home_verySparse.wad -------------------------------------------------------------------------------- /images/mario1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pathak22/noreward-rl/3e220c2177fc253916f12d980957fc40579d577a/images/mario1.gif -------------------------------------------------------------------------------- /images/mario2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pathak22/noreward-rl/3e220c2177fc253916f12d980957fc40579d577a/images/mario2.gif -------------------------------------------------------------------------------- /images/vizdoom.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pathak22/noreward-rl/3e220c2177fc253916f12d980957fc40579d577a/images/vizdoom.gif -------------------------------------------------------------------------------- /models/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )/" && pwd )" 4 | cd $DIR 5 | 6 | FILE=models.tar.gz 7 | URL=https://people.eecs.berkeley.edu/~pathak/noreward-rl/resources/$FILE 8 | CHECKSUM=26bdf54e9562e23750ebc2ef503204b1 9 | 10 | if [ ! -f $FILE ]; then 11 | echo "Downloading the curiosity-driven RL trained models (6MB)..." 12 | wget $URL -O $FILE 13 | echo "Unzipping..." 14 | tar zxvf $FILE 15 | mv models/* . 16 | rm -rf models 17 | echo "Downloading Done." 18 | else 19 | echo "File already exists. Checking md5..." 20 | fi 21 | 22 | os=`uname -s` 23 | if [ "$os" = "Linux" ]; then 24 | checksum=`md5sum $FILE | awk '{ print $1 }'` 25 | elif [ "$os" = "Darwin" ]; then 26 | checksum=`cat $FILE | md5` 27 | elif [ "$os" = "SunOS" ]; then 28 | checksum=`digest -a md5 -v $FILE | awk '{ print $4 }'` 29 | fi 30 | if [ "$checksum" = "$CHECKSUM" ]; then 31 | echo "Checksum is correct. File was correctly downloaded." 32 | exit 0 33 | else 34 | echo "Checksum is incorrect. DELETE and download again." 35 | fi 36 | -------------------------------------------------------------------------------- /src/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | unsuprl/ 85 | 86 | # Spyder project settings 87 | .spyderproject 88 | 89 | # Rope project settings 90 | .ropeproject 91 | tmp/ 92 | 93 | # vizdoom cache 94 | vizdoom.ini 95 | -------------------------------------------------------------------------------- /src/a3c.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from collections import namedtuple 3 | import numpy as np 4 | import tensorflow as tf 5 | from model import LSTMPolicy, StateActionPredictor, StatePredictor 6 | import six.moves.queue as queue 7 | import scipy.signal 8 | import threading 9 | import distutils.version 10 | from constants import constants 11 | use_tf12_api = distutils.version.LooseVersion(tf.VERSION) >= distutils.version.LooseVersion('0.12.0') 12 | 13 | def discount(x, gamma): 14 | """ 15 | x = [r1, r2, r3, ..., rN] 16 | returns [r1 + r2*gamma + r3*gamma^2 + ..., 17 | r2 + r3*gamma + r4*gamma^2 + ..., 18 | r3 + r4*gamma + r5*gamma^2 + ..., 19 | ..., ..., rN] 20 | """ 21 | return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1] 22 | 23 | def process_rollout(rollout, gamma, lambda_=1.0, clip=False): 24 | """ 25 | Given a rollout, compute its returns and the advantage. 26 | """ 27 | # collecting transitions 28 | if rollout.unsup: 29 | batch_si = np.asarray(rollout.states + [rollout.end_state]) 30 | else: 31 | batch_si = np.asarray(rollout.states) 32 | batch_a = np.asarray(rollout.actions) 33 | 34 | # collecting target for value network 35 | # V_t <-> r_t + gamma*r_{t+1} + ... + gamma^n*r_{t+n} + gamma^{n+1}*V_{n+1} 36 | rewards_plus_v = np.asarray(rollout.rewards + [rollout.r]) # bootstrapping 37 | if rollout.unsup: 38 | rewards_plus_v += np.asarray(rollout.bonuses + [0]) 39 | if clip: 40 | rewards_plus_v[:-1] = np.clip(rewards_plus_v[:-1], -constants['REWARD_CLIP'], constants['REWARD_CLIP']) 41 | batch_r = discount(rewards_plus_v, gamma)[:-1] # value network target 42 | 43 | # collecting target for policy network 44 | rewards = np.asarray(rollout.rewards) 45 | if rollout.unsup: 46 | rewards += np.asarray(rollout.bonuses) 47 | if clip: 48 | rewards = np.clip(rewards, -constants['REWARD_CLIP'], constants['REWARD_CLIP']) 49 | vpred_t = np.asarray(rollout.values + [rollout.r]) 50 | # "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438 51 | # Eq (10): delta_t = Rt + gamma*V_{t+1} - V_t 52 | # Eq (16): batch_adv_t = delta_t + gamma*delta_{t+1} + gamma^2*delta_{t+2} + ... 53 | delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1] 54 | batch_adv = discount(delta_t, gamma * lambda_) 55 | 56 | features = rollout.features[0] 57 | 58 | return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.terminal, features) 59 | 60 | Batch = namedtuple("Batch", ["si", "a", "adv", "r", "terminal", "features"]) 61 | 62 | class PartialRollout(object): 63 | """ 64 | A piece of a complete rollout. We run our agent, and process its experience 65 | once it has processed enough steps. 66 | """ 67 | def __init__(self, unsup=False): 68 | self.states = [] 69 | self.actions = [] 70 | self.rewards = [] 71 | self.values = [] 72 | self.r = 0.0 73 | self.terminal = False 74 | self.features = [] 75 | self.unsup = unsup 76 | if self.unsup: 77 | self.bonuses = [] 78 | self.end_state = None 79 | 80 | 81 | def add(self, state, action, reward, value, terminal, features, 82 | bonus=None, end_state=None): 83 | self.states += [state] 84 | self.actions += [action] 85 | self.rewards += [reward] 86 | self.values += [value] 87 | self.terminal = terminal 88 | self.features += [features] 89 | if self.unsup: 90 | self.bonuses += [bonus] 91 | self.end_state = end_state 92 | 93 | def extend(self, other): 94 | assert not self.terminal 95 | self.states.extend(other.states) 96 | self.actions.extend(other.actions) 97 | self.rewards.extend(other.rewards) 98 | self.values.extend(other.values) 99 | self.r = other.r 100 | self.terminal = other.terminal 101 | self.features.extend(other.features) 102 | if self.unsup: 103 | self.bonuses.extend(other.bonuses) 104 | self.end_state = other.end_state 105 | 106 | class RunnerThread(threading.Thread): 107 | """ 108 | One of the key distinctions between a normal environment and a universe environment 109 | is that a universe environment is _real time_. This means that there should be a thread 110 | that would constantly interact with the environment and tell it what to do. This thread is here. 111 | """ 112 | def __init__(self, env, policy, num_local_steps, visualise, predictor, envWrap, 113 | noReward): 114 | threading.Thread.__init__(self) 115 | self.queue = queue.Queue(5) # ideally, should be 1. Mostly doesn't matter in our case. 116 | self.num_local_steps = num_local_steps 117 | self.env = env 118 | self.last_features = None 119 | self.policy = policy 120 | self.daemon = True 121 | self.sess = None 122 | self.summary_writer = None 123 | self.visualise = visualise 124 | self.predictor = predictor 125 | self.envWrap = envWrap 126 | self.noReward = noReward 127 | 128 | def start_runner(self, sess, summary_writer): 129 | self.sess = sess 130 | self.summary_writer = summary_writer 131 | self.start() 132 | 133 | def run(self): 134 | with self.sess.as_default(): 135 | self._run() 136 | 137 | def _run(self): 138 | rollout_provider = env_runner(self.env, self.policy, self.num_local_steps, 139 | self.summary_writer, self.visualise, self.predictor, 140 | self.envWrap, self.noReward) 141 | while True: 142 | # the timeout variable exists because apparently, if one worker dies, the other workers 143 | # won't die with it, unless the timeout is set to some large number. This is an empirical 144 | # observation. 145 | 146 | self.queue.put(next(rollout_provider), timeout=600.0) 147 | 148 | 149 | def env_runner(env, policy, num_local_steps, summary_writer, render, predictor, 150 | envWrap, noReward): 151 | """ 152 | The logic of the thread runner. In brief, it constantly keeps on running 153 | the policy, and as long as the rollout exceeds a certain length, the thread 154 | runner appends the policy to the queue. 155 | """ 156 | last_state = env.reset() 157 | last_features = policy.get_initial_features() # reset lstm memory 158 | length = 0 159 | rewards = 0 160 | values = 0 161 | if predictor is not None: 162 | ep_bonus = 0 163 | life_bonus = 0 164 | 165 | while True: 166 | terminal_end = False 167 | rollout = PartialRollout(predictor is not None) 168 | 169 | for _ in range(num_local_steps): 170 | # run policy 171 | fetched = policy.act(last_state, *last_features) 172 | action, value_, features = fetched[0], fetched[1], fetched[2:] 173 | 174 | # run environment: get action_index from sampled one-hot 'action' 175 | stepAct = action.argmax() 176 | state, reward, terminal, info = env.step(stepAct) 177 | if noReward: 178 | reward = 0. 179 | if render: 180 | env.render() 181 | 182 | curr_tuple = [last_state, action, reward, value_, terminal, last_features] 183 | if predictor is not None: 184 | bonus = predictor.pred_bonus(last_state, state, action) 185 | curr_tuple += [bonus, state] 186 | life_bonus += bonus 187 | ep_bonus += bonus 188 | 189 | # collect the experience 190 | rollout.add(*curr_tuple) 191 | rewards += reward 192 | length += 1 193 | values += value_[0] 194 | 195 | last_state = state 196 | last_features = features 197 | 198 | timestep_limit = env.spec.tags.get('wrapper_config.TimeLimit.max_episode_steps') 199 | if timestep_limit is None: timestep_limit = env.spec.timestep_limit 200 | if terminal or length >= timestep_limit: 201 | # prints summary of each life if envWrap==True else each game 202 | if predictor is not None: 203 | print("Episode finished. Sum of shaped rewards: %.2f. Length: %d. Bonus: %.4f." % (rewards, length, life_bonus)) 204 | life_bonus = 0 205 | else: 206 | print("Episode finished. Sum of shaped rewards: %.2f. Length: %d." % (rewards, length)) 207 | if 'distance' in info: print('Mario Distance Covered:', info['distance']) 208 | length = 0 209 | rewards = 0 210 | terminal_end = True 211 | last_features = policy.get_initial_features() # reset lstm memory 212 | # TODO: don't reset when gym timestep_limit increases, bootstrap -- doesn't matter for atari? 213 | # reset only if it hasn't already reseted 214 | if length >= timestep_limit or not env.metadata.get('semantics.autoreset'): 215 | last_state = env.reset() 216 | 217 | if info: 218 | # summarize full game including all lives (even if envWrap=True) 219 | summary = tf.Summary() 220 | for k, v in info.items(): 221 | summary.value.add(tag=k, simple_value=float(v)) 222 | if terminal: 223 | summary.value.add(tag='global/episode_value', simple_value=float(values)) 224 | values = 0 225 | if predictor is not None: 226 | summary.value.add(tag='global/episode_bonus', simple_value=float(ep_bonus)) 227 | ep_bonus = 0 228 | summary_writer.add_summary(summary, policy.global_step.eval()) 229 | summary_writer.flush() 230 | 231 | if terminal_end: 232 | break 233 | 234 | if not terminal_end: 235 | rollout.r = policy.value(last_state, *last_features) 236 | 237 | # once we have enough experience, yield it, and have the ThreadRunner place it on a queue 238 | yield rollout 239 | 240 | 241 | class A3C(object): 242 | def __init__(self, env, task, visualise, unsupType, envWrap=False, designHead='universe', noReward=False): 243 | """ 244 | An implementation of the A3C algorithm that is reasonably well-tuned for the VNC environments. 245 | Below, we will have a modest amount of complexity due to the way TensorFlow handles data parallelism. 246 | But overall, we'll define the model, specify its inputs, and describe how the policy gradients step 247 | should be computed. 248 | """ 249 | self.task = task 250 | self.unsup = unsupType is not None 251 | self.envWrap = envWrap 252 | self.env = env 253 | 254 | predictor = None 255 | numaction = env.action_space.n 256 | worker_device = "/job:worker/task:{}/cpu:0".format(task) 257 | 258 | with tf.device(tf.train.replica_device_setter(1, worker_device=worker_device)): 259 | with tf.variable_scope("global"): 260 | self.network = LSTMPolicy(env.observation_space.shape, numaction, designHead) 261 | self.global_step = tf.get_variable("global_step", [], tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32), 262 | trainable=False) 263 | if self.unsup: 264 | with tf.variable_scope("predictor"): 265 | if 'state' in unsupType: 266 | self.ap_network = StatePredictor(env.observation_space.shape, numaction, designHead, unsupType) 267 | else: 268 | self.ap_network = StateActionPredictor(env.observation_space.shape, numaction, designHead) 269 | 270 | with tf.device(worker_device): 271 | with tf.variable_scope("local"): 272 | self.local_network = pi = LSTMPolicy(env.observation_space.shape, numaction, designHead) 273 | pi.global_step = self.global_step 274 | if self.unsup: 275 | with tf.variable_scope("predictor"): 276 | if 'state' in unsupType: 277 | self.local_ap_network = predictor = StatePredictor(env.observation_space.shape, numaction, designHead, unsupType) 278 | else: 279 | self.local_ap_network = predictor = StateActionPredictor(env.observation_space.shape, numaction, designHead) 280 | 281 | # Computing a3c loss: https://arxiv.org/abs/1506.02438 282 | self.ac = tf.placeholder(tf.float32, [None, numaction], name="ac") 283 | self.adv = tf.placeholder(tf.float32, [None], name="adv") 284 | self.r = tf.placeholder(tf.float32, [None], name="r") 285 | log_prob_tf = tf.nn.log_softmax(pi.logits) 286 | prob_tf = tf.nn.softmax(pi.logits) 287 | # 1) the "policy gradients" loss: its derivative is precisely the policy gradient 288 | # notice that self.ac is a placeholder that is provided externally. 289 | # adv will contain the advantages, as calculated in process_rollout 290 | pi_loss = - tf.reduce_mean(tf.reduce_sum(log_prob_tf * self.ac, 1) * self.adv) # Eq (19) 291 | # 2) loss of value function: l2_loss = (x-y)^2/2 292 | vf_loss = 0.5 * tf.reduce_mean(tf.square(pi.vf - self.r)) # Eq (28) 293 | # 3) entropy to ensure randomness 294 | entropy = - tf.reduce_mean(tf.reduce_sum(prob_tf * log_prob_tf, 1)) 295 | # final a3c loss: lr of critic is half of actor 296 | self.loss = pi_loss + 0.5 * vf_loss - entropy * constants['ENTROPY_BETA'] 297 | 298 | # compute gradients 299 | grads = tf.gradients(self.loss * 20.0, pi.var_list) # batchsize=20. Factored out to make hyperparams not depend on it. 300 | 301 | # computing predictor loss 302 | if self.unsup: 303 | if 'state' in unsupType: 304 | self.predloss = constants['PREDICTION_LR_SCALE'] * predictor.forwardloss 305 | else: 306 | self.predloss = constants['PREDICTION_LR_SCALE'] * (predictor.invloss * (1-constants['FORWARD_LOSS_WT']) + 307 | predictor.forwardloss * constants['FORWARD_LOSS_WT']) 308 | predgrads = tf.gradients(self.predloss * 20.0, predictor.var_list) # batchsize=20. Factored out to make hyperparams not depend on it. 309 | 310 | # do not backprop to policy 311 | if constants['POLICY_NO_BACKPROP_STEPS'] > 0: 312 | grads = [tf.scalar_mul(tf.to_float(tf.greater(self.global_step, constants['POLICY_NO_BACKPROP_STEPS'])), grads_i) 313 | for grads_i in grads] 314 | 315 | 316 | self.runner = RunnerThread(env, pi, constants['ROLLOUT_MAXLEN'], visualise, 317 | predictor, envWrap, noReward) 318 | 319 | # storing summaries 320 | bs = tf.to_float(tf.shape(pi.x)[0]) 321 | if use_tf12_api: 322 | tf.summary.scalar("model/policy_loss", pi_loss) 323 | tf.summary.scalar("model/value_loss", vf_loss) 324 | tf.summary.scalar("model/entropy", entropy) 325 | tf.summary.image("model/state", pi.x) # max_outputs=10 326 | tf.summary.scalar("model/grad_global_norm", tf.global_norm(grads)) 327 | tf.summary.scalar("model/var_global_norm", tf.global_norm(pi.var_list)) 328 | if self.unsup: 329 | tf.summary.scalar("model/predloss", self.predloss) 330 | if 'action' in unsupType: 331 | tf.summary.scalar("model/inv_loss", predictor.invloss) 332 | tf.summary.scalar("model/forward_loss", predictor.forwardloss) 333 | tf.summary.scalar("model/predgrad_global_norm", tf.global_norm(predgrads)) 334 | tf.summary.scalar("model/predvar_global_norm", tf.global_norm(predictor.var_list)) 335 | self.summary_op = tf.summary.merge_all() 336 | else: 337 | tf.scalar_summary("model/policy_loss", pi_loss) 338 | tf.scalar_summary("model/value_loss", vf_loss) 339 | tf.scalar_summary("model/entropy", entropy) 340 | tf.image_summary("model/state", pi.x) 341 | tf.scalar_summary("model/grad_global_norm", tf.global_norm(grads)) 342 | tf.scalar_summary("model/var_global_norm", tf.global_norm(pi.var_list)) 343 | if self.unsup: 344 | tf.scalar_summary("model/predloss", self.predloss) 345 | if 'action' in unsupType: 346 | tf.scalar_summary("model/inv_loss", predictor.invloss) 347 | tf.scalar_summary("model/forward_loss", predictor.forwardloss) 348 | tf.scalar_summary("model/predgrad_global_norm", tf.global_norm(predgrads)) 349 | tf.scalar_summary("model/predvar_global_norm", tf.global_norm(predictor.var_list)) 350 | self.summary_op = tf.merge_all_summaries() 351 | 352 | # clip gradients 353 | grads, _ = tf.clip_by_global_norm(grads, constants['GRAD_NORM_CLIP']) 354 | grads_and_vars = list(zip(grads, self.network.var_list)) 355 | if self.unsup: 356 | predgrads, _ = tf.clip_by_global_norm(predgrads, constants['GRAD_NORM_CLIP']) 357 | pred_grads_and_vars = list(zip(predgrads, self.ap_network.var_list)) 358 | grads_and_vars = grads_and_vars + pred_grads_and_vars 359 | 360 | # update global step by batch size 361 | inc_step = self.global_step.assign_add(tf.shape(pi.x)[0]) 362 | 363 | # each worker has a different set of adam optimizer parameters 364 | # TODO: make optimizer global shared, if needed 365 | print("Optimizer: ADAM with lr: %f" % (constants['LEARNING_RATE'])) 366 | print("Input observation shape: ",env.observation_space.shape) 367 | opt = tf.train.AdamOptimizer(constants['LEARNING_RATE']) 368 | self.train_op = tf.group(opt.apply_gradients(grads_and_vars), inc_step) 369 | 370 | # copy weights from the parameter server to the local model 371 | sync_var_list = [v1.assign(v2) for v1, v2 in zip(pi.var_list, self.network.var_list)] 372 | if self.unsup: 373 | sync_var_list += [v1.assign(v2) for v1, v2 in zip(predictor.var_list, self.ap_network.var_list)] 374 | self.sync = tf.group(*sync_var_list) 375 | 376 | # initialize extras 377 | self.summary_writer = None 378 | self.local_steps = 0 379 | 380 | def start(self, sess, summary_writer): 381 | self.runner.start_runner(sess, summary_writer) 382 | self.summary_writer = summary_writer 383 | 384 | def pull_batch_from_queue(self): 385 | """ 386 | Take a rollout from the queue of the thread runner. 387 | """ 388 | # get top rollout from queue (FIFO) 389 | rollout = self.runner.queue.get(timeout=600.0) 390 | while not rollout.terminal: 391 | try: 392 | # Now, get remaining *available* rollouts from queue and append them into 393 | # the same one above. If queue.Queue(5): len=5 and everything is 394 | # superfast (not usually the case), then all 5 will be returned and 395 | # exception is raised. In such a case, effective batch_size would become 396 | # constants['ROLLOUT_MAXLEN'] * queue_maxlen(5). But it is almost never the 397 | # case, i.e., collecting a rollout of length=ROLLOUT_MAXLEN takes more time 398 | # than get(). So, there are no more available rollouts in queue usually and 399 | # exception gets always raised. Hence, one should keep queue_maxlen = 1 ideally. 400 | # Also note that the next rollout generation gets invoked automatically because 401 | # its a thread which is always running using 'yield' at end of generation process. 402 | # To conclude, effective batch_size = constants['ROLLOUT_MAXLEN'] 403 | rollout.extend(self.runner.queue.get_nowait()) 404 | except queue.Empty: 405 | break 406 | return rollout 407 | 408 | def process(self, sess): 409 | """ 410 | Process grabs a rollout that's been produced by the thread runner, 411 | and updates the parameters. The update is then sent to the parameter 412 | server. 413 | """ 414 | sess.run(self.sync) # copy weights from shared to local 415 | rollout = self.pull_batch_from_queue() 416 | batch = process_rollout(rollout, gamma=constants['GAMMA'], lambda_=constants['LAMBDA'], clip=self.envWrap) 417 | 418 | should_compute_summary = self.task == 0 and self.local_steps % 11 == 0 419 | 420 | if should_compute_summary: 421 | fetches = [self.summary_op, self.train_op, self.global_step] 422 | else: 423 | fetches = [self.train_op, self.global_step] 424 | 425 | feed_dict = { 426 | self.local_network.x: batch.si, 427 | self.ac: batch.a, 428 | self.adv: batch.adv, 429 | self.r: batch.r, 430 | self.local_network.state_in[0]: batch.features[0], 431 | self.local_network.state_in[1]: batch.features[1], 432 | } 433 | if self.unsup: 434 | feed_dict[self.local_network.x] = batch.si[:-1] 435 | feed_dict[self.local_ap_network.s1] = batch.si[:-1] 436 | feed_dict[self.local_ap_network.s2] = batch.si[1:] 437 | feed_dict[self.local_ap_network.asample] = batch.a 438 | 439 | fetched = sess.run(fetches, feed_dict=feed_dict) 440 | if batch.terminal: 441 | print("Global Step Counter: %d"%fetched[-1]) 442 | 443 | if should_compute_summary: 444 | self.summary_writer.add_summary(tf.Summary.FromString(fetched[0]), fetched[-1]) 445 | self.summary_writer.flush() 446 | self.local_steps += 1 447 | -------------------------------------------------------------------------------- /src/constants.py: -------------------------------------------------------------------------------- 1 | constants = { 2 | 'GAMMA': 0.99, # discount factor for rewards 3 | 'LAMBDA': 1.0, # lambda of Generalized Advantage Estimation: https://arxiv.org/abs/1506.02438 4 | 'ENTROPY_BETA': 0.01, # entropy regurarlization constant. 5 | 'ROLLOUT_MAXLEN': 20, # 20 represents the number of 'local steps': the number of timesteps 6 | # we run the policy before we update the parameters. 7 | # The larger local steps is, the lower is the variance in our policy gradients estimate 8 | # on the one hand; but on the other hand, we get less frequent parameter updates, which 9 | # slows down learning. In this code, we found that making local steps be much 10 | # smaller than 20 makes the algorithm more difficult to tune and to get to work. 11 | 'GRAD_NORM_CLIP': 40.0, # gradient norm clipping 12 | 'REWARD_CLIP': 1.0, # reward value clipping in [-x,x] 13 | 'MAX_GLOBAL_STEPS': 100000000, # total steps taken across all workers 14 | 'LEARNING_RATE': 1e-4, # learning rate for adam 15 | 16 | 'PREDICTION_BETA': 0.01, # weight of prediction bonus 17 | # set 0.5 for unsup=state 18 | 'PREDICTION_LR_SCALE': 10.0, # scale lr of predictor wrt to policy network 19 | # set 30-50 for unsup=state 20 | 'FORWARD_LOSS_WT': 0.2, # should be between [0,1] 21 | # predloss = ( (1-FORWARD_LOSS_WT) * inv_loss + FORWARD_LOSS_WT * forward_loss) * PREDICTION_LR_SCALE 22 | 'POLICY_NO_BACKPROP_STEPS': 0, # number of global steps after which we start backpropagating to policy 23 | } 24 | -------------------------------------------------------------------------------- /src/demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import print_function 3 | import tensorflow as tf 4 | import gym 5 | import numpy as np 6 | import argparse 7 | import logging 8 | from envs import create_env 9 | logger = logging.getLogger(__name__) 10 | logger.setLevel(logging.INFO) 11 | 12 | 13 | def inference(args): 14 | """ 15 | It restore policy weights, and does inference. 16 | """ 17 | # define environment 18 | env = create_env(args.env_id, client_id='0', remotes=None, envWrap=True, 19 | acRepeat=1, record=args.record, outdir=args.outdir) 20 | numaction = env.action_space.n 21 | 22 | with tf.device("/cpu:0"): 23 | config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) 24 | with tf.Session(config=config) as sess: 25 | logger.info("Restoring trainable global parameters.") 26 | saver = tf.train.import_meta_graph(args.ckpt+'.meta') 27 | saver.restore(sess, args.ckpt) 28 | 29 | probs = tf.get_collection("probs")[0] 30 | sample = tf.get_collection("sample")[0] 31 | vf = tf.get_collection("vf")[0] 32 | state_out_0 = tf.get_collection("state_out_0")[0] 33 | state_out_1 = tf.get_collection("state_out_1")[0] 34 | 35 | last_state = env.reset() 36 | if args.render or args.record: 37 | env.render() 38 | last_features = np.zeros((1, 256), np.float32); last_features = [last_features, last_features] 39 | length = 0 40 | rewards = 0 41 | mario_distances = np.zeros((args.num_episodes,)) 42 | for i in range(args.num_episodes): 43 | print("Starting episode %d" % (i + 1)) 44 | 45 | if args.random: 46 | print('I am a random policy!') 47 | else: 48 | if args.greedy: 49 | print('I am a greedy policy!') 50 | else: 51 | print('I am a sampled policy!') 52 | while True: 53 | # run policy 54 | fetched = sess.run([probs, sample, vf, state_out_0, state_out_1] , 55 | {"global/x:0": [last_state], "global/c_in:0": last_features[0], "global/h_in:0": last_features[1]}) 56 | prob_action, action, value_, features = fetched[0], fetched[1], fetched[2], fetched[3:] 57 | 58 | # run environment 59 | if args.random: 60 | stepAct = np.random.randint(0, numaction) # random policy 61 | else: 62 | if args.greedy: 63 | stepAct = prob_action.argmax() # greedy policy 64 | else: 65 | stepAct = action.argmax() 66 | state, reward, terminal, info = env.step(stepAct) 67 | 68 | # update stats 69 | length += 1 70 | rewards += reward 71 | last_state = state 72 | last_features = features 73 | if args.render or args.record: 74 | env.render() 75 | 76 | timestep_limit = env.spec.tags.get('wrapper_config.TimeLimit.max_episode_steps') 77 | if timestep_limit is None: timestep_limit = env.spec.timestep_limit 78 | if terminal or length >= timestep_limit: 79 | if length >= timestep_limit or not env.metadata.get('semantics.autoreset'): 80 | last_state = env.reset() 81 | last_features = np.zeros((1, 256), np.float32); last_features = [last_features, last_features] 82 | print("Episode finished. Sum of rewards: %.2f. Length: %d." % (rewards, length)) 83 | length = 0 84 | rewards = 0 85 | if args.render or args.record: 86 | env.render() 87 | break 88 | 89 | logger.info('Finished %d true episodes.', args.num_episodes) 90 | env.close() 91 | 92 | 93 | def main(_): 94 | parser = argparse.ArgumentParser(description=None) 95 | parser.add_argument('--ckpt', default="../models/doom/doom_ICM", help='checkpoint name') 96 | parser.add_argument('--outdir', default="../models/output", help='Output log directory') 97 | parser.add_argument('--env-id', default="doom", help='Environment id') 98 | parser.add_argument('--record', action='store_true', help="Record the policy running video") 99 | parser.add_argument('--render', action='store_true', 100 | help="Render the gym environment video online") 101 | parser.add_argument('--num-episodes', type=int, default=2, help="Number of episodes to run") 102 | parser.add_argument('--greedy', action='store_true', 103 | help="Default sampled policy. This option does argmax.") 104 | parser.add_argument('--random', action='store_true', 105 | help="Default sampled policy. This option does random policy.") 106 | args = parser.parse_args() 107 | inference(args) 108 | 109 | if __name__ == "__main__": 110 | tf.app.run() 111 | -------------------------------------------------------------------------------- /src/env_wrapper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Deepak Pathak 3 | 4 | Acknowledgement: 5 | - The wrappers (BufferedObsEnv, SkipEnv) were originally written by 6 | Evan Shelhamer and modified by Deepak. Thanks Evan! 7 | - This file is derived from 8 | https://github.com/shelhamer/ourl/envs.py 9 | https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers_deprecated.py 10 | """ 11 | from __future__ import print_function 12 | import numpy as np 13 | from collections import deque 14 | from PIL import Image 15 | from gym.spaces.box import Box 16 | import gym 17 | import time, sys 18 | 19 | 20 | class BufferedObsEnv(gym.ObservationWrapper): 21 | """Buffer observations and stack e.g. for frame skipping. 22 | 23 | n is the length of the buffer, and number of observations stacked. 24 | skip is the number of steps between buffered observations (min=1). 25 | 26 | n.b. first obs is the oldest, last obs is the newest. 27 | the buffer is zeroed out on reset. 28 | *must* call reset() for init! 29 | """ 30 | def __init__(self, env=None, n=4, skip=4, shape=(84, 84), 31 | channel_last=True, maxFrames=True): 32 | super(BufferedObsEnv, self).__init__(env) 33 | self.obs_shape = shape 34 | # most recent raw observations (for max pooling across time steps) 35 | self.obs_buffer = deque(maxlen=2) 36 | self.maxFrames = maxFrames 37 | self.n = n 38 | self.skip = skip 39 | self.buffer = deque(maxlen=self.n) 40 | self.counter = 0 # init and reset should agree on this 41 | shape = shape + (n,) if channel_last else (n,) + shape 42 | self.observation_space = Box(0.0, 255.0, shape) 43 | self.ch_axis = -1 if channel_last else 0 44 | self.scale = 1.0 / 255 45 | self.observation_space.high[...] = 1.0 46 | 47 | def _step(self, action): 48 | obs, reward, done, info = self.env.step(action) 49 | return self._observation(obs), reward, done, info 50 | 51 | def _observation(self, obs): 52 | obs = self._convert(obs) 53 | self.counter += 1 54 | if self.counter % self.skip == 0: 55 | self.buffer.append(obs) 56 | obsNew = np.stack(self.buffer, axis=self.ch_axis) 57 | return obsNew.astype(np.float32) * self.scale 58 | 59 | def _reset(self): 60 | """Clear buffer and re-fill by duplicating the first observation.""" 61 | self.obs_buffer.clear() 62 | obs = self._convert(self.env.reset()) 63 | self.buffer.clear() 64 | self.counter = 0 65 | for _ in range(self.n - 1): 66 | self.buffer.append(np.zeros_like(obs)) 67 | self.buffer.append(obs) 68 | obsNew = np.stack(self.buffer, axis=self.ch_axis) 69 | return obsNew.astype(np.float32) * self.scale 70 | 71 | def _convert(self, obs): 72 | self.obs_buffer.append(obs) 73 | if self.maxFrames: 74 | max_frame = np.max(np.stack(self.obs_buffer), axis=0) 75 | else: 76 | max_frame = obs 77 | intensity_frame = self._rgb2y(max_frame).astype(np.uint8) 78 | small_frame = np.array(Image.fromarray(intensity_frame).resize( 79 | self.obs_shape, resample=Image.BILINEAR), dtype=np.uint8) 80 | return small_frame 81 | 82 | def _rgb2y(self, im): 83 | """Converts an RGB image to a Y image (as in YUV). 84 | 85 | These coefficients are taken from the torch/image library. 86 | Beware: these are more critical than you might think, as the 87 | monochromatic contrast can be surprisingly low. 88 | """ 89 | if len(im.shape) < 3: 90 | return im 91 | return np.sum(im * [0.299, 0.587, 0.114], axis=2) 92 | 93 | 94 | class NoNegativeRewardEnv(gym.RewardWrapper): 95 | """Clip reward in negative direction.""" 96 | def __init__(self, env=None, neg_clip=0.0): 97 | super(NoNegativeRewardEnv, self).__init__(env) 98 | self.neg_clip = neg_clip 99 | 100 | def _reward(self, reward): 101 | new_reward = self.neg_clip if reward < self.neg_clip else reward 102 | return new_reward 103 | 104 | 105 | class SkipEnv(gym.Wrapper): 106 | """Skip timesteps: repeat action, accumulate reward, take last obs.""" 107 | def __init__(self, env=None, skip=4): 108 | super(SkipEnv, self).__init__(env) 109 | self.skip = skip 110 | 111 | def _step(self, action): 112 | total_reward = 0 113 | for i in range(0, self.skip): 114 | obs, reward, done, info = self.env.step(action) 115 | total_reward += reward 116 | info['steps'] = i + 1 117 | if done: 118 | break 119 | return obs, total_reward, done, info 120 | 121 | 122 | class MarioEnv(gym.Wrapper): 123 | def __init__(self, env=None, tilesEnv=False): 124 | """Reset mario environment without actually restarting fceux everytime. 125 | This speeds up unrolling by approximately 10 times. 126 | """ 127 | super(MarioEnv, self).__init__(env) 128 | self.resetCount = -1 129 | # reward is distance travelled. So normalize it with total distance 130 | # https://github.com/ppaquette/gym-super-mario/blob/master/ppaquette_gym_super_mario/lua/super-mario-bros.lua 131 | # However, we will not use this reward at all. It is only for completion. 132 | self.maxDistance = 3000.0 133 | self.tilesEnv = tilesEnv 134 | 135 | def _reset(self): 136 | if self.resetCount < 0: 137 | print('\nDoing hard mario fceux reset (40 seconds wait) !') 138 | sys.stdout.flush() 139 | self.env.reset() 140 | time.sleep(40) 141 | obs, _, _, info = self.env.step(7) # take right once to start game 142 | if info.get('ignore', False): # assuming this happens only in beginning 143 | self.resetCount = -1 144 | self.env.close() 145 | return self._reset() 146 | self.resetCount = info.get('iteration', -1) 147 | if self.tilesEnv: 148 | return obs 149 | return obs[24:-12,8:-8,:] 150 | 151 | def _step(self, action): 152 | obs, reward, done, info = self.env.step(action) 153 | # print('info:', info) 154 | done = info['iteration'] > self.resetCount 155 | reward = float(reward)/self.maxDistance # note: we do not use this rewards at all. 156 | if self.tilesEnv: 157 | return obs, reward, done, info 158 | return obs[24:-12,8:-8,:], reward, done, info 159 | 160 | def _close(self): 161 | self.resetCount = -1 162 | return self.env.close() 163 | 164 | 165 | class MakeEnvDynamic(gym.ObservationWrapper): 166 | """Make observation dynamic by adding noise""" 167 | def __init__(self, env=None, percentPad=5): 168 | super(MakeEnvDynamic, self).__init__(env) 169 | self.origShape = env.observation_space.shape 170 | newside = int(round(max(self.origShape[:-1])*100./(100.-percentPad))) 171 | self.newShape = [newside, newside, 3] 172 | self.observation_space = Box(0.0, 255.0, self.newShape) 173 | self.bottomIgnore = 20 # doom 20px bottom is useless 174 | self.ob = None 175 | 176 | def _observation(self, obs): 177 | imNoise = np.random.randint(0,256,self.newShape).astype(obs.dtype) 178 | imNoise[:self.origShape[0]-self.bottomIgnore, :self.origShape[1], :] = obs[:-self.bottomIgnore,:,:] 179 | self.ob = imNoise 180 | return imNoise 181 | 182 | # def render(self, mode='human', close=False): 183 | # temp = self.env.render(mode, close) 184 | # return self.ob 185 | -------------------------------------------------------------------------------- /src/envs.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from PIL import Image 3 | from gym.spaces.box import Box 4 | import numpy as np 5 | import gym 6 | from gym import spaces 7 | import logging 8 | import universe 9 | from universe import vectorized 10 | from universe.wrappers import BlockingReset, GymCoreAction, EpisodeID, Unvectorize, Vectorize, Vision, Logger 11 | from universe import spaces as vnc_spaces 12 | from universe.spaces.vnc_event import keycode 13 | import env_wrapper 14 | import time 15 | logger = logging.getLogger(__name__) 16 | logger.setLevel(logging.INFO) 17 | universe.configure_logging() 18 | 19 | def create_env(env_id, client_id, remotes, **kwargs): 20 | if 'doom' in env_id.lower() or 'labyrinth' in env_id.lower(): 21 | return create_doom(env_id, client_id, **kwargs) 22 | if 'mario' in env_id.lower(): 23 | return create_mario(env_id, client_id, **kwargs) 24 | 25 | spec = gym.spec(env_id) 26 | if spec.tags.get('flashgames', False): 27 | return create_flash_env(env_id, client_id, remotes, **kwargs) 28 | elif spec.tags.get('atari', False) and spec.tags.get('vnc', False): 29 | return create_vncatari_env(env_id, client_id, remotes, **kwargs) 30 | else: 31 | # Assume atari. 32 | assert "." not in env_id # universe environments have dots in names. 33 | return create_atari_env(env_id, **kwargs) 34 | 35 | def create_doom(env_id, client_id, envWrap=True, record=False, outdir=None, 36 | noLifeReward=False, acRepeat=0, **_): 37 | from ppaquette_gym_doom import wrappers 38 | if 'labyrinth' in env_id.lower(): 39 | if 'single' in env_id.lower(): 40 | env_id = 'ppaquette/LabyrinthSingle-v0' 41 | elif 'fix' in env_id.lower(): 42 | env_id = 'ppaquette/LabyrinthManyFixed-v0' 43 | else: 44 | env_id = 'ppaquette/LabyrinthMany-v0' 45 | elif 'very' in env_id.lower(): 46 | env_id = 'ppaquette/DoomMyWayHomeFixed15-v0' 47 | elif 'sparse' in env_id.lower(): 48 | env_id = 'ppaquette/DoomMyWayHomeFixed-v0' 49 | elif 'fix' in env_id.lower(): 50 | if '1' in env_id or '2' in env_id: 51 | env_id = 'ppaquette/DoomMyWayHomeFixed' + str(env_id[-2:]) + '-v0' 52 | elif 'new' in env_id.lower(): 53 | env_id = 'ppaquette/DoomMyWayHomeFixedNew-v0' 54 | else: 55 | env_id = 'ppaquette/DoomMyWayHomeFixed-v0' 56 | else: 57 | env_id = 'ppaquette/DoomMyWayHome-v0' 58 | 59 | # VizDoom workaround: Simultaneously launching multiple vizdoom processes 60 | # makes program stuck, so use the global lock in multi-threading/processing 61 | client_id = int(client_id) 62 | time.sleep(client_id * 10) 63 | env = gym.make(env_id) 64 | modewrapper = wrappers.SetPlayingMode('algo') 65 | obwrapper = wrappers.SetResolution('160x120') 66 | acwrapper = wrappers.ToDiscrete('minimal') 67 | env = modewrapper(obwrapper(acwrapper(env))) 68 | # env = env_wrapper.MakeEnvDynamic(env) # to add stochasticity 69 | 70 | if record and outdir is not None: 71 | env = gym.wrappers.Monitor(env, outdir, force=True) 72 | 73 | if envWrap: 74 | fshape = (42, 42) 75 | frame_skip = acRepeat if acRepeat>0 else 4 76 | env.seed(None) 77 | if noLifeReward: 78 | env = env_wrapper.NoNegativeRewardEnv(env) 79 | env = env_wrapper.BufferedObsEnv(env, skip=frame_skip, shape=fshape) 80 | env = env_wrapper.SkipEnv(env, skip=frame_skip) 81 | elif noLifeReward: 82 | env = env_wrapper.NoNegativeRewardEnv(env) 83 | 84 | env = Vectorize(env) 85 | env = DiagnosticsInfo(env) 86 | env = Unvectorize(env) 87 | return env 88 | 89 | def create_mario(env_id, client_id, envWrap=True, record=False, outdir=None, 90 | noLifeReward=False, acRepeat=0, **_): 91 | import ppaquette_gym_super_mario 92 | from ppaquette_gym_super_mario import wrappers 93 | if '-v' in env_id.lower(): 94 | env_id = 'ppaquette/' + env_id 95 | else: 96 | env_id = 'ppaquette/SuperMarioBros-1-1-v0' # shape: (224,256,3)=(h,w,c) 97 | 98 | # Mario workaround: Simultaneously launching multiple vizdoom processes makes program stuck, 99 | # so use the global lock in multi-threading/multi-processing 100 | # see: https://github.com/ppaquette/gym-super-mario/tree/master/ppaquette_gym_super_mario 101 | client_id = int(client_id) 102 | time.sleep(client_id * 50) 103 | env = gym.make(env_id) 104 | modewrapper = wrappers.SetPlayingMode('algo') 105 | acwrapper = wrappers.ToDiscrete() 106 | env = modewrapper(acwrapper(env)) 107 | env = env_wrapper.MarioEnv(env) 108 | 109 | if record and outdir is not None: 110 | env = gym.wrappers.Monitor(env, outdir, force=True) 111 | 112 | if envWrap: 113 | frame_skip = acRepeat if acRepeat>0 else 6 114 | fshape = (42, 42) 115 | env.seed(None) 116 | if noLifeReward: 117 | env = env_wrapper.NoNegativeRewardEnv(env) 118 | env = env_wrapper.BufferedObsEnv(env, skip=frame_skip, shape=fshape, maxFrames=False) 119 | if frame_skip > 1: 120 | env = env_wrapper.SkipEnv(env, skip=frame_skip) 121 | elif noLifeReward: 122 | env = env_wrapper.NoNegativeRewardEnv(env) 123 | 124 | env = Vectorize(env) 125 | env = DiagnosticsInfo(env) 126 | env = Unvectorize(env) 127 | # env.close() # TODO: think about where to put env.close ! 128 | return env 129 | 130 | def create_flash_env(env_id, client_id, remotes, **_): 131 | env = gym.make(env_id) 132 | env = Vision(env) 133 | env = Logger(env) 134 | env = BlockingReset(env) 135 | 136 | reg = universe.runtime_spec('flashgames').server_registry 137 | height = reg[env_id]["height"] 138 | width = reg[env_id]["width"] 139 | env = CropScreen(env, height, width, 84, 18) 140 | env = FlashRescale(env) 141 | 142 | keys = ['left', 'right', 'up', 'down', 'x'] 143 | if env_id == 'flashgames.NeonRace-v0': 144 | # Better key space for this game. 145 | keys = ['left', 'right', 'up', 'left up', 'right up', 'down', 'up x'] 146 | logger.info('create_flash_env(%s): keys=%s', env_id, keys) 147 | 148 | env = DiscreteToFixedKeysVNCActions(env, keys) 149 | env = EpisodeID(env) 150 | env = DiagnosticsInfo(env) 151 | env = Unvectorize(env) 152 | env.configure(fps=5.0, remotes=remotes, start_timeout=15 * 60, client_id=client_id, 153 | vnc_driver='go', vnc_kwargs={ 154 | 'encoding': 'tight', 'compress_level': 0, 155 | 'fine_quality_level': 50, 'subsample_level': 3}) 156 | return env 157 | 158 | def create_vncatari_env(env_id, client_id, remotes, **_): 159 | env = gym.make(env_id) 160 | env = Vision(env) 161 | env = Logger(env) 162 | env = BlockingReset(env) 163 | env = GymCoreAction(env) 164 | env = AtariRescale42x42(env) 165 | env = EpisodeID(env) 166 | env = DiagnosticsInfo(env) 167 | env = Unvectorize(env) 168 | 169 | logger.info('Connecting to remotes: %s', remotes) 170 | fps = env.metadata['video.frames_per_second'] 171 | env.configure(remotes=remotes, start_timeout=15 * 60, fps=fps, client_id=client_id) 172 | return env 173 | 174 | def create_atari_env(env_id, record=False, outdir=None, **_): 175 | env = gym.make(env_id) 176 | if record and outdir is not None: 177 | env = gym.wrappers.Monitor(env, outdir, force=True) 178 | env = Vectorize(env) 179 | env = AtariRescale42x42(env) 180 | env = DiagnosticsInfo(env) 181 | env = Unvectorize(env) 182 | return env 183 | 184 | def DiagnosticsInfo(env, *args, **kwargs): 185 | return vectorized.VectorizeFilter(env, DiagnosticsInfoI, *args, **kwargs) 186 | 187 | class DiagnosticsInfoI(vectorized.Filter): 188 | def __init__(self, log_interval=503): 189 | super(DiagnosticsInfoI, self).__init__() 190 | 191 | self._episode_time = time.time() 192 | self._last_time = time.time() 193 | self._local_t = 0 194 | self._log_interval = log_interval 195 | self._episode_reward = 0 196 | self._episode_length = 0 197 | self._all_rewards = [] 198 | self._num_vnc_updates = 0 199 | self._last_episode_id = -1 200 | 201 | def _after_reset(self, observation): 202 | logger.info('Resetting environment logs') 203 | self._episode_reward = 0 204 | self._episode_length = 0 205 | self._all_rewards = [] 206 | return observation 207 | 208 | def _after_step(self, observation, reward, done, info): 209 | to_log = {} 210 | if self._episode_length == 0: 211 | self._episode_time = time.time() 212 | 213 | self._local_t += 1 214 | if info.get("stats.vnc.updates.n") is not None: 215 | self._num_vnc_updates += info.get("stats.vnc.updates.n") 216 | 217 | if self._local_t % self._log_interval == 0: 218 | cur_time = time.time() 219 | elapsed = cur_time - self._last_time 220 | fps = self._log_interval / elapsed 221 | self._last_time = cur_time 222 | cur_episode_id = info.get('vectorized.episode_id', 0) 223 | to_log["diagnostics/fps"] = fps 224 | if self._last_episode_id == cur_episode_id: 225 | to_log["diagnostics/fps_within_episode"] = fps 226 | self._last_episode_id = cur_episode_id 227 | if info.get("stats.gauges.diagnostics.lag.action") is not None: 228 | to_log["diagnostics/action_lag_lb"] = info["stats.gauges.diagnostics.lag.action"][0] 229 | to_log["diagnostics/action_lag_ub"] = info["stats.gauges.diagnostics.lag.action"][1] 230 | if info.get("reward.count") is not None: 231 | to_log["diagnostics/reward_count"] = info["reward.count"] 232 | if info.get("stats.gauges.diagnostics.clock_skew") is not None: 233 | to_log["diagnostics/clock_skew_lb"] = info["stats.gauges.diagnostics.clock_skew"][0] 234 | to_log["diagnostics/clock_skew_ub"] = info["stats.gauges.diagnostics.clock_skew"][1] 235 | if info.get("stats.gauges.diagnostics.lag.observation") is not None: 236 | to_log["diagnostics/observation_lag_lb"] = info["stats.gauges.diagnostics.lag.observation"][0] 237 | to_log["diagnostics/observation_lag_ub"] = info["stats.gauges.diagnostics.lag.observation"][1] 238 | 239 | if info.get("stats.vnc.updates.n") is not None: 240 | to_log["diagnostics/vnc_updates_n"] = info["stats.vnc.updates.n"] 241 | to_log["diagnostics/vnc_updates_n_ps"] = self._num_vnc_updates / elapsed 242 | self._num_vnc_updates = 0 243 | if info.get("stats.vnc.updates.bytes") is not None: 244 | to_log["diagnostics/vnc_updates_bytes"] = info["stats.vnc.updates.bytes"] 245 | if info.get("stats.vnc.updates.pixels") is not None: 246 | to_log["diagnostics/vnc_updates_pixels"] = info["stats.vnc.updates.pixels"] 247 | if info.get("stats.vnc.updates.rectangles") is not None: 248 | to_log["diagnostics/vnc_updates_rectangles"] = info["stats.vnc.updates.rectangles"] 249 | if info.get("env_status.state_id") is not None: 250 | to_log["diagnostics/env_state_id"] = info["env_status.state_id"] 251 | 252 | if reward is not None: 253 | self._episode_reward += reward 254 | if observation is not None: 255 | self._episode_length += 1 256 | self._all_rewards.append(reward) 257 | 258 | if done: 259 | logger.info('True Game terminating: env_episode_reward=%s episode_length=%s', self._episode_reward, self._episode_length) 260 | total_time = time.time() - self._episode_time 261 | to_log["global/episode_reward"] = self._episode_reward 262 | to_log["global/episode_length"] = self._episode_length 263 | to_log["global/episode_time"] = total_time 264 | to_log["global/reward_per_time"] = self._episode_reward / total_time 265 | self._episode_reward = 0 266 | self._episode_length = 0 267 | self._all_rewards = [] 268 | 269 | if 'distance' in info: to_log['distance'] = info['distance'] # mario 270 | if 'POSITION_X' in info: # doom 271 | to_log['POSITION_X'] = info['POSITION_X'] 272 | to_log['POSITION_Y'] = info['POSITION_Y'] 273 | return observation, reward, done, to_log 274 | 275 | def _process_frame42(frame): 276 | frame = frame[34:34+160, :160] 277 | # Resize by half, then down to 42x42 (essentially mipmapping). If 278 | # we resize directly we lose pixels that, when mapped to 42x42, 279 | # aren't close enough to the pixel boundary. 280 | frame = np.asarray(Image.fromarray(frame).resize((80, 80), resample=Image.BILINEAR).resize( 281 | (42,42), resample=Image.BILINEAR)) 282 | frame = frame.mean(2) # take mean along channels 283 | frame = frame.astype(np.float32) 284 | frame *= (1.0 / 255.0) 285 | frame = np.reshape(frame, [42, 42, 1]) 286 | return frame 287 | 288 | class AtariRescale42x42(vectorized.ObservationWrapper): 289 | def __init__(self, env=None): 290 | super(AtariRescale42x42, self).__init__(env) 291 | self.observation_space = Box(0.0, 1.0, [42, 42, 1]) 292 | 293 | def _observation(self, observation_n): 294 | return [_process_frame42(observation) for observation in observation_n] 295 | 296 | class FixedKeyState(object): 297 | def __init__(self, keys): 298 | self._keys = [keycode(key) for key in keys] 299 | self._down_keysyms = set() 300 | 301 | def apply_vnc_actions(self, vnc_actions): 302 | for event in vnc_actions: 303 | if isinstance(event, vnc_spaces.KeyEvent): 304 | if event.down: 305 | self._down_keysyms.add(event.key) 306 | else: 307 | self._down_keysyms.discard(event.key) 308 | 309 | def to_index(self): 310 | action_n = 0 311 | for key in self._down_keysyms: 312 | if key in self._keys: 313 | # If multiple keys are pressed, just use the first one 314 | action_n = self._keys.index(key) + 1 315 | break 316 | return action_n 317 | 318 | class DiscreteToFixedKeysVNCActions(vectorized.ActionWrapper): 319 | """ 320 | Define a fixed action space. Action 0 is all keys up. Each element of keys can be a single key or a space-separated list of keys 321 | 322 | For example, 323 | e=DiscreteToFixedKeysVNCActions(e, ['left', 'right']) 324 | will have 3 actions: [none, left, right] 325 | 326 | You can define a state with more than one key down by separating with spaces. For example, 327 | e=DiscreteToFixedKeysVNCActions(e, ['left', 'right', 'space', 'left space', 'right space']) 328 | will have 6 actions: [none, left, right, space, left space, right space] 329 | """ 330 | def __init__(self, env, keys): 331 | super(DiscreteToFixedKeysVNCActions, self).__init__(env) 332 | 333 | self._keys = keys 334 | self._generate_actions() 335 | self.action_space = spaces.Discrete(len(self._actions)) 336 | 337 | def _generate_actions(self): 338 | self._actions = [] 339 | uniq_keys = set() 340 | for key in self._keys: 341 | for cur_key in key.split(' '): 342 | uniq_keys.add(cur_key) 343 | 344 | for key in [''] + self._keys: 345 | split_keys = key.split(' ') 346 | cur_action = [] 347 | for cur_key in uniq_keys: 348 | cur_action.append(vnc_spaces.KeyEvent.by_name(cur_key, down=(cur_key in split_keys))) 349 | self._actions.append(cur_action) 350 | self.key_state = FixedKeyState(uniq_keys) 351 | 352 | def _action(self, action_n): 353 | # Each action might be a length-1 np.array. Cast to int to 354 | # avoid warnings. 355 | return [self._actions[int(action)] for action in action_n] 356 | 357 | class CropScreen(vectorized.ObservationWrapper): 358 | """Crops out a [height]x[width] area starting from (top,left) """ 359 | def __init__(self, env, height, width, top=0, left=0): 360 | super(CropScreen, self).__init__(env) 361 | self.height = height 362 | self.width = width 363 | self.top = top 364 | self.left = left 365 | self.observation_space = Box(0, 255, shape=(height, width, 3)) 366 | 367 | def _observation(self, observation_n): 368 | return [ob[self.top:self.top+self.height, self.left:self.left+self.width, :] if ob is not None else None 369 | for ob in observation_n] 370 | 371 | def _process_frame_flash(frame): 372 | frame = np.array(Image.fromarray(frame).resize((200, 128), resample=Image.BILINEAR)) 373 | frame = frame.mean(2).astype(np.float32) 374 | frame *= (1.0 / 255.0) 375 | frame = np.reshape(frame, [128, 200, 1]) 376 | return frame 377 | 378 | class FlashRescale(vectorized.ObservationWrapper): 379 | def __init__(self, env=None): 380 | super(FlashRescale, self).__init__(env) 381 | self.observation_space = Box(0.0, 1.0, [128, 200, 1]) 382 | 383 | def _observation(self, observation_n): 384 | return [_process_frame_flash(observation) for observation in observation_n] 385 | -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import print_function 3 | import go_vncdriver 4 | import tensorflow as tf 5 | import numpy as np 6 | import argparse 7 | import logging 8 | import os 9 | import gym 10 | from envs import create_env 11 | from worker import FastSaver 12 | from model import LSTMPolicy 13 | import utils 14 | import distutils.version 15 | use_tf12_api = distutils.version.LooseVersion(tf.VERSION) >= distutils.version.LooseVersion('0.12.0') 16 | 17 | logger = logging.getLogger(__name__) 18 | logger.setLevel(logging.INFO) 19 | 20 | 21 | def inference(args): 22 | """ 23 | It only restores LSTMPolicy architecture, and does inference using that. 24 | """ 25 | # get address of checkpoints 26 | indir = os.path.join(args.log_dir, 'train') 27 | outdir = os.path.join(args.log_dir, 'inference') if args.out_dir is None else args.out_dir 28 | with open(indir + '/checkpoint', 'r') as f: 29 | first_line = f.readline().strip() 30 | ckpt = first_line.split(' ')[-1].split('/')[-1][:-1] 31 | ckpt = ckpt.split('-')[-1] 32 | ckpt = indir + '/model.ckpt-' + ckpt 33 | 34 | # define environment 35 | if args.record: 36 | env = create_env(args.env_id, client_id='0', remotes=None, envWrap=args.envWrap, designHead=args.designHead, 37 | record=True, noop=args.noop, acRepeat=args.acRepeat, outdir=outdir) 38 | else: 39 | env = create_env(args.env_id, client_id='0', remotes=None, envWrap=args.envWrap, designHead=args.designHead, 40 | record=True, noop=args.noop, acRepeat=args.acRepeat) 41 | numaction = env.action_space.n 42 | 43 | with tf.device("/cpu:0"): 44 | # define policy network 45 | with tf.variable_scope("global"): 46 | policy = LSTMPolicy(env.observation_space.shape, numaction, args.designHead) 47 | policy.global_step = tf.get_variable("global_step", [], tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32), 48 | trainable=False) 49 | 50 | # Variable names that start with "local" are not saved in checkpoints. 51 | if use_tf12_api: 52 | variables_to_restore = [v for v in tf.global_variables() if not v.name.startswith("local")] 53 | init_all_op = tf.global_variables_initializer() 54 | else: 55 | variables_to_restore = [v for v in tf.all_variables() if not v.name.startswith("local")] 56 | init_all_op = tf.initialize_all_variables() 57 | saver = FastSaver(variables_to_restore) 58 | 59 | # print trainable variables 60 | var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) 61 | logger.info('Trainable vars:') 62 | for v in var_list: 63 | logger.info(' %s %s', v.name, v.get_shape()) 64 | 65 | # summary of rewards 66 | action_writers = [] 67 | if use_tf12_api: 68 | summary_writer = tf.summary.FileWriter(outdir) 69 | for ac_id in range(numaction): 70 | action_writers.append(tf.summary.FileWriter(os.path.join(outdir,'action_{}'.format(ac_id)))) 71 | else: 72 | summary_writer = tf.train.SummaryWriter(outdir) 73 | for ac_id in range(numaction): 74 | action_writers.append(tf.train.SummaryWriter(os.path.join(outdir,'action_{}'.format(ac_id)))) 75 | logger.info("Inference events directory: %s", outdir) 76 | 77 | config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) 78 | with tf.Session(config=config) as sess: 79 | logger.info("Initializing all parameters.") 80 | sess.run(init_all_op) 81 | logger.info("Restoring trainable global parameters.") 82 | saver.restore(sess, ckpt) 83 | logger.info("Restored model was trained for %.2fM global steps", sess.run(policy.global_step)/1000000.) 84 | # saving with meta graph: 85 | # metaSaver = tf.train.Saver(variables_to_restore) 86 | # metaSaver.save(sess, 'models/doomICM') 87 | 88 | last_state = env.reset() 89 | if args.render or args.record: 90 | env.render() 91 | last_features = policy.get_initial_features() # reset lstm memory 92 | length = 0 93 | rewards = 0 94 | mario_distances = np.zeros((args.num_episodes,)) 95 | for i in range(args.num_episodes): 96 | print("Starting episode %d" % (i + 1)) 97 | if args.recordSignal: 98 | from PIL import Image 99 | signalCount = 1 100 | utils.mkdir_p(outdir + '/recordedSignal/ep_%02d/'%i) 101 | Image.fromarray((255*last_state[..., -1]).astype('uint8')).save(outdir + '/recordedSignal/ep_%02d/%06d.jpg'%(i,signalCount)) 102 | 103 | if args.random: 104 | print('I am random policy!') 105 | else: 106 | if args.greedy: 107 | print('I am greedy policy!') 108 | else: 109 | print('I am sampled policy!') 110 | while True: 111 | # run policy 112 | fetched = policy.act_inference(last_state, *last_features) 113 | prob_action, action, value_, features = fetched[0], fetched[1], fetched[2], fetched[3:] 114 | 115 | # run environment: sampled one-hot 'action' (not greedy) 116 | if args.random: 117 | stepAct = np.random.randint(0, numaction) # random policy 118 | else: 119 | if args.greedy: 120 | stepAct = prob_action.argmax() # greedy policy 121 | else: 122 | stepAct = action.argmax() 123 | # print(stepAct, prob_action.argmax(), prob_action) 124 | state, reward, terminal, info = env.step(stepAct) 125 | 126 | # update stats 127 | length += 1 128 | rewards += reward 129 | last_state = state 130 | last_features = features 131 | if args.render or args.record: 132 | env.render() 133 | if args.recordSignal: 134 | signalCount += 1 135 | Image.fromarray((255*last_state[..., -1]).astype('uint8')).save(outdir + '/recordedSignal/ep_%02d/%06d.jpg'%(i,signalCount)) 136 | 137 | # store summary 138 | summary = tf.Summary() 139 | summary.value.add(tag='ep_{}/reward'.format(i), simple_value=reward) 140 | summary.value.add(tag='ep_{}/netreward'.format(i), simple_value=rewards) 141 | summary.value.add(tag='ep_{}/value'.format(i), simple_value=float(value_[0])) 142 | if 'NoFrameskip-v' in args.env_id: # atari 143 | summary.value.add(tag='ep_{}/lives'.format(i), simple_value=env.unwrapped.ale.lives()) 144 | summary_writer.add_summary(summary, length) 145 | summary_writer.flush() 146 | summary = tf.Summary() 147 | for ac_id in range(numaction): 148 | summary.value.add(tag='action_prob', simple_value=float(prob_action[ac_id])) 149 | action_writers[ac_id].add_summary(summary, length) 150 | action_writers[ac_id].flush() 151 | 152 | timestep_limit = env.spec.tags.get('wrapper_config.TimeLimit.max_episode_steps') 153 | if timestep_limit is None: timestep_limit = env.spec.timestep_limit 154 | if terminal or length >= timestep_limit: 155 | if length >= timestep_limit or not env.metadata.get('semantics.autoreset'): 156 | last_state = env.reset() 157 | last_features = policy.get_initial_features() # reset lstm memory 158 | print("Episode finished. Sum of rewards: %.2f. Length: %d." % (rewards, length)) 159 | if 'distance' in info: 160 | print('Mario Distance Covered:', info['distance']) 161 | mario_distances[i] = info['distance'] 162 | length = 0 163 | rewards = 0 164 | if args.render or args.record: 165 | env.render() 166 | if args.recordSignal: 167 | signalCount += 1 168 | Image.fromarray((255*last_state[..., -1]).astype('uint8')).save(outdir + '/recordedSignal/ep_%02d/%06d.jpg'%(i,signalCount)) 169 | break 170 | 171 | logger.info('Finished %d true episodes.', args.num_episodes) 172 | if 'distance' in info: 173 | print('Mario Distances:', mario_distances) 174 | np.save(outdir + '/distances.npy', mario_distances) 175 | env.close() 176 | 177 | 178 | def main(_): 179 | parser = argparse.ArgumentParser(description=None) 180 | parser.add_argument('--log-dir', default="tmp/doom", help='input model directory') 181 | parser.add_argument('--out-dir', default=None, help='output log directory. Default: log_dir/inference/') 182 | parser.add_argument('--env-id', default="PongDeterministic-v3", help='Environment id') 183 | parser.add_argument('--record', action='store_true', 184 | help="Record the gym environment video -- user friendly") 185 | parser.add_argument('--recordSignal', action='store_true', 186 | help="Record images of true processed input to network") 187 | parser.add_argument('--render', action='store_true', 188 | help="Render the gym environment video online") 189 | parser.add_argument('--envWrap', action='store_true', 190 | help="Preprocess input in env_wrapper (no change in input size or network)") 191 | parser.add_argument('--designHead', type=str, default='universe', 192 | help="Network deign head: nips or nature or doom or universe(default)") 193 | parser.add_argument('--num-episodes', type=int, default=2, 194 | help="Number of episodes to run") 195 | parser.add_argument('--noop', action='store_true', 196 | help="Add 30-noop for inference too (recommended by Nature paper, don't know?)") 197 | parser.add_argument('--acRepeat', type=int, default=0, 198 | help="Actions to be repeated at inference. 0 means default. applies iff envWrap is True.") 199 | parser.add_argument('--greedy', action='store_true', 200 | help="Default sampled policy. This option does argmax.") 201 | parser.add_argument('--random', action='store_true', 202 | help="Default sampled policy. This option does random policy.") 203 | parser.add_argument('--default', action='store_true', help="run with default params") 204 | args = parser.parse_args() 205 | if args.default: 206 | args.envWrap = True 207 | args.acRepeat = 1 208 | if args.acRepeat <= 0: 209 | print('Using default action repeat (i.e. 4). Min value that can be set is 1.') 210 | inference(args) 211 | 212 | if __name__ == "__main__": 213 | tf.app.run() 214 | -------------------------------------------------------------------------------- /src/mario.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Script to test if mario installation works fine. It 3 | displays the game play simultaneously. 4 | ''' 5 | 6 | from __future__ import print_function 7 | import gym, universe 8 | import env_wrapper 9 | import ppaquette_gym_super_mario 10 | from ppaquette_gym_super_mario import wrappers 11 | import numpy as np 12 | import time 13 | from PIL import Image 14 | import utils 15 | 16 | outputdir = './gray42/' 17 | env_id = 'ppaquette/SuperMarioBros-1-1-v0' 18 | env = gym.make(env_id) 19 | modewrapper = wrappers.SetPlayingMode('algo') 20 | acwrapper = wrappers.ToDiscrete() 21 | env = modewrapper(acwrapper(env)) 22 | env = env_wrapper.MarioEnv(env) 23 | 24 | freshape = fshape = (42, 42) 25 | env.seed(None) 26 | env = env_wrapper.NoNegativeRewardEnv(env) 27 | env = env_wrapper.DQNObsEnv(env, shape=freshape) 28 | env = env_wrapper.BufferedObsEnv(env, n=4, skip=1, shape=fshape, channel_last=True) 29 | env = env_wrapper.EltwiseScaleObsEnv(env) 30 | 31 | start = time.time() 32 | episodes = 0 33 | maxepisodes = 1 34 | env.reset() 35 | imCount = 1 36 | utils.mkdir_p(outputdir + '/ep_%02d/'%(episodes+1)) 37 | while(1): 38 | obs, reward, done, info = env.step(env.action_space.sample()) 39 | Image.fromarray((255*obs).astype('uint8')).save(outputdir + '/ep_%02d/%06d.jpg'%(episodes+1,imCount)) 40 | imCount += 1 41 | if done: 42 | episodes += 1 43 | print('Ep: %d, Distance: %d'%(episodes, info['distance'])) 44 | if episodes >= maxepisodes: 45 | break 46 | env.reset() 47 | imCount = 1 48 | utils.mkdir_p(outputdir + '/ep_%02d/'%(episodes+1)) 49 | end = time.time() 50 | print('\nTotal Time spent: %0.2f seconds'% (end-start)) 51 | env.close() 52 | print('Done!') 53 | exit(1) 54 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import tensorflow as tf 4 | import tensorflow.contrib.rnn as rnn 5 | from constants import constants 6 | 7 | 8 | def normalized_columns_initializer(std=1.0): 9 | def _initializer(shape, dtype=None, partition_info=None): 10 | out = np.random.randn(*shape).astype(np.float32) 11 | out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) 12 | return tf.constant(out) 13 | return _initializer 14 | 15 | 16 | def cosineLoss(A, B, name): 17 | ''' A, B : (BatchSize, d) ''' 18 | dotprod = tf.reduce_sum(tf.multiply(tf.nn.l2_normalize(A,1), tf.nn.l2_normalize(B,1)), 1) 19 | loss = 1-tf.reduce_mean(dotprod, name=name) 20 | return loss 21 | 22 | 23 | def flatten(x): 24 | return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])]) 25 | 26 | 27 | def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME", dtype=tf.float32, collections=None): 28 | with tf.variable_scope(name): 29 | stride_shape = [1, stride[0], stride[1], 1] 30 | filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]), num_filters] 31 | 32 | # there are "num input feature maps * filter height * filter width" 33 | # inputs to each hidden unit 34 | fan_in = np.prod(filter_shape[:3]) 35 | # each unit in the lower layer receives a gradient from: 36 | # "num output feature maps * filter height * filter width" / 37 | # pooling size 38 | fan_out = np.prod(filter_shape[:2]) * num_filters 39 | # initialize weights with random weights 40 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 41 | 42 | w = tf.get_variable("W", filter_shape, dtype, tf.random_uniform_initializer(-w_bound, w_bound), 43 | collections=collections) 44 | b = tf.get_variable("b", [1, 1, 1, num_filters], initializer=tf.constant_initializer(0.0), 45 | collections=collections) 46 | return tf.nn.conv2d(x, w, stride_shape, pad) + b 47 | 48 | 49 | def deconv2d(x, out_shape, name, filter_size=(3, 3), stride=(1, 1), pad="SAME", dtype=tf.float32, collections=None, prevNumFeat=None): 50 | with tf.variable_scope(name): 51 | num_filters = out_shape[-1] 52 | prevNumFeat = int(x.get_shape()[3]) if prevNumFeat is None else prevNumFeat 53 | stride_shape = [1, stride[0], stride[1], 1] 54 | # transpose_filter : [height, width, out_channels, in_channels] 55 | filter_shape = [filter_size[0], filter_size[1], num_filters, prevNumFeat] 56 | 57 | # there are "num input feature maps * filter height * filter width" 58 | # inputs to each hidden unit 59 | fan_in = np.prod(filter_shape[:2]) * prevNumFeat 60 | # each unit in the lower layer receives a gradient from: 61 | # "num output feature maps * filter height * filter width" 62 | fan_out = np.prod(filter_shape[:3]) 63 | # initialize weights with random weights 64 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 65 | 66 | w = tf.get_variable("W", filter_shape, dtype, tf.random_uniform_initializer(-w_bound, w_bound), 67 | collections=collections) 68 | b = tf.get_variable("b", [num_filters], initializer=tf.constant_initializer(0.0), 69 | collections=collections) 70 | deconv2d = tf.nn.conv2d_transpose(x, w, tf.pack(out_shape), stride_shape, pad) 71 | # deconv2d = tf.reshape(tf.nn.bias_add(deconv2d, b), deconv2d.get_shape()) 72 | return deconv2d 73 | 74 | 75 | def linear(x, size, name, initializer=None, bias_init=0): 76 | w = tf.get_variable(name + "/w", [x.get_shape()[1], size], initializer=initializer) 77 | b = tf.get_variable(name + "/b", [size], initializer=tf.constant_initializer(bias_init)) 78 | return tf.matmul(x, w) + b 79 | 80 | 81 | def categorical_sample(logits, d): 82 | value = tf.squeeze(tf.multinomial(logits - tf.reduce_max(logits, [1], keep_dims=True), 1), [1]) 83 | return tf.one_hot(value, d) 84 | 85 | 86 | def inverseUniverseHead(x, final_shape, nConvs=4): 87 | ''' universe agent example 88 | input: [None, 288]; output: [None, 42, 42, 1]; 89 | ''' 90 | print('Using inverse-universe head design') 91 | bs = tf.shape(x)[0] 92 | deconv_shape1 = [final_shape[1]] 93 | deconv_shape2 = [final_shape[2]] 94 | for i in range(nConvs): 95 | deconv_shape1.append((deconv_shape1[-1]-1)/2 + 1) 96 | deconv_shape2.append((deconv_shape2[-1]-1)/2 + 1) 97 | inshapeprod = np.prod(x.get_shape().as_list()[1:]) / 32.0 98 | assert(inshapeprod == deconv_shape1[-1]*deconv_shape2[-1]) 99 | # print('deconv_shape1: ',deconv_shape1) 100 | # print('deconv_shape2: ',deconv_shape2) 101 | 102 | x = tf.reshape(x, [-1, deconv_shape1[-1], deconv_shape2[-1], 32]) 103 | deconv_shape1 = deconv_shape1[:-1] 104 | deconv_shape2 = deconv_shape2[:-1] 105 | for i in range(nConvs-1): 106 | x = tf.nn.elu(deconv2d(x, [bs, deconv_shape1[-1], deconv_shape2[-1], 32], 107 | "dl{}".format(i + 1), [3, 3], [2, 2], prevNumFeat=32)) 108 | deconv_shape1 = deconv_shape1[:-1] 109 | deconv_shape2 = deconv_shape2[:-1] 110 | x = deconv2d(x, [bs] + final_shape[1:], "dl4", [3, 3], [2, 2], prevNumFeat=32) 111 | return x 112 | 113 | 114 | def universeHead(x, nConvs=4): 115 | ''' universe agent example 116 | input: [None, 42, 42, 1]; output: [None, 288]; 117 | ''' 118 | print('Using universe head design') 119 | for i in range(nConvs): 120 | x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2])) 121 | # print('Loop{} '.format(i+1),tf.shape(x)) 122 | # print('Loop{}'.format(i+1),x.get_shape()) 123 | x = flatten(x) 124 | return x 125 | 126 | 127 | def nipsHead(x): 128 | ''' DQN NIPS 2013 and A3C paper 129 | input: [None, 84, 84, 4]; output: [None, 2592] -> [None, 256]; 130 | ''' 131 | print('Using nips head design') 132 | x = tf.nn.relu(conv2d(x, 16, "l1", [8, 8], [4, 4], pad="VALID")) 133 | x = tf.nn.relu(conv2d(x, 32, "l2", [4, 4], [2, 2], pad="VALID")) 134 | x = flatten(x) 135 | x = tf.nn.relu(linear(x, 256, "fc", normalized_columns_initializer(0.01))) 136 | return x 137 | 138 | 139 | def natureHead(x): 140 | ''' DQN Nature 2015 paper 141 | input: [None, 84, 84, 4]; output: [None, 3136] -> [None, 512]; 142 | ''' 143 | print('Using nature head design') 144 | x = tf.nn.relu(conv2d(x, 32, "l1", [8, 8], [4, 4], pad="VALID")) 145 | x = tf.nn.relu(conv2d(x, 64, "l2", [4, 4], [2, 2], pad="VALID")) 146 | x = tf.nn.relu(conv2d(x, 64, "l3", [3, 3], [1, 1], pad="VALID")) 147 | x = flatten(x) 148 | x = tf.nn.relu(linear(x, 512, "fc", normalized_columns_initializer(0.01))) 149 | return x 150 | 151 | 152 | def doomHead(x): 153 | ''' Learning by Prediction ICLR 2017 paper 154 | (their final output was 64 changed to 256 here) 155 | input: [None, 120, 160, 1]; output: [None, 1280] -> [None, 256]; 156 | ''' 157 | print('Using doom head design') 158 | x = tf.nn.elu(conv2d(x, 8, "l1", [5, 5], [4, 4])) 159 | x = tf.nn.elu(conv2d(x, 16, "l2", [3, 3], [2, 2])) 160 | x = tf.nn.elu(conv2d(x, 32, "l3", [3, 3], [2, 2])) 161 | x = tf.nn.elu(conv2d(x, 64, "l4", [3, 3], [2, 2])) 162 | x = flatten(x) 163 | x = tf.nn.elu(linear(x, 256, "fc", normalized_columns_initializer(0.01))) 164 | return x 165 | 166 | 167 | class LSTMPolicy(object): 168 | def __init__(self, ob_space, ac_space, designHead='universe'): 169 | self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space), name='x') 170 | size = 256 171 | if designHead == 'nips': 172 | x = nipsHead(x) 173 | elif designHead == 'nature': 174 | x = natureHead(x) 175 | elif designHead == 'doom': 176 | x = doomHead(x) 177 | elif 'tile' in designHead: 178 | x = universeHead(x, nConvs=2) 179 | else: 180 | x = universeHead(x) 181 | 182 | # introduce a "fake" batch dimension of 1 to do LSTM over time dim 183 | x = tf.expand_dims(x, [0]) 184 | lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True) 185 | self.state_size = lstm.state_size 186 | step_size = tf.shape(self.x)[:1] 187 | 188 | c_init = np.zeros((1, lstm.state_size.c), np.float32) 189 | h_init = np.zeros((1, lstm.state_size.h), np.float32) 190 | self.state_init = [c_init, h_init] 191 | c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c], name='c_in') 192 | h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h], name='h_in') 193 | self.state_in = [c_in, h_in] 194 | 195 | state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in) 196 | lstm_outputs, lstm_state = tf.nn.dynamic_rnn( 197 | lstm, x, initial_state=state_in, sequence_length=step_size, 198 | time_major=False) 199 | lstm_c, lstm_h = lstm_state 200 | x = tf.reshape(lstm_outputs, [-1, size]) 201 | self.vf = tf.reshape(linear(x, 1, "value", normalized_columns_initializer(1.0)), [-1]) 202 | self.state_out = [lstm_c[:1, :], lstm_h[:1, :]] 203 | 204 | # [0, :] means pick action of first state from batch. Hardcoded b/c 205 | # batch=1 during rollout collection. Its not used during batch training. 206 | self.logits = linear(x, ac_space, "action", normalized_columns_initializer(0.01)) 207 | self.sample = categorical_sample(self.logits, ac_space)[0, :] 208 | self.probs = tf.nn.softmax(self.logits, dim=-1)[0, :] 209 | 210 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) 211 | # tf.add_to_collection('probs', self.probs) 212 | # tf.add_to_collection('sample', self.sample) 213 | # tf.add_to_collection('state_out_0', self.state_out[0]) 214 | # tf.add_to_collection('state_out_1', self.state_out[1]) 215 | # tf.add_to_collection('vf', self.vf) 216 | 217 | def get_initial_features(self): 218 | # Call this function to get reseted lstm memory cells 219 | return self.state_init 220 | 221 | def act(self, ob, c, h): 222 | sess = tf.get_default_session() 223 | return sess.run([self.sample, self.vf] + self.state_out, 224 | {self.x: [ob], self.state_in[0]: c, self.state_in[1]: h}) 225 | 226 | def act_inference(self, ob, c, h): 227 | sess = tf.get_default_session() 228 | return sess.run([self.probs, self.sample, self.vf] + self.state_out, 229 | {self.x: [ob], self.state_in[0]: c, self.state_in[1]: h}) 230 | 231 | def value(self, ob, c, h): 232 | sess = tf.get_default_session() 233 | return sess.run(self.vf, {self.x: [ob], self.state_in[0]: c, self.state_in[1]: h})[0] 234 | 235 | 236 | class StateActionPredictor(object): 237 | def __init__(self, ob_space, ac_space, designHead='universe'): 238 | # input: s1,s2: : [None, h, w, ch] (usually ch=1 or 4) 239 | # asample: 1-hot encoding of sampled action from policy: [None, ac_space] 240 | input_shape = [None] + list(ob_space) 241 | self.s1 = phi1 = tf.placeholder(tf.float32, input_shape) 242 | self.s2 = phi2 = tf.placeholder(tf.float32, input_shape) 243 | self.asample = asample = tf.placeholder(tf.float32, [None, ac_space]) 244 | 245 | # feature encoding: phi1, phi2: [None, LEN] 246 | size = 256 247 | if designHead == 'nips': 248 | phi1 = nipsHead(phi1) 249 | with tf.variable_scope(tf.get_variable_scope(), reuse=True): 250 | phi2 = nipsHead(phi2) 251 | elif designHead == 'nature': 252 | phi1 = natureHead(phi1) 253 | with tf.variable_scope(tf.get_variable_scope(), reuse=True): 254 | phi2 = natureHead(phi2) 255 | elif designHead == 'doom': 256 | phi1 = doomHead(phi1) 257 | with tf.variable_scope(tf.get_variable_scope(), reuse=True): 258 | phi2 = doomHead(phi2) 259 | elif 'tile' in designHead: 260 | phi1 = universeHead(phi1, nConvs=2) 261 | with tf.variable_scope(tf.get_variable_scope(), reuse=True): 262 | phi2 = universeHead(phi2, nConvs=2) 263 | else: 264 | phi1 = universeHead(phi1) 265 | with tf.variable_scope(tf.get_variable_scope(), reuse=True): 266 | phi2 = universeHead(phi2) 267 | 268 | # inverse model: g(phi1,phi2) -> a_inv: [None, ac_space] 269 | g = tf.concat(1,[phi1, phi2]) 270 | g = tf.nn.relu(linear(g, size, "g1", normalized_columns_initializer(0.01))) 271 | aindex = tf.argmax(asample, axis=1) # aindex: [batch_size,] 272 | logits = linear(g, ac_space, "glast", normalized_columns_initializer(0.01)) 273 | self.invloss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( 274 | logits, aindex), name="invloss") 275 | self.ainvprobs = tf.nn.softmax(logits, dim=-1) 276 | 277 | # forward model: f(phi1,asample) -> phi2 278 | # Note: no backprop to asample of policy: it is treated as fixed for predictor training 279 | f = tf.concat(1, [phi1, asample]) 280 | f = tf.nn.relu(linear(f, size, "f1", normalized_columns_initializer(0.01))) 281 | f = linear(f, phi1.get_shape()[1].value, "flast", normalized_columns_initializer(0.01)) 282 | self.forwardloss = 0.5 * tf.reduce_mean(tf.square(tf.subtract(f, phi2)), name='forwardloss') 283 | # self.forwardloss = 0.5 * tf.reduce_mean(tf.sqrt(tf.abs(tf.subtract(f, phi2))), name='forwardloss') 284 | # self.forwardloss = cosineLoss(f, phi2, name='forwardloss') 285 | self.forwardloss = self.forwardloss * 288.0 # lenFeatures=288. Factored out to make hyperparams not depend on it. 286 | 287 | # variable list 288 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) 289 | 290 | def pred_act(self, s1, s2): 291 | ''' 292 | returns action probability distribution predicted by inverse model 293 | input: s1,s2: [h, w, ch] 294 | output: ainvprobs: [ac_space] 295 | ''' 296 | sess = tf.get_default_session() 297 | return sess.run(self.ainvprobs, {self.s1: [s1], self.s2: [s2]})[0, :] 298 | 299 | def pred_bonus(self, s1, s2, asample): 300 | ''' 301 | returns bonus predicted by forward model 302 | input: s1,s2: [h, w, ch], asample: [ac_space] 1-hot encoding 303 | output: scalar bonus 304 | ''' 305 | sess = tf.get_default_session() 306 | # error = sess.run([self.forwardloss, self.invloss], 307 | # {self.s1: [s1], self.s2: [s2], self.asample: [asample]}) 308 | # print('ErrorF: ', error[0], ' ErrorI:', error[1]) 309 | error = sess.run(self.forwardloss, 310 | {self.s1: [s1], self.s2: [s2], self.asample: [asample]}) 311 | error = error * constants['PREDICTION_BETA'] 312 | return error 313 | 314 | 315 | class StatePredictor(object): 316 | ''' 317 | Loss is normalized across spatial dimension (42x42), but not across batches. 318 | It is unlike ICM where no normalization is there across 288 spatial dimension 319 | and neither across batches. 320 | ''' 321 | 322 | def __init__(self, ob_space, ac_space, designHead='universe', unsupType='state'): 323 | # input: s1,s2: : [None, h, w, ch] (usually ch=1 or 4) 324 | # asample: 1-hot encoding of sampled action from policy: [None, ac_space] 325 | input_shape = [None] + list(ob_space) 326 | self.s1 = phi1 = tf.placeholder(tf.float32, input_shape) 327 | self.s2 = phi2 = tf.placeholder(tf.float32, input_shape) 328 | self.asample = asample = tf.placeholder(tf.float32, [None, ac_space]) 329 | self.stateAenc = unsupType == 'stateAenc' 330 | 331 | # feature encoding: phi1: [None, LEN] 332 | if designHead == 'universe': 333 | phi1 = universeHead(phi1) 334 | if self.stateAenc: 335 | with tf.variable_scope(tf.get_variable_scope(), reuse=True): 336 | phi2_aenc = universeHead(phi2) 337 | elif 'tile' in designHead: # for mario tiles 338 | phi1 = universeHead(phi1, nConvs=2) 339 | if self.stateAenc: 340 | with tf.variable_scope(tf.get_variable_scope(), reuse=True): 341 | phi2_aenc = universeHead(phi2) 342 | else: 343 | print('Only universe designHead implemented for state prediction baseline.') 344 | exit(1) 345 | 346 | # forward model: f(phi1,asample) -> phi2 347 | # Note: no backprop to asample of policy: it is treated as fixed for predictor training 348 | f = tf.concat(1, [phi1, asample]) 349 | f = tf.nn.relu(linear(f, phi1.get_shape()[1].value, "f1", normalized_columns_initializer(0.01))) 350 | if 'tile' in designHead: 351 | f = inverseUniverseHead(f, input_shape, nConvs=2) 352 | else: 353 | f = inverseUniverseHead(f, input_shape) 354 | self.forwardloss = 0.5 * tf.reduce_mean(tf.square(tf.subtract(f, phi2)), name='forwardloss') 355 | if self.stateAenc: 356 | self.aencBonus = 0.5 * tf.reduce_mean(tf.square(tf.subtract(phi1, phi2_aenc)), name='aencBonus') 357 | self.predstate = phi1 358 | 359 | # variable list 360 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) 361 | 362 | def pred_state(self, s1, asample): 363 | ''' 364 | returns state predicted by forward model 365 | input: s1: [h, w, ch], asample: [ac_space] 1-hot encoding 366 | output: s2: [h, w, ch] 367 | ''' 368 | sess = tf.get_default_session() 369 | return sess.run(self.predstate, {self.s1: [s1], 370 | self.asample: [asample]})[0, :] 371 | 372 | def pred_bonus(self, s1, s2, asample): 373 | ''' 374 | returns bonus predicted by forward model 375 | input: s1,s2: [h, w, ch], asample: [ac_space] 1-hot encoding 376 | output: scalar bonus 377 | ''' 378 | sess = tf.get_default_session() 379 | bonus = self.aencBonus if self.stateAenc else self.forwardloss 380 | error = sess.run(bonus, 381 | {self.s1: [s1], self.s2: [s2], self.asample: [asample]}) 382 | # print('ErrorF: ', error) 383 | error = error * constants['PREDICTION_BETA'] 384 | return error 385 | -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- 1 | atari-py==0.1.1 2 | attrs==17.2.0 3 | autobahn==17.6.2 4 | Automat==0.6.0 5 | backports.ssl-match-hostname==3.5.0.1 6 | certifi==2017.4.17 7 | chardet==3.0.4 8 | constantly==15.1.0 9 | docker-py==1.10.3 10 | docker-pycreds==0.2.1 11 | doom-py==0.0.15 12 | fastzbarlight==0.0.14 13 | funcsigs==1.0.2 14 | -e git+https://github.com/openai/go-vncdriver.git@33bd0dd9620e97acd9b4e559bca217df09ba89e6#egg=go_vncdriver 15 | -e git+https://github.com/openai/gym.git@6f277090ed3323009a324ea31d00363afd8dfb3a#egg=gym 16 | -e git+https://github.com/pathak22/gym-pull.git@589039c29567c67fb3d5c0a315806419e0999415#egg=gym_pull 17 | hyperlink==17.2.1 18 | idna==2.5 19 | incremental==17.5.0 20 | ipaddress==1.0.18 21 | mock==2.0.0 22 | numpy==1.13.1 23 | olefile==0.44 24 | pbr==3.1.1 25 | Pillow==4.2.1 26 | ppaquette-gym-doom==0.0.3 27 | -e git+https://github.com/ppaquette/gym-super-mario.git@2e5ee823b6090af3f99b1f62c465fc4b033532f4#egg=ppaquette_gym_super_mario 28 | protobuf==3.1.0 29 | pyglet==1.2.4 30 | PyOpenGL==3.1.0 31 | PyYAML==3.12 32 | requests>=2.20.0 33 | scipy==0.19.1 34 | six==1.10.0 35 | tensorflow==0.12.0rc1 36 | Twisted==17.5.0 37 | txaio==2.8.0 38 | ujson==1.35 39 | -e git+https://github.com/openai/universe.git@e8037a103d8871a29396c39b2a58df439bde3380#egg=universe 40 | urllib3==1.21.1 41 | websocket-client==0.44.0 42 | zope.interface==4.4.2 43 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from six.moves import shlex_quote 5 | 6 | parser = argparse.ArgumentParser(description="Run commands") 7 | parser.add_argument('-w', '--num-workers', default=20, type=int, 8 | help="Number of workers") 9 | parser.add_argument('-r', '--remotes', default=None, 10 | help='The address of pre-existing VNC servers and ' 11 | 'rewarders to use (e.g. -r vnc://localhost:5900+15900,vnc://localhost:5901+15901).') 12 | parser.add_argument('-e', '--env-id', type=str, default="doom", 13 | help="Environment id") 14 | parser.add_argument('-l', '--log-dir', type=str, default="tmp/doom", 15 | help="Log directory path") 16 | parser.add_argument('-n', '--dry-run', action='store_true', 17 | help="Print out commands rather than executing them") 18 | parser.add_argument('-m', '--mode', type=str, default='tmux', 19 | help="tmux: run workers in a tmux session. nohup: run workers with nohup. child: run workers as child processes") 20 | parser.add_argument('--visualise', action='store_true', 21 | help="Visualise the gym environment by running env.render() between each timestep") 22 | parser.add_argument('--envWrap', action='store_true', 23 | help="Preprocess input in env_wrapper (no change in input size or network)") 24 | parser.add_argument('--designHead', type=str, default='universe', 25 | help="Network deign head: nips or nature or doom or universe(default)") 26 | parser.add_argument('--unsup', type=str, default=None, 27 | help="Unsup. exploration mode: action or state or stateAenc or None") 28 | parser.add_argument('--noReward', action='store_true', help="Remove all extrinsic reward") 29 | parser.add_argument('--noLifeReward', action='store_true', 30 | help="Remove all negative reward (in doom: it is living reward)") 31 | parser.add_argument('--expName', type=str, default='a3c', 32 | help="Experiment tmux session-name. Default a3c.") 33 | parser.add_argument('--expId', type=int, default=0, 34 | help="Experiment Id >=0. Needed while runnig more than one run per machine.") 35 | parser.add_argument('--savio', action='store_true', 36 | help="Savio or KNL cpu cluster hacks") 37 | parser.add_argument('--default', action='store_true', help="run with default params") 38 | parser.add_argument('--pretrain', type=str, default=None, help="Checkpoint dir (generally ..../train/) to load from.") 39 | 40 | def new_cmd(session, name, cmd, mode, logdir, shell): 41 | if isinstance(cmd, (list, tuple)): 42 | cmd = " ".join(shlex_quote(str(v)) for v in cmd) 43 | if mode == 'tmux': 44 | return name, "tmux send-keys -t {}:{} {} Enter".format(session, name, shlex_quote(cmd)) 45 | elif mode == 'child': 46 | return name, "{} >{}/{}.{}.out 2>&1 & echo kill $! >>{}/kill.sh".format(cmd, logdir, session, name, logdir) 47 | elif mode == 'nohup': 48 | return name, "nohup {} -c {} >{}/{}.{}.out 2>&1 & echo kill $! >>{}/kill.sh".format(shell, shlex_quote(cmd), logdir, session, name, logdir) 49 | 50 | 51 | def create_commands(session, num_workers, remotes, env_id, logdir, shell='bash', 52 | mode='tmux', visualise=False, envWrap=False, designHead=None, 53 | unsup=None, noReward=False, noLifeReward=False, psPort=12222, 54 | delay=0, savio=False, pretrain=None): 55 | # for launching the TF workers and for launching tensorboard 56 | py_cmd = 'python' if savio else sys.executable 57 | base_cmd = [ 58 | 'CUDA_VISIBLE_DEVICES=', 59 | py_cmd, 'worker.py', 60 | '--log-dir', logdir, 61 | '--env-id', env_id, 62 | '--num-workers', str(num_workers), 63 | '--psPort', psPort] 64 | 65 | if delay > 0: 66 | base_cmd += ['--delay', delay] 67 | if visualise: 68 | base_cmd += ['--visualise'] 69 | if envWrap: 70 | base_cmd += ['--envWrap'] 71 | if designHead is not None: 72 | base_cmd += ['--designHead', designHead] 73 | if unsup is not None: 74 | base_cmd += ['--unsup', unsup] 75 | if noReward: 76 | base_cmd += ['--noReward'] 77 | if noLifeReward: 78 | base_cmd += ['--noLifeReward'] 79 | if pretrain is not None: 80 | base_cmd += ['--pretrain', pretrain] 81 | 82 | if remotes is None: 83 | remotes = ["1"] * num_workers 84 | else: 85 | remotes = remotes.split(',') 86 | assert len(remotes) == num_workers 87 | 88 | cmds_map = [new_cmd(session, "ps", base_cmd + ["--job-name", "ps"], mode, logdir, shell)] 89 | for i in range(num_workers): 90 | cmds_map += [new_cmd(session, 91 | "w-%d" % i, base_cmd + ["--job-name", "worker", "--task", str(i), "--remotes", remotes[i]], mode, logdir, shell)] 92 | 93 | # No tensorboard or htop window if running multiple experiments per machine 94 | if session == 'a3c': 95 | cmds_map += [new_cmd(session, "tb", ["tensorboard", "--logdir", logdir, "--port", "12345"], mode, logdir, shell)] 96 | if session == 'a3c' and mode == 'tmux': 97 | cmds_map += [new_cmd(session, "htop", ["htop"], mode, logdir, shell)] 98 | 99 | windows = [v[0] for v in cmds_map] 100 | 101 | notes = [] 102 | cmds = [ 103 | "mkdir -p {}".format(logdir), 104 | "echo {} {} > {}/cmd.sh".format(sys.executable, ' '.join([shlex_quote(arg) for arg in sys.argv if arg != '-n']), logdir), 105 | ] 106 | if mode == 'nohup' or mode == 'child': 107 | cmds += ["echo '#!/bin/sh' >{}/kill.sh".format(logdir)] 108 | notes += ["Run `source {}/kill.sh` to kill the job".format(logdir)] 109 | if mode == 'tmux': 110 | notes += ["Use `tmux attach -t {}` to watch process output".format(session)] 111 | notes += ["Use `tmux kill-session -t {}` to kill the job".format(session)] 112 | else: 113 | notes += ["Use `tail -f {}/*.out` to watch process output".format(logdir)] 114 | notes += ["Point your browser to http://localhost:12345 to see Tensorboard"] 115 | 116 | if mode == 'tmux': 117 | cmds += [ 118 | "kill -9 $( lsof -i:12345 -t ) > /dev/null 2>&1", # kill any process using tensorboard's port 119 | "kill -9 $( lsof -i:{}-{} -t ) > /dev/null 2>&1".format(psPort, num_workers+psPort), # kill any processes using ps / worker ports 120 | "tmux kill-session -t {}".format(session), 121 | "tmux new-session -s {} -n {} -d {}".format(session, windows[0], shell) 122 | ] 123 | for w in windows[1:]: 124 | cmds += ["tmux new-window -t {} -n {} {}".format(session, w, shell)] 125 | cmds += ["sleep 1"] 126 | for window, cmd in cmds_map: 127 | cmds += [cmd] 128 | 129 | return cmds, notes 130 | 131 | 132 | def run(): 133 | args = parser.parse_args() 134 | if args.default: 135 | args.envWrap = True 136 | args.savio = True 137 | args.noLifeReward = True 138 | args.unsup = 'action' 139 | 140 | # handling nuances of running multiple jobs per-machine 141 | psPort = 12222 + 50*args.expId 142 | delay = 220*args.expId if 'doom' in args.env_id.lower() or 'labyrinth' in args.env_id.lower() else 5*args.expId 143 | delay = 6*delay if 'mario' in args.env_id else delay 144 | 145 | cmds, notes = create_commands(args.expName, args.num_workers, args.remotes, args.env_id, 146 | args.log_dir, mode=args.mode, visualise=args.visualise, 147 | envWrap=args.envWrap, designHead=args.designHead, 148 | unsup=args.unsup, noReward=args.noReward, 149 | noLifeReward=args.noLifeReward, psPort=psPort, 150 | delay=delay, savio=args.savio, pretrain=args.pretrain) 151 | if args.dry_run: 152 | print("Dry-run mode due to -n flag, otherwise the following commands would be executed:") 153 | else: 154 | print("Executing the following commands:") 155 | print("\n".join(cmds)) 156 | print("") 157 | if not args.dry_run: 158 | if args.mode == "tmux": 159 | os.environ["TMUX"] = "" 160 | os.system("\n".join(cmds)) 161 | print('\n'.join(notes)) 162 | 163 | 164 | if __name__ == "__main__": 165 | run() 166 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import os 4 | import errno 5 | 6 | 7 | def mkdir_p(path): 8 | """ 9 | It creates directory recursively if it does not already exist 10 | """ 11 | try: 12 | os.makedirs(path) 13 | except OSError as exc: 14 | if exc.errno == errno.EEXIST and os.path.isdir(path): 15 | pass 16 | else: 17 | raise 18 | -------------------------------------------------------------------------------- /src/worker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import go_vncdriver 3 | import tensorflow as tf 4 | import argparse 5 | import logging 6 | import sys, signal 7 | import time 8 | import os 9 | from a3c import A3C 10 | from envs import create_env 11 | from constants import constants 12 | import distutils.version 13 | use_tf12_api = distutils.version.LooseVersion(tf.VERSION) >= distutils.version.LooseVersion('0.12.0') 14 | 15 | logger = logging.getLogger(__name__) 16 | logger.setLevel(logging.INFO) 17 | 18 | # Disables write_meta_graph argument, which freezes entire process and is mostly useless. 19 | class FastSaver(tf.train.Saver): 20 | def save(self, sess, save_path, global_step=None, latest_filename=None, 21 | meta_graph_suffix="meta", write_meta_graph=True): 22 | super(FastSaver, self).save(sess, save_path, global_step, latest_filename, 23 | meta_graph_suffix, False) 24 | 25 | def run(args, server): 26 | env = create_env(args.env_id, client_id=str(args.task), remotes=args.remotes, envWrap=args.envWrap, designHead=args.designHead, 27 | noLifeReward=args.noLifeReward) 28 | trainer = A3C(env, args.task, args.visualise, args.unsup, args.envWrap, args.designHead, args.noReward) 29 | 30 | # logging 31 | if args.task == 0: 32 | with open(args.log_dir + '/log.txt', 'w') as fid: 33 | for key, val in constants.items(): 34 | fid.write('%s: %s\n'%(str(key), str(val))) 35 | fid.write('designHead: %s\n'%args.designHead) 36 | fid.write('input observation: %s\n'%str(env.observation_space.shape)) 37 | fid.write('env name: %s\n'%str(env.spec.id)) 38 | fid.write('unsup method type: %s\n'%str(args.unsup)) 39 | 40 | # Variable names that start with "local" are not saved in checkpoints. 41 | if use_tf12_api: 42 | variables_to_save = [v for v in tf.global_variables() if not v.name.startswith("local")] 43 | init_op = tf.variables_initializer(variables_to_save) 44 | init_all_op = tf.global_variables_initializer() 45 | else: 46 | variables_to_save = [v for v in tf.all_variables() if not v.name.startswith("local")] 47 | init_op = tf.initialize_variables(variables_to_save) 48 | init_all_op = tf.initialize_all_variables() 49 | saver = FastSaver(variables_to_save) 50 | if args.pretrain is not None: 51 | variables_to_restore = [v for v in tf.trainable_variables() if not v.name.startswith("local")] 52 | pretrain_saver = FastSaver(variables_to_restore) 53 | 54 | var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) 55 | logger.info('Trainable vars:') 56 | for v in var_list: 57 | logger.info(' %s %s', v.name, v.get_shape()) 58 | 59 | def init_fn(ses): 60 | logger.info("Initializing all parameters.") 61 | ses.run(init_all_op) 62 | if args.pretrain is not None: 63 | pretrain = tf.train.latest_checkpoint(args.pretrain) 64 | logger.info("==> Restoring from given pretrained checkpoint.") 65 | logger.info(" Pretraining address: %s", pretrain) 66 | pretrain_saver.restore(ses, pretrain) 67 | logger.info("==> Done restoring model! Restored %d variables.", len(variables_to_restore)) 68 | 69 | config = tf.ConfigProto(device_filters=["/job:ps", "/job:worker/task:{}/cpu:0".format(args.task)]) 70 | logdir = os.path.join(args.log_dir, 'train') 71 | 72 | if use_tf12_api: 73 | summary_writer = tf.summary.FileWriter(logdir + "_%d" % args.task) 74 | else: 75 | summary_writer = tf.train.SummaryWriter(logdir + "_%d" % args.task) 76 | 77 | logger.info("Events directory: %s_%s", logdir, args.task) 78 | sv = tf.train.Supervisor(is_chief=(args.task == 0), 79 | logdir=logdir, 80 | saver=saver, 81 | summary_op=None, 82 | init_op=init_op, 83 | init_fn=init_fn, 84 | summary_writer=summary_writer, 85 | ready_op=tf.report_uninitialized_variables(variables_to_save), 86 | global_step=trainer.global_step, 87 | save_model_secs=30, 88 | save_summaries_secs=30) 89 | 90 | num_global_steps = constants['MAX_GLOBAL_STEPS'] 91 | 92 | logger.info( 93 | "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. " + 94 | "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified.") 95 | with sv.managed_session(server.target, config=config) as sess, sess.as_default(): 96 | # Workaround for FailedPreconditionError 97 | # see: https://github.com/openai/universe-starter-agent/issues/44 and 31 98 | sess.run(trainer.sync) 99 | 100 | trainer.start(sess, summary_writer) 101 | global_step = sess.run(trainer.global_step) 102 | logger.info("Starting training at gobal_step=%d", global_step) 103 | while not sv.should_stop() and (not num_global_steps or global_step < num_global_steps): 104 | trainer.process(sess) 105 | global_step = sess.run(trainer.global_step) 106 | 107 | # Ask for all the services to stop. 108 | sv.stop() 109 | logger.info('reached %s steps. worker stopped.', global_step) 110 | 111 | def cluster_spec(num_workers, num_ps, port=12222): 112 | """ 113 | More tensorflow setup for data parallelism 114 | """ 115 | cluster = {} 116 | 117 | all_ps = [] 118 | host = '127.0.0.1' 119 | for _ in range(num_ps): 120 | all_ps.append('{}:{}'.format(host, port)) 121 | port += 1 122 | cluster['ps'] = all_ps 123 | 124 | all_workers = [] 125 | for _ in range(num_workers): 126 | all_workers.append('{}:{}'.format(host, port)) 127 | port += 1 128 | cluster['worker'] = all_workers 129 | return cluster 130 | 131 | def main(_): 132 | """ 133 | Setting up Tensorflow for data parallel work 134 | """ 135 | 136 | parser = argparse.ArgumentParser(description=None) 137 | parser.add_argument('-v', '--verbose', action='count', dest='verbosity', default=0, help='Set verbosity.') 138 | parser.add_argument('--task', default=0, type=int, help='Task index') 139 | parser.add_argument('--job-name', default="worker", help='worker or ps') 140 | parser.add_argument('--num-workers', default=1, type=int, help='Number of workers') 141 | parser.add_argument('--log-dir', default="tmp/doom", help='Log directory path') 142 | parser.add_argument('--env-id', default="doom", help='Environment id') 143 | parser.add_argument('-r', '--remotes', default=None, 144 | help='References to environments to create (e.g. -r 20), ' 145 | 'or the address of pre-existing VNC servers and ' 146 | 'rewarders to use (e.g. -r vnc://localhost:5900+15900,vnc://localhost:5901+15901)') 147 | parser.add_argument('--visualise', action='store_true', 148 | help="Visualise the gym environment by running env.render() between each timestep") 149 | parser.add_argument('--envWrap', action='store_true', 150 | help="Preprocess input in env_wrapper (no change in input size or network)") 151 | parser.add_argument('--designHead', type=str, default='universe', 152 | help="Network deign head: nips or nature or doom or universe(default)") 153 | parser.add_argument('--unsup', type=str, default=None, 154 | help="Unsup. exploration mode: action or state or stateAenc or None") 155 | parser.add_argument('--noReward', action='store_true', help="Remove all extrinsic reward") 156 | parser.add_argument('--noLifeReward', action='store_true', 157 | help="Remove all negative reward (in doom: it is living reward)") 158 | parser.add_argument('--psPort', default=12222, type=int, help='Port number for parameter server') 159 | parser.add_argument('--delay', default=0, type=int, help='delay start by these many seconds') 160 | parser.add_argument('--pretrain', type=str, default=None, help="Checkpoint dir (generally ..../train/) to load from.") 161 | args = parser.parse_args() 162 | 163 | spec = cluster_spec(args.num_workers, 1, args.psPort) 164 | cluster = tf.train.ClusterSpec(spec).as_cluster_def() 165 | 166 | def shutdown(signal, frame): 167 | logger.warn('Received signal %s: exiting', signal) 168 | sys.exit(128+signal) 169 | signal.signal(signal.SIGHUP, shutdown) 170 | signal.signal(signal.SIGINT, shutdown) 171 | signal.signal(signal.SIGTERM, shutdown) 172 | 173 | if args.job_name == "worker": 174 | server = tf.train.Server(cluster, job_name="worker", task_index=args.task, 175 | config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=2)) 176 | if args.delay > 0: 177 | print('Startup delay in worker: {}s'.format(args.delay)) 178 | time.sleep(args.delay) 179 | print('.. wait over !') 180 | run(args, server) 181 | else: 182 | server = tf.train.Server(cluster, job_name="ps", task_index=args.task, 183 | config=tf.ConfigProto(device_filters=["/job:ps"])) 184 | while True: 185 | time.sleep(1000) 186 | 187 | if __name__ == "__main__": 188 | tf.app.run() 189 | --------------------------------------------------------------------------------