├── .gitignore ├── LICENSE ├── README.md ├── gailtf ├── __init__.py ├── algo │ ├── __init__.py │ ├── behavior_clone.py │ └── trpo_mpi.py ├── baselines │ ├── __init__.py │ ├── bench │ │ ├── __init__.py │ │ ├── benchmarks.py │ │ └── monitor.py │ ├── common │ │ ├── __init__.py │ │ ├── atari_wrappers.py │ │ ├── atari_wrappers_deprecated.py │ │ ├── azure_utils.py │ │ ├── cg.py │ │ ├── console_util.py │ │ ├── dataset.py │ │ ├── distributions.py │ │ ├── math_util.py │ │ ├── misc_util.py │ │ ├── mpi_adam.py │ │ ├── mpi_fork.py │ │ ├── mpi_moments.py │ │ ├── mpi_running_mean_std.py │ │ ├── schedules.py │ │ ├── segment_tree.py │ │ ├── tests │ │ │ ├── test_schedules.py │ │ │ ├── test_segment_tree.py │ │ │ └── test_tf_util.py │ │ ├── tf_util.py │ │ └── vec_env │ │ │ ├── __init__.py │ │ │ └── subproc_vec_env.py │ ├── logger.py │ ├── ppo1 │ │ ├── README.md │ │ ├── __init__.py │ │ ├── cnn_policy.py │ │ ├── mlp_policy.py │ │ ├── pposgd_simple.py │ │ └── run_mujoco.py │ └── trpo_mpi │ │ ├── README.md │ │ ├── __init__.py │ │ ├── nosharing_cnn_policy.py │ │ ├── run_mujoco.py │ │ └── trpo_mpi.py ├── common │ ├── __init__.py │ ├── statistics.py │ └── tf_util.py ├── dataset │ ├── __init__.py │ └── mujoco.py └── network │ └── adversary.py ├── main.py └── misc ├── HalfCheetah-D.png ├── HalfCheetah-length-reward(D).png ├── HalfCheetah-true-reward.png ├── Hopper-D.png ├── Hopper-length-reward(D).png ├── Hopper-true-reward.png ├── Walker2d-D.png ├── Walker2d-length-reward(D).png ├── Walker2d-true-reward.png ├── exp.md ├── halfcheetah.png ├── hopper.png └── walker2d.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Platform-specific files and text editors 2 | .DS_Store 3 | *.swp 4 | 5 | # Mujoco 6 | MUJOCO_LOG.TXT 7 | 8 | # TensorFlow checkpoints and logs 9 | log/ 10 | checkpoint/ 11 | 12 | # Pickle files 13 | *.pkl 14 | 15 | 16 | ####### https://github.com/github/gitignore/blob/master/Python.gitignore ####### 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | *$py.class 21 | 22 | # C extensions 23 | *.so 24 | 25 | # Distribution / packaging 26 | .Python 27 | build/ 28 | develop-eggs/ 29 | dist/ 30 | downloads/ 31 | eggs/ 32 | .eggs/ 33 | lib/ 34 | lib64/ 35 | parts/ 36 | sdist/ 37 | var/ 38 | wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | .hypothesis/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | .static_storage/ 72 | .media/ 73 | local_settings.py 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2017 Andrew Liao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Check out the simpler version at [openai/baselines/gail](https://github.com/openai/baselines/blob/master/baselines/gail/README.md)! 2 | 3 | 4 | 5 | 6 | 7 | 8 | # gail-tf 9 | Tensorflow implementation of Generative Adversarial Imitation Learning (and 10 | behavior cloning) 11 | 12 | **disclaimers**: some code is borrowed from @openai/baselines 13 | 14 | ## What's GAIL? 15 | - model free imtation learning -> low sample efficiency in training time 16 | - model-based GAIL: End-to-End Differentiable Adversarial Imitation Learning 17 | - Directly extract policy from demonstrations 18 | - Remove the RL optimization from the inner loop od inverse RL 19 | - Some work based on GAIL: 20 | - Inferring The Latent Structure of Human Decision-Making from Raw Visual 21 | Inputs 22 | - Multi-Modal Imitation Learning from Unstructured Demonstrations using 23 | Generative Adversarial Nets 24 | - Robust Imitation of Diverse Behaviors 25 | 26 | ## Requirements 27 | - python==3.5.2 28 | - mujoco-py==0.5.7 29 | - tensorflow==1.1.0 30 | - gym==0.9.3 31 | 32 | ## Run the code 33 | I separate the code into two parts: (1) Sampling expert data, (2) Imitation 34 | learning with GAIL/BC 35 | 36 | ### Step 1: Generate expert data 37 | 38 | #### Train the expert policy using PPO/TRPO, from openai/baselines 39 | Ensure that `$GAILTF` is set to the path to your gail-tf repository, and 40 | `$ENV_ID` is any valid OpenAI gym environment (e.g. Hopper-v1, HalfCheetah-v1, 41 | etc.) 42 | 43 | ##### Configuration 44 | ``` bash 45 | export GAILTF=/path/to/your/gail-tf 46 | export ENV_ID="Hopper-v1" 47 | export BASELINES_PATH=$GAILTF/gailtf/baselines/ppo1 # use gailtf/baselines/trpo_mpi for TRPO 48 | export SAMPLE_STOCHASTIC="False" # use True for stochastic sampling 49 | export STOCHASTIC_POLICY="False" # use True for a stochastic policy 50 | export PYTHONPATH=$GAILTF:$PYTHONPATH # as mentioned below 51 | cd $GAILTF 52 | ``` 53 | 54 | ##### Train the expert 55 | ```bash 56 | python3 $BASELINES_PATH/run_mujoco.py --env_id $ENV_ID 57 | ``` 58 | 59 | The trained model will save in ```./checkpoint```, and its precise name will 60 | vary based on your optimization method and environment ID. Choose the last 61 | checkpoint in the series. 62 | 63 | ```bash 64 | export PATH_TO_CKPT=./checkpoint/trpo.Hopper.0.00/trpo.Hopper.00-900 65 | ``` 66 | 67 | ##### Sample from the generated expert policy 68 | ```bash 69 | python3 $BASELINES_PATH/run_mujoco.py --env_id $ENV_ID --task sample_trajectory --sample_stochastic $SAMPLE_STOCHASTIC --load_model_path $PATH_TO_CKPT 70 | ``` 71 | 72 | This will generate a pickle file that store the expert trajectories in 73 | ```./XXX.pkl``` (e.g. deterministic.ppo.Hopper.0.00.pkl) 74 | 75 | ```bash 76 | export PICKLE_PATH=./stochastic.trpo.Hopper.0.00.pkl 77 | ``` 78 | 79 | ### Step 2: Imitation learning 80 | 81 | #### Imitation learning via GAIL 82 | 83 | ```bash 84 | python3 main.py --env_id $ENV_ID --expert_path $PICKLE_PATH 85 | ``` 86 | 87 | Usage: 88 | ```bash 89 | --env_id: The environment id 90 | --num_cpu: Number of CPU available during sampling 91 | --expert_path: The path to the pickle file generated in the [previous section]() 92 | --traj_limitation: Limitation of the exerpt trajectories 93 | --g_step: Number of policy optimization steps in each iteration 94 | --d_step: Number of discriminator optimization steps in each iteration 95 | --num_timesteps: Number of timesteps to train (limit the number of timesteps to interact with environment) 96 | ``` 97 | 98 | To view the summary plots in TensorBoard, issue 99 | ```bash 100 | tensorboard --logdir $GAILTF/log 101 | ``` 102 | 103 | ##### Evaluate your GAIL agent 104 | ```bash 105 | python3 main.py --env_id $ENV_ID --task evaluate --stochastic_policy $STOCHASTIC_POLICY --load_model_path $PATH_TO_CKPT --expert_path $PICKLE_PATH 106 | ``` 107 | 108 | #### Imitation learning via Behavioral Cloning 109 | ```bash 110 | python3 main.py --env_id $ENV_ID --algo bc --expert_path $PICKLE_PATH 111 | ``` 112 | 113 | ##### Evaluate your BC agent 114 | ```bash 115 | python3 main.py --env_id $ENV_ID --algo bc --task evalaute --stochastic_policy $STOCHASTIC_POLICY --load_model_path $PATH_TO_CKPT --expert_path $PICKLE_PATH 116 | ``` 117 | 118 | ## Results 119 | 120 | Note: The following hyper-parameter setting is the best that I've tested (simple 121 | grid search on setting with 1500 trajectories), just for your information. 122 | 123 | The different curves below correspond to different expert size (1000,100,10,5). 124 | 125 | - Hopper-v1 (Average total return of expert policy: 3589) 126 | 127 | ```bash 128 | python3 main.py --env_id Hopper-v1 --expert_path baselines/ppo1/deterministic.ppo.Hopper.0.00.pkl --g_step 3 --adversary_entcoeff 0 129 | ``` 130 | 131 | ![](misc/Hopper-true-reward.png) 132 | 133 | - Walker-v1 (Average total return of expert policy: 4392) 134 | 135 | ```bash 136 | python3 main.py --env_id Walker2d-v1 --expert_path baselines/ppo1/deterministic.ppo.Walker2d.0.00.pkl --g_step 3 --adversary_entcoeff 1e-3 137 | ``` 138 | 139 | ![](misc/Walker2d-true-reward.png) 140 | 141 | - HalfCheetah-v1 (Average total return of expert policy: 2110) 142 | 143 | For HalfCheetah-v1 and Ant-v1, using behavior cloning is needed: 144 | ```bash 145 | python3 main.py --env_id HalfCheetah-v1 --expert_path baselines/ppo1/deterministic.ppo.HalfCheetah.0.00.pkl --pretrained True --BC_max_iter 10000 --g_step 3 --adversary_entcoeff 1e-3 146 | ``` 147 | 148 | ![](misc/HalfCheetah-true-reward.png) 149 | 150 | **You can find more details [here](https://github.com/andrewliao11/gail-tf/blob/master/misc/exp.md), 151 | GAIL policy [here](https://drive.google.com/drive/folders/0B3fKFm-j0RqeRnZMTUJHSmdIdlU?usp=sharing), 152 | and BC policy [here](https://drive.google.com/drive/folders/0B3fKFm-j0RqeVFFmMWpHMk85cUk?usp=sharing)** 153 | 154 | ## Hacking 155 | We don't have a pip package yet, so you'll need to add this repo to your 156 | PYTHONPATH manually. 157 | ```bash 158 | export PYTHONPATH=/path/to/your/repo/with/gailtf:$PYTHONPATH 159 | ``` 160 | 161 | ## TODO 162 | * Create pip package/setup.py 163 | * Make style PEP8 compliant 164 | * Create requirements.txt 165 | * Depend on openai/baselines directly and modularize modifications 166 | * openai/robotschool support 167 | 168 | ## TroubleShooting 169 | 170 | - encounter `error: Cannot compile MPI programs. Check your configuration!!!` or the systme complain about `mpi/h` 171 | ```bash 172 | sudo apt install libopenmpi-dev 173 | ``` 174 | 175 | ## Reference 176 | - Jonathan Ho and Stefano Ermon. Generative adversarial imitation learning, [[arxiv](https://arxiv.org/abs/1606.03476)] 177 | - @openai/imitation 178 | - @openai/baselines 179 | -------------------------------------------------------------------------------- /gailtf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewliao11/gail-tf/ad92f41c26c34e8fabc536664fb11b44f25956cf/gailtf/__init__.py -------------------------------------------------------------------------------- /gailtf/algo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewliao11/gail-tf/ad92f41c26c34e8fabc536664fb11b44f25956cf/gailtf/algo/__init__.py -------------------------------------------------------------------------------- /gailtf/algo/behavior_clone.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import gailtf.baselines.common.tf_util as U 3 | from gailtf.baselines import logger 4 | from tqdm import tqdm 5 | from gailtf.baselines.common.mpi_adam import MpiAdam 6 | import tempfile, os 7 | from common.statistics import stats 8 | import ipdb 9 | 10 | def evaluate(env, policy_func, load_model_path, stochastic_policy=False, number_trajs=10): 11 | from algo.trpo_mpi import traj_episode_generator 12 | ob_space = env.observation_space 13 | ac_space = env.action_space 14 | pi = policy_func("pi", ob_space, ac_space) # Construct network for new policy 15 | # placeholder 16 | ob = U.get_placeholder_cached(name="ob") 17 | ac = pi.pdtype.sample_placeholder([None]) 18 | stochastic = U.get_placeholder_cached(name="stochastic") 19 | ep_gen = traj_episode_generator(pi, env, 1024, stochastic=stochastic_policy) 20 | U.load_state(load_model_path) 21 | len_list = [] 22 | ret_list = [] 23 | for _ in tqdm(range(number_trajs)): 24 | traj = ep_gen.__next__() 25 | ep_len, ep_ret = traj['ep_len'], traj['ep_ret'] 26 | len_list.append(ep_len) 27 | ret_list.append(ep_ret) 28 | if stochastic_policy: 29 | print ('stochastic policy:') 30 | else: 31 | print ('deterministic policy:' ) 32 | print ("Average length:", sum(len_list)/len(len_list)) 33 | print ("Average return:", sum(ret_list)/len(ret_list)) 34 | 35 | def learn(env, policy_func, dataset, pretrained, optim_batch_size=128, max_iters=1e4, 36 | adam_epsilon=1e-5, optim_stepsize=3e-4, ckpt_dir=None, log_dir=None, task_name=None): 37 | val_per_iter = int(max_iters/10) 38 | ob_space = env.observation_space 39 | ac_space = env.action_space 40 | pi = policy_func("pi", ob_space, ac_space) # Construct network for new policy 41 | # placeholder 42 | ob = U.get_placeholder_cached(name="ob") 43 | ac = pi.pdtype.sample_placeholder([None]) 44 | stochastic = U.get_placeholder_cached(name="stochastic") 45 | loss = tf.reduce_mean(tf.square(ac-pi.ac)) 46 | var_list = pi.get_trainable_variables() 47 | adam = MpiAdam(var_list, epsilon=adam_epsilon) 48 | lossandgrad = U.function([ob, ac, stochastic], [loss]+[U.flatgrad(loss, var_list)]) 49 | 50 | if not pretrained: 51 | writer = U.FileWriter(log_dir) 52 | ep_stats = stats(["Loss"]) 53 | U.initialize() 54 | adam.sync() 55 | logger.log("Pretraining with Behavior Cloning...") 56 | for iter_so_far in tqdm(range(int(max_iters))): 57 | ob_expert, ac_expert = dataset.get_next_batch(optim_batch_size, 'train') 58 | loss, g = lossandgrad(ob_expert, ac_expert, True) 59 | adam.update(g, optim_stepsize) 60 | if not pretrained: 61 | ep_stats.add_all_summary(writer, [loss], iter_so_far) 62 | if iter_so_far % val_per_iter == 0: 63 | ob_expert, ac_expert = dataset.get_next_batch(-1, 'val') 64 | loss, g = lossandgrad(ob_expert, ac_expert, False) 65 | logger.log("Validation:") 66 | logger.log("Loss: %f"%loss) 67 | if not pretrained: 68 | U.save_state(os.path.join(ckpt_dir, task_name), counter=iter_so_far) 69 | if pretrained: 70 | savedir_fname = tempfile.TemporaryDirectory().name 71 | U.save_state(savedir_fname, var_list=pi.get_variables()) 72 | return savedir_fname 73 | -------------------------------------------------------------------------------- /gailtf/algo/trpo_mpi.py: -------------------------------------------------------------------------------- 1 | from gailtf.baselines.common import explained_variance, zipsame, dataset, Dataset, fmt_row 2 | from gailtf.baselines import logger 3 | import gailtf.baselines.common.tf_util as U 4 | import tensorflow as tf, numpy as np 5 | import time, os 6 | from gailtf.baselines.common import colorize 7 | from mpi4py import MPI 8 | from collections import deque 9 | from gailtf.baselines.common.mpi_adam import MpiAdam 10 | from gailtf.baselines.common.cg import cg 11 | from contextlib import contextmanager 12 | from gailtf.common.statistics import stats 13 | import ipdb 14 | 15 | def traj_segment_generator(pi, env, discriminator, horizon, stochastic): 16 | # Initialize state variables 17 | t = 0 18 | ac = env.action_space.sample() 19 | new = True 20 | rew = 0.0 21 | true_rew = 0.0 22 | ob = env.reset() 23 | 24 | cur_ep_ret = 0 25 | cur_ep_len = 0 26 | cur_ep_true_ret = 0 27 | ep_true_rets = [] 28 | ep_rets = [] 29 | ep_lens = [] 30 | 31 | # Initialize history arrays 32 | obs = np.array([ob for _ in range(horizon)]) 33 | true_rews = np.zeros(horizon, 'float32') 34 | rews = np.zeros(horizon, 'float32') 35 | vpreds = np.zeros(horizon, 'float32') 36 | news = np.zeros(horizon, 'int32') 37 | acs = np.array([ac for _ in range(horizon)]) 38 | prevacs = acs.copy() 39 | 40 | while True: 41 | prevac = ac 42 | ac, vpred = pi.act(stochastic, ob) 43 | # Slight weirdness here because we need value function at time T 44 | # before returning segment [0, T-1] so we get the correct 45 | # terminal value 46 | if t > 0 and t % horizon == 0: 47 | yield {"ob" : obs, "rew" : rews, "vpred" : vpreds, "new" : news, 48 | "ac" : acs, "prevac" : prevacs, "nextvpred": vpred * (1 - new), 49 | "ep_rets" : ep_rets, "ep_lens" : ep_lens, "ep_true_rets": ep_true_rets} 50 | _, vpred = pi.act(stochastic, ob) 51 | # Be careful!!! if you change the downstream algorithm to aggregate 52 | # several of these batches, then be sure to do a deepcopy 53 | ep_rets = [] 54 | ep_true_rets = [] 55 | ep_lens = [] 56 | i = t % horizon 57 | obs[i] = ob 58 | vpreds[i] = vpred 59 | news[i] = new 60 | acs[i] = ac 61 | prevacs[i] = prevac 62 | 63 | rew = discriminator.get_reward(ob, ac) 64 | ob, true_rew, new, _ = env.step(ac) 65 | rews[i] = rew 66 | true_rews[i] = true_rew 67 | 68 | cur_ep_ret += rew 69 | cur_ep_true_ret += true_rew 70 | cur_ep_len += 1 71 | if new: 72 | ep_rets.append(cur_ep_ret) 73 | ep_true_rets.append(cur_ep_true_ret) 74 | ep_lens.append(cur_ep_len) 75 | cur_ep_ret = 0 76 | cur_ep_true_ret = 0 77 | cur_ep_len = 0 78 | ob = env.reset() 79 | t += 1 80 | 81 | def add_vtarg_and_adv(seg, gamma, lam): 82 | new = np.append(seg["new"], 0) # last element is only used for last vtarg, but we already zeroed it if last new = 1 83 | vpred = np.append(seg["vpred"], seg["nextvpred"]) 84 | T = len(seg["rew"]) 85 | seg["adv"] = gaelam = np.empty(T, 'float32') 86 | rew = seg["rew"] 87 | lastgaelam = 0 88 | for t in reversed(range(T)): 89 | nonterminal = 1-new[t+1] 90 | delta = rew[t] + gamma * vpred[t+1] * nonterminal - vpred[t] 91 | gaelam[t] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam 92 | seg["tdlamret"] = seg["adv"] + seg["vpred"] 93 | 94 | def learn(env, policy_func, discriminator, expert_dataset, 95 | pretrained, pretrained_weight, *, 96 | g_step, d_step, 97 | timesteps_per_batch, # what to train on 98 | max_kl, cg_iters, 99 | gamma, lam, # advantage estimation 100 | entcoeff=0.0, 101 | cg_damping=1e-2, 102 | vf_stepsize=3e-4, d_stepsize=3e-4, 103 | vf_iters =3, 104 | max_timesteps=0, max_episodes=0, max_iters=0, # time constraint 105 | callback=None, 106 | save_per_iter=100, ckpt_dir=None, log_dir=None, 107 | load_model_path=None, task_name=None 108 | ): 109 | nworkers = MPI.COMM_WORLD.Get_size() 110 | rank = MPI.COMM_WORLD.Get_rank() 111 | np.set_printoptions(precision=3) 112 | # Setup losses and stuff 113 | # ---------------------------------------- 114 | ob_space = env.observation_space 115 | ac_space = env.action_space 116 | pi = policy_func("pi", ob_space, ac_space, reuse=(pretrained_weight!=None)) 117 | oldpi = policy_func("oldpi", ob_space, ac_space) 118 | atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable) 119 | ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return 120 | 121 | ob = U.get_placeholder_cached(name="ob") 122 | ac = pi.pdtype.sample_placeholder([None]) 123 | 124 | kloldnew = oldpi.pd.kl(pi.pd) 125 | ent = pi.pd.entropy() 126 | meankl = U.mean(kloldnew) 127 | meanent = U.mean(ent) 128 | entbonus = entcoeff * meanent 129 | 130 | vferr = U.mean(tf.square(pi.vpred - ret)) 131 | 132 | ratio = tf.exp(pi.pd.logp(ac) - oldpi.pd.logp(ac)) # advantage * pnew / pold 133 | surrgain = U.mean(ratio * atarg) 134 | 135 | optimgain = surrgain + entbonus 136 | losses = [optimgain, meankl, entbonus, surrgain, meanent] 137 | loss_names = ["optimgain", "meankl", "entloss", "surrgain", "entropy"] 138 | 139 | dist = meankl 140 | 141 | all_var_list = pi.get_trainable_variables() 142 | var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")] 143 | vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")] 144 | d_adam = MpiAdam(discriminator.get_trainable_variables()) 145 | vfadam = MpiAdam(vf_var_list) 146 | 147 | get_flat = U.GetFlat(var_list) 148 | set_from_flat = U.SetFromFlat(var_list) 149 | klgrads = tf.gradients(dist, var_list) 150 | flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan") 151 | shapes = [var.get_shape().as_list() for var in var_list] 152 | start = 0 153 | tangents = [] 154 | for shape in shapes: 155 | sz = U.intprod(shape) 156 | tangents.append(tf.reshape(flat_tangent[start:start+sz], shape)) 157 | start += sz 158 | gvp = tf.add_n([U.sum(g*tangent) for (g, tangent) in zipsame(klgrads, tangents)]) #pylint: disable=E1111 159 | fvp = U.flatgrad(gvp, var_list) 160 | 161 | assign_old_eq_new = U.function([],[], updates=[tf.assign(oldv, newv) 162 | for (oldv, newv) in zipsame(oldpi.get_variables(), pi.get_variables())]) 163 | compute_losses = U.function([ob, ac, atarg], losses) 164 | compute_lossandgrad = U.function([ob, ac, atarg], losses + [U.flatgrad(optimgain, var_list)]) 165 | compute_fvp = U.function([flat_tangent, ob, ac, atarg], fvp) 166 | compute_vflossandgrad = U.function([ob, ret], U.flatgrad(vferr, vf_var_list)) 167 | 168 | @contextmanager 169 | def timed(msg): 170 | if rank == 0: 171 | print(colorize(msg, color='magenta')) 172 | tstart = time.time() 173 | yield 174 | print(colorize("done in %.3f seconds"%(time.time() - tstart), color='magenta')) 175 | else: 176 | yield 177 | 178 | def allmean(x): 179 | assert isinstance(x, np.ndarray) 180 | out = np.empty_like(x) 181 | MPI.COMM_WORLD.Allreduce(x, out, op=MPI.SUM) 182 | out /= nworkers 183 | return out 184 | 185 | writer = U.FileWriter(log_dir) 186 | U.initialize() 187 | th_init = get_flat() 188 | MPI.COMM_WORLD.Bcast(th_init, root=0) 189 | set_from_flat(th_init) 190 | d_adam.sync() 191 | vfadam.sync() 192 | print("Init param sum", th_init.sum(), flush=True) 193 | 194 | # Prepare for rollouts 195 | # ---------------------------------------- 196 | seg_gen = traj_segment_generator(pi, env, discriminator, timesteps_per_batch, stochastic=True) 197 | 198 | episodes_so_far = 0 199 | timesteps_so_far = 0 200 | iters_so_far = 0 201 | tstart = time.time() 202 | lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths 203 | rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards 204 | true_rewbuffer = deque(maxlen=40) 205 | 206 | assert sum([max_iters>0, max_timesteps>0, max_episodes>0])==1 207 | 208 | g_loss_stats = stats(loss_names) 209 | d_loss_stats = stats(discriminator.loss_name) 210 | ep_stats = stats(["True_rewards", "Rewards", "Episode_length"]) 211 | # if provide pretrained weight 212 | if pretrained_weight is not None: 213 | U.load_state(pretrained_weight, var_list=pi.get_variables()) 214 | # if provieded model path 215 | if load_model_path is not None: 216 | U.load_state(load_model_path) 217 | 218 | while True: 219 | if callback: callback(locals(), globals()) 220 | if max_timesteps and timesteps_so_far >= max_timesteps: 221 | break 222 | elif max_episodes and episodes_so_far >= max_episodes: 223 | break 224 | elif max_iters and iters_so_far >= max_iters: 225 | break 226 | 227 | # Save model 228 | if iters_so_far % save_per_iter == 0 and ckpt_dir is not None: 229 | U.save_state(os.path.join(ckpt_dir, task_name), counter=iters_so_far) 230 | 231 | logger.log("********** Iteration %i ************"%iters_so_far) 232 | 233 | def fisher_vector_product(p): 234 | return allmean(compute_fvp(p, *fvpargs)) + cg_damping * p 235 | # ------------------ Update G ------------------ 236 | logger.log("Optimizing Policy...") 237 | for _ in range(g_step): 238 | with timed("sampling"): 239 | seg = seg_gen.__next__() 240 | add_vtarg_and_adv(seg, gamma, lam) 241 | # ob, ac, atarg, ret, td1ret = map(np.concatenate, (obs, acs, atargs, rets, td1rets)) 242 | ob, ac, atarg, tdlamret = seg["ob"], seg["ac"], seg["adv"], seg["tdlamret"] 243 | vpredbefore = seg["vpred"] # predicted value function before udpate 244 | atarg = (atarg - atarg.mean()) / atarg.std() # standardized advantage function estimate 245 | 246 | if hasattr(pi, "ob_rms"): pi.ob_rms.update(ob) # update running mean/std for policy 247 | 248 | args = seg["ob"], seg["ac"], atarg 249 | fvpargs = [arr[::5] for arr in args] 250 | 251 | assign_old_eq_new() # set old parameter values to new parameter values 252 | with timed("computegrad"): 253 | *lossbefore, g = compute_lossandgrad(*args) 254 | lossbefore = allmean(np.array(lossbefore)) 255 | g = allmean(g) 256 | if np.allclose(g, 0): 257 | logger.log("Got zero gradient. not updating") 258 | else: 259 | with timed("cg"): 260 | stepdir = cg(fisher_vector_product, g, cg_iters=cg_iters, verbose=rank==0) 261 | assert np.isfinite(stepdir).all() 262 | shs = .5*stepdir.dot(fisher_vector_product(stepdir)) 263 | lm = np.sqrt(shs / max_kl) 264 | # logger.log("lagrange multiplier:", lm, "gnorm:", np.linalg.norm(g)) 265 | fullstep = stepdir / lm 266 | expectedimprove = g.dot(fullstep) 267 | surrbefore = lossbefore[0] 268 | stepsize = 1.0 269 | thbefore = get_flat() 270 | for _ in range(10): 271 | thnew = thbefore + fullstep * stepsize 272 | set_from_flat(thnew) 273 | meanlosses = surr, kl, *_ = allmean(np.array(compute_losses(*args))) 274 | improve = surr - surrbefore 275 | logger.log("Expected: %.3f Actual: %.3f"%(expectedimprove, improve)) 276 | if not np.isfinite(meanlosses).all(): 277 | logger.log("Got non-finite value of losses -- bad!") 278 | elif kl > max_kl * 1.5: 279 | logger.log("violated KL constraint. shrinking step.") 280 | elif improve < 0: 281 | logger.log("surrogate didn't improve. shrinking step.") 282 | else: 283 | logger.log("Stepsize OK!") 284 | break 285 | stepsize *= .5 286 | else: 287 | logger.log("couldn't compute a good step") 288 | set_from_flat(thbefore) 289 | if nworkers > 1 and iters_so_far % 20 == 0: 290 | paramsums = MPI.COMM_WORLD.allgather((thnew.sum(), vfadam.getflat().sum())) # list of tuples 291 | assert all(np.allclose(ps, paramsums[0]) for ps in paramsums[1:]) 292 | with timed("vf"): 293 | for _ in range(vf_iters): 294 | for (mbob, mbret) in dataset.iterbatches((seg["ob"], seg["tdlamret"]), 295 | include_final_partial_batch=False, batch_size=128): 296 | if hasattr(pi, "ob_rms"): pi.ob_rms.update(mbob) # update running mean/std for policy 297 | g = allmean(compute_vflossandgrad(mbob, mbret)) 298 | vfadam.update(g, vf_stepsize) 299 | 300 | g_losses = meanlosses 301 | for (lossname, lossval) in zip(loss_names, meanlosses): 302 | logger.record_tabular(lossname, lossval) 303 | logger.record_tabular("ev_tdlam_before", explained_variance(vpredbefore, tdlamret)) 304 | # ------------------ Update D ------------------ 305 | logger.log("Optimizing Discriminator...") 306 | logger.log(fmt_row(13, discriminator.loss_name)) 307 | ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob)) 308 | batch_size = len(ob) // d_step 309 | d_losses = [] # list of tuples, each of which gives the loss for a minibatch 310 | for ob_batch, ac_batch in dataset.iterbatches((ob, ac), 311 | include_final_partial_batch=False, batch_size=batch_size): 312 | ob_expert, ac_expert = expert_dataset.get_next_batch(len(ob_batch)) 313 | # update running mean/std for discriminator 314 | if hasattr(discriminator, "obs_rms"): discriminator.obs_rms.update(np.concatenate((ob_batch, ob_expert), 0)) 315 | *newlosses, g = discriminator.lossandgrad(ob_batch, ac_batch, ob_expert, ac_expert) 316 | d_adam.update(allmean(g), d_stepsize) 317 | d_losses.append(newlosses) 318 | logger.log(fmt_row(13, np.mean(d_losses, axis=0))) 319 | 320 | 321 | lrlocal = (seg["ep_lens"], seg["ep_rets"], seg["ep_true_rets"]) # local values 322 | listoflrpairs = MPI.COMM_WORLD.allgather(lrlocal) # list of tuples 323 | lens, rews, true_rets = map(flatten_lists, zip(*listoflrpairs)) 324 | true_rewbuffer.extend(true_rets) 325 | lenbuffer.extend(lens) 326 | rewbuffer.extend(rews) 327 | 328 | logger.record_tabular("EpLenMean", np.mean(lenbuffer)) 329 | logger.record_tabular("EpRewMean", np.mean(rewbuffer)) 330 | logger.record_tabular("EpTrueRewMean", np.mean(true_rewbuffer)) 331 | logger.record_tabular("EpThisIter", len(lens)) 332 | episodes_so_far += len(lens) 333 | timesteps_so_far += sum(lens) 334 | iters_so_far += 1 335 | 336 | logger.record_tabular("EpisodesSoFar", episodes_so_far) 337 | logger.record_tabular("TimestepsSoFar", timesteps_so_far) 338 | logger.record_tabular("TimeElapsed", time.time() - tstart) 339 | 340 | if rank==0: 341 | logger.dump_tabular() 342 | g_loss_stats.add_all_summary(writer, g_losses, iters_so_far) 343 | d_loss_stats.add_all_summary(writer, np.mean(d_losses, axis=0), iters_so_far) 344 | ep_stats.add_all_summary(writer, [np.mean(true_rewbuffer), np.mean(rewbuffer), 345 | np.mean(lenbuffer)], iters_so_far) 346 | 347 | # Sample one trajectory (until trajectory end) 348 | def traj_episode_generator(pi, env, horizon, stochastic): 349 | t = 0 350 | ac = env.action_space.sample() # not used, just so we have the datatype 351 | new = True # marks if we're on first timestep of an episode 352 | 353 | ob = env.reset() 354 | cur_ep_ret = 0 # return in current episode 355 | cur_ep_len = 0 # len of current episode 356 | 357 | # Initialize history arrays 358 | obs = []; rews = []; news = []; acs = [] 359 | 360 | while True: 361 | prevac = ac 362 | ac, vpred = pi.act(stochastic, ob) 363 | obs.append(ob) 364 | news.append(new) 365 | acs.append(ac) 366 | 367 | ob, rew, new, _ = env.step(ac) 368 | rews.append(rew) 369 | 370 | cur_ep_ret += rew 371 | cur_ep_len += 1 372 | if t > 0 and (new or t % horizon == 0): 373 | # convert list into numpy array 374 | obs = np.array(obs) 375 | rews = np.array(rews) 376 | news = np.array(news) 377 | acs = np.array(acs) 378 | yield {"ob":obs, "rew":rews, "new":news, "ac":acs, 379 | "ep_ret":cur_ep_ret, "ep_len":cur_ep_len} 380 | ob = env.reset() 381 | cur_ep_ret = 0; cur_ep_len = 0; t = 0 382 | 383 | # Initialize history arrays 384 | obs = []; rews = []; news = []; acs = [] 385 | t += 1 386 | 387 | def evaluate(env, policy_func, load_model_path, timesteps_per_batch, number_trajs=10, 388 | stochastic_policy=False): 389 | 390 | from tqdm import tqdm 391 | # Setup network 392 | # ---------------------------------------- 393 | ob_space = env.observation_space 394 | ac_space = env.action_space 395 | pi = policy_func("pi", ob_space, ac_space, reuse=False) 396 | U.initialize() 397 | # Prepare for rollouts 398 | # ---------------------------------------- 399 | ep_gen = traj_episode_generator(pi, env, timesteps_per_batch, stochastic=stochastic_policy) 400 | U.load_state(load_model_path) 401 | 402 | len_list = [] 403 | ret_list = [] 404 | for _ in tqdm(range(number_trajs)): 405 | traj = ep_gen.__next__() 406 | ep_len, ep_ret = traj['ep_len'], traj['ep_ret'] 407 | len_list.append(ep_len) 408 | ret_list.append(ep_ret) 409 | if stochastic_policy: 410 | print ('stochastic policy:') 411 | else: 412 | print ('deterministic policy:' ) 413 | print ("Average length:", sum(len_list)/len(len_list)) 414 | print ("Average return:", sum(ret_list)/len(ret_list)) 415 | 416 | def flatten_lists(listoflists): 417 | return [el for list_ in listoflists for el in list_] 418 | -------------------------------------------------------------------------------- /gailtf/baselines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewliao11/gail-tf/ad92f41c26c34e8fabc536664fb11b44f25956cf/gailtf/baselines/__init__.py -------------------------------------------------------------------------------- /gailtf/baselines/bench/__init__.py: -------------------------------------------------------------------------------- 1 | from gailtf.baselines.bench.benchmarks import * 2 | from gailtf.baselines.bench.monitor import * 3 | 4 | -------------------------------------------------------------------------------- /gailtf/baselines/bench/benchmarks.py: -------------------------------------------------------------------------------- 1 | _atari7 = ['BeamRider', 'Breakout', 'Enduro', 'Pong', 'Qbert', 'Seaquest', 'SpaceInvaders'] 2 | _atariexpl7 = ['Freeway', 'Gravitar', 'MontezumaRevenge', 'Pitfall', 'PrivateEye', 'Solaris', 'Venture'] 3 | 4 | _BENCHMARKS = [] 5 | 6 | def register_benchmark(benchmark): 7 | for b in _BENCHMARKS: 8 | if b['name'] == benchmark['name']: 9 | raise ValueError('Benchmark with name %s already registered!'%b['name']) 10 | _BENCHMARKS.append(benchmark) 11 | 12 | def list_benchmarks(): 13 | return [b['name'] for b in _BENCHMARKS] 14 | 15 | def get_benchmark(benchmark_name): 16 | for b in _BENCHMARKS: 17 | if b['name'] == benchmark_name: 18 | return b 19 | raise ValueError('%s not found! Known benchmarks: %s' % (benchmark_name, list_benchmarks())) 20 | 21 | def get_task(benchmark, env_id): 22 | """Get a task by env_id. Return None if the benchmark doesn't have the env""" 23 | return next(filter(lambda task: task['env_id'] == env_id, benchmark['tasks']), None) 24 | 25 | def find_task_for_env_id_in_any_benchmark(env_id): 26 | for bm in _BENCHMARKS: 27 | for task in bm["tasks"]: 28 | if task["env_id"]==env_id: 29 | return bm, task 30 | return None, None 31 | 32 | _ATARI_SUFFIX = 'NoFrameskip-v4' 33 | 34 | register_benchmark({ 35 | 'name' : 'Atari200M', 36 | 'description' :'7 Atari games from Mnih et al. (2013), with pixel observations, 200M frames', 37 | 'tasks' : [{'env_id' : _game + _ATARI_SUFFIX, 'trials' : 2, 'num_timesteps' : int(200e6)} for _game in _atari7] 38 | }) 39 | 40 | register_benchmark({ 41 | 'name' : 'Atari40M', 42 | 'description' :'7 Atari games from Mnih et al. (2013), with pixel observations, 40M frames', 43 | 'tasks' : [{'env_id' : _game + _ATARI_SUFFIX, 'trials' : 2, 'num_timesteps' : int(40e6)} for _game in _atari7] 44 | }) 45 | 46 | register_benchmark({ 47 | 'name' : 'Atari1Hr', 48 | 'description' :'7 Atari games from Mnih et al. (2013), with pixel observations, 1 hour of walltime', 49 | 'tasks' : [{'env_id' : _game + _ATARI_SUFFIX, 'trials' : 2, 'num_seconds' : 60*60} for _game in _atari7] 50 | }) 51 | 52 | register_benchmark({ 53 | 'name' : 'AtariExploration40M', 54 | 'description' :'7 Atari games emphasizing exploration, with pixel observations, 40M frames', 55 | 'tasks' : [{'env_id' : _game + _ATARI_SUFFIX, 'trials' : 2, 'num_timesteps' : int(40e6)} for _game in _atariexpl7] 56 | }) 57 | 58 | 59 | # MuJoCo 60 | 61 | _mujocosmall = [ 62 | 'InvertedDoublePendulum-v1', 'InvertedPendulum-v1', 63 | 'HalfCheetah-v1', 'Hopper-v1', 'Walker2d-v1', 64 | 'Reacher-v1', 'Swimmer-v1'] 65 | register_benchmark({ 66 | 'name' : 'Mujoco1M', 67 | 'description' : 'Some small 2D MuJoCo tasks, run for 1M timesteps', 68 | 'tasks' : [{'env_id' : _envid, 'trials' : 3, 'num_timesteps' : int(1e6)} for _envid in _mujocosmall] 69 | }) 70 | register_benchmark({ 71 | 'name' : 'MujocoWalkers', 72 | 'description' : 'MuJoCo forward walkers, run for 8M, humanoid 100M', 73 | 'tasks' : [ 74 | {'env_id' : "Hopper-v1", 'trials' : 4, 'num_timesteps' : 8*1000000 }, 75 | {'env_id' : "Walker2d-v1", 'trials' : 4, 'num_timesteps' : 8*1000000 }, 76 | {'env_id' : "Humanoid-v1", 'trials' : 4, 'num_timesteps' : 100*1000000 }, 77 | ] 78 | }) 79 | # To reproduce: 80 | # python3 baselines/baselines/ppo2/ppo2_run_benchmark.py gce MujocoWalkers myrun_ppo2_whiteobs1_cpu8 81 | # (observation input filters necessary) 82 | 83 | 84 | # Roboschool 85 | 86 | register_benchmark({ 87 | 'name' : 'Roboschool8M', 88 | 'description' : 'Small 2D tasks, up to 30 minutes to complete on 8 cores', 89 | 'tasks' : [ 90 | {'env_id' : "RoboschoolReacher-v1", 'trials' : 4, 'num_timesteps' : 2*1000000 }, 91 | {'env_id' : "RoboschoolAnt-v1", 'trials' : 4, 'num_timesteps' : 8*1000000 }, 92 | {'env_id' : "RoboschoolHalfCheetah-v1", 'trials' : 4, 'num_timesteps' : 8*1000000 }, 93 | {'env_id' : "RoboschoolHopper-v1", 'trials' : 4, 'num_timesteps' : 8*1000000 }, 94 | {'env_id' : "RoboschoolWalker2d-v1", 'trials' : 4, 'num_timesteps' : 8*1000000 }, 95 | ] 96 | }) 97 | register_benchmark({ 98 | 'name' : 'RoboschoolHarder', 99 | 'description' : 'Test your might!!! Up to 12 hours on 32 cores', 100 | 'tasks' : [ 101 | {'env_id' : "RoboschoolHumanoid-v1", 'trials' : 4, 'num_timesteps' : 100*1000000 }, 102 | {'env_id' : "RoboschoolHumanoidFlagrun-v1", 'trials' : 4, 'num_timesteps' : 200*1000000 }, 103 | {'env_id' : "RoboschoolHumanoidFlagrunHarder-v1", 'trials' : 4, 'num_timesteps' : 400*1000000 }, 104 | ] 105 | }) 106 | # To reproduce: 107 | # python3 baselines/baselines/ppo2/ppo2_run_benchmark.py gce Roboschool8M myrun_ppo2_cpu8 108 | # python3 baselines/baselines/ppo2/ppo2_run_benchmark.py gce RoboschoolHarder myrun_ppo2_cpu32_large_samples65536 109 | # (Large network, train on 65536 samples each iteration. Also, _large is really necessary only for Harder) 110 | 111 | 112 | # Other 113 | 114 | _atari50 = [ # actually 49 115 | 'Alien', 'Amidar', 'Assault', 'Asterix', 'Asteroids', 116 | 'Atlantis', 'BankHeist', 'BattleZone', 'BeamRider', 'Bowling', 117 | 'Boxing', 'Breakout', 'Centipede', 'ChopperCommand', 'CrazyClimber', 118 | 'DemonAttack', 'DoubleDunk', 'Enduro', 'FishingDerby', 'Freeway', 119 | 'Frostbite', 'Gopher', 'Gravitar', 'IceHockey', 'Jamesbond', 120 | 'Kangaroo', 'Krull', 'KungFuMaster', 'MontezumaRevenge', 'MsPacman', 121 | 'NameThisGame', 'Pitfall', 'Pong', 'PrivateEye', 'Qbert', 122 | 'Riverraid', 'RoadRunner', 'Robotank', 'Seaquest', 'SpaceInvaders', 123 | 'StarGunner', 'Tennis', 'TimePilot', 'Tutankham', 'UpNDown', 124 | 'Venture', 'VideoPinball', 'WizardOfWor', 'Zaxxon', 125 | ] 126 | 127 | register_benchmark({ 128 | 'name' : 'Atari50_40M', 129 | 'description' :'7 Atari games from Mnih et al. (2013), with pixel observations, 40M frames', 130 | 'tasks' : [{'env_id' : _game + _ATARI_SUFFIX, 'trials' : 3, 'num_timesteps' : int(40e6)} for _game in _atari50] 131 | }) 132 | 133 | def env_shortname(s): 134 | "Make typical names above shorter, while keeping recognizable" 135 | s = s.replace("NoFrameskip", "") 136 | if s[:10]=="Roboschool": s = s[10:] 137 | i = s.rfind("-v") 138 | if i!=-1: s = s[:i] 139 | 140 | return s.lower() 141 | -------------------------------------------------------------------------------- /gailtf/baselines/bench/monitor.py: -------------------------------------------------------------------------------- 1 | __all__ = ['Monitor', 'get_monitor_files', 'load_results'] 2 | 3 | import gym 4 | from gym.core import Wrapper 5 | from os import path 6 | import time 7 | from glob import glob 8 | 9 | try: 10 | import ujson as json # Not necessary for monitor writing, but very useful for monitor loading 11 | except ImportError: 12 | import json 13 | 14 | class Monitor(Wrapper): 15 | EXT = "monitor.json" 16 | f = None 17 | 18 | def __init__(self, env, filename, allow_early_resets=False): 19 | Wrapper.__init__(self, env=env) 20 | self.tstart = time.time() 21 | if filename is None: 22 | self.f = None 23 | self.logger = None 24 | else: 25 | if not filename.endswith(Monitor.EXT): 26 | filename = filename + "." + Monitor.EXT 27 | self.f = open(filename, "wt") 28 | self.logger = JSONLogger(self.f) 29 | self.logger.writekvs({"t_start": self.tstart, "gym_version": gym.__version__, 30 | "env_id": env.spec.id if env.spec else 'Unknown'}) 31 | self.allow_early_resets = allow_early_resets 32 | self.rewards = None 33 | self.needs_reset = True 34 | self.episode_rewards = [] 35 | self.episode_lengths = [] 36 | self.total_steps = 0 37 | self.current_metadata = {} # extra info that gets injected into each log entry 38 | # Useful for metalearning where we're modifying the environment externally 39 | # But want our logs to know about these modifications 40 | 41 | def __getstate__(self): # XXX 42 | d = self.__dict__.copy() 43 | if self.f: 44 | del d['f'], d['logger'] 45 | d['_filename'] = self.f.name 46 | d['_num_episodes'] = len(self.episode_rewards) 47 | else: 48 | d['_filename'] = None 49 | return d 50 | def __setstate__(self, d): 51 | filename = d.pop('_filename') 52 | self.__dict__ = d 53 | if filename is not None: 54 | nlines = d.pop('_num_episodes') + 1 55 | self.f = open(filename, "r+t") 56 | for _ in range(nlines): 57 | self.f.readline() 58 | self.f.truncate() 59 | self.logger = JSONLogger(self.f) 60 | 61 | 62 | def reset(self): 63 | if not self.allow_early_resets and not self.needs_reset: 64 | raise RuntimeError("Tried to reset an environment before done. If you want to allow early resets, wrap your env with Monitor(env, path, allow_early_resets=True)") 65 | self.rewards = [] 66 | self.needs_reset = False 67 | return self.env.reset() 68 | 69 | def step(self, action): 70 | if self.needs_reset: 71 | raise RuntimeError("Tried to step environment that needs reset") 72 | ob, rew, done, info = self.env.step(action) 73 | self.rewards.append(rew) 74 | if done: 75 | self.needs_reset = True 76 | eprew = sum(self.rewards) 77 | eplen = len(self.rewards) 78 | epinfo = {"r": eprew, "l": eplen, "t": round(time.time() - self.tstart, 6)} 79 | epinfo.update(self.current_metadata) 80 | if self.logger: 81 | self.logger.writekvs(epinfo) 82 | self.episode_rewards.append(eprew) 83 | self.episode_lengths.append(eplen) 84 | info['episode'] = epinfo 85 | self.total_steps += 1 86 | return (ob, rew, done, info) 87 | 88 | def close(self): 89 | if self.f is not None: 90 | self.f.close() 91 | 92 | def get_total_steps(self): 93 | return self.total_steps 94 | 95 | def get_episode_rewards(self): 96 | return self.episode_rewards 97 | 98 | def get_episode_lengths(self): 99 | return self.episode_lengths 100 | 101 | class JSONLogger(object): 102 | def __init__(self, file): 103 | self.file = file 104 | 105 | def writekvs(self, kvs): 106 | for k,v in kvs.items(): 107 | if hasattr(v, 'dtype'): 108 | v = v.tolist() 109 | kvs[k] = float(v) 110 | self.file.write(json.dumps(kvs) + '\n') 111 | self.file.flush() 112 | 113 | 114 | class LoadMonitorResultsError(Exception): 115 | pass 116 | 117 | def get_monitor_files(dir): 118 | return glob(path.join(dir, "*" + Monitor.EXT)) 119 | 120 | def load_results(dir, raw_episodes=False): 121 | fnames = get_monitor_files(dir) 122 | if not fnames: 123 | raise LoadMonitorResultsError("no monitor files of the form *%s found in %s" % (Monitor.EXT, dir)) 124 | episodes = [] 125 | headers = [] 126 | for fname in fnames: 127 | with open(fname, 'rt') as fh: 128 | lines = fh.readlines() 129 | header = json.loads(lines[0]) 130 | headers.append(header) 131 | for line in lines[1:]: 132 | episode = json.loads(line) 133 | episode['abstime'] = header['t_start'] + episode['t'] 134 | del episode['t'] 135 | episodes.append(episode) 136 | header0 = headers[0] 137 | for header in headers[1:]: 138 | assert header['env_id'] == header0['env_id'], "mixing data from two envs" 139 | episodes = sorted(episodes, key=lambda e: e['abstime']) 140 | if raw_episodes: 141 | return episodes 142 | else: 143 | return { 144 | 'env_info': {'env_id': header0['env_id'], 'gym_version': header0['gym_version']}, 145 | 'episode_end_times': [e['abstime'] for e in episodes], 146 | 'episode_lengths': [e['l'] for e in episodes], 147 | 'episode_rewards': [e['r'] for e in episodes], 148 | 'initial_reset_time': min([min(header['t_start'] for header in headers)]) 149 | } 150 | -------------------------------------------------------------------------------- /gailtf/baselines/common/__init__.py: -------------------------------------------------------------------------------- 1 | from gailtf.baselines.common.console_util import * 2 | from gailtf.baselines.common.dataset import Dataset 3 | from gailtf.baselines.common.math_util import * 4 | from gailtf.baselines.common.misc_util import * 5 | -------------------------------------------------------------------------------- /gailtf/baselines/common/atari_wrappers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import deque 3 | from PIL import Image 4 | import gym 5 | from gym import spaces 6 | 7 | 8 | class NoopResetEnv(gym.Wrapper): 9 | def __init__(self, env, noop_max=30): 10 | """Sample initial states by taking random number of no-ops on reset. 11 | No-op is assumed to be action 0. 12 | """ 13 | gym.Wrapper.__init__(self, env) 14 | self.noop_max = noop_max 15 | self.override_num_noops = None 16 | assert env.unwrapped.get_action_meanings()[0] == 'NOOP' 17 | 18 | def _reset(self): 19 | """ Do no-op action for a number of steps in [1, noop_max].""" 20 | self.env.reset() 21 | if self.override_num_noops is not None: 22 | noops = self.override_num_noops 23 | else: 24 | noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101 25 | assert noops > 0 26 | obs = None 27 | for _ in range(noops): 28 | obs, _, done, _ = self.env.step(0) 29 | if done: 30 | obs = self.env.reset() 31 | return obs 32 | 33 | class FireResetEnv(gym.Wrapper): 34 | def __init__(self, env): 35 | """Take action on reset for environments that are fixed until firing.""" 36 | gym.Wrapper.__init__(self, env) 37 | assert env.unwrapped.get_action_meanings()[1] == 'FIRE' 38 | assert len(env.unwrapped.get_action_meanings()) >= 3 39 | 40 | def _reset(self): 41 | self.env.reset() 42 | obs, _, done, _ = self.env.step(1) 43 | if done: 44 | self.env.reset() 45 | obs, _, done, _ = self.env.step(2) 46 | if done: 47 | self.env.reset() 48 | return obs 49 | 50 | class EpisodicLifeEnv(gym.Wrapper): 51 | def __init__(self, env): 52 | """Make end-of-life == end-of-episode, but only reset on true game over. 53 | Done by DeepMind for the DQN and co. since it helps value estimation. 54 | """ 55 | gym.Wrapper.__init__(self, env) 56 | self.lives = 0 57 | self.was_real_done = True 58 | 59 | def _step(self, action): 60 | obs, reward, done, info = self.env.step(action) 61 | self.was_real_done = done 62 | # check current lives, make loss of life terminal, 63 | # then update lives to handle bonus lives 64 | lives = self.env.unwrapped.ale.lives() 65 | if lives < self.lives and lives > 0: 66 | # for Qbert somtimes we stay in lives == 0 condtion for a few frames 67 | # so its important to keep lives > 0, so that we only reset once 68 | # the environment advertises done. 69 | done = True 70 | self.lives = lives 71 | return obs, reward, done, info 72 | 73 | def _reset(self): 74 | """Reset only when lives are exhausted. 75 | This way all states are still reachable even though lives are episodic, 76 | and the learner need not know about any of this behind-the-scenes. 77 | """ 78 | if self.was_real_done: 79 | obs = self.env.reset() 80 | else: 81 | # no-op step to advance from terminal/lost life state 82 | obs, _, _, _ = self.env.step(0) 83 | self.lives = self.env.unwrapped.ale.lives() 84 | return obs 85 | 86 | class MaxAndSkipEnv(gym.Wrapper): 87 | def __init__(self, env, skip=4): 88 | """Return only every `skip`-th frame""" 89 | gym.Wrapper.__init__(self, env) 90 | # most recent raw observations (for max pooling across time steps) 91 | self._obs_buffer = deque(maxlen=2) 92 | self._skip = skip 93 | 94 | def _step(self, action): 95 | """Repeat action, sum reward, and max over last observations.""" 96 | total_reward = 0.0 97 | done = None 98 | for _ in range(self._skip): 99 | obs, reward, done, info = self.env.step(action) 100 | self._obs_buffer.append(obs) 101 | total_reward += reward 102 | if done: 103 | break 104 | max_frame = np.max(np.stack(self._obs_buffer), axis=0) 105 | 106 | return max_frame, total_reward, done, info 107 | 108 | def _reset(self): 109 | """Clear past frame buffer and init. to first obs. from inner env.""" 110 | self._obs_buffer.clear() 111 | obs = self.env.reset() 112 | self._obs_buffer.append(obs) 113 | return obs 114 | 115 | class ClipRewardEnv(gym.RewardWrapper): 116 | def _reward(self, reward): 117 | """Bin reward to {+1, 0, -1} by its sign.""" 118 | return np.sign(reward) 119 | 120 | class WarpFrame(gym.ObservationWrapper): 121 | def __init__(self, env): 122 | """Warp frames to 84x84 as done in the Nature paper and later work.""" 123 | gym.ObservationWrapper.__init__(self, env) 124 | self.res = 84 125 | self.observation_space = spaces.Box(low=0, high=255, shape=(self.res, self.res, 1)) 126 | 127 | def _observation(self, obs): 128 | frame = np.dot(obs.astype('float32'), np.array([0.299, 0.587, 0.114], 'float32')) 129 | frame = np.array(Image.fromarray(frame).resize((self.res, self.res), 130 | resample=Image.BILINEAR), dtype=np.uint8) 131 | return frame.reshape((self.res, self.res, 1)) 132 | 133 | class FrameStack(gym.Wrapper): 134 | def __init__(self, env, k): 135 | """Buffer observations and stack across channels (last axis).""" 136 | gym.Wrapper.__init__(self, env) 137 | self.k = k 138 | self.frames = deque([], maxlen=k) 139 | shp = env.observation_space.shape 140 | assert shp[2] == 1 # can only stack 1-channel frames 141 | self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], k)) 142 | 143 | def _reset(self): 144 | """Clear buffer and re-fill by duplicating the first observation.""" 145 | ob = self.env.reset() 146 | for _ in range(self.k): self.frames.append(ob) 147 | return self._observation() 148 | 149 | def _step(self, action): 150 | ob, reward, done, info = self.env.step(action) 151 | self.frames.append(ob) 152 | return self._observation(), reward, done, info 153 | 154 | def _observation(self): 155 | assert len(self.frames) == self.k 156 | return np.concatenate(self.frames, axis=2) 157 | 158 | def wrap_deepmind(env, episode_life=True, clip_rewards=True): 159 | """Configure environment for DeepMind-style Atari. 160 | 161 | Note: this does not include frame stacking!""" 162 | assert 'NoFrameskip' in env.spec.id # required for DeepMind-style skip 163 | if episode_life: 164 | env = EpisodicLifeEnv(env) 165 | env = NoopResetEnv(env, noop_max=30) 166 | env = MaxAndSkipEnv(env, skip=4) 167 | if 'FIRE' in env.unwrapped.get_action_meanings(): 168 | env = FireResetEnv(env) 169 | env = WarpFrame(env) 170 | if clip_rewards: 171 | env = ClipRewardEnv(env) 172 | return env 173 | -------------------------------------------------------------------------------- /gailtf/baselines/common/atari_wrappers_deprecated.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import gym 3 | import numpy as np 4 | 5 | from collections import deque 6 | from gym import spaces 7 | 8 | 9 | class NoopResetEnv(gym.Wrapper): 10 | def __init__(self, env=None, noop_max=30): 11 | """Sample initial states by taking random number of no-ops on reset. 12 | No-op is assumed to be action 0. 13 | """ 14 | super(NoopResetEnv, self).__init__(env) 15 | self.noop_max = noop_max 16 | self.override_num_noops = None 17 | assert env.unwrapped.get_action_meanings()[0] == 'NOOP' 18 | 19 | def _reset(self): 20 | """ Do no-op action for a number of steps in [1, noop_max].""" 21 | self.env.reset() 22 | if self.override_num_noops is not None: 23 | noops = self.override_num_noops 24 | else: 25 | noops = np.random.randint(1, self.noop_max + 1) 26 | assert noops > 0 27 | obs = None 28 | for _ in range(noops): 29 | obs, _, done, _ = self.env.step(0) 30 | if done: 31 | obs = self.env.reset() 32 | return obs 33 | 34 | 35 | class FireResetEnv(gym.Wrapper): 36 | def __init__(self, env=None): 37 | """For environments where the user need to press FIRE for the game to start.""" 38 | super(FireResetEnv, self).__init__(env) 39 | assert env.unwrapped.get_action_meanings()[1] == 'FIRE' 40 | assert len(env.unwrapped.get_action_meanings()) >= 3 41 | 42 | def _reset(self): 43 | self.env.reset() 44 | obs, _, done, _ = self.env.step(1) 45 | if done: 46 | self.env.reset() 47 | obs, _, done, _ = self.env.step(2) 48 | if done: 49 | self.env.reset() 50 | return obs 51 | 52 | 53 | class EpisodicLifeEnv(gym.Wrapper): 54 | def __init__(self, env=None): 55 | """Make end-of-life == end-of-episode, but only reset on true game over. 56 | Done by DeepMind for the DQN and co. since it helps value estimation. 57 | """ 58 | super(EpisodicLifeEnv, self).__init__(env) 59 | self.lives = 0 60 | self.was_real_done = True 61 | self.was_real_reset = False 62 | 63 | def _step(self, action): 64 | obs, reward, done, info = self.env.step(action) 65 | self.was_real_done = done 66 | # check current lives, make loss of life terminal, 67 | # then update lives to handle bonus lives 68 | lives = self.env.unwrapped.ale.lives() 69 | if lives < self.lives and lives > 0: 70 | # for Qbert somtimes we stay in lives == 0 condtion for a few frames 71 | # so its important to keep lives > 0, so that we only reset once 72 | # the environment advertises done. 73 | done = True 74 | self.lives = lives 75 | return obs, reward, done, info 76 | 77 | def _reset(self): 78 | """Reset only when lives are exhausted. 79 | This way all states are still reachable even though lives are episodic, 80 | and the learner need not know about any of this behind-the-scenes. 81 | """ 82 | if self.was_real_done: 83 | obs = self.env.reset() 84 | self.was_real_reset = True 85 | else: 86 | # no-op step to advance from terminal/lost life state 87 | obs, _, _, _ = self.env.step(0) 88 | self.was_real_reset = False 89 | self.lives = self.env.unwrapped.ale.lives() 90 | return obs 91 | 92 | 93 | class MaxAndSkipEnv(gym.Wrapper): 94 | def __init__(self, env=None, skip=4): 95 | """Return only every `skip`-th frame""" 96 | super(MaxAndSkipEnv, self).__init__(env) 97 | # most recent raw observations (for max pooling across time steps) 98 | self._obs_buffer = deque(maxlen=2) 99 | self._skip = skip 100 | 101 | def _step(self, action): 102 | total_reward = 0.0 103 | done = None 104 | for _ in range(self._skip): 105 | obs, reward, done, info = self.env.step(action) 106 | self._obs_buffer.append(obs) 107 | total_reward += reward 108 | if done: 109 | break 110 | 111 | max_frame = np.max(np.stack(self._obs_buffer), axis=0) 112 | 113 | return max_frame, total_reward, done, info 114 | 115 | def _reset(self): 116 | """Clear past frame buffer and init. to first obs. from inner env.""" 117 | self._obs_buffer.clear() 118 | obs = self.env.reset() 119 | self._obs_buffer.append(obs) 120 | return obs 121 | 122 | 123 | class ProcessFrame84(gym.ObservationWrapper): 124 | def __init__(self, env=None): 125 | super(ProcessFrame84, self).__init__(env) 126 | self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1)) 127 | 128 | def _observation(self, obs): 129 | return ProcessFrame84.process(obs) 130 | 131 | @staticmethod 132 | def process(frame): 133 | if frame.size == 210 * 160 * 3: 134 | img = np.reshape(frame, [210, 160, 3]).astype(np.float32) 135 | elif frame.size == 250 * 160 * 3: 136 | img = np.reshape(frame, [250, 160, 3]).astype(np.float32) 137 | else: 138 | assert False, "Unknown resolution." 139 | img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114 140 | resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA) 141 | x_t = resized_screen[18:102, :] 142 | x_t = np.reshape(x_t, [84, 84, 1]) 143 | return x_t.astype(np.uint8) 144 | 145 | 146 | class ClippedRewardsWrapper(gym.RewardWrapper): 147 | def _reward(self, reward): 148 | """Change all the positive rewards to 1, negative to -1 and keep zero.""" 149 | return np.sign(reward) 150 | 151 | 152 | class LazyFrames(object): 153 | def __init__(self, frames): 154 | """This object ensures that common frames between the observations are only stored once. 155 | It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay 156 | buffers. 157 | 158 | This object should only be converted to numpy array before being passed to the model. 159 | 160 | You'd not belive how complex the previous solution was.""" 161 | self._frames = frames 162 | 163 | def __array__(self, dtype=None): 164 | out = np.concatenate(self._frames, axis=2) 165 | if dtype is not None: 166 | out = out.astype(dtype) 167 | return out 168 | 169 | 170 | class FrameStack(gym.Wrapper): 171 | def __init__(self, env, k): 172 | """Stack k last frames. 173 | 174 | Returns lazy array, which is much more memory efficient. 175 | 176 | See Also 177 | -------- 178 | baselines.common.atari_wrappers.LazyFrames 179 | """ 180 | gym.Wrapper.__init__(self, env) 181 | self.k = k 182 | self.frames = deque([], maxlen=k) 183 | shp = env.observation_space.shape 184 | self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k)) 185 | 186 | def _reset(self): 187 | ob = self.env.reset() 188 | for _ in range(self.k): 189 | self.frames.append(ob) 190 | return self._get_ob() 191 | 192 | def _step(self, action): 193 | ob, reward, done, info = self.env.step(action) 194 | self.frames.append(ob) 195 | return self._get_ob(), reward, done, info 196 | 197 | def _get_ob(self): 198 | assert len(self.frames) == self.k 199 | return LazyFrames(list(self.frames)) 200 | 201 | 202 | class ScaledFloatFrame(gym.ObservationWrapper): 203 | def _observation(self, obs): 204 | # careful! This undoes the memory optimization, use 205 | # with smaller replay buffers only. 206 | return np.array(obs).astype(np.float32) / 255.0 207 | 208 | 209 | def wrap_dqn(env): 210 | """Apply a common set of wrappers for Atari games.""" 211 | assert 'NoFrameskip' in env.spec.id 212 | env = EpisodicLifeEnv(env) 213 | env = NoopResetEnv(env, noop_max=30) 214 | env = MaxAndSkipEnv(env, skip=4) 215 | if 'FIRE' in env.unwrapped.get_action_meanings(): 216 | env = FireResetEnv(env) 217 | env = ProcessFrame84(env) 218 | env = FrameStack(env, 4) 219 | env = ClippedRewardsWrapper(env) 220 | return env 221 | 222 | 223 | class A2cProcessFrame(gym.Wrapper): 224 | def __init__(self, env): 225 | gym.Wrapper.__init__(self, env) 226 | self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1)) 227 | 228 | def _step(self, action): 229 | ob, reward, done, info = self.env.step(action) 230 | return A2cProcessFrame.process(ob), reward, done, info 231 | 232 | def _reset(self): 233 | return A2cProcessFrame.process(self.env.reset()) 234 | 235 | @staticmethod 236 | def process(frame): 237 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) 238 | frame = cv2.resize(frame, (84, 84), interpolation=cv2.INTER_AREA) 239 | return frame.reshape(84, 84, 1) 240 | -------------------------------------------------------------------------------- /gailtf/baselines/common/azure_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import zipfile 4 | 5 | from azure.common import AzureMissingResourceHttpError 6 | try: 7 | from azure.storage.blob import BlobService 8 | except ImportError: 9 | from azure.storage.blob import BlockBlobService as BlobService 10 | from shutil import unpack_archive 11 | from threading import Event 12 | 13 | """TODOS: 14 | - use Azure snapshots instead of hacky backups 15 | """ 16 | 17 | 18 | def fixed_list_blobs(service, *args, **kwargs): 19 | """By defualt list_containers only returns a subset of results. 20 | 21 | This function attempts to fix this. 22 | """ 23 | res = [] 24 | next_marker = None 25 | while next_marker is None or len(next_marker) > 0: 26 | kwargs['marker'] = next_marker 27 | gen = service.list_blobs(*args, **kwargs) 28 | for b in gen: 29 | res.append(b.name) 30 | next_marker = gen.next_marker 31 | return res 32 | 33 | 34 | def make_archive(source_path, dest_path): 35 | if source_path.endswith(os.path.sep): 36 | source_path = source_path.rstrip(os.path.sep) 37 | prefix_path = os.path.dirname(source_path) 38 | with zipfile.ZipFile(dest_path, "w", compression=zipfile.ZIP_STORED) as zf: 39 | if os.path.isdir(source_path): 40 | for dirname, subdirs, files in os.walk(source_path): 41 | zf.write(dirname, os.path.relpath(dirname, prefix_path)) 42 | for filename in files: 43 | filepath = os.path.join(dirname, filename) 44 | zf.write(filepath, os.path.relpath(filepath, prefix_path)) 45 | else: 46 | zf.write(source_path, os.path.relpath(source_path, prefix_path)) 47 | 48 | 49 | class Container(object): 50 | services = {} 51 | 52 | def __init__(self, account_name, account_key, container_name, maybe_create=False): 53 | self._account_name = account_name 54 | self._container_name = container_name 55 | if account_name not in Container.services: 56 | Container.services[account_name] = BlobService(account_name, account_key) 57 | self._service = Container.services[account_name] 58 | if maybe_create: 59 | self._service.create_container(self._container_name, fail_on_exist=False) 60 | 61 | def put(self, source_path, blob_name, callback=None): 62 | """Upload a file or directory from `source_path` to azure blob `blob_name`. 63 | 64 | Upload progress can be traced by an optional callback. 65 | """ 66 | upload_done = Event() 67 | 68 | def progress_callback(current, total): 69 | if callback: 70 | callback(current, total) 71 | if current >= total: 72 | upload_done.set() 73 | 74 | # Attempt to make backup if an existing version is already available 75 | try: 76 | x_ms_copy_source = "https://{}.blob.core.windows.net/{}/{}".format( 77 | self._account_name, 78 | self._container_name, 79 | blob_name 80 | ) 81 | self._service.copy_blob( 82 | container_name=self._container_name, 83 | blob_name=blob_name + ".backup", 84 | x_ms_copy_source=x_ms_copy_source 85 | ) 86 | except AzureMissingResourceHttpError: 87 | pass 88 | 89 | with tempfile.TemporaryDirectory() as td: 90 | arcpath = os.path.join(td, "archive.zip") 91 | make_archive(source_path, arcpath) 92 | self._service.put_block_blob_from_path( 93 | container_name=self._container_name, 94 | blob_name=blob_name, 95 | file_path=arcpath, 96 | max_connections=4, 97 | progress_callback=progress_callback, 98 | max_retries=10) 99 | upload_done.wait() 100 | 101 | def get(self, dest_path, blob_name, callback=None): 102 | """Download a file or directory to `dest_path` to azure blob `blob_name`. 103 | 104 | Warning! If directory is downloaded the `dest_path` is the parent directory. 105 | 106 | Upload progress can be traced by an optional callback. 107 | """ 108 | download_done = Event() 109 | 110 | def progress_callback(current, total): 111 | if callback: 112 | callback(current, total) 113 | if current >= total: 114 | download_done.set() 115 | 116 | with tempfile.TemporaryDirectory() as td: 117 | arcpath = os.path.join(td, "archive.zip") 118 | for backup_blob_name in [blob_name, blob_name + '.backup']: 119 | try: 120 | properties = self._service.get_blob_properties( 121 | blob_name=backup_blob_name, 122 | container_name=self._container_name 123 | ) 124 | if hasattr(properties, 'properties'): 125 | # Annoyingly, Azure has changed the API and this now returns a blob 126 | # instead of it's properties with up-to-date azure package. 127 | blob_size = properties.properties.content_length 128 | else: 129 | blob_size = properties['content-length'] 130 | if int(blob_size) > 0: 131 | self._service.get_blob_to_path( 132 | container_name=self._container_name, 133 | blob_name=backup_blob_name, 134 | file_path=arcpath, 135 | max_connections=4, 136 | progress_callback=progress_callback) 137 | unpack_archive(arcpath, dest_path) 138 | download_done.wait() 139 | return True 140 | except AzureMissingResourceHttpError: 141 | pass 142 | return False 143 | 144 | def list(self, prefix=None): 145 | """List all blobs in the container.""" 146 | return fixed_list_blobs(self._service, self._container_name, prefix=prefix) 147 | 148 | def exists(self, blob_name): 149 | """Returns true if `blob_name` exists in container.""" 150 | try: 151 | self._service.get_blob_properties( 152 | blob_name=blob_name, 153 | container_name=self._container_name 154 | ) 155 | return True 156 | except AzureMissingResourceHttpError: 157 | return False 158 | -------------------------------------------------------------------------------- /gailtf/baselines/common/cg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | def cg(f_Ax, b, cg_iters=10, callback=None, verbose=False, residual_tol=1e-10): 3 | """ 4 | Demmel p 312 5 | """ 6 | p = b.copy() 7 | r = b.copy() 8 | x = np.zeros_like(b) 9 | rdotr = r.dot(r) 10 | 11 | fmtstr = "%10i %10.3g %10.3g" 12 | titlestr = "%10s %10s %10s" 13 | if verbose: print(titlestr % ("iter", "residual norm", "soln norm")) 14 | 15 | for i in range(cg_iters): 16 | if callback is not None: 17 | callback(x) 18 | if verbose: print(fmtstr % (i, rdotr, np.linalg.norm(x))) 19 | z = f_Ax(p) 20 | v = rdotr / p.dot(z) 21 | x += v*p 22 | r -= v*z 23 | newrdotr = r.dot(r) 24 | mu = newrdotr/rdotr 25 | p = r + mu*p 26 | 27 | rdotr = newrdotr 28 | if rdotr < residual_tol: 29 | break 30 | 31 | if callback is not None: 32 | callback(x) 33 | if verbose: print(fmtstr % (i+1, rdotr, np.linalg.norm(x))) # pylint: disable=W0631 34 | return x -------------------------------------------------------------------------------- /gailtf/baselines/common/console_util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from contextlib import contextmanager 3 | import numpy as np 4 | import time 5 | 6 | # ================================================================ 7 | # Misc 8 | # ================================================================ 9 | 10 | def fmt_row(width, row, header=False): 11 | out = " | ".join(fmt_item(x, width) for x in row) 12 | if header: out = out + "\n" + "-"*len(out) 13 | return out 14 | 15 | def fmt_item(x, l): 16 | if isinstance(x, np.ndarray): 17 | assert x.ndim==0 18 | x = x.item() 19 | if isinstance(x, float): rep = "%g"%x 20 | else: rep = str(x) 21 | return " "*(l - len(rep)) + rep 22 | 23 | color2num = dict( 24 | gray=30, 25 | red=31, 26 | green=32, 27 | yellow=33, 28 | blue=34, 29 | magenta=35, 30 | cyan=36, 31 | white=37, 32 | crimson=38 33 | ) 34 | 35 | def colorize(string, color, bold=False, highlight=False): 36 | attr = [] 37 | num = color2num[color] 38 | if highlight: num += 10 39 | attr.append(str(num)) 40 | if bold: attr.append('1') 41 | return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string) 42 | 43 | 44 | MESSAGE_DEPTH = 0 45 | 46 | @contextmanager 47 | def timed(msg): 48 | global MESSAGE_DEPTH #pylint: disable=W0603 49 | print(colorize('\t'*MESSAGE_DEPTH + '=: ' + msg, color='magenta')) 50 | tstart = time.time() 51 | MESSAGE_DEPTH += 1 52 | yield 53 | MESSAGE_DEPTH -= 1 54 | print(colorize('\t'*MESSAGE_DEPTH + "done in %.3f seconds"%(time.time() - tstart), color='magenta')) 55 | -------------------------------------------------------------------------------- /gailtf/baselines/common/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Dataset(object): 4 | def __init__(self, data_map, deterministic=False, shuffle=True): 5 | self.data_map = data_map 6 | self.deterministic = deterministic 7 | self.enable_shuffle = shuffle 8 | self.n = next(iter(data_map.values())).shape[0] 9 | self._next_id = 0 10 | self.shuffle() 11 | 12 | def shuffle(self): 13 | if self.deterministic: 14 | return 15 | perm = np.arange(self.n) 16 | np.random.shuffle(perm) 17 | 18 | for key in self.data_map: 19 | self.data_map[key] = self.data_map[key][perm] 20 | 21 | self._next_id = 0 22 | 23 | def next_batch(self, batch_size): 24 | if self._next_id >= self.n and self.enable_shuffle: 25 | self.shuffle() 26 | 27 | cur_id = self._next_id 28 | cur_batch_size = min(batch_size, self.n - self._next_id) 29 | self._next_id += cur_batch_size 30 | 31 | data_map = dict() 32 | for key in self.data_map: 33 | data_map[key] = self.data_map[key][cur_id:cur_id+cur_batch_size] 34 | return data_map 35 | 36 | def iterate_once(self, batch_size): 37 | if self.enable_shuffle: self.shuffle() 38 | 39 | while self._next_id <= self.n - batch_size: 40 | yield self.next_batch(batch_size) 41 | self._next_id = 0 42 | 43 | def subset(self, num_elements, deterministic=True): 44 | data_map = dict() 45 | for key in self.data_map: 46 | data_map[key] = self.data_map[key][:num_elements] 47 | return Dataset(data_map, deterministic) 48 | 49 | 50 | def iterbatches(arrays, *, num_batches=None, batch_size=None, shuffle=True, include_final_partial_batch=True): 51 | assert (num_batches is None) != (batch_size is None), 'Provide num_batches or batch_size, but not both' 52 | arrays = tuple(map(np.asarray, arrays)) 53 | n = arrays[0].shape[0] 54 | assert all(a.shape[0] == n for a in arrays[1:]) 55 | inds = np.arange(n) 56 | if shuffle: np.random.shuffle(inds) 57 | sections = np.arange(0, n, batch_size)[1:] if num_batches is None else num_batches 58 | for batch_inds in np.array_split(inds, sections): 59 | if include_final_partial_batch or len(batch_inds) == batch_size: 60 | yield tuple(a[batch_inds] for a in arrays) 61 | -------------------------------------------------------------------------------- /gailtf/baselines/common/distributions.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import gailtf.baselines.common.tf_util as U 4 | from tensorflow.python.ops import math_ops 5 | from tensorflow.python.ops import nn 6 | 7 | class Pd(object): 8 | """ 9 | A particular probability distribution 10 | """ 11 | def flatparam(self): 12 | raise NotImplementedError 13 | def mode(self): 14 | raise NotImplementedError 15 | def neglogp(self, x): 16 | # Usually it's easier to define the negative logprob 17 | raise NotImplementedError 18 | def kl(self, other): 19 | raise NotImplementedError 20 | def entropy(self): 21 | raise NotImplementedError 22 | def sample(self): 23 | raise NotImplementedError 24 | def logp(self, x): 25 | return - self.neglogp(x) 26 | 27 | class PdType(object): 28 | """ 29 | Parametrized family of probability distributions 30 | """ 31 | def pdclass(self): 32 | raise NotImplementedError 33 | def pdfromflat(self, flat): 34 | return self.pdclass()(flat) 35 | def param_shape(self): 36 | raise NotImplementedError 37 | def sample_shape(self): 38 | raise NotImplementedError 39 | def sample_dtype(self): 40 | raise NotImplementedError 41 | 42 | def param_placeholder(self, prepend_shape, name=None): 43 | return tf.placeholder(dtype=tf.float32, shape=prepend_shape+self.param_shape(), name=name) 44 | def sample_placeholder(self, prepend_shape, name=None): 45 | return tf.placeholder(dtype=self.sample_dtype(), shape=prepend_shape+self.sample_shape(), name=name) 46 | 47 | class CategoricalPdType(PdType): 48 | def __init__(self, ncat): 49 | self.ncat = ncat 50 | def pdclass(self): 51 | return CategoricalPd 52 | def param_shape(self): 53 | return [self.ncat] 54 | def sample_shape(self): 55 | return [] 56 | def sample_dtype(self): 57 | return tf.int32 58 | 59 | 60 | class MultiCategoricalPdType(PdType): 61 | def __init__(self, low, high): 62 | self.low = low 63 | self.high = high 64 | self.ncats = high - low + 1 65 | def pdclass(self): 66 | return MultiCategoricalPd 67 | def pdfromflat(self, flat): 68 | return MultiCategoricalPd(self.low, self.high, flat) 69 | def param_shape(self): 70 | return [sum(self.ncats)] 71 | def sample_shape(self): 72 | return [len(self.ncats)] 73 | def sample_dtype(self): 74 | return tf.int32 75 | 76 | class DiagGaussianPdType(PdType): 77 | def __init__(self, size): 78 | self.size = size 79 | def pdclass(self): 80 | return DiagGaussianPd 81 | def param_shape(self): 82 | return [2*self.size] 83 | def sample_shape(self): 84 | return [self.size] 85 | def sample_dtype(self): 86 | return tf.float32 87 | 88 | class BernoulliPdType(PdType): 89 | def __init__(self, size): 90 | self.size = size 91 | def pdclass(self): 92 | return BernoulliPd 93 | def param_shape(self): 94 | return [self.size] 95 | def sample_shape(self): 96 | return [self.size] 97 | def sample_dtype(self): 98 | return tf.int32 99 | 100 | # WRONG SECOND DERIVATIVES 101 | # class CategoricalPd(Pd): 102 | # def __init__(self, logits): 103 | # self.logits = logits 104 | # self.ps = tf.nn.softmax(logits) 105 | # @classmethod 106 | # def fromflat(cls, flat): 107 | # return cls(flat) 108 | # def flatparam(self): 109 | # return self.logits 110 | # def mode(self): 111 | # return U.argmax(self.logits, axis=-1) 112 | # def logp(self, x): 113 | # return -tf.nn.sparse_softmax_cross_entropy_with_logits(self.logits, x) 114 | # def kl(self, other): 115 | # return tf.nn.softmax_cross_entropy_with_logits(other.logits, self.ps) \ 116 | # - tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps) 117 | # def entropy(self): 118 | # return tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps) 119 | # def sample(self): 120 | # u = tf.random_uniform(tf.shape(self.logits)) 121 | # return U.argmax(self.logits - tf.log(-tf.log(u)), axis=-1) 122 | 123 | class CategoricalPd(Pd): 124 | def __init__(self, logits): 125 | self.logits = logits 126 | def flatparam(self): 127 | return self.logits 128 | def mode(self): 129 | return U.argmax(self.logits, axis=-1) 130 | def neglogp(self, x): 131 | # return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x) 132 | # Note: we can't use sparse_softmax_cross_entropy_with_logits because 133 | # the implementation does not allow second-order derivatives... 134 | one_hot_actions = tf.one_hot(x, self.logits.get_shape().as_list()[-1]) 135 | return tf.nn.softmax_cross_entropy_with_logits( 136 | logits=self.logits, 137 | labels=one_hot_actions) 138 | def kl(self, other): 139 | a0 = self.logits - U.max(self.logits, axis=-1, keepdims=True) 140 | a1 = other.logits - U.max(other.logits, axis=-1, keepdims=True) 141 | ea0 = tf.exp(a0) 142 | ea1 = tf.exp(a1) 143 | z0 = U.sum(ea0, axis=-1, keepdims=True) 144 | z1 = U.sum(ea1, axis=-1, keepdims=True) 145 | p0 = ea0 / z0 146 | return U.sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=-1) 147 | def entropy(self): 148 | a0 = self.logits - U.max(self.logits, axis=-1, keepdims=True) 149 | ea0 = tf.exp(a0) 150 | z0 = U.sum(ea0, axis=-1, keepdims=True) 151 | p0 = ea0 / z0 152 | return U.sum(p0 * (tf.log(z0) - a0), axis=-1) 153 | def sample(self): 154 | u = tf.random_uniform(tf.shape(self.logits)) 155 | return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=-1) 156 | @classmethod 157 | def fromflat(cls, flat): 158 | return cls(flat) 159 | 160 | class MultiCategoricalPd(Pd): 161 | def __init__(self, low, high, flat): 162 | self.flat = flat 163 | self.low = tf.constant(low, dtype=tf.int32) 164 | self.categoricals = list(map(CategoricalPd, tf.split(flat, high - low + 1, axis=len(flat.get_shape()) - 1))) 165 | def flatparam(self): 166 | return self.flat 167 | def mode(self): 168 | return self.low + tf.cast(tf.stack([p.mode() for p in self.categoricals], axis=-1), tf.int32) 169 | def neglogp(self, x): 170 | return tf.add_n([p.neglogp(px) for p, px in zip(self.categoricals, tf.unstack(x - self.low, axis=len(x.get_shape()) - 1))]) 171 | def kl(self, other): 172 | return tf.add_n([ 173 | p.kl(q) for p, q in zip(self.categoricals, other.categoricals) 174 | ]) 175 | def entropy(self): 176 | return tf.add_n([p.entropy() for p in self.categoricals]) 177 | def sample(self): 178 | return self.low + tf.cast(tf.stack([p.sample() for p in self.categoricals], axis=-1), tf.int32) 179 | @classmethod 180 | def fromflat(cls, flat): 181 | raise NotImplementedError 182 | 183 | class DiagGaussianPd(Pd): 184 | def __init__(self, flat): 185 | self.flat = flat 186 | mean, logstd = tf.split(axis=len(flat.shape)-1, num_or_size_splits=2, value=flat) 187 | self.mean = mean 188 | self.logstd = logstd 189 | self.std = tf.exp(logstd) 190 | def flatparam(self): 191 | return self.flat 192 | def mode(self): 193 | return self.mean 194 | def neglogp(self, x): 195 | return 0.5 * U.sum(tf.square((x - self.mean) / self.std), axis=-1) \ 196 | + 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[-1]) \ 197 | + U.sum(self.logstd, axis=-1) 198 | def kl(self, other): 199 | assert isinstance(other, DiagGaussianPd) 200 | return U.sum(other.logstd - self.logstd + (tf.square(self.std) + tf.square(self.mean - other.mean)) / (2.0 * tf.square(other.std)) - 0.5, axis=-1) 201 | def entropy(self): 202 | return U.sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), axis=-1) 203 | def sample(self): 204 | return self.mean + self.std * tf.random_normal(tf.shape(self.mean)) 205 | @classmethod 206 | def fromflat(cls, flat): 207 | return cls(flat) 208 | 209 | class BernoulliPd(Pd): 210 | def __init__(self, logits): 211 | self.logits = logits 212 | self.ps = tf.sigmoid(logits) 213 | def flatparam(self): 214 | return self.logits 215 | def mode(self): 216 | return tf.round(self.ps) 217 | def neglogp(self, x): 218 | return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=tf.to_float(x)), axis=-1) 219 | def kl(self, other): 220 | return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=other.logits, labels=self.ps), axis=-1) - U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1) 221 | def entropy(self): 222 | return U.sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1) 223 | def sample(self): 224 | u = tf.random_uniform(tf.shape(self.ps)) 225 | return tf.to_float(math_ops.less(u, self.ps)) 226 | @classmethod 227 | def fromflat(cls, flat): 228 | return cls(flat) 229 | 230 | def make_pdtype(ac_space): 231 | from gym import spaces 232 | if isinstance(ac_space, spaces.Box): 233 | assert len(ac_space.shape) == 1 234 | return DiagGaussianPdType(ac_space.shape[0]) 235 | elif isinstance(ac_space, spaces.Discrete): 236 | return CategoricalPdType(ac_space.n) 237 | elif isinstance(ac_space, spaces.MultiDiscrete): 238 | return MultiCategoricalPdType(ac_space.low, ac_space.high) 239 | elif isinstance(ac_space, spaces.MultiBinary): 240 | return BernoulliPdType(ac_space.n) 241 | else: 242 | raise NotImplementedError 243 | 244 | def shape_el(v, i): 245 | maybe = v.get_shape()[i] 246 | if maybe is not None: 247 | return maybe 248 | else: 249 | return tf.shape(v)[i] 250 | 251 | @U.in_session 252 | def test_probtypes(): 253 | np.random.seed(0) 254 | 255 | pdparam_diag_gauss = np.array([-.2, .3, .4, -.5, .1, -.5, .1, 0.8]) 256 | diag_gauss = DiagGaussianPdType(pdparam_diag_gauss.size // 2) #pylint: disable=E1101 257 | validate_probtype(diag_gauss, pdparam_diag_gauss) 258 | 259 | pdparam_categorical = np.array([-.2, .3, .5]) 260 | categorical = CategoricalPdType(pdparam_categorical.size) #pylint: disable=E1101 261 | validate_probtype(categorical, pdparam_categorical) 262 | 263 | pdparam_bernoulli = np.array([-.2, .3, .5]) 264 | bernoulli = BernoulliPdType(pdparam_bernoulli.size) #pylint: disable=E1101 265 | validate_probtype(bernoulli, pdparam_bernoulli) 266 | 267 | 268 | def validate_probtype(probtype, pdparam): 269 | N = 100000 270 | # Check to see if mean negative log likelihood == differential entropy 271 | Mval = np.repeat(pdparam[None, :], N, axis=0) 272 | M = probtype.param_placeholder([N]) 273 | X = probtype.sample_placeholder([N]) 274 | pd = probtype.pdclass()(M) 275 | calcloglik = U.function([X, M], pd.logp(X)) 276 | calcent = U.function([M], pd.entropy()) 277 | Xval = U.eval(pd.sample(), feed_dict={M:Mval}) 278 | logliks = calcloglik(Xval, Mval) 279 | entval_ll = - logliks.mean() #pylint: disable=E1101 280 | entval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101 281 | entval = calcent(Mval).mean() #pylint: disable=E1101 282 | assert np.abs(entval - entval_ll) < 3 * entval_ll_stderr # within 3 sigmas 283 | 284 | # Check to see if kldiv[p,q] = - ent[p] - E_p[log q] 285 | M2 = probtype.param_placeholder([N]) 286 | pd2 = probtype.pdclass()(M2) 287 | q = pdparam + np.random.randn(pdparam.size) * 0.1 288 | Mval2 = np.repeat(q[None, :], N, axis=0) 289 | calckl = U.function([M, M2], pd.kl(pd2)) 290 | klval = calckl(Mval, Mval2).mean() #pylint: disable=E1101 291 | logliks = calcloglik(Xval, Mval2) 292 | klval_ll = - entval - logliks.mean() #pylint: disable=E1101 293 | klval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101 294 | assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas 295 | -------------------------------------------------------------------------------- /gailtf/baselines/common/math_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.signal 3 | 4 | 5 | def discount(x, gamma): 6 | """ 7 | computes discounted sums along 0th dimension of x. 8 | 9 | inputs 10 | ------ 11 | x: ndarray 12 | gamma: float 13 | 14 | outputs 15 | ------- 16 | y: ndarray with same shape as x, satisfying 17 | 18 | y[t] = x[t] + gamma*x[t+1] + gamma^2*x[t+2] + ... + gamma^k x[t+k], 19 | where k = len(x) - t - 1 20 | 21 | """ 22 | assert x.ndim >= 1 23 | return scipy.signal.lfilter([1],[1,-gamma],x[::-1], axis=0)[::-1] 24 | 25 | def explained_variance(ypred,y): 26 | """ 27 | Computes fraction of variance that ypred explains about y. 28 | Returns 1 - Var[y-ypred] / Var[y] 29 | 30 | interpretation: 31 | ev=0 => might as well have predicted zero 32 | ev=1 => perfect prediction 33 | ev<0 => worse than just predicting zero 34 | 35 | """ 36 | assert y.ndim == 1 and ypred.ndim == 1 37 | vary = np.var(y) 38 | return np.nan if vary==0 else 1 - np.var(y-ypred)/vary 39 | 40 | def explained_variance_2d(ypred, y): 41 | assert y.ndim == 2 and ypred.ndim == 2 42 | vary = np.var(y, axis=0) 43 | out = 1 - np.var(y-ypred)/vary 44 | out[vary < 1e-10] = 0 45 | return out 46 | 47 | def ncc(ypred, y): 48 | return np.corrcoef(ypred, y)[1,0] 49 | 50 | def flatten_arrays(arrs): 51 | return np.concatenate([arr.flat for arr in arrs]) 52 | 53 | def unflatten_vector(vec, shapes): 54 | i=0 55 | arrs = [] 56 | for shape in shapes: 57 | size = np.prod(shape) 58 | arr = vec[i:i+size].reshape(shape) 59 | arrs.append(arr) 60 | i += size 61 | return arrs 62 | 63 | def discount_with_boundaries(X, New, gamma): 64 | """ 65 | X: 2d array of floats, time x features 66 | New: 2d array of bools, indicating when a new episode has started 67 | """ 68 | Y = np.zeros_like(X) 69 | T = X.shape[0] 70 | Y[T-1] = X[T-1] 71 | for t in range(T-2, -1, -1): 72 | Y[t] = X[t] + gamma * Y[t+1] * (1 - New[t+1]) 73 | return Y 74 | 75 | def test_discount_with_boundaries(): 76 | gamma=0.9 77 | x = np.array([1.0, 2.0, 3.0, 4.0], 'float32') 78 | starts = [1.0, 0.0, 0.0, 1.0] 79 | y = discount_with_boundaries(x, starts, gamma) 80 | assert np.allclose(y, [ 81 | 1 + gamma * 2 + gamma**2 * 3, 82 | 2 + gamma * 3, 83 | 3, 84 | 4 85 | ]) -------------------------------------------------------------------------------- /gailtf/baselines/common/misc_util.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import os 4 | import pickle 5 | import random 6 | import tempfile 7 | import time 8 | import zipfile 9 | 10 | 11 | def zipsame(*seqs): 12 | L = len(seqs[0]) 13 | assert all(len(seq) == L for seq in seqs[1:]) 14 | return zip(*seqs) 15 | 16 | 17 | def unpack(seq, sizes): 18 | """ 19 | Unpack 'seq' into a sequence of lists, with lengths specified by 'sizes'. 20 | None = just one bare element, not a list 21 | 22 | Example: 23 | unpack([1,2,3,4,5,6], [3,None,2]) -> ([1,2,3], 4, [5,6]) 24 | """ 25 | seq = list(seq) 26 | it = iter(seq) 27 | assert sum(1 if s is None else s for s in sizes) == len(seq), "Trying to unpack %s into %s" % (seq, sizes) 28 | for size in sizes: 29 | if size is None: 30 | yield it.__next__() 31 | else: 32 | li = [] 33 | for _ in range(size): 34 | li.append(it.__next__()) 35 | yield li 36 | 37 | 38 | class EzPickle(object): 39 | """Objects that are pickled and unpickled via their constructor 40 | arguments. 41 | 42 | Example usage: 43 | 44 | class Dog(Animal, EzPickle): 45 | def __init__(self, furcolor, tailkind="bushy"): 46 | Animal.__init__() 47 | EzPickle.__init__(furcolor, tailkind) 48 | ... 49 | 50 | When this object is unpickled, a new Dog will be constructed by passing the provided 51 | furcolor and tailkind into the constructor. However, philosophers are still not sure 52 | whether it is still the same dog. 53 | 54 | This is generally needed only for environments which wrap C/C++ code, such as MuJoCo 55 | and Atari. 56 | """ 57 | 58 | def __init__(self, *args, **kwargs): 59 | self._ezpickle_args = args 60 | self._ezpickle_kwargs = kwargs 61 | 62 | def __getstate__(self): 63 | return {"_ezpickle_args": self._ezpickle_args, "_ezpickle_kwargs": self._ezpickle_kwargs} 64 | 65 | def __setstate__(self, d): 66 | out = type(self)(*d["_ezpickle_args"], **d["_ezpickle_kwargs"]) 67 | self.__dict__.update(out.__dict__) 68 | 69 | 70 | def set_global_seeds(i): 71 | try: 72 | import tensorflow as tf 73 | except ImportError: 74 | pass 75 | else: 76 | tf.set_random_seed(i) 77 | np.random.seed(i) 78 | random.seed(i) 79 | 80 | 81 | def pretty_eta(seconds_left): 82 | """Print the number of seconds in human readable format. 83 | 84 | Examples: 85 | 2 days 86 | 2 hours and 37 minutes 87 | less than a minute 88 | 89 | Paramters 90 | --------- 91 | seconds_left: int 92 | Number of seconds to be converted to the ETA 93 | Returns 94 | ------- 95 | eta: str 96 | String representing the pretty ETA. 97 | """ 98 | minutes_left = seconds_left // 60 99 | seconds_left %= 60 100 | hours_left = minutes_left // 60 101 | minutes_left %= 60 102 | days_left = hours_left // 24 103 | hours_left %= 24 104 | 105 | def helper(cnt, name): 106 | return "{} {}{}".format(str(cnt), name, ('s' if cnt > 1 else '')) 107 | 108 | if days_left > 0: 109 | msg = helper(days_left, 'day') 110 | if hours_left > 0: 111 | msg += ' and ' + helper(hours_left, 'hour') 112 | return msg 113 | if hours_left > 0: 114 | msg = helper(hours_left, 'hour') 115 | if minutes_left > 0: 116 | msg += ' and ' + helper(minutes_left, 'minute') 117 | return msg 118 | if minutes_left > 0: 119 | return helper(minutes_left, 'minute') 120 | return 'less than a minute' 121 | 122 | 123 | class RunningAvg(object): 124 | def __init__(self, gamma, init_value=None): 125 | """Keep a running estimate of a quantity. This is a bit like mean 126 | but more sensitive to recent changes. 127 | 128 | Parameters 129 | ---------- 130 | gamma: float 131 | Must be between 0 and 1, where 0 is the most sensitive to recent 132 | changes. 133 | init_value: float or None 134 | Initial value of the estimate. If None, it will be set on the first update. 135 | """ 136 | self._value = init_value 137 | self._gamma = gamma 138 | 139 | def update(self, new_val): 140 | """Update the estimate. 141 | 142 | Parameters 143 | ---------- 144 | new_val: float 145 | new observated value of estimated quantity. 146 | """ 147 | if self._value is None: 148 | self._value = new_val 149 | else: 150 | self._value = self._gamma * self._value + (1.0 - self._gamma) * new_val 151 | 152 | def __float__(self): 153 | """Get the current estimate""" 154 | return self._value 155 | 156 | 157 | class SimpleMonitor(gym.Wrapper): 158 | def __init__(self, env): 159 | """Adds two qunatities to info returned by every step: 160 | 161 | num_steps: int 162 | Number of steps takes so far 163 | rewards: [float] 164 | All the cumulative rewards for the episodes completed so far. 165 | """ 166 | super().__init__(env) 167 | # current episode state 168 | self._current_reward = None 169 | self._num_steps = None 170 | # temporary monitor state that we do not save 171 | self._time_offset = None 172 | self._total_steps = None 173 | # monitor state 174 | self._episode_rewards = [] 175 | self._episode_lengths = [] 176 | self._episode_end_times = [] 177 | 178 | def _reset(self): 179 | obs = self.env.reset() 180 | # recompute temporary state if needed 181 | if self._time_offset is None: 182 | self._time_offset = time.time() 183 | if len(self._episode_end_times) > 0: 184 | self._time_offset -= self._episode_end_times[-1] 185 | if self._total_steps is None: 186 | self._total_steps = sum(self._episode_lengths) 187 | # update monitor state 188 | if self._current_reward is not None: 189 | self._episode_rewards.append(self._current_reward) 190 | self._episode_lengths.append(self._num_steps) 191 | self._episode_end_times.append(time.time() - self._time_offset) 192 | # reset episode state 193 | self._current_reward = 0 194 | self._num_steps = 0 195 | 196 | return obs 197 | 198 | def _step(self, action): 199 | obs, rew, done, info = self.env.step(action) 200 | self._current_reward += rew 201 | self._num_steps += 1 202 | self._total_steps += 1 203 | info['steps'] = self._total_steps 204 | info['rewards'] = self._episode_rewards 205 | return (obs, rew, done, info) 206 | 207 | def get_state(self): 208 | return { 209 | 'env_id': self.env.unwrapped.spec.id, 210 | 'episode_data': { 211 | 'episode_rewards': self._episode_rewards, 212 | 'episode_lengths': self._episode_lengths, 213 | 'episode_end_times': self._episode_end_times, 214 | 'initial_reset_time': 0, 215 | } 216 | } 217 | 218 | def set_state(self, state): 219 | assert state['env_id'] == self.env.unwrapped.spec.id 220 | ed = state['episode_data'] 221 | self._episode_rewards = ed['episode_rewards'] 222 | self._episode_lengths = ed['episode_lengths'] 223 | self._episode_end_times = ed['episode_end_times'] 224 | 225 | 226 | def boolean_flag(parser, name, default=False, help=None): 227 | """Add a boolean flag to argparse parser. 228 | 229 | Parameters 230 | ---------- 231 | parser: argparse.Parser 232 | parser to add the flag to 233 | name: str 234 | -- will enable the flag, while --no- will disable it 235 | default: bool or None 236 | default value of the flag 237 | help: str 238 | help string for the flag 239 | """ 240 | dest = name.replace('-', '_') 241 | parser.add_argument("--" + name, action="store_true", default=default, dest=dest, help=help) 242 | parser.add_argument("--no-" + name, action="store_false", dest=dest) 243 | 244 | 245 | def get_wrapper_by_name(env, classname): 246 | """Given an a gym environment possibly wrapped multiple times, returns a wrapper 247 | of class named classname or raises ValueError if no such wrapper was applied 248 | 249 | Parameters 250 | ---------- 251 | env: gym.Env of gym.Wrapper 252 | gym environment 253 | classname: str 254 | name of the wrapper 255 | 256 | Returns 257 | ------- 258 | wrapper: gym.Wrapper 259 | wrapper named classname 260 | """ 261 | currentenv = env 262 | while True: 263 | if classname == currentenv.class_name(): 264 | return currentenv 265 | elif isinstance(currentenv, gym.Wrapper): 266 | currentenv = currentenv.env 267 | else: 268 | raise ValueError("Couldn't find wrapper named %s" % classname) 269 | 270 | 271 | def relatively_safe_pickle_dump(obj, path, compression=False): 272 | """This is just like regular pickle dump, except from the fact that failure cases are 273 | different: 274 | 275 | - It's never possible that we end up with a pickle in corrupted state. 276 | - If a there was a different file at the path, that file will remain unchanged in the 277 | even of failure (provided that filesystem rename is atomic). 278 | - it is sometimes possible that we end up with useless temp file which needs to be 279 | deleted manually (it will be removed automatically on the next function call) 280 | 281 | The indended use case is periodic checkpoints of experiment state, such that we never 282 | corrupt previous checkpoints if the current one fails. 283 | 284 | Parameters 285 | ---------- 286 | obj: object 287 | object to pickle 288 | path: str 289 | path to the output file 290 | compression: bool 291 | if true pickle will be compressed 292 | """ 293 | temp_storage = path + ".relatively_safe" 294 | if compression: 295 | # Using gzip here would be simpler, but the size is limited to 2GB 296 | with tempfile.NamedTemporaryFile() as uncompressed_file: 297 | pickle.dump(obj, uncompressed_file) 298 | with zipfile.ZipFile(temp_storage, "w", compression=zipfile.ZIP_DEFLATED) as myzip: 299 | myzip.write(uncompressed_file.name, "data") 300 | else: 301 | with open(temp_storage, "wb") as f: 302 | pickle.dump(obj, f) 303 | os.rename(temp_storage, path) 304 | 305 | 306 | def pickle_load(path, compression=False): 307 | """Unpickle a possible compressed pickle. 308 | 309 | Parameters 310 | ---------- 311 | path: str 312 | path to the output file 313 | compression: bool 314 | if true assumes that pickle was compressed when created and attempts decompression. 315 | 316 | Returns 317 | ------- 318 | obj: object 319 | the unpickled object 320 | """ 321 | 322 | if compression: 323 | with zipfile.ZipFile(path, "r", compression=zipfile.ZIP_DEFLATED) as myzip: 324 | with myzip.open("data") as f: 325 | return pickle.load(f) 326 | else: 327 | with open(path, "rb") as f: 328 | return pickle.load(f) 329 | -------------------------------------------------------------------------------- /gailtf/baselines/common/mpi_adam.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | import gailtf.baselines.common.tf_util as U 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | class MpiAdam(object): 7 | def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None): 8 | self.var_list = var_list 9 | self.beta1 = beta1 10 | self.beta2 = beta2 11 | self.epsilon = epsilon 12 | self.scale_grad_by_procs = scale_grad_by_procs 13 | size = sum(U.numel(v) for v in var_list) 14 | self.m = np.zeros(size, 'float32') 15 | self.v = np.zeros(size, 'float32') 16 | self.t = 0 17 | self.setfromflat = U.SetFromFlat(var_list) 18 | self.getflat = U.GetFlat(var_list) 19 | self.comm = MPI.COMM_WORLD if comm is None else comm 20 | 21 | def update(self, localg, stepsize): 22 | if self.t % 100 == 0: 23 | self.check_synced() 24 | localg = localg.astype('float32') 25 | globalg = np.zeros_like(localg) 26 | self.comm.Allreduce(localg, globalg, op=MPI.SUM) 27 | if self.scale_grad_by_procs: 28 | globalg /= self.comm.Get_size() 29 | 30 | self.t += 1 31 | a = stepsize * np.sqrt(1 - self.beta2**self.t)/(1 - self.beta1**self.t) 32 | self.m = self.beta1 * self.m + (1 - self.beta1) * globalg 33 | self.v = self.beta2 * self.v + (1 - self.beta2) * (globalg * globalg) 34 | step = (- a) * self.m / (np.sqrt(self.v) + self.epsilon) 35 | self.setfromflat(self.getflat() + step) 36 | 37 | def sync(self): 38 | theta = self.getflat() 39 | self.comm.Bcast(theta, root=0) 40 | self.setfromflat(theta) 41 | 42 | def check_synced(self): 43 | if self.comm.Get_rank() == 0: # this is root 44 | theta = self.getflat() 45 | self.comm.Bcast(theta, root=0) 46 | else: 47 | thetalocal = self.getflat() 48 | thetaroot = np.empty_like(thetalocal) 49 | self.comm.Bcast(thetaroot, root=0) 50 | assert (thetaroot == thetalocal).all(), (thetaroot, thetalocal) 51 | 52 | @U.in_session 53 | def test_MpiAdam(): 54 | np.random.seed(0) 55 | tf.set_random_seed(0) 56 | 57 | a = tf.Variable(np.random.randn(3).astype('float32')) 58 | b = tf.Variable(np.random.randn(2,5).astype('float32')) 59 | loss = tf.reduce_sum(tf.square(a)) + tf.reduce_sum(tf.sin(b)) 60 | 61 | stepsize = 1e-2 62 | update_op = tf.train.AdamOptimizer(stepsize).minimize(loss) 63 | do_update = U.function([], loss, updates=[update_op]) 64 | 65 | tf.get_default_session().run(tf.global_variables_initializer()) 66 | for i in range(10): 67 | print(i,do_update()) 68 | 69 | tf.set_random_seed(0) 70 | tf.get_default_session().run(tf.global_variables_initializer()) 71 | 72 | var_list = [a,b] 73 | lossandgrad = U.function([], [loss, U.flatgrad(loss, var_list)], updates=[update_op]) 74 | adam = MpiAdam(var_list) 75 | 76 | for i in range(10): 77 | l,g = lossandgrad() 78 | adam.update(g, stepsize) 79 | print(i,l) -------------------------------------------------------------------------------- /gailtf/baselines/common/mpi_fork.py: -------------------------------------------------------------------------------- 1 | import os, subprocess, sys 2 | 3 | def mpi_fork(n, bind_to_core=False): 4 | """Re-launches the current script with workers 5 | Returns "parent" for original parent, "child" for MPI children 6 | """ 7 | if n<=1: 8 | return "child" 9 | if os.getenv("IN_MPI") is None: 10 | env = os.environ.copy() 11 | env.update( 12 | MKL_NUM_THREADS="1", 13 | OMP_NUM_THREADS="1", 14 | IN_MPI="1" 15 | ) 16 | args = ["mpirun", "-np", str(n)] 17 | if bind_to_core: 18 | args += ["-bind-to", "core"] 19 | args += [sys.executable] + sys.argv 20 | subprocess.check_call(args, env=env) 21 | return "parent" 22 | else: 23 | return "child" 24 | -------------------------------------------------------------------------------- /gailtf/baselines/common/mpi_moments.py: -------------------------------------------------------------------------------- 1 | from mpi4py import MPI 2 | import numpy as np 3 | from gailtf.baselines.common import zipsame 4 | 5 | def mpi_moments(x, axis=0): 6 | x = np.asarray(x, dtype='float64') 7 | newshape = list(x.shape) 8 | newshape.pop(axis) 9 | n = np.prod(newshape,dtype=int) 10 | totalvec = np.zeros(n*2+1, 'float64') 11 | addvec = np.concatenate([x.sum(axis=axis).ravel(), 12 | np.square(x).sum(axis=axis).ravel(), 13 | np.array([x.shape[axis]],dtype='float64')]) 14 | MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM) 15 | sum = totalvec[:n] 16 | sumsq = totalvec[n:2*n] 17 | count = totalvec[2*n] 18 | if count == 0: 19 | mean = np.empty(newshape); mean[:] = np.nan 20 | std = np.empty(newshape); std[:] = np.nan 21 | else: 22 | mean = sum/count 23 | std = np.sqrt(np.maximum(sumsq/count - np.square(mean),0)) 24 | return mean, std, count 25 | 26 | 27 | def test_runningmeanstd(): 28 | comm = MPI.COMM_WORLD 29 | np.random.seed(0) 30 | for (triple,axis) in [ 31 | ((np.random.randn(3), np.random.randn(4), np.random.randn(5)),0), 32 | ((np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),0), 33 | ((np.random.randn(2,3), np.random.randn(2,4), np.random.randn(2,4)),1), 34 | ]: 35 | 36 | 37 | x = np.concatenate(triple, axis=axis) 38 | ms1 = [x.mean(axis=axis), x.std(axis=axis), x.shape[axis]] 39 | 40 | 41 | ms2 = mpi_moments(triple[comm.Get_rank()],axis=axis) 42 | 43 | for (a1,a2) in zipsame(ms1, ms2): 44 | print(a1, a2) 45 | assert np.allclose(a1, a2) 46 | print("ok!") 47 | 48 | if __name__ == "__main__": 49 | #mpirun -np 3 python