├── .gitignore ├── README.md ├── asset ├── CarRacing-v0_rollout.gif ├── DoomTakeCover_dream_rollout.gif ├── DoomTakeCover_rollout.gif ├── cma-es.gif ├── mdn-data_only.png ├── mdn-linear_with_preds.png ├── mdn-with_preds.png └── world_model_schematic.png ├── controller.py ├── lib ├── constants.py ├── data.py ├── env_wrappers.py └── utils.py ├── model.py ├── random_rollouts.py ├── test.py ├── toy ├── cma-es.py └── mdn.py └── vision.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__/ 3 | .ipynb_checkpoints 4 | result/ 5 | _vizdoom.ini -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # World Models Implementation In Chainer 2 | This is a fairly complete implementation, in [Chainer](https://chainer.org), of the World Models framework described by David Ha and Jürgen Schmidhuber: https://arxiv.org/abs/1803.10122 3 | 4 | This project was created as part of my MSc Artificial Intelligence dissertation at the University of Edinburgh, under the supervision of [Subramanian Ramamoorthy](http://homepages.inf.ed.ac.uk/sramamoo/), with guidance from fellow PhD student [Svetlin Penkov](https://www.linkedin.com/in/svpenkov/). The scope of my research is model-based learning. Particularly, investigating the use of external memory available to a RNN to learn complex models. And thus I will be extending this framework accordingly, though these extensions are not present in this repository. 5 | 6 | > ![](asset/CarRacing-v0_rollout.gif) My best trained World Models agent playing [CarRacing-v0](https://gym.openai.com/envs/CarRacing-v0/). 7 | 8 | ## World Models Summary 9 | 10 | Here is a quick summary of Ha & Schmidhuber's World Models framework. The framework aims to train an agent that can perform well in virtual gaming environments. Ha & Schmidhuber's experiments were done in the [CarRacing-v0](https://gym.openai.com/envs/CarRacing-v0/) (from [OpenAI gym](https://gym.openai.com/)), and [ViZDoom: Take Cover](https://github.com/mwydmuch/ViZDoom/tree/master/scenarios#take-cover) environments. 11 | 12 | 13 | World Models consists of three main components: Vision (**V**), Model (**M**), and Controller (**C**) that interact together to form an agent: 14 | 15 | > ![World Models](asset/world_model_schematic.png) 16 | 17 | **V** consists of a convolutional [Variational Autoencoder (VAE)](https://arxiv.org/abs/1606.05908), which compresses frames taken from the gameplay into a latent vector *z*. **M** consists of a [Mixture Density Network (MDN)](https://publications.aston.ac.uk/373/1/NCRG_94_004.pdf), which involves outputting a [Mixture Density Model](https://en.wikipedia.org/wiki/Mixture_model) from a [Recurrent Neural Network (RNN)](https://en.wikipedia.org/wiki/Recurrent_neural_network). This MDN-RNN takes latent vectors *z* from **V** and predicts the next frame. And finally **C** is a simple single layer linear model that maps the output from **M** to actions to perform in the environment. **C** is trained using [Evolution Strategies](https://blog.openai.com/evolution-strategies/), particularly the [CMA-ES](https://arxiv.org/abs/1604.00772) algorithm. 18 | 19 | Most interesting (to me, at least) are the MDN-RNN used by **M**, and CMA-ES used by **C**, which are briefly summarized here. 20 | 21 | ###### MDN 22 | [Mixture Density Networks](https://publications.aston.ac.uk/373/1/NCRG_94_004.pdf) were developed by [Christopher Bishop](https://en.wikipedia.org/wiki/Christopher_Bishop) (who received his PhD from the University of Edinburgh). MDNs combine a mixture density model with a neural network. The goal of the neural network is to, rather than output a desired result directly, output the parameters of a mixture model which can be used to sample from to get the ultimate output. The parameters the neural network outputs include a set of probabilities, set of means, and set of standard deviations. 23 | 24 | This approach is particularly useful when the desired output cannot simply be the average of likely correct outputs. In such data, mean square error (MSE) will not work. Take for example the toy problem below: 25 | > ![](asset/mdn-data_only.png) 26 | 27 | The blue dots represent the desired y value given an x value. So at x=0.25, y could be {0, 0.5, 1} (roughly), but it might not make sense to be the average of the three. 28 | 29 | If trying to predict the outputs directly through a simple feedforward neural network using MSE, here's what we get: 30 | > ![](asset/mdn-linear_with_preds.png) 31 | 32 | Instead, if we use a simple MDN and sample from it, we get a good fit: 33 | > ![](asset/mdn-with_preds.png) 34 | 35 | *The code for this toy problem is available in [toy/mdn.py](toy/mdn.py).* 36 | 37 | ###### CMA-ES 38 | According to OpenAI, [Evolution Strategies are a scalable alternative to Reinforcement Learning](https://arxiv.org/abs/1703.03864). Where Reinforcement Learning is a guess and check on the actions, Evolution Strategies are a guess and check on the model parameters themselves. A "population" of "mutations" to seed parameters is created, and all mutated parameters are checked for fitness, and the seed adjusted towards the mean of the fittest mutations. [CMA-ES](https://arxiv.org/abs/1604.00772) is a particular evolution strategy where the covariance matrix is adapted, to cast a wider net for the mutations, in an attempt to search for the solution. 39 | 40 | To demonstrate, here is a toy problem. Consider a shifted [Schaffer](https://en.wikipedia.org/wiki/Test_functions_for_optimization) function with a solution at (10,10). So the parameters values being sought should be (10,10). We can have our fitness function return the square error between the parameters being tested, and the actual solution, against the Schaffer function. The animation below depicts how CMA-ES creates populations of parameters that are tested against the fitness function. The blue dot represents the solution. The red dots the entire population being tested. And the green dot the mean of the population as it evolves, which eventually fits the solution. You see the "net" the algorithm casts (the covariance matrix) from which the population is sampled, is adapted as it is further or closer to the solution based on the fitness score. How cool! 41 | 42 | > ![](asset/cma-es.gif) 43 | 44 | *The code for this toy problem is available in [toy/cma-es.py](toy/cma-es.py). I translated the *(mu/mu_w, lambda)-CMA-ES* algorithm to Python as simply as I could.* 45 | 46 | ## Results 47 | 48 | Available to me is a cluster powered with NVidia 1060 Ti GPUs. And a CPU cluster consisting of 12 machines with 40-48 CPUs per machine. Unfortunately, the machines on the CPU cluster do not have GPUs, which would speed up training for the controller. 49 | 50 | All setup and hyperparameters were kept exactly the same as Ha & Schmidhuber's, except where they were not explicit. For example, 10,000 rollouts were used to collect frames to train **V**, 5 mixture models were used in MDN-RNN, N_z was set to 32 and 64 for CarRacing-v0 and ViZDoom: Take Cover respectively, and so on. 51 | 52 | ###### CarRacing-v0 53 | 54 | Task is considered solved if the average score over 100 consecutive rollouts is greater than 900. 55 | 56 | * **Random agent**: Mean score -72 +/- 3 over 100 rollouts 57 | * **Trained agent**: Mean score 753 +/- 13 over 100 rollouts* 58 | **Ended controller training early due to time constraints--will run longer and update final results* 59 | 60 | ###### ViZDoom: Take Cover 61 | 62 | Task is considered solved if the average score over 100 consecutive rollouts is greater than 750. 63 | 64 | * **Random agent**: Mean score 278 +/- 100 over 100 rollouts 65 | * **Trained agent**: Mean score 680 +/- 411 over 100 rollouts* 66 | **Ended controller training early due to time constraints--will run longer and update final results* 67 | 68 | > ![](asset/DoomTakeCover_dream_rollout.gif) An agent roaming around in its dream. 69 | 70 | > ![](asset/DoomTakeCover_rollout.gif) An agent trained in its dream playing the actual game. 71 | 72 | ## Usage 73 | 74 | ### Setup 75 | 76 | * `conda install numpy chainer scipy Pillow imageio numba cupy` *(cupy if using GPU)* 77 | * `pip install gym Box2D vizdoom` *(ViZDoom build [pre-setup](https://github.com/mwydmuch/ViZDoom/blob/master/doc/Building.md))* 78 | 79 | ### Running 80 | 81 | Some base notes: 82 | * The OpenAI gym file [car_racing.py](https://github.com/openai/gym/blob/master/gym/envs/box2d/car_racing.py) has a bug. Edit it in your local environment with my proposed: [solution](https://github.com/openai/gym/issues/976#issuecomment-395486438). 83 | * Prepend *xvfb-run -e /dev/stdout -s "-screen 0 1400x900x24"* to #1, #4, and #5 if running headlessly on a server, for CarRacing-v0. 84 | * *--data_dir* defines the base directory where all output (results, snapshots, samples, etc) are stored. 85 | * In most cases, samples are generated throughout training and placed in the output folder. 86 | * In most cases, sub-processes equal to the number of CPU cores are fired off for parallelization, and can be changed by using the *--cores* argument. 87 | * GPU support is coded in most cases, using the *--gpu* or *--gpus* flags. 88 | * *--snapshot_interval* defines the interval to keep snapshots through training for #2-#4, which can be used to resume. 89 | * Ha & Schmidhuber used [DoomTakeCover-v0](https://gym.openai.com/envs/DoomTakeCover-v0/) from OpenAI gym, which appears to be deprecated. So I wrote my own simple wrapper directly to ViZDoom and loaded the *Take Cover* configuration. 90 | 91 | 92 | ### 1. Random Rollouts 93 | 94 | `python random_rollouts.py --game CarRacing-v0 --num_rollouts 10000` 95 | or 96 | `python random_rollouts.py --game DoomTakeCover --num_rollouts 10000` 97 | 98 | **Notes:** 99 | * Perform random rollouts to record gameplay to train the components. 100 | 101 | ### 2. Vision (V) 102 | 103 | `python vision.py --game CarRacing-v0 --z_dim 32 --epoch 1` 104 | or 105 | `python vision.py --game DoomTakeCover --z_dim 64 --epoch 1` 106 | 107 | **Notes:** 108 | * Main hurdle was fitting all frames from 10,000 rollouts in memory. In the case of CarRacing-v0, that's 10 million frames. So I implemented parallelized batched loading of frames. 109 | * You'll want to adjust *--load_batch_size* according to the memory you have available. A low setting in low memory is fine but will be slower. 110 | 111 | ### 3. Model (M) 112 | 113 | `python model.py --game CarRacing-v0 --z_dim 32 --hidden_dim 256 --mixtures 5 --epoch 20` 114 | or 115 | `python model.py --game DoomTakeCover --z_dim 64 --hidden_dim 512 --mixtures 5 --predict_done --epoch 20` 116 | 117 | **Notes:** 118 | 119 | * At the end of training, a dream rollout is generated just for fun. 120 | 121 | ### 4. Controller (C) 122 | 123 | `python controller.py --game CarRacing-v0 --lambda_ 64 --mu 0.25 --trials 16 --target_cumulative_reward 900 --z_dim 32 --hidden_dim 256 --mixtures 5 --temperature 1.0 --weights_type 1 [--cluster_mode]` 124 | or 125 | `python controller.py --game DoomTakeCover --lambda_ 64 --mu 0.25 --trials 16 --target_cumulative_reward 2050 --z_dim 64 --hidden_dim 512 --mixtures 5 --temperature 1.15 --weights_type 2 --in_dream --dream_max_len 2100 --initial_z_noise 0.5 --predict_done --done_threshold 0.5 [--cluster_mode]` 126 | 127 | **Notes:** 128 | * *--cluster_mode* allows you to split a generation of CMA-ES over a compute cluster. 129 | * A dispatcher is set in the *CLUSTER_DISPATCHER* variable, and workers in *CLUSTER_WORKERS* (the hostnames or IPs). 130 | * The full generation of 64 mutations will dynamically adjust to run on all cores on all machines, to evenly spread the load. 131 | * Start controller.py on all the worker nodes first, then the dispatcher node. 132 | * Make sure a firewall is not blocking the ports configured in the *_PORT variables in controller.py. 133 | * Not sure what *mu* value was used for CMA-ES, but David Ha's blog posts on other experiments seem to indicate her prefers 25%, so I used 0.25 with good results. 134 | * Training the controller will take a while (on the scale of days or weeks)! Even with multiple GPUs and clusters of CPUs (both together would be ideal, though not tested). 135 | * For dream training, I had to add noise to the initial frame picked from a real game (*--initial_z_noise*), in order for it to have a higher variety of scenarios. This is different than Ha & Schmidhuber as far as I know. 136 | * Since the controller can take excessively long to train, I added a simple form of curriculum learning to speed it up, which is also different than Ha & Schmidhuber. To set, use the *--curriculum* flage. For CarRacing-v0, "50,5" (initial max timesteps 50, increase by 5 when average cumulative score increases over a generation) allows it to learn the game progressively and seems to work well. For DoomTakeCover, try "500,10". 137 | * After your controller is trained, celebrate by proclaiming [Evolution Complete!](https://coub.com/view/o8exp) 138 | 139 | ### 5. Test 140 | 141 | `python test.py --game CarRacing-v0 --z_dim 32 --hidden_dim 256 --mixtures 5 --temperature 1.0 --weights_type 1 --rollouts 100 [--record]` 142 | or 143 | `python test.py --game DoomTakeCover --z_dim 64 --hidden_dim 512 --mixtures 5 --temperature 1.15 --weights_type 2 --predict_done --rollouts 100 [--record]` 144 | 145 | **Notes:** 146 | * Take your final "agent" (trained V + M + C) for a spin. 147 | * It will report back the mean score and standard deviation over the desired number of *--rollouts*. 148 | * *--record* will save all rollouts as gifs. 149 | 150 | ## License 151 | 152 | The original research for World Models was conducted by Ha & Schmidhuber, but this code is entirely written by myself (Adeel Mufti). I am releasing it under the [MIT](https://opensource.org/licenses/MIT) license. -------------------------------------------------------------------------------- /asset/CarRacing-v0_rollout.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdeelMufti/WorldModels/fa3d2b95633ad8b7f8d95783aa180369c91b5476/asset/CarRacing-v0_rollout.gif -------------------------------------------------------------------------------- /asset/DoomTakeCover_dream_rollout.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdeelMufti/WorldModels/fa3d2b95633ad8b7f8d95783aa180369c91b5476/asset/DoomTakeCover_dream_rollout.gif -------------------------------------------------------------------------------- /asset/DoomTakeCover_rollout.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdeelMufti/WorldModels/fa3d2b95633ad8b7f8d95783aa180369c91b5476/asset/DoomTakeCover_rollout.gif -------------------------------------------------------------------------------- /asset/cma-es.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdeelMufti/WorldModels/fa3d2b95633ad8b7f8d95783aa180369c91b5476/asset/cma-es.gif -------------------------------------------------------------------------------- /asset/mdn-data_only.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdeelMufti/WorldModels/fa3d2b95633ad8b7f8d95783aa180369c91b5476/asset/mdn-data_only.png -------------------------------------------------------------------------------- /asset/mdn-linear_with_preds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdeelMufti/WorldModels/fa3d2b95633ad8b7f8d95783aa180369c91b5476/asset/mdn-linear_with_preds.png -------------------------------------------------------------------------------- /asset/mdn-with_preds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdeelMufti/WorldModels/fa3d2b95633ad8b7f8d95783aa180369c91b5476/asset/mdn-with_preds.png -------------------------------------------------------------------------------- /asset/world_model_schematic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdeelMufti/WorldModels/fa3d2b95633ad8b7f8d95783aa180369c91b5476/asset/world_model_schematic.png -------------------------------------------------------------------------------- /controller.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import re 5 | from multiprocessing import cpu_count, Pool 6 | from multiprocessing.pool import ThreadPool 7 | from threading import Thread, Lock, Event 8 | import socket 9 | from io import BytesIO 10 | import math 11 | import ast 12 | import traceback 13 | 14 | import chainer 15 | import chainer.functions as F 16 | 17 | try: 18 | import cupy as cp 19 | from chainer.backends import cuda 20 | except Exception as e: 21 | None 22 | import numpy as np 23 | import gym 24 | from scipy.misc import imresize 25 | import imageio 26 | 27 | from lib.utils import log, mkdir, pre_process_image_tensor, post_process_image_tensor 28 | try: 29 | from lib.env_wrappers import ViZDoomWrapper 30 | except Exception as e: 31 | None 32 | from lib.constants import DOOM_GAMES 33 | from model import MDN_RNN 34 | from vision import CVAE 35 | from lib.data import ModelDataset 36 | 37 | ID = "controller" 38 | 39 | CLUSTER_WORKERS = ['machine01','machine02','machine03','machine04','machine05','machine06', 40 | 'machine07','machine08','machine09','machine10','machine11','machine12'] 41 | CLUSTER_DISPATCHER = 'machine01' 42 | CLUSTER_DISPATCHER_PORT = 9955 43 | CLUSTER_WORKER_PORT = 9956 44 | cluster_cumulative_rewards = {} 45 | lock = Lock() 46 | 47 | initial_z_t = None 48 | 49 | def action(args, W_c, b_c, z_t, h_t, c_t, gpu): 50 | if args.weights_type == 1: 51 | input = F.concat((z_t, h_t), axis=0).data 52 | action = F.tanh(W_c.dot(input) + b_c).data 53 | elif args.weights_type == 2: 54 | input = F.concat((z_t, h_t, c_t), axis=0).data 55 | dot = W_c.dot(input) 56 | if gpu is not None: 57 | dot = cp.asarray(dot) 58 | else: 59 | dot = np.asarray(dot) 60 | output = F.tanh(dot).data 61 | if output == 1.: 62 | output = 0.999 63 | action_dim = args.action_dim + 1 64 | action_range = 2 / action_dim 65 | action = [0. for i in range(action_dim)] 66 | start = -1. 67 | for i in range(action_dim): 68 | if start <= output and output <= (start + action_range): 69 | action[i] = 1. 70 | break 71 | start += action_range 72 | mid = action_dim // 2 # reserve action[mid] for no action 73 | action = action[0:mid] + action[mid + 1:action_dim] 74 | if gpu is not None: 75 | action = cp.asarray(action).astype(cp.float32) 76 | else: 77 | action = np.asarray(action).astype(np.float32) 78 | return action 79 | 80 | 81 | def transform_to_weights(args, parameters): 82 | if args.weights_type == 1: 83 | W_c = parameters[0:args.action_dim * (args.z_dim + args.hidden_dim)].reshape(args.action_dim, 84 | args.z_dim + args.hidden_dim) 85 | b_c = parameters[args.action_dim * (args.z_dim + args.hidden_dim):] 86 | elif args.weights_type == 2: 87 | W_c = parameters 88 | b_c = None 89 | return W_c, b_c 90 | 91 | 92 | def rollout(rollout_arg_tuple): 93 | try: 94 | global initial_z_t 95 | generation, mutation_idx, trial, args, vision, model, gpu, W_c, b_c, max_timesteps, with_frames = rollout_arg_tuple 96 | 97 | # The same starting seed gets passed in multiprocessing, need to reset it for each process: 98 | np.random.seed() 99 | 100 | if not with_frames: 101 | log(ID, ">>> Starting generation #" + str(generation) + ", mutation #" + str( 102 | mutation_idx + 1) + ", trial #" + str(trial + 1)) 103 | else: 104 | frames_array = [] 105 | start_time = time.time() 106 | 107 | model.reset_state() 108 | 109 | if args.in_dream: 110 | z_t, _, _, _ = initial_z_t[np.random.randint(len(initial_z_t))] 111 | z_t = z_t[0] 112 | if gpu is not None: 113 | z_t = cuda.to_gpu(z_t) 114 | if with_frames: 115 | observation = vision.decode(z_t).data 116 | if gpu is not None: 117 | observation = cp.asnumpy(observation) 118 | observation = post_process_image_tensor(observation)[0] 119 | else: 120 | # free up precious GPU memory: 121 | if gpu is not None: 122 | vision.to_cpu() 123 | vision = None 124 | if args.initial_z_noise > 0.: 125 | if gpu is not None: 126 | z_t += cp.random.normal(0., args.initial_z_noise, z_t.shape).astype(cp.float32) 127 | else: 128 | z_t += np.random.normal(0., args.initial_z_noise, z_t.shape).astype(np.float32) 129 | else: 130 | if args.game in DOOM_GAMES: 131 | env = ViZDoomWrapper(args.game) 132 | else: 133 | env = gym.make(args.game) 134 | observation = env.reset() 135 | if with_frames: 136 | frames_array.append(observation) 137 | 138 | if gpu is not None: 139 | h_t = cp.zeros(args.hidden_dim).astype(cp.float32) 140 | c_t = cp.zeros(args.hidden_dim).astype(cp.float32) 141 | else: 142 | h_t = np.zeros(args.hidden_dim).astype(np.float32) 143 | c_t = np.zeros(args.hidden_dim).astype(np.float32) 144 | 145 | done = False 146 | cumulative_reward = 0 147 | t = 0 148 | while not done: 149 | if not args.in_dream: 150 | observation = imresize(observation, (args.frame_resize, args.frame_resize)) 151 | observation = pre_process_image_tensor(np.expand_dims(observation, 0)) 152 | 153 | if gpu is not None: 154 | observation = cuda.to_gpu(observation) 155 | z_t = vision.encode(observation, return_z=True).data[0] 156 | 157 | a_t = action(args, W_c, b_c, z_t, h_t, c_t, gpu) 158 | 159 | if args.in_dream: 160 | z_t, done = model(z_t, a_t, temperature=args.temperature) 161 | done = done.data[0] 162 | if with_frames: 163 | observation = post_process_image_tensor(vision.decode(z_t).data)[0] 164 | reward = 1 165 | if done >= args.done_threshold: 166 | done = True 167 | else: 168 | done = False 169 | else: 170 | observation, reward, done, _ = env.step(a_t if gpu is None else cp.asnumpy(a_t)) 171 | model(z_t, a_t, temperature=args.temperature) 172 | if with_frames: 173 | frames_array.append(observation) 174 | 175 | cumulative_reward += reward 176 | 177 | h_t = model.get_h().data[0] 178 | c_t = model.get_c().data[0] 179 | 180 | t += 1 181 | if max_timesteps is not None and t == max_timesteps: 182 | break 183 | elif args.in_dream and t == args.dream_max_len: 184 | log(ID, 185 | ">>> generation #{}, mutation #{}, trial #{}: maximum length of {} timesteps reached in dream!" 186 | .format(generation, str(mutation_idx + 1), str(trial + 1), t)) 187 | break 188 | 189 | if not args.in_dream: 190 | env.close() 191 | 192 | if not with_frames: 193 | log(ID, 194 | ">>> Finished generation #{}, mutation #{}, trial #{} in {} timesteps in {:.2f}s with cumulative reward {:.2f}" 195 | .format(generation, str(mutation_idx + 1), str(trial + 1), t, (time.time() - start_time), 196 | cumulative_reward)) 197 | return cumulative_reward 198 | else: 199 | frames_array = np.asarray(frames_array) 200 | if args.game in DOOM_GAMES and not args.in_dream: 201 | frames_array = post_process_image_tensor(frames_array) 202 | return cumulative_reward, np.asarray(frames_array) 203 | except Exception: 204 | print(traceback.format_exc()) 205 | return 0. 206 | 207 | 208 | def rollout_worker(worker_arg_tuple): 209 | generation, mutation_idx, args, vision, model, mutation, max_timesteps, in_parallel = worker_arg_tuple 210 | W_c, b_c = transform_to_weights(args, mutation) 211 | 212 | log(ID, ">> Starting generation #" + str(generation) + ", mutation #" + str(mutation_idx + 1)) 213 | start_time = time.time() 214 | 215 | rollout_arg_tuples = [] 216 | cumulative_rewards = [] 217 | for trial in range(args.trials): 218 | this_vision = vision.copy() 219 | this_model = model.copy() 220 | gpu = None 221 | if isinstance(args.gpus, (list,)): 222 | gpu = args.gpus[mutation_idx % len(args.gpus)] 223 | elif args.gpu >= 0: 224 | gpu = args.gpu 225 | if gpu is not None: 226 | # log(ID,"Assigning GPU "+str(gpu)) 227 | cp.cuda.Device(gpu).use() 228 | this_vision.to_gpu() 229 | this_model.to_gpu() 230 | W_c = cuda.to_gpu(W_c) 231 | if b_c is not None: 232 | b_c = cuda.to_gpu(b_c) 233 | if in_parallel: 234 | rollout_arg_tuples.append( 235 | (generation, mutation_idx, trial, args, this_vision, this_model, gpu, W_c, b_c, max_timesteps, False)) 236 | else: 237 | cumulative_reward = rollout( 238 | (generation, mutation_idx, trial, args, this_vision, this_model, gpu, W_c, b_c, max_timesteps, False)) 239 | cumulative_rewards.append(cumulative_reward) 240 | if in_parallel: 241 | pool = Pool(args.trials) 242 | cumulative_rewards = pool.map(rollout, rollout_arg_tuples) 243 | pool.close() 244 | pool.join() 245 | 246 | avg_cumulative_reward = np.mean(cumulative_rewards) 247 | 248 | log(ID, ">> Finished generation #{}, mutation #{}, in {:.2f}s with averge cumulative reward {:.2f} over {} trials" 249 | .format(generation, (mutation_idx + 1), (time.time() - start_time), avg_cumulative_reward, args.trials)) 250 | 251 | return avg_cumulative_reward 252 | 253 | 254 | class WorkerServer(object): 255 | def __init__(self, port, args, vision, model): 256 | self.args = args 257 | self.vision = vision 258 | self.model = model 259 | self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 260 | self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 261 | self.sock.bind(('', port)) 262 | self.listen() 263 | 264 | def listen(self): 265 | self.sock.listen(10) 266 | while True: 267 | client, address = self.sock.accept() 268 | client.settimeout(10) 269 | Thread(target=self.listenToClient, args=(client, address)).start() 270 | 271 | def listenToClient(self, client, address): 272 | data = b'' 273 | while True: 274 | input = client.recv(1024) 275 | data += input 276 | if input.endswith(b"\r\n"): 277 | data = data.strip() 278 | break 279 | if not input: break 280 | 281 | npz = np.load(BytesIO(data)) 282 | chunked_mutations = npz['chunked_mutations'] 283 | indices = npz['indices'] 284 | generation = npz['generation'] 285 | max_timesteps = npz['max_timesteps'] 286 | npz.close() 287 | client.send(b"OK") 288 | client.close() 289 | 290 | log(ID, "> Received " + str(len(chunked_mutations)) + " mutations from dispatcher") 291 | length = len(chunked_mutations) 292 | cores = cpu_count() 293 | if cores < self.args.trials: 294 | splits = length 295 | else: 296 | splits = math.ceil((length * self.args.trials) / cores) 297 | chunked_mutations = np.array_split(chunked_mutations, splits) 298 | indices = np.array_split(indices, splits) 299 | cumulative_rewards = {} 300 | for i, this_chunked_mutations in enumerate(chunked_mutations): 301 | this_indices = indices[i] 302 | worker_arg_tuples = [] 303 | for i, mutation in enumerate(this_chunked_mutations): 304 | worker_arg_tuples.append( 305 | (generation, this_indices[i], self.args, self.vision, self.model, mutation, max_timesteps, True)) 306 | pool = ThreadPool(len(this_chunked_mutations)) 307 | this_cumulative_rewards = pool.map(rollout_worker, worker_arg_tuples) 308 | for i, index in enumerate(this_indices): 309 | cumulative_rewards[index] = this_cumulative_rewards[i] 310 | 311 | log(ID, "> Sending results back to dispatcher: " + str(cumulative_rewards)) 312 | 313 | succeeded = False 314 | for retries in range(3): 315 | try: 316 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 317 | sock.settimeout(10) 318 | sock.connect((CLUSTER_DISPATCHER, CLUSTER_DISPATCHER_PORT)) 319 | sock.sendall(str(cumulative_rewards).encode()) 320 | sock.sendall(b"\r\n") 321 | data = sock.recv(1024).decode("utf-8") 322 | sock.close() 323 | if data == "OK": 324 | succeeded = True 325 | break 326 | except Exception as e: 327 | log(ID, e) 328 | log(ID, "Unable to send results back to dispatcher. Retrying after sleeping for 30s") 329 | time.sleep(30) 330 | if not succeeded: 331 | log(ID, "Unable to send results back to dispatcher!") 332 | 333 | 334 | class DispatcherServer(object): 335 | def __init__(self, port, args, cluster_event): 336 | self.args = args 337 | self.cluster_event = cluster_event 338 | self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 339 | self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 340 | self.sock.bind(('', port)) 341 | self.listen() 342 | 343 | def listen(self): 344 | try: 345 | count = 10 * len(CLUSTER_WORKERS) 346 | self.sock.listen(count) 347 | while True: 348 | client, address = self.sock.accept() 349 | client.settimeout(10) 350 | Thread(target=self.listenToClient, args=(client, address)).start() 351 | except Exception as e: 352 | print(e) 353 | 354 | def listenToClient(self, client, address): 355 | global cluster_cumulative_rewards 356 | data = b'' 357 | while True: 358 | input = client.recv(1024) 359 | data += input 360 | if input.endswith(b"\r\n"): 361 | data = data.strip() 362 | break 363 | if not input: break 364 | 365 | cumulative_rewards = ast.literal_eval(data.decode("utf-8")) 366 | client.send(b"OK") 367 | client.close() 368 | log(ID, "> DispatcherServer received results: " + str(cumulative_rewards)) 369 | with lock: 370 | for index in cumulative_rewards: 371 | cluster_cumulative_rewards[index] = cumulative_rewards[index] 372 | if len(cluster_cumulative_rewards) == self.args.lambda_: 373 | log(ID, "> All results received. Waking up CMA-ES loop") 374 | self.cluster_event.set() 375 | 376 | 377 | def main(): 378 | parser = argparse.ArgumentParser(description='World Models ' + ID) 379 | parser.add_argument('--data_dir', '-d', default="/data/wm", help='The base data/output directory') 380 | parser.add_argument('--game', default='CarRacing-v0', 381 | help='Game to use') # https://gym.openai.com/envs/CarRacing-v0/ 382 | parser.add_argument('--experiment_name', default='experiment_1', help='To isolate its files from others') 383 | parser.add_argument('--model', '-m', default='', help='Initialize the model from given file') 384 | parser.add_argument('--no_resume', action='store_true', help='Don''t auto resume from the latest snapshot') 385 | parser.add_argument('--resume_from', '-r', default='', help='Resume the optimization from a specific snapshot') 386 | parser.add_argument('--hidden_dim', default=256, type=int, help='LSTM hidden units') 387 | parser.add_argument('--z_dim', '-z', default=32, type=int, help='dimension of encoded vector') 388 | parser.add_argument('--mixtures', default=5, type=int, help='number of gaussian mixtures for MDN') 389 | parser.add_argument('--lambda_', "-l", default=7, type=int, help='Population size for CMA-ES') 390 | parser.add_argument('--mu', default=0.5, type=float, help='Keep this percent of fittest mutations for CMA-ES') 391 | parser.add_argument('--trials', default=3, type=int, 392 | help='The number of trials per mutation for CMA-ES, to average fitness score over') 393 | parser.add_argument('--target_cumulative_reward', default=900, type=int, help='Target cumulative reward') 394 | parser.add_argument('--frame_resize', default=64, type=int, help='h x w resize of each observation frame') 395 | parser.add_argument('--temperature', '-t', default=1.0, type=float, help='Temperature (tau) for MDN-RNN (model)') 396 | parser.add_argument('--snapshot_interval', '-s', default=5, type=int, 397 | help='snapshot every x generations of evolution') 398 | parser.add_argument('--cluster_mode', action='store_true', 399 | help='If in a distributed cpu cluster. Set CLUSTER_ variables accordingly.') 400 | parser.add_argument('--test', action='store_true', 401 | help='Generate a rollout gif only (must have access to saved snapshot or model)') 402 | parser.add_argument('--gpu', '-g', default=-1, type=int, help='GPU ID (negative value indicates CPU)') 403 | parser.add_argument('--gpus', default="", help='A list of gpus to use, i.e. "0,1,2,3"') 404 | parser.add_argument('--curriculum', default="", help='initial,step e.g. 50,5 starts at 50 steps and adds 5 steps') 405 | parser.add_argument('--predict_done', action='store_true', help='Whether MDN-RNN should also predict done state') 406 | parser.add_argument('--done_threshold', default=0.5, type=float, help='What done probability really means done') 407 | parser.add_argument('--weights_type', default=1, type=int, 408 | help="1=action_dim*(z_dim+hidden_dim), 2=z_dim+2*hidden_dim") 409 | parser.add_argument('--in_dream', action='store_true', help='Whether to train in dream, or real environment') 410 | parser.add_argument('--dream_max_len', default=2100, type=int, help="Maximum timesteps for dream to avoid runaway") 411 | parser.add_argument('--cores', default=0, type=int, 412 | help='# CPU cores for main CMA-ES loop in non-cluster_mode. 0=all cores') 413 | parser.add_argument('--initial_z_size', default=10000, type=int, 414 | help="How many real initial frames to load for dream training") 415 | parser.add_argument('--initial_z_noise', default=0., type=float, 416 | help="Gaussian noise std for initial z for dream training") 417 | parser.add_argument('--cluster_max_wait', default=5400, type=int, 418 | help="Move on after this many seconds of no response from worker(s)") 419 | 420 | args = parser.parse_args() 421 | log(ID, "args =\n " + str(vars(args)).replace(",", ",\n ")) 422 | 423 | hostname = socket.gethostname().split(".")[0] 424 | if args.gpus: 425 | args.gpus = [int(item) for item in args.gpus.split(',')] 426 | if args.curriculum: 427 | curriculum_start = int(args.curriculum.split(',')[0]) 428 | curriculum_step = int(args.curriculum.split(',')[1]) 429 | 430 | output_dir = os.path.join(args.data_dir, args.game, args.experiment_name, ID) 431 | mkdir(output_dir) 432 | model_dir = os.path.join(args.data_dir, args.game, args.experiment_name, 'model') 433 | vision_dir = os.path.join(args.data_dir, args.game, args.experiment_name, 'vision') 434 | random_rollouts_dir = os.path.join(args.data_dir, args.game, args.experiment_name, 'random_rollouts') 435 | 436 | model = MDN_RNN(args.hidden_dim, args.z_dim, args.mixtures, args.predict_done) 437 | chainer.serializers.load_npz(os.path.join(model_dir, "model.model"), model) 438 | vision = CVAE(args.z_dim) 439 | chainer.serializers.load_npz(os.path.join(vision_dir, "vision.model"), vision) 440 | 441 | global initial_z_t 442 | if args.in_dream: 443 | log(ID,"Loading random rollouts for initial frames for dream training") 444 | initial_z_t = ModelDataset(dir=random_rollouts_dir, 445 | load_batch_size=args.initial_z_size, 446 | verbose=False) 447 | 448 | if args.game in DOOM_GAMES: 449 | env = ViZDoomWrapper(args.game) 450 | else: 451 | env = gym.make(args.game) 452 | action_dim = len(env.action_space.low) 453 | args.action_dim = action_dim 454 | env = None 455 | 456 | auto_resume_file = None 457 | if not args.cluster_mode or (args.cluster_mode and hostname == CLUSTER_DISPATCHER): 458 | max_iter = 0 459 | files = os.listdir(output_dir) 460 | for file in files: 461 | if re.match(r'^snapshot_iter_', file): 462 | iter = int(re.search(r'\d+', file).group()) 463 | if (iter > max_iter): 464 | max_iter = iter 465 | if max_iter > 0: 466 | auto_resume_file = os.path.join(output_dir, "snapshot_iter_{}.npz".format(max_iter)) 467 | 468 | resume = None 469 | if args.model: 470 | if args.model == 'default': 471 | args.model = os.path.join(output_dir, ID + ".model") 472 | log(ID, "Loading saved model from: " + args.model) 473 | resume = args.model 474 | elif args.resume_from: 475 | log(ID, "Resuming manually from snapshot: " + args.resume_from) 476 | resume = args.resume_from 477 | elif not args.no_resume and auto_resume_file is not None: 478 | log(ID, "Auto resuming from last snapshot: " + auto_resume_file) 479 | resume = auto_resume_file 480 | 481 | if resume is not None: 482 | npz = np.load(resume) 483 | pc = npz['pc'] 484 | ps = npz['ps'] 485 | B = npz['B'] 486 | D = npz['D'] 487 | C = npz['C'] 488 | invsqrtC = npz['invsqrtC'] 489 | eigeneval = npz['eigeneval'] 490 | xmean = npz['xmean'] 491 | sigma = npz['sigma'] 492 | counteval = npz['counteval'] 493 | generation = npz['generation'] + 1 494 | cumulative_rewards_over_generations = npz['cumulative_rewards_over_generations'] 495 | if args.curriculum: 496 | if 'max_timesteps' in npz and npz['max_timesteps'] is not None: 497 | max_timesteps = npz['max_timesteps'] 498 | else: 499 | max_timesteps = curriculum_start 500 | last_highest_avg_cumulative_reward = max(cumulative_rewards_over_generations.mean(axis=1)) 501 | else: 502 | max_timesteps = None 503 | npz.close() 504 | 505 | log(ID, "Starting") 506 | 507 | if args.cluster_mode and hostname != CLUSTER_DISPATCHER and not args.test: 508 | log(ID, "Starting cluster worker") 509 | WorkerServer(CLUSTER_WORKER_PORT, args, vision, model) 510 | elif not args.test: 511 | if args.cluster_mode: 512 | global cluster_cumulative_rewards 513 | cluster_event = Event() 514 | 515 | log(ID, "Starting cluster dispatcher") 516 | dispatcher_thread = Thread(target=DispatcherServer, args=(CLUSTER_DISPATCHER_PORT, args, cluster_event)) 517 | dispatcher_thread.start() 518 | 519 | # Make the dispatcher a worker too 520 | log(ID, "Starting cluster worker") 521 | worker_thread = Thread(target=WorkerServer, args=(CLUSTER_WORKER_PORT, args, vision, model)) 522 | worker_thread.start() 523 | 524 | if args.weights_type == 1: 525 | N = action_dim * (args.z_dim + args.hidden_dim) + action_dim 526 | elif args.weights_type == 2: 527 | N = args.z_dim + 2 * args.hidden_dim 528 | 529 | stopeval = 1e3 * N ** 2 530 | stopfitness = args.target_cumulative_reward 531 | 532 | lambda_ = args.lambda_ # 4+int(3*np.log(N)) 533 | mu = int(lambda_ * args.mu) # //2 534 | weights = np.log(mu + 1 / 2) - np.log(np.asarray(range(1, mu + 1))).astype(np.float32) 535 | weights = weights / np.sum(weights) 536 | mueff = (np.sum(weights) ** 2) / np.sum(weights ** 2) 537 | 538 | cc = (4 + mueff / N) / (N + 4 + 2 * mueff / N) 539 | cs = (mueff + 2) / (N + mueff + 5) 540 | c1 = 2 / ((N + 1.3) ** 2 + mueff) 541 | cmu = min(1 - c1, 2 * (mueff - 2 + 1 / mueff) / ((N + 2) ** 2 + mueff)) 542 | damps = 1 + 2 * max(0, ((mueff - 1) / (N + 1)) ** 0.5 - 1) + cs 543 | chiN = N ** 0.5 * (1 - 1 / (4 * N) + 1 / (21 * N ** 2)) 544 | 545 | if resume is None: 546 | pc = np.zeros(N).astype(np.float32) 547 | ps = np.zeros(N).astype(np.float32) 548 | B = np.eye(N, N).astype(np.float32) 549 | D = np.ones(N).astype(np.float32) 550 | C = B * np.diag(D ** 2) * B.T 551 | invsqrtC = B * np.diag(D ** -1) * B.T 552 | eigeneval = 0 553 | xmean = np.random.randn(N).astype(np.float32) 554 | sigma = 0.3 555 | counteval = 0 556 | generation = 1 557 | cumulative_rewards_over_generations = None 558 | if args.curriculum: 559 | max_timesteps = curriculum_start 560 | last_highest_avg_cumulative_reward = None 561 | else: 562 | max_timesteps = None 563 | 564 | solution_found = False 565 | while counteval < stopeval: 566 | log(ID, "> Starting evolution generation #" + str(generation)) 567 | 568 | arfitness = np.zeros(lambda_).astype(np.float32) 569 | arx = np.zeros((lambda_, N)).astype(np.float32) 570 | for k in range(lambda_): 571 | arx[k] = xmean + sigma * B.dot(D * np.random.randn(N).astype(np.float32)) 572 | counteval += 1 573 | 574 | if not args.cluster_mode: 575 | if args.cores == 0: 576 | cores = cpu_count() 577 | else: 578 | cores = args.cores 579 | pool = Pool(cores) 580 | worker_arg_tuples = [] 581 | for k in range(lambda_): 582 | worker_arg_tuples.append((generation, k, args, vision, model, arx[k], max_timesteps, False)) 583 | cumulative_rewards = pool.map(rollout_worker, worker_arg_tuples) 584 | pool.close() 585 | pool.join() 586 | for k, cumulative_reward in enumerate(cumulative_rewards): 587 | arfitness[k] = cumulative_reward 588 | else: 589 | arx_splits = np.array_split(arx, len(CLUSTER_WORKERS)) 590 | indices = np.array_split(np.arange(lambda_), len(CLUSTER_WORKERS)) 591 | cluster_cumulative_rewards = {} 592 | for i, chunked_mutations in enumerate(arx_splits): 593 | log(ID, "> Dispatching " + str(len(chunked_mutations)) + " mutations to " + CLUSTER_WORKERS[i]) 594 | compressed_array = BytesIO() 595 | np.savez_compressed(compressed_array, 596 | chunked_mutations=chunked_mutations, 597 | indices=indices[i], 598 | generation=generation, 599 | max_timesteps=max_timesteps) 600 | compressed_array.seek(0) 601 | out = compressed_array.read() 602 | 603 | succeeded = False 604 | for retries in range(3): 605 | try: 606 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 607 | sock.settimeout(10) 608 | sock.connect((CLUSTER_WORKERS[i], CLUSTER_WORKER_PORT)) 609 | sock.sendall(out) 610 | sock.sendall(b"\r\n") 611 | data = sock.recv(1024).decode("utf-8") 612 | sock.close() 613 | if data == "OK": 614 | succeeded = True 615 | break 616 | except Exception as e: 617 | log(ID, e) 618 | log(ID, "Unable to dispatch mutations to " + CLUSTER_WORKERS[i] + ". Retrying after sleeping for 30s") 619 | time.sleep(30) 620 | if not succeeded: 621 | log(ID, "Unable to dispatch mutations to " + CLUSTER_WORKERS[i] + "!") 622 | log(ID, "> Dispatched all mutations to cluster. Waiting for results.") 623 | cluster_event.clear() 624 | cluster_event.wait(args.cluster_max_wait) # Cut our losses if some results never get returned 625 | for k in range(lambda_): 626 | if k in cluster_cumulative_rewards: 627 | arfitness[k] = cluster_cumulative_rewards[k] 628 | else: 629 | arfitness[k] = 0. 630 | 631 | if cumulative_rewards_over_generations is None: 632 | cumulative_rewards_over_generations = np.expand_dims(arfitness, 0) 633 | else: 634 | cumulative_rewards_over_generations = np.concatenate( 635 | (cumulative_rewards_over_generations, np.expand_dims(arfitness, 0)), 636 | axis=0) 637 | 638 | arindex = np.argsort(-arfitness) 639 | # arfitness = arfitness[arindex] 640 | 641 | xold = xmean 642 | xmean = weights.dot(arx[arindex[0:mu]]) 643 | 644 | avg_cumulative_reward = np.mean(arfitness) 645 | 646 | log(ID, "> Finished evolution generation #{}, average cumulative reward = {:.2f}" 647 | .format(generation, avg_cumulative_reward)) 648 | 649 | if generation > 1 and args.curriculum: 650 | if last_highest_avg_cumulative_reward is None: 651 | last_highest_avg_cumulative_reward = np.mean(cumulative_rewards_over_generations[-2]) 652 | log(ID, "> Highest average cumulative reward from previous generations = {:.2f}".format( 653 | last_highest_avg_cumulative_reward)) 654 | if avg_cumulative_reward > (last_highest_avg_cumulative_reward*0.99): #Let is pass if within 1% of the old average 655 | max_timesteps += curriculum_step 656 | log(ID, "> Average cumulative reward increased. Increasing max timesteps to " + str(max_timesteps)) 657 | last_highest_avg_cumulative_reward = None 658 | else: 659 | log(ID, 660 | "> Average cumulative reward did not increase. Keeping max timesteps at " + str(max_timesteps)) 661 | 662 | # Average over the whole population, but breaking here means we use only the 663 | # top x% of the mutations as the calculation for the final mean 664 | if avg_cumulative_reward >= stopfitness: 665 | solution_found = True 666 | break 667 | 668 | ps = (1 - cs) * ps + np.sqrt(cs * (2 - cs) * mueff) * invsqrtC.dot((xmean - xold) / sigma) 669 | hsig = np.linalg.norm(ps) / np.sqrt(1 - (1 - cs) ** (2 * counteval / lambda_)) / chiN < 1.4 + 2 / (N + 1) 670 | pc = (1 - cc) * pc + hsig * np.sqrt(cc * (2 - cc) * mueff) * ((xmean - xold) / sigma) 671 | artmp = (1 / sigma) * (arx[arindex[0:mu]] - xold) 672 | C = (1 - c1 - cmu) * C + c1 * (pc.dot(pc.T) + (1 - hsig) * cc * (2 - cc) * C) + cmu * artmp.T.dot( 673 | np.diag(weights)).dot(artmp) 674 | sigma = sigma * np.exp((cs / damps) * (np.linalg.norm(ps) / chiN - 1)) 675 | 676 | if counteval - eigeneval > lambda_ / (c1 + cmu) / N / 10: 677 | eigeneval = counteval 678 | C = np.triu(C) + np.triu(C, 1).T 679 | D, B = np.linalg.eig(C) 680 | D = np.sqrt(D) 681 | invsqrtC = B.dot(np.diag(D ** -1).dot(B.T)) 682 | 683 | if generation % args.snapshot_interval == 0: 684 | snapshot_file = os.path.join(output_dir, "snapshot_iter_" + str(generation) + ".npz") 685 | log(ID, "> Saving snapshot to " + str(snapshot_file)) 686 | np.savez_compressed(snapshot_file, 687 | pc=pc, 688 | ps=ps, 689 | B=B, 690 | D=D, 691 | C=C, 692 | invsqrtC=invsqrtC, 693 | eigeneval=eigeneval, 694 | xmean=xmean, 695 | sigma=sigma, 696 | counteval=counteval, 697 | generation=generation, 698 | cumulative_rewards_over_generations=cumulative_rewards_over_generations, 699 | max_timesteps=max_timesteps) 700 | 701 | generation += 1 702 | 703 | if solution_found: 704 | log(ID, "Evolution Complete!") 705 | log(ID, "Solution found at generation #" + str(generation) + ", with average cumulative reward = " + 706 | str(avg_cumulative_reward) + " over " + str(args.lambda_ * args.trials) + " rollouts") 707 | else: 708 | log(ID, "Solution not found") 709 | 710 | controller_model_file = os.path.join(output_dir, ID + ".model") 711 | if os.path.exists(controller_model_file): 712 | os.remove(controller_model_file) 713 | log(ID, "Saving model to: " + controller_model_file) 714 | np.savez_compressed(controller_model_file, 715 | pc=pc, 716 | ps=ps, 717 | B=B, 718 | D=D, 719 | C=C, 720 | invsqrtC=invsqrtC, 721 | eigeneval=eigeneval, 722 | xmean=xmean, 723 | sigma=sigma, 724 | counteval=counteval, 725 | generation=generation, 726 | cumulative_rewards_over_generations=cumulative_rewards_over_generations, 727 | max_timesteps=max_timesteps) 728 | os.rename(os.path.join(output_dir, ID + ".model.npz"), controller_model_file) 729 | 730 | # xmean = np.random.randn(action_dim * (args.z_dim + args.hidden_dim) + action_dim).astype(np.float32) 731 | # xmean = np.random.randn(args.z_dim + 2 * args.hidden_dim).astype(np.float32) 732 | parameters = xmean 733 | 734 | if args.in_dream: 735 | log(ID, "Generating a rollout gif with the controller model in a dream") 736 | W_c, b_c = transform_to_weights(args, parameters) 737 | cumulative_reward, frames = rollout( 738 | (0, 0, 0, args, vision.to_cpu(), model.to_cpu(), None, W_c, b_c, None, True)) 739 | imageio.mimsave(os.path.join(output_dir, 'dream_rollout.gif'), frames, fps=20) 740 | log(ID, "Final cumulative reward in dream: " + str(cumulative_reward)) 741 | args.in_dream = False 742 | 743 | log(ID, "Generating a rollout gif with the controller model in the environment") 744 | W_c, b_c = transform_to_weights(args, parameters) 745 | cumulative_reward, frames = rollout((0, 0, 0, args, vision.to_cpu(), model.to_cpu(), None, W_c, b_c, None, True)) 746 | imageio.mimsave(os.path.join(output_dir, 'env_rollout.gif'), frames, fps=20) 747 | log(ID, "Final cumulative reward in environment: " + str(cumulative_reward)) 748 | 749 | log(ID, "Done") 750 | 751 | 752 | if __name__ == '__main__': 753 | main() 754 | -------------------------------------------------------------------------------- /lib/constants.py: -------------------------------------------------------------------------------- 1 | DOOM_GAMES = ['DoomTakeCover'] 2 | -------------------------------------------------------------------------------- /lib/data.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import numpy as np 3 | import os 4 | import random 5 | import gc 6 | from multiprocessing import cpu_count, Pool 7 | 8 | from chainer import dataset 9 | import chainer.functions as F 10 | 11 | from lib.utils import log, pre_process_image_tensor 12 | 13 | 14 | def load_frames_worker(frames_file): 15 | with gzip.GzipFile(frames_file, "r") as file: 16 | rollout_frames = pre_process_image_tensor(np.load(file)) 17 | return rollout_frames 18 | 19 | 20 | def load_model_npz_worker(files): 21 | npz1, npz2 = files 22 | 23 | npz = np.load(npz1) 24 | mu = npz['mu'] 25 | ln_var = npz['ln_var'] 26 | npz.close() 27 | 28 | npz = np.load(npz2) 29 | action = npz['action'] 30 | npz.close() 31 | 32 | return mu, ln_var, action 33 | 34 | 35 | class VisionDataset(dataset.DatasetMixin): 36 | def __init__(self, dir='', load_batch_size=10, shuffle=True, verbose=True): 37 | rollouts = os.listdir(dir) 38 | rollouts_counts = {} 39 | 40 | for rollout in rollouts: 41 | count_file = os.path.join(dir, rollout, "count") 42 | if os.path.exists(count_file): 43 | with open(count_file, 'r') as count_file: 44 | count = int(count_file.read()) 45 | rollouts_counts[rollout] = count 46 | 47 | rollouts = list(rollouts_counts.keys()) 48 | 49 | if shuffle: 50 | random.shuffle(rollouts) 51 | else: 52 | rollouts = sorted(rollouts, key=lambda x: int(x)) 53 | 54 | total_batches = len(rollouts) // load_batch_size 55 | if len(rollouts) % load_batch_size != 0: 56 | total_batches += 1 57 | 58 | self.batch = -1 59 | self.total_batches = total_batches 60 | self.dir = dir 61 | self.shuffle = shuffle 62 | self.verbose = verbose 63 | self.load_batch_size = load_batch_size 64 | self.rollouts = rollouts 65 | self.total_count = sum(rollouts_counts.values()) 66 | self.rollouts_counts = rollouts_counts 67 | 68 | self.reset_indices() 69 | 70 | self.load_batch(0) 71 | 72 | def reset_indices(self): 73 | if self.verbose: 74 | log("VisionDataset", "*** Creating list of indices") 75 | running_count = 0 76 | absolute_indices = [] 77 | for batch in range(self.total_batches): 78 | batch_start_idx = batch * self.load_batch_size 79 | batch_end_idx = (batch + 1) * self.load_batch_size 80 | batch_rollouts = self.rollouts[batch_start_idx:batch_end_idx] 81 | count = sum([self.rollouts_counts[rollout] for rollout in batch_rollouts]) 82 | absolute_start_idx = running_count 83 | absolute_end_idx = running_count + count 84 | running_count = running_count + count 85 | if self.verbose: 86 | log("VisionDataset", "*** Batch " + str(batch) + ", from index " + str(batch_start_idx) + ":" + str( 87 | batch_end_idx) + ", of rollouts " + str(batch_rollouts) + ", with " + str( 88 | count) + " frames in this batch, goes from absolute index " + str(absolute_start_idx) + ":" + str( 89 | absolute_end_idx)) 90 | absolute_indices.append([absolute_start_idx, absolute_end_idx]) 91 | self.absolute_indices = absolute_indices 92 | 93 | def load_batch(self, batch): 94 | if self.batch == batch: 95 | return 96 | 97 | if batch == 0 and self.batch > 0 and self.shuffle: 98 | random.shuffle(self.rollouts) 99 | self.reset_indices() 100 | 101 | self.batch_frames = None 102 | gc.collect() 103 | 104 | if self.verbose: 105 | log("VisionDataset", "*** Loading batch " + str(batch)) 106 | 107 | batch_start_idx = batch * self.load_batch_size 108 | batch_end_idx = (batch + 1) * self.load_batch_size 109 | batch_rollouts = self.rollouts[batch_start_idx:batch_end_idx] 110 | batch_frames = None 111 | batch_rollouts_counts = {} 112 | frames_files = [] 113 | for rollout in batch_rollouts: 114 | batch_rollouts_counts[rollout] = self.rollouts_counts[rollout] 115 | frames_files.append(os.path.join(self.dir, rollout, "frames.npy.gz")) 116 | pool = Pool(cpu_count()) 117 | all_rollout_frames = pool.map(load_frames_worker, frames_files) 118 | pool.close() 119 | pool.join() 120 | batch_frames = np.concatenate(all_rollout_frames) 121 | if self.shuffle: 122 | shuffled_indices = np.random.permutation(np.arange(batch_frames.shape[0])) 123 | batch_frames = batch_frames[shuffled_indices] 124 | if self.verbose: 125 | log("VisionDataset", "*** Loaded batch " + str(batch) + ", from index " + str(batch_start_idx) + ":" + str( 126 | batch_end_idx) + ", of rollouts " + str(batch_rollouts) + ", with " + str( 127 | batch_frames.shape[0]) + " total frames in this batch. Each rollout has count: " + str( 128 | batch_rollouts_counts)) 129 | self.batch_rollouts = batch_rollouts 130 | self.batch_rollouts_counts = batch_rollouts_counts 131 | self.batch_frames = batch_frames 132 | self.batch = batch 133 | 134 | def get_current_batch_size(self): 135 | return self.batch_frames.shape[0] 136 | 137 | def get_total_batches(self): 138 | return self.total_batches 139 | 140 | def get_current_batch(self): 141 | return self.batch_frames, self.batch_rollouts, self.batch_rollouts_counts 142 | 143 | def __len__(self): 144 | return self.total_count 145 | 146 | def get_example(self, i): 147 | absolute_start_idx = self.absolute_indices[self.batch][0] 148 | absolute_end_idx = self.absolute_indices[self.batch][1] 149 | if i < absolute_start_idx or i >= absolute_end_idx: 150 | for batch, absolute_indices in enumerate(self.absolute_indices): 151 | absolute_start_idx = absolute_indices[0] 152 | absolute_end_idx = absolute_indices[1] 153 | if i >= absolute_start_idx and i < absolute_end_idx: 154 | self.load_batch(batch) 155 | break 156 | return self.batch_frames[i - absolute_start_idx] 157 | 158 | 159 | class ModelDataset(dataset.DatasetMixin): 160 | def __init__(self, dir='', load_batch_size=10, verbose=True): 161 | rollouts = os.listdir(dir) 162 | rollouts_counts = {} 163 | 164 | for rollout in rollouts: 165 | count_file = os.path.join(dir, rollout, "count") 166 | if os.path.exists(count_file): 167 | with open(count_file, 'r') as count_file: 168 | count = int(count_file.read()) 169 | rollouts_counts[rollout] = count - 1 # -1 b/c last frame doesn't have a next frame 170 | 171 | rollouts = list(rollouts_counts.keys()) 172 | 173 | # Sort by the longest rollouts up front, for chainer's LSTM, at least for the first epoch 174 | rollouts = sorted(rollouts, key=lambda x: -rollouts_counts[x]) 175 | 176 | total_batches = len(rollouts) // load_batch_size 177 | if len(rollouts) % load_batch_size != 0: 178 | total_batches += 1 179 | 180 | for batch in range(total_batches): 181 | batch_start_idx = batch * load_batch_size 182 | batch_end_idx = (batch + 1) * load_batch_size 183 | batch_rollouts = rollouts[batch_start_idx:batch_end_idx] 184 | if verbose: 185 | log("ModelDataset", "*** Batch " + str(batch) + ", from index " + str(batch_start_idx) + ":" + str( 186 | batch_end_idx) + ", will be of rollouts " + str(batch_rollouts)) 187 | 188 | self.batch = -1 189 | self.last_index = -1 190 | self.total_batches = total_batches 191 | self.dir = dir 192 | self.verbose = verbose 193 | self.load_batch_size = load_batch_size 194 | self.rollouts = rollouts 195 | self.rollouts_counts = rollouts_counts 196 | 197 | self.load_batch(0) 198 | 199 | def load_batch(self, batch): 200 | if self.batch == batch: 201 | return 202 | 203 | if batch == 0 and self.batch > 0: 204 | random.shuffle(self.rollouts) 205 | 206 | self.z_t = None 207 | self.z_t_plus_1 = None 208 | self.action = None 209 | gc.collect() 210 | 211 | if self.verbose: 212 | log("ModelDataset", "*** Loading batch " + str(batch)) 213 | 214 | batch_start_idx = batch * self.load_batch_size 215 | batch_end_idx = (batch + 1) * self.load_batch_size 216 | batch_rollouts = self.rollouts[batch_start_idx:batch_end_idx] 217 | batch_rollouts_counts = {} 218 | files = [] 219 | for rollout in batch_rollouts: 220 | batch_rollouts_counts[rollout] = self.rollouts_counts[rollout] 221 | files.append( 222 | (os.path.join(self.dir, rollout, "mu+ln_var.npz"), 223 | os.path.join(self.dir, rollout, "misc.npz"))) 224 | pool = Pool(cpu_count()) 225 | data = pool.map(load_model_npz_worker, files) 226 | pool.close() 227 | pool.join() 228 | mu = [] 229 | ln_var = [] 230 | action = [] 231 | for rollout_mu, rollout_ln_var, rollout_action in data: 232 | mu.append(rollout_mu) 233 | ln_var.append(rollout_ln_var) 234 | action.append(rollout_action) 235 | 236 | if self.verbose: 237 | log("ModelDataset", "*** Loaded batch " + str(batch) + ", from index " + str(batch_start_idx) + ":" + str( 238 | batch_end_idx) + ", of rollouts " + str(batch_rollouts) + ", with " + str( 239 | len(mu)) + " total rollouts in this batch. Each rollout has count: " + str( 240 | batch_rollouts_counts)) 241 | 242 | self.batch_rollouts = batch_rollouts 243 | self.batch_rollouts_counts = batch_rollouts_counts 244 | self.mu = mu 245 | self.ln_var = ln_var 246 | self.action = action 247 | self.batch = batch 248 | 249 | def get_current_batch_size(self): 250 | return len(self.batch_rollouts) 251 | 252 | def __len__(self): 253 | return len(self.rollouts) 254 | 255 | def get_example(self, i): 256 | batch = i // self.load_batch_size 257 | self.load_batch(batch) 258 | index = i % self.load_batch_size 259 | 260 | # In case we have all rollouts loaded in memory, and 261 | # are not doing batched loading, shuffle every epoch: 262 | if self.load_batch_size >= len(self.rollouts): 263 | if index == 0 and self.last_index > 0: 264 | shuffled = list(zip(self.mu, self.ln_var, self.action)) 265 | random.shuffle(shuffled) 266 | self.mu, self.ln_var, self.action = zip(*shuffled) 267 | self.last_index = index 268 | 269 | # reconstruct every time, prevent overfitting: 270 | z_t = F.gaussian(self.mu[index], self.ln_var[index]).data 271 | z_t_plus_1 = z_t[1:] 272 | z_t = z_t[0:z_t.shape[0] - 1] 273 | 274 | done = np.zeros((z_t_plus_1.shape[0], 1)).astype(np.int32) 275 | done[-1, 0] = 1. 276 | 277 | return z_t, z_t_plus_1, self.action[index], done 278 | -------------------------------------------------------------------------------- /lib/env_wrappers.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | from vizdoom import DoomGame, ScreenResolution 5 | 6 | CONFIGURATIONS_DIR = os.path.join( 7 | os.environ['CONDA_PREFIX'], 8 | "lib/python" + str(sys.version_info.major) + "." + str(sys.version_info.minor) + "/site-packages/vizdoom/scenarios") 9 | CONFIGURATIONS = { # https://github.com/mwydmuch/ViZDoom/tree/master/scenarios: 10 | "DoomTakeCover": "take_cover" 11 | } 12 | 13 | 14 | class AttrDict(dict): 15 | def __init__(self, *args, **kwargs): 16 | super(AttrDict, self).__init__(*args, **kwargs) 17 | self.__dict__ = self 18 | 19 | 20 | class ViZDoomWrapper(object): 21 | def __init__(self, configuration): 22 | configuration = CONFIGURATIONS[configuration] 23 | game = DoomGame() 24 | game.load_config( 25 | os.path.join(CONFIGURATIONS_DIR, configuration + ".cfg")) 26 | game.set_screen_resolution(ScreenResolution.RES_160X120) 27 | game.set_window_visible(False) 28 | game.init() 29 | action_dim = game.get_available_buttons_size() 30 | action_space = AttrDict() 31 | action_space.low = [0 for i in range(action_dim)] 32 | action_space.high = [1 for i in range(action_dim)] 33 | self.action_space = action_space 34 | self.game = game 35 | 36 | def reset(self): 37 | self.game.new_episode() 38 | return self.game.get_state().screen_buffer 39 | 40 | def step(self, action): 41 | action = action.astype(bool).tolist() 42 | reward = self.game.make_action(action) 43 | if self.game.get_state() is not None: 44 | self.last_screen_buffer = self.game.get_state().screen_buffer 45 | return self.last_screen_buffer, \ 46 | reward, \ 47 | self.game.is_episode_finished(), \ 48 | None 49 | 50 | def close(self): 51 | self.game.close() 52 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone 2 | import os 3 | from scipy.misc import imsave 4 | import numpy as np 5 | 6 | 7 | def pre_process_image_tensor(images): 8 | if images.dtype != np.float32: 9 | images = images.astype(np.float32) / 255. 10 | if images.shape[-1] == 3: 11 | images = np.rollaxis(images, 3, 1) 12 | return images 13 | 14 | 15 | def post_process_image_tensor(images): 16 | if images.dtype != np.uint8: 17 | images = (images * 255).astype('uint8') 18 | if images.shape[-1] != 3: 19 | images = np.rollaxis(images, 1, 4) 20 | return images 21 | 22 | 23 | def save_images_collage(images, save_path, pre_processed=True): 24 | if pre_processed: 25 | images = post_process_image_tensor(images) 26 | 27 | npad = ((0, 0), (2, 2), (2, 2), (0, 0)) 28 | images = np.pad(images, pad_width=npad, mode='constant', constant_values=255) 29 | 30 | n_samples = images.shape[0] 31 | rows = int(np.sqrt(n_samples)) 32 | while n_samples % rows != 0: 33 | rows -= 1 34 | 35 | nh, nw = rows, n_samples // rows 36 | 37 | if images.ndim == 2: 38 | images = np.reshape(images, (images.shape[0], int(np.sqrt(images.shape[1])), int(np.sqrt(images.shape[1])))) 39 | 40 | if images.ndim == 4: 41 | h, w = images[0].shape[:2] 42 | img = np.zeros((h * nh, w * nw, 3)) 43 | elif images.ndim == 3: 44 | h, w = images[0].shape[:2] 45 | img = np.zeros((h * nh, w * nw)) 46 | 47 | for n, images in enumerate(images): 48 | j = n // nw 49 | i = n % nw 50 | img[j * h:j * h + h, i * w:i * w + w] = images 51 | 52 | imsave(save_path, img) 53 | 54 | 55 | def mkdir(dir_name): 56 | if not os.path.exists(dir_name): 57 | os.makedirs(dir_name) 58 | 59 | 60 | def log(id, message): 61 | print(str(datetime.now(timezone.utc)) + " [" + str(id) + "] " + str(message)) 62 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | import math 5 | 6 | import chainer 7 | import chainer.functions as F 8 | import chainer.links as L 9 | from chainer import training 10 | from chainer.training import extensions 11 | try: 12 | import cupy as cp 13 | except Exception as e: 14 | None 15 | 16 | import numpy as np 17 | import imageio 18 | import numba 19 | 20 | from lib.utils import log, mkdir, save_images_collage, post_process_image_tensor 21 | from lib.data import ModelDataset 22 | from vision import CVAE 23 | 24 | ID = "model" 25 | 26 | 27 | @numba.jit(nopython=True) 28 | def optimized_sampling(output_dim, temperature, coef, mu, ln_var): 29 | mus = np.zeros(output_dim) 30 | ln_vars = np.zeros(output_dim) 31 | for i in range(output_dim): 32 | cumulative_probability = 0. 33 | r = np.random.uniform(0., 1.) 34 | index = len(coef)-1 35 | for j, probability in enumerate(coef[i]): 36 | cumulative_probability = cumulative_probability + probability 37 | if r <= cumulative_probability: 38 | index = j 39 | break 40 | for j, this_mu in enumerate(mu[i]): 41 | if j == index: 42 | mus[i] = this_mu 43 | break 44 | for j, this_ln_var in enumerate(ln_var[i]): 45 | if j == index: 46 | ln_vars[i] = this_ln_var 47 | break 48 | z_t_plus_1 = mus + np.exp(ln_vars) * np.random.randn(output_dim) * np.sqrt(temperature) 49 | return z_t_plus_1 50 | 51 | 52 | class MDN_RNN(chainer.Chain): 53 | def __init__(self, hidden_dim=256, output_dim=32, k=5, predict_done=False): 54 | self.output_dim = output_dim 55 | self.hidden_dim = hidden_dim 56 | self.k = k 57 | self.predict_done = predict_done 58 | init_dict = { 59 | "rnn_layer": L.LSTM(None, hidden_dim), 60 | "coef_layer": L.Linear(None, k * output_dim), 61 | "mu_layer": L.Linear(None, k * output_dim), 62 | "ln_var_layer": L.Linear(None, k * output_dim) 63 | } 64 | if predict_done: 65 | init_dict["done_layer"] = L.Linear(None, 1) 66 | super(MDN_RNN, self).__init__(**init_dict) 67 | 68 | def __call__(self, z_t, action, temperature=1.0): 69 | k = self.k 70 | output_dim = self.output_dim 71 | 72 | if len(z_t.shape) == 1: 73 | z_t = F.expand_dims(z_t, 0) 74 | if len(action.shape) == 1: 75 | action = F.expand_dims(action, 0) 76 | 77 | output = self.fprop(F.concat((z_t, action))) 78 | if self.predict_done: 79 | coef, mu, ln_var, done = output 80 | else: 81 | coef, mu, ln_var = output 82 | 83 | coef = F.reshape(coef, (-1, k)) 84 | mu = F.reshape(mu, (-1, k)) 85 | ln_var = F.reshape(ln_var, (-1, k)) 86 | 87 | coef /= temperature 88 | coef = F.softmax(coef,axis=1) 89 | 90 | if self._cpu: 91 | z_t_plus_1 = optimized_sampling(output_dim, temperature, coef.data, mu.data, ln_var.data).astype(np.float32) 92 | else: 93 | coef = cp.asnumpy(coef.data) 94 | mu = cp.asnumpy(mu.data) 95 | ln_var = cp.asnumpy(ln_var.data) 96 | z_t_plus_1 = optimized_sampling(output_dim, temperature, coef, mu, ln_var).astype(np.float32) 97 | z_t_plus_1 = chainer.Variable(cp.asarray(z_t_plus_1)) 98 | 99 | if self.predict_done: 100 | return z_t_plus_1, F.sigmoid(done) 101 | else: 102 | return z_t_plus_1 103 | 104 | def fprop(self, input): 105 | h = self.rnn_layer(input) 106 | coef = self.coef_layer(h) 107 | mu = self.mu_layer(h) 108 | ln_var = self.ln_var_layer(h) 109 | 110 | if self.predict_done: 111 | done = self.done_layer(h) 112 | 113 | if self.predict_done: 114 | return coef, mu, ln_var, done 115 | else: 116 | return coef, mu, ln_var 117 | 118 | def get_loss_func(self): 119 | def lf(z_t, z_t_plus_1, action, done_label, reset=True): 120 | k = self.k 121 | output_dim = self.output_dim 122 | if reset: 123 | self.reset_state() 124 | 125 | output = self.fprop(F.concat((z_t, action))) 126 | if self.predict_done: 127 | coef, mu, ln_var, done = output 128 | else: 129 | coef, mu, ln_var = output 130 | 131 | coef = F.reshape(coef, (-1, output_dim, k)) 132 | coef = F.softmax(coef, axis=2) 133 | mu = F.reshape(mu, (-1, output_dim, k)) 134 | ln_var = F.reshape(ln_var, (-1, output_dim, k)) 135 | 136 | z_t_plus_1 = F.repeat(z_t_plus_1, k, 1).reshape(-1, output_dim, k) 137 | 138 | normals = F.sum( 139 | coef * F.exp(-F.gaussian_nll(z_t_plus_1, mu, ln_var, reduce='no')) 140 | ,axis=2) 141 | densities = F.sum(normals, axis=1) 142 | nll = -F.log(densities) 143 | 144 | loss = F.sum(nll) 145 | 146 | if self.predict_done: 147 | done_loss = F.sigmoid_cross_entropy(done.reshape(-1,1), done_label, reduce="no") 148 | done_loss *= (1. + done_label.astype("float32")*9.) 149 | done_loss = F.mean(done_loss) 150 | loss = loss + done_loss 151 | 152 | return loss 153 | return lf 154 | 155 | def reset_state(self): 156 | self.rnn_layer.reset_state() 157 | 158 | def get_h(self): 159 | return self.rnn_layer.h 160 | 161 | def get_c(self): 162 | return self.rnn_layer.c 163 | 164 | 165 | class ImageSampler(chainer.training.Extension): 166 | def __init__(self, model, vision, args, output_dir, z_t, action): 167 | self.model = model 168 | self.vision = vision 169 | self.args = args 170 | self.output_dir = output_dir 171 | self.z_t = z_t 172 | self.action = action 173 | 174 | def __call__(self, trainer): 175 | if self.args.gpu >= 0: 176 | self.model.to_cpu() 177 | with chainer.using_config('train', False), chainer.no_backprop_mode(): 178 | self.model.reset_state() 179 | z_t_plus_1s = [] 180 | dones = [] 181 | for i in range(self.z_t.shape[0]): 182 | output = self.model(self.z_t[i], self.action[i], temperature=self.args.sample_temperature) 183 | if self.args.predict_done: 184 | z_t_plus_1, done = output 185 | z_t_plus_1 = z_t_plus_1.data 186 | done = done.data 187 | else: 188 | z_t_plus_1 = output.data 189 | z_t_plus_1s.append(z_t_plus_1) 190 | if self.args.predict_done: 191 | dones.append(done[0]) 192 | z_t_plus_1s = np.asarray(z_t_plus_1s) 193 | dones = np.asarray(dones).reshape(-1) 194 | img_t_plus_1 = post_process_image_tensor(self.vision.decode(z_t_plus_1s).data) 195 | if self.args.predict_done: 196 | img_t_plus_1[np.where(dones >= 0.5), :, :, :] = 0 # Make all the done's black 197 | save_images_collage(img_t_plus_1, 198 | os.path.join(self.output_dir, 199 | 'train_t_plus_1_{}.png'.format(trainer.updater.iteration)), 200 | pre_processed=False) 201 | if self.args.gpu >= 0: 202 | self.model.to_gpu() 203 | 204 | 205 | class TBPTTUpdater(training.updaters.StandardUpdater): 206 | def __init__(self, train_iter, optimizer, device, loss_func, sequence_length): 207 | self.sequence_length = sequence_length 208 | super(TBPTTUpdater, self).__init__( 209 | train_iter, optimizer, device=device, 210 | loss_func=loss_func) 211 | 212 | def update_core(self): 213 | train_iter = self.get_iterator('main') 214 | optimizer = self.get_optimizer('main') 215 | 216 | batch = train_iter.__next__() 217 | total_loss = 0 218 | z_t, z_t_plus_1, action, done = self.converter(batch, self.device) 219 | z_t = chainer.Variable(z_t[0]) 220 | z_t_plus_1 = chainer.Variable(z_t_plus_1[0]) 221 | action = chainer.Variable(action[0]) 222 | done = chainer.Variable(done[0]) 223 | for i in range(math.ceil(z_t.shape[0]/self.sequence_length)): 224 | start_idx = i*self.sequence_length 225 | end_idx = (i+1)*self.sequence_length 226 | loss = self.loss_func(z_t[start_idx:end_idx].data, 227 | z_t_plus_1[start_idx:end_idx].data, 228 | action[start_idx:end_idx].data, 229 | done[start_idx:end_idx].data, 230 | True if i==0 else False) 231 | optimizer.target.cleargrads() 232 | loss.backward() 233 | loss.unchain_backward() 234 | optimizer.update() 235 | total_loss += loss 236 | 237 | chainer.report({'loss': total_loss}) 238 | 239 | 240 | def main(): 241 | parser = argparse.ArgumentParser(description='World Models ' + ID) 242 | parser.add_argument('--data_dir', '-d', default="/data/wm", help='The base data/output directory') 243 | parser.add_argument('--game', default='CarRacing-v0', 244 | help='Game to use') # https://gym.openai.com/envs/CarRacing-v0/ 245 | parser.add_argument('--experiment_name', default='experiment_1', help='To isolate its files from others') 246 | parser.add_argument('--load_batch_size', default=100, type=int, 247 | help='Load rollouts in batches so as not to run out of memory') 248 | parser.add_argument('--model', '-m', default='', 249 | help='Initialize the model from given file, or "default" for one in data folder') 250 | parser.add_argument('--no_resume', action='store_true', help='Don''t auto resume from the latest snapshot') 251 | parser.add_argument('--resume_from', '-r', default='', help='Resume the optimization from a specific snapshot') 252 | parser.add_argument('--test', action='store_true', help='Generate samples only') 253 | parser.add_argument('--gpu', '-g', default=-1, type=int, help='GPU ID (negative value indicates CPU)') 254 | parser.add_argument('--epoch', '-e', default=20, type=int, help='number of epochs to learn') 255 | parser.add_argument('--snapshot_interval', '-s', default=200, type=int, help='snapshot every x games') 256 | parser.add_argument('--z_dim', '-z', default=32, type=int, help='dimension of encoded vector') 257 | parser.add_argument('--hidden_dim', default=256, type=int, help='LSTM hidden units') 258 | parser.add_argument('--mixtures', default=5, type=int, help='number of gaussian mixtures for MDN') 259 | parser.add_argument('--no_progress_bar', '-p', action='store_true', help='Display progress bar during training') 260 | parser.add_argument('--predict_done', action='store_true', help='Whether MDN-RNN should also predict done state') 261 | parser.add_argument('--sample_temperature', default=1., type=float, help='Temperature for generating samples') 262 | parser.add_argument('--gradient_clip', default=0., type=float, help='Clip grads L2 norm threshold. 0 = no clip') 263 | parser.add_argument('--sequence_length', type=int, default=128, help='sequence length for LSTM for TBPTT') 264 | 265 | args = parser.parse_args() 266 | log(ID, "args =\n " + str(vars(args)).replace(",", ",\n ")) 267 | 268 | output_dir = os.path.join(args.data_dir, args.game, args.experiment_name, ID) 269 | mkdir(output_dir) 270 | random_rollouts_dir = os.path.join(args.data_dir, args.game, args.experiment_name, 'random_rollouts') 271 | vision_dir = os.path.join(args.data_dir, args.game, args.experiment_name, 'vision') 272 | 273 | log(ID, "Starting") 274 | 275 | max_iter = 0 276 | auto_resume_file = None 277 | files = os.listdir(output_dir) 278 | for file in files: 279 | if re.match(r'^snapshot_iter_', file): 280 | iter = int(re.search(r'\d+', file).group()) 281 | if (iter > max_iter): 282 | max_iter = iter 283 | if max_iter > 0: 284 | auto_resume_file = os.path.join(output_dir, "snapshot_iter_{}".format(max_iter)) 285 | 286 | model = MDN_RNN(args.hidden_dim, args.z_dim, args.mixtures, args.predict_done) 287 | vision = CVAE(args.z_dim) 288 | chainer.serializers.load_npz(os.path.join(vision_dir, "vision.model"), vision) 289 | 290 | if args.model: 291 | if args.model == 'default': 292 | args.model = os.path.join(output_dir, ID + ".model") 293 | log(ID, "Loading saved model from: " + args.model) 294 | chainer.serializers.load_npz(args.model, model) 295 | 296 | optimizer = chainer.optimizers.Adam() 297 | optimizer.setup(model) 298 | if args.gradient_clip > 0.: 299 | optimizer.add_hook(chainer.optimizer_hooks.GradientClipping(args.gradient_clip)) 300 | 301 | log(ID, "Loading training data") 302 | train = ModelDataset(dir=random_rollouts_dir, load_batch_size=args.load_batch_size, verbose=False) 303 | train_iter = chainer.iterators.SerialIterator(train, batch_size=1, shuffle=False) 304 | 305 | updater = TBPTTUpdater(train_iter, optimizer, args.gpu, model.get_loss_func(), args.sequence_length) 306 | 307 | trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=output_dir) 308 | trainer.extend(extensions.snapshot(), trigger=(args.snapshot_interval, 'iteration')) 309 | trainer.extend(extensions.LogReport(trigger=(10 if args.gpu >= 0 else 1, 'iteration'))) 310 | trainer.extend(extensions.PrintReport(['epoch', 'iteration', 'loss', 'elapsed_time'])) 311 | if not args.no_progress_bar: 312 | trainer.extend(extensions.ProgressBar(update_interval=10 if args.gpu >= 0 else 1)) 313 | 314 | sample_size = 256 315 | rollout_z_t, rollout_z_t_plus_1, rollout_action, done = train[0] 316 | sample_z_t = rollout_z_t[0:sample_size] 317 | sample_z_t_plus_1 = rollout_z_t_plus_1[0:sample_size] 318 | sample_action = rollout_action[0:sample_size] 319 | img_t = vision.decode(sample_z_t).data 320 | img_t_plus_1 = vision.decode(sample_z_t_plus_1).data 321 | if args.predict_done: 322 | done = done.reshape(-1) 323 | img_t_plus_1[np.where(done[0:sample_size] >= 0.5), :, :, :] = 0 # Make done black 324 | save_images_collage(img_t, os.path.join(output_dir, 'train_t.png')) 325 | save_images_collage(img_t_plus_1, os.path.join(output_dir, 'train_t_plus_1.png')) 326 | image_sampler = ImageSampler(model.copy(), vision, args, output_dir, sample_z_t, sample_action) 327 | trainer.extend(image_sampler, trigger=(args.snapshot_interval, 'iteration')) 328 | 329 | if args.resume_from: 330 | log(ID, "Resuming trainer manually from snapshot: " + args.resume_from) 331 | chainer.serializers.load_npz(args.resume_from, trainer) 332 | elif not args.no_resume and auto_resume_file is not None: 333 | log(ID, "Auto resuming trainer from last snapshot: " + auto_resume_file) 334 | chainer.serializers.load_npz(auto_resume_file, trainer) 335 | 336 | if not args.test: 337 | log(ID, "Starting training") 338 | trainer.run() 339 | log(ID, "Done training") 340 | log(ID, "Saving model") 341 | chainer.serializers.save_npz(os.path.join(output_dir, ID + ".model"), model) 342 | 343 | if args.test: 344 | log(ID, "Saving test samples") 345 | image_sampler(trainer) 346 | 347 | log(ID, "Generating gif for a rollout generated in dream") 348 | if args.gpu >= 0: 349 | model.to_cpu() 350 | model.reset_state() 351 | # current_z_t = np.random.randn(64).astype(np.float32) # Noise as starting frame 352 | rollout_z_t, rollout_z_t_plus_1, rollout_action, done = train[np.random.randint(len(train))] # Pick a random real rollout 353 | current_z_t = rollout_z_t[0] # Starting frame from the real rollout 354 | current_z_t += np.random.normal(0, 0.5, current_z_t.shape).astype(np.float32) # Add some noise to the real rollout starting frame 355 | all_z_t = [current_z_t] 356 | # current_action = np.asarray([0., 1.]).astype(np.float32) 357 | for i in range(rollout_z_t.shape[0]): 358 | # if i != 0 and i % 200 == 0: current_action = 1 - current_action # Flip actions every 100 frames 359 | current_action = np.expand_dims(rollout_action[i], 0) # follow actions performed in a real rollout 360 | output = model(current_z_t, current_action, temperature=args.sample_temperature) 361 | if args.predict_done: 362 | current_z_t, done = output 363 | done = done.data 364 | # print(i, current_action, done) 365 | else: 366 | current_z_t = output 367 | all_z_t.append(current_z_t.data) 368 | if args.predict_done and done[0] >= 0.5: 369 | break 370 | dream_rollout_imgs = vision.decode(np.asarray(all_z_t).astype(np.float32)).data 371 | dream_rollout_imgs = post_process_image_tensor(dream_rollout_imgs) 372 | imageio.mimsave(os.path.join(output_dir, 'dream_rollout.gif'), dream_rollout_imgs, fps=20) 373 | 374 | log(ID, "Done") 375 | 376 | 377 | if __name__ == '__main__': 378 | main() -------------------------------------------------------------------------------- /random_rollouts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | from multiprocessing import cpu_count, Pool 5 | 6 | import gzip 7 | import gym 8 | from scipy.misc import imresize 9 | import numpy as np 10 | 11 | from lib.utils import log, mkdir 12 | from lib.constants import DOOM_GAMES 13 | try: 14 | from lib.env_wrappers import ViZDoomWrapper 15 | except Exception as e: 16 | None 17 | 18 | ID = "random_rollouts" 19 | 20 | 21 | def generate_action(low, high, prev_action, balance_no_actions=False, force_actions=False): 22 | if np.random.randint(10) % 10 and prev_action is not None: 23 | return (prev_action) 24 | 25 | action_len = len(low) 26 | action = [0 for i in range(action_len)] 27 | while True: 28 | for i in range(action_len): 29 | random = np.random.randint(low[i], high[i]+1) 30 | if random % action_len: 31 | action[i] = random 32 | # Because in many games all 0's or all 1's are the same action (no action), we want to limit 33 | # those to be the equal probability as all other combinations of actions. Maybe there's a 34 | # better way, or I'm not thinking of something: 35 | if balance_no_actions and (all(a == 0 for a in action) or all(a == 1 for a in action)) and np.random.randint(2)==0: 36 | action = [0 for i in range(action_len)] 37 | continue 38 | if force_actions and all(a == 0 for a in action): 39 | continue 40 | break 41 | 42 | return (np.array(action).astype(np.float32)) 43 | 44 | 45 | def worker(worker_arg_tuple): 46 | rollouts_per_core, args, output_dir = worker_arg_tuple 47 | 48 | np.random.seed() 49 | 50 | if args.game in DOOM_GAMES: 51 | env = ViZDoomWrapper(args.game) 52 | else: 53 | env = gym.make(args.game) 54 | 55 | for rollout_num in rollouts_per_core: 56 | t = 1 57 | 58 | actions_array = [] 59 | frames_array = [] 60 | rewards_array = [] 61 | 62 | observation = env.reset() 63 | frames_array.append(imresize(observation, (args.frame_resize, args.frame_resize))) 64 | 65 | start_time = time.time() 66 | prev_action = None 67 | while True: 68 | # action = env.action_space.sample() 69 | action = generate_action(env.action_space.low, 70 | env.action_space.high, 71 | prev_action, 72 | balance_no_actions=True if args.game in DOOM_GAMES else False, 73 | force_actions=False if args.game in DOOM_GAMES else True) 74 | prev_action = action 75 | observation, reward, done, _ = env.step(action) 76 | actions_array.append(action) 77 | frames_array.append(imresize(observation, (args.frame_resize, args.frame_resize))) 78 | rewards_array.append(reward) 79 | 80 | if done: 81 | log(ID, 82 | "\t> Rollout {}/{} finished after {} timesteps in {:.2f}s".format(rollout_num, args.num_rollouts, t, 83 | (time.time() - start_time))) 84 | break 85 | t = t + 1 86 | 87 | actions_array = np.asarray(actions_array) 88 | frames_array = np.asarray(frames_array) 89 | rewards_array = np.asarray(rewards_array).astype(np.float32) 90 | 91 | rollout_dir = os.path.join(output_dir, str(rollout_num)) 92 | mkdir(rollout_dir) 93 | 94 | # from lib.utils import post_process_image_tensor 95 | # import imageio 96 | # imageio.mimsave(os.path.join(output_dir, str(rollout_num), 'rollout.gif'), post_process_image_tensor(frames_array), fps=20) 97 | 98 | with gzip.GzipFile(os.path.join(rollout_dir, "frames.npy.gz"), "w") as file: 99 | np.save(file, frames_array) 100 | np.savez_compressed(os.path.join(rollout_dir, "misc.npz"), 101 | action=actions_array, 102 | reward=rewards_array) 103 | with open(os.path.join(rollout_dir, "count"), "w") as file: 104 | print("{}".format(frames_array.shape[0]), file=file) 105 | 106 | env.close() 107 | 108 | 109 | def main(): 110 | parser = argparse.ArgumentParser(description='World Models ' + ID) 111 | parser.add_argument('--data_dir', '-d', default="/data/wm", help='The base data/output directory') 112 | parser.add_argument('--game', default='CarRacing-v0', 113 | help='Game to use') # https://gym.openai.com/envs/CarRacing-v0/ 114 | parser.add_argument('--experiment_name', default='experiment_1', help='To isolate its files from others') 115 | parser.add_argument('--num_rollouts', '-n', default=100, type=int, help='Number of rollouts to collect') 116 | parser.add_argument('--offset', '-o', default=0, type=int, 117 | help='Offset rollout count, in case running on distributed cluster') 118 | parser.add_argument('--frame_resize', '-r', default=64, type=int, help='h x w resize of each observation frame') 119 | parser.add_argument('--cores', default=0, type=int, help='Number of CPU cores to use. 0=all cores') 120 | args = parser.parse_args() 121 | log(ID, "args =\n " + str(vars(args)).replace(",", ",\n ")) 122 | 123 | output_dir = os.path.join(args.data_dir, args.game, args.experiment_name, ID) 124 | mkdir(output_dir) 125 | 126 | log(ID, "Starting") 127 | 128 | if args.cores == 0: 129 | cores = cpu_count() 130 | else: 131 | cores = args.cores 132 | start = 1 + args.offset 133 | end = args.num_rollouts + 1 + args.offset 134 | rollouts_per_core = np.array_split(range(start, end), cores) 135 | pool = Pool(cores) 136 | worker_arg_tuples = [] 137 | for i in rollouts_per_core: 138 | if len(i) != 0: 139 | worker_arg_tuples.append((i, args, output_dir)) 140 | pool.map(worker, worker_arg_tuples) 141 | pool.close() 142 | pool.join() 143 | 144 | log(ID, "Done") 145 | 146 | 147 | if __name__ == '__main__': 148 | main() 149 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | from multiprocessing import cpu_count, Pool 5 | import gzip 6 | import traceback 7 | 8 | import chainer 9 | 10 | import numpy as np 11 | from scipy.misc import imresize 12 | import gym 13 | import imageio 14 | 15 | from lib.utils import log, mkdir, pre_process_image_tensor, post_process_image_tensor 16 | from lib.constants import DOOM_GAMES 17 | try: 18 | from lib.env_wrappers import ViZDoomWrapper 19 | except Exception as e: 20 | None 21 | from model import MDN_RNN 22 | from vision import CVAE 23 | from controller import transform_to_weights, action 24 | 25 | ID = "test" 26 | 27 | 28 | def worker(worker_arg_tuple): 29 | try: 30 | rollout_num, args, vision, model, W_c, b_c, output_dir = worker_arg_tuple 31 | 32 | np.random.seed() 33 | 34 | model.reset_state() 35 | 36 | if args.game in DOOM_GAMES: 37 | env = ViZDoomWrapper(args.game) 38 | else: 39 | env = gym.make(args.game) 40 | 41 | h_t = np.zeros(args.hidden_dim).astype(np.float32) 42 | c_t = np.zeros(args.hidden_dim).astype(np.float32) 43 | 44 | t = 0 45 | cumulative_reward = 0 46 | if args.record: 47 | frames_array = [] 48 | 49 | observation = env.reset() 50 | if args.record: 51 | frames_array.append(observation) 52 | 53 | start_time = time.time() 54 | while True: 55 | observation = imresize(observation, (args.frame_resize, args.frame_resize)) 56 | observation = pre_process_image_tensor(np.expand_dims(observation, 0)) 57 | 58 | z_t = vision.encode(observation, return_z=True).data[0] 59 | 60 | a_t = action(args, W_c, b_c, z_t, h_t, c_t, None) 61 | 62 | observation, reward, done, _ = env.step(a_t) 63 | model(z_t, a_t, temperature=args.temperature) 64 | 65 | if args.record: 66 | frames_array.append(observation) 67 | cumulative_reward += reward 68 | 69 | h_t = model.get_h().data[0] 70 | c_t = model.get_c().data[0] 71 | 72 | t += 1 73 | 74 | if done: 75 | break 76 | 77 | log(ID, 78 | "> Rollout #{} finished after {} timesteps in {:.2f}s with cumulative reward {:.2f}".format( 79 | (rollout_num + 1), t, 80 | (time.time() - start_time), 81 | cumulative_reward) 82 | ) 83 | 84 | env.close() 85 | 86 | if args.record: 87 | frames_array = np.asarray(frames_array) 88 | imageio.mimsave(os.path.join(output_dir, str(rollout_num + 1) + '.gif'), 89 | post_process_image_tensor(frames_array), 90 | fps=20) 91 | 92 | return cumulative_reward 93 | except Exception: 94 | print(traceback.format_exc()) 95 | return 0. 96 | 97 | 98 | def main(): 99 | parser = argparse.ArgumentParser(description='World Models ' + ID) 100 | parser.add_argument('--data_dir', '-d', default="/data/wm", help='The base data/output directory') 101 | parser.add_argument('--game', default='CarRacing-v0', 102 | help='Game to use') # https://gym.openai.com/envs/CarRacing-v0/ 103 | parser.add_argument('--experiment_name', default='experiment_1', help='To isolate its files from others') 104 | parser.add_argument('--rollouts', '-n', default=100, type=int, help='Number of times to rollout') 105 | parser.add_argument('--frame_resize', default=64, type=int, help='h x w resize of each observation frame') 106 | parser.add_argument('--hidden_dim', default=256, type=int, help='LSTM hidden units') 107 | parser.add_argument('--z_dim', '-z', default=32, type=int, help='dimension of encoded vector') 108 | parser.add_argument('--mixtures', default=5, type=int, help='number of gaussian mixtures for MDN') 109 | parser.add_argument('--temperature', '-t', default=1.0, type=float, help='Temperature (tau) for MDN-RNN (model)') 110 | parser.add_argument('--predict_done', action='store_true', help='Whether MDN-RNN should also predict done state') 111 | parser.add_argument('--cores', default=0, type=int, help='Number of CPU cores to use. 0=all cores') 112 | parser.add_argument('--weights_type', default=1, type=int, 113 | help="1=action_dim*(z_dim+hidden_dim), 2=z_dim+2*hidden_dim") 114 | parser.add_argument('--record', action='store_true', help='Record as gifs') 115 | 116 | args = parser.parse_args() 117 | log(ID, "args =\n " + str(vars(args)).replace(",", ",\n ")) 118 | 119 | if args.game in DOOM_GAMES: 120 | env = ViZDoomWrapper(args.game) 121 | else: 122 | env = gym.make(args.game) 123 | action_dim = len(env.action_space.low) 124 | args.action_dim = action_dim 125 | env = None 126 | 127 | if args.cores == 0: 128 | cores = cpu_count() 129 | else: 130 | cores = args.cores 131 | 132 | output_dir = os.path.join(args.data_dir, args.game, args.experiment_name, ID) 133 | mkdir(output_dir) 134 | model_dir = os.path.join(args.data_dir, args.game, args.experiment_name, 'model') 135 | vision_dir = os.path.join(args.data_dir, args.game, args.experiment_name, 'vision') 136 | controller_dir = os.path.join(args.data_dir, args.game, args.experiment_name, 'controller') 137 | 138 | model = MDN_RNN(args.hidden_dim, args.z_dim, args.mixtures, args.predict_done) 139 | chainer.serializers.load_npz(os.path.join(model_dir, "model.model"), model) 140 | vision = CVAE(args.z_dim) 141 | chainer.serializers.load_npz(os.path.join(vision_dir, "vision.model"), vision) 142 | # controller = np.random.randn(action_dim * (args.z_dim + args.hidden_dim) + action_dim).astype(np.float32) 143 | # controller = np.random.randn(args.z_dim + 2 * args.hidden_dim).astype(np.float32) 144 | controller = np.load(os.path.join(controller_dir, "controller.model"))['xmean'] 145 | W_c, b_c = transform_to_weights(args, controller) 146 | 147 | log(ID, "Starting") 148 | 149 | worker_arg_tuples = [] 150 | for rollout_num in range(args.rollouts): 151 | worker_arg_tuples.append((rollout_num, args, vision, model.copy(), W_c, b_c, output_dir)) 152 | pool = Pool(cores) 153 | cumulative_rewards = pool.map(worker, worker_arg_tuples) 154 | pool.close() 155 | pool.join() 156 | 157 | log(ID, "Cumulative Rewards:") 158 | for rollout_num in range(args.rollouts): 159 | log(ID, "> #{} = {:.2f}".format((rollout_num + 1), cumulative_rewards[rollout_num])) 160 | 161 | log(ID, "Mean: {:.2f} Std: {:.2f}".format(np.mean(cumulative_rewards), np.std(cumulative_rewards))) 162 | log(ID, "Highest: #{} = {:.2f} Lowest: #{} = {:.2f}" 163 | .format(np.argmax(cumulative_rewards) + 1, np.amax(cumulative_rewards), 164 | np.argmin(cumulative_rewards) + 1, np.amin(cumulative_rewards))) 165 | 166 | cumulative_rewards_file = os.path.join(output_dir, "cumulative_rewards.npy.gz") 167 | log(ID, "Saving cumulative rewards to: " + os.path.join(output_dir, "cumulative_rewards.npy.gz")) 168 | with gzip.GzipFile(cumulative_rewards_file, "w") as file: 169 | np.save(file, cumulative_rewards) 170 | 171 | # To load: 172 | # with gzip.GzipFile(cumulative_rewards_file, "r") as file: 173 | # cumulative_rewards = np.load(file) 174 | 175 | log(ID, "Done") 176 | 177 | 178 | if __name__ == '__main__': 179 | main() 180 | -------------------------------------------------------------------------------- /toy/cma-es.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import matplotlib.pyplot as plt 4 | 5 | plt.switch_backend('agg') 6 | from scipy.misc import imread 7 | 8 | import numpy as np 9 | import imageio 10 | 11 | solution = [10, 10] 12 | 13 | 14 | def schaffer(x, y): 15 | x -= solution[0] 16 | y -= solution[1] 17 | return 0.5 + ((np.sin((x ** 2) - (y ** 2)) ** 2) - 0.5) \ 18 | / \ 19 | ((1 + 0.001 * ((x ** 2) + (y ** 2))) ** 2) 20 | 21 | 22 | def f(w): 23 | return -(schaffer(w[0], w[1]) - schaffer(solution[0], solution[1])) ** 2 24 | 25 | 26 | def main(): 27 | plt.switch_backend('agg') 28 | 29 | N = 2 30 | xmean = np.random.randn(N) 31 | sigma = 0.3 32 | stopeval = 1e3 * N ** 2 33 | stopfitness = 1e-10 34 | 35 | λ = 64 # 4+int(3*np.log(N)) 36 | mu = λ // 4 37 | weights = np.log(mu + 1 / 2) - np.log(np.asarray(range(1, mu + 1))).astype(np.float32) 38 | weights = weights / np.sum(weights) 39 | mueff = (np.sum(weights) ** 2) / np.sum(weights ** 2) 40 | 41 | cc = (4 + mueff / N) / (N + 4 + 2 * mueff / N) 42 | cs = (mueff + 2) / (N + mueff + 5) 43 | c1 = 2 / ((N + 1.3) ** 2 + mueff) 44 | cmu = min(1 - c1, 2 * (mueff - 2 + 1 / mueff) / ((N + 2) ** 2 + mueff)) 45 | damps = 1 + 2 * max(0, ((mueff - 1) / (N + 1)) ** 0.5 - 1) + cs 46 | 47 | pc = np.zeros(N).astype(np.float32) 48 | ps = np.zeros(N).astype(np.float32) 49 | B = np.eye(N, N).astype(np.float32) 50 | D = np.ones(N).astype(np.float32) 51 | 52 | C = B * np.diag(D ** 2) * B.T 53 | invsqrtC = B * np.diag(D ** -1) * B.T 54 | eigeneval = 0 55 | chiN = N ** 0.5 * (1 - 1 / (4 * N) + 1 / (21 * N ** 2)) 56 | 57 | counteval = 0 58 | generation = 0 59 | solution_found = False 60 | graphs = [] 61 | while counteval < stopeval: 62 | arx = np.zeros((λ, N)) 63 | arfitness = np.zeros(λ) 64 | for k in range(λ): 65 | arx[k] = xmean + sigma * B.dot(D * np.random.randn(N)) 66 | arfitness[k] = f(arx[k]) 67 | counteval += 1 68 | 69 | plt.ylim(-1, 20) 70 | plt.xlim(-1, 20) 71 | plt.plot(solution[0], solution[1], "b.") 72 | plt.plot(arx[:, 0], arx[:, 1], "r.") 73 | plt.plot(np.mean(arx[:, 0]), np.mean(arx[:, 1]), "g.") 74 | buf = io.BytesIO() 75 | plt.savefig(buf, format='png') 76 | plt.clf() 77 | buf.seek(0) 78 | img = imread(buf) 79 | buf.close() 80 | graphs.append(img) 81 | 82 | arindex = np.argsort(-arfitness) 83 | arfitness = arfitness[arindex] 84 | 85 | xold = xmean 86 | xmean = weights.dot(arx[arindex[0:mu]]) 87 | 88 | ps = (1 - cs) * ps + np.sqrt(cs * (2 - cs) * mueff) * invsqrtC.dot((xmean - xold) / sigma) 89 | hsig = np.linalg.norm(ps) / np.sqrt(1 - (1 - cs) ** (2 * counteval / λ)) / chiN < 1.4 + 2 / (N + 1) 90 | pc = (1 - cc) * pc + hsig * np.sqrt(cc * (2 - cc) * mueff) * ((xmean - xold) / sigma) 91 | artmp = (1 / sigma) * (arx[arindex[0:mu]] - xold) 92 | C = (1 - c1 - cmu) * C + c1 * (pc.dot(pc.T) + (1 - hsig) * cc * (2 - cc) * C) + cmu * artmp.T.dot( 93 | np.diag(weights)).dot(artmp) 94 | sigma = sigma * np.exp((cs / damps) * (np.linalg.norm(ps) / chiN - 1)) 95 | 96 | if counteval - eigeneval > λ / (c1 + cmu) / N / 10: 97 | eigeneval = counteval 98 | C = np.triu(C) + np.triu(C, 1).T 99 | D, B = np.linalg.eig(C) 100 | D = np.sqrt(D) 101 | invsqrtC = B.dot(np.diag(D ** -1).dot(B.T)) 102 | 103 | generation += 1 104 | 105 | if arfitness[0] >= -stopfitness: 106 | solution_found = True 107 | break 108 | 109 | if solution_found: 110 | print("Solution found at generation #" + str(generation)) 111 | else: 112 | print("Solution not found") 113 | 114 | if not os.path.exists("result"): 115 | os.makedirs("result") 116 | imageio.mimsave('result/cma-es.gif', graphs) 117 | 118 | 119 | if __name__ == '__main__': 120 | main() 121 | -------------------------------------------------------------------------------- /toy/mdn.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | plt.switch_backend('agg') 4 | 5 | import chainer 6 | import chainer.functions as F 7 | import chainer.links as L 8 | from chainer import training, datasets, iterators, report 9 | from chainer.training import extensions 10 | 11 | import numpy as np 12 | 13 | 14 | class MDN(chainer.Chain): 15 | def __init__(self, hidden_dim, output_dim, k): 16 | self.output_dim = output_dim 17 | self.hidden_dim = hidden_dim 18 | self.k = k 19 | super(MDN, self).__init__( 20 | input_layer=L.Linear(None, hidden_dim), 21 | coef_layer=L.Linear(hidden_dim, k * output_dim), 22 | mu_layer=L.Linear(hidden_dim, k * output_dim), 23 | ln_var_layer=L.Linear(hidden_dim, k * output_dim), 24 | ) 25 | 26 | def __call__(self, input): 27 | coef, mu, ln_var = self.fprop(input) 28 | 29 | def sample(row_num): 30 | cum_prod = 0 31 | r = np.random.uniform() 32 | index = None 33 | for i, probability in enumerate(coef[row_num]): 34 | cum_prod += sum(probability) 35 | if r <= cum_prod.data: 36 | index = i 37 | break 38 | 39 | return F.gaussian(mu[row_num][index], ln_var[row_num][index]) 40 | 41 | output = F.expand_dims(sample(0), 0) 42 | for row_num in range(1, input.shape[0]): 43 | this_output = F.expand_dims(sample(row_num), 0) 44 | output = F.concat((output, this_output), axis=0) 45 | 46 | return output 47 | 48 | def fprop(self, input): 49 | k = self.k 50 | output_dim = self.output_dim 51 | 52 | h = self.input_layer(input) 53 | 54 | coef = F.softmax(self.coef_layer(h)) 55 | mu = self.mu_layer(h) 56 | ln_var = self.ln_var_layer(h) 57 | 58 | mu = F.reshape(mu, (-1, k, output_dim)) 59 | coef = F.reshape(coef, (-1, k, output_dim)) 60 | ln_var = F.reshape(ln_var, (-1, k, output_dim)) 61 | 62 | return coef, mu, ln_var 63 | 64 | def get_loss_func(self): 65 | def lf(input, output, epsilon=1e-8): 66 | output_dim = self.output_dim 67 | 68 | coef, mu, ln_var = self.fprop(input) 69 | 70 | output = F.reshape(output, (-1, 1, output_dim)) 71 | mu, output = F.broadcast(mu, output) 72 | 73 | var = F.exp(ln_var) 74 | 75 | density = F.sum( 76 | coef * 77 | (1 / (np.sqrt(2 * np.pi) * F.sqrt(var))) * 78 | F.exp(-0.5 * F.square(output - mu) / var) 79 | , axis=1) 80 | 81 | nll = -F.sum(F.log(density)) 82 | report({'loss': nll}, self) 83 | return nll 84 | 85 | return lf 86 | 87 | 88 | class Linear(chainer.Chain): 89 | def __init__(self, hidden_dim, output_dim): 90 | self.output_dim = output_dim 91 | super(Linear, self).__init__( 92 | input_layer=L.Linear(None, hidden_dim), 93 | output_layer=L.Linear(hidden_dim, output_dim), 94 | ) 95 | 96 | def __call__(self, input): 97 | return self.fprop(input) 98 | 99 | def fprop(self, input): 100 | h = self.input_layer(input) 101 | return self.output_layer(h) 102 | 103 | def get_loss_func(self): 104 | def lf(input, output): 105 | pred = self.fprop(input) 106 | loss = F.mean_squared_error(output.reshape(-1, 1), pred) 107 | report({'loss': loss}, self) 108 | return loss 109 | 110 | return lf 111 | 112 | 113 | def main(): 114 | model = MDN(256, 1, 5) 115 | # model = Linear(256, 1) 116 | 117 | points = 500 118 | 119 | y = np.random.rand(points).astype(np.float32) 120 | x = np.sin(2 * np.pi * y) + 0.2 * np.random.rand(points) * (np.cos(2 * np.pi * y) + 2) 121 | x = x.astype(np.float32) 122 | 123 | optimizer = chainer.optimizers.Adam() 124 | optimizer.setup(model) 125 | dataset = datasets.tuple_dataset.TupleDataset(x.reshape(-1, 1), y) 126 | train_iter = iterators.SerialIterator(dataset, batch_size=100) 127 | updater = training.StandardUpdater(train_iter, optimizer, loss_func=model.get_loss_func()) 128 | trainer = training.Trainer(updater, (2000, 'epoch')) 129 | trainer.extend(extensions.LogReport()) 130 | trainer.extend(extensions.PrintReport(['epoch', 'main/loss'])) 131 | trainer.run() 132 | 133 | plt.ylim(-0.1, 1.1) 134 | plt.plot(x, y, "b.") 135 | plt.savefig("result/mdn-data_only.png") 136 | plt.clf() 137 | 138 | x_test = np.linspace(min(x), max(x), points).astype(np.float32) 139 | y_pred = model(x_test.reshape(-1, 1)).data 140 | 141 | plt.ylim(-0.1, 1.1) 142 | plt.plot(x, y, "b.") 143 | plt.plot(x_test, y_pred, "r.") 144 | plt.savefig("result/mdn-with_preds.png") 145 | 146 | 147 | if __name__ == '__main__': 148 | main() 149 | -------------------------------------------------------------------------------- /vision.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | import gc 5 | 6 | import chainer.functions as F 7 | import chainer.links as L 8 | import chainer 9 | from chainer import training 10 | from chainer.training import extensions 11 | try: 12 | import cupy as cp 13 | except Exception as e: 14 | None 15 | 16 | import numpy as np 17 | 18 | from lib.data import VisionDataset 19 | from lib.utils import save_images_collage, mkdir, log, pre_process_image_tensor, post_process_image_tensor 20 | 21 | ID = "vision" 22 | 23 | 24 | class CVAE(chainer.Chain): 25 | def __init__(self, n_latent): 26 | self.n_latent = n_latent 27 | super(CVAE, self).__init__( 28 | e_c0=L.Convolution2D(None, 32, 4, 2), 29 | e_c1=L.Convolution2D(None, 64, 4, 2), 30 | e_c2=L.Convolution2D(None, 128, 4, 2), 31 | e_c3=L.Convolution2D(None, 256, 4, 2), 32 | e_mu=L.Linear(None, n_latent), 33 | e_ln_var=L.Linear(None, n_latent), 34 | 35 | d_l0=L.Linear(n_latent, 1024), 36 | d_dc0=L.Deconvolution2D(None, 128, 5, 2), 37 | d_dc1=L.Deconvolution2D(None, 64, 5, 2), 38 | d_dc2=L.Deconvolution2D(None, 32, 6, 2), 39 | d_dc3=L.Deconvolution2D(None, 3, 6, 2), 40 | ) 41 | 42 | def __call__(self, frames, pre_process=False): 43 | if len(frames.shape) == 3: 44 | frames = F.expand_dims(frames, 0) 45 | if pre_process: 46 | frames = pre_process_image_tensor(frames) 47 | frames_variational = self.decode(self.encode(frames, return_z=True)) 48 | if pre_process: 49 | frames_variational = post_process_image_tensor(frames_variational) 50 | return frames_variational 51 | 52 | def encode(self, frames, return_z=False): 53 | if len(frames.shape) == 3: 54 | frames = F.expand_dims(frames, 0) 55 | h = F.relu(self.e_c0(frames)) 56 | h = F.relu(self.e_c1(h)) 57 | h = F.relu(self.e_c2(h)) 58 | h = F.relu(self.e_c3(h)) 59 | h = F.reshape(h, (-1, 1024)) 60 | mu = self.e_mu(h) 61 | ln_var = self.e_ln_var(h) 62 | if return_z: 63 | return F.gaussian(mu, ln_var) 64 | else: 65 | return mu, ln_var 66 | 67 | def decode(self, z): 68 | if len(z.shape) == 1: 69 | z = F.expand_dims(z, 0) 70 | h = self.d_l0(z) 71 | h = F.reshape(h, (-1, 1024, 1, 1)) 72 | h = F.relu(self.d_dc0(h)) 73 | h = F.relu(self.d_dc1(h)) 74 | h = F.relu(self.d_dc2(h)) 75 | h = F.sigmoid(self.d_dc3(h)) 76 | return h 77 | 78 | def get_loss_func(self, kl_tolerance=0.5): 79 | self.kl_tolerance = kl_tolerance 80 | def lf(frames): 81 | mu, ln_var = self.encode(frames) 82 | z = F.gaussian(mu, ln_var) 83 | frames_flat = F.reshape(frames, (-1, frames.shape[1] * frames.shape[2] * frames.shape[3])) 84 | variational_flat = F.reshape(self.decode(z), (-1, frames.shape[1] * frames.shape[2] * frames.shape[3])) 85 | rec_loss = F.sum(F.square(frames_flat - variational_flat), axis=1) # l2 reconstruction loss 86 | rec_loss = F.mean(rec_loss) 87 | kl_loss = F.sum(F.gaussian_kl_divergence(mu, ln_var, reduce="no"), axis=1) 88 | if self._cpu: 89 | kl_tolerance = np.asarray(self.kl_tolerance * self.n_latent).astype(np.float32) 90 | else: 91 | kl_tolerance = cp.asarray(self.kl_tolerance * self.n_latent).astype(cp.float32) 92 | kl_loss = F.maximum(kl_loss, F.broadcast_to(kl_tolerance, kl_loss.shape)) 93 | kl_loss = F.mean(kl_loss) 94 | loss = rec_loss + kl_loss 95 | chainer.report({'loss': loss}, observer=self) 96 | chainer.report({'kl_loss': kl_loss}, observer=self) 97 | chainer.report({'rec_loss': rec_loss}, observer=self) 98 | return loss 99 | return lf 100 | 101 | 102 | class Sampler(chainer.training.Extension): 103 | def __init__(self, model, args, output_dir, frames, z): 104 | self.model = model 105 | self.args = args 106 | self.output_dir = output_dir 107 | self.frames = frames 108 | self.z = z 109 | 110 | def __call__(self, trainer): 111 | if self.args.gpu >= 0: 112 | self.model.to_cpu() 113 | 114 | with chainer.using_config('train', False), chainer.no_backprop_mode(): 115 | frames_variational = self.model(self.frames) 116 | save_images_collage(frames_variational.data, 117 | os.path.join(self.output_dir, 118 | 'train_reconstructed_{}.png'.format(trainer.updater.iteration))) 119 | 120 | with chainer.using_config('train', False), chainer.no_backprop_mode(): 121 | frames_variational = self.model.decode(self.z) 122 | save_images_collage(frames_variational.data, 123 | os.path.join(self.output_dir, 'sampled_{}.png'.format(trainer.updater.iteration))) 124 | 125 | if self.args.gpu >= 0: 126 | self.model.to_gpu() 127 | 128 | 129 | def main(): 130 | parser = argparse.ArgumentParser(description='World Models ' + ID) 131 | parser.add_argument('--data_dir', '-d', default="/data/wm", help='The base data/output directory') 132 | parser.add_argument('--game', default='CarRacing-v0', 133 | help='Game to use') # https://gym.openai.com/envs/CarRacing-v0/ 134 | parser.add_argument('--experiment_name', default='experiment_1', help='To isolate its files from others') 135 | parser.add_argument('--load_batch_size', default=10, type=int, 136 | help='Load game frames in batches so as not to run out of memory') 137 | parser.add_argument('--model', '-m', default='', 138 | help='Initialize the model from given file, or "default" for one in data folder') 139 | parser.add_argument('--no_resume', action='store_true', help='Don''t auto resume from the latest snapshot') 140 | parser.add_argument('--resume_from', '-r', default='', help='Resume the optimization from a specific snapshot') 141 | parser.add_argument('--test', action='store_true', help='Generate samples only') 142 | parser.add_argument('--gpu', '-g', default=-1, type=int, help='GPU ID (negative value indicates CPU)') 143 | parser.add_argument('--epoch', '-e', default=1, type=int, help='number of epochs to learn') 144 | parser.add_argument('--snapshot_interval', '-s', default=100, type=int, 145 | help='100 = snapshot every 100itr*batch_size imgs processed') 146 | parser.add_argument('--z_dim', '-z', default=32, type=int, help='dimension of encoded vector') 147 | parser.add_argument('--batch_size', '-b', type=int, default=100, help='learning minibatch size') 148 | parser.add_argument('--no_progress_bar', '-p', action='store_true', help='Display progress bar during training') 149 | parser.add_argument('--kl_tolerance', type=float, default=0.5, help='') 150 | 151 | args = parser.parse_args() 152 | log(ID, "args =\n " + str(vars(args)).replace(",", ",\n ")) 153 | 154 | output_dir = os.path.join(args.data_dir, args.game, args.experiment_name, ID) 155 | random_rollouts_dir = os.path.join(args.data_dir, args.game, args.experiment_name, 'random_rollouts') 156 | mkdir(output_dir) 157 | 158 | max_iter = 0 159 | auto_resume_file = None 160 | files = os.listdir(output_dir) 161 | for file in files: 162 | if re.match(r'^snapshot_iter_', file): 163 | iter = int(re.search(r'\d+', file).group()) 164 | if (iter > max_iter): 165 | max_iter = iter 166 | if max_iter > 0: 167 | auto_resume_file = os.path.join(output_dir, "snapshot_iter_{}".format(max_iter)) 168 | 169 | model = CVAE(args.z_dim) 170 | 171 | if args.model: 172 | if args.model == 'default': 173 | args.model = os.path.join(output_dir, ID + ".model") 174 | log(ID, "Loading saved model from: " + args.model) 175 | chainer.serializers.load_npz(args.model, model) 176 | 177 | optimizer = chainer.optimizers.Adam(alpha=0.0001) 178 | optimizer.setup(model) 179 | 180 | log(ID, "Loading training data") 181 | train = VisionDataset(dir=random_rollouts_dir, load_batch_size=args.load_batch_size, shuffle=True, verbose=True) 182 | train_iter = chainer.iterators.SerialIterator(train, args.batch_size, shuffle=False) 183 | 184 | updater = training.StandardUpdater( 185 | train_iter, optimizer, 186 | device=args.gpu, loss_func=model.get_loss_func(args.kl_tolerance)) 187 | 188 | trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=output_dir) 189 | trainer.extend(extensions.snapshot(), trigger=(args.snapshot_interval, 'iteration')) 190 | trainer.extend(extensions.LogReport(trigger=(100 if args.gpu >= 0 else 10, 'iteration'))) 191 | trainer.extend(extensions.PrintReport(['epoch', 'iteration', 'main/loss', 'main/kl_loss', 'main/rec_loss', 'elapsed_time'])) 192 | if not args.no_progress_bar: 193 | trainer.extend(extensions.ProgressBar(update_interval=100 if args.gpu >= 0 else 10)) 194 | 195 | sample_idx = np.random.choice(range(train.get_current_batch_size()), 64, replace=False) 196 | sample_frames = chainer.Variable(np.asarray(train[sample_idx])) 197 | np.random.seed(31337) 198 | sample_z = chainer.Variable(np.random.normal(0, 1, (64, args.z_dim)).astype(np.float32)) 199 | save_images_collage(sample_frames.data, os.path.join(output_dir, 'train.png')) 200 | sampler = Sampler(model, args, output_dir, sample_frames, sample_z) 201 | trainer.extend(sampler, trigger=(args.snapshot_interval, 'iteration')) 202 | 203 | if args.resume_from: 204 | log(ID, "Resuming trainer manually from snapshot: " + args.resume_from) 205 | chainer.serializers.load_npz(args.resume_from, trainer) 206 | elif not args.no_resume and auto_resume_file is not None: 207 | log(ID, "Auto resuming trainer from last snapshot: " + auto_resume_file) 208 | chainer.serializers.load_npz(auto_resume_file, trainer) 209 | 210 | if not args.test: 211 | log(ID, "Starting training") 212 | trainer.run() 213 | log(ID, "Done training") 214 | log(ID, "Saving model") 215 | chainer.serializers.save_npz(os.path.join(output_dir, ID + ".model"), model) 216 | 217 | if args.test: 218 | log(ID, "Saving test samples") 219 | sampler(trainer) 220 | 221 | if not args.test: 222 | log(ID, "Saving latent z's for all training data") 223 | train = VisionDataset(dir=random_rollouts_dir, load_batch_size=args.load_batch_size, shuffle=False, 224 | verbose=True) 225 | total_batches = train.get_total_batches() 226 | for batch in range(total_batches): 227 | gc.collect() 228 | train.load_batch(batch) 229 | batch_frames, batch_rollouts, batch_rollouts_counts = train.get_current_batch() 230 | mu = None 231 | ln_var = None 232 | splits = batch_frames.shape[0] // args.batch_size 233 | if batch_frames.shape[0] % args.batch_size != 0: 234 | splits += 1 235 | for i in range(splits): 236 | start_idx = i * args.batch_size 237 | end_idx = (i + 1) * args.batch_size 238 | sample_frames = batch_frames[start_idx:end_idx] 239 | if args.gpu >= 0: 240 | sample_frames = chainer.Variable(cp.asarray(sample_frames)) 241 | else: 242 | sample_frames = chainer.Variable(sample_frames) 243 | this_mu, this_ln_var = model.encode(sample_frames) 244 | this_mu = this_mu.data 245 | this_ln_var = this_ln_var.data 246 | if args.gpu >= 0: 247 | this_mu = cp.asnumpy(this_mu) 248 | this_ln_var = cp.asnumpy(this_ln_var) 249 | if mu is None: 250 | mu = this_mu 251 | ln_var = this_ln_var 252 | else: 253 | mu = np.concatenate((mu, this_mu), axis=0) 254 | ln_var = np.concatenate((ln_var, this_ln_var), axis=0) 255 | running_count = 0 256 | for rollout in batch_rollouts: 257 | rollout_dir = os.path.join(random_rollouts_dir, rollout) 258 | rollout_count = batch_rollouts_counts[rollout] 259 | start_idx = running_count 260 | end_idx = running_count + rollout_count 261 | this_mu = mu[start_idx:end_idx] 262 | this_ln_var = ln_var[start_idx:end_idx] 263 | np.savez_compressed(os.path.join(rollout_dir, "mu+ln_var.npz"), mu=this_mu, ln_var=this_ln_var) 264 | running_count = running_count + rollout_count 265 | log(ID, "> Processed z's for rollouts " + str(batch_rollouts)) 266 | # Free up memory: 267 | batch_frames = None 268 | mu = None 269 | ln_var = None 270 | 271 | log(ID, "Done") 272 | 273 | 274 | if __name__ == '__main__': 275 | main() 276 | --------------------------------------------------------------------------------