├── .gitignore
├── LICENSE
├── README.md
├── doomFiles
├── README.md
├── __init__.py
├── action_space.py
├── doom_env.py
├── doom_my_way_home_sparse.py
├── doom_my_way_home_verySparse.py
└── wads
│ ├── my_way_home_dense.wad
│ ├── my_way_home_sparse.wad
│ └── my_way_home_verySparse.wad
├── images
├── mario1.gif
├── mario2.gif
└── vizdoom.gif
├── models
└── download_models.sh
└── src
├── .gitignore
├── a3c.py
├── constants.py
├── demo.py
├── env_wrapper.py
├── envs.py
├── inference.py
├── mario.py
├── model.py
├── requirements.txt
├── train.py
├── utils.py
└── worker.py
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__
2 | **pyc
3 | **npy
4 | tmp/
5 | curiosity
6 | src/vizdoom.ini
7 | models/*.tar.gz
8 | models/output
9 | models/doom
10 | models/mario
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017 Deepak Pathak
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 |
25 |
26 | --------------------------------------------------------------------------------
27 | Original openai License:
28 | --------------------------------------------------------------------------------
29 | MIT License
30 |
31 | Copyright (c) 2016 openai
32 |
33 | Permission is hereby granted, free of charge, to any person obtaining a copy
34 | of this software and associated documentation files (the "Software"), to deal
35 | in the Software without restriction, including without limitation the rights
36 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
37 | copies of the Software, and to permit persons to whom the Software is
38 | furnished to do so, subject to the following conditions:
39 |
40 | The above copyright notice and this permission notice shall be included in all
41 | copies or substantial portions of the Software.
42 |
43 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
44 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
45 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
46 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
47 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
48 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
49 | SOFTWARE.
50 | --------------------------------------------------------------------------------
51 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Curiosity-driven Exploration by Self-supervised Prediction ##
2 | #### In ICML 2017 [[Project Website]](http://pathak22.github.io/noreward-rl/) [[Demo Video]](http://pathak22.github.io/noreward-rl/index.html#demoVideo)
3 |
4 | [Deepak Pathak](https://people.eecs.berkeley.edu/~pathak/), [Pulkit Agrawal](https://people.eecs.berkeley.edu/~pulkitag/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/), [Trevor Darrell](https://people.eecs.berkeley.edu/~trevor/)
5 | University of California, Berkeley
6 |
7 |
8 |
9 | This is a tensorflow based implementation for our [ICML 2017 paper on curiosity-driven exploration for reinforcement learning](http://pathak22.github.io/noreward-rl/). Idea is to train agent with intrinsic curiosity-based motivation (ICM) when external rewards from environment are sparse. Surprisingly, you can use ICM even when there are no rewards available from the environment, in which case, agent learns to explore only out of curiosity: 'RL without rewards'. If you find this work useful in your research, please cite:
10 |
11 | @inproceedings{pathakICMl17curiosity,
12 | Author = {Pathak, Deepak and Agrawal, Pulkit and
13 | Efros, Alexei A. and Darrell, Trevor},
14 | Title = {Curiosity-driven Exploration by Self-supervised Prediction},
15 | Booktitle = {International Conference on Machine Learning ({ICML})},
16 | Year = {2017}
17 | }
18 |
19 | ### 1) Installation and Usage
20 | 1. This code is based on [TensorFlow](https://www.tensorflow.org/). To install, run these commands:
21 | ```Shell
22 | # you might not need many of these, e.g., fceux is only for mario
23 | sudo apt-get install -y python-numpy python-dev cmake zlib1g-dev libjpeg-dev xvfb \
24 | libav-tools xorg-dev python-opengl libboost-all-dev libsdl2-dev swig python3-dev \
25 | python3-venv make golang libjpeg-turbo8-dev gcc wget unzip git fceux virtualenv \
26 | tmux
27 |
28 | # install the code
29 | git clone -b master --single-branch https://github.com/pathak22/noreward-rl.git
30 | cd noreward-rl/
31 | virtualenv curiosity
32 | source $PWD/curiosity/bin/activate
33 | pip install numpy
34 | pip install -r src/requirements.txt
35 | python curiosity/src/go-vncdriver/build.py
36 |
37 | # download models
38 | bash models/download_models.sh
39 |
40 | # setup customized doom environment
41 | cd doomFiles/
42 | # then follow commands in doomFiles/README.md
43 | ```
44 |
45 | 2. Running demo
46 | ```Shell
47 | cd noreward-rl/src/
48 | python demo.py --ckpt ../models/doom/doom_ICM
49 | python demo.py --env-id SuperMarioBros-1-1-v0 --ckpt ../models/mario/mario_ICM
50 | ```
51 |
52 | 3. Training code
53 | ```Shell
54 | cd noreward-rl/src/
55 | # For Doom: doom or doomSparse or doomVerySparse
56 | python train.py --default --env-id doom
57 |
58 | # For Mario, change src/constants.py as follows:
59 | # PREDICTION_BETA = 0.2
60 | # ENTROPY_BETA = 0.0005
61 | python train.py --default --env-id mario --noReward
62 |
63 | xvfb-run -s "-screen 0 1400x900x24" bash # only for remote desktops
64 | # useful xvfb link: http://stackoverflow.com/a/30336424
65 | python inference.py --default --env-id doom --record
66 | ```
67 |
68 | ### 2) Other helpful pointers
69 | - [Paper](https://pathak22.github.io/noreward-rl/resources/icml17.pdf)
70 | - [Project Website](http://pathak22.github.io/noreward-rl/)
71 | - [Demo Video](http://pathak22.github.io/noreward-rl/index.html#demoVideo)
72 | - [Reddit Discussion](https://redd.it/6bc8ul)
73 | - [Media Articles (New Scientist, MIT Tech Review and others)](http://pathak22.github.io/noreward-rl/index.html#media)
74 |
75 | ### 3) Acknowledgement
76 | Vanilla A3C code is based on the open source implementation of [universe-starter-agent](https://github.com/openai/universe-starter-agent).
77 |
--------------------------------------------------------------------------------
/doomFiles/README.md:
--------------------------------------------------------------------------------
1 | ### VizDoom Scenarios
2 | This directory provides the relevant files to replicate doom scenarios in ICML'17 paper. Run following commands:
3 |
4 | ```Shell
5 | cp wads/*.wad ../curiosity/lib/python2.7/site-packages/doom_py/scenarios/
6 | cp __init__.py ../curiosity/lib/python2.7/site-packages/ppaquette_gym_doom/
7 | cp doom*.py ../curiosity/lib/python2.7/site-packages/ppaquette_gym_doom/
8 | cp action_space.py ../curiosity/lib/python2.7/site-packages/ppaquette_gym_doom/wrappers/
9 | ```
10 |
--------------------------------------------------------------------------------
/doomFiles/__init__.py:
--------------------------------------------------------------------------------
1 | '''
2 | Place this file in:
3 | /home/pathak/projects/unsup-rl/unsuprl/local/lib/python2.7/site-packages/ppaquette_gym_doom/__init__.py
4 | '''
5 |
6 | from gym.envs.registration import register
7 | from gym.scoreboard.registration import add_task, add_group
8 | from .package_info import USERNAME
9 | from .doom_env import DoomEnv, MetaDoomEnv
10 | from .doom_basic import DoomBasicEnv
11 | from .doom_corridor import DoomCorridorEnv
12 | from .doom_defend_center import DoomDefendCenterEnv
13 | from .doom_defend_line import DoomDefendLineEnv
14 | from .doom_health_gathering import DoomHealthGatheringEnv
15 | from .doom_my_way_home import DoomMyWayHomeEnv
16 | from .doom_predict_position import DoomPredictPositionEnv
17 | from .doom_take_cover import DoomTakeCoverEnv
18 | from .doom_deathmatch import DoomDeathmatchEnv
19 | from .doom_my_way_home_sparse import DoomMyWayHomeFixedEnv
20 | from .doom_my_way_home_verySparse import DoomMyWayHomeFixed15Env
21 |
22 | # Env registration
23 | # ==========================
24 |
25 | register(
26 | id='{}/meta-Doom-v0'.format(USERNAME),
27 | entry_point='{}_gym_doom:MetaDoomEnv'.format(USERNAME),
28 | timestep_limit=999999,
29 | reward_threshold=9000.0,
30 | kwargs={
31 | 'average_over': 3,
32 | 'passing_grade': 600,
33 | 'min_tries_for_avg': 3
34 | },
35 | )
36 |
37 | register(
38 | id='{}/DoomBasic-v0'.format(USERNAME),
39 | entry_point='{}_gym_doom:DoomBasicEnv'.format(USERNAME),
40 | timestep_limit=10000,
41 | reward_threshold=10.0,
42 | )
43 |
44 | register(
45 | id='{}/DoomCorridor-v0'.format(USERNAME),
46 | entry_point='{}_gym_doom:DoomCorridorEnv'.format(USERNAME),
47 | timestep_limit=10000,
48 | reward_threshold=1000.0,
49 | )
50 |
51 | register(
52 | id='{}/DoomDefendCenter-v0'.format(USERNAME),
53 | entry_point='{}_gym_doom:DoomDefendCenterEnv'.format(USERNAME),
54 | timestep_limit=10000,
55 | reward_threshold=10.0,
56 | )
57 |
58 | register(
59 | id='{}/DoomDefendLine-v0'.format(USERNAME),
60 | entry_point='{}_gym_doom:DoomDefendLineEnv'.format(USERNAME),
61 | timestep_limit=10000,
62 | reward_threshold=15.0,
63 | )
64 |
65 | register(
66 | id='{}/DoomHealthGathering-v0'.format(USERNAME),
67 | entry_point='{}_gym_doom:DoomHealthGatheringEnv'.format(USERNAME),
68 | timestep_limit=10000,
69 | reward_threshold=1000.0,
70 | )
71 |
72 | register(
73 | id='{}/DoomMyWayHome-v0'.format(USERNAME),
74 | entry_point='{}_gym_doom:DoomMyWayHomeEnv'.format(USERNAME),
75 | timestep_limit=10000,
76 | reward_threshold=0.5,
77 | )
78 |
79 | register(
80 | id='{}/DoomPredictPosition-v0'.format(USERNAME),
81 | entry_point='{}_gym_doom:DoomPredictPositionEnv'.format(USERNAME),
82 | timestep_limit=10000,
83 | reward_threshold=0.5,
84 | )
85 |
86 | register(
87 | id='{}/DoomTakeCover-v0'.format(USERNAME),
88 | entry_point='{}_gym_doom:DoomTakeCoverEnv'.format(USERNAME),
89 | timestep_limit=10000,
90 | reward_threshold=750.0,
91 | )
92 |
93 | register(
94 | id='{}/DoomDeathmatch-v0'.format(USERNAME),
95 | entry_point='{}_gym_doom:DoomDeathmatchEnv'.format(USERNAME),
96 | timestep_limit=10000,
97 | reward_threshold=20.0,
98 | )
99 |
100 | register(
101 | id='{}/DoomMyWayHomeFixed-v0'.format(USERNAME),
102 | entry_point='{}_gym_doom:DoomMyWayHomeFixedEnv'.format(USERNAME),
103 | timestep_limit=10000,
104 | reward_threshold=0.5,
105 | )
106 |
107 | register(
108 | id='{}/DoomMyWayHomeFixed15-v0'.format(USERNAME),
109 | entry_point='{}_gym_doom:DoomMyWayHomeFixed15Env'.format(USERNAME),
110 | timestep_limit=10000,
111 | reward_threshold=0.5,
112 | )
113 |
114 | # Scoreboard registration
115 | # ==========================
116 | add_group(
117 | id= 'doom',
118 | name= 'Doom',
119 | description= 'Doom environments based on VizDoom.'
120 | )
121 |
122 | add_task(
123 | id='{}/meta-Doom-v0'.format(USERNAME),
124 | group='doom',
125 | summary='Mission #1 to #9 - Beat all 9 Doom missions.',
126 | description="""
127 | This is a meta map that combines all 9 Doom levels.
128 |
129 | Levels:
130 | - #0 Doom Basic
131 | - #1 Doom Corridor
132 | - #2 Doom DefendCenter
133 | - #3 Doom DefendLine
134 | - #4 Doom HealthGathering
135 | - #5 Doom MyWayHome
136 | - #6 Doom PredictPosition
137 | - #7 Doom TakeCover
138 | - #8 Doom Deathmatch
139 | - #9 Doom MyWayHomeFixed (customized)
140 | - #10 Doom MyWayHomeFixed15 (customized)
141 |
142 | Goal: 9,000 points
143 | - Pass all levels
144 |
145 | Scoring:
146 | - Each level score has been standardized on a scale of 0 to 1,000
147 | - The passing score for a level is 990 (99th percentile)
148 | - A bonus of 450 (50 * 9 levels) is given if all levels are passed
149 | - The score for a level is the average of the last 3 tries
150 | """
151 | )
152 |
153 | add_task(
154 | id='{}/DoomBasic-v0'.format(USERNAME),
155 | group='doom',
156 | summary='Mission #1 - Kill a single monster using your pistol.',
157 | description="""
158 | This map is rectangular with gray walls, ceiling and floor.
159 | You are spawned in the center of the longer wall, and a red
160 | circular monster is spawned randomly on the opposite wall.
161 | You need to kill the monster (one bullet is enough).
162 |
163 | Goal: 10 points
164 | - Kill the monster in 3 secs with 1 shot
165 |
166 | Rewards:
167 | - Plus 101 pts for killing the monster
168 | - Minus 5 pts for missing a shot
169 | - Minus 1 pts every 0.028 secs
170 |
171 | Ends when:
172 | - Monster is dead
173 | - Player is dead
174 | - Timeout (10 seconds - 350 frames)
175 |
176 | Allowed actions:
177 | - ATTACK
178 | - MOVE_RIGHT
179 | - MOVE_LEFT
180 | """
181 | )
182 |
183 | add_task(
184 | id='{}/DoomCorridor-v0'.format(USERNAME),
185 | group='doom',
186 | summary='Mission #2 - Run as fast as possible to grab a vest.',
187 | description="""
188 | This map is designed to improve your navigation. There is a vest
189 | at the end of the corridor, with 6 enemies (3 groups of 2). Your goal
190 | is to get to the vest as soon as possible, without being killed.
191 |
192 | Goal: 1,000 points
193 | - Reach the vest (or get very close to it)
194 |
195 | Rewards:
196 | - Plus distance for getting closer to the vest
197 | - Minus distance for getting further from the vest
198 | - Minus 100 pts for getting killed
199 |
200 | Ends when:
201 | - Player touches vest
202 | - Player is dead
203 | - Timeout (1 minutes - 2,100 frames)
204 |
205 | Allowed actions:
206 | - ATTACK
207 | - MOVE_RIGHT
208 | - MOVE_LEFT
209 | - MOVE_FORWARD
210 | - TURN_RIGHT
211 | - TURN_LEFT
212 | """
213 | )
214 |
215 | add_task(
216 | id='{}/DoomDefendCenter-v0'.format(USERNAME),
217 | group='doom',
218 | summary='Mission #3 - Kill enemies coming at your from all sides.',
219 | description="""
220 | This map is designed to teach you how to kill and how to stay alive.
221 | You will also need to keep an eye on your ammunition level. You are only
222 | rewarded for kills, so figure out how to stay alive.
223 |
224 | The map is a circle with monsters. You are in the middle. Monsters will
225 | respawn with additional health when killed. Kill as many as you can
226 | before you run out of ammo.
227 |
228 | Goal: 10 points
229 | - Kill 11 monsters (you have 26 ammo)
230 |
231 | Rewards:
232 | - Plus 1 point for killing a monster
233 | - Minus 1 point for getting killed
234 |
235 | Ends when:
236 | - Player is dead
237 | - Timeout (60 seconds - 2100 frames)
238 |
239 | Allowed actions:
240 | - ATTACK
241 | - TURN_RIGHT
242 | - TURN_LEFT
243 | """
244 | )
245 |
246 | add_task(
247 | id='{}/DoomDefendLine-v0'.format(USERNAME),
248 | group='doom',
249 | summary='Mission #4 - Kill enemies on the other side of the room.',
250 | description="""
251 | This map is designed to teach you how to kill and how to stay alive.
252 | Your ammo will automatically replenish. You are only rewarded for kills,
253 | so figure out how to stay alive.
254 |
255 | The map is a rectangle with monsters on the other side. Monsters will
256 | respawn with additional health when killed. Kill as many as you can
257 | before they kill you. This map is harder than the previous.
258 |
259 | Goal: 15 points
260 | - Kill 16 monsters
261 |
262 | Rewards:
263 | - Plus 1 point for killing a monster
264 | - Minus 1 point for getting killed
265 |
266 | Ends when:
267 | - Player is dead
268 | - Timeout (60 seconds - 2100 frames)
269 |
270 | Allowed actions:
271 | - ATTACK
272 | - TURN_RIGHT
273 | - TURN_LEFT
274 | """
275 | )
276 |
277 | add_task(
278 | id='{}/DoomHealthGathering-v0'.format(USERNAME),
279 | group='doom',
280 | summary='Mission #5 - Learn to grad medkits to survive as long as possible.',
281 | description="""
282 | This map is a guide on how to survive by collecting health packs.
283 | It is a rectangle with green, acidic floor which hurts the player
284 | periodically. There are also medkits spread around the map, and
285 | additional kits will spawn at interval.
286 |
287 | Goal: 1000 points
288 | - Stay alive long enough for approx. 30 secs
289 |
290 | Rewards:
291 | - Plus 1 point every 0.028 secs
292 | - Minus 100 pts for dying
293 |
294 | Ends when:
295 | - Player is dead
296 | - Timeout (60 seconds - 2,100 frames)
297 |
298 | Allowed actions:
299 | - MOVE_FORWARD
300 | - TURN_RIGHT
301 | - TURN_LEFT
302 | """
303 | )
304 |
305 | add_task(
306 | id='{}/DoomMyWayHome-v0'.format(USERNAME),
307 | group='doom',
308 | summary='Mission #6 - Find the vest in one the 4 rooms.',
309 | description="""
310 | This map is designed to improve navigational skills. It is a series of
311 | interconnected rooms and 1 corridor with a dead end. Each room
312 | has a separate color. There is a green vest in one of the room.
313 | The vest is always in the same room. Player must find the vest.
314 |
315 | Goal: 0.50 point
316 | - Find the vest
317 |
318 | Rewards:
319 | - Plus 1 point for finding the vest
320 | - Minus 0.0001 point every 0.028 secs
321 |
322 | Ends when:
323 | - Vest is found
324 | - Timeout (1 minutes - 2,100 frames)
325 |
326 | Allowed actions:
327 | - MOVE_FORWARD
328 | - TURN_RIGHT
329 | - TURN_LEFT
330 | """
331 | )
332 |
333 | add_task(
334 | id='{}/DoomPredictPosition-v0'.format(USERNAME),
335 | group='doom',
336 | summary='Mission #7 - Learn how to kill an enemy with a rocket launcher.',
337 | description="""
338 | This map is designed to train you on using a rocket launcher.
339 | It is a rectangular map with a monster on the opposite side. You need
340 | to use your rocket launcher to kill it. The rocket adds a delay between
341 | the moment it is fired and the moment it reaches the other side of the room.
342 | You need to predict the position of the monster to kill it.
343 |
344 | Goal: 0.5 point
345 | - Kill the monster
346 |
347 | Rewards:
348 | - Plus 1 point for killing the monster
349 | - Minus 0.0001 point every 0.028 secs
350 |
351 | Ends when:
352 | - Monster is dead
353 | - Out of missile (you only have one)
354 | - Timeout (20 seconds - 700 frames)
355 |
356 | Hint: Wait 1 sec for the missile launcher to load.
357 |
358 | Allowed actions:
359 | - ATTACK
360 | - TURN_RIGHT
361 | - TURN_LEFT
362 | """
363 | )
364 |
365 | add_task(
366 | id='{}/DoomTakeCover-v0'.format(USERNAME),
367 | group='doom',
368 | summary='Mission #8 - Survive as long as possible with enemies shooting at you.',
369 | description="""
370 | This map is to train you on the damage of incoming missiles.
371 | It is a rectangular map with monsters firing missiles and fireballs
372 | at you. You need to survive as long as possible.
373 |
374 | Goal: 750 points
375 | - Survive for approx. 20 seconds
376 |
377 | Rewards:
378 | - Plus 1 point every 0.028 secs
379 |
380 | Ends when:
381 | - Player is dead (1 or 2 fireballs is enough)
382 | - Timeout (60 seconds - 2,100 frames)
383 |
384 | Allowed actions:
385 | - MOVE_RIGHT
386 | - MOVE_LEFT
387 | """
388 | )
389 |
390 | add_task(
391 | id='{}/DoomDeathmatch-v0'.format(USERNAME),
392 | group='doom',
393 | summary='Mission #9 - Kill as many enemies as possible without being killed.',
394 | description="""
395 | Kill as many monsters as possible without being killed.
396 |
397 | Goal: 20 points
398 | - Kill 20 monsters
399 |
400 | Rewards:
401 | - Plus 1 point for killing a monster
402 |
403 | Ends when:
404 | - Player is dead
405 | - Timeout (3 minutes - 6,300 frames)
406 |
407 | Allowed actions:
408 | - ALL
409 | """
410 | )
411 |
412 | add_task(
413 | id='{}/DoomMyWayHomeFixed-v0'.format(USERNAME),
414 | group='doom',
415 | summary='Mission #10 - Find the vest in one the 4 rooms.',
416 | description="""
417 | This map is designed to improve navigational skills. It is a series of
418 | interconnected rooms and 1 corridor with a dead end. Each room
419 | has a separate color. There is a green vest in one of the room.
420 | The vest is always in the same room. Player must find the vest.
421 | You always start from fixed room (room no. 10 -- farthest).
422 |
423 | Goal: 0.50 point
424 | - Find the vest
425 |
426 | Rewards:
427 | - Plus 1 point for finding the vest
428 | - Minus 0.0001 point every 0.028 secs
429 |
430 | Ends when:
431 | - Vest is found
432 | - Timeout (1 minutes - 2,100 frames)
433 |
434 | Allowed actions:
435 | - MOVE_FORWARD
436 | - TURN_RIGHT
437 | - TURN_LEFT
438 | """
439 | )
440 |
441 | add_task(
442 | id='{}/DoomMyWayHomeFixed15-v0'.format(USERNAME),
443 | group='doom',
444 | summary='Mission #11 - Find the vest in one the 4 rooms.',
445 | description="""
446 | This map is designed to improve navigational skills. It is a series of
447 | interconnected rooms and 1 corridor with a dead end. Each room
448 | has a separate color. There is a green vest in one of the room.
449 | The vest is always in the same room. Player must find the vest.
450 | You always start from fixed room (room no. 10 -- farthest).
451 |
452 | Goal: 0.50 point
453 | - Find the vest
454 |
455 | Rewards:
456 | - Plus 1 point for finding the vest
457 | - Minus 0.0001 point every 0.028 secs
458 |
459 | Ends when:
460 | - Vest is found
461 | - Timeout (1 minutes - 2,100 frames)
462 |
463 | Allowed actions:
464 | - MOVE_FORWARD
465 | - TURN_RIGHT
466 | - TURN_LEFT
467 | """
468 | )
469 |
--------------------------------------------------------------------------------
/doomFiles/action_space.py:
--------------------------------------------------------------------------------
1 | '''
2 | Place this file in:
3 | /home/pathak/projects/unsup-rl/unsuprl/local/lib/python2.7/site-packages/ppaquette_gym_doom/wrappers/action_space.py
4 | '''
5 |
6 | import gym
7 |
8 | # Constants
9 | NUM_ACTIONS = 43
10 | ALLOWED_ACTIONS = [
11 | [0, 10, 11], # 0 - Basic
12 | [0, 10, 11, 13, 14, 15], # 1 - Corridor
13 | [0, 14, 15], # 2 - DefendCenter
14 | [0, 14, 15], # 3 - DefendLine
15 | [13, 14, 15], # 4 - HealthGathering
16 | [13, 14, 15], # 5 - MyWayHome
17 | [0, 14, 15], # 6 - PredictPosition
18 | [10, 11], # 7 - TakeCover
19 | [x for x in range(NUM_ACTIONS) if x != 33], # 8 - Deathmatch
20 | [13, 14, 15], # 9 - MyWayHomeFixed
21 | [13, 14, 15], # 10 - MyWayHomeFixed15
22 | ]
23 |
24 | __all__ = [ 'ToDiscrete', 'ToBox' ]
25 |
26 | def ToDiscrete(config):
27 | # Config can be 'minimal', 'constant-7', 'constant-17', 'full'
28 |
29 | class ToDiscreteWrapper(gym.Wrapper):
30 | """
31 | Doom wrapper to convert MultiDiscrete action space to Discrete
32 |
33 | config:
34 | - minimal - Will only use the levels' allowed actions (+ NOOP)
35 | - constant-7 - Will use the 7 minimum actions (+NOOP) to complete all levels
36 | - constant-17 - Will use the 17 most common actions (+NOOP) to complete all levels
37 | - full - Will use all available actions (+ NOOP)
38 |
39 | list of commands:
40 | - minimal:
41 | Basic: NOOP, ATTACK, MOVE_RIGHT, MOVE_LEFT
42 | Corridor: NOOP, ATTACK, MOVE_RIGHT, MOVE_LEFT, MOVE_FORWARD, TURN_RIGHT, TURN_LEFT
43 | DefendCenter NOOP, ATTACK, TURN_RIGHT, TURN_LEFT
44 | DefendLine: NOOP, ATTACK, TURN_RIGHT, TURN_LEFT
45 | HealthGathering: NOOP, MOVE_FORWARD, TURN_RIGHT, TURN_LEFT
46 | MyWayHome: NOOP, MOVE_FORWARD, TURN_RIGHT, TURN_LEFT
47 | PredictPosition: NOOP, ATTACK, TURN_RIGHT, TURN_LEFT
48 | TakeCover: NOOP, MOVE_RIGHT, MOVE_LEFT
49 | Deathmatch: NOOP, ALL COMMANDS (Deltas are limited to [0,1] range and will not work properly)
50 |
51 | - constant-7: NOOP, ATTACK, MOVE_RIGHT, MOVE_LEFT, MOVE_FORWARD, TURN_RIGHT, TURN_LEFT, SELECT_NEXT_WEAPON
52 |
53 | - constant-17: NOOP, ATTACK, JUMP, CROUCH, TURN180, RELOAD, SPEED, STRAFE, MOVE_RIGHT, MOVE_LEFT, MOVE_BACKWARD
54 | MOVE_FORWARD, TURN_RIGHT, TURN_LEFT, LOOK_UP, LOOK_DOWN, SELECT_NEXT_WEAPON, SELECT_PREV_WEAPON
55 | """
56 | def __init__(self, env):
57 | super(ToDiscreteWrapper, self).__init__(env)
58 | if config == 'minimal':
59 | allowed_actions = ALLOWED_ACTIONS[self.unwrapped.level]
60 | elif config == 'constant-7':
61 | allowed_actions = [0, 10, 11, 13, 14, 15, 31]
62 | elif config == 'constant-17':
63 | allowed_actions = [0, 2, 3, 4, 6, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 31, 32]
64 | elif config == 'full':
65 | allowed_actions = None
66 | else:
67 | raise gym.error.Error('Invalid configuration. Valid options are "minimal", "constant-7", "constant-17", "full"')
68 | self.action_space = gym.spaces.multi_discrete.DiscreteToMultiDiscrete(self.action_space, allowed_actions)
69 | def _step(self, action):
70 | return self.env._step(self.action_space(action))
71 |
72 | return ToDiscreteWrapper
73 |
74 | def ToBox(config):
75 | # Config can be 'minimal', 'constant-7', 'constant-17', 'full'
76 |
77 | class ToBoxWrapper(gym.Wrapper):
78 | """
79 | Doom wrapper to convert MultiDiscrete action space to Box
80 |
81 | config:
82 | - minimal - Will only use the levels' allowed actions
83 | - constant-7 - Will use the 7 minimum actions to complete all levels
84 | - constant-17 - Will use the 17 most common actions to complete all levels
85 | - full - Will use all available actions
86 |
87 | list of commands:
88 | - minimal:
89 | Basic: ATTACK, MOVE_RIGHT, MOVE_LEFT
90 | Corridor: ATTACK, MOVE_RIGHT, MOVE_LEFT, MOVE_FORWARD, TURN_RIGHT, TURN_LEFT
91 | DefendCenter ATTACK, TURN_RIGHT, TURN_LEFT
92 | DefendLine: ATTACK, TURN_RIGHT, TURN_LEFT
93 | HealthGathering: MOVE_FORWARD, TURN_RIGHT, TURN_LEFT
94 | MyWayHome: MOVE_FORWARD, TURN_RIGHT, TURN_LEFT
95 | PredictPosition: ATTACK, TURN_RIGHT, TURN_LEFT
96 | TakeCover: MOVE_RIGHT, MOVE_LEFT
97 | Deathmatch: ALL COMMANDS
98 |
99 | - constant-7: ATTACK, MOVE_RIGHT, MOVE_LEFT, MOVE_FORWARD, TURN_RIGHT, TURN_LEFT, SELECT_NEXT_WEAPON
100 |
101 | - constant-17: ATTACK, JUMP, CROUCH, TURN180, RELOAD, SPEED, STRAFE, MOVE_RIGHT, MOVE_LEFT, MOVE_BACKWARD
102 | MOVE_FORWARD, TURN_RIGHT, TURN_LEFT, LOOK_UP, LOOK_DOWN, SELECT_NEXT_WEAPON, SELECT_PREV_WEAPON
103 | """
104 | def __init__(self, env):
105 | super(ToBoxWrapper, self).__init__(env)
106 | if config == 'minimal':
107 | allowed_actions = ALLOWED_ACTIONS[self.unwrapped.level]
108 | elif config == 'constant-7':
109 | allowed_actions = [0, 10, 11, 13, 14, 15, 31]
110 | elif config == 'constant-17':
111 | allowed_actions = [0, 2, 3, 4, 6, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 31, 32]
112 | elif config == 'full':
113 | allowed_actions = None
114 | else:
115 | raise gym.error.Error('Invalid configuration. Valid options are "minimal", "constant-7", "constant-17", "full"')
116 | self.action_space = gym.spaces.multi_discrete.BoxToMultiDiscrete(self.action_space, allowed_actions)
117 | def _step(self, action):
118 | return self.env._step(self.action_space(action))
119 |
120 | return ToBoxWrapper
121 |
--------------------------------------------------------------------------------
/doomFiles/doom_env.py:
--------------------------------------------------------------------------------
1 | '''
2 | Place this file in:
3 | /home/pathak/projects/unsup-rl/unsuprl/local/lib/python2.7/site-packages/ppaquette_gym_doom/doom_env.py
4 | '''
5 |
6 | import logging
7 | import os
8 | from time import sleep
9 | import multiprocessing
10 |
11 | import numpy as np
12 |
13 | import gym
14 | from gym import spaces, error
15 | from gym.utils import seeding
16 |
17 | try:
18 | import doom_py
19 | from doom_py import DoomGame, Mode, Button, GameVariable, ScreenFormat, ScreenResolution, Loader, doom_fixed_to_double
20 | from doom_py.vizdoom import ViZDoomUnexpectedExitException, ViZDoomErrorException
21 | except ImportError as e:
22 | raise gym.error.DependencyNotInstalled("{}. (HINT: you can install Doom dependencies " +
23 | "with 'pip install doom_py.)'".format(e))
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 | # Arguments:
28 | RANDOMIZE_MAPS = 80 # 0 means load default, otherwise randomly load in the id mentioned
29 | NO_MONSTERS = True # remove monster spawning
30 |
31 | # Constants
32 | NUM_ACTIONS = 43
33 | NUM_LEVELS = 9
34 | CONFIG = 0
35 | SCENARIO = 1
36 | MAP = 2
37 | DIFFICULTY = 3
38 | ACTIONS = 4
39 | MIN_SCORE = 5
40 | TARGET_SCORE = 6
41 |
42 | # Format (config, scenario, map, difficulty, actions, min, target)
43 | DOOM_SETTINGS = [
44 | ['basic.cfg', 'basic.wad', 'map01', 5, [0, 10, 11], -485, 10], # 0 - Basic
45 | ['deadly_corridor.cfg', 'deadly_corridor.wad', '', 1, [0, 10, 11, 13, 14, 15], -120, 1000], # 1 - Corridor
46 | ['defend_the_center.cfg', 'defend_the_center.wad', '', 5, [0, 14, 15], -1, 10], # 2 - DefendCenter
47 | ['defend_the_line.cfg', 'defend_the_line.wad', '', 5, [0, 14, 15], -1, 15], # 3 - DefendLine
48 | ['health_gathering.cfg', 'health_gathering.wad', 'map01', 5, [13, 14, 15], 0, 1000], # 4 - HealthGathering
49 | ['my_way_home.cfg', 'my_way_home_dense.wad', '', 5, [13, 14, 15], -0.22, 0.5], # 5 - MyWayHome
50 | ['predict_position.cfg', 'predict_position.wad', 'map01', 3, [0, 14, 15], -0.075, 0.5], # 6 - PredictPosition
51 | ['take_cover.cfg', 'take_cover.wad', 'map01', 5, [10, 11], 0, 750], # 7 - TakeCover
52 | ['deathmatch.cfg', 'deathmatch.wad', '', 5, [x for x in range(NUM_ACTIONS) if x != 33], 0, 20], # 8 - Deathmatch
53 | ['my_way_home.cfg', 'my_way_home_sparse.wad', '', 5, [13, 14, 15], -0.22, 0.5], # 9 - MyWayHomeFixed
54 | ['my_way_home.cfg', 'my_way_home_verySparse.wad', '', 5, [13, 14, 15], -0.22, 0.5], # 10 - MyWayHomeFixed15
55 | ]
56 |
57 | # Singleton pattern
58 | class DoomLock:
59 | class __DoomLock:
60 | def __init__(self):
61 | self.lock = multiprocessing.Lock()
62 | instance = None
63 | def __init__(self):
64 | if not DoomLock.instance:
65 | DoomLock.instance = DoomLock.__DoomLock()
66 | def get_lock(self):
67 | return DoomLock.instance.lock
68 |
69 |
70 | class DoomEnv(gym.Env):
71 | metadata = {'render.modes': ['human', 'rgb_array'], 'video.frames_per_second': 35}
72 |
73 | def __init__(self, level):
74 | self.previous_level = -1
75 | self.level = level
76 | self.game = DoomGame()
77 | self.loader = Loader()
78 | self.doom_dir = os.path.dirname(os.path.abspath(__file__))
79 | self._mode = 'algo' # 'algo' or 'human'
80 | self.no_render = False # To disable double rendering in human mode
81 | self.viewer = None
82 | self.is_initialized = False # Indicates that reset() has been called
83 | self.curr_seed = 0
84 | self.lock = (DoomLock()).get_lock()
85 | self.action_space = spaces.MultiDiscrete([[0, 1]] * 38 + [[-10, 10]] * 2 + [[-100, 100]] * 3)
86 | self.allowed_actions = list(range(NUM_ACTIONS))
87 | self.screen_height = 480
88 | self.screen_width = 640
89 | self.screen_resolution = ScreenResolution.RES_640X480
90 | self.observation_space = spaces.Box(low=0, high=255, shape=(self.screen_height, self.screen_width, 3))
91 | self._seed()
92 | self._configure()
93 |
94 | def _configure(self, lock=None, **kwargs):
95 | if 'screen_resolution' in kwargs:
96 | logger.warn('Deprecated - Screen resolution must now be set using a wrapper. See documentation for details.')
97 | # Multiprocessing lock
98 | if lock is not None:
99 | self.lock = lock
100 |
101 | def _load_level(self):
102 | # Closing if is_initialized
103 | if self.is_initialized:
104 | self.is_initialized = False
105 | self.game.close()
106 | self.game = DoomGame()
107 |
108 | # Customizing level
109 | if getattr(self, '_customize_game', None) is not None and callable(self._customize_game):
110 | self.level = -1
111 | self._customize_game()
112 |
113 | else:
114 | # Loading Paths
115 | if not self.is_initialized:
116 | self.game.set_vizdoom_path(self.loader.get_vizdoom_path())
117 | self.game.set_doom_game_path(self.loader.get_freedoom_path())
118 |
119 | # Common settings
120 | self.game.load_config(os.path.join(self.doom_dir, 'assets/%s' % DOOM_SETTINGS[self.level][CONFIG]))
121 | self.game.set_doom_scenario_path(self.loader.get_scenario_path(DOOM_SETTINGS[self.level][SCENARIO]))
122 | if DOOM_SETTINGS[self.level][MAP] != '':
123 | if RANDOMIZE_MAPS > 0 and 'labyrinth' in DOOM_SETTINGS[self.level][CONFIG].lower():
124 | if 'fix' in DOOM_SETTINGS[self.level][SCENARIO].lower():
125 | # mapId = 'map%02d'%np.random.randint(1, 23)
126 | mapId = 'map%02d'%np.random.randint(4, 8)
127 | else:
128 | mapId = 'map%02d'%np.random.randint(1, RANDOMIZE_MAPS+1)
129 | print('\t=> Special Config: Randomly Loading Maps. MapID = ' + mapId)
130 | self.game.set_doom_map(mapId)
131 | else:
132 | print('\t=> Default map loaded. MapID = ' + DOOM_SETTINGS[self.level][MAP])
133 | self.game.set_doom_map(DOOM_SETTINGS[self.level][MAP])
134 | self.game.set_doom_skill(DOOM_SETTINGS[self.level][DIFFICULTY])
135 | self.allowed_actions = DOOM_SETTINGS[self.level][ACTIONS]
136 | self.game.set_screen_resolution(self.screen_resolution)
137 |
138 | self.previous_level = self.level
139 | self._closed = False
140 |
141 | # Algo mode
142 | if 'human' != self._mode:
143 | if NO_MONSTERS:
144 | print('\t=> Special Config: Monsters Removed.')
145 | self.game.add_game_args('-nomonsters 1')
146 | self.game
147 | self.game.set_window_visible(False)
148 | self.game.set_mode(Mode.PLAYER)
149 | self.no_render = False
150 | try:
151 | with self.lock:
152 | self.game.init()
153 | except (ViZDoomUnexpectedExitException, ViZDoomErrorException):
154 | raise error.Error(
155 | 'VizDoom exited unexpectedly. This is likely caused by a missing multiprocessing lock. ' +
156 | 'To run VizDoom across multiple processes, you need to pass a lock when you configure the env ' +
157 | '[e.g. env.configure(lock=my_multiprocessing_lock)], or create and close an env ' +
158 | 'before starting your processes [e.g. env = gym.make("DoomBasic-v0"); env.close()] to cache a ' +
159 | 'singleton lock in memory.')
160 | self._start_episode()
161 | self.is_initialized = True
162 | return self.game.get_state().image_buffer.copy()
163 |
164 | # Human mode
165 | else:
166 | if NO_MONSTERS:
167 | print('\t=> Special Config: Monsters Removed.')
168 | self.game.add_game_args('-nomonsters 1')
169 | self.game.add_game_args('+freelook 1')
170 | self.game.set_window_visible(True)
171 | self.game.set_mode(Mode.SPECTATOR)
172 | self.no_render = True
173 | with self.lock:
174 | self.game.init()
175 | self._start_episode()
176 | self.is_initialized = True
177 | self._play_human_mode()
178 | return np.zeros(shape=self.observation_space.shape, dtype=np.uint8)
179 |
180 | def _start_episode(self):
181 | if self.curr_seed > 0:
182 | self.game.set_seed(self.curr_seed)
183 | self.curr_seed = 0
184 | self.game.new_episode()
185 | return
186 |
187 | def _play_human_mode(self):
188 | while not self.game.is_episode_finished():
189 | self.game.advance_action()
190 | state = self.game.get_state()
191 | total_reward = self.game.get_total_reward()
192 | info = self._get_game_variables(state.game_variables)
193 | info["TOTAL_REWARD"] = round(total_reward, 4)
194 | print('===============================')
195 | print('State: #' + str(state.number))
196 | print('Action: \t' + str(self.game.get_last_action()) + '\t (=> only allowed actions)')
197 | print('Reward: \t' + str(self.game.get_last_reward()))
198 | print('Total Reward: \t' + str(total_reward))
199 | print('Variables: \n' + str(info))
200 | sleep(0.02857) # 35 fps = 0.02857 sleep between frames
201 | print('===============================')
202 | print('Done')
203 | return
204 |
205 | def _step(self, action):
206 | if NUM_ACTIONS != len(action):
207 | logger.warn('Doom action list must contain %d items. Padding missing items with 0' % NUM_ACTIONS)
208 | old_action = action
209 | action = [0] * NUM_ACTIONS
210 | for i in range(len(old_action)):
211 | action[i] = old_action[i]
212 | # action is a list of numbers but DoomGame.make_action expects a list of ints
213 | if len(self.allowed_actions) > 0:
214 | list_action = [int(action[action_idx]) for action_idx in self.allowed_actions]
215 | else:
216 | list_action = [int(x) for x in action]
217 | try:
218 | reward = self.game.make_action(list_action)
219 | state = self.game.get_state()
220 | info = self._get_game_variables(state.game_variables)
221 | info["TOTAL_REWARD"] = round(self.game.get_total_reward(), 4)
222 |
223 | if self.game.is_episode_finished():
224 | is_finished = True
225 | return np.zeros(shape=self.observation_space.shape, dtype=np.uint8), reward, is_finished, info
226 | else:
227 | is_finished = False
228 | return state.image_buffer.copy(), reward, is_finished, info
229 |
230 | except doom_py.vizdoom.ViZDoomIsNotRunningException:
231 | return np.zeros(shape=self.observation_space.shape, dtype=np.uint8), 0, True, {}
232 |
233 | def _reset(self):
234 | if self.is_initialized and not self._closed:
235 | self._start_episode()
236 | image_buffer = self.game.get_state().image_buffer
237 | if image_buffer is None:
238 | raise error.Error(
239 | 'VizDoom incorrectly initiated. This is likely caused by a missing multiprocessing lock. ' +
240 | 'To run VizDoom across multiple processes, you need to pass a lock when you configure the env ' +
241 | '[e.g. env.configure(lock=my_multiprocessing_lock)], or create and close an env ' +
242 | 'before starting your processes [e.g. env = gym.make("DoomBasic-v0"); env.close()] to cache a ' +
243 | 'singleton lock in memory.')
244 | return image_buffer.copy()
245 | else:
246 | return self._load_level()
247 |
248 | def _render(self, mode='human', close=False):
249 | if close:
250 | if self.viewer is not None:
251 | self.viewer.close()
252 | self.viewer = None # If we don't None out this reference pyglet becomes unhappy
253 | return
254 | try:
255 | if 'human' == mode and self.no_render:
256 | return
257 | state = self.game.get_state()
258 | img = state.image_buffer
259 | # VizDoom returns None if the episode is finished, let's make it
260 | # an empty image so the recorder doesn't stop
261 | if img is None:
262 | img = np.zeros(shape=self.observation_space.shape, dtype=np.uint8)
263 | if mode == 'rgb_array':
264 | return img
265 | elif mode is 'human':
266 | from gym.envs.classic_control import rendering
267 | if self.viewer is None:
268 | self.viewer = rendering.SimpleImageViewer()
269 | self.viewer.imshow(img)
270 | except doom_py.vizdoom.ViZDoomIsNotRunningException:
271 | pass # Doom has been closed
272 |
273 | def _close(self):
274 | # Lock required for VizDoom to close processes properly
275 | with self.lock:
276 | self.game.close()
277 |
278 | def _seed(self, seed=None):
279 | self.curr_seed = seeding.hash_seed(seed) % 2 ** 32
280 | return [self.curr_seed]
281 |
282 | def _get_game_variables(self, state_variables):
283 | info = {
284 | "LEVEL": self.level
285 | }
286 | if state_variables is None:
287 | return info
288 | info['KILLCOUNT'] = state_variables[0]
289 | info['ITEMCOUNT'] = state_variables[1]
290 | info['SECRETCOUNT'] = state_variables[2]
291 | info['FRAGCOUNT'] = state_variables[3]
292 | info['HEALTH'] = state_variables[4]
293 | info['ARMOR'] = state_variables[5]
294 | info['DEAD'] = state_variables[6]
295 | info['ON_GROUND'] = state_variables[7]
296 | info['ATTACK_READY'] = state_variables[8]
297 | info['ALTATTACK_READY'] = state_variables[9]
298 | info['SELECTED_WEAPON'] = state_variables[10]
299 | info['SELECTED_WEAPON_AMMO'] = state_variables[11]
300 | info['AMMO1'] = state_variables[12]
301 | info['AMMO2'] = state_variables[13]
302 | info['AMMO3'] = state_variables[14]
303 | info['AMMO4'] = state_variables[15]
304 | info['AMMO5'] = state_variables[16]
305 | info['AMMO6'] = state_variables[17]
306 | info['AMMO7'] = state_variables[18]
307 | info['AMMO8'] = state_variables[19]
308 | info['AMMO9'] = state_variables[20]
309 | info['AMMO0'] = state_variables[21]
310 | info['POSITION_X'] = doom_fixed_to_double(self.game.get_game_variable(GameVariable.USER1))
311 | info['POSITION_Y'] = doom_fixed_to_double(self.game.get_game_variable(GameVariable.USER2))
312 | return info
313 |
314 |
315 | class MetaDoomEnv(DoomEnv):
316 |
317 | def __init__(self, average_over=10, passing_grade=600, min_tries_for_avg=5):
318 | super(MetaDoomEnv, self).__init__(0)
319 | self.average_over = average_over
320 | self.passing_grade = passing_grade
321 | self.min_tries_for_avg = min_tries_for_avg # Need to use at least this number of tries to calc avg
322 | self.scores = [[]] * NUM_LEVELS
323 | self.locked_levels = [True] * NUM_LEVELS # Locking all levels but the first
324 | self.locked_levels[0] = False
325 | self.total_reward = 0
326 | self.find_new_level = False # Indicates that we need a level change
327 | self._unlock_levels()
328 |
329 | def _play_human_mode(self):
330 | while not self.game.is_episode_finished():
331 | self.game.advance_action()
332 | state = self.game.get_state()
333 | episode_reward = self.game.get_total_reward()
334 | (reward, self.total_reward) = self._calculate_reward(episode_reward, self.total_reward)
335 | info = self._get_game_variables(state.game_variables)
336 | info["SCORES"] = self.get_scores()
337 | info["TOTAL_REWARD"] = round(self.total_reward, 4)
338 | info["LOCKED_LEVELS"] = self.locked_levels
339 | print('===============================')
340 | print('State: #' + str(state.number))
341 | print('Action: \t' + str(self.game.get_last_action()) + '\t (=> only allowed actions)')
342 | print('Reward: \t' + str(reward))
343 | print('Total Reward: \t' + str(self.total_reward))
344 | print('Variables: \n' + str(info))
345 | sleep(0.02857) # 35 fps = 0.02857 sleep between frames
346 | print('===============================')
347 | print('Done')
348 | return
349 |
350 | def _get_next_level(self):
351 | # Finds the unlocked level with the lowest average
352 | averages = self.get_scores()
353 | lowest_level = 0 # Defaulting to first level
354 | lowest_score = 1001
355 | for i in range(NUM_LEVELS):
356 | if not self.locked_levels[i]:
357 | if averages[i] < lowest_score:
358 | lowest_level = i
359 | lowest_score = averages[i]
360 | return lowest_level
361 |
362 | def _unlock_levels(self):
363 | averages = self.get_scores()
364 | for i in range(NUM_LEVELS - 2, -1, -1):
365 | if self.locked_levels[i + 1] and averages[i] >= self.passing_grade:
366 | self.locked_levels[i + 1] = False
367 | return
368 |
369 | def _start_episode(self):
370 | if 0 == len(self.scores[self.level]):
371 | self.scores[self.level] = [0] * self.min_tries_for_avg
372 | else:
373 | self.scores[self.level].insert(0, 0)
374 | self.scores[self.level] = self.scores[self.level][:self.min_tries_for_avg]
375 | self.is_new_episode = True
376 | return super(MetaDoomEnv, self)._start_episode()
377 |
378 | def change_level(self, new_level=None):
379 | if new_level is not None and self.locked_levels[new_level] == False:
380 | self.find_new_level = False
381 | self.level = new_level
382 | self.reset()
383 | else:
384 | self.find_new_level = False
385 | self.level = self._get_next_level()
386 | self.reset()
387 | return
388 |
389 | def _get_standard_reward(self, episode_reward):
390 | # Returns a standardized reward for an episode (i.e. between 0 and 1,000)
391 | min_score = float(DOOM_SETTINGS[self.level][MIN_SCORE])
392 | target_score = float(DOOM_SETTINGS[self.level][TARGET_SCORE])
393 | max_score = min_score + (target_score - min_score) / 0.99 # Target is 99th percentile (Scale 0-1000)
394 | std_reward = round(1000 * (episode_reward - min_score) / (max_score - min_score), 4)
395 | std_reward = min(1000, std_reward) # Cannot be more than 1,000
396 | std_reward = max(0, std_reward) # Cannot be less than 0
397 | return std_reward
398 |
399 | def get_total_reward(self):
400 | # Returns the sum of the average of all levels
401 | total_score = 0
402 | passed_levels = 0
403 | for i in range(NUM_LEVELS):
404 | if len(self.scores[i]) > 0:
405 | level_total = 0
406 | level_count = min(len(self.scores[i]), self.average_over)
407 | for j in range(level_count):
408 | level_total += self.scores[i][j]
409 | level_average = level_total / level_count
410 | if level_average >= 990:
411 | passed_levels += 1
412 | total_score += level_average
413 | # Bonus for passing all levels (50 * num of levels)
414 | if NUM_LEVELS == passed_levels:
415 | total_score += NUM_LEVELS * 50
416 | return round(total_score, 4)
417 |
418 | def _calculate_reward(self, episode_reward, prev_total_reward):
419 | # Calculates the action reward and the new total reward
420 | std_reward = self._get_standard_reward(episode_reward)
421 | self.scores[self.level][0] = std_reward
422 | total_reward = self.get_total_reward()
423 | reward = total_reward - prev_total_reward
424 | return reward, total_reward
425 |
426 | def get_scores(self):
427 | # Returns a list with the averages per level
428 | averages = [0] * NUM_LEVELS
429 | for i in range(NUM_LEVELS):
430 | if len(self.scores[i]) > 0:
431 | level_total = 0
432 | level_count = min(len(self.scores[i]), self.average_over)
433 | for j in range(level_count):
434 | level_total += self.scores[i][j]
435 | level_average = level_total / level_count
436 | averages[i] = round(level_average, 4)
437 | return averages
438 |
439 | def _reset(self):
440 | # Reset is called on first step() after level is finished
441 | # or when change_level() is called. Returning if neither have been called to
442 | # avoid resetting the level twice
443 | if self.find_new_level:
444 | return
445 |
446 | if self.is_initialized and not self._closed and self.previous_level == self.level:
447 | self._start_episode()
448 | return self.game.get_state().image_buffer.copy()
449 | else:
450 | return self._load_level()
451 |
452 | def _step(self, action):
453 | # Changing level
454 | if self.find_new_level:
455 | self.change_level()
456 |
457 | if 'human' == self._mode:
458 | self._play_human_mode()
459 | obs = np.zeros(shape=self.observation_space.shape, dtype=np.uint8)
460 | reward = 0
461 | is_finished = True
462 | info = self._get_game_variables(None)
463 | else:
464 | obs, step_reward, is_finished, info = super(MetaDoomEnv, self)._step(action)
465 | reward, self.total_reward = self._calculate_reward(self.game.get_total_reward(), self.total_reward)
466 | # First step() after new episode returns the entire total reward
467 | # because stats_recorder resets the episode score to 0 after reset() is called
468 | if self.is_new_episode:
469 | reward = self.total_reward
470 |
471 | self.is_new_episode = False
472 | info["SCORES"] = self.get_scores()
473 | info["TOTAL_REWARD"] = round(self.total_reward, 4)
474 | info["LOCKED_LEVELS"] = self.locked_levels
475 |
476 | # Indicating new level required
477 | if is_finished:
478 | self._unlock_levels()
479 | self.find_new_level = True
480 |
481 | return obs, reward, is_finished, info
482 |
--------------------------------------------------------------------------------
/doomFiles/doom_my_way_home_sparse.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from .doom_env import DoomEnv
3 |
4 | logger = logging.getLogger(__name__)
5 |
6 |
7 | class DoomMyWayHomeFixedEnv(DoomEnv):
8 | """
9 | ------------ Training Mission 10 - My Way Home Fixed ------------
10 | Exactly same as Mission#6, but with fixed start from room-10 (farthest).
11 | This map is designed to improve navigational skills. It is a series of
12 | interconnected rooms and 1 corridor with a dead end. Each room
13 | has a separate color. There is a green vest in one of the room.
14 | The vest is always in the same room. Player must find the vest.
15 |
16 | Allowed actions:
17 | [13] - MOVE_FORWARD - Move forward - Values 0 or 1
18 | [14] - TURN_RIGHT - Turn right - Values 0 or 1
19 | [15] - TURN_LEFT - Turn left - Values 0 or 1
20 | Note: see controls.md for details
21 |
22 | Rewards:
23 | + 1 - Finding the vest
24 | -0.0001 - 35 times per second - Find the vest quick!
25 |
26 | Goal: 0.50 point
27 | Find the vest
28 |
29 | Ends when:
30 | - Vest is found
31 | - Timeout (1 minutes - 2,100 frames)
32 |
33 | Actions:
34 | actions = [0] * 43
35 | actions[13] = 0 # MOVE_FORWARD
36 | actions[14] = 1 # TURN_RIGHT
37 | actions[15] = 0 # TURN_LEFT
38 |
39 | Configuration:
40 | After creating the env, you can call env.configure() to configure some parameters.
41 |
42 | - lock [e.g. env.configure(lock=multiprocessing_lock)]
43 |
44 | VizDoom requires a multiprocessing lock when running across multiple processes, otherwise the vizdoom instance
45 | might crash on launch
46 |
47 | You can either:
48 |
49 | 1) [Preferred] Create a multiprocessing.Lock() and pass it as a parameter to the configure() method
50 | [e.g. env.configure(lock=multiprocessing_lock)]
51 |
52 | 2) Create and close a Doom environment before running your multiprocessing routine, this will create
53 | a singleton lock that will be cached in memory, and be used by all Doom environments afterwards
54 | [e.g. env = gym.make('Doom-...'); env.close()]
55 |
56 | 3) Manually wrap calls to reset() and close() in a multiprocessing.Lock()
57 |
58 | Wrappers:
59 |
60 | You can use wrappers to further customize the environment. Wrappers need to be manually copied from the wrappers folder.
61 |
62 | theWrapperOne = WrapperOneName(init_options)
63 | theWrapperTwo = WrapperTwoName(init_options)
64 | env = gym.make('ppaquette/DoomMyWayHome-v0')
65 | env = theWrapperTwo(theWrapperOne((env))
66 |
67 | - Observation space:
68 |
69 | You can change the resolution by using the SetResolution wrapper.
70 |
71 | wrapper = SetResolution(target_resolution)
72 | env = wrapper(env)
73 |
74 | The following are valid target_resolution that can be used:
75 |
76 | '160x120', '200x125', '200x150', '256x144', '256x160', '256x192', '320x180', '320x200',
77 | '320x240', '320x256', '400x225', '400x250', '400x300', '512x288', '512x320', '512x384',
78 | '640x360', '640x400', '640x480', '800x450', '800x500', '800x600', '1024x576', '1024x640',
79 | '1024x768', '1280x720', '1280x800', '1280x960', '1280x1024', '1400x787', '1400x875',
80 | '1400x1050', '1600x900', '1600x1000', '1600x1200', '1920x1080'
81 |
82 | - Action space:
83 |
84 | You can change the action space by using the ToDiscrete or ToBox wrapper
85 |
86 | wrapper = ToBox(config_options)
87 | env = wrapper(env)
88 |
89 | The following are valid config options (for both ToDiscrete and ToBox)
90 |
91 | - minimal - Only the level's allowed actions (and NOOP for discrete)
92 | - constant-7 - 7 minimum actions required to complete all levels (and NOOP for discrete)
93 | - constant-17 - 17 most common actions required to complete all levels (and NOOP for discrete)
94 | - full - All available actions (and NOOP for discrete)
95 |
96 | Note: Discrete action spaces only allow one action at a time, Box action spaces support simultaneous actions
97 |
98 | - Control:
99 |
100 | You can play the game manually with the SetPlayingMode wrapper.
101 |
102 | wrapper = SetPlayingMode('human')
103 | env = wrapper(env)
104 |
105 | Valid options are 'human' or 'algo' (default)
106 |
107 | -----------------------------------------------------
108 | """
109 | def __init__(self):
110 | super(DoomMyWayHomeFixedEnv, self).__init__(9)
111 |
--------------------------------------------------------------------------------
/doomFiles/doom_my_way_home_verySparse.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from .doom_env import DoomEnv
3 |
4 | logger = logging.getLogger(__name__)
5 |
6 |
7 | class DoomMyWayHomeFixed15Env(DoomEnv):
8 | """
9 | ------------ Training Mission 11 - My Way Home Fixed15 ------------
10 | Exactly same as Mission#6, but with fixed start from room-10 (farthest).
11 | This map is designed to improve navigational skills. It is a series of
12 | interconnected rooms and 1 corridor with a dead end. Each room
13 | has a separate color. There is a green vest in one of the room.
14 | The vest is always in the same room. Player must find the vest.
15 |
16 | Allowed actions:
17 | [13] - MOVE_FORWARD - Move forward - Values 0 or 1
18 | [14] - TURN_RIGHT - Turn right - Values 0 or 1
19 | [15] - TURN_LEFT - Turn left - Values 0 or 1
20 | Note: see controls.md for details
21 |
22 | Rewards:
23 | + 1 - Finding the vest
24 | -0.0001 - 35 times per second - Find the vest quick!
25 |
26 | Goal: 0.50 point
27 | Find the vest
28 |
29 | Ends when:
30 | - Vest is found
31 | - Timeout (1 minutes - 2,100 frames)
32 |
33 | Actions:
34 | actions = [0] * 43
35 | actions[13] = 0 # MOVE_FORWARD
36 | actions[14] = 1 # TURN_RIGHT
37 | actions[15] = 0 # TURN_LEFT
38 |
39 | Configuration:
40 | After creating the env, you can call env.configure() to configure some parameters.
41 |
42 | - lock [e.g. env.configure(lock=multiprocessing_lock)]
43 |
44 | VizDoom requires a multiprocessing lock when running across multiple processes, otherwise the vizdoom instance
45 | might crash on launch
46 |
47 | You can either:
48 |
49 | 1) [Preferred] Create a multiprocessing.Lock() and pass it as a parameter to the configure() method
50 | [e.g. env.configure(lock=multiprocessing_lock)]
51 |
52 | 2) Create and close a Doom environment before running your multiprocessing routine, this will create
53 | a singleton lock that will be cached in memory, and be used by all Doom environments afterwards
54 | [e.g. env = gym.make('Doom-...'); env.close()]
55 |
56 | 3) Manually wrap calls to reset() and close() in a multiprocessing.Lock()
57 |
58 | Wrappers:
59 |
60 | You can use wrappers to further customize the environment. Wrappers need to be manually copied from the wrappers folder.
61 |
62 | theWrapperOne = WrapperOneName(init_options)
63 | theWrapperTwo = WrapperTwoName(init_options)
64 | env = gym.make('ppaquette/DoomMyWayHome-v0')
65 | env = theWrapperTwo(theWrapperOne((env))
66 |
67 | - Observation space:
68 |
69 | You can change the resolution by using the SetResolution wrapper.
70 |
71 | wrapper = SetResolution(target_resolution)
72 | env = wrapper(env)
73 |
74 | The following are valid target_resolution that can be used:
75 |
76 | '160x120', '200x125', '200x150', '256x144', '256x160', '256x192', '320x180', '320x200',
77 | '320x240', '320x256', '400x225', '400x250', '400x300', '512x288', '512x320', '512x384',
78 | '640x360', '640x400', '640x480', '800x450', '800x500', '800x600', '1024x576', '1024x640',
79 | '1024x768', '1280x720', '1280x800', '1280x960', '1280x1024', '1400x787', '1400x875',
80 | '1400x1050', '1600x900', '1600x1000', '1600x1200', '1920x1080'
81 |
82 | - Action space:
83 |
84 | You can change the action space by using the ToDiscrete or ToBox wrapper
85 |
86 | wrapper = ToBox(config_options)
87 | env = wrapper(env)
88 |
89 | The following are valid config options (for both ToDiscrete and ToBox)
90 |
91 | - minimal - Only the level's allowed actions (and NOOP for discrete)
92 | - constant-7 - 7 minimum actions required to complete all levels (and NOOP for discrete)
93 | - constant-17 - 17 most common actions required to complete all levels (and NOOP for discrete)
94 | - full - All available actions (and NOOP for discrete)
95 |
96 | Note: Discrete action spaces only allow one action at a time, Box action spaces support simultaneous actions
97 |
98 | - Control:
99 |
100 | You can play the game manually with the SetPlayingMode wrapper.
101 |
102 | wrapper = SetPlayingMode('human')
103 | env = wrapper(env)
104 |
105 | Valid options are 'human' or 'algo' (default)
106 |
107 | -----------------------------------------------------
108 | """
109 | def __init__(self):
110 | super(DoomMyWayHomeFixed15Env, self).__init__(10)
111 |
--------------------------------------------------------------------------------
/doomFiles/wads/my_way_home_dense.wad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pathak22/noreward-rl/3e220c2177fc253916f12d980957fc40579d577a/doomFiles/wads/my_way_home_dense.wad
--------------------------------------------------------------------------------
/doomFiles/wads/my_way_home_sparse.wad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pathak22/noreward-rl/3e220c2177fc253916f12d980957fc40579d577a/doomFiles/wads/my_way_home_sparse.wad
--------------------------------------------------------------------------------
/doomFiles/wads/my_way_home_verySparse.wad:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pathak22/noreward-rl/3e220c2177fc253916f12d980957fc40579d577a/doomFiles/wads/my_way_home_verySparse.wad
--------------------------------------------------------------------------------
/images/mario1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pathak22/noreward-rl/3e220c2177fc253916f12d980957fc40579d577a/images/mario1.gif
--------------------------------------------------------------------------------
/images/mario2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pathak22/noreward-rl/3e220c2177fc253916f12d980957fc40579d577a/images/mario2.gif
--------------------------------------------------------------------------------
/images/vizdoom.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pathak22/noreward-rl/3e220c2177fc253916f12d980957fc40579d577a/images/vizdoom.gif
--------------------------------------------------------------------------------
/models/download_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )/" && pwd )"
4 | cd $DIR
5 |
6 | FILE=models.tar.gz
7 | URL=https://people.eecs.berkeley.edu/~pathak/noreward-rl/resources/$FILE
8 | CHECKSUM=26bdf54e9562e23750ebc2ef503204b1
9 |
10 | if [ ! -f $FILE ]; then
11 | echo "Downloading the curiosity-driven RL trained models (6MB)..."
12 | wget $URL -O $FILE
13 | echo "Unzipping..."
14 | tar zxvf $FILE
15 | mv models/* .
16 | rm -rf models
17 | echo "Downloading Done."
18 | else
19 | echo "File already exists. Checking md5..."
20 | fi
21 |
22 | os=`uname -s`
23 | if [ "$os" = "Linux" ]; then
24 | checksum=`md5sum $FILE | awk '{ print $1 }'`
25 | elif [ "$os" = "Darwin" ]; then
26 | checksum=`cat $FILE | md5`
27 | elif [ "$os" = "SunOS" ]; then
28 | checksum=`digest -a md5 -v $FILE | awk '{ print $4 }'`
29 | fi
30 | if [ "$checksum" = "$CHECKSUM" ]; then
31 | echo "Checksum is correct. File was correctly downloaded."
32 | exit 0
33 | else
34 | echo "Checksum is incorrect. DELETE and download again."
35 | fi
36 |
--------------------------------------------------------------------------------
/src/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # IPython Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # dotenv
79 | .env
80 |
81 | # virtualenv
82 | venv/
83 | ENV/
84 | unsuprl/
85 |
86 | # Spyder project settings
87 | .spyderproject
88 |
89 | # Rope project settings
90 | .ropeproject
91 | tmp/
92 |
93 | # vizdoom cache
94 | vizdoom.ini
95 |
--------------------------------------------------------------------------------
/src/a3c.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from collections import namedtuple
3 | import numpy as np
4 | import tensorflow as tf
5 | from model import LSTMPolicy, StateActionPredictor, StatePredictor
6 | import six.moves.queue as queue
7 | import scipy.signal
8 | import threading
9 | import distutils.version
10 | from constants import constants
11 | use_tf12_api = distutils.version.LooseVersion(tf.VERSION) >= distutils.version.LooseVersion('0.12.0')
12 |
13 | def discount(x, gamma):
14 | """
15 | x = [r1, r2, r3, ..., rN]
16 | returns [r1 + r2*gamma + r3*gamma^2 + ...,
17 | r2 + r3*gamma + r4*gamma^2 + ...,
18 | r3 + r4*gamma + r5*gamma^2 + ...,
19 | ..., ..., rN]
20 | """
21 | return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
22 |
23 | def process_rollout(rollout, gamma, lambda_=1.0, clip=False):
24 | """
25 | Given a rollout, compute its returns and the advantage.
26 | """
27 | # collecting transitions
28 | if rollout.unsup:
29 | batch_si = np.asarray(rollout.states + [rollout.end_state])
30 | else:
31 | batch_si = np.asarray(rollout.states)
32 | batch_a = np.asarray(rollout.actions)
33 |
34 | # collecting target for value network
35 | # V_t <-> r_t + gamma*r_{t+1} + ... + gamma^n*r_{t+n} + gamma^{n+1}*V_{n+1}
36 | rewards_plus_v = np.asarray(rollout.rewards + [rollout.r]) # bootstrapping
37 | if rollout.unsup:
38 | rewards_plus_v += np.asarray(rollout.bonuses + [0])
39 | if clip:
40 | rewards_plus_v[:-1] = np.clip(rewards_plus_v[:-1], -constants['REWARD_CLIP'], constants['REWARD_CLIP'])
41 | batch_r = discount(rewards_plus_v, gamma)[:-1] # value network target
42 |
43 | # collecting target for policy network
44 | rewards = np.asarray(rollout.rewards)
45 | if rollout.unsup:
46 | rewards += np.asarray(rollout.bonuses)
47 | if clip:
48 | rewards = np.clip(rewards, -constants['REWARD_CLIP'], constants['REWARD_CLIP'])
49 | vpred_t = np.asarray(rollout.values + [rollout.r])
50 | # "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438
51 | # Eq (10): delta_t = Rt + gamma*V_{t+1} - V_t
52 | # Eq (16): batch_adv_t = delta_t + gamma*delta_{t+1} + gamma^2*delta_{t+2} + ...
53 | delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1]
54 | batch_adv = discount(delta_t, gamma * lambda_)
55 |
56 | features = rollout.features[0]
57 |
58 | return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.terminal, features)
59 |
60 | Batch = namedtuple("Batch", ["si", "a", "adv", "r", "terminal", "features"])
61 |
62 | class PartialRollout(object):
63 | """
64 | A piece of a complete rollout. We run our agent, and process its experience
65 | once it has processed enough steps.
66 | """
67 | def __init__(self, unsup=False):
68 | self.states = []
69 | self.actions = []
70 | self.rewards = []
71 | self.values = []
72 | self.r = 0.0
73 | self.terminal = False
74 | self.features = []
75 | self.unsup = unsup
76 | if self.unsup:
77 | self.bonuses = []
78 | self.end_state = None
79 |
80 |
81 | def add(self, state, action, reward, value, terminal, features,
82 | bonus=None, end_state=None):
83 | self.states += [state]
84 | self.actions += [action]
85 | self.rewards += [reward]
86 | self.values += [value]
87 | self.terminal = terminal
88 | self.features += [features]
89 | if self.unsup:
90 | self.bonuses += [bonus]
91 | self.end_state = end_state
92 |
93 | def extend(self, other):
94 | assert not self.terminal
95 | self.states.extend(other.states)
96 | self.actions.extend(other.actions)
97 | self.rewards.extend(other.rewards)
98 | self.values.extend(other.values)
99 | self.r = other.r
100 | self.terminal = other.terminal
101 | self.features.extend(other.features)
102 | if self.unsup:
103 | self.bonuses.extend(other.bonuses)
104 | self.end_state = other.end_state
105 |
106 | class RunnerThread(threading.Thread):
107 | """
108 | One of the key distinctions between a normal environment and a universe environment
109 | is that a universe environment is _real time_. This means that there should be a thread
110 | that would constantly interact with the environment and tell it what to do. This thread is here.
111 | """
112 | def __init__(self, env, policy, num_local_steps, visualise, predictor, envWrap,
113 | noReward):
114 | threading.Thread.__init__(self)
115 | self.queue = queue.Queue(5) # ideally, should be 1. Mostly doesn't matter in our case.
116 | self.num_local_steps = num_local_steps
117 | self.env = env
118 | self.last_features = None
119 | self.policy = policy
120 | self.daemon = True
121 | self.sess = None
122 | self.summary_writer = None
123 | self.visualise = visualise
124 | self.predictor = predictor
125 | self.envWrap = envWrap
126 | self.noReward = noReward
127 |
128 | def start_runner(self, sess, summary_writer):
129 | self.sess = sess
130 | self.summary_writer = summary_writer
131 | self.start()
132 |
133 | def run(self):
134 | with self.sess.as_default():
135 | self._run()
136 |
137 | def _run(self):
138 | rollout_provider = env_runner(self.env, self.policy, self.num_local_steps,
139 | self.summary_writer, self.visualise, self.predictor,
140 | self.envWrap, self.noReward)
141 | while True:
142 | # the timeout variable exists because apparently, if one worker dies, the other workers
143 | # won't die with it, unless the timeout is set to some large number. This is an empirical
144 | # observation.
145 |
146 | self.queue.put(next(rollout_provider), timeout=600.0)
147 |
148 |
149 | def env_runner(env, policy, num_local_steps, summary_writer, render, predictor,
150 | envWrap, noReward):
151 | """
152 | The logic of the thread runner. In brief, it constantly keeps on running
153 | the policy, and as long as the rollout exceeds a certain length, the thread
154 | runner appends the policy to the queue.
155 | """
156 | last_state = env.reset()
157 | last_features = policy.get_initial_features() # reset lstm memory
158 | length = 0
159 | rewards = 0
160 | values = 0
161 | if predictor is not None:
162 | ep_bonus = 0
163 | life_bonus = 0
164 |
165 | while True:
166 | terminal_end = False
167 | rollout = PartialRollout(predictor is not None)
168 |
169 | for _ in range(num_local_steps):
170 | # run policy
171 | fetched = policy.act(last_state, *last_features)
172 | action, value_, features = fetched[0], fetched[1], fetched[2:]
173 |
174 | # run environment: get action_index from sampled one-hot 'action'
175 | stepAct = action.argmax()
176 | state, reward, terminal, info = env.step(stepAct)
177 | if noReward:
178 | reward = 0.
179 | if render:
180 | env.render()
181 |
182 | curr_tuple = [last_state, action, reward, value_, terminal, last_features]
183 | if predictor is not None:
184 | bonus = predictor.pred_bonus(last_state, state, action)
185 | curr_tuple += [bonus, state]
186 | life_bonus += bonus
187 | ep_bonus += bonus
188 |
189 | # collect the experience
190 | rollout.add(*curr_tuple)
191 | rewards += reward
192 | length += 1
193 | values += value_[0]
194 |
195 | last_state = state
196 | last_features = features
197 |
198 | timestep_limit = env.spec.tags.get('wrapper_config.TimeLimit.max_episode_steps')
199 | if timestep_limit is None: timestep_limit = env.spec.timestep_limit
200 | if terminal or length >= timestep_limit:
201 | # prints summary of each life if envWrap==True else each game
202 | if predictor is not None:
203 | print("Episode finished. Sum of shaped rewards: %.2f. Length: %d. Bonus: %.4f." % (rewards, length, life_bonus))
204 | life_bonus = 0
205 | else:
206 | print("Episode finished. Sum of shaped rewards: %.2f. Length: %d." % (rewards, length))
207 | if 'distance' in info: print('Mario Distance Covered:', info['distance'])
208 | length = 0
209 | rewards = 0
210 | terminal_end = True
211 | last_features = policy.get_initial_features() # reset lstm memory
212 | # TODO: don't reset when gym timestep_limit increases, bootstrap -- doesn't matter for atari?
213 | # reset only if it hasn't already reseted
214 | if length >= timestep_limit or not env.metadata.get('semantics.autoreset'):
215 | last_state = env.reset()
216 |
217 | if info:
218 | # summarize full game including all lives (even if envWrap=True)
219 | summary = tf.Summary()
220 | for k, v in info.items():
221 | summary.value.add(tag=k, simple_value=float(v))
222 | if terminal:
223 | summary.value.add(tag='global/episode_value', simple_value=float(values))
224 | values = 0
225 | if predictor is not None:
226 | summary.value.add(tag='global/episode_bonus', simple_value=float(ep_bonus))
227 | ep_bonus = 0
228 | summary_writer.add_summary(summary, policy.global_step.eval())
229 | summary_writer.flush()
230 |
231 | if terminal_end:
232 | break
233 |
234 | if not terminal_end:
235 | rollout.r = policy.value(last_state, *last_features)
236 |
237 | # once we have enough experience, yield it, and have the ThreadRunner place it on a queue
238 | yield rollout
239 |
240 |
241 | class A3C(object):
242 | def __init__(self, env, task, visualise, unsupType, envWrap=False, designHead='universe', noReward=False):
243 | """
244 | An implementation of the A3C algorithm that is reasonably well-tuned for the VNC environments.
245 | Below, we will have a modest amount of complexity due to the way TensorFlow handles data parallelism.
246 | But overall, we'll define the model, specify its inputs, and describe how the policy gradients step
247 | should be computed.
248 | """
249 | self.task = task
250 | self.unsup = unsupType is not None
251 | self.envWrap = envWrap
252 | self.env = env
253 |
254 | predictor = None
255 | numaction = env.action_space.n
256 | worker_device = "/job:worker/task:{}/cpu:0".format(task)
257 |
258 | with tf.device(tf.train.replica_device_setter(1, worker_device=worker_device)):
259 | with tf.variable_scope("global"):
260 | self.network = LSTMPolicy(env.observation_space.shape, numaction, designHead)
261 | self.global_step = tf.get_variable("global_step", [], tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32),
262 | trainable=False)
263 | if self.unsup:
264 | with tf.variable_scope("predictor"):
265 | if 'state' in unsupType:
266 | self.ap_network = StatePredictor(env.observation_space.shape, numaction, designHead, unsupType)
267 | else:
268 | self.ap_network = StateActionPredictor(env.observation_space.shape, numaction, designHead)
269 |
270 | with tf.device(worker_device):
271 | with tf.variable_scope("local"):
272 | self.local_network = pi = LSTMPolicy(env.observation_space.shape, numaction, designHead)
273 | pi.global_step = self.global_step
274 | if self.unsup:
275 | with tf.variable_scope("predictor"):
276 | if 'state' in unsupType:
277 | self.local_ap_network = predictor = StatePredictor(env.observation_space.shape, numaction, designHead, unsupType)
278 | else:
279 | self.local_ap_network = predictor = StateActionPredictor(env.observation_space.shape, numaction, designHead)
280 |
281 | # Computing a3c loss: https://arxiv.org/abs/1506.02438
282 | self.ac = tf.placeholder(tf.float32, [None, numaction], name="ac")
283 | self.adv = tf.placeholder(tf.float32, [None], name="adv")
284 | self.r = tf.placeholder(tf.float32, [None], name="r")
285 | log_prob_tf = tf.nn.log_softmax(pi.logits)
286 | prob_tf = tf.nn.softmax(pi.logits)
287 | # 1) the "policy gradients" loss: its derivative is precisely the policy gradient
288 | # notice that self.ac is a placeholder that is provided externally.
289 | # adv will contain the advantages, as calculated in process_rollout
290 | pi_loss = - tf.reduce_mean(tf.reduce_sum(log_prob_tf * self.ac, 1) * self.adv) # Eq (19)
291 | # 2) loss of value function: l2_loss = (x-y)^2/2
292 | vf_loss = 0.5 * tf.reduce_mean(tf.square(pi.vf - self.r)) # Eq (28)
293 | # 3) entropy to ensure randomness
294 | entropy = - tf.reduce_mean(tf.reduce_sum(prob_tf * log_prob_tf, 1))
295 | # final a3c loss: lr of critic is half of actor
296 | self.loss = pi_loss + 0.5 * vf_loss - entropy * constants['ENTROPY_BETA']
297 |
298 | # compute gradients
299 | grads = tf.gradients(self.loss * 20.0, pi.var_list) # batchsize=20. Factored out to make hyperparams not depend on it.
300 |
301 | # computing predictor loss
302 | if self.unsup:
303 | if 'state' in unsupType:
304 | self.predloss = constants['PREDICTION_LR_SCALE'] * predictor.forwardloss
305 | else:
306 | self.predloss = constants['PREDICTION_LR_SCALE'] * (predictor.invloss * (1-constants['FORWARD_LOSS_WT']) +
307 | predictor.forwardloss * constants['FORWARD_LOSS_WT'])
308 | predgrads = tf.gradients(self.predloss * 20.0, predictor.var_list) # batchsize=20. Factored out to make hyperparams not depend on it.
309 |
310 | # do not backprop to policy
311 | if constants['POLICY_NO_BACKPROP_STEPS'] > 0:
312 | grads = [tf.scalar_mul(tf.to_float(tf.greater(self.global_step, constants['POLICY_NO_BACKPROP_STEPS'])), grads_i)
313 | for grads_i in grads]
314 |
315 |
316 | self.runner = RunnerThread(env, pi, constants['ROLLOUT_MAXLEN'], visualise,
317 | predictor, envWrap, noReward)
318 |
319 | # storing summaries
320 | bs = tf.to_float(tf.shape(pi.x)[0])
321 | if use_tf12_api:
322 | tf.summary.scalar("model/policy_loss", pi_loss)
323 | tf.summary.scalar("model/value_loss", vf_loss)
324 | tf.summary.scalar("model/entropy", entropy)
325 | tf.summary.image("model/state", pi.x) # max_outputs=10
326 | tf.summary.scalar("model/grad_global_norm", tf.global_norm(grads))
327 | tf.summary.scalar("model/var_global_norm", tf.global_norm(pi.var_list))
328 | if self.unsup:
329 | tf.summary.scalar("model/predloss", self.predloss)
330 | if 'action' in unsupType:
331 | tf.summary.scalar("model/inv_loss", predictor.invloss)
332 | tf.summary.scalar("model/forward_loss", predictor.forwardloss)
333 | tf.summary.scalar("model/predgrad_global_norm", tf.global_norm(predgrads))
334 | tf.summary.scalar("model/predvar_global_norm", tf.global_norm(predictor.var_list))
335 | self.summary_op = tf.summary.merge_all()
336 | else:
337 | tf.scalar_summary("model/policy_loss", pi_loss)
338 | tf.scalar_summary("model/value_loss", vf_loss)
339 | tf.scalar_summary("model/entropy", entropy)
340 | tf.image_summary("model/state", pi.x)
341 | tf.scalar_summary("model/grad_global_norm", tf.global_norm(grads))
342 | tf.scalar_summary("model/var_global_norm", tf.global_norm(pi.var_list))
343 | if self.unsup:
344 | tf.scalar_summary("model/predloss", self.predloss)
345 | if 'action' in unsupType:
346 | tf.scalar_summary("model/inv_loss", predictor.invloss)
347 | tf.scalar_summary("model/forward_loss", predictor.forwardloss)
348 | tf.scalar_summary("model/predgrad_global_norm", tf.global_norm(predgrads))
349 | tf.scalar_summary("model/predvar_global_norm", tf.global_norm(predictor.var_list))
350 | self.summary_op = tf.merge_all_summaries()
351 |
352 | # clip gradients
353 | grads, _ = tf.clip_by_global_norm(grads, constants['GRAD_NORM_CLIP'])
354 | grads_and_vars = list(zip(grads, self.network.var_list))
355 | if self.unsup:
356 | predgrads, _ = tf.clip_by_global_norm(predgrads, constants['GRAD_NORM_CLIP'])
357 | pred_grads_and_vars = list(zip(predgrads, self.ap_network.var_list))
358 | grads_and_vars = grads_and_vars + pred_grads_and_vars
359 |
360 | # update global step by batch size
361 | inc_step = self.global_step.assign_add(tf.shape(pi.x)[0])
362 |
363 | # each worker has a different set of adam optimizer parameters
364 | # TODO: make optimizer global shared, if needed
365 | print("Optimizer: ADAM with lr: %f" % (constants['LEARNING_RATE']))
366 | print("Input observation shape: ",env.observation_space.shape)
367 | opt = tf.train.AdamOptimizer(constants['LEARNING_RATE'])
368 | self.train_op = tf.group(opt.apply_gradients(grads_and_vars), inc_step)
369 |
370 | # copy weights from the parameter server to the local model
371 | sync_var_list = [v1.assign(v2) for v1, v2 in zip(pi.var_list, self.network.var_list)]
372 | if self.unsup:
373 | sync_var_list += [v1.assign(v2) for v1, v2 in zip(predictor.var_list, self.ap_network.var_list)]
374 | self.sync = tf.group(*sync_var_list)
375 |
376 | # initialize extras
377 | self.summary_writer = None
378 | self.local_steps = 0
379 |
380 | def start(self, sess, summary_writer):
381 | self.runner.start_runner(sess, summary_writer)
382 | self.summary_writer = summary_writer
383 |
384 | def pull_batch_from_queue(self):
385 | """
386 | Take a rollout from the queue of the thread runner.
387 | """
388 | # get top rollout from queue (FIFO)
389 | rollout = self.runner.queue.get(timeout=600.0)
390 | while not rollout.terminal:
391 | try:
392 | # Now, get remaining *available* rollouts from queue and append them into
393 | # the same one above. If queue.Queue(5): len=5 and everything is
394 | # superfast (not usually the case), then all 5 will be returned and
395 | # exception is raised. In such a case, effective batch_size would become
396 | # constants['ROLLOUT_MAXLEN'] * queue_maxlen(5). But it is almost never the
397 | # case, i.e., collecting a rollout of length=ROLLOUT_MAXLEN takes more time
398 | # than get(). So, there are no more available rollouts in queue usually and
399 | # exception gets always raised. Hence, one should keep queue_maxlen = 1 ideally.
400 | # Also note that the next rollout generation gets invoked automatically because
401 | # its a thread which is always running using 'yield' at end of generation process.
402 | # To conclude, effective batch_size = constants['ROLLOUT_MAXLEN']
403 | rollout.extend(self.runner.queue.get_nowait())
404 | except queue.Empty:
405 | break
406 | return rollout
407 |
408 | def process(self, sess):
409 | """
410 | Process grabs a rollout that's been produced by the thread runner,
411 | and updates the parameters. The update is then sent to the parameter
412 | server.
413 | """
414 | sess.run(self.sync) # copy weights from shared to local
415 | rollout = self.pull_batch_from_queue()
416 | batch = process_rollout(rollout, gamma=constants['GAMMA'], lambda_=constants['LAMBDA'], clip=self.envWrap)
417 |
418 | should_compute_summary = self.task == 0 and self.local_steps % 11 == 0
419 |
420 | if should_compute_summary:
421 | fetches = [self.summary_op, self.train_op, self.global_step]
422 | else:
423 | fetches = [self.train_op, self.global_step]
424 |
425 | feed_dict = {
426 | self.local_network.x: batch.si,
427 | self.ac: batch.a,
428 | self.adv: batch.adv,
429 | self.r: batch.r,
430 | self.local_network.state_in[0]: batch.features[0],
431 | self.local_network.state_in[1]: batch.features[1],
432 | }
433 | if self.unsup:
434 | feed_dict[self.local_network.x] = batch.si[:-1]
435 | feed_dict[self.local_ap_network.s1] = batch.si[:-1]
436 | feed_dict[self.local_ap_network.s2] = batch.si[1:]
437 | feed_dict[self.local_ap_network.asample] = batch.a
438 |
439 | fetched = sess.run(fetches, feed_dict=feed_dict)
440 | if batch.terminal:
441 | print("Global Step Counter: %d"%fetched[-1])
442 |
443 | if should_compute_summary:
444 | self.summary_writer.add_summary(tf.Summary.FromString(fetched[0]), fetched[-1])
445 | self.summary_writer.flush()
446 | self.local_steps += 1
447 |
--------------------------------------------------------------------------------
/src/constants.py:
--------------------------------------------------------------------------------
1 | constants = {
2 | 'GAMMA': 0.99, # discount factor for rewards
3 | 'LAMBDA': 1.0, # lambda of Generalized Advantage Estimation: https://arxiv.org/abs/1506.02438
4 | 'ENTROPY_BETA': 0.01, # entropy regurarlization constant.
5 | 'ROLLOUT_MAXLEN': 20, # 20 represents the number of 'local steps': the number of timesteps
6 | # we run the policy before we update the parameters.
7 | # The larger local steps is, the lower is the variance in our policy gradients estimate
8 | # on the one hand; but on the other hand, we get less frequent parameter updates, which
9 | # slows down learning. In this code, we found that making local steps be much
10 | # smaller than 20 makes the algorithm more difficult to tune and to get to work.
11 | 'GRAD_NORM_CLIP': 40.0, # gradient norm clipping
12 | 'REWARD_CLIP': 1.0, # reward value clipping in [-x,x]
13 | 'MAX_GLOBAL_STEPS': 100000000, # total steps taken across all workers
14 | 'LEARNING_RATE': 1e-4, # learning rate for adam
15 |
16 | 'PREDICTION_BETA': 0.01, # weight of prediction bonus
17 | # set 0.5 for unsup=state
18 | 'PREDICTION_LR_SCALE': 10.0, # scale lr of predictor wrt to policy network
19 | # set 30-50 for unsup=state
20 | 'FORWARD_LOSS_WT': 0.2, # should be between [0,1]
21 | # predloss = ( (1-FORWARD_LOSS_WT) * inv_loss + FORWARD_LOSS_WT * forward_loss) * PREDICTION_LR_SCALE
22 | 'POLICY_NO_BACKPROP_STEPS': 0, # number of global steps after which we start backpropagating to policy
23 | }
24 |
--------------------------------------------------------------------------------
/src/demo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from __future__ import print_function
3 | import tensorflow as tf
4 | import gym
5 | import numpy as np
6 | import argparse
7 | import logging
8 | from envs import create_env
9 | logger = logging.getLogger(__name__)
10 | logger.setLevel(logging.INFO)
11 |
12 |
13 | def inference(args):
14 | """
15 | It restore policy weights, and does inference.
16 | """
17 | # define environment
18 | env = create_env(args.env_id, client_id='0', remotes=None, envWrap=True,
19 | acRepeat=1, record=args.record, outdir=args.outdir)
20 | numaction = env.action_space.n
21 |
22 | with tf.device("/cpu:0"):
23 | config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
24 | with tf.Session(config=config) as sess:
25 | logger.info("Restoring trainable global parameters.")
26 | saver = tf.train.import_meta_graph(args.ckpt+'.meta')
27 | saver.restore(sess, args.ckpt)
28 |
29 | probs = tf.get_collection("probs")[0]
30 | sample = tf.get_collection("sample")[0]
31 | vf = tf.get_collection("vf")[0]
32 | state_out_0 = tf.get_collection("state_out_0")[0]
33 | state_out_1 = tf.get_collection("state_out_1")[0]
34 |
35 | last_state = env.reset()
36 | if args.render or args.record:
37 | env.render()
38 | last_features = np.zeros((1, 256), np.float32); last_features = [last_features, last_features]
39 | length = 0
40 | rewards = 0
41 | mario_distances = np.zeros((args.num_episodes,))
42 | for i in range(args.num_episodes):
43 | print("Starting episode %d" % (i + 1))
44 |
45 | if args.random:
46 | print('I am a random policy!')
47 | else:
48 | if args.greedy:
49 | print('I am a greedy policy!')
50 | else:
51 | print('I am a sampled policy!')
52 | while True:
53 | # run policy
54 | fetched = sess.run([probs, sample, vf, state_out_0, state_out_1] ,
55 | {"global/x:0": [last_state], "global/c_in:0": last_features[0], "global/h_in:0": last_features[1]})
56 | prob_action, action, value_, features = fetched[0], fetched[1], fetched[2], fetched[3:]
57 |
58 | # run environment
59 | if args.random:
60 | stepAct = np.random.randint(0, numaction) # random policy
61 | else:
62 | if args.greedy:
63 | stepAct = prob_action.argmax() # greedy policy
64 | else:
65 | stepAct = action.argmax()
66 | state, reward, terminal, info = env.step(stepAct)
67 |
68 | # update stats
69 | length += 1
70 | rewards += reward
71 | last_state = state
72 | last_features = features
73 | if args.render or args.record:
74 | env.render()
75 |
76 | timestep_limit = env.spec.tags.get('wrapper_config.TimeLimit.max_episode_steps')
77 | if timestep_limit is None: timestep_limit = env.spec.timestep_limit
78 | if terminal or length >= timestep_limit:
79 | if length >= timestep_limit or not env.metadata.get('semantics.autoreset'):
80 | last_state = env.reset()
81 | last_features = np.zeros((1, 256), np.float32); last_features = [last_features, last_features]
82 | print("Episode finished. Sum of rewards: %.2f. Length: %d." % (rewards, length))
83 | length = 0
84 | rewards = 0
85 | if args.render or args.record:
86 | env.render()
87 | break
88 |
89 | logger.info('Finished %d true episodes.', args.num_episodes)
90 | env.close()
91 |
92 |
93 | def main(_):
94 | parser = argparse.ArgumentParser(description=None)
95 | parser.add_argument('--ckpt', default="../models/doom/doom_ICM", help='checkpoint name')
96 | parser.add_argument('--outdir', default="../models/output", help='Output log directory')
97 | parser.add_argument('--env-id', default="doom", help='Environment id')
98 | parser.add_argument('--record', action='store_true', help="Record the policy running video")
99 | parser.add_argument('--render', action='store_true',
100 | help="Render the gym environment video online")
101 | parser.add_argument('--num-episodes', type=int, default=2, help="Number of episodes to run")
102 | parser.add_argument('--greedy', action='store_true',
103 | help="Default sampled policy. This option does argmax.")
104 | parser.add_argument('--random', action='store_true',
105 | help="Default sampled policy. This option does random policy.")
106 | args = parser.parse_args()
107 | inference(args)
108 |
109 | if __name__ == "__main__":
110 | tf.app.run()
111 |
--------------------------------------------------------------------------------
/src/env_wrapper.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Deepak Pathak
3 |
4 | Acknowledgement:
5 | - The wrappers (BufferedObsEnv, SkipEnv) were originally written by
6 | Evan Shelhamer and modified by Deepak. Thanks Evan!
7 | - This file is derived from
8 | https://github.com/shelhamer/ourl/envs.py
9 | https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers_deprecated.py
10 | """
11 | from __future__ import print_function
12 | import numpy as np
13 | from collections import deque
14 | from PIL import Image
15 | from gym.spaces.box import Box
16 | import gym
17 | import time, sys
18 |
19 |
20 | class BufferedObsEnv(gym.ObservationWrapper):
21 | """Buffer observations and stack e.g. for frame skipping.
22 |
23 | n is the length of the buffer, and number of observations stacked.
24 | skip is the number of steps between buffered observations (min=1).
25 |
26 | n.b. first obs is the oldest, last obs is the newest.
27 | the buffer is zeroed out on reset.
28 | *must* call reset() for init!
29 | """
30 | def __init__(self, env=None, n=4, skip=4, shape=(84, 84),
31 | channel_last=True, maxFrames=True):
32 | super(BufferedObsEnv, self).__init__(env)
33 | self.obs_shape = shape
34 | # most recent raw observations (for max pooling across time steps)
35 | self.obs_buffer = deque(maxlen=2)
36 | self.maxFrames = maxFrames
37 | self.n = n
38 | self.skip = skip
39 | self.buffer = deque(maxlen=self.n)
40 | self.counter = 0 # init and reset should agree on this
41 | shape = shape + (n,) if channel_last else (n,) + shape
42 | self.observation_space = Box(0.0, 255.0, shape)
43 | self.ch_axis = -1 if channel_last else 0
44 | self.scale = 1.0 / 255
45 | self.observation_space.high[...] = 1.0
46 |
47 | def _step(self, action):
48 | obs, reward, done, info = self.env.step(action)
49 | return self._observation(obs), reward, done, info
50 |
51 | def _observation(self, obs):
52 | obs = self._convert(obs)
53 | self.counter += 1
54 | if self.counter % self.skip == 0:
55 | self.buffer.append(obs)
56 | obsNew = np.stack(self.buffer, axis=self.ch_axis)
57 | return obsNew.astype(np.float32) * self.scale
58 |
59 | def _reset(self):
60 | """Clear buffer and re-fill by duplicating the first observation."""
61 | self.obs_buffer.clear()
62 | obs = self._convert(self.env.reset())
63 | self.buffer.clear()
64 | self.counter = 0
65 | for _ in range(self.n - 1):
66 | self.buffer.append(np.zeros_like(obs))
67 | self.buffer.append(obs)
68 | obsNew = np.stack(self.buffer, axis=self.ch_axis)
69 | return obsNew.astype(np.float32) * self.scale
70 |
71 | def _convert(self, obs):
72 | self.obs_buffer.append(obs)
73 | if self.maxFrames:
74 | max_frame = np.max(np.stack(self.obs_buffer), axis=0)
75 | else:
76 | max_frame = obs
77 | intensity_frame = self._rgb2y(max_frame).astype(np.uint8)
78 | small_frame = np.array(Image.fromarray(intensity_frame).resize(
79 | self.obs_shape, resample=Image.BILINEAR), dtype=np.uint8)
80 | return small_frame
81 |
82 | def _rgb2y(self, im):
83 | """Converts an RGB image to a Y image (as in YUV).
84 |
85 | These coefficients are taken from the torch/image library.
86 | Beware: these are more critical than you might think, as the
87 | monochromatic contrast can be surprisingly low.
88 | """
89 | if len(im.shape) < 3:
90 | return im
91 | return np.sum(im * [0.299, 0.587, 0.114], axis=2)
92 |
93 |
94 | class NoNegativeRewardEnv(gym.RewardWrapper):
95 | """Clip reward in negative direction."""
96 | def __init__(self, env=None, neg_clip=0.0):
97 | super(NoNegativeRewardEnv, self).__init__(env)
98 | self.neg_clip = neg_clip
99 |
100 | def _reward(self, reward):
101 | new_reward = self.neg_clip if reward < self.neg_clip else reward
102 | return new_reward
103 |
104 |
105 | class SkipEnv(gym.Wrapper):
106 | """Skip timesteps: repeat action, accumulate reward, take last obs."""
107 | def __init__(self, env=None, skip=4):
108 | super(SkipEnv, self).__init__(env)
109 | self.skip = skip
110 |
111 | def _step(self, action):
112 | total_reward = 0
113 | for i in range(0, self.skip):
114 | obs, reward, done, info = self.env.step(action)
115 | total_reward += reward
116 | info['steps'] = i + 1
117 | if done:
118 | break
119 | return obs, total_reward, done, info
120 |
121 |
122 | class MarioEnv(gym.Wrapper):
123 | def __init__(self, env=None, tilesEnv=False):
124 | """Reset mario environment without actually restarting fceux everytime.
125 | This speeds up unrolling by approximately 10 times.
126 | """
127 | super(MarioEnv, self).__init__(env)
128 | self.resetCount = -1
129 | # reward is distance travelled. So normalize it with total distance
130 | # https://github.com/ppaquette/gym-super-mario/blob/master/ppaquette_gym_super_mario/lua/super-mario-bros.lua
131 | # However, we will not use this reward at all. It is only for completion.
132 | self.maxDistance = 3000.0
133 | self.tilesEnv = tilesEnv
134 |
135 | def _reset(self):
136 | if self.resetCount < 0:
137 | print('\nDoing hard mario fceux reset (40 seconds wait) !')
138 | sys.stdout.flush()
139 | self.env.reset()
140 | time.sleep(40)
141 | obs, _, _, info = self.env.step(7) # take right once to start game
142 | if info.get('ignore', False): # assuming this happens only in beginning
143 | self.resetCount = -1
144 | self.env.close()
145 | return self._reset()
146 | self.resetCount = info.get('iteration', -1)
147 | if self.tilesEnv:
148 | return obs
149 | return obs[24:-12,8:-8,:]
150 |
151 | def _step(self, action):
152 | obs, reward, done, info = self.env.step(action)
153 | # print('info:', info)
154 | done = info['iteration'] > self.resetCount
155 | reward = float(reward)/self.maxDistance # note: we do not use this rewards at all.
156 | if self.tilesEnv:
157 | return obs, reward, done, info
158 | return obs[24:-12,8:-8,:], reward, done, info
159 |
160 | def _close(self):
161 | self.resetCount = -1
162 | return self.env.close()
163 |
164 |
165 | class MakeEnvDynamic(gym.ObservationWrapper):
166 | """Make observation dynamic by adding noise"""
167 | def __init__(self, env=None, percentPad=5):
168 | super(MakeEnvDynamic, self).__init__(env)
169 | self.origShape = env.observation_space.shape
170 | newside = int(round(max(self.origShape[:-1])*100./(100.-percentPad)))
171 | self.newShape = [newside, newside, 3]
172 | self.observation_space = Box(0.0, 255.0, self.newShape)
173 | self.bottomIgnore = 20 # doom 20px bottom is useless
174 | self.ob = None
175 |
176 | def _observation(self, obs):
177 | imNoise = np.random.randint(0,256,self.newShape).astype(obs.dtype)
178 | imNoise[:self.origShape[0]-self.bottomIgnore, :self.origShape[1], :] = obs[:-self.bottomIgnore,:,:]
179 | self.ob = imNoise
180 | return imNoise
181 |
182 | # def render(self, mode='human', close=False):
183 | # temp = self.env.render(mode, close)
184 | # return self.ob
185 |
--------------------------------------------------------------------------------
/src/envs.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from PIL import Image
3 | from gym.spaces.box import Box
4 | import numpy as np
5 | import gym
6 | from gym import spaces
7 | import logging
8 | import universe
9 | from universe import vectorized
10 | from universe.wrappers import BlockingReset, GymCoreAction, EpisodeID, Unvectorize, Vectorize, Vision, Logger
11 | from universe import spaces as vnc_spaces
12 | from universe.spaces.vnc_event import keycode
13 | import env_wrapper
14 | import time
15 | logger = logging.getLogger(__name__)
16 | logger.setLevel(logging.INFO)
17 | universe.configure_logging()
18 |
19 | def create_env(env_id, client_id, remotes, **kwargs):
20 | if 'doom' in env_id.lower() or 'labyrinth' in env_id.lower():
21 | return create_doom(env_id, client_id, **kwargs)
22 | if 'mario' in env_id.lower():
23 | return create_mario(env_id, client_id, **kwargs)
24 |
25 | spec = gym.spec(env_id)
26 | if spec.tags.get('flashgames', False):
27 | return create_flash_env(env_id, client_id, remotes, **kwargs)
28 | elif spec.tags.get('atari', False) and spec.tags.get('vnc', False):
29 | return create_vncatari_env(env_id, client_id, remotes, **kwargs)
30 | else:
31 | # Assume atari.
32 | assert "." not in env_id # universe environments have dots in names.
33 | return create_atari_env(env_id, **kwargs)
34 |
35 | def create_doom(env_id, client_id, envWrap=True, record=False, outdir=None,
36 | noLifeReward=False, acRepeat=0, **_):
37 | from ppaquette_gym_doom import wrappers
38 | if 'labyrinth' in env_id.lower():
39 | if 'single' in env_id.lower():
40 | env_id = 'ppaquette/LabyrinthSingle-v0'
41 | elif 'fix' in env_id.lower():
42 | env_id = 'ppaquette/LabyrinthManyFixed-v0'
43 | else:
44 | env_id = 'ppaquette/LabyrinthMany-v0'
45 | elif 'very' in env_id.lower():
46 | env_id = 'ppaquette/DoomMyWayHomeFixed15-v0'
47 | elif 'sparse' in env_id.lower():
48 | env_id = 'ppaquette/DoomMyWayHomeFixed-v0'
49 | elif 'fix' in env_id.lower():
50 | if '1' in env_id or '2' in env_id:
51 | env_id = 'ppaquette/DoomMyWayHomeFixed' + str(env_id[-2:]) + '-v0'
52 | elif 'new' in env_id.lower():
53 | env_id = 'ppaquette/DoomMyWayHomeFixedNew-v0'
54 | else:
55 | env_id = 'ppaquette/DoomMyWayHomeFixed-v0'
56 | else:
57 | env_id = 'ppaquette/DoomMyWayHome-v0'
58 |
59 | # VizDoom workaround: Simultaneously launching multiple vizdoom processes
60 | # makes program stuck, so use the global lock in multi-threading/processing
61 | client_id = int(client_id)
62 | time.sleep(client_id * 10)
63 | env = gym.make(env_id)
64 | modewrapper = wrappers.SetPlayingMode('algo')
65 | obwrapper = wrappers.SetResolution('160x120')
66 | acwrapper = wrappers.ToDiscrete('minimal')
67 | env = modewrapper(obwrapper(acwrapper(env)))
68 | # env = env_wrapper.MakeEnvDynamic(env) # to add stochasticity
69 |
70 | if record and outdir is not None:
71 | env = gym.wrappers.Monitor(env, outdir, force=True)
72 |
73 | if envWrap:
74 | fshape = (42, 42)
75 | frame_skip = acRepeat if acRepeat>0 else 4
76 | env.seed(None)
77 | if noLifeReward:
78 | env = env_wrapper.NoNegativeRewardEnv(env)
79 | env = env_wrapper.BufferedObsEnv(env, skip=frame_skip, shape=fshape)
80 | env = env_wrapper.SkipEnv(env, skip=frame_skip)
81 | elif noLifeReward:
82 | env = env_wrapper.NoNegativeRewardEnv(env)
83 |
84 | env = Vectorize(env)
85 | env = DiagnosticsInfo(env)
86 | env = Unvectorize(env)
87 | return env
88 |
89 | def create_mario(env_id, client_id, envWrap=True, record=False, outdir=None,
90 | noLifeReward=False, acRepeat=0, **_):
91 | import ppaquette_gym_super_mario
92 | from ppaquette_gym_super_mario import wrappers
93 | if '-v' in env_id.lower():
94 | env_id = 'ppaquette/' + env_id
95 | else:
96 | env_id = 'ppaquette/SuperMarioBros-1-1-v0' # shape: (224,256,3)=(h,w,c)
97 |
98 | # Mario workaround: Simultaneously launching multiple vizdoom processes makes program stuck,
99 | # so use the global lock in multi-threading/multi-processing
100 | # see: https://github.com/ppaquette/gym-super-mario/tree/master/ppaquette_gym_super_mario
101 | client_id = int(client_id)
102 | time.sleep(client_id * 50)
103 | env = gym.make(env_id)
104 | modewrapper = wrappers.SetPlayingMode('algo')
105 | acwrapper = wrappers.ToDiscrete()
106 | env = modewrapper(acwrapper(env))
107 | env = env_wrapper.MarioEnv(env)
108 |
109 | if record and outdir is not None:
110 | env = gym.wrappers.Monitor(env, outdir, force=True)
111 |
112 | if envWrap:
113 | frame_skip = acRepeat if acRepeat>0 else 6
114 | fshape = (42, 42)
115 | env.seed(None)
116 | if noLifeReward:
117 | env = env_wrapper.NoNegativeRewardEnv(env)
118 | env = env_wrapper.BufferedObsEnv(env, skip=frame_skip, shape=fshape, maxFrames=False)
119 | if frame_skip > 1:
120 | env = env_wrapper.SkipEnv(env, skip=frame_skip)
121 | elif noLifeReward:
122 | env = env_wrapper.NoNegativeRewardEnv(env)
123 |
124 | env = Vectorize(env)
125 | env = DiagnosticsInfo(env)
126 | env = Unvectorize(env)
127 | # env.close() # TODO: think about where to put env.close !
128 | return env
129 |
130 | def create_flash_env(env_id, client_id, remotes, **_):
131 | env = gym.make(env_id)
132 | env = Vision(env)
133 | env = Logger(env)
134 | env = BlockingReset(env)
135 |
136 | reg = universe.runtime_spec('flashgames').server_registry
137 | height = reg[env_id]["height"]
138 | width = reg[env_id]["width"]
139 | env = CropScreen(env, height, width, 84, 18)
140 | env = FlashRescale(env)
141 |
142 | keys = ['left', 'right', 'up', 'down', 'x']
143 | if env_id == 'flashgames.NeonRace-v0':
144 | # Better key space for this game.
145 | keys = ['left', 'right', 'up', 'left up', 'right up', 'down', 'up x']
146 | logger.info('create_flash_env(%s): keys=%s', env_id, keys)
147 |
148 | env = DiscreteToFixedKeysVNCActions(env, keys)
149 | env = EpisodeID(env)
150 | env = DiagnosticsInfo(env)
151 | env = Unvectorize(env)
152 | env.configure(fps=5.0, remotes=remotes, start_timeout=15 * 60, client_id=client_id,
153 | vnc_driver='go', vnc_kwargs={
154 | 'encoding': 'tight', 'compress_level': 0,
155 | 'fine_quality_level': 50, 'subsample_level': 3})
156 | return env
157 |
158 | def create_vncatari_env(env_id, client_id, remotes, **_):
159 | env = gym.make(env_id)
160 | env = Vision(env)
161 | env = Logger(env)
162 | env = BlockingReset(env)
163 | env = GymCoreAction(env)
164 | env = AtariRescale42x42(env)
165 | env = EpisodeID(env)
166 | env = DiagnosticsInfo(env)
167 | env = Unvectorize(env)
168 |
169 | logger.info('Connecting to remotes: %s', remotes)
170 | fps = env.metadata['video.frames_per_second']
171 | env.configure(remotes=remotes, start_timeout=15 * 60, fps=fps, client_id=client_id)
172 | return env
173 |
174 | def create_atari_env(env_id, record=False, outdir=None, **_):
175 | env = gym.make(env_id)
176 | if record and outdir is not None:
177 | env = gym.wrappers.Monitor(env, outdir, force=True)
178 | env = Vectorize(env)
179 | env = AtariRescale42x42(env)
180 | env = DiagnosticsInfo(env)
181 | env = Unvectorize(env)
182 | return env
183 |
184 | def DiagnosticsInfo(env, *args, **kwargs):
185 | return vectorized.VectorizeFilter(env, DiagnosticsInfoI, *args, **kwargs)
186 |
187 | class DiagnosticsInfoI(vectorized.Filter):
188 | def __init__(self, log_interval=503):
189 | super(DiagnosticsInfoI, self).__init__()
190 |
191 | self._episode_time = time.time()
192 | self._last_time = time.time()
193 | self._local_t = 0
194 | self._log_interval = log_interval
195 | self._episode_reward = 0
196 | self._episode_length = 0
197 | self._all_rewards = []
198 | self._num_vnc_updates = 0
199 | self._last_episode_id = -1
200 |
201 | def _after_reset(self, observation):
202 | logger.info('Resetting environment logs')
203 | self._episode_reward = 0
204 | self._episode_length = 0
205 | self._all_rewards = []
206 | return observation
207 |
208 | def _after_step(self, observation, reward, done, info):
209 | to_log = {}
210 | if self._episode_length == 0:
211 | self._episode_time = time.time()
212 |
213 | self._local_t += 1
214 | if info.get("stats.vnc.updates.n") is not None:
215 | self._num_vnc_updates += info.get("stats.vnc.updates.n")
216 |
217 | if self._local_t % self._log_interval == 0:
218 | cur_time = time.time()
219 | elapsed = cur_time - self._last_time
220 | fps = self._log_interval / elapsed
221 | self._last_time = cur_time
222 | cur_episode_id = info.get('vectorized.episode_id', 0)
223 | to_log["diagnostics/fps"] = fps
224 | if self._last_episode_id == cur_episode_id:
225 | to_log["diagnostics/fps_within_episode"] = fps
226 | self._last_episode_id = cur_episode_id
227 | if info.get("stats.gauges.diagnostics.lag.action") is not None:
228 | to_log["diagnostics/action_lag_lb"] = info["stats.gauges.diagnostics.lag.action"][0]
229 | to_log["diagnostics/action_lag_ub"] = info["stats.gauges.diagnostics.lag.action"][1]
230 | if info.get("reward.count") is not None:
231 | to_log["diagnostics/reward_count"] = info["reward.count"]
232 | if info.get("stats.gauges.diagnostics.clock_skew") is not None:
233 | to_log["diagnostics/clock_skew_lb"] = info["stats.gauges.diagnostics.clock_skew"][0]
234 | to_log["diagnostics/clock_skew_ub"] = info["stats.gauges.diagnostics.clock_skew"][1]
235 | if info.get("stats.gauges.diagnostics.lag.observation") is not None:
236 | to_log["diagnostics/observation_lag_lb"] = info["stats.gauges.diagnostics.lag.observation"][0]
237 | to_log["diagnostics/observation_lag_ub"] = info["stats.gauges.diagnostics.lag.observation"][1]
238 |
239 | if info.get("stats.vnc.updates.n") is not None:
240 | to_log["diagnostics/vnc_updates_n"] = info["stats.vnc.updates.n"]
241 | to_log["diagnostics/vnc_updates_n_ps"] = self._num_vnc_updates / elapsed
242 | self._num_vnc_updates = 0
243 | if info.get("stats.vnc.updates.bytes") is not None:
244 | to_log["diagnostics/vnc_updates_bytes"] = info["stats.vnc.updates.bytes"]
245 | if info.get("stats.vnc.updates.pixels") is not None:
246 | to_log["diagnostics/vnc_updates_pixels"] = info["stats.vnc.updates.pixels"]
247 | if info.get("stats.vnc.updates.rectangles") is not None:
248 | to_log["diagnostics/vnc_updates_rectangles"] = info["stats.vnc.updates.rectangles"]
249 | if info.get("env_status.state_id") is not None:
250 | to_log["diagnostics/env_state_id"] = info["env_status.state_id"]
251 |
252 | if reward is not None:
253 | self._episode_reward += reward
254 | if observation is not None:
255 | self._episode_length += 1
256 | self._all_rewards.append(reward)
257 |
258 | if done:
259 | logger.info('True Game terminating: env_episode_reward=%s episode_length=%s', self._episode_reward, self._episode_length)
260 | total_time = time.time() - self._episode_time
261 | to_log["global/episode_reward"] = self._episode_reward
262 | to_log["global/episode_length"] = self._episode_length
263 | to_log["global/episode_time"] = total_time
264 | to_log["global/reward_per_time"] = self._episode_reward / total_time
265 | self._episode_reward = 0
266 | self._episode_length = 0
267 | self._all_rewards = []
268 |
269 | if 'distance' in info: to_log['distance'] = info['distance'] # mario
270 | if 'POSITION_X' in info: # doom
271 | to_log['POSITION_X'] = info['POSITION_X']
272 | to_log['POSITION_Y'] = info['POSITION_Y']
273 | return observation, reward, done, to_log
274 |
275 | def _process_frame42(frame):
276 | frame = frame[34:34+160, :160]
277 | # Resize by half, then down to 42x42 (essentially mipmapping). If
278 | # we resize directly we lose pixels that, when mapped to 42x42,
279 | # aren't close enough to the pixel boundary.
280 | frame = np.asarray(Image.fromarray(frame).resize((80, 80), resample=Image.BILINEAR).resize(
281 | (42,42), resample=Image.BILINEAR))
282 | frame = frame.mean(2) # take mean along channels
283 | frame = frame.astype(np.float32)
284 | frame *= (1.0 / 255.0)
285 | frame = np.reshape(frame, [42, 42, 1])
286 | return frame
287 |
288 | class AtariRescale42x42(vectorized.ObservationWrapper):
289 | def __init__(self, env=None):
290 | super(AtariRescale42x42, self).__init__(env)
291 | self.observation_space = Box(0.0, 1.0, [42, 42, 1])
292 |
293 | def _observation(self, observation_n):
294 | return [_process_frame42(observation) for observation in observation_n]
295 |
296 | class FixedKeyState(object):
297 | def __init__(self, keys):
298 | self._keys = [keycode(key) for key in keys]
299 | self._down_keysyms = set()
300 |
301 | def apply_vnc_actions(self, vnc_actions):
302 | for event in vnc_actions:
303 | if isinstance(event, vnc_spaces.KeyEvent):
304 | if event.down:
305 | self._down_keysyms.add(event.key)
306 | else:
307 | self._down_keysyms.discard(event.key)
308 |
309 | def to_index(self):
310 | action_n = 0
311 | for key in self._down_keysyms:
312 | if key in self._keys:
313 | # If multiple keys are pressed, just use the first one
314 | action_n = self._keys.index(key) + 1
315 | break
316 | return action_n
317 |
318 | class DiscreteToFixedKeysVNCActions(vectorized.ActionWrapper):
319 | """
320 | Define a fixed action space. Action 0 is all keys up. Each element of keys can be a single key or a space-separated list of keys
321 |
322 | For example,
323 | e=DiscreteToFixedKeysVNCActions(e, ['left', 'right'])
324 | will have 3 actions: [none, left, right]
325 |
326 | You can define a state with more than one key down by separating with spaces. For example,
327 | e=DiscreteToFixedKeysVNCActions(e, ['left', 'right', 'space', 'left space', 'right space'])
328 | will have 6 actions: [none, left, right, space, left space, right space]
329 | """
330 | def __init__(self, env, keys):
331 | super(DiscreteToFixedKeysVNCActions, self).__init__(env)
332 |
333 | self._keys = keys
334 | self._generate_actions()
335 | self.action_space = spaces.Discrete(len(self._actions))
336 |
337 | def _generate_actions(self):
338 | self._actions = []
339 | uniq_keys = set()
340 | for key in self._keys:
341 | for cur_key in key.split(' '):
342 | uniq_keys.add(cur_key)
343 |
344 | for key in [''] + self._keys:
345 | split_keys = key.split(' ')
346 | cur_action = []
347 | for cur_key in uniq_keys:
348 | cur_action.append(vnc_spaces.KeyEvent.by_name(cur_key, down=(cur_key in split_keys)))
349 | self._actions.append(cur_action)
350 | self.key_state = FixedKeyState(uniq_keys)
351 |
352 | def _action(self, action_n):
353 | # Each action might be a length-1 np.array. Cast to int to
354 | # avoid warnings.
355 | return [self._actions[int(action)] for action in action_n]
356 |
357 | class CropScreen(vectorized.ObservationWrapper):
358 | """Crops out a [height]x[width] area starting from (top,left) """
359 | def __init__(self, env, height, width, top=0, left=0):
360 | super(CropScreen, self).__init__(env)
361 | self.height = height
362 | self.width = width
363 | self.top = top
364 | self.left = left
365 | self.observation_space = Box(0, 255, shape=(height, width, 3))
366 |
367 | def _observation(self, observation_n):
368 | return [ob[self.top:self.top+self.height, self.left:self.left+self.width, :] if ob is not None else None
369 | for ob in observation_n]
370 |
371 | def _process_frame_flash(frame):
372 | frame = np.array(Image.fromarray(frame).resize((200, 128), resample=Image.BILINEAR))
373 | frame = frame.mean(2).astype(np.float32)
374 | frame *= (1.0 / 255.0)
375 | frame = np.reshape(frame, [128, 200, 1])
376 | return frame
377 |
378 | class FlashRescale(vectorized.ObservationWrapper):
379 | def __init__(self, env=None):
380 | super(FlashRescale, self).__init__(env)
381 | self.observation_space = Box(0.0, 1.0, [128, 200, 1])
382 |
383 | def _observation(self, observation_n):
384 | return [_process_frame_flash(observation) for observation in observation_n]
385 |
--------------------------------------------------------------------------------
/src/inference.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from __future__ import print_function
3 | import go_vncdriver
4 | import tensorflow as tf
5 | import numpy as np
6 | import argparse
7 | import logging
8 | import os
9 | import gym
10 | from envs import create_env
11 | from worker import FastSaver
12 | from model import LSTMPolicy
13 | import utils
14 | import distutils.version
15 | use_tf12_api = distutils.version.LooseVersion(tf.VERSION) >= distutils.version.LooseVersion('0.12.0')
16 |
17 | logger = logging.getLogger(__name__)
18 | logger.setLevel(logging.INFO)
19 |
20 |
21 | def inference(args):
22 | """
23 | It only restores LSTMPolicy architecture, and does inference using that.
24 | """
25 | # get address of checkpoints
26 | indir = os.path.join(args.log_dir, 'train')
27 | outdir = os.path.join(args.log_dir, 'inference') if args.out_dir is None else args.out_dir
28 | with open(indir + '/checkpoint', 'r') as f:
29 | first_line = f.readline().strip()
30 | ckpt = first_line.split(' ')[-1].split('/')[-1][:-1]
31 | ckpt = ckpt.split('-')[-1]
32 | ckpt = indir + '/model.ckpt-' + ckpt
33 |
34 | # define environment
35 | if args.record:
36 | env = create_env(args.env_id, client_id='0', remotes=None, envWrap=args.envWrap, designHead=args.designHead,
37 | record=True, noop=args.noop, acRepeat=args.acRepeat, outdir=outdir)
38 | else:
39 | env = create_env(args.env_id, client_id='0', remotes=None, envWrap=args.envWrap, designHead=args.designHead,
40 | record=True, noop=args.noop, acRepeat=args.acRepeat)
41 | numaction = env.action_space.n
42 |
43 | with tf.device("/cpu:0"):
44 | # define policy network
45 | with tf.variable_scope("global"):
46 | policy = LSTMPolicy(env.observation_space.shape, numaction, args.designHead)
47 | policy.global_step = tf.get_variable("global_step", [], tf.int32, initializer=tf.constant_initializer(0, dtype=tf.int32),
48 | trainable=False)
49 |
50 | # Variable names that start with "local" are not saved in checkpoints.
51 | if use_tf12_api:
52 | variables_to_restore = [v for v in tf.global_variables() if not v.name.startswith("local")]
53 | init_all_op = tf.global_variables_initializer()
54 | else:
55 | variables_to_restore = [v for v in tf.all_variables() if not v.name.startswith("local")]
56 | init_all_op = tf.initialize_all_variables()
57 | saver = FastSaver(variables_to_restore)
58 |
59 | # print trainable variables
60 | var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
61 | logger.info('Trainable vars:')
62 | for v in var_list:
63 | logger.info(' %s %s', v.name, v.get_shape())
64 |
65 | # summary of rewards
66 | action_writers = []
67 | if use_tf12_api:
68 | summary_writer = tf.summary.FileWriter(outdir)
69 | for ac_id in range(numaction):
70 | action_writers.append(tf.summary.FileWriter(os.path.join(outdir,'action_{}'.format(ac_id))))
71 | else:
72 | summary_writer = tf.train.SummaryWriter(outdir)
73 | for ac_id in range(numaction):
74 | action_writers.append(tf.train.SummaryWriter(os.path.join(outdir,'action_{}'.format(ac_id))))
75 | logger.info("Inference events directory: %s", outdir)
76 |
77 | config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
78 | with tf.Session(config=config) as sess:
79 | logger.info("Initializing all parameters.")
80 | sess.run(init_all_op)
81 | logger.info("Restoring trainable global parameters.")
82 | saver.restore(sess, ckpt)
83 | logger.info("Restored model was trained for %.2fM global steps", sess.run(policy.global_step)/1000000.)
84 | # saving with meta graph:
85 | # metaSaver = tf.train.Saver(variables_to_restore)
86 | # metaSaver.save(sess, 'models/doomICM')
87 |
88 | last_state = env.reset()
89 | if args.render or args.record:
90 | env.render()
91 | last_features = policy.get_initial_features() # reset lstm memory
92 | length = 0
93 | rewards = 0
94 | mario_distances = np.zeros((args.num_episodes,))
95 | for i in range(args.num_episodes):
96 | print("Starting episode %d" % (i + 1))
97 | if args.recordSignal:
98 | from PIL import Image
99 | signalCount = 1
100 | utils.mkdir_p(outdir + '/recordedSignal/ep_%02d/'%i)
101 | Image.fromarray((255*last_state[..., -1]).astype('uint8')).save(outdir + '/recordedSignal/ep_%02d/%06d.jpg'%(i,signalCount))
102 |
103 | if args.random:
104 | print('I am random policy!')
105 | else:
106 | if args.greedy:
107 | print('I am greedy policy!')
108 | else:
109 | print('I am sampled policy!')
110 | while True:
111 | # run policy
112 | fetched = policy.act_inference(last_state, *last_features)
113 | prob_action, action, value_, features = fetched[0], fetched[1], fetched[2], fetched[3:]
114 |
115 | # run environment: sampled one-hot 'action' (not greedy)
116 | if args.random:
117 | stepAct = np.random.randint(0, numaction) # random policy
118 | else:
119 | if args.greedy:
120 | stepAct = prob_action.argmax() # greedy policy
121 | else:
122 | stepAct = action.argmax()
123 | # print(stepAct, prob_action.argmax(), prob_action)
124 | state, reward, terminal, info = env.step(stepAct)
125 |
126 | # update stats
127 | length += 1
128 | rewards += reward
129 | last_state = state
130 | last_features = features
131 | if args.render or args.record:
132 | env.render()
133 | if args.recordSignal:
134 | signalCount += 1
135 | Image.fromarray((255*last_state[..., -1]).astype('uint8')).save(outdir + '/recordedSignal/ep_%02d/%06d.jpg'%(i,signalCount))
136 |
137 | # store summary
138 | summary = tf.Summary()
139 | summary.value.add(tag='ep_{}/reward'.format(i), simple_value=reward)
140 | summary.value.add(tag='ep_{}/netreward'.format(i), simple_value=rewards)
141 | summary.value.add(tag='ep_{}/value'.format(i), simple_value=float(value_[0]))
142 | if 'NoFrameskip-v' in args.env_id: # atari
143 | summary.value.add(tag='ep_{}/lives'.format(i), simple_value=env.unwrapped.ale.lives())
144 | summary_writer.add_summary(summary, length)
145 | summary_writer.flush()
146 | summary = tf.Summary()
147 | for ac_id in range(numaction):
148 | summary.value.add(tag='action_prob', simple_value=float(prob_action[ac_id]))
149 | action_writers[ac_id].add_summary(summary, length)
150 | action_writers[ac_id].flush()
151 |
152 | timestep_limit = env.spec.tags.get('wrapper_config.TimeLimit.max_episode_steps')
153 | if timestep_limit is None: timestep_limit = env.spec.timestep_limit
154 | if terminal or length >= timestep_limit:
155 | if length >= timestep_limit or not env.metadata.get('semantics.autoreset'):
156 | last_state = env.reset()
157 | last_features = policy.get_initial_features() # reset lstm memory
158 | print("Episode finished. Sum of rewards: %.2f. Length: %d." % (rewards, length))
159 | if 'distance' in info:
160 | print('Mario Distance Covered:', info['distance'])
161 | mario_distances[i] = info['distance']
162 | length = 0
163 | rewards = 0
164 | if args.render or args.record:
165 | env.render()
166 | if args.recordSignal:
167 | signalCount += 1
168 | Image.fromarray((255*last_state[..., -1]).astype('uint8')).save(outdir + '/recordedSignal/ep_%02d/%06d.jpg'%(i,signalCount))
169 | break
170 |
171 | logger.info('Finished %d true episodes.', args.num_episodes)
172 | if 'distance' in info:
173 | print('Mario Distances:', mario_distances)
174 | np.save(outdir + '/distances.npy', mario_distances)
175 | env.close()
176 |
177 |
178 | def main(_):
179 | parser = argparse.ArgumentParser(description=None)
180 | parser.add_argument('--log-dir', default="tmp/doom", help='input model directory')
181 | parser.add_argument('--out-dir', default=None, help='output log directory. Default: log_dir/inference/')
182 | parser.add_argument('--env-id', default="PongDeterministic-v3", help='Environment id')
183 | parser.add_argument('--record', action='store_true',
184 | help="Record the gym environment video -- user friendly")
185 | parser.add_argument('--recordSignal', action='store_true',
186 | help="Record images of true processed input to network")
187 | parser.add_argument('--render', action='store_true',
188 | help="Render the gym environment video online")
189 | parser.add_argument('--envWrap', action='store_true',
190 | help="Preprocess input in env_wrapper (no change in input size or network)")
191 | parser.add_argument('--designHead', type=str, default='universe',
192 | help="Network deign head: nips or nature or doom or universe(default)")
193 | parser.add_argument('--num-episodes', type=int, default=2,
194 | help="Number of episodes to run")
195 | parser.add_argument('--noop', action='store_true',
196 | help="Add 30-noop for inference too (recommended by Nature paper, don't know?)")
197 | parser.add_argument('--acRepeat', type=int, default=0,
198 | help="Actions to be repeated at inference. 0 means default. applies iff envWrap is True.")
199 | parser.add_argument('--greedy', action='store_true',
200 | help="Default sampled policy. This option does argmax.")
201 | parser.add_argument('--random', action='store_true',
202 | help="Default sampled policy. This option does random policy.")
203 | parser.add_argument('--default', action='store_true', help="run with default params")
204 | args = parser.parse_args()
205 | if args.default:
206 | args.envWrap = True
207 | args.acRepeat = 1
208 | if args.acRepeat <= 0:
209 | print('Using default action repeat (i.e. 4). Min value that can be set is 1.')
210 | inference(args)
211 |
212 | if __name__ == "__main__":
213 | tf.app.run()
214 |
--------------------------------------------------------------------------------
/src/mario.py:
--------------------------------------------------------------------------------
1 | '''
2 | Script to test if mario installation works fine. It
3 | displays the game play simultaneously.
4 | '''
5 |
6 | from __future__ import print_function
7 | import gym, universe
8 | import env_wrapper
9 | import ppaquette_gym_super_mario
10 | from ppaquette_gym_super_mario import wrappers
11 | import numpy as np
12 | import time
13 | from PIL import Image
14 | import utils
15 |
16 | outputdir = './gray42/'
17 | env_id = 'ppaquette/SuperMarioBros-1-1-v0'
18 | env = gym.make(env_id)
19 | modewrapper = wrappers.SetPlayingMode('algo')
20 | acwrapper = wrappers.ToDiscrete()
21 | env = modewrapper(acwrapper(env))
22 | env = env_wrapper.MarioEnv(env)
23 |
24 | freshape = fshape = (42, 42)
25 | env.seed(None)
26 | env = env_wrapper.NoNegativeRewardEnv(env)
27 | env = env_wrapper.DQNObsEnv(env, shape=freshape)
28 | env = env_wrapper.BufferedObsEnv(env, n=4, skip=1, shape=fshape, channel_last=True)
29 | env = env_wrapper.EltwiseScaleObsEnv(env)
30 |
31 | start = time.time()
32 | episodes = 0
33 | maxepisodes = 1
34 | env.reset()
35 | imCount = 1
36 | utils.mkdir_p(outputdir + '/ep_%02d/'%(episodes+1))
37 | while(1):
38 | obs, reward, done, info = env.step(env.action_space.sample())
39 | Image.fromarray((255*obs).astype('uint8')).save(outputdir + '/ep_%02d/%06d.jpg'%(episodes+1,imCount))
40 | imCount += 1
41 | if done:
42 | episodes += 1
43 | print('Ep: %d, Distance: %d'%(episodes, info['distance']))
44 | if episodes >= maxepisodes:
45 | break
46 | env.reset()
47 | imCount = 1
48 | utils.mkdir_p(outputdir + '/ep_%02d/'%(episodes+1))
49 | end = time.time()
50 | print('\nTotal Time spent: %0.2f seconds'% (end-start))
51 | env.close()
52 | print('Done!')
53 | exit(1)
54 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import numpy as np
3 | import tensorflow as tf
4 | import tensorflow.contrib.rnn as rnn
5 | from constants import constants
6 |
7 |
8 | def normalized_columns_initializer(std=1.0):
9 | def _initializer(shape, dtype=None, partition_info=None):
10 | out = np.random.randn(*shape).astype(np.float32)
11 | out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
12 | return tf.constant(out)
13 | return _initializer
14 |
15 |
16 | def cosineLoss(A, B, name):
17 | ''' A, B : (BatchSize, d) '''
18 | dotprod = tf.reduce_sum(tf.multiply(tf.nn.l2_normalize(A,1), tf.nn.l2_normalize(B,1)), 1)
19 | loss = 1-tf.reduce_mean(dotprod, name=name)
20 | return loss
21 |
22 |
23 | def flatten(x):
24 | return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])])
25 |
26 |
27 | def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME", dtype=tf.float32, collections=None):
28 | with tf.variable_scope(name):
29 | stride_shape = [1, stride[0], stride[1], 1]
30 | filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]), num_filters]
31 |
32 | # there are "num input feature maps * filter height * filter width"
33 | # inputs to each hidden unit
34 | fan_in = np.prod(filter_shape[:3])
35 | # each unit in the lower layer receives a gradient from:
36 | # "num output feature maps * filter height * filter width" /
37 | # pooling size
38 | fan_out = np.prod(filter_shape[:2]) * num_filters
39 | # initialize weights with random weights
40 | w_bound = np.sqrt(6. / (fan_in + fan_out))
41 |
42 | w = tf.get_variable("W", filter_shape, dtype, tf.random_uniform_initializer(-w_bound, w_bound),
43 | collections=collections)
44 | b = tf.get_variable("b", [1, 1, 1, num_filters], initializer=tf.constant_initializer(0.0),
45 | collections=collections)
46 | return tf.nn.conv2d(x, w, stride_shape, pad) + b
47 |
48 |
49 | def deconv2d(x, out_shape, name, filter_size=(3, 3), stride=(1, 1), pad="SAME", dtype=tf.float32, collections=None, prevNumFeat=None):
50 | with tf.variable_scope(name):
51 | num_filters = out_shape[-1]
52 | prevNumFeat = int(x.get_shape()[3]) if prevNumFeat is None else prevNumFeat
53 | stride_shape = [1, stride[0], stride[1], 1]
54 | # transpose_filter : [height, width, out_channels, in_channels]
55 | filter_shape = [filter_size[0], filter_size[1], num_filters, prevNumFeat]
56 |
57 | # there are "num input feature maps * filter height * filter width"
58 | # inputs to each hidden unit
59 | fan_in = np.prod(filter_shape[:2]) * prevNumFeat
60 | # each unit in the lower layer receives a gradient from:
61 | # "num output feature maps * filter height * filter width"
62 | fan_out = np.prod(filter_shape[:3])
63 | # initialize weights with random weights
64 | w_bound = np.sqrt(6. / (fan_in + fan_out))
65 |
66 | w = tf.get_variable("W", filter_shape, dtype, tf.random_uniform_initializer(-w_bound, w_bound),
67 | collections=collections)
68 | b = tf.get_variable("b", [num_filters], initializer=tf.constant_initializer(0.0),
69 | collections=collections)
70 | deconv2d = tf.nn.conv2d_transpose(x, w, tf.pack(out_shape), stride_shape, pad)
71 | # deconv2d = tf.reshape(tf.nn.bias_add(deconv2d, b), deconv2d.get_shape())
72 | return deconv2d
73 |
74 |
75 | def linear(x, size, name, initializer=None, bias_init=0):
76 | w = tf.get_variable(name + "/w", [x.get_shape()[1], size], initializer=initializer)
77 | b = tf.get_variable(name + "/b", [size], initializer=tf.constant_initializer(bias_init))
78 | return tf.matmul(x, w) + b
79 |
80 |
81 | def categorical_sample(logits, d):
82 | value = tf.squeeze(tf.multinomial(logits - tf.reduce_max(logits, [1], keep_dims=True), 1), [1])
83 | return tf.one_hot(value, d)
84 |
85 |
86 | def inverseUniverseHead(x, final_shape, nConvs=4):
87 | ''' universe agent example
88 | input: [None, 288]; output: [None, 42, 42, 1];
89 | '''
90 | print('Using inverse-universe head design')
91 | bs = tf.shape(x)[0]
92 | deconv_shape1 = [final_shape[1]]
93 | deconv_shape2 = [final_shape[2]]
94 | for i in range(nConvs):
95 | deconv_shape1.append((deconv_shape1[-1]-1)/2 + 1)
96 | deconv_shape2.append((deconv_shape2[-1]-1)/2 + 1)
97 | inshapeprod = np.prod(x.get_shape().as_list()[1:]) / 32.0
98 | assert(inshapeprod == deconv_shape1[-1]*deconv_shape2[-1])
99 | # print('deconv_shape1: ',deconv_shape1)
100 | # print('deconv_shape2: ',deconv_shape2)
101 |
102 | x = tf.reshape(x, [-1, deconv_shape1[-1], deconv_shape2[-1], 32])
103 | deconv_shape1 = deconv_shape1[:-1]
104 | deconv_shape2 = deconv_shape2[:-1]
105 | for i in range(nConvs-1):
106 | x = tf.nn.elu(deconv2d(x, [bs, deconv_shape1[-1], deconv_shape2[-1], 32],
107 | "dl{}".format(i + 1), [3, 3], [2, 2], prevNumFeat=32))
108 | deconv_shape1 = deconv_shape1[:-1]
109 | deconv_shape2 = deconv_shape2[:-1]
110 | x = deconv2d(x, [bs] + final_shape[1:], "dl4", [3, 3], [2, 2], prevNumFeat=32)
111 | return x
112 |
113 |
114 | def universeHead(x, nConvs=4):
115 | ''' universe agent example
116 | input: [None, 42, 42, 1]; output: [None, 288];
117 | '''
118 | print('Using universe head design')
119 | for i in range(nConvs):
120 | x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
121 | # print('Loop{} '.format(i+1),tf.shape(x))
122 | # print('Loop{}'.format(i+1),x.get_shape())
123 | x = flatten(x)
124 | return x
125 |
126 |
127 | def nipsHead(x):
128 | ''' DQN NIPS 2013 and A3C paper
129 | input: [None, 84, 84, 4]; output: [None, 2592] -> [None, 256];
130 | '''
131 | print('Using nips head design')
132 | x = tf.nn.relu(conv2d(x, 16, "l1", [8, 8], [4, 4], pad="VALID"))
133 | x = tf.nn.relu(conv2d(x, 32, "l2", [4, 4], [2, 2], pad="VALID"))
134 | x = flatten(x)
135 | x = tf.nn.relu(linear(x, 256, "fc", normalized_columns_initializer(0.01)))
136 | return x
137 |
138 |
139 | def natureHead(x):
140 | ''' DQN Nature 2015 paper
141 | input: [None, 84, 84, 4]; output: [None, 3136] -> [None, 512];
142 | '''
143 | print('Using nature head design')
144 | x = tf.nn.relu(conv2d(x, 32, "l1", [8, 8], [4, 4], pad="VALID"))
145 | x = tf.nn.relu(conv2d(x, 64, "l2", [4, 4], [2, 2], pad="VALID"))
146 | x = tf.nn.relu(conv2d(x, 64, "l3", [3, 3], [1, 1], pad="VALID"))
147 | x = flatten(x)
148 | x = tf.nn.relu(linear(x, 512, "fc", normalized_columns_initializer(0.01)))
149 | return x
150 |
151 |
152 | def doomHead(x):
153 | ''' Learning by Prediction ICLR 2017 paper
154 | (their final output was 64 changed to 256 here)
155 | input: [None, 120, 160, 1]; output: [None, 1280] -> [None, 256];
156 | '''
157 | print('Using doom head design')
158 | x = tf.nn.elu(conv2d(x, 8, "l1", [5, 5], [4, 4]))
159 | x = tf.nn.elu(conv2d(x, 16, "l2", [3, 3], [2, 2]))
160 | x = tf.nn.elu(conv2d(x, 32, "l3", [3, 3], [2, 2]))
161 | x = tf.nn.elu(conv2d(x, 64, "l4", [3, 3], [2, 2]))
162 | x = flatten(x)
163 | x = tf.nn.elu(linear(x, 256, "fc", normalized_columns_initializer(0.01)))
164 | return x
165 |
166 |
167 | class LSTMPolicy(object):
168 | def __init__(self, ob_space, ac_space, designHead='universe'):
169 | self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space), name='x')
170 | size = 256
171 | if designHead == 'nips':
172 | x = nipsHead(x)
173 | elif designHead == 'nature':
174 | x = natureHead(x)
175 | elif designHead == 'doom':
176 | x = doomHead(x)
177 | elif 'tile' in designHead:
178 | x = universeHead(x, nConvs=2)
179 | else:
180 | x = universeHead(x)
181 |
182 | # introduce a "fake" batch dimension of 1 to do LSTM over time dim
183 | x = tf.expand_dims(x, [0])
184 | lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True)
185 | self.state_size = lstm.state_size
186 | step_size = tf.shape(self.x)[:1]
187 |
188 | c_init = np.zeros((1, lstm.state_size.c), np.float32)
189 | h_init = np.zeros((1, lstm.state_size.h), np.float32)
190 | self.state_init = [c_init, h_init]
191 | c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c], name='c_in')
192 | h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h], name='h_in')
193 | self.state_in = [c_in, h_in]
194 |
195 | state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in)
196 | lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
197 | lstm, x, initial_state=state_in, sequence_length=step_size,
198 | time_major=False)
199 | lstm_c, lstm_h = lstm_state
200 | x = tf.reshape(lstm_outputs, [-1, size])
201 | self.vf = tf.reshape(linear(x, 1, "value", normalized_columns_initializer(1.0)), [-1])
202 | self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
203 |
204 | # [0, :] means pick action of first state from batch. Hardcoded b/c
205 | # batch=1 during rollout collection. Its not used during batch training.
206 | self.logits = linear(x, ac_space, "action", normalized_columns_initializer(0.01))
207 | self.sample = categorical_sample(self.logits, ac_space)[0, :]
208 | self.probs = tf.nn.softmax(self.logits, dim=-1)[0, :]
209 |
210 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
211 | # tf.add_to_collection('probs', self.probs)
212 | # tf.add_to_collection('sample', self.sample)
213 | # tf.add_to_collection('state_out_0', self.state_out[0])
214 | # tf.add_to_collection('state_out_1', self.state_out[1])
215 | # tf.add_to_collection('vf', self.vf)
216 |
217 | def get_initial_features(self):
218 | # Call this function to get reseted lstm memory cells
219 | return self.state_init
220 |
221 | def act(self, ob, c, h):
222 | sess = tf.get_default_session()
223 | return sess.run([self.sample, self.vf] + self.state_out,
224 | {self.x: [ob], self.state_in[0]: c, self.state_in[1]: h})
225 |
226 | def act_inference(self, ob, c, h):
227 | sess = tf.get_default_session()
228 | return sess.run([self.probs, self.sample, self.vf] + self.state_out,
229 | {self.x: [ob], self.state_in[0]: c, self.state_in[1]: h})
230 |
231 | def value(self, ob, c, h):
232 | sess = tf.get_default_session()
233 | return sess.run(self.vf, {self.x: [ob], self.state_in[0]: c, self.state_in[1]: h})[0]
234 |
235 |
236 | class StateActionPredictor(object):
237 | def __init__(self, ob_space, ac_space, designHead='universe'):
238 | # input: s1,s2: : [None, h, w, ch] (usually ch=1 or 4)
239 | # asample: 1-hot encoding of sampled action from policy: [None, ac_space]
240 | input_shape = [None] + list(ob_space)
241 | self.s1 = phi1 = tf.placeholder(tf.float32, input_shape)
242 | self.s2 = phi2 = tf.placeholder(tf.float32, input_shape)
243 | self.asample = asample = tf.placeholder(tf.float32, [None, ac_space])
244 |
245 | # feature encoding: phi1, phi2: [None, LEN]
246 | size = 256
247 | if designHead == 'nips':
248 | phi1 = nipsHead(phi1)
249 | with tf.variable_scope(tf.get_variable_scope(), reuse=True):
250 | phi2 = nipsHead(phi2)
251 | elif designHead == 'nature':
252 | phi1 = natureHead(phi1)
253 | with tf.variable_scope(tf.get_variable_scope(), reuse=True):
254 | phi2 = natureHead(phi2)
255 | elif designHead == 'doom':
256 | phi1 = doomHead(phi1)
257 | with tf.variable_scope(tf.get_variable_scope(), reuse=True):
258 | phi2 = doomHead(phi2)
259 | elif 'tile' in designHead:
260 | phi1 = universeHead(phi1, nConvs=2)
261 | with tf.variable_scope(tf.get_variable_scope(), reuse=True):
262 | phi2 = universeHead(phi2, nConvs=2)
263 | else:
264 | phi1 = universeHead(phi1)
265 | with tf.variable_scope(tf.get_variable_scope(), reuse=True):
266 | phi2 = universeHead(phi2)
267 |
268 | # inverse model: g(phi1,phi2) -> a_inv: [None, ac_space]
269 | g = tf.concat(1,[phi1, phi2])
270 | g = tf.nn.relu(linear(g, size, "g1", normalized_columns_initializer(0.01)))
271 | aindex = tf.argmax(asample, axis=1) # aindex: [batch_size,]
272 | logits = linear(g, ac_space, "glast", normalized_columns_initializer(0.01))
273 | self.invloss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
274 | logits, aindex), name="invloss")
275 | self.ainvprobs = tf.nn.softmax(logits, dim=-1)
276 |
277 | # forward model: f(phi1,asample) -> phi2
278 | # Note: no backprop to asample of policy: it is treated as fixed for predictor training
279 | f = tf.concat(1, [phi1, asample])
280 | f = tf.nn.relu(linear(f, size, "f1", normalized_columns_initializer(0.01)))
281 | f = linear(f, phi1.get_shape()[1].value, "flast", normalized_columns_initializer(0.01))
282 | self.forwardloss = 0.5 * tf.reduce_mean(tf.square(tf.subtract(f, phi2)), name='forwardloss')
283 | # self.forwardloss = 0.5 * tf.reduce_mean(tf.sqrt(tf.abs(tf.subtract(f, phi2))), name='forwardloss')
284 | # self.forwardloss = cosineLoss(f, phi2, name='forwardloss')
285 | self.forwardloss = self.forwardloss * 288.0 # lenFeatures=288. Factored out to make hyperparams not depend on it.
286 |
287 | # variable list
288 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
289 |
290 | def pred_act(self, s1, s2):
291 | '''
292 | returns action probability distribution predicted by inverse model
293 | input: s1,s2: [h, w, ch]
294 | output: ainvprobs: [ac_space]
295 | '''
296 | sess = tf.get_default_session()
297 | return sess.run(self.ainvprobs, {self.s1: [s1], self.s2: [s2]})[0, :]
298 |
299 | def pred_bonus(self, s1, s2, asample):
300 | '''
301 | returns bonus predicted by forward model
302 | input: s1,s2: [h, w, ch], asample: [ac_space] 1-hot encoding
303 | output: scalar bonus
304 | '''
305 | sess = tf.get_default_session()
306 | # error = sess.run([self.forwardloss, self.invloss],
307 | # {self.s1: [s1], self.s2: [s2], self.asample: [asample]})
308 | # print('ErrorF: ', error[0], ' ErrorI:', error[1])
309 | error = sess.run(self.forwardloss,
310 | {self.s1: [s1], self.s2: [s2], self.asample: [asample]})
311 | error = error * constants['PREDICTION_BETA']
312 | return error
313 |
314 |
315 | class StatePredictor(object):
316 | '''
317 | Loss is normalized across spatial dimension (42x42), but not across batches.
318 | It is unlike ICM where no normalization is there across 288 spatial dimension
319 | and neither across batches.
320 | '''
321 |
322 | def __init__(self, ob_space, ac_space, designHead='universe', unsupType='state'):
323 | # input: s1,s2: : [None, h, w, ch] (usually ch=1 or 4)
324 | # asample: 1-hot encoding of sampled action from policy: [None, ac_space]
325 | input_shape = [None] + list(ob_space)
326 | self.s1 = phi1 = tf.placeholder(tf.float32, input_shape)
327 | self.s2 = phi2 = tf.placeholder(tf.float32, input_shape)
328 | self.asample = asample = tf.placeholder(tf.float32, [None, ac_space])
329 | self.stateAenc = unsupType == 'stateAenc'
330 |
331 | # feature encoding: phi1: [None, LEN]
332 | if designHead == 'universe':
333 | phi1 = universeHead(phi1)
334 | if self.stateAenc:
335 | with tf.variable_scope(tf.get_variable_scope(), reuse=True):
336 | phi2_aenc = universeHead(phi2)
337 | elif 'tile' in designHead: # for mario tiles
338 | phi1 = universeHead(phi1, nConvs=2)
339 | if self.stateAenc:
340 | with tf.variable_scope(tf.get_variable_scope(), reuse=True):
341 | phi2_aenc = universeHead(phi2)
342 | else:
343 | print('Only universe designHead implemented for state prediction baseline.')
344 | exit(1)
345 |
346 | # forward model: f(phi1,asample) -> phi2
347 | # Note: no backprop to asample of policy: it is treated as fixed for predictor training
348 | f = tf.concat(1, [phi1, asample])
349 | f = tf.nn.relu(linear(f, phi1.get_shape()[1].value, "f1", normalized_columns_initializer(0.01)))
350 | if 'tile' in designHead:
351 | f = inverseUniverseHead(f, input_shape, nConvs=2)
352 | else:
353 | f = inverseUniverseHead(f, input_shape)
354 | self.forwardloss = 0.5 * tf.reduce_mean(tf.square(tf.subtract(f, phi2)), name='forwardloss')
355 | if self.stateAenc:
356 | self.aencBonus = 0.5 * tf.reduce_mean(tf.square(tf.subtract(phi1, phi2_aenc)), name='aencBonus')
357 | self.predstate = phi1
358 |
359 | # variable list
360 | self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
361 |
362 | def pred_state(self, s1, asample):
363 | '''
364 | returns state predicted by forward model
365 | input: s1: [h, w, ch], asample: [ac_space] 1-hot encoding
366 | output: s2: [h, w, ch]
367 | '''
368 | sess = tf.get_default_session()
369 | return sess.run(self.predstate, {self.s1: [s1],
370 | self.asample: [asample]})[0, :]
371 |
372 | def pred_bonus(self, s1, s2, asample):
373 | '''
374 | returns bonus predicted by forward model
375 | input: s1,s2: [h, w, ch], asample: [ac_space] 1-hot encoding
376 | output: scalar bonus
377 | '''
378 | sess = tf.get_default_session()
379 | bonus = self.aencBonus if self.stateAenc else self.forwardloss
380 | error = sess.run(bonus,
381 | {self.s1: [s1], self.s2: [s2], self.asample: [asample]})
382 | # print('ErrorF: ', error)
383 | error = error * constants['PREDICTION_BETA']
384 | return error
385 |
--------------------------------------------------------------------------------
/src/requirements.txt:
--------------------------------------------------------------------------------
1 | atari-py==0.1.1
2 | attrs==17.2.0
3 | autobahn==17.6.2
4 | Automat==0.6.0
5 | backports.ssl-match-hostname==3.5.0.1
6 | certifi==2017.4.17
7 | chardet==3.0.4
8 | constantly==15.1.0
9 | docker-py==1.10.3
10 | docker-pycreds==0.2.1
11 | doom-py==0.0.15
12 | fastzbarlight==0.0.14
13 | funcsigs==1.0.2
14 | -e git+https://github.com/openai/go-vncdriver.git@33bd0dd9620e97acd9b4e559bca217df09ba89e6#egg=go_vncdriver
15 | -e git+https://github.com/openai/gym.git@6f277090ed3323009a324ea31d00363afd8dfb3a#egg=gym
16 | -e git+https://github.com/pathak22/gym-pull.git@589039c29567c67fb3d5c0a315806419e0999415#egg=gym_pull
17 | hyperlink==17.2.1
18 | idna==2.5
19 | incremental==17.5.0
20 | ipaddress==1.0.18
21 | mock==2.0.0
22 | numpy==1.13.1
23 | olefile==0.44
24 | pbr==3.1.1
25 | Pillow==4.2.1
26 | ppaquette-gym-doom==0.0.3
27 | -e git+https://github.com/ppaquette/gym-super-mario.git@2e5ee823b6090af3f99b1f62c465fc4b033532f4#egg=ppaquette_gym_super_mario
28 | protobuf==3.1.0
29 | pyglet==1.2.4
30 | PyOpenGL==3.1.0
31 | PyYAML==3.12
32 | requests>=2.20.0
33 | scipy==0.19.1
34 | six==1.10.0
35 | tensorflow==0.12.0rc1
36 | Twisted==17.5.0
37 | txaio==2.8.0
38 | ujson==1.35
39 | -e git+https://github.com/openai/universe.git@e8037a103d8871a29396c39b2a58df439bde3380#egg=universe
40 | urllib3==1.21.1
41 | websocket-client==0.44.0
42 | zope.interface==4.4.2
43 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | from six.moves import shlex_quote
5 |
6 | parser = argparse.ArgumentParser(description="Run commands")
7 | parser.add_argument('-w', '--num-workers', default=20, type=int,
8 | help="Number of workers")
9 | parser.add_argument('-r', '--remotes', default=None,
10 | help='The address of pre-existing VNC servers and '
11 | 'rewarders to use (e.g. -r vnc://localhost:5900+15900,vnc://localhost:5901+15901).')
12 | parser.add_argument('-e', '--env-id', type=str, default="doom",
13 | help="Environment id")
14 | parser.add_argument('-l', '--log-dir', type=str, default="tmp/doom",
15 | help="Log directory path")
16 | parser.add_argument('-n', '--dry-run', action='store_true',
17 | help="Print out commands rather than executing them")
18 | parser.add_argument('-m', '--mode', type=str, default='tmux',
19 | help="tmux: run workers in a tmux session. nohup: run workers with nohup. child: run workers as child processes")
20 | parser.add_argument('--visualise', action='store_true',
21 | help="Visualise the gym environment by running env.render() between each timestep")
22 | parser.add_argument('--envWrap', action='store_true',
23 | help="Preprocess input in env_wrapper (no change in input size or network)")
24 | parser.add_argument('--designHead', type=str, default='universe',
25 | help="Network deign head: nips or nature or doom or universe(default)")
26 | parser.add_argument('--unsup', type=str, default=None,
27 | help="Unsup. exploration mode: action or state or stateAenc or None")
28 | parser.add_argument('--noReward', action='store_true', help="Remove all extrinsic reward")
29 | parser.add_argument('--noLifeReward', action='store_true',
30 | help="Remove all negative reward (in doom: it is living reward)")
31 | parser.add_argument('--expName', type=str, default='a3c',
32 | help="Experiment tmux session-name. Default a3c.")
33 | parser.add_argument('--expId', type=int, default=0,
34 | help="Experiment Id >=0. Needed while runnig more than one run per machine.")
35 | parser.add_argument('--savio', action='store_true',
36 | help="Savio or KNL cpu cluster hacks")
37 | parser.add_argument('--default', action='store_true', help="run with default params")
38 | parser.add_argument('--pretrain', type=str, default=None, help="Checkpoint dir (generally ..../train/) to load from.")
39 |
40 | def new_cmd(session, name, cmd, mode, logdir, shell):
41 | if isinstance(cmd, (list, tuple)):
42 | cmd = " ".join(shlex_quote(str(v)) for v in cmd)
43 | if mode == 'tmux':
44 | return name, "tmux send-keys -t {}:{} {} Enter".format(session, name, shlex_quote(cmd))
45 | elif mode == 'child':
46 | return name, "{} >{}/{}.{}.out 2>&1 & echo kill $! >>{}/kill.sh".format(cmd, logdir, session, name, logdir)
47 | elif mode == 'nohup':
48 | return name, "nohup {} -c {} >{}/{}.{}.out 2>&1 & echo kill $! >>{}/kill.sh".format(shell, shlex_quote(cmd), logdir, session, name, logdir)
49 |
50 |
51 | def create_commands(session, num_workers, remotes, env_id, logdir, shell='bash',
52 | mode='tmux', visualise=False, envWrap=False, designHead=None,
53 | unsup=None, noReward=False, noLifeReward=False, psPort=12222,
54 | delay=0, savio=False, pretrain=None):
55 | # for launching the TF workers and for launching tensorboard
56 | py_cmd = 'python' if savio else sys.executable
57 | base_cmd = [
58 | 'CUDA_VISIBLE_DEVICES=',
59 | py_cmd, 'worker.py',
60 | '--log-dir', logdir,
61 | '--env-id', env_id,
62 | '--num-workers', str(num_workers),
63 | '--psPort', psPort]
64 |
65 | if delay > 0:
66 | base_cmd += ['--delay', delay]
67 | if visualise:
68 | base_cmd += ['--visualise']
69 | if envWrap:
70 | base_cmd += ['--envWrap']
71 | if designHead is not None:
72 | base_cmd += ['--designHead', designHead]
73 | if unsup is not None:
74 | base_cmd += ['--unsup', unsup]
75 | if noReward:
76 | base_cmd += ['--noReward']
77 | if noLifeReward:
78 | base_cmd += ['--noLifeReward']
79 | if pretrain is not None:
80 | base_cmd += ['--pretrain', pretrain]
81 |
82 | if remotes is None:
83 | remotes = ["1"] * num_workers
84 | else:
85 | remotes = remotes.split(',')
86 | assert len(remotes) == num_workers
87 |
88 | cmds_map = [new_cmd(session, "ps", base_cmd + ["--job-name", "ps"], mode, logdir, shell)]
89 | for i in range(num_workers):
90 | cmds_map += [new_cmd(session,
91 | "w-%d" % i, base_cmd + ["--job-name", "worker", "--task", str(i), "--remotes", remotes[i]], mode, logdir, shell)]
92 |
93 | # No tensorboard or htop window if running multiple experiments per machine
94 | if session == 'a3c':
95 | cmds_map += [new_cmd(session, "tb", ["tensorboard", "--logdir", logdir, "--port", "12345"], mode, logdir, shell)]
96 | if session == 'a3c' and mode == 'tmux':
97 | cmds_map += [new_cmd(session, "htop", ["htop"], mode, logdir, shell)]
98 |
99 | windows = [v[0] for v in cmds_map]
100 |
101 | notes = []
102 | cmds = [
103 | "mkdir -p {}".format(logdir),
104 | "echo {} {} > {}/cmd.sh".format(sys.executable, ' '.join([shlex_quote(arg) for arg in sys.argv if arg != '-n']), logdir),
105 | ]
106 | if mode == 'nohup' or mode == 'child':
107 | cmds += ["echo '#!/bin/sh' >{}/kill.sh".format(logdir)]
108 | notes += ["Run `source {}/kill.sh` to kill the job".format(logdir)]
109 | if mode == 'tmux':
110 | notes += ["Use `tmux attach -t {}` to watch process output".format(session)]
111 | notes += ["Use `tmux kill-session -t {}` to kill the job".format(session)]
112 | else:
113 | notes += ["Use `tail -f {}/*.out` to watch process output".format(logdir)]
114 | notes += ["Point your browser to http://localhost:12345 to see Tensorboard"]
115 |
116 | if mode == 'tmux':
117 | cmds += [
118 | "kill -9 $( lsof -i:12345 -t ) > /dev/null 2>&1", # kill any process using tensorboard's port
119 | "kill -9 $( lsof -i:{}-{} -t ) > /dev/null 2>&1".format(psPort, num_workers+psPort), # kill any processes using ps / worker ports
120 | "tmux kill-session -t {}".format(session),
121 | "tmux new-session -s {} -n {} -d {}".format(session, windows[0], shell)
122 | ]
123 | for w in windows[1:]:
124 | cmds += ["tmux new-window -t {} -n {} {}".format(session, w, shell)]
125 | cmds += ["sleep 1"]
126 | for window, cmd in cmds_map:
127 | cmds += [cmd]
128 |
129 | return cmds, notes
130 |
131 |
132 | def run():
133 | args = parser.parse_args()
134 | if args.default:
135 | args.envWrap = True
136 | args.savio = True
137 | args.noLifeReward = True
138 | args.unsup = 'action'
139 |
140 | # handling nuances of running multiple jobs per-machine
141 | psPort = 12222 + 50*args.expId
142 | delay = 220*args.expId if 'doom' in args.env_id.lower() or 'labyrinth' in args.env_id.lower() else 5*args.expId
143 | delay = 6*delay if 'mario' in args.env_id else delay
144 |
145 | cmds, notes = create_commands(args.expName, args.num_workers, args.remotes, args.env_id,
146 | args.log_dir, mode=args.mode, visualise=args.visualise,
147 | envWrap=args.envWrap, designHead=args.designHead,
148 | unsup=args.unsup, noReward=args.noReward,
149 | noLifeReward=args.noLifeReward, psPort=psPort,
150 | delay=delay, savio=args.savio, pretrain=args.pretrain)
151 | if args.dry_run:
152 | print("Dry-run mode due to -n flag, otherwise the following commands would be executed:")
153 | else:
154 | print("Executing the following commands:")
155 | print("\n".join(cmds))
156 | print("")
157 | if not args.dry_run:
158 | if args.mode == "tmux":
159 | os.environ["TMUX"] = ""
160 | os.system("\n".join(cmds))
161 | print('\n'.join(notes))
162 |
163 |
164 | if __name__ == "__main__":
165 | run()
166 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import sys
3 | import os
4 | import errno
5 |
6 |
7 | def mkdir_p(path):
8 | """
9 | It creates directory recursively if it does not already exist
10 | """
11 | try:
12 | os.makedirs(path)
13 | except OSError as exc:
14 | if exc.errno == errno.EEXIST and os.path.isdir(path):
15 | pass
16 | else:
17 | raise
18 |
--------------------------------------------------------------------------------
/src/worker.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import go_vncdriver
3 | import tensorflow as tf
4 | import argparse
5 | import logging
6 | import sys, signal
7 | import time
8 | import os
9 | from a3c import A3C
10 | from envs import create_env
11 | from constants import constants
12 | import distutils.version
13 | use_tf12_api = distutils.version.LooseVersion(tf.VERSION) >= distutils.version.LooseVersion('0.12.0')
14 |
15 | logger = logging.getLogger(__name__)
16 | logger.setLevel(logging.INFO)
17 |
18 | # Disables write_meta_graph argument, which freezes entire process and is mostly useless.
19 | class FastSaver(tf.train.Saver):
20 | def save(self, sess, save_path, global_step=None, latest_filename=None,
21 | meta_graph_suffix="meta", write_meta_graph=True):
22 | super(FastSaver, self).save(sess, save_path, global_step, latest_filename,
23 | meta_graph_suffix, False)
24 |
25 | def run(args, server):
26 | env = create_env(args.env_id, client_id=str(args.task), remotes=args.remotes, envWrap=args.envWrap, designHead=args.designHead,
27 | noLifeReward=args.noLifeReward)
28 | trainer = A3C(env, args.task, args.visualise, args.unsup, args.envWrap, args.designHead, args.noReward)
29 |
30 | # logging
31 | if args.task == 0:
32 | with open(args.log_dir + '/log.txt', 'w') as fid:
33 | for key, val in constants.items():
34 | fid.write('%s: %s\n'%(str(key), str(val)))
35 | fid.write('designHead: %s\n'%args.designHead)
36 | fid.write('input observation: %s\n'%str(env.observation_space.shape))
37 | fid.write('env name: %s\n'%str(env.spec.id))
38 | fid.write('unsup method type: %s\n'%str(args.unsup))
39 |
40 | # Variable names that start with "local" are not saved in checkpoints.
41 | if use_tf12_api:
42 | variables_to_save = [v for v in tf.global_variables() if not v.name.startswith("local")]
43 | init_op = tf.variables_initializer(variables_to_save)
44 | init_all_op = tf.global_variables_initializer()
45 | else:
46 | variables_to_save = [v for v in tf.all_variables() if not v.name.startswith("local")]
47 | init_op = tf.initialize_variables(variables_to_save)
48 | init_all_op = tf.initialize_all_variables()
49 | saver = FastSaver(variables_to_save)
50 | if args.pretrain is not None:
51 | variables_to_restore = [v for v in tf.trainable_variables() if not v.name.startswith("local")]
52 | pretrain_saver = FastSaver(variables_to_restore)
53 |
54 | var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
55 | logger.info('Trainable vars:')
56 | for v in var_list:
57 | logger.info(' %s %s', v.name, v.get_shape())
58 |
59 | def init_fn(ses):
60 | logger.info("Initializing all parameters.")
61 | ses.run(init_all_op)
62 | if args.pretrain is not None:
63 | pretrain = tf.train.latest_checkpoint(args.pretrain)
64 | logger.info("==> Restoring from given pretrained checkpoint.")
65 | logger.info(" Pretraining address: %s", pretrain)
66 | pretrain_saver.restore(ses, pretrain)
67 | logger.info("==> Done restoring model! Restored %d variables.", len(variables_to_restore))
68 |
69 | config = tf.ConfigProto(device_filters=["/job:ps", "/job:worker/task:{}/cpu:0".format(args.task)])
70 | logdir = os.path.join(args.log_dir, 'train')
71 |
72 | if use_tf12_api:
73 | summary_writer = tf.summary.FileWriter(logdir + "_%d" % args.task)
74 | else:
75 | summary_writer = tf.train.SummaryWriter(logdir + "_%d" % args.task)
76 |
77 | logger.info("Events directory: %s_%s", logdir, args.task)
78 | sv = tf.train.Supervisor(is_chief=(args.task == 0),
79 | logdir=logdir,
80 | saver=saver,
81 | summary_op=None,
82 | init_op=init_op,
83 | init_fn=init_fn,
84 | summary_writer=summary_writer,
85 | ready_op=tf.report_uninitialized_variables(variables_to_save),
86 | global_step=trainer.global_step,
87 | save_model_secs=30,
88 | save_summaries_secs=30)
89 |
90 | num_global_steps = constants['MAX_GLOBAL_STEPS']
91 |
92 | logger.info(
93 | "Starting session. If this hangs, we're mostly likely waiting to connect to the parameter server. " +
94 | "One common cause is that the parameter server DNS name isn't resolving yet, or is misspecified.")
95 | with sv.managed_session(server.target, config=config) as sess, sess.as_default():
96 | # Workaround for FailedPreconditionError
97 | # see: https://github.com/openai/universe-starter-agent/issues/44 and 31
98 | sess.run(trainer.sync)
99 |
100 | trainer.start(sess, summary_writer)
101 | global_step = sess.run(trainer.global_step)
102 | logger.info("Starting training at gobal_step=%d", global_step)
103 | while not sv.should_stop() and (not num_global_steps or global_step < num_global_steps):
104 | trainer.process(sess)
105 | global_step = sess.run(trainer.global_step)
106 |
107 | # Ask for all the services to stop.
108 | sv.stop()
109 | logger.info('reached %s steps. worker stopped.', global_step)
110 |
111 | def cluster_spec(num_workers, num_ps, port=12222):
112 | """
113 | More tensorflow setup for data parallelism
114 | """
115 | cluster = {}
116 |
117 | all_ps = []
118 | host = '127.0.0.1'
119 | for _ in range(num_ps):
120 | all_ps.append('{}:{}'.format(host, port))
121 | port += 1
122 | cluster['ps'] = all_ps
123 |
124 | all_workers = []
125 | for _ in range(num_workers):
126 | all_workers.append('{}:{}'.format(host, port))
127 | port += 1
128 | cluster['worker'] = all_workers
129 | return cluster
130 |
131 | def main(_):
132 | """
133 | Setting up Tensorflow for data parallel work
134 | """
135 |
136 | parser = argparse.ArgumentParser(description=None)
137 | parser.add_argument('-v', '--verbose', action='count', dest='verbosity', default=0, help='Set verbosity.')
138 | parser.add_argument('--task', default=0, type=int, help='Task index')
139 | parser.add_argument('--job-name', default="worker", help='worker or ps')
140 | parser.add_argument('--num-workers', default=1, type=int, help='Number of workers')
141 | parser.add_argument('--log-dir', default="tmp/doom", help='Log directory path')
142 | parser.add_argument('--env-id', default="doom", help='Environment id')
143 | parser.add_argument('-r', '--remotes', default=None,
144 | help='References to environments to create (e.g. -r 20), '
145 | 'or the address of pre-existing VNC servers and '
146 | 'rewarders to use (e.g. -r vnc://localhost:5900+15900,vnc://localhost:5901+15901)')
147 | parser.add_argument('--visualise', action='store_true',
148 | help="Visualise the gym environment by running env.render() between each timestep")
149 | parser.add_argument('--envWrap', action='store_true',
150 | help="Preprocess input in env_wrapper (no change in input size or network)")
151 | parser.add_argument('--designHead', type=str, default='universe',
152 | help="Network deign head: nips or nature or doom or universe(default)")
153 | parser.add_argument('--unsup', type=str, default=None,
154 | help="Unsup. exploration mode: action or state or stateAenc or None")
155 | parser.add_argument('--noReward', action='store_true', help="Remove all extrinsic reward")
156 | parser.add_argument('--noLifeReward', action='store_true',
157 | help="Remove all negative reward (in doom: it is living reward)")
158 | parser.add_argument('--psPort', default=12222, type=int, help='Port number for parameter server')
159 | parser.add_argument('--delay', default=0, type=int, help='delay start by these many seconds')
160 | parser.add_argument('--pretrain', type=str, default=None, help="Checkpoint dir (generally ..../train/) to load from.")
161 | args = parser.parse_args()
162 |
163 | spec = cluster_spec(args.num_workers, 1, args.psPort)
164 | cluster = tf.train.ClusterSpec(spec).as_cluster_def()
165 |
166 | def shutdown(signal, frame):
167 | logger.warn('Received signal %s: exiting', signal)
168 | sys.exit(128+signal)
169 | signal.signal(signal.SIGHUP, shutdown)
170 | signal.signal(signal.SIGINT, shutdown)
171 | signal.signal(signal.SIGTERM, shutdown)
172 |
173 | if args.job_name == "worker":
174 | server = tf.train.Server(cluster, job_name="worker", task_index=args.task,
175 | config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=2))
176 | if args.delay > 0:
177 | print('Startup delay in worker: {}s'.format(args.delay))
178 | time.sleep(args.delay)
179 | print('.. wait over !')
180 | run(args, server)
181 | else:
182 | server = tf.train.Server(cluster, job_name="ps", task_index=args.task,
183 | config=tf.ConfigProto(device_filters=["/job:ps"]))
184 | while True:
185 | time.sleep(1000)
186 |
187 | if __name__ == "__main__":
188 | tf.app.run()
189 |
--------------------------------------------------------------------------------