├── .gitignore ├── README.md ├── car_racing.py ├── generate-rollouts.ipynb ├── generate-rollouts.py ├── pushover.py ├── utils.py ├── vae-cnn.ipynb ├── vae.ipynb └── vae.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | __pycache__/ 3 | 4 | data/ 5 | 6 | rollouts/ 7 | 8 | *.zip 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## A CNN Variational Autoencoder in PyTorch 2 | 3 | Implemetation of David Ha's research on [Word Model](https://arxiv.org/pdf/1803.10122.pdf). 4 | 5 | -------------------------------------------------------------------------------- /car_racing.py: -------------------------------------------------------------------------------- 1 | import sys, math 2 | import numpy as np 3 | 4 | import Box2D 5 | from Box2D.b2 import (edgeShape, circleShape, fixtureDef, polygonShape, revoluteJointDef, contactListener) 6 | 7 | import gym 8 | from gym import spaces 9 | from gym.envs.box2d.car_dynamics import Car 10 | from gym.utils import colorize, seeding 11 | 12 | import pyglet 13 | from pyglet import gl 14 | 15 | # Easiest continuous control task to learn from pixels, a top-down racing environment. 16 | # Discreet control is reasonable in this environment as well, on/off discretisation is 17 | # fine. 18 | # 19 | # State consists of STATE_W x STATE_H pixels. 20 | # 21 | # Reward is -0.1 every frame and +1000/N for every track tile visited, where N is 22 | # the total number of tiles in track. For example, if you have finished in 732 frames, 23 | # your reward is 1000 - 0.1*732 = 926.8 points. 24 | # 25 | # Game is solved when agent consistently gets 900+ points. Track is random every episode. 26 | # 27 | # Episode finishes when all tiles are visited. Car also can go outside of PLAYFIELD, that 28 | # is far off the track, then it will get -100 and die. 29 | # 30 | # Some indicators shown at the bottom of the window and the state RGB buffer. From 31 | # left to right: true speed, four ABS sensors, steering wheel position, gyroscope. 32 | # 33 | # To play yourself (it's rather fast for humans), type: 34 | # 35 | # python gym/envs/box2d/car_racing.py 36 | # 37 | # Remember it's powerful rear-wheel drive car, don't press accelerator and turn at the 38 | # same time. 39 | # 40 | # Created by Oleg Klimov. Licensed on the same terms as the rest of OpenAI Gym. 41 | 42 | STATE_W = 512 # less than Atari 160x192 43 | STATE_H = 512 44 | VIDEO_W = 600 45 | VIDEO_H = 400 46 | WINDOW_W = 1200 47 | WINDOW_H = 1000 48 | 49 | SCALE = 6.0 # Track scale 50 | TRACK_RAD = 900/SCALE # Track is heavily morphed circle with this radius 51 | PLAYFIELD = 2000/SCALE # Game over boundary 52 | FPS = 50 53 | ZOOM = 2.7 # Camera zoom 54 | ZOOM_FOLLOW = True # Set to False for fixed view (don't use zoom) 55 | 56 | 57 | TRACK_DETAIL_STEP = 21/SCALE 58 | TRACK_TURN_RATE = 0.31 59 | TRACK_WIDTH = 40/SCALE 60 | BORDER = 8/SCALE 61 | BORDER_MIN_COUNT = 4 62 | 63 | ROAD_COLOR = [0.4, 0.4, 0.4] 64 | 65 | class FrictionDetector(contactListener): 66 | def __init__(self, env): 67 | contactListener.__init__(self) 68 | self.env = env 69 | def BeginContact(self, contact): 70 | self._contact(contact, True) 71 | def EndContact(self, contact): 72 | self._contact(contact, False) 73 | def _contact(self, contact, begin): 74 | tile = None 75 | obj = None 76 | u1 = contact.fixtureA.body.userData 77 | u2 = contact.fixtureB.body.userData 78 | if u1 and "road_friction" in u1.__dict__: 79 | tile = u1 80 | obj = u2 81 | if u2 and "road_friction" in u2.__dict__: 82 | tile = u2 83 | obj = u1 84 | if not tile: return 85 | 86 | tile.color[0] = ROAD_COLOR[0] 87 | tile.color[1] = ROAD_COLOR[1] 88 | tile.color[2] = ROAD_COLOR[2] 89 | if not obj or "tiles" not in obj.__dict__: return 90 | if begin: 91 | obj.tiles.add(tile) 92 | #print tile.road_friction, "ADD", len(obj.tiles) 93 | if not tile.road_visited: 94 | tile.road_visited = True 95 | self.env.reward += 1000.0/len(self.env.track) 96 | self.env.tile_visited_count += 1 97 | else: 98 | obj.tiles.remove(tile) 99 | #print tile.road_friction, "DEL", len(obj.tiles) -- should delete to zero when on grass (this works) 100 | 101 | class CarRacing(gym.Env): 102 | metadata = { 103 | 'render.modes': ['human', 'rgb_array', 'state_pixels'], 104 | 'video.frames_per_second' : FPS 105 | } 106 | 107 | def __init__(self): 108 | self.seed() 109 | self.contactListener_keepref = FrictionDetector(self) 110 | self.world = Box2D.b2World((0,0), contactListener=self.contactListener_keepref) 111 | self.viewer = None 112 | self.invisible_state_window = None 113 | self.invisible_video_window = None 114 | self.road = None 115 | self.car = None 116 | self.reward = 0.0 117 | self.prev_reward = 0.0 118 | 119 | self.action_space = spaces.Box( np.array([-1,0,0]), np.array([+1,+1,+1])) # steer, gas, brake 120 | self.observation_space = spaces.Box(low=0, high=255, shape=(STATE_H, STATE_W, 3), dtype=np.uint8) 121 | 122 | def seed(self, seed=None): 123 | self.np_random, seed = seeding.np_random(seed) 124 | return [seed] 125 | 126 | def _destroy(self): 127 | if not self.road: return 128 | for t in self.road: 129 | self.world.DestroyBody(t) 130 | self.road = [] 131 | self.car.destroy() 132 | 133 | def _create_track(self): 134 | CHECKPOINTS = 12 135 | 136 | # Create checkpoints 137 | checkpoints = [] 138 | for c in range(CHECKPOINTS): 139 | alpha = 2*math.pi*c/CHECKPOINTS + self.np_random.uniform(0, 2*math.pi*1/CHECKPOINTS) 140 | rad = self.np_random.uniform(TRACK_RAD/3, TRACK_RAD) 141 | if c==0: 142 | alpha = 0 143 | rad = 1.5*TRACK_RAD 144 | if c==CHECKPOINTS-1: 145 | alpha = 2*math.pi*c/CHECKPOINTS 146 | self.start_alpha = 2*math.pi*(-0.5)/CHECKPOINTS 147 | rad = 1.5*TRACK_RAD 148 | checkpoints.append( (alpha, rad*math.cos(alpha), rad*math.sin(alpha)) ) 149 | 150 | #print "\n".join(str(h) for h in checkpoints) 151 | #self.road_poly = [ ( # uncomment this to see checkpoints 152 | # [ (tx,ty) for a,tx,ty in checkpoints ], 153 | # (0.7,0.7,0.9) ) ] 154 | self.road = [] 155 | 156 | # Go from one checkpoint to another to create track 157 | x, y, beta = 1.5*TRACK_RAD, 0, 0 158 | dest_i = 0 159 | laps = 0 160 | track = [] 161 | no_freeze = 2500 162 | visited_other_side = False 163 | while 1: 164 | alpha = math.atan2(y, x) 165 | if visited_other_side and alpha > 0: 166 | laps += 1 167 | visited_other_side = False 168 | if alpha < 0: 169 | visited_other_side = True 170 | alpha += 2*math.pi 171 | while True: # Find destination from checkpoints 172 | failed = True 173 | while True: 174 | dest_alpha, dest_x, dest_y = checkpoints[dest_i % len(checkpoints)] 175 | if alpha <= dest_alpha: 176 | failed = False 177 | break 178 | dest_i += 1 179 | if dest_i % len(checkpoints) == 0: break 180 | if not failed: break 181 | alpha -= 2*math.pi 182 | continue 183 | r1x = math.cos(beta) 184 | r1y = math.sin(beta) 185 | p1x = -r1y 186 | p1y = r1x 187 | dest_dx = dest_x - x # vector towards destination 188 | dest_dy = dest_y - y 189 | proj = r1x*dest_dx + r1y*dest_dy # destination vector projected on rad 190 | while beta - alpha > 1.5*math.pi: beta -= 2*math.pi 191 | while beta - alpha < -1.5*math.pi: beta += 2*math.pi 192 | prev_beta = beta 193 | proj *= SCALE 194 | if proj > 0.3: beta -= min(TRACK_TURN_RATE, abs(0.001*proj)) 195 | if proj < -0.3: beta += min(TRACK_TURN_RATE, abs(0.001*proj)) 196 | x += p1x*TRACK_DETAIL_STEP 197 | y += p1y*TRACK_DETAIL_STEP 198 | track.append( (alpha,prev_beta*0.5 + beta*0.5,x,y) ) 199 | if laps > 4: break 200 | no_freeze -= 1 201 | if no_freeze==0: break 202 | #print "\n".join([str(t) for t in enumerate(track)]) 203 | 204 | # Find closed loop range i1..i2, first loop should be ignored, second is OK 205 | i1, i2 = -1, -1 206 | i = len(track) 207 | while True: 208 | i -= 1 209 | if i==0: return False # Failed 210 | pass_through_start = track[i][0] > self.start_alpha and track[i-1][0] <= self.start_alpha 211 | if pass_through_start and i2==-1: 212 | i2 = i 213 | elif pass_through_start and i1==-1: 214 | i1 = i 215 | break 216 | # print("Track generation: %i..%i -> %i-tiles track" % (i1, i2, i2-i1)) 217 | assert i1!=-1 218 | assert i2!=-1 219 | 220 | track = track[i1:i2-1] 221 | 222 | first_beta = track[0][1] 223 | first_perp_x = math.cos(first_beta) 224 | first_perp_y = math.sin(first_beta) 225 | # Length of perpendicular jump to put together head and tail 226 | well_glued_together = np.sqrt( 227 | np.square( first_perp_x*(track[0][2] - track[-1][2]) ) + 228 | np.square( first_perp_y*(track[0][3] - track[-1][3]) )) 229 | if well_glued_together > TRACK_DETAIL_STEP: 230 | return False 231 | 232 | # Red-white border on hard turns 233 | border = [False]*len(track) 234 | for i in range(len(track)): 235 | good = True 236 | oneside = 0 237 | for neg in range(BORDER_MIN_COUNT): 238 | beta1 = track[i-neg-0][1] 239 | beta2 = track[i-neg-1][1] 240 | good &= abs(beta1 - beta2) > TRACK_TURN_RATE*0.2 241 | oneside += np.sign(beta1 - beta2) 242 | good &= abs(oneside) == BORDER_MIN_COUNT 243 | border[i] = good 244 | for i in range(len(track)): 245 | for neg in range(BORDER_MIN_COUNT): 246 | border[i-neg] |= border[i] 247 | 248 | # Create tiles 249 | for i in range(len(track)): 250 | alpha1, beta1, x1, y1 = track[i] 251 | alpha2, beta2, x2, y2 = track[i-1] 252 | road1_l = (x1 - TRACK_WIDTH*math.cos(beta1), y1 - TRACK_WIDTH*math.sin(beta1)) 253 | road1_r = (x1 + TRACK_WIDTH*math.cos(beta1), y1 + TRACK_WIDTH*math.sin(beta1)) 254 | road2_l = (x2 - TRACK_WIDTH*math.cos(beta2), y2 - TRACK_WIDTH*math.sin(beta2)) 255 | road2_r = (x2 + TRACK_WIDTH*math.cos(beta2), y2 + TRACK_WIDTH*math.sin(beta2)) 256 | t = self.world.CreateStaticBody( fixtures = fixtureDef( 257 | shape=polygonShape(vertices=[road1_l, road1_r, road2_r, road2_l]) 258 | )) 259 | t.userData = t 260 | c = 0.01*(i%3) 261 | t.color = [ROAD_COLOR[0] + c, ROAD_COLOR[1] + c, ROAD_COLOR[2] + c] 262 | t.road_visited = False 263 | t.road_friction = 1.0 264 | t.fixtures[0].sensor = True 265 | self.road_poly.append(( [road1_l, road1_r, road2_r, road2_l], t.color )) 266 | self.road.append(t) 267 | if border[i]: 268 | side = np.sign(beta2 - beta1) 269 | b1_l = (x1 + side* TRACK_WIDTH *math.cos(beta1), y1 + side* TRACK_WIDTH *math.sin(beta1)) 270 | b1_r = (x1 + side*(TRACK_WIDTH+BORDER)*math.cos(beta1), y1 + side*(TRACK_WIDTH+BORDER)*math.sin(beta1)) 271 | b2_l = (x2 + side* TRACK_WIDTH *math.cos(beta2), y2 + side* TRACK_WIDTH *math.sin(beta2)) 272 | b2_r = (x2 + side*(TRACK_WIDTH+BORDER)*math.cos(beta2), y2 + side*(TRACK_WIDTH+BORDER)*math.sin(beta2)) 273 | self.road_poly.append(( [b1_l, b1_r, b2_r, b2_l], (1,1,1) if i%2==0 else (1,0,0) )) 274 | self.track = track 275 | return True 276 | 277 | def reset(self): 278 | self._destroy() 279 | self.reward = 0.0 280 | self.prev_reward = 0.0 281 | self.tile_visited_count = 0 282 | self.t = 0.0 283 | self.road_poly = [] 284 | self.human_render = False 285 | 286 | while True: 287 | success = self._create_track() 288 | if success: break 289 | print("retry to generate track (normal if there are not many of this messages)") 290 | self.car = Car(self.world, *self.track[0][1:4]) 291 | 292 | return self.step(None)[0] 293 | 294 | def step(self, action): 295 | if action is not None: 296 | self.car.steer(-action[0]) 297 | self.car.gas(action[1]) 298 | self.car.brake(action[2]) 299 | 300 | self.car.step(1.0/FPS) 301 | self.world.Step(1.0/FPS, 6*30, 2*30) 302 | self.t += 1.0/FPS 303 | 304 | self.state = self.render("state_pixels") 305 | 306 | step_reward = 0 307 | done = False 308 | if action is not None: # First step without action, called from reset() 309 | self.reward -= 0.1 310 | # We actually don't want to count fuel spent, we want car to be faster. 311 | #self.reward -= 10 * self.car.fuel_spent / ENGINE_POWER 312 | self.car.fuel_spent = 0.0 313 | step_reward = self.reward - self.prev_reward 314 | self.prev_reward = self.reward 315 | if self.tile_visited_count==len(self.track): 316 | done = True 317 | x, y = self.car.hull.position 318 | if abs(x) > PLAYFIELD or abs(y) > PLAYFIELD: 319 | done = True 320 | step_reward = -100 321 | 322 | return self.state, step_reward, done, {} 323 | 324 | def render(self, mode='human'): 325 | if self.viewer is None: 326 | from gym.envs.classic_control import rendering 327 | self.viewer = rendering.Viewer(WINDOW_W, WINDOW_H) 328 | self.score_label = pyglet.text.Label('0000', font_size=36, 329 | x=20, y=WINDOW_H*2.5/40.00, anchor_x='left', anchor_y='center', 330 | color=(255,255,255,255)) 331 | self.transform = rendering.Transform() 332 | 333 | if "t" not in self.__dict__: return # reset() not called yet 334 | 335 | zoom = ZOOM*SCALE # Animate zoom first second 336 | zoom_state = ZOOM*SCALE*STATE_W/WINDOW_W 337 | zoom_video = ZOOM*SCALE*VIDEO_W/WINDOW_W 338 | scroll_x = self.car.hull.position[0] 339 | scroll_y = self.car.hull.position[1] 340 | angle = -self.car.hull.angle 341 | vel = self.car.hull.linearVelocity 342 | if np.linalg.norm(vel) > 0.5: 343 | angle = math.atan2(vel[0], vel[1]) 344 | self.transform.set_scale(zoom, zoom) 345 | self.transform.set_translation( 346 | WINDOW_W/2 - (scroll_x*zoom*math.cos(angle) - scroll_y*zoom*math.sin(angle)), 347 | WINDOW_H/4 - (scroll_x*zoom*math.sin(angle) + scroll_y*zoom*math.cos(angle)) ) 348 | self.transform.set_rotation(angle) 349 | 350 | self.car.draw(self.viewer, mode!="state_pixels") 351 | 352 | arr = None 353 | win = self.viewer.window 354 | if mode != 'state_pixels': 355 | win.switch_to() 356 | win.dispatch_events() 357 | if mode=="rgb_array" or mode=="state_pixels": 358 | win.clear() 359 | t = self.transform 360 | if mode=='rgb_array': 361 | VP_W = VIDEO_W 362 | VP_H = VIDEO_H 363 | else: 364 | VP_W = STATE_W 365 | VP_H = STATE_H 366 | gl.glViewport(0, 0, VP_W, VP_H) 367 | t.enable() 368 | self.render_road() 369 | for geom in self.viewer.onetime_geoms: 370 | geom.render() 371 | t.disable() 372 | self.render_indicators(WINDOW_W, WINDOW_H) # TODO: find why 2x needed, wtf 373 | image_data = pyglet.image.get_buffer_manager().get_color_buffer().get_image_data() 374 | arr = np.fromstring(image_data.data, dtype=np.uint8, sep='') 375 | arr = arr.reshape(VP_H, VP_W, 4) 376 | arr = arr[::-1, :, 0:3] 377 | 378 | if mode=="rgb_array" and not self.human_render: # agent can call or not call env.render() itself when recording video. 379 | win.flip() 380 | 381 | if mode=='human': 382 | self.human_render = True 383 | win.clear() 384 | t = self.transform 385 | gl.glViewport(0, 0, WINDOW_W, WINDOW_H) 386 | t.enable() 387 | self.render_road() 388 | for geom in self.viewer.onetime_geoms: 389 | geom.render() 390 | t.disable() 391 | self.render_indicators(WINDOW_W, WINDOW_H) 392 | win.flip() 393 | 394 | self.viewer.onetime_geoms = [] 395 | return arr 396 | 397 | def close(self): 398 | if self.viewer is not None: 399 | self.viewer.close() 400 | self.viewer = None 401 | 402 | def render_road(self): 403 | gl.glBegin(gl.GL_QUADS) 404 | gl.glColor4f(0.4, 0.8, 0.4, 1.0) 405 | gl.glVertex3f(-PLAYFIELD, +PLAYFIELD, 0) 406 | gl.glVertex3f(+PLAYFIELD, +PLAYFIELD, 0) 407 | gl.glVertex3f(+PLAYFIELD, -PLAYFIELD, 0) 408 | gl.glVertex3f(-PLAYFIELD, -PLAYFIELD, 0) 409 | gl.glColor4f(0.4, 0.9, 0.4, 1.0) 410 | k = PLAYFIELD/20.0 411 | for x in range(-20, 20, 2): 412 | for y in range(-20, 20, 2): 413 | gl.glVertex3f(k*x + k, k*y + 0, 0) 414 | gl.glVertex3f(k*x + 0, k*y + 0, 0) 415 | gl.glVertex3f(k*x + 0, k*y + k, 0) 416 | gl.glVertex3f(k*x + k, k*y + k, 0) 417 | for poly, color in self.road_poly: 418 | gl.glColor4f(color[0], color[1], color[2], 1) 419 | for p in poly: 420 | gl.glVertex3f(p[0], p[1], 0) 421 | gl.glEnd() 422 | 423 | def render_indicators(self, W, H): 424 | gl.glBegin(gl.GL_QUADS) 425 | s = W/40.0 426 | h = H/40.0 427 | gl.glColor4f(0,0,0,1) 428 | gl.glVertex3f(W, 0, 0) 429 | gl.glVertex3f(W, 5*h, 0) 430 | gl.glVertex3f(0, 5*h, 0) 431 | gl.glVertex3f(0, 0, 0) 432 | def vertical_ind(place, val, color): 433 | gl.glColor4f(color[0], color[1], color[2], 1) 434 | gl.glVertex3f((place+0)*s, h + h*val, 0) 435 | gl.glVertex3f((place+1)*s, h + h*val, 0) 436 | gl.glVertex3f((place+1)*s, h, 0) 437 | gl.glVertex3f((place+0)*s, h, 0) 438 | def horiz_ind(place, val, color): 439 | gl.glColor4f(color[0], color[1], color[2], 1) 440 | gl.glVertex3f((place+0)*s, 4*h , 0) 441 | gl.glVertex3f((place+val)*s, 4*h, 0) 442 | gl.glVertex3f((place+val)*s, 2*h, 0) 443 | gl.glVertex3f((place+0)*s, 2*h, 0) 444 | true_speed = np.sqrt(np.square(self.car.hull.linearVelocity[0]) + np.square(self.car.hull.linearVelocity[1])) 445 | vertical_ind(5, 0.02*true_speed, (1,1,1)) 446 | vertical_ind(7, 0.01*self.car.wheels[0].omega, (0.0,0,1)) # ABS sensors 447 | vertical_ind(8, 0.01*self.car.wheels[1].omega, (0.0,0,1)) 448 | vertical_ind(9, 0.01*self.car.wheels[2].omega, (0.2,0,1)) 449 | vertical_ind(10,0.01*self.car.wheels[3].omega, (0.2,0,1)) 450 | horiz_ind(20, -10.0*self.car.wheels[0].joint.angle, (0,1,0)) 451 | horiz_ind(30, -0.8*self.car.hull.angularVelocity, (1,0,0)) 452 | gl.glEnd() 453 | self.score_label.text = "%04i" % self.reward 454 | self.score_label.draw() 455 | 456 | 457 | if __name__=="__main__": 458 | from pyglet.window import key 459 | a = np.array( [0.0, 0.0, 0.0] ) 460 | def key_press(k, mod): 461 | global restart 462 | if k==0xff0d: restart = True 463 | if k==key.LEFT: a[0] = -1.0 464 | if k==key.RIGHT: a[0] = +1.0 465 | if k==key.UP: a[1] = +1.0 466 | if k==key.DOWN: a[2] = +0.8 # set 1.0 for wheels to block to zero rotation 467 | def key_release(k, mod): 468 | if k==key.LEFT and a[0]==-1.0: a[0] = 0 469 | if k==key.RIGHT and a[0]==+1.0: a[0] = 0 470 | if k==key.UP: a[1] = 0 471 | if k==key.DOWN: a[2] = 0 472 | env = CarRacing() 473 | env.render() 474 | record_video = False 475 | if record_video: 476 | env.monitor.start('/tmp/video-test', force=True) 477 | env.viewer.window.on_key_press = key_press 478 | env.viewer.window.on_key_release = key_release 479 | while True: 480 | env.reset() 481 | total_reward = 0.0 482 | steps = 0 483 | restart = False 484 | while True: 485 | s, r, done, info = env.step(a) 486 | total_reward += r 487 | if steps % 200 == 0 or done: 488 | print("\naction " + str(["{:+0.2f}".format(x) for x in a])) 489 | print("step {} total_reward {:+0.2f}".format(steps, total_reward)) 490 | #import matplotlib.pyplot as plt 491 | #plt.imshow(s) 492 | #plt.savefig("test.jpeg") 493 | steps += 1 494 | if not record_video: # Faster, but you can as well call env.render() every time to play full window. 495 | env.render() 496 | if done or restart: break 497 | env.close() 498 | -------------------------------------------------------------------------------- /generate-rollouts.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import gym\n", 10 | "from car_racing import CarRacing" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 65, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "from random import random, randint, choice" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stdout", 29 | "output_type": "stream", 30 | "text": [ 31 | "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n" 32 | ] 33 | } 34 | ], 35 | "source": [ 36 | "env = gym.make('CarRacing-v0')" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 34, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "episodes = 1\n", 46 | "steps = 100" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 68, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "import numpy as np" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 4, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "data": { 65 | "text/plain": [ 66 | "array([0.08976637, 0.4236548 , 0.6458941 ], dtype=float32)" 67 | ] 68 | }, 69 | "execution_count": 4, 70 | "metadata": {}, 71 | "output_type": "execute_result" 72 | } 73 | ], 74 | "source": [ 75 | "env.action_space.sample()" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 37, 81 | "metadata": { 82 | "collapsed": true 83 | }, 84 | "outputs": [ 85 | { 86 | "name": "stdout", 87 | "output_type": "stream", 88 | "text": [ 89 | "Track generation: 1094..1380 -> 286-tiles track\n" 90 | ] 91 | }, 92 | { 93 | "ename": "NotImplementedError", 94 | "evalue": "abstract", 95 | "output_type": "error", 96 | "traceback": [ 97 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 98 | "\u001b[0;31mNotImplementedError\u001b[0m Traceback (most recent call last)", 99 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0meps\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepisodes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mobs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msteps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0maction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maction_space\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 100 | "\u001b[0;32m~/Documents/Github/gym/gym/wrappers/time_limit.py\u001b[0m in \u001b[0;36mreset\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_episode_started_at\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_elapsed_steps\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 44\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 101 | "\u001b[0;32m~/Documents/Github/gym/gym/envs/box2d/car_racing.py\u001b[0m in \u001b[0;36mreset\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 290\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcar\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mworld\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrack\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 291\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 292\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 293\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 294\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 102 | "\u001b[0;32m~/Documents/Github/gym/gym/envs/box2d/car_racing.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, action)\u001b[0m\n\u001b[1;32m 302\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1.0\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0mFPS\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 303\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 304\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"state_pixels\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 305\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[0mstep_reward\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 103 | "\u001b[0;32m~/Documents/Github/gym/gym/envs/box2d/car_racing.py\u001b[0m in \u001b[0;36mrender\u001b[0;34m(self, mode)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mviewer\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 326\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menvs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclassic_control\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mrendering\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 327\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mviewer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrendering\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mViewer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mWINDOW_W\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mWINDOW_H\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 328\u001b[0m self.score_label = pyglet.text.Label('0000', font_size=36,\n\u001b[1;32m 329\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mWINDOW_H\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;36m2.5\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m40.00\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0manchor_x\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'left'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0manchor_y\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'center'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 104 | "\u001b[0;32m~/Documents/Github/gym/gym/envs/classic_control/rendering.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, width, height, display)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwidth\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwidth\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mheight\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mheight\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 51\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwindow\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpyglet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwindow\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mWindow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwidth\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwidth\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mheight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mheight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdisplay\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdisplay\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 52\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwindow\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_close\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwindow_closed_by_user\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misopen\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 105 | "\u001b[0;32m~/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages/pyglet/window/__init__.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, width, height, caption, resizable, style, fullscreen, visible, vsync, display, screen, config, context, mode)\u001b[0m\n\u001b[1;32m 502\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 503\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mscreen\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 504\u001b[0;31m \u001b[0mscreen\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdisplay\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_default_screen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 505\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 506\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 106 | "\u001b[0;32m~/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages/pyglet/canvas/base.py\u001b[0m in \u001b[0;36mget_default_screen\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0mrtype\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;32mclass\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mScreen\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m '''\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_screens\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_windows\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 107 | "\u001b[0;32m~/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages/pyglet/canvas/base.py\u001b[0m in \u001b[0;36mget_screens\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0mrtype\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mlist\u001b[0m \u001b[0mof\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;32mclass\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0mScreen\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m '''\n\u001b[0;32m---> 65\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mNotImplementedError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'abstract'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 66\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_default_screen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 108 | "\u001b[0;31mNotImplementedError\u001b[0m: abstract" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "for eps in range(episodes):\n", 114 | " obs = env.reset()\n", 115 | " print(obs)\n", 116 | " for t in range(steps):\n", 117 | " action = env.action_space.sample()\n", 118 | " a, b, xc, d = env.step(action)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [] 127 | } 128 | ], 129 | "metadata": { 130 | "kernelspec": { 131 | "display_name": "Python 3", 132 | "language": "python", 133 | "name": "python3" 134 | }, 135 | "language_info": { 136 | "codemirror_mode": { 137 | "name": "ipython", 138 | "version": 3 139 | }, 140 | "file_extension": ".py", 141 | "mimetype": "text/x-python", 142 | "name": "python", 143 | "nbconvert_exporter": "python", 144 | "pygments_lexer": "ipython3", 145 | "version": "3.6.1" 146 | } 147 | }, 148 | "nbformat": 4, 149 | "nbformat_minor": 2 150 | } 151 | -------------------------------------------------------------------------------- /generate-rollouts.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Shubham Chandel 3 | # @Date: 2018-04-21 18:55:18 4 | # @Last Modified by: Shubham Chandel 5 | # @Last Modified time: 2018-04-22 03:02:35 6 | 7 | from car_racing import CarRacing 8 | 9 | import scipy.misc 10 | import numpy as np 11 | from random import choice, random, randint 12 | 13 | env = CarRacing() 14 | 15 | episodes = 100 16 | steps = 150 17 | 18 | def get_action(): 19 | # return choice([[1, 0, 0], [-1, 0, 0], [0, 1, 0]]) 20 | # return [2*random()-1.05, 1, 0] 21 | action = env.action_space.sample() 22 | action[0] -= 0.05 23 | action[1] = 1 24 | action[2] = 0 25 | return action 26 | 27 | for eps in range(episodes): 28 | obs = env.reset() 29 | env.render() 30 | r = 0 31 | for t in range(steps): 32 | action = get_action() 33 | obs, reward, done, _ = env.step(action) 34 | # env.render() 35 | r += reward 36 | if t%5 == 0: 37 | i = ('000' + str(t//5))[-3:] 38 | scipy.misc.imsave(f'rollouts/CarRacing/car_{eps}_{i}.jpg', obs) 39 | print("Episode [{}/{}]: CummReward {:.2f}".format(eps+1, episodes, r)) 40 | 41 | 42 | -------------------------------------------------------------------------------- /pushover.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: chanshub 3 | # @Date: 2018-01-29 23:03:52 4 | # @Last Modified by: Shubham Chandel 5 | # @Last Modified time: 2018-04-21 06:23:57 6 | 7 | import requests 8 | 9 | data = { 10 | 'token': "", 11 | 'user': "", 12 | } 13 | 14 | JUSTBECAUSE = "Just because I have to put something." 15 | 16 | def notify(t, s=JUSTBECAUSE, priority=-1, sound="classical"): 17 | data['message'] = str(s) 18 | data['title'] = str(t) 19 | data['priority'] = priority 20 | data['sound'] = sound 21 | r = requests.post('https://api.pushover.net/1/messages.json', data=data) 22 | print("Notifing", t, s, r) 23 | 24 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | from pathlib import Path 3 | 4 | 5 | def makegif(folder='reconstructed', name='recon_image', ext='jpg'): 6 | images = [] 7 | for file in sorted([file for file in Path(folder).glob(f'*.{ext}')]): 8 | images.append(imageio.imread(str(file))) 9 | imageio.mimsave(f'{name}.gif', images) 10 | print("Done!") 11 | -------------------------------------------------------------------------------- /vae-cnn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torch.nn.functional as F\n", 12 | "from torch.autograd import Variable\n", 13 | "\n", 14 | "import torchvision\n", 15 | "from torchvision import datasets\n", 16 | "from torchvision import transforms\n", 17 | "from torchvision.utils import save_image\n", 18 | "from torchsummary import summary\n", 19 | "\n", 20 | "from pushover import notify\n", 21 | "from utils import makegif\n", 22 | "from random import randint\n", 23 | "\n", 24 | "from IPython.display import Image\n", 25 | "from IPython.core.display import Image, display\n", 26 | "\n", 27 | "%load_ext autoreload\n", 28 | "%autoreload 2" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 10, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "data": { 38 | "text/plain": [ 39 | "device(type='cpu')" 40 | ] 41 | }, 42 | "execution_count": 10, 43 | "metadata": {}, 44 | "output_type": "execute_result" 45 | } 46 | ], 47 | "source": [ 48 | "# Device configuration\n", 49 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 50 | "device" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 2, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "bs = 32" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 3, 65 | "metadata": {}, 66 | "outputs": [ 67 | { 68 | "data": { 69 | "text/plain": [ 70 | "(30000, 938)" 71 | ] 72 | }, 73 | "execution_count": 3, 74 | "metadata": {}, 75 | "output_type": "execute_result" 76 | } 77 | ], 78 | "source": [ 79 | "# Load Data\n", 80 | "dataset = datasets.ImageFolder(root='./rollouts', transform=transforms.Compose([\n", 81 | " transforms.Resize(64),\n", 82 | " transforms.ToTensor(), \n", 83 | "]))\n", 84 | "dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=True)\n", 85 | "len(dataset.imgs), len(dataloader)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 4, 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "data": { 95 | "image/png": "\n", 96 | "text/plain": [ 97 | "" 98 | ] 99 | }, 100 | "execution_count": 4, 101 | "metadata": {}, 102 | "output_type": "execute_result" 103 | } 104 | ], 105 | "source": [ 106 | "# Fixed input for debugging\n", 107 | "fixed_x, _ = next(iter(dataloader))\n", 108 | "save_image(fixed_x, 'real_image.png')\n", 109 | "\n", 110 | "Image('real_image.png')" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 5, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "class Flatten(nn.Module):\n", 120 | " def forward(self, input):\n", 121 | " return input.view(input.size(0), -1)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 6, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "class UnFlatten(nn.Module):\n", 131 | " def forward(self, input, size=1024):\n", 132 | " return input.view(input.size(0), size, 1, 1)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 7, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "class VAE(nn.Module):\n", 142 | " def __init__(self, image_channels=3, h_dim=1024, z_dim=32):\n", 143 | " super(VAE, self).__init__()\n", 144 | " self.encoder = nn.Sequential(\n", 145 | " nn.Conv2d(image_channels, 32, kernel_size=4, stride=2),\n", 146 | " nn.ReLU(),\n", 147 | " nn.Conv2d(32, 64, kernel_size=4, stride=2),\n", 148 | " nn.ReLU(),\n", 149 | " nn.Conv2d(64, 128, kernel_size=4, stride=2),\n", 150 | " nn.ReLU(),\n", 151 | " nn.Conv2d(128, 256, kernel_size=4, stride=2),\n", 152 | " nn.ReLU(),\n", 153 | " Flatten()\n", 154 | " )\n", 155 | " \n", 156 | " self.fc1 = nn.Linear(h_dim, z_dim)\n", 157 | " self.fc2 = nn.Linear(h_dim, z_dim)\n", 158 | " self.fc3 = nn.Linear(z_dim, h_dim)\n", 159 | " \n", 160 | " self.decoder = nn.Sequential(\n", 161 | " UnFlatten(),\n", 162 | " nn.ConvTranspose2d(h_dim, 128, kernel_size=5, stride=2),\n", 163 | " nn.ReLU(),\n", 164 | " nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2),\n", 165 | " nn.ReLU(),\n", 166 | " nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2),\n", 167 | " nn.ReLU(),\n", 168 | " nn.ConvTranspose2d(32, image_channels, kernel_size=6, stride=2),\n", 169 | " nn.Sigmoid(),\n", 170 | " )\n", 171 | " \n", 172 | " def reparameterize(self, mu, logvar):\n", 173 | " std = logvar.mul(0.5).exp_()\n", 174 | " # return torch.normal(mu, std)\n", 175 | " esp = torch.randn(*mu.size())\n", 176 | " z = mu + std * esp\n", 177 | " return z\n", 178 | " \n", 179 | " def bottleneck(self, h):\n", 180 | " mu, logvar = self.fc1(h), self.fc2(h)\n", 181 | " z = self.reparameterize(mu, logvar)\n", 182 | " return z, mu, logvar\n", 183 | "\n", 184 | " def encode(self, x):\n", 185 | " h = self.encoder(x)\n", 186 | " z, mu, logvar = self.bottleneck(h)\n", 187 | " return z, mu, logvar\n", 188 | "\n", 189 | " def decode(self, z):\n", 190 | " z = self.fc3(z)\n", 191 | " z = self.decoder(z)\n", 192 | " return z\n", 193 | "\n", 194 | " def forward(self, x):\n", 195 | " z, mu, logvar = self.encode(x)\n", 196 | " z = self.decode(z)\n", 197 | " return z, mu, logvar" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 8, 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "image_channels = fixed_x.size(1)" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 12, 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "model = VAE(image_channels=image_channels).to(device)\n", 216 | "model.load_state_dict(torch.load('vae.torch', map_location='cpu'))" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 13, 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3) " 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 14, 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "def loss_fn(recon_x, x, mu, logvar):\n", 235 | " BCE = F.binary_cross_entropy(recon_x, x, size_average=False)\n", 236 | " # BCE = F.mse_loss(recon_x, x, size_average=False)\n", 237 | "\n", 238 | " # see Appendix B from VAE paper:\n", 239 | " # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014\n", 240 | " # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)\n", 241 | " KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())\n", 242 | "\n", 243 | " return BCE + KLD, BCE, KLD" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 14, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "!rm -rfr reconstructed\n", 253 | "!mkdir reconstructed" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 12, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "epochs = 50" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 27, 268 | "metadata": {}, 269 | "outputs": [], 270 | "source": [ 271 | "for epoch in range(epochs):\n", 272 | " for idx, (images, _) in enumerate(dataloader):\n", 273 | " recon_images, mu, logvar = vae(images)\n", 274 | " loss, bce, kld = loss_fn(recon_images, images, mu, logvar)\n", 275 | " optimizer.zero_grad()\n", 276 | " loss.backward()\n", 277 | " optimizer.step()\n", 278 | "\n", 279 | " to_print = \"Epoch[{}/{}] Loss: {:.3f} {:.3f} {:.3f}\".format(epoch+1, \n", 280 | " epochs, loss.data[0]/bs, bce.data[0]/bs, kld.data[0]/bs)\n", 281 | " print(to_print)\n", 282 | "\n", 283 | "# notify to android when finished training\n", 284 | "notify(to_print, priority=1)\n", 285 | "\n", 286 | "torch.save(vae.state_dict(), 'vae.torch')" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 15, 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "def compare(x):\n", 296 | " recon_x, _, _ = vae(x)\n", 297 | " return torch.cat([x, recon_x])" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 22, 303 | "metadata": {}, 304 | "outputs": [ 305 | { 306 | "data": { 307 | "image/png": "\n", 308 | "text/plain": [ 309 | "" 310 | ] 311 | }, 312 | "metadata": { 313 | "image/png": { 314 | "unconfined": true, 315 | "width": 700 316 | } 317 | }, 318 | "output_type": "display_data" 319 | } 320 | ], 321 | "source": [ 322 | "# sample = torch.randn(bs, 1024)\n", 323 | "# compare_x = vae.decoder(sample)\n", 324 | "\n", 325 | "# fixed_x, _ = next(iter(dataloader))\n", 326 | "# fixed_x = fixed_x[:8]\n", 327 | "fixed_x = dataset[randint(1, 100)][0].unsqueeze(0)\n", 328 | "compare_x = compare(fixed_x)\n", 329 | "\n", 330 | "save_image(compare_x.data.cpu(), 'sample_image.png')\n", 331 | "display(Image('sample_image.png', width=700, unconfined=True))" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": null, 337 | "metadata": {}, 338 | "outputs": [], 339 | "source": [] 340 | } 341 | ], 342 | "metadata": { 343 | "kernelspec": { 344 | "display_name": "Python 3", 345 | "language": "python", 346 | "name": "python3" 347 | }, 348 | "language_info": { 349 | "codemirror_mode": { 350 | "name": "ipython", 351 | "version": 3 352 | }, 353 | "file_extension": ".py", 354 | "mimetype": "text/x-python", 355 | "name": "python", 356 | "nbconvert_exporter": "python", 357 | "pygments_lexer": "ipython3", 358 | "version": "3.6.4" 359 | } 360 | }, 361 | "nbformat": 4, 362 | "nbformat_minor": 2 363 | } 364 | -------------------------------------------------------------------------------- /vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | class Flatten(nn.Module): 7 | def forward(self, input): 8 | return input.view(input.size(0), -1) 9 | 10 | 11 | class UnFlatten(nn.Module): 12 | def forward(self, input, size=1024): 13 | return input.view(input.size(0), size, 1, 1) 14 | 15 | class VAE(nn.Module): 16 | def __init__(self, image_channels=3, h_dim=1024, z_dim=32): 17 | super(VAE, self).__init__() 18 | self.encoder = nn.Sequential( 19 | nn.Conv2d(image_channels, 32, kernel_size=4, stride=2), 20 | nn.ReLU(), 21 | nn.Conv2d(32, 64, kernel_size=4, stride=2), 22 | nn.ReLU(), 23 | nn.Conv2d(64, 128, kernel_size=4, stride=2), 24 | nn.ReLU(), 25 | nn.Conv2d(128, 256, kernel_size=4, stride=2), 26 | nn.ReLU(), 27 | Flatten() 28 | ) 29 | 30 | self.fc1 = nn.Linear(h_dim, z_dim) 31 | self.fc2 = nn.Linear(h_dim, z_dim) 32 | self.fc3 = nn.Linear(z_dim, h_dim) 33 | 34 | self.decoder = nn.Sequential( 35 | UnFlatten(), 36 | nn.ConvTranspose2d(h_dim, 128, kernel_size=5, stride=2), 37 | nn.ReLU(), 38 | nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2), 39 | nn.ReLU(), 40 | nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2), 41 | nn.ReLU(), 42 | nn.ConvTranspose2d(32, image_channels, kernel_size=6, stride=2), 43 | nn.Sigmoid(), 44 | ) 45 | 46 | def reparameterize(self, mu, logvar): 47 | std = logvar.mul(0.5).exp_() 48 | # return torch.normal(mu, std) 49 | esp = torch.randn(*mu.size()) 50 | z = mu + std * esp 51 | return z 52 | 53 | def bottleneck(self, h): 54 | mu, logvar = self.fc1(h), self.fc2(h) 55 | z = self.reparameterize(mu, logvar) 56 | return z, mu, logvar 57 | 58 | def representation(self, x): 59 | return self.bottleneck(self.encoder(x))[0] 60 | 61 | def forward(self, x): 62 | h = self.encoder(x) 63 | z, mu, logvar = self.bottleneck(h) 64 | z = self.fc3(z) 65 | return self.decoder(z), mu, logvar 66 | 67 | --------------------------------------------------------------------------------