├── .gitignore ├── LICENSE ├── README.md ├── a2c_ppo_acktr ├── __init__.py ├── algo │ ├── __init__.py │ ├── a2c_acktr.py │ ├── gail.py │ ├── kfac.py │ └── ppo.py ├── arguments.py ├── distributions.py ├── envs.py ├── model.py ├── storage.py └── utils.py ├── alphamujoco.sh ├── baselinesbc.sh ├── bc.py ├── bcgail.sh ├── bconly.sh ├── calclength.sh ├── enjoy.py ├── evaluation.py ├── gail.sh ├── gail_experts ├── README.md └── convert_to_pytorch.py ├── generate_tmux_yaml.py ├── imgs ├── a2c_beamrider.png ├── a2c_breakout.png ├── a2c_qbert.png ├── a2c_seaquest.png ├── acktr_beamrider.png ├── acktr_breakout.png ├── acktr_qbert.png ├── acktr_seaquest.png ├── ppo_halfcheetah.png ├── ppo_hopper.png ├── ppo_reacher.png └── ppo_walker.png ├── logs ├── halfcheetah │ ├── halfcheetah-0 │ │ └── 0.monitor.csv │ ├── halfcheetah-1 │ │ └── 0.monitor.csv │ ├── halfcheetah-2 │ │ └── 0.monitor.csv │ ├── halfcheetah-3 │ │ └── 0.monitor.csv │ ├── halfcheetah-4 │ │ └── 0.monitor.csv │ ├── halfcheetah-5 │ │ └── 0.monitor.csv │ ├── halfcheetah-6 │ │ └── 0.monitor.csv │ ├── halfcheetah-7 │ │ └── 0.monitor.csv │ ├── halfcheetah-8 │ │ └── 0.monitor.csv │ └── halfcheetah-9 │ │ └── 0.monitor.csv ├── hopper │ ├── hopper-0 │ │ └── 0.monitor.csv │ ├── hopper-1 │ │ └── 0.monitor.csv │ ├── hopper-2 │ │ └── 0.monitor.csv │ ├── hopper-3 │ │ └── 0.monitor.csv │ ├── hopper-4 │ │ └── 0.monitor.csv │ ├── hopper-5 │ │ └── 0.monitor.csv │ ├── hopper-6 │ │ └── 0.monitor.csv │ ├── hopper-7 │ │ └── 0.monitor.csv │ ├── hopper-8 │ │ └── 0.monitor.csv │ └── hopper-9 │ │ └── 0.monitor.csv ├── reacher │ ├── reacher-0 │ │ └── 0.monitor.csv │ ├── reacher-1 │ │ └── 0.monitor.csv │ ├── reacher-2 │ │ └── 0.monitor.csv │ ├── reacher-3 │ │ └── 0.monitor.csv │ ├── reacher-4 │ │ └── 0.monitor.csv │ ├── reacher-5 │ │ └── 0.monitor.csv │ ├── reacher-6 │ │ └── 0.monitor.csv │ ├── reacher-7 │ │ └── 0.monitor.csv │ ├── reacher-8 │ │ └── 0.monitor.csv │ └── reacher-9 │ │ └── 0.monitor.csv └── walker2d │ ├── walker2d-0 │ └── 0.monitor.csv │ ├── walker2d-1 │ └── 0.monitor.csv │ ├── walker2d-2 │ └── 0.monitor.csv │ ├── walker2d-3 │ └── 0.monitor.csv │ ├── walker2d-4 │ └── 0.monitor.csv │ ├── walker2d-5 │ └── 0.monitor.csv │ ├── walker2d-6 │ └── 0.monitor.csv │ ├── walker2d-7 │ └── 0.monitor.csv │ ├── walker2d-8 │ └── 0.monitor.csv │ └── walker2d-9 │ └── 0.monitor.csv ├── main.py ├── mujoco.sh ├── plot_graphs.py ├── plot_length_bar_graphs.py ├── plots ├── plotalpha.sh ├── plotbcgail.sh └── plotnogail.sh ├── print_experts.py ├── print_results.py ├── redsail.sh ├── requirements.txt ├── run_all.yaml ├── save.sh ├── savemaster.sh ├── scripts.txt ├── setup.py ├── time_limit_logs ├── halfcheetah │ ├── halfcheetah-0 │ │ └── 0.monitor.csv │ ├── halfcheetah-1 │ │ └── 0.monitor.csv │ ├── halfcheetah-2 │ │ └── 0.monitor.csv │ ├── halfcheetah-3 │ │ └── 0.monitor.csv │ ├── halfcheetah-4 │ │ └── 0.monitor.csv │ ├── halfcheetah-5 │ │ └── 0.monitor.csv │ ├── halfcheetah-6 │ │ └── 0.monitor.csv │ ├── halfcheetah-7 │ │ └── 0.monitor.csv │ ├── halfcheetah-8 │ │ └── 0.monitor.csv │ ├── halfcheetah-9 │ │ └── 0.monitor.csv │ ├── unfixhalfcheetah-0 │ │ └── 0.monitor.csv │ ├── unfixhalfcheetah-1 │ │ └── 0.monitor.csv │ ├── unfixhalfcheetah-2 │ │ └── 0.monitor.csv │ ├── unfixhalfcheetah-3 │ │ └── 0.monitor.csv │ ├── unfixhalfcheetah-4 │ │ └── 0.monitor.csv │ ├── unfixhalfcheetah-5 │ │ └── 0.monitor.csv │ ├── unfixhalfcheetah-6 │ │ └── 0.monitor.csv │ ├── unfixhalfcheetah-7 │ │ └── 0.monitor.csv │ ├── unfixhalfcheetah-8 │ │ └── 0.monitor.csv │ └── unfixhalfcheetah-9 │ │ └── 0.monitor.csv ├── hopper │ ├── hopper-0 │ │ └── 0.monitor.csv │ ├── hopper-1 │ │ └── 0.monitor.csv │ ├── hopper-2 │ │ └── 0.monitor.csv │ ├── hopper-3 │ │ └── 0.monitor.csv │ ├── hopper-4 │ │ └── 0.monitor.csv │ ├── hopper-5 │ │ └── 0.monitor.csv │ ├── hopper-6 │ │ └── 0.monitor.csv │ ├── hopper-7 │ │ └── 0.monitor.csv │ ├── hopper-8 │ │ └── 0.monitor.csv │ ├── hopper-9 │ │ └── 0.monitor.csv │ ├── unfixhopper-0 │ │ └── 0.monitor.csv │ ├── unfixhopper-1 │ │ └── 0.monitor.csv │ ├── unfixhopper-2 │ │ └── 0.monitor.csv │ ├── unfixhopper-3 │ │ └── 0.monitor.csv │ ├── unfixhopper-4 │ │ └── 0.monitor.csv │ ├── unfixhopper-5 │ │ └── 0.monitor.csv │ ├── unfixhopper-6 │ │ └── 0.monitor.csv │ ├── unfixhopper-7 │ │ └── 0.monitor.csv │ ├── unfixhopper-8 │ │ └── 0.monitor.csv │ └── unfixhopper-9 │ │ └── 0.monitor.csv ├── reacher │ ├── reacher-0 │ │ └── 0.monitor.csv │ ├── reacher-1 │ │ └── 0.monitor.csv │ ├── reacher-2 │ │ └── 0.monitor.csv │ ├── reacher-3 │ │ └── 0.monitor.csv │ ├── reacher-4 │ │ └── 0.monitor.csv │ ├── reacher-5 │ │ └── 0.monitor.csv │ ├── reacher-6 │ │ └── 0.monitor.csv │ ├── reacher-7 │ │ └── 0.monitor.csv │ ├── reacher-8 │ │ └── 0.monitor.csv │ ├── reacher-9 │ │ └── 0.monitor.csv │ ├── unfixreacher-0 │ │ └── 0.monitor.csv │ ├── unfixreacher-1 │ │ └── 0.monitor.csv │ ├── unfixreacher-2 │ │ └── 0.monitor.csv │ ├── unfixreacher-3 │ │ └── 0.monitor.csv │ ├── unfixreacher-4 │ │ └── 0.monitor.csv │ ├── unfixreacher-5 │ │ └── 0.monitor.csv │ ├── unfixreacher-6 │ │ └── 0.monitor.csv │ ├── unfixreacher-7 │ │ └── 0.monitor.csv │ ├── unfixreacher-8 │ │ └── 0.monitor.csv │ └── unfixreacher-9 │ │ └── 0.monitor.csv └── walker2d │ ├── unfixwalker2d-0 │ └── 0.monitor.csv │ ├── unfixwalker2d-1 │ └── 0.monitor.csv │ ├── unfixwalker2d-2 │ └── 0.monitor.csv │ ├── unfixwalker2d-3 │ └── 0.monitor.csv │ ├── unfixwalker2d-4 │ └── 0.monitor.csv │ ├── unfixwalker2d-5 │ └── 0.monitor.csv │ ├── unfixwalker2d-6 │ └── 0.monitor.csv │ ├── unfixwalker2d-7 │ └── 0.monitor.csv │ ├── unfixwalker2d-8 │ └── 0.monitor.csv │ ├── unfixwalker2d-9 │ └── 0.monitor.csv │ ├── walker2d-0 │ └── 0.monitor.csv │ ├── walker2d-1 │ └── 0.monitor.csv │ ├── walker2d-2 │ └── 0.monitor.csv │ ├── walker2d-3 │ └── 0.monitor.csv │ ├── walker2d-4 │ └── 0.monitor.csv │ ├── walker2d-5 │ └── 0.monitor.csv │ ├── walker2d-6 │ └── 0.monitor.csv │ ├── walker2d-7 │ └── 0.monitor.csv │ ├── walker2d-8 │ └── 0.monitor.csv │ └── walker2d-9 │ └── 0.monitor.csv └── visualize.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *txt 6 | *pt 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | trained_models/ 106 | .fuse_hidden* 107 | tags 108 | *png 109 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Ilya Kostrikov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Augmenting GAIL with BC for sample efficient imitation learning 2 | 3 | 4 | Official implemention of the paper [Augmenting GAIL with BC for sample efficient imitation learning](https://arxiv.org/abs/2001.07798) in PyTorch. 5 | 6 | It builds upon the PyTorch implementation of popular RL algorithms repository (readme below). 7 | 8 | ## Installation 9 | 1. Install required packages from `requirements.txt` file. 10 | 2. Install this package with `pip install -e`. 11 | 12 | ## Reproducing results 13 | * To reproduce results for GAIL, run the `gail.sh` script. Be sure to change the default log and model paths in `a2c_ppo_acktr/arguments.py` first. 14 | The general script to run is 15 | ``` 16 | ./.sh 17 | ``` 18 | Where keyword `method` corresponds to the following Experiment/Baseline 19 | |**method**| **Experiment/Baseline** | 20 | |---|---| 21 | |gail| GAIL | 22 | |baselinesbc| BC pretraining + GAIL finetuning | 23 | |bcgail| Our method | 24 | |redsail| RED & SAIL | 25 | |alphamujoco | Ablation on effect of `\alpha` | 26 | |bcnogail | Ablation on effect of BC + untrained GAIL | 27 | 28 | Use the following `steps` for the following mujoco environments: 29 | |**Environment**| **Steps** | 30 | |---|---| 31 | |Ant-v2| 3000000 | 32 | |HalfCheetah-v2| 3000000 | 33 | |Hopper-v2| 1000000 | 34 | |Walker2d-v2| 3000000 | 35 | |Reacher-v2| 2000000 | 36 | 37 | 38 | If you like this work and want to use it in your research, consider citing our paper (and the repository if you use it - bibtex below): 39 | ``` 40 | @misc{jena2020augmenting, 41 | title={Augmenting GAIL with BC for sample efficient imitation learning}, 42 | author={Rohit Jena and Changliu Liu and Katia Sycara}, 43 | year={2020}, 44 | eprint={2001.07798}, 45 | archivePrefix={arXiv}, 46 | primaryClass={cs.LG} 47 | } 48 | ``` 49 | 50 | 51 | # pytorch-a2c-ppo-acktr 52 | 53 | ## Please use hyper parameters from this readme. With other hyper parameters things might not work (it's RL after all)! 54 | 55 | This is a PyTorch implementation of 56 | * Advantage Actor Critic (A2C), a synchronous deterministic version of [A3C](https://arxiv.org/pdf/1602.01783v1.pdf) 57 | * Proximal Policy Optimization [PPO](https://arxiv.org/pdf/1707.06347.pdf) 58 | * Scalable trust-region method for deep reinforcement learning using Kronecker-factored approximation [ACKTR](https://arxiv.org/abs/1708.05144) 59 | * Generative Adversarial Imitation Learning [GAIL](https://arxiv.org/abs/1606.03476) 60 | 61 | Also see the OpenAI posts: [A2C/ACKTR](https://blog.openai.com/baselines-acktr-a2c/) and [PPO](https://blog.openai.com/openai-baselines-ppo/) for more information. 62 | 63 | This implementation is inspired by the OpenAI baselines for [A2C](https://github.com/openai/baselines/tree/master/baselines/a2c), [ACKTR](https://github.com/openai/baselines/tree/master/baselines/acktr) and [PPO](https://github.com/openai/baselines/tree/master/baselines/ppo1). It uses the same hyper parameters and the model since they were well tuned for Atari games. 64 | 65 | Please use this bibtex if you want to cite this repository in your publications: 66 | 67 | @misc{pytorchrl, 68 | author = {Kostrikov, Ilya}, 69 | title = {PyTorch Implementations of Reinforcement Learning Algorithms}, 70 | year = {2018}, 71 | publisher = {GitHub}, 72 | journal = {GitHub repository}, 73 | howpublished = {\url{https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail}}, 74 | } 75 | 76 | ## Supported (and tested) environments (via [OpenAI Gym](https://gym.openai.com)) 77 | * [Atari Learning Environment](https://github.com/mgbellemare/Arcade-Learning-Environment) 78 | * [MuJoCo](http://mujoco.org) 79 | * [PyBullet](http://pybullet.org) (including Racecar, Minitaur and Kuka) 80 | * [DeepMind Control Suite](https://github.com/deepmind/dm_control) (via [dm_control2gym](https://github.com/martinseilair/dm_control2gym)) 81 | 82 | I highly recommend PyBullet as a free open source alternative to MuJoCo for continuous control tasks. 83 | 84 | All environments are operated using exactly the same Gym interface. See their documentations for a comprehensive list. 85 | 86 | To use the DeepMind Control Suite environments, set the flag `--env-name dm..`, where `domain_name` and `task_name` are the name of a domain (e.g. `hopper`) and a task within that domain (e.g. `stand`) from the DeepMind Control Suite. Refer to their repo and their [tech report](https://arxiv.org/abs/1801.00690) for a full list of available domains and tasks. Other than setting the task, the API for interacting with the environment is exactly the same as for all the Gym environments thanks to [dm_control2gym](https://github.com/martinseilair/dm_control2gym). 87 | 88 | ## Requirements 89 | 90 | * Python 3 (it might work with Python 2, but I didn't test it) 91 | * [PyTorch](http://pytorch.org/) 92 | * [OpenAI baselines](https://github.com/openai/baselines) 93 | 94 | In order to install requirements, follow: 95 | 96 | ```bash 97 | # PyTorch 98 | conda install pytorch torchvision -c soumith 99 | 100 | # Baselines for Atari preprocessing 101 | git clone https://github.com/openai/baselines.git 102 | cd baselines 103 | pip install -e . 104 | 105 | # Other requirements 106 | pip install -r requirements.txt 107 | ``` 108 | 109 | ## Contributions 110 | 111 | Contributions are very welcome. If you know how to make this code better, please open an issue. If you want to submit a pull request, please open an issue first. Also see a todo list below. 112 | 113 | Also I'm searching for volunteers to run all experiments on Atari and MuJoCo (with multiple random seeds). 114 | 115 | ## Disclaimer 116 | 117 | It's extremely difficult to reproduce results for Reinforcement Learning methods. See ["Deep Reinforcement Learning that Matters"](https://arxiv.org/abs/1709.06560) for more information. I tried to reproduce OpenAI results as closely as possible. However, majors differences in performance can be caused even by minor differences in TensorFlow and PyTorch libraries. 118 | 119 | ### TODO 120 | * Improve this README file. Rearrange images. 121 | * Improve performance of KFAC, see kfac.py for more information 122 | * Run evaluation for all games and algorithms 123 | 124 | ## Visualization 125 | 126 | In order to visualize the results use ```visualize.ipynb```. 127 | 128 | 129 | ## Training 130 | 131 | ### Atari 132 | #### A2C 133 | 134 | ```bash 135 | python main.py --env-name "PongNoFrameskip-v4" 136 | ``` 137 | 138 | #### PPO 139 | 140 | ```bash 141 | python main.py --env-name "PongNoFrameskip-v4" --algo ppo --use-gae --lr 2.5e-4 --clip-param 0.1 --value-loss-coef 0.5 --num-processes 8 --num-steps 128 --num-mini-batch 4 --log-interval 1 --use-linear-lr-decay --entropy-coef 0.01 142 | ``` 143 | 144 | #### ACKTR 145 | 146 | ```bash 147 | python main.py --env-name "PongNoFrameskip-v4" --algo acktr --num-processes 32 --num-steps 20 148 | ``` 149 | 150 | ### MuJoCo 151 | 152 | Please always try to use ```--use-proper-time-limits``` flag. It properly handles partial trajectories (see https://github.com/sfujim/TD3/blob/master/main.py#L123). 153 | 154 | #### A2C 155 | 156 | ```bash 157 | python main.py --env-name "Reacher-v2" --num-env-steps 1000000 158 | ``` 159 | 160 | #### PPO 161 | 162 | ```bash 163 | python main.py --env-name "Reacher-v2" --algo ppo --use-gae --log-interval 1 --num-steps 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --gae-lambda 0.95 --num-env-steps 1000000 --use-linear-lr-decay --use-proper-time-limits 164 | ``` 165 | 166 | #### ACKTR 167 | 168 | ACKTR requires some modifications to be made specifically for MuJoCo. But at the moment, I want to keep this code as unified as possible. Thus, I'm going for better ways to integrate it into the codebase. 169 | 170 | ## Enjoy 171 | 172 | Load a pretrained model from [my Google Drive](https://drive.google.com/open?id=0Bw49qC_cgohKS3k2OWpyMWdzYkk). 173 | 174 | Also pretrained models for other games are available on request. Send me an email or create an issue, and I will upload it. 175 | 176 | Disclaimer: I might have used different hyper-parameters to train these models. 177 | 178 | ### Atari 179 | 180 | ```bash 181 | python enjoy.py --load-dir trained_models/a2c --env-name "PongNoFrameskip-v4" 182 | ``` 183 | 184 | ### MuJoCo 185 | 186 | ```bash 187 | python enjoy.py --load-dir trained_models/ppo --env-name "Reacher-v2" 188 | ``` 189 | -------------------------------------------------------------------------------- /a2c_ppo_acktr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitrango/BC-regularized-GAIL/08e18fb407587e0383cf53dab5b75a33738c9ce0/a2c_ppo_acktr/__init__.py -------------------------------------------------------------------------------- /a2c_ppo_acktr/algo/__init__.py: -------------------------------------------------------------------------------- 1 | from .a2c_acktr import A2C_ACKTR 2 | from .ppo import PPO -------------------------------------------------------------------------------- /a2c_ppo_acktr/algo/a2c_acktr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | 5 | from a2c_ppo_acktr.algo.kfac import KFACOptimizer 6 | 7 | 8 | class A2C_ACKTR(): 9 | def __init__(self, 10 | actor_critic, 11 | value_loss_coef, 12 | entropy_coef, 13 | lr=None, 14 | eps=None, 15 | alpha=None, 16 | max_grad_norm=None, 17 | acktr=False): 18 | 19 | self.actor_critic = actor_critic 20 | self.acktr = acktr 21 | 22 | self.value_loss_coef = value_loss_coef 23 | self.entropy_coef = entropy_coef 24 | 25 | self.max_grad_norm = max_grad_norm 26 | 27 | if acktr: 28 | self.optimizer = KFACOptimizer(actor_critic) 29 | else: 30 | self.optimizer = optim.RMSprop( 31 | actor_critic.parameters(), lr, eps=eps, alpha=alpha) 32 | 33 | def update(self, rollouts): 34 | obs_shape = rollouts.obs.size()[2:] 35 | action_shape = rollouts.actions.size()[-1] 36 | num_steps, num_processes, _ = rollouts.rewards.size() 37 | 38 | values, action_log_probs, dist_entropy, _ = self.actor_critic.evaluate_actions( 39 | rollouts.obs[:-1].view(-1, *obs_shape), 40 | rollouts.recurrent_hidden_states[0].view( 41 | -1, self.actor_critic.recurrent_hidden_state_size), 42 | rollouts.masks[:-1].view(-1, 1), 43 | rollouts.actions.view(-1, action_shape)) 44 | 45 | values = values.view(num_steps, num_processes, 1) 46 | action_log_probs = action_log_probs.view(num_steps, num_processes, 1) 47 | 48 | advantages = rollouts.returns[:-1] - values 49 | value_loss = advantages.pow(2).mean() 50 | 51 | action_loss = -(advantages.detach() * action_log_probs).mean() 52 | 53 | if self.acktr and self.optimizer.steps % self.optimizer.Ts == 0: 54 | # Compute fisher, see Martens 2014 55 | self.actor_critic.zero_grad() 56 | pg_fisher_loss = -action_log_probs.mean() 57 | 58 | value_noise = torch.randn(values.size()) 59 | if values.is_cuda: 60 | value_noise = value_noise.cuda() 61 | 62 | sample_values = values + value_noise 63 | vf_fisher_loss = -(values - sample_values.detach()).pow(2).mean() 64 | 65 | fisher_loss = pg_fisher_loss + vf_fisher_loss 66 | self.optimizer.acc_stats = True 67 | fisher_loss.backward(retain_graph=True) 68 | self.optimizer.acc_stats = False 69 | 70 | self.optimizer.zero_grad() 71 | (value_loss * self.value_loss_coef + action_loss - 72 | dist_entropy * self.entropy_coef).backward() 73 | 74 | if self.acktr == False: 75 | nn.utils.clip_grad_norm_(self.actor_critic.parameters(), 76 | self.max_grad_norm) 77 | 78 | self.optimizer.step() 79 | 80 | return value_loss.item(), action_loss.item(), dist_entropy.item() 81 | -------------------------------------------------------------------------------- /a2c_ppo_acktr/algo/gail.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.utils.data 7 | from torch import autograd 8 | import gym 9 | from baselines.common.running_mean_std import RunningMeanStd 10 | 11 | 12 | class RED(nn.Module): 13 | def __init__(self, input_dim, hidden_dim, device, sigma, iters): 14 | super().__init__() 15 | self.device = device 16 | self.sigma = sigma 17 | self.iters = iters 18 | 19 | # This is a random initialization, used to learn 20 | self.dummytrunk = nn.Sequential( 21 | nn.Linear(input_dim, hidden_dim), nn.Tanh(), 22 | nn.Linear(hidden_dim, hidden_dim), nn.Tanh(), 23 | nn.Linear(hidden_dim, 1) 24 | ).to(device) 25 | 26 | self.trunk = nn.Sequential( 27 | nn.Linear(input_dim, hidden_dim), nn.Tanh(), 28 | nn.Linear(hidden_dim, hidden_dim), nn.Tanh(), 29 | nn.Linear(hidden_dim, 1) 30 | ).to(device) 31 | 32 | self.trunk.train() 33 | self.optimizer = torch.optim.Adam(self.trunk.parameters()) 34 | 35 | def train_red(self, expert_loader): 36 | # Train the loader 37 | self.train() 38 | for _ in range(self.iters): 39 | for expert_batch in expert_loader: 40 | # Get expert state and actions 41 | expert_state, expert_action = expert_batch 42 | expert_state = torch.FloatTensor(expert_state).to(self.device) 43 | expert_action = expert_action.to(self.device) 44 | 45 | # Given expert state and action 46 | expert_sa = torch.cat([expert_state, expert_action], dim=1) 47 | fsa = self.trunk(expert_sa) 48 | with torch.no_grad(): 49 | fsa_random = self.dummytrunk(expert_sa) 50 | 51 | loss = ((fsa - fsa_random)**2).mean() 52 | 53 | self.optimizer.zero_grad() 54 | loss.backward() 55 | self.optimizer.step() 56 | print("RED loss: {}".format(loss.data.cpu().numpy())) 57 | 58 | def predict_reward(self, state, action, obfilt=None): 59 | with torch.no_grad(): 60 | self.eval() 61 | if obfilt is not None: 62 | s = obfilt(state.cpu().numpy()) 63 | s = torch.FloatTensor(s).to(action.device) 64 | else: 65 | s = state 66 | d = torch.cat([s, action], dim=1) 67 | fsa = self.trunk(d) 68 | fsa_random = self.dummytrunk(d) 69 | rew = torch.exp(-self.sigma * ((fsa - fsa_random)**2).mean(1))[:, None] 70 | return rew 71 | 72 | 73 | 74 | class Discriminator(nn.Module): 75 | def __init__(self, input_dim, hidden_dim, device, red=None, sail=False, learn=True): 76 | super(Discriminator, self).__init__() 77 | 78 | self.device = device 79 | 80 | self.red = red 81 | self.sail = sail 82 | self.redtrained = False 83 | if self.sail: 84 | assert self.red is not None, 'Cannot run SAIL without using RED' 85 | 86 | self.trunk = nn.Sequential( 87 | nn.Linear(input_dim, hidden_dim), nn.Tanh(), 88 | nn.Linear(hidden_dim, hidden_dim), nn.Tanh(), 89 | nn.Linear(hidden_dim, 1)).to(device) 90 | 91 | self.trunk.train() 92 | 93 | self.learn = learn 94 | self.optimizer = torch.optim.Adam(self.trunk.parameters()) 95 | 96 | self.returns = None 97 | self.ret_rms = RunningMeanStd(shape=()) 98 | 99 | def compute_grad_pen(self, 100 | expert_state, 101 | expert_action, 102 | policy_state, 103 | policy_action, 104 | lambda_=10): 105 | alpha = torch.rand(expert_state.size(0), 1) 106 | expert_data = torch.cat([expert_state, expert_action], dim=1) 107 | policy_data = torch.cat([policy_state, policy_action], dim=1) 108 | 109 | alpha = alpha.expand_as(expert_data).to(expert_data.device) 110 | 111 | mixup_data = alpha * expert_data + (1 - alpha) * policy_data 112 | mixup_data.requires_grad = True 113 | 114 | disc = self.trunk(mixup_data) 115 | ones = torch.ones(disc.size()).to(disc.device) 116 | grad = autograd.grad( 117 | outputs=disc, 118 | inputs=mixup_data, 119 | grad_outputs=ones, 120 | create_graph=True, 121 | retain_graph=True, 122 | only_inputs=True)[0] 123 | 124 | grad_pen = lambda_ * (grad.norm(2, dim=1) - 1).pow(2).mean() 125 | return grad_pen 126 | 127 | def update(self, expert_loader, rollouts, obsfilt=None): 128 | self.train() 129 | if obsfilt is None: 130 | obsfilt = lambda x,y : x 131 | 132 | # If RED is untrained, then train it 133 | if self.red is not None and not self.redtrained: 134 | print("Training RED...") 135 | self.red.train_red(expert_loader) # obsfilt keeps changing after that, Pass the obsfilt to reverse normalized states 136 | self.redtrained = True 137 | print("Trained RED.") 138 | 139 | # If there is no SAIL but RED is present, 140 | # then GAIL doesn't need to be updated 141 | if self.red is not None and not self.sail: 142 | return 0 143 | 144 | policy_data_generator = rollouts.feed_forward_generator( 145 | None, mini_batch_size=expert_loader.batch_size) 146 | 147 | loss = 0 148 | n = 0 149 | for expert_batch, policy_batch in zip(expert_loader, 150 | policy_data_generator): 151 | policy_state, policy_action = policy_batch[0], policy_batch[2] 152 | policy_d = self.trunk( 153 | torch.cat([policy_state, policy_action], dim=1)) 154 | 155 | expert_state, expert_action = expert_batch 156 | expert_state = obsfilt(expert_state.numpy(), update=False) 157 | expert_state = torch.FloatTensor(expert_state).to(self.device) 158 | expert_action = expert_action.to(self.device) 159 | expert_d = self.trunk( 160 | torch.cat([expert_state, expert_action], dim=1)) 161 | 162 | expert_loss = F.binary_cross_entropy_with_logits( 163 | expert_d, 164 | torch.ones(expert_d.size()).to(self.device)) 165 | policy_loss = F.binary_cross_entropy_with_logits( 166 | policy_d, 167 | torch.zeros(policy_d.size()).to(self.device)) 168 | 169 | gail_loss = expert_loss + policy_loss 170 | grad_pen = self.compute_grad_pen(expert_state, expert_action, 171 | policy_state, policy_action) 172 | 173 | loss += (gail_loss + grad_pen).item() 174 | n += 1 175 | 176 | if self.learn: 177 | self.optimizer.zero_grad() 178 | (gail_loss + grad_pen).backward() 179 | self.optimizer.step() 180 | return loss / n 181 | 182 | def predict_reward(self, state, action, gamma, masks, update_rms=True, obsfilt=None): 183 | with torch.no_grad(): 184 | self.eval() 185 | d = self.trunk(torch.cat([state, action], dim=1)) 186 | s = torch.sigmoid(d) 187 | # Get RED reward 188 | if self.red is not None: 189 | assert self.redtrained 190 | red_rew = self.red.predict_reward(state, action, obsfilt) 191 | 192 | # Check if SAIL is present or not 193 | if self.sail: 194 | reward = s * red_rew 195 | else: 196 | reward = red_rew 197 | else: 198 | # If traditional GAIL 199 | #reward = s.log() - (1 - s).log() 200 | reward = - (1 - s).log() 201 | 202 | if self.returns is None: 203 | self.returns = reward.clone() 204 | 205 | if update_rms: 206 | self.returns = self.returns * masks * gamma + reward 207 | self.ret_rms.update(self.returns.cpu().numpy()) 208 | 209 | return reward / np.sqrt(self.ret_rms.var[0] + 1e-8) 210 | 211 | 212 | class Flatten(nn.Module): 213 | def forward(self, x): 214 | return x.view(x.size(0), -1) 215 | 216 | class CNNDiscriminator(nn.Module): 217 | def __init__(self, input_shape, action_space, hidden_dim, device, clip=0.01): 218 | super(CNNDiscriminator, self).__init__() 219 | self.device = device 220 | C, H, W = input_shape 221 | self.n = 0 222 | if type(action_space) == gym.spaces.box.Box: 223 | A = action_space.shape[0] 224 | else: 225 | A = action_space.n 226 | self.n = A 227 | 228 | self.main = nn.Sequential( 229 | nn.Conv2d(C, 32, 4, stride=2), nn.ReLU(), 230 | nn.Conv2d(32, 64, 4, stride=2), nn.ReLU(), 231 | nn.Conv2d(64, 128, 4, stride=2), nn.ReLU(), 232 | nn.Conv2d(128, 256, 4, stride=2), nn.ReLU(), Flatten(), 233 | ).to(device) 234 | self.clip = clip 235 | print("Using clip {}".format(self.clip)) 236 | 237 | for i in range(4): 238 | H = (H - 4)//2 + 1 239 | W = (W - 4)//2 + 1 240 | # Get image dim 241 | img_dim = 256*H*W 242 | 243 | self.trunk = nn.Sequential( 244 | nn.Linear(A + img_dim, hidden_dim), nn.Tanh(), 245 | nn.Linear(hidden_dim, hidden_dim), nn.Tanh(), 246 | nn.Linear(hidden_dim, 1)).to(device) 247 | 248 | self.main.train() 249 | self.trunk.train() 250 | 251 | self.optimizer = torch.optim.Adam(list(self.main.parameters()) + list(self.trunk.parameters())) 252 | self.returns = None 253 | self.ret_rms = RunningMeanStd(shape=()) 254 | 255 | def compute_grad_pen(self, 256 | expert_state, 257 | expert_action, 258 | policy_state, 259 | policy_action, 260 | lambda_=10): 261 | grad_pen = 0 262 | if True: 263 | alpha = torch.rand(expert_state.size(0), 1) 264 | 265 | # Change state values 266 | exp_state = self.main(expert_state) 267 | pol_state = self.main(policy_state) 268 | 269 | expert_data = torch.cat([exp_state, expert_action], dim=1) 270 | policy_data = torch.cat([pol_state, policy_action], dim=1) 271 | 272 | alpha = alpha.expand_as(expert_data).to(expert_data.device) 273 | 274 | mixup_data = alpha * expert_data + (1 - alpha) * policy_data 275 | 276 | disc = self.trunk(mixup_data) 277 | ones = torch.ones(disc.size()).to(disc.device) 278 | grad = autograd.grad( 279 | outputs=disc, 280 | inputs=mixup_data, 281 | grad_outputs=ones, 282 | create_graph=True, 283 | retain_graph=True, 284 | only_inputs=True)[0] 285 | 286 | grad_pen = lambda_ * (grad.norm(2, dim=1) - 1).pow(2).mean() 287 | return grad_pen 288 | 289 | def update(self, expert_loader, rollouts, obsfilt=None): 290 | self.train() 291 | assert obsfilt is None 292 | 293 | policy_data_generator = rollouts.feed_forward_generator( 294 | None, mini_batch_size=expert_loader.batch_size) 295 | 296 | loss = 0 297 | n = 0 298 | for expert_batch, policy_batch in zip(expert_loader, 299 | policy_data_generator): 300 | policy_state, policy_action = policy_batch[0], policy_batch[2] 301 | 302 | if self.n > 0: 303 | act = torch.zeros(policy_action.shape[0], self.n) 304 | polact = policy_action.squeeze() 305 | act[np.arange(polact.shape[0]), polact] = 1 306 | policy_action = torch.FloatTensor(act).to(policy_action.device) 307 | #print('policy', policy_action.shape) 308 | 309 | pol_state = self.main(policy_state) 310 | policy_d = self.trunk( 311 | torch.cat([pol_state, policy_action], dim=1)) 312 | 313 | expert_state, expert_action = expert_batch 314 | #print('expert', expert_action.shape) 315 | expert_state = torch.FloatTensor(expert_state).to(self.device) 316 | expert_action = expert_action.to(self.device) 317 | exp_state = self.main(expert_state) 318 | expert_d = self.trunk( 319 | torch.cat([exp_state, expert_action], dim=1)) 320 | 321 | expert_loss = F.binary_cross_entropy_with_logits( 322 | expert_d, 323 | torch.ones(expert_d.size()).to(self.device)) 324 | policy_loss = F.binary_cross_entropy_with_logits( 325 | policy_d, 326 | torch.zeros(policy_d.size()).to(self.device)) 327 | #expert_loss = -expert_d.mean().to(self.device) 328 | #policy_loss = policy_d.mean().to(self.device) 329 | 330 | gail_loss = expert_loss + policy_loss 331 | grad_pen = self.compute_grad_pen(expert_state, expert_action, 332 | policy_state, policy_action) 333 | 334 | loss += (gail_loss + grad_pen).item() 335 | n += 1 336 | 337 | self.optimizer.zero_grad() 338 | (gail_loss + grad_pen).backward() 339 | self.optimizer.step() 340 | 341 | # Clip params here 342 | #for p in self.parameters(): 343 | #p = p.clamp(-self.clip, self.clip) 344 | 345 | return loss / n 346 | 347 | def predict_reward(self, state, action, gamma, masks, update_rms=True): 348 | with torch.no_grad(): 349 | self.eval() 350 | if self.n > 0: 351 | acts = torch.zeros((action.shape[0], self.n)) 352 | acts[np.arange(action.shape[0]), action.squeeze()] = 1 353 | acts = torch.FloatTensor(acts).to(action.device) 354 | else: 355 | acts = action 356 | 357 | stat = self.main(state) 358 | d = self.trunk(torch.cat([stat, acts], dim=1)) 359 | s = torch.sigmoid(d) 360 | reward = -(1 - s).log() 361 | #reward = d / self.clip 362 | if self.returns is None: 363 | self.returns = reward.clone() 364 | 365 | if update_rms: 366 | self.returns = self.returns * masks * gamma + reward 367 | self.ret_rms.update(self.returns.cpu().numpy()) 368 | 369 | return reward / np.sqrt(self.ret_rms.var[0] + 1e-8) 370 | 371 | 372 | class ExpertImageDataset(torch.utils.data.Dataset): 373 | def __init__(self, file_name, train=None, act=None): 374 | trajs = torch.load(file_name) 375 | self.observations = trajs['obs'] 376 | self.actions = trajs['actions'] 377 | self.train = train 378 | self.act = None 379 | if isinstance(act, gym.spaces.Discrete): 380 | self.act = act.n 381 | 382 | self.actual_obs = [None for _ in range(len(self.actions))] 383 | self.lenn = 0 384 | if train is not None: 385 | lenn = int(0.8*len(self.actions)) 386 | self.lenn = lenn 387 | if train: 388 | self.actions = self.actions[:lenn] 389 | else: 390 | self.actions = self.actions[lenn:] 391 | 392 | def __len__(self, ): 393 | return len(self.actions) 394 | 395 | def __getitem__(self, idx): 396 | action = self.actions[idx] 397 | if self.act: 398 | act = np.zeros((self.act, )) 399 | act[action[0]] = 1 400 | action = act 401 | # Load only the first time, images in uint8 are supposed to be light 402 | if self.actual_obs[idx] is None: 403 | if self.train == False: 404 | image = np.load(self.observations[idx + self.lenn] + '.npy') 405 | else: 406 | image = np.load(self.observations[idx] + '.npy') 407 | self.actual_obs[idx] = image 408 | else: 409 | image = self.actual_obs[idx] 410 | # rescale image and pass it 411 | img = image / 255.0 412 | img = img.transpose(2, 0, 1) 413 | # [C, H, W ] image and [A] actions 414 | return torch.FloatTensor(img), torch.FloatTensor(action) 415 | 416 | 417 | class ExpertDataset(torch.utils.data.Dataset): 418 | def __init__(self, file_name, num_trajectories=4, subsample_frequency=20, train=True, start=0): 419 | all_trajectories = torch.load(file_name) 420 | 421 | perm = torch.randperm(all_trajectories['states'].size(0)) 422 | #idx = perm[:num_trajectories] 423 | idx = np.arange(num_trajectories) + start 424 | if not train: 425 | assert start > 0 426 | 427 | self.trajectories = {} 428 | 429 | # See https://github.com/pytorch/pytorch/issues/14886 430 | # .long() for fixing bug in torch v0.4.1 431 | start_idx = torch.randint( 432 | 0, subsample_frequency, size=(num_trajectories, )).long() 433 | 434 | for k, v in all_trajectories.items(): 435 | data = v[idx] 436 | 437 | if k != 'lengths': 438 | samples = [] 439 | for i in range(num_trajectories): 440 | samples.append(data[i, start_idx[i]::subsample_frequency]) 441 | self.trajectories[k] = torch.stack(samples) 442 | else: 443 | self.trajectories[k] = data // subsample_frequency 444 | 445 | self.i2traj_idx = {} 446 | self.i2i = {} 447 | 448 | self.length = self.trajectories['lengths'].sum().item() 449 | 450 | traj_idx = 0 451 | i = 0 452 | 453 | self.get_idx = [] 454 | 455 | for j in range(self.length): 456 | 457 | while self.trajectories['lengths'][traj_idx].item() <= i: 458 | i -= self.trajectories['lengths'][traj_idx].item() 459 | traj_idx += 1 460 | 461 | self.get_idx.append((traj_idx, i)) 462 | 463 | i += 1 464 | 465 | 466 | def __len__(self): 467 | return self.length 468 | 469 | def __getitem__(self, i): 470 | traj_idx, i = self.get_idx[i] 471 | 472 | return self.trajectories['states'][traj_idx][i], self.trajectories[ 473 | 'actions'][traj_idx][i] 474 | -------------------------------------------------------------------------------- /a2c_ppo_acktr/algo/kfac.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | 8 | from a2c_ppo_acktr.utils import AddBias 9 | 10 | # TODO: In order to make this code faster: 11 | # 1) Implement _extract_patches as a single cuda kernel 12 | # 2) Compute QR decomposition in a separate process 13 | # 3) Actually make a general KFAC optimizer so it fits PyTorch 14 | 15 | 16 | def _extract_patches(x, kernel_size, stride, padding): 17 | if padding[0] + padding[1] > 0: 18 | x = F.pad(x, (padding[1], padding[1], padding[0], 19 | padding[0])).data # Actually check dims 20 | x = x.unfold(2, kernel_size[0], stride[0]) 21 | x = x.unfold(3, kernel_size[1], stride[1]) 22 | x = x.transpose_(1, 2).transpose_(2, 3).contiguous() 23 | x = x.view( 24 | x.size(0), x.size(1), x.size(2), 25 | x.size(3) * x.size(4) * x.size(5)) 26 | return x 27 | 28 | 29 | def compute_cov_a(a, classname, layer_info, fast_cnn): 30 | batch_size = a.size(0) 31 | 32 | if classname == 'Conv2d': 33 | if fast_cnn: 34 | a = _extract_patches(a, *layer_info) 35 | a = a.view(a.size(0), -1, a.size(-1)) 36 | a = a.mean(1) 37 | else: 38 | a = _extract_patches(a, *layer_info) 39 | a = a.view(-1, a.size(-1)).div_(a.size(1)).div_(a.size(2)) 40 | elif classname == 'AddBias': 41 | is_cuda = a.is_cuda 42 | a = torch.ones(a.size(0), 1) 43 | if is_cuda: 44 | a = a.cuda() 45 | 46 | return a.t() @ (a / batch_size) 47 | 48 | 49 | def compute_cov_g(g, classname, layer_info, fast_cnn): 50 | batch_size = g.size(0) 51 | 52 | if classname == 'Conv2d': 53 | if fast_cnn: 54 | g = g.view(g.size(0), g.size(1), -1) 55 | g = g.sum(-1) 56 | else: 57 | g = g.transpose(1, 2).transpose(2, 3).contiguous() 58 | g = g.view(-1, g.size(-1)).mul_(g.size(1)).mul_(g.size(2)) 59 | elif classname == 'AddBias': 60 | g = g.view(g.size(0), g.size(1), -1) 61 | g = g.sum(-1) 62 | 63 | g_ = g * batch_size 64 | return g_.t() @ (g_ / g.size(0)) 65 | 66 | 67 | def update_running_stat(aa, m_aa, momentum): 68 | # Do the trick to keep aa unchanged and not create any additional tensors 69 | m_aa *= momentum / (1 - momentum) 70 | m_aa += aa 71 | m_aa *= (1 - momentum) 72 | 73 | 74 | class SplitBias(nn.Module): 75 | def __init__(self, module): 76 | super(SplitBias, self).__init__() 77 | self.module = module 78 | self.add_bias = AddBias(module.bias.data) 79 | self.module.bias = None 80 | 81 | def forward(self, input): 82 | x = self.module(input) 83 | x = self.add_bias(x) 84 | return x 85 | 86 | 87 | class KFACOptimizer(optim.Optimizer): 88 | def __init__(self, 89 | model, 90 | lr=0.25, 91 | momentum=0.9, 92 | stat_decay=0.99, 93 | kl_clip=0.001, 94 | damping=1e-2, 95 | weight_decay=0, 96 | fast_cnn=False, 97 | Ts=1, 98 | Tf=10): 99 | defaults = dict() 100 | 101 | def split_bias(module): 102 | for mname, child in module.named_children(): 103 | if hasattr(child, 'bias') and child.bias is not None: 104 | module._modules[mname] = SplitBias(child) 105 | else: 106 | split_bias(child) 107 | 108 | split_bias(model) 109 | 110 | super(KFACOptimizer, self).__init__(model.parameters(), defaults) 111 | 112 | self.known_modules = {'Linear', 'Conv2d', 'AddBias'} 113 | 114 | self.modules = [] 115 | self.grad_outputs = {} 116 | 117 | self.model = model 118 | self._prepare_model() 119 | 120 | self.steps = 0 121 | 122 | self.m_aa, self.m_gg = {}, {} 123 | self.Q_a, self.Q_g = {}, {} 124 | self.d_a, self.d_g = {}, {} 125 | 126 | self.momentum = momentum 127 | self.stat_decay = stat_decay 128 | 129 | self.lr = lr 130 | self.kl_clip = kl_clip 131 | self.damping = damping 132 | self.weight_decay = weight_decay 133 | 134 | self.fast_cnn = fast_cnn 135 | 136 | self.Ts = Ts 137 | self.Tf = Tf 138 | 139 | self.optim = optim.SGD( 140 | model.parameters(), 141 | lr=self.lr * (1 - self.momentum), 142 | momentum=self.momentum) 143 | 144 | def _save_input(self, module, input): 145 | if torch.is_grad_enabled() and self.steps % self.Ts == 0: 146 | classname = module.__class__.__name__ 147 | layer_info = None 148 | if classname == 'Conv2d': 149 | layer_info = (module.kernel_size, module.stride, 150 | module.padding) 151 | 152 | aa = compute_cov_a(input[0].data, classname, layer_info, 153 | self.fast_cnn) 154 | 155 | # Initialize buffers 156 | if self.steps == 0: 157 | self.m_aa[module] = aa.clone() 158 | 159 | update_running_stat(aa, self.m_aa[module], self.stat_decay) 160 | 161 | def _save_grad_output(self, module, grad_input, grad_output): 162 | # Accumulate statistics for Fisher matrices 163 | if self.acc_stats: 164 | classname = module.__class__.__name__ 165 | layer_info = None 166 | if classname == 'Conv2d': 167 | layer_info = (module.kernel_size, module.stride, 168 | module.padding) 169 | 170 | gg = compute_cov_g(grad_output[0].data, classname, layer_info, 171 | self.fast_cnn) 172 | 173 | # Initialize buffers 174 | if self.steps == 0: 175 | self.m_gg[module] = gg.clone() 176 | 177 | update_running_stat(gg, self.m_gg[module], self.stat_decay) 178 | 179 | def _prepare_model(self): 180 | for module in self.model.modules(): 181 | classname = module.__class__.__name__ 182 | if classname in self.known_modules: 183 | assert not ((classname in ['Linear', 'Conv2d']) and module.bias is not None), \ 184 | "You must have a bias as a separate layer" 185 | 186 | self.modules.append(module) 187 | module.register_forward_pre_hook(self._save_input) 188 | module.register_backward_hook(self._save_grad_output) 189 | 190 | def step(self): 191 | # Add weight decay 192 | if self.weight_decay > 0: 193 | for p in self.model.parameters(): 194 | p.grad.data.add_(self.weight_decay, p.data) 195 | 196 | updates = {} 197 | for i, m in enumerate(self.modules): 198 | assert len(list(m.parameters()) 199 | ) == 1, "Can handle only one parameter at the moment" 200 | classname = m.__class__.__name__ 201 | p = next(m.parameters()) 202 | 203 | la = self.damping + self.weight_decay 204 | 205 | if self.steps % self.Tf == 0: 206 | # My asynchronous implementation exists, I will add it later. 207 | # Experimenting with different ways to this in PyTorch. 208 | self.d_a[m], self.Q_a[m] = torch.symeig( 209 | self.m_aa[m], eigenvectors=True) 210 | self.d_g[m], self.Q_g[m] = torch.symeig( 211 | self.m_gg[m], eigenvectors=True) 212 | 213 | self.d_a[m].mul_((self.d_a[m] > 1e-6).float()) 214 | self.d_g[m].mul_((self.d_g[m] > 1e-6).float()) 215 | 216 | if classname == 'Conv2d': 217 | p_grad_mat = p.grad.data.view(p.grad.data.size(0), -1) 218 | else: 219 | p_grad_mat = p.grad.data 220 | 221 | v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m] 222 | v2 = v1 / ( 223 | self.d_g[m].unsqueeze(1) * self.d_a[m].unsqueeze(0) + la) 224 | v = self.Q_g[m] @ v2 @ self.Q_a[m].t() 225 | 226 | v = v.view(p.grad.data.size()) 227 | updates[p] = v 228 | 229 | vg_sum = 0 230 | for p in self.model.parameters(): 231 | v = updates[p] 232 | vg_sum += (v * p.grad.data * self.lr * self.lr).sum() 233 | 234 | nu = min(1, math.sqrt(self.kl_clip / vg_sum)) 235 | 236 | for p in self.model.parameters(): 237 | v = updates[p] 238 | p.grad.data.copy_(v) 239 | p.grad.data.mul_(nu) 240 | 241 | self.optim.step() 242 | self.steps += 1 243 | -------------------------------------------------------------------------------- /a2c_ppo_acktr/algo/ppo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torch.autograd import Variable 6 | 7 | import gym 8 | 9 | 10 | class PPO(): 11 | def __init__(self, 12 | actor_critic, 13 | clip_param, 14 | ppo_epoch, 15 | num_mini_batch, 16 | value_loss_coef, 17 | entropy_coef, 18 | lr=None, 19 | eps=None, 20 | max_grad_norm=None, 21 | use_clipped_value_loss=True, 22 | gamma=None, 23 | decay=None, 24 | act_space=None, 25 | ): 26 | 27 | self.actor_critic = actor_critic 28 | 29 | self.clip_param = clip_param 30 | self.ppo_epoch = ppo_epoch 31 | self.num_mini_batch = num_mini_batch 32 | self.act_space = act_space 33 | print(self.act_space) 34 | 35 | self.value_loss_coef = value_loss_coef 36 | self.entropy_coef = entropy_coef 37 | 38 | self.gamma = gamma 39 | self.decay = decay 40 | 41 | self.max_grad_norm = max_grad_norm 42 | self.use_clipped_value_loss = use_clipped_value_loss 43 | 44 | self.optimizer = optim.Adam(actor_critic.parameters(), lr=lr, eps=eps) 45 | 46 | def update_bc(self, expert_state, expert_actions, obfilt=None): 47 | if obfilt: 48 | expert_state = obfilt(expert_state.cpu().numpy(), update=False) 49 | expert_state = torch.FloatTensor(expert_state).to(expert_actions.device) 50 | expert_state = Variable(expert_state) 51 | if isinstance(self.act_space, gym.spaces.Discrete): 52 | _expert_actions = torch.argmax(expert_actions, 1) 53 | else: 54 | _expert_actions = expert_actions 55 | values, actions_log_probs, _, _ = self.actor_critic.evaluate_actions(expert_state, None, None, \ 56 | _expert_actions) 57 | loss = -actions_log_probs.mean() 58 | self.optimizer.zero_grad() 59 | loss.backward() 60 | nn.utils.clip_grad_norm_(self.actor_critic.parameters(), 61 | self.max_grad_norm) 62 | self.optimizer.step() 63 | return loss 64 | 65 | def get_action_loss(self, expert_state, expert_actions): 66 | if isinstance(self.act_space, gym.spaces.Discrete): 67 | _expert_actions = torch.argmax(expert_actions, 1) 68 | else: 69 | _expert_actions = expert_actions 70 | values, actions_log_probs, _, _ = self.actor_critic.evaluate_actions(expert_state, None, None, \ 71 | _expert_actions) 72 | loss = -actions_log_probs.mean() 73 | return loss 74 | 75 | 76 | def update(self, rollouts, expert_dataset=None, obfilt=None): 77 | # Expert dataset in case the BC update is required 78 | advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1] 79 | advantages = (advantages - advantages.mean()) / ( 80 | advantages.std() + 1e-5) 81 | 82 | value_loss_epoch = 0 83 | action_loss_epoch = 0 84 | dist_entropy_epoch = 0 85 | 86 | for e in range(self.ppo_epoch): 87 | if self.actor_critic.is_recurrent: 88 | data_generator = rollouts.recurrent_generator( 89 | advantages, self.num_mini_batch) 90 | else: 91 | data_generator = rollouts.feed_forward_generator( 92 | advantages, self.num_mini_batch) 93 | 94 | for sample in data_generator: 95 | obs_batch, recurrent_hidden_states_batch, actions_batch, \ 96 | value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, \ 97 | adv_targ = sample 98 | 99 | # Reshape to do in a single forward pass for all steps 100 | values, action_log_probs, dist_entropy, _ = self.actor_critic.evaluate_actions( 101 | obs_batch, recurrent_hidden_states_batch, masks_batch, 102 | actions_batch) 103 | 104 | ratio = torch.exp(action_log_probs - 105 | old_action_log_probs_batch) 106 | surr1 = ratio * adv_targ 107 | surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 108 | 1.0 + self.clip_param) * adv_targ 109 | action_loss = -torch.min(surr1, surr2).mean() 110 | 111 | # Expert dataset 112 | if expert_dataset: 113 | for exp_state, exp_action in expert_dataset: 114 | if obfilt: 115 | exp_state = obfilt(exp_state.numpy(), update=False) 116 | exp_state = torch.FloatTensor(exp_state) 117 | exp_state = Variable(exp_state).to(action_loss.device) 118 | exp_action = Variable(exp_action).to(action_loss.device) 119 | # Get BC loss 120 | if isinstance(self.act_space, gym.spaces.Discrete): 121 | _exp_action = torch.argmax(exp_action, 1) 122 | else: 123 | _exp_action = exp_action 124 | _, alogprobs, _, _ = self.actor_critic.evaluate_actions(exp_state, None, None, _exp_action) 125 | bcloss = -alogprobs.mean() 126 | # action loss is weighted sum 127 | action_loss = self.gamma * bcloss + (1 - self.gamma) * action_loss 128 | # Multiply this coeff with decay factor 129 | break 130 | 131 | if self.use_clipped_value_loss: 132 | value_pred_clipped = value_preds_batch + \ 133 | (values - value_preds_batch).clamp(-self.clip_param, self.clip_param) 134 | value_losses = (values - return_batch).pow(2) 135 | value_losses_clipped = ( 136 | value_pred_clipped - return_batch).pow(2) 137 | value_loss = 0.5 * torch.max(value_losses, 138 | value_losses_clipped).mean() 139 | else: 140 | value_loss = 0.5 * (return_batch - values).pow(2).mean() 141 | 142 | self.optimizer.zero_grad() 143 | (value_loss * self.value_loss_coef + action_loss - 144 | dist_entropy * self.entropy_coef).backward() 145 | nn.utils.clip_grad_norm_(self.actor_critic.parameters(), 146 | self.max_grad_norm) 147 | self.optimizer.step() 148 | 149 | value_loss_epoch += value_loss.item() 150 | action_loss_epoch += action_loss.item() 151 | dist_entropy_epoch += dist_entropy.item() 152 | 153 | num_updates = self.ppo_epoch * self.num_mini_batch 154 | 155 | value_loss_epoch /= num_updates 156 | action_loss_epoch /= num_updates 157 | dist_entropy_epoch /= num_updates 158 | 159 | if self.gamma is not None: 160 | self.gamma *= self.decay 161 | print("gamma {}".format(self.gamma)) 162 | 163 | return value_loss_epoch, action_loss_epoch, dist_entropy_epoch 164 | -------------------------------------------------------------------------------- /a2c_ppo_acktr/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | 6 | def get_args(): 7 | parser = argparse.ArgumentParser(description='RL') 8 | parser.add_argument( 9 | '--algo', default='a2c', help='algorithm to use: a2c | ppo | acktr') 10 | parser.add_argument( 11 | '--gail', 12 | action='store_true', 13 | default=False, 14 | help='do imitation learning with gail') 15 | parser.add_argument( 16 | '--gail-experts-dir', 17 | default='./gail_experts', 18 | help='directory that contains expert demonstrations for gail') 19 | parser.add_argument( 20 | '--gail-batch-size', 21 | type=int, 22 | default=128, 23 | help='gail batch size (default: 128)') 24 | parser.add_argument( 25 | '--gail-epoch', type=int, default=5, help='gail epochs (default: 5)') 26 | parser.add_argument('--redsigma', type=float, default=1.0) 27 | parser.add_argument('--rediters', type=int, default=100) 28 | parser.add_argument('--red', type=int, default=0) 29 | parser.add_argument('--sail', type=int, default=0) 30 | parser.add_argument( 31 | '--lr', type=float, default=7e-4, help='learning rate (default: 7e-4)') 32 | parser.add_argument('--learn', type=int, default=1) 33 | parser.add_argument( 34 | '--eps', 35 | type=float, 36 | default=1e-5, 37 | help='RMSprop optimizer epsilon (default: 1e-5)') 38 | parser.add_argument( 39 | '--alpha', 40 | type=float, 41 | default=0.99, 42 | help='RMSprop optimizer apha (default: 0.99)') 43 | parser.add_argument( 44 | '--gamma', 45 | type=float, 46 | default=0.99, 47 | help='discount factor for rewards (default: 0.99)') 48 | parser.add_argument( 49 | '--use-gae', 50 | action='store_true', 51 | default=False, 52 | help='use generalized advantage estimation') 53 | parser.add_argument( 54 | '--gae-lambda', 55 | type=float, 56 | default=0.95, 57 | help='gae lambda parameter (default: 0.95)') 58 | parser.add_argument( 59 | '--entropy-coef', 60 | type=float, 61 | default=0.01, 62 | help='entropy term coefficient (default: 0.01)') 63 | parser.add_argument( 64 | '--value-loss-coef', 65 | type=float, 66 | default=0.5, 67 | help='value loss coefficient (default: 0.5)') 68 | parser.add_argument( 69 | '--max-grad-norm', 70 | type=float, 71 | default=0.5, 72 | help='max norm of gradients (default: 0.5)') 73 | parser.add_argument( 74 | '--record_trajectories', 75 | type=int, 76 | default=0, 77 | help='Save trajectories?') 78 | parser.add_argument('--num_episodes', type=int, default=20) 79 | parser.add_argument('--num-traj', type=int, default=4) 80 | parser.add_argument( 81 | '--seed', type=int, default=1, help='random seed (default: 1)') 82 | parser.add_argument( 83 | '--cuda-deterministic', 84 | action='store_true', 85 | default=False, 86 | help="sets flags for determinism when using CUDA (potentially slow!)") 87 | parser.add_argument( 88 | '--num-processes', 89 | type=int, 90 | default=16, 91 | help='how many training CPU processes to use (default: 16)') 92 | parser.add_argument( 93 | '--num-steps', 94 | type=int, 95 | default=5, 96 | help='number of forward steps in A2C (default: 5)') 97 | parser.add_argument( 98 | '--ppo-epoch', 99 | type=int, 100 | default=4, 101 | help='number of ppo epochs (default: 4)') 102 | parser.add_argument( 103 | '--num-mini-batch', 104 | type=int, 105 | default=32, 106 | help='number of batches for ppo (default: 32)') 107 | parser.add_argument('--use_activation', default=0, type=int, 108 | help='Use final activation? (Useful for certain scenarios)') 109 | parser.add_argument( 110 | '--clip-param', 111 | type=float, 112 | default=0.2, 113 | help='ppo clip parameter (default: 0.2)') 114 | parser.add_argument( 115 | '--log-interval', 116 | type=int, 117 | default=10, 118 | help='log interval, one log per n updates (default: 10)') 119 | parser.add_argument( 120 | '--save-interval', 121 | type=int, 122 | default=100, 123 | help='save interval, one save per n updates (default: 100)') 124 | parser.add_argument( 125 | '--eval-interval', 126 | type=int, 127 | default=None, 128 | help='eval interval, one eval per n updates (default: None)') 129 | parser.add_argument( 130 | '--num-env-steps', 131 | type=int, 132 | default=10e6, 133 | help='number of environment steps to train (default: 10e6)') 134 | parser.add_argument( 135 | '--env-name', 136 | default='PongNoFrameskip-v4', 137 | help='environment to train on (default: PongNoFrameskip-v4)') 138 | parser.add_argument( 139 | '--log-dir', 140 | default='/serverdata/rohit/BCGAIL/logs/', 141 | help='directory to save agent logs (default: /tmp/gym)') 142 | parser.add_argument('--model_name', type=str, required=True) 143 | parser.add_argument('--load_model_name', type=str, default=None) 144 | parser.add_argument( 145 | '--save-dir', 146 | default='/serverdata/rohit/BCGAIL/', 147 | help='directory to save agent logs (default: ./trained_models/)') 148 | parser.add_argument('--bcgail', type=int, default=0) 149 | parser.add_argument('--decay', type=float, default=None) 150 | parser.add_argument('--gailgamma', type=float, default=None) 151 | parser.add_argument( 152 | '--no-cuda', 153 | action='store_true', 154 | default=False, 155 | help='disables CUDA training') 156 | parser.add_argument( 157 | '--use-proper-time-limits', 158 | action='store_true', 159 | default=False, 160 | help='compute returns taking into account time limits') 161 | parser.add_argument( 162 | '--recurrent-policy', 163 | action='store_true', 164 | default=False, 165 | help='use a recurrent policy') 166 | parser.add_argument('--savelength', type=int, default=0, help='Save average lengths or rewards') 167 | parser.add_argument( 168 | '--use-linear-lr-decay', 169 | action='store_true', 170 | default=False, 171 | help='use a linear schedule on the learning rate') 172 | args = parser.parse_args() 173 | 174 | args.cuda = not args.no_cuda and torch.cuda.is_available() 175 | 176 | assert args.algo in ['a2c', 'ppo', 'acktr'] 177 | if args.recurrent_policy: 178 | assert args.algo in ['a2c', 'ppo'], \ 179 | 'Recurrent policy is not implemented for ACKTR' 180 | 181 | return args 182 | -------------------------------------------------------------------------------- /a2c_ppo_acktr/distributions.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from a2c_ppo_acktr.utils import AddBias, init 8 | 9 | """ 10 | Modify standard PyTorch distributions so they are compatible with this code. 11 | """ 12 | 13 | # 14 | # Standardize distribution interfaces 15 | # 16 | 17 | # Categorical 18 | class FixedCategorical(torch.distributions.Categorical): 19 | def sample(self): 20 | return super().sample().unsqueeze(-1) 21 | 22 | def log_probs(self, actions): 23 | return ( 24 | super() 25 | .log_prob(actions.squeeze(-1)) 26 | .view(actions.size(0), -1) 27 | .sum(-1) 28 | .unsqueeze(-1) 29 | ) 30 | 31 | def mode(self): 32 | return self.probs.argmax(dim=-1, keepdim=True) 33 | 34 | 35 | # Normal 36 | class FixedNormal(torch.distributions.Normal): 37 | def log_probs(self, actions): 38 | return super().log_prob(actions).sum(-1, keepdim=True) 39 | 40 | def entrop(self): 41 | return super.entropy().sum(-1) 42 | 43 | def mode(self): 44 | return self.mean 45 | 46 | 47 | # Bernoulli 48 | class FixedBernoulli(torch.distributions.Bernoulli): 49 | def log_probs(self, actions): 50 | return super.log_prob(actions).view(actions.size(0), -1).sum(-1).unsqueeze(-1) 51 | 52 | def entropy(self): 53 | return super().entropy().sum(-1) 54 | 55 | def mode(self): 56 | return torch.gt(self.probs, 0.5).float() 57 | 58 | 59 | class Categorical(nn.Module): 60 | def __init__(self, num_inputs, num_outputs): 61 | super(Categorical, self).__init__() 62 | 63 | init_ = lambda m: init( 64 | m, 65 | nn.init.orthogonal_, 66 | lambda x: nn.init.constant_(x, 0), 67 | gain=0.01) 68 | 69 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 70 | 71 | def forward(self, x): 72 | x = self.linear(x) 73 | return FixedCategorical(logits=x) 74 | 75 | 76 | class DiagGaussian(nn.Module): 77 | def __init__(self, num_inputs, num_outputs, activation=None): 78 | super(DiagGaussian, self).__init__() 79 | 80 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 81 | constant_(x, 0)) 82 | 83 | self.fc_mean = init_(nn.Linear(num_inputs, num_outputs)) 84 | self.logstd = AddBias(torch.zeros(num_outputs)) 85 | self.activation = activation 86 | 87 | def forward(self, x): 88 | action_mean = self.fc_mean(x) 89 | if self.activation is not None: 90 | action_mean = self.activation(action_mean) 91 | 92 | # An ugly hack for my KFAC implementation. 93 | zeros = torch.zeros(action_mean.size()) 94 | if x.is_cuda: 95 | zeros = zeros.cuda() 96 | 97 | action_logstd = self.logstd(zeros) 98 | return FixedNormal(action_mean, action_logstd.exp()) 99 | 100 | 101 | class Bernoulli(nn.Module): 102 | def __init__(self, num_inputs, num_outputs): 103 | super(Bernoulli, self).__init__() 104 | 105 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 106 | constant_(x, 0)) 107 | 108 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 109 | 110 | def forward(self, x): 111 | x = self.linear(x) 112 | return FixedBernoulli(logits=x) 113 | -------------------------------------------------------------------------------- /a2c_ppo_acktr/envs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import gym 4 | import numpy as np 5 | import torch 6 | from gym.spaces.box import Box 7 | 8 | from baselines import bench 9 | from baselines.common.atari_wrappers import make_atari, wrap_deepmind, ScaledFloatFrame, WarpFrame 10 | from baselines.common.vec_env import VecEnvWrapper 11 | from baselines.common.vec_env.dummy_vec_env import DummyVecEnv 12 | from baselines.common.vec_env.shmem_vec_env import ShmemVecEnv 13 | from baselines.common.vec_env.vec_normalize import \ 14 | VecNormalize as VecNormalize_ 15 | 16 | import gym_minigrid 17 | from gym_minigrid import wrappers as mgwr 18 | try: 19 | import dm_control2gym 20 | except ImportError: 21 | pass 22 | 23 | try: 24 | import roboschool 25 | except ImportError: 26 | pass 27 | 28 | try: 29 | import pybullet_envs 30 | except ImportError: 31 | pass 32 | 33 | 34 | def make_env(env_id, seed, rank, log_dir, allow_early_resets): 35 | def _thunk(): 36 | if env_id.startswith("dm"): 37 | _, domain, task = env_id.split('.') 38 | env = dm_control2gym.make(domain_name=domain, task_name=task) 39 | else: 40 | env = gym.make(env_id) 41 | 42 | is_atari = hasattr(gym.envs, 'atari') and isinstance( 43 | env.unwrapped, gym.envs.atari.atari_env.AtariEnv) 44 | if is_atari: 45 | env = make_atari(env_id) 46 | 47 | env.seed(seed + rank) 48 | 49 | if str(env.__class__.__name__).find('TimeLimit') >= 0: 50 | env = TimeLimitMask(env) 51 | 52 | # minigrid 53 | keep_classes = ['agent', 'goal', 'wall', 'empty'] 54 | if 'key' in env_id.lower(): 55 | keep_classes.extend(['door', 'key']) 56 | 57 | if env_id.startswith('MiniGrid'): 58 | env = mgwr.FullyObsWrapper(env) 59 | env = mgwr.ImgObsWrapper(env) 60 | env = mgwr.FullyObsOneHotWrapper(env, drop_color=1, keep_classes=keep_classes, flatten=False) 61 | 62 | if log_dir is not None: 63 | env = bench.Monitor( 64 | env, 65 | os.path.join(log_dir, str(rank)), 66 | allow_early_resets=allow_early_resets) 67 | 68 | if is_atari: 69 | if len(env.observation_space.shape) == 3: 70 | env = wrap_deepmind(env) 71 | elif len(env.observation_space.shape) == 3: 72 | if env_id.startswith('CarRacing'): 73 | env = WarpFrame(env, width=96, height=96, grayscale=True) 74 | env = ScaledFloatFrame(env) 75 | else: 76 | raise NotImplementedError 77 | 78 | # If the input has shape (W,H,3), wrap for PyTorch convolutions 79 | obs_shape = env.observation_space.shape 80 | if len(obs_shape) == 3: 81 | env = TransposeImage(env, op=[2, 0, 1]) 82 | 83 | return env 84 | 85 | return _thunk 86 | 87 | 88 | def make_vec_envs(env_name, 89 | seed, 90 | num_processes, 91 | gamma, 92 | log_dir, 93 | device, 94 | allow_early_resets, 95 | training=True, 96 | num_frame_stack=None, 97 | red=False, 98 | ): 99 | envs = [ 100 | make_env(env_name, seed, i, log_dir, allow_early_resets) 101 | for i in range(num_processes) 102 | ] 103 | 104 | if len(envs) > 1: 105 | envs = ShmemVecEnv(envs, context='fork') 106 | else: 107 | envs = DummyVecEnv(envs) 108 | 109 | # Dont filter if RED 110 | obfilt = not red 111 | if len(envs.observation_space.shape) == 1: 112 | if gamma is None: 113 | envs = VecNormalize(envs, ob=obfilt, ret=False) 114 | else: 115 | envs = VecNormalize(envs, ob=obfilt, gamma=gamma) 116 | if not training: 117 | envs.eval() 118 | 119 | elif env_name.startswith('CarRacing'): 120 | # Car Racing, use a normalizer for rewards 121 | envs = VecNormalize(envs, ob=False, ret=training, clipob=1e10, cliprew=1.0) 122 | if not training: 123 | envs.eval() 124 | 125 | envs = VecPyTorch(envs, device) 126 | # Hack for now 127 | is_atari = env_name.startswith('MiniGrid') or env_name.startswith('CarRacing') 128 | is_atari = not is_atari 129 | 130 | if num_frame_stack is not None: 131 | envs = VecPyTorchFrameStack(envs, num_frame_stack, device) 132 | elif len(envs.observation_space.shape) == 3: 133 | envs = VecPyTorchFrameStack(envs, 4, device) 134 | 135 | return envs 136 | 137 | 138 | # Checks whether done was caused my timit limits or not 139 | class TimeLimitMask(gym.Wrapper): 140 | def step(self, action): 141 | obs, rew, done, info = self.env.step(action) 142 | if done and self.env._max_episode_steps == self.env._elapsed_steps: 143 | info['bad_transition'] = True 144 | 145 | return obs, rew, done, info 146 | 147 | def reset(self, **kwargs): 148 | return self.env.reset(**kwargs) 149 | 150 | 151 | # Can be used to test recurrent policies for Reacher-v2 152 | class MaskGoal(gym.ObservationWrapper): 153 | def observation(self, observation): 154 | if self.env._elapsed_steps > 0: 155 | observation[-2:] = 0 156 | return observation 157 | 158 | 159 | class TransposeObs(gym.ObservationWrapper): 160 | def __init__(self, env=None): 161 | """ 162 | Transpose observation space (base class) 163 | """ 164 | super(TransposeObs, self).__init__(env) 165 | 166 | 167 | class TransposeImage(TransposeObs): 168 | def __init__(self, env=None, op=[2, 0, 1]): 169 | """ 170 | Transpose observation space for images 171 | """ 172 | super(TransposeImage, self).__init__(env) 173 | assert len(op) == 3, "Error: Operation, " + str(op) + ", must be dim3" 174 | self.op = op 175 | obs_shape = self.observation_space.shape 176 | self.observation_space = Box( 177 | self.observation_space.low[0, 0, 0], 178 | self.observation_space.high[0, 0, 0], [ 179 | obs_shape[self.op[0]], obs_shape[self.op[1]], 180 | obs_shape[self.op[2]] 181 | ], 182 | dtype=self.observation_space.dtype) 183 | 184 | def observation(self, ob): 185 | return ob.transpose(self.op[0], self.op[1], self.op[2]) 186 | 187 | 188 | class VecPyTorch(VecEnvWrapper): 189 | def __init__(self, venv, device): 190 | """Return only every `skip`-th frame""" 191 | super(VecPyTorch, self).__init__(venv) 192 | self.device = device 193 | # TODO: Fix data types 194 | 195 | def reset(self): 196 | obs = self.venv.reset() 197 | obs = torch.from_numpy(obs).float().to(self.device) 198 | return obs 199 | 200 | def step_async(self, actions): 201 | if isinstance(actions, torch.LongTensor): 202 | # Squeeze the dimension for discrete actions 203 | actions = actions.squeeze(1) 204 | actions = actions.cpu().numpy() 205 | self.venv.step_async(actions) 206 | 207 | def step_wait(self): 208 | obs, reward, done, info = self.venv.step_wait() 209 | obs = torch.from_numpy(obs).float().to(self.device) 210 | reward = torch.from_numpy(reward).unsqueeze(dim=1).float() 211 | return obs, reward, done, info 212 | 213 | 214 | class VecNormalize(VecNormalize_): 215 | def __init__(self, *args, **kwargs): 216 | super(VecNormalize, self).__init__(*args, **kwargs) 217 | self.training = True 218 | 219 | def _rev_obfilt(self, n_obs, update=False): 220 | if self.ob_rms: 221 | obs = np.sqrt(self.ob_rms.var + self.epsilon) * n_obs + self.ob_rms.mean 222 | else: 223 | obs = n_obs 224 | return obs 225 | 226 | def _obfilt(self, obs, update=True): 227 | if self.ob_rms: 228 | if self.training and update: 229 | self.ob_rms.update(obs) 230 | obs = np.clip((obs - self.ob_rms.mean) / 231 | np.sqrt(self.ob_rms.var + self.epsilon), 232 | -self.clipob, self.clipob) 233 | return obs 234 | else: 235 | return obs 236 | 237 | def train(self): 238 | self.training = True 239 | 240 | def eval(self): 241 | self.training = False 242 | 243 | 244 | # Derived from 245 | # https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_frame_stack.py 246 | class VecPyTorchFrameStack(VecEnvWrapper): 247 | def __init__(self, venv, nstack, device=None): 248 | self.venv = venv 249 | self.nstack = nstack 250 | 251 | wos = venv.observation_space # wrapped ob space 252 | self.shape_dim0 = wos.shape[0] 253 | 254 | low = np.repeat(wos.low, self.nstack, axis=0) 255 | high = np.repeat(wos.high, self.nstack, axis=0) 256 | 257 | if device is None: 258 | device = torch.device('cpu') 259 | self.stacked_obs = torch.zeros((venv.num_envs, ) + 260 | low.shape).to(device) 261 | 262 | observation_space = gym.spaces.Box( 263 | low=low, high=high, dtype=venv.observation_space.dtype) 264 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space) 265 | 266 | def step_wait(self): 267 | obs, rews, news, infos = self.venv.step_wait() 268 | self.stacked_obs[:, :-self.shape_dim0] = \ 269 | self.stacked_obs[:, self.shape_dim0:] 270 | for (i, new) in enumerate(news): 271 | if new: 272 | self.stacked_obs[i] = 0 273 | self.stacked_obs[:, -self.shape_dim0:] = obs 274 | return self.stacked_obs, rews, news, infos 275 | 276 | def reset(self): 277 | obs = self.venv.reset() 278 | if torch.backends.cudnn.deterministic: 279 | self.stacked_obs = torch.zeros(self.stacked_obs.shape) 280 | else: 281 | self.stacked_obs.zero_() 282 | self.stacked_obs[:, -self.shape_dim0:] = obs 283 | return self.stacked_obs 284 | 285 | def close(self): 286 | self.venv.close() 287 | -------------------------------------------------------------------------------- /a2c_ppo_acktr/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from a2c_ppo_acktr.distributions import Bernoulli, Categorical, DiagGaussian 7 | from a2c_ppo_acktr.utils import init 8 | 9 | 10 | class Flatten(nn.Module): 11 | def forward(self, x): 12 | return x.view(x.size(0), -1) 13 | 14 | 15 | class Policy(nn.Module): 16 | def __init__(self, obs_shape, action_space, base=None, base_kwargs=None, activation=None): 17 | super(Policy, self).__init__() 18 | if base_kwargs is None: 19 | base_kwargs = {} 20 | if base is None: 21 | if len(obs_shape) == 3: 22 | base = CNNBase 23 | elif len(obs_shape) == 1: 24 | base = MLPBase 25 | else: 26 | raise NotImplementedError 27 | 28 | self.base = base(obs_shape[0], **base_kwargs) 29 | 30 | if action_space.__class__.__name__ == "Discrete": 31 | num_outputs = action_space.n 32 | self.dist = Categorical(self.base.output_size, num_outputs) 33 | elif action_space.__class__.__name__ == "Box": 34 | num_outputs = action_space.shape[0] 35 | self.dist = DiagGaussian(self.base.output_size, num_outputs, activation=activation) 36 | elif action_space.__class__.__name__ == "MultiBinary": 37 | num_outputs = action_space.shape[0] 38 | self.dist = Bernoulli(self.base.output_size, num_outputs) 39 | else: 40 | raise NotImplementedError 41 | 42 | @property 43 | def is_recurrent(self): 44 | return self.base.is_recurrent 45 | 46 | @property 47 | def recurrent_hidden_state_size(self): 48 | """Size of rnn_hx.""" 49 | return self.base.recurrent_hidden_state_size 50 | 51 | def forward(self, inputs, rnn_hxs, masks): 52 | raise NotImplementedError 53 | 54 | def act(self, inputs, rnn_hxs, masks, deterministic=False): 55 | value, actor_features, rnn_hxs = self.base(inputs, rnn_hxs, masks) 56 | dist = self.dist(actor_features) 57 | 58 | if deterministic: 59 | action = dist.mode() 60 | else: 61 | action = dist.sample() 62 | 63 | action_log_probs = dist.log_probs(action) 64 | dist_entropy = dist.entropy().mean() 65 | 66 | return value, action, action_log_probs, rnn_hxs 67 | 68 | def get_value(self, inputs, rnn_hxs, masks): 69 | value, _, _ = self.base(inputs, rnn_hxs, masks) 70 | return value 71 | 72 | def evaluate_actions(self, inputs, rnn_hxs, masks, action): 73 | value, actor_features, rnn_hxs = self.base(inputs, rnn_hxs, masks) 74 | dist = self.dist(actor_features) 75 | 76 | action_log_probs = dist.log_probs(action) 77 | dist_entropy = dist.entropy().mean() 78 | 79 | return value, action_log_probs, dist_entropy, rnn_hxs 80 | 81 | 82 | class NNBase(nn.Module): 83 | def __init__(self, recurrent, recurrent_input_size, hidden_size): 84 | super(NNBase, self).__init__() 85 | 86 | self._hidden_size = hidden_size 87 | self._recurrent = recurrent 88 | 89 | if recurrent: 90 | self.gru = nn.GRU(recurrent_input_size, hidden_size) 91 | for name, param in self.gru.named_parameters(): 92 | if 'bias' in name: 93 | nn.init.constant_(param, 0) 94 | elif 'weight' in name: 95 | nn.init.orthogonal_(param) 96 | 97 | @property 98 | def is_recurrent(self): 99 | return self._recurrent 100 | 101 | @property 102 | def recurrent_hidden_state_size(self): 103 | if self._recurrent: 104 | return self._hidden_size 105 | return 1 106 | 107 | @property 108 | def output_size(self): 109 | return self._hidden_size 110 | 111 | def _forward_gru(self, x, hxs, masks): 112 | if x.size(0) == hxs.size(0): 113 | x, hxs = self.gru(x.unsqueeze(0), (hxs * masks).unsqueeze(0)) 114 | x = x.squeeze(0) 115 | hxs = hxs.squeeze(0) 116 | else: 117 | # x is a (T, N, -1) tensor that has been flatten to (T * N, -1) 118 | N = hxs.size(0) 119 | T = int(x.size(0) / N) 120 | 121 | # unflatten 122 | x = x.view(T, N, x.size(1)) 123 | 124 | # Same deal with masks 125 | masks = masks.view(T, N) 126 | 127 | # Let's figure out which steps in the sequence have a zero for any agent 128 | # We will always assume t=0 has a zero in it as that makes the logic cleaner 129 | has_zeros = ((masks[1:] == 0.0) \ 130 | .any(dim=-1) 131 | .nonzero() 132 | .squeeze() 133 | .cpu()) 134 | 135 | # +1 to correct the masks[1:] 136 | if has_zeros.dim() == 0: 137 | # Deal with scalar 138 | has_zeros = [has_zeros.item() + 1] 139 | else: 140 | has_zeros = (has_zeros + 1).numpy().tolist() 141 | 142 | # add t=0 and t=T to the list 143 | has_zeros = [0] + has_zeros + [T] 144 | 145 | hxs = hxs.unsqueeze(0) 146 | outputs = [] 147 | for i in range(len(has_zeros) - 1): 148 | # We can now process steps that don't have any zeros in masks together! 149 | # This is much faster 150 | start_idx = has_zeros[i] 151 | end_idx = has_zeros[i + 1] 152 | 153 | rnn_scores, hxs = self.gru( 154 | x[start_idx:end_idx], 155 | hxs * masks[start_idx].view(1, -1, 1)) 156 | 157 | outputs.append(rnn_scores) 158 | 159 | # assert len(outputs) == T 160 | # x is a (T, N, -1) tensor 161 | x = torch.cat(outputs, dim=0) 162 | # flatten 163 | x = x.view(T * N, -1) 164 | hxs = hxs.squeeze(0) 165 | 166 | return x, hxs 167 | 168 | 169 | class CNNBase(NNBase): 170 | def __init__(self, num_inputs, recurrent=False, hidden_size=512, env=None): 171 | super(CNNBase, self).__init__(recurrent, hidden_size, hidden_size) 172 | 173 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 174 | constant_(x, 0), nn.init.calculate_gain('relu')) 175 | 176 | if env != 'CarRacing-v0': 177 | self.main = nn.Sequential( 178 | init_(nn.Conv2d(num_inputs, 32, 8, stride=4)), nn.ReLU(), 179 | init_(nn.Conv2d(32, 64, 4, stride=2)), nn.ReLU(), 180 | init_(nn.Conv2d(64, 32, 3, stride=1)), nn.ReLU(), Flatten(), 181 | init_(nn.Linear(32 * 7 * 7, hidden_size)), nn.ReLU()) 182 | else: 183 | # For Car Racing 184 | print("Using CarRacing base") 185 | self.main = nn.Sequential( 186 | init_(nn.Conv2d(num_inputs, 32, 4, stride=2)), nn.ReLU(), 187 | init_(nn.Conv2d(32, 64, 4, stride=2)), nn.ReLU(), 188 | init_(nn.Conv2d(64, 128, 4, stride=2)), nn.ReLU(), 189 | init_(nn.Conv2d(128, 256, 4, stride=2)), nn.ReLU(), Flatten(), 190 | init_(nn.Linear(256 * 4 * 4, hidden_size)), nn.ReLU(), 191 | init_(nn.Linear(hidden_size, hidden_size)), nn.ReLU(), 192 | ) 193 | 194 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 195 | constant_(x, 0)) 196 | 197 | self.critic_linear = init_(nn.Linear(hidden_size, 1)) 198 | self.train() 199 | 200 | def forward(self, inputs, rnn_hxs, masks): 201 | x = self.main(inputs) 202 | 203 | if self.is_recurrent: 204 | x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks) 205 | 206 | return self.critic_linear(x), x, rnn_hxs 207 | 208 | 209 | class MLPBase(NNBase): 210 | def __init__(self, num_inputs, recurrent=False, hidden_size=64, env=None): 211 | super(MLPBase, self).__init__(recurrent, num_inputs, hidden_size) 212 | 213 | if recurrent: 214 | num_inputs = hidden_size 215 | 216 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 217 | constant_(x, 0), np.sqrt(2)) 218 | 219 | self.actor = nn.Sequential( 220 | init_(nn.Linear(num_inputs, hidden_size)), nn.Tanh(), 221 | init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh()) 222 | 223 | self.critic = nn.Sequential( 224 | init_(nn.Linear(num_inputs, hidden_size)), nn.Tanh(), 225 | init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh()) 226 | 227 | self.critic_linear = init_(nn.Linear(hidden_size, 1)) 228 | 229 | self.train() 230 | 231 | def forward(self, inputs, rnn_hxs, masks): 232 | x = inputs 233 | 234 | if self.is_recurrent: 235 | x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks) 236 | 237 | hidden_critic = self.critic(x) 238 | hidden_actor = self.actor(x) 239 | 240 | return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs 241 | -------------------------------------------------------------------------------- /a2c_ppo_acktr/storage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 3 | 4 | 5 | def _flatten_helper(T, N, _tensor): 6 | return _tensor.view(T * N, *_tensor.size()[2:]) 7 | 8 | 9 | class RolloutStorage(object): 10 | def __init__(self, num_steps, num_processes, obs_shape, action_space, 11 | recurrent_hidden_state_size): 12 | self.obs = torch.zeros(num_steps + 1, num_processes, *obs_shape) 13 | self.recurrent_hidden_states = torch.zeros( 14 | num_steps + 1, num_processes, recurrent_hidden_state_size) 15 | self.rewards = torch.zeros(num_steps, num_processes, 1) 16 | self.value_preds = torch.zeros(num_steps + 1, num_processes, 1) 17 | self.returns = torch.zeros(num_steps + 1, num_processes, 1) 18 | self.action_log_probs = torch.zeros(num_steps, num_processes, 1) 19 | if action_space.__class__.__name__ == 'Discrete': 20 | action_shape = 1 21 | else: 22 | action_shape = action_space.shape[0] 23 | self.actions = torch.zeros(num_steps, num_processes, action_shape) 24 | if action_space.__class__.__name__ == 'Discrete': 25 | self.actions = self.actions.long() 26 | self.masks = torch.ones(num_steps + 1, num_processes, 1) 27 | 28 | # Masks that indicate whether it's a true terminal state 29 | # or time limit end state 30 | self.bad_masks = torch.ones(num_steps + 1, num_processes, 1) 31 | 32 | self.num_steps = num_steps 33 | self.step = 0 34 | 35 | def to(self, device): 36 | self.obs = self.obs.to(device) 37 | self.recurrent_hidden_states = self.recurrent_hidden_states.to(device) 38 | self.rewards = self.rewards.to(device) 39 | self.value_preds = self.value_preds.to(device) 40 | self.returns = self.returns.to(device) 41 | self.action_log_probs = self.action_log_probs.to(device) 42 | self.actions = self.actions.to(device) 43 | self.masks = self.masks.to(device) 44 | self.bad_masks = self.bad_masks.to(device) 45 | 46 | def insert(self, obs, recurrent_hidden_states, actions, action_log_probs, 47 | value_preds, rewards, masks, bad_masks): 48 | self.obs[self.step + 1].copy_(obs) 49 | self.recurrent_hidden_states[self.step + 50 | 1].copy_(recurrent_hidden_states) 51 | self.actions[self.step].copy_(actions) 52 | self.action_log_probs[self.step].copy_(action_log_probs) 53 | self.value_preds[self.step].copy_(value_preds) 54 | self.rewards[self.step].copy_(rewards) 55 | self.masks[self.step + 1].copy_(masks) 56 | self.bad_masks[self.step + 1].copy_(bad_masks) 57 | 58 | self.step = (self.step + 1) % self.num_steps 59 | 60 | def after_update(self): 61 | self.obs[0].copy_(self.obs[-1]) 62 | self.recurrent_hidden_states[0].copy_(self.recurrent_hidden_states[-1]) 63 | self.masks[0].copy_(self.masks[-1]) 64 | self.bad_masks[0].copy_(self.bad_masks[-1]) 65 | 66 | def compute_returns(self, 67 | next_value, 68 | use_gae, 69 | gamma, 70 | gae_lambda, 71 | use_proper_time_limits=True): 72 | if use_proper_time_limits: 73 | if use_gae: 74 | self.value_preds[-1] = next_value 75 | gae = 0 76 | for step in reversed(range(self.rewards.size(0))): 77 | delta = self.rewards[step] + gamma * self.value_preds[ 78 | step + 1] * self.masks[step + 79 | 1] - self.value_preds[step] 80 | gae = delta + gamma * gae_lambda * self.masks[step + 81 | 1] * gae 82 | gae = gae * self.bad_masks[step + 1] 83 | self.returns[step] = gae + self.value_preds[step] 84 | else: 85 | self.returns[-1] = next_value 86 | for step in reversed(range(self.rewards.size(0))): 87 | self.returns[step] = (self.returns[step + 1] * \ 88 | gamma * self.masks[step + 1] + self.rewards[step]) * self.bad_masks[step + 1] \ 89 | + (1 - self.bad_masks[step + 1]) * self.value_preds[step] 90 | else: 91 | if use_gae: 92 | self.value_preds[-1] = next_value 93 | gae = 0 94 | for step in reversed(range(self.rewards.size(0))): 95 | delta = self.rewards[step] + gamma * self.value_preds[ 96 | step + 1] * self.masks[step + 97 | 1] - self.value_preds[step] 98 | gae = delta + gamma * gae_lambda * self.masks[step + 99 | 1] * gae 100 | self.returns[step] = gae + self.value_preds[step] 101 | else: 102 | self.returns[-1] = next_value 103 | for step in reversed(range(self.rewards.size(0))): 104 | self.returns[step] = self.returns[step + 1] * \ 105 | gamma * self.masks[step + 1] + self.rewards[step] 106 | 107 | def feed_forward_generator(self, 108 | advantages, 109 | num_mini_batch=None, 110 | mini_batch_size=None): 111 | num_steps, num_processes = self.rewards.size()[0:2] 112 | batch_size = num_processes * num_steps 113 | 114 | if mini_batch_size is None: 115 | assert batch_size >= num_mini_batch, ( 116 | "PPO requires the number of processes ({}) " 117 | "* number of steps ({}) = {} " 118 | "to be greater than or equal to the number of PPO mini batches ({})." 119 | "".format(num_processes, num_steps, num_processes * num_steps, 120 | num_mini_batch)) 121 | mini_batch_size = batch_size // num_mini_batch 122 | sampler = BatchSampler( 123 | SubsetRandomSampler(range(batch_size)), 124 | mini_batch_size, 125 | drop_last=True) 126 | for indices in sampler: 127 | obs_batch = self.obs[:-1].view(-1, *self.obs.size()[2:])[indices] 128 | recurrent_hidden_states_batch = self.recurrent_hidden_states[:-1].view( 129 | -1, self.recurrent_hidden_states.size(-1))[indices] 130 | actions_batch = self.actions.view(-1, 131 | self.actions.size(-1))[indices] 132 | value_preds_batch = self.value_preds[:-1].view(-1, 1)[indices] 133 | return_batch = self.returns[:-1].view(-1, 1)[indices] 134 | masks_batch = self.masks[:-1].view(-1, 1)[indices] 135 | old_action_log_probs_batch = self.action_log_probs.view(-1, 136 | 1)[indices] 137 | if advantages is None: 138 | adv_targ = None 139 | else: 140 | adv_targ = advantages.view(-1, 1)[indices] 141 | 142 | yield obs_batch, recurrent_hidden_states_batch, actions_batch, \ 143 | value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ 144 | 145 | def recurrent_generator(self, advantages, num_mini_batch): 146 | num_processes = self.rewards.size(1) 147 | assert num_processes >= num_mini_batch, ( 148 | "PPO requires the number of processes ({}) " 149 | "to be greater than or equal to the number of " 150 | "PPO mini batches ({}).".format(num_processes, num_mini_batch)) 151 | num_envs_per_batch = num_processes // num_mini_batch 152 | perm = torch.randperm(num_processes) 153 | for start_ind in range(0, num_processes, num_envs_per_batch): 154 | obs_batch = [] 155 | recurrent_hidden_states_batch = [] 156 | actions_batch = [] 157 | value_preds_batch = [] 158 | return_batch = [] 159 | masks_batch = [] 160 | old_action_log_probs_batch = [] 161 | adv_targ = [] 162 | 163 | for offset in range(num_envs_per_batch): 164 | ind = perm[start_ind + offset] 165 | obs_batch.append(self.obs[:-1, ind]) 166 | recurrent_hidden_states_batch.append( 167 | self.recurrent_hidden_states[0:1, ind]) 168 | actions_batch.append(self.actions[:, ind]) 169 | value_preds_batch.append(self.value_preds[:-1, ind]) 170 | return_batch.append(self.returns[:-1, ind]) 171 | masks_batch.append(self.masks[:-1, ind]) 172 | old_action_log_probs_batch.append( 173 | self.action_log_probs[:, ind]) 174 | adv_targ.append(advantages[:, ind]) 175 | 176 | T, N = self.num_steps, num_envs_per_batch 177 | # These are all tensors of size (T, N, -1) 178 | obs_batch = torch.stack(obs_batch, 1) 179 | actions_batch = torch.stack(actions_batch, 1) 180 | value_preds_batch = torch.stack(value_preds_batch, 1) 181 | return_batch = torch.stack(return_batch, 1) 182 | masks_batch = torch.stack(masks_batch, 1) 183 | old_action_log_probs_batch = torch.stack( 184 | old_action_log_probs_batch, 1) 185 | adv_targ = torch.stack(adv_targ, 1) 186 | 187 | # States is just a (N, -1) tensor 188 | recurrent_hidden_states_batch = torch.stack( 189 | recurrent_hidden_states_batch, 1).view(N, -1) 190 | 191 | # Flatten the (T, N, ...) tensors to (T * N, ...) 192 | obs_batch = _flatten_helper(T, N, obs_batch) 193 | actions_batch = _flatten_helper(T, N, actions_batch) 194 | value_preds_batch = _flatten_helper(T, N, value_preds_batch) 195 | return_batch = _flatten_helper(T, N, return_batch) 196 | masks_batch = _flatten_helper(T, N, masks_batch) 197 | old_action_log_probs_batch = _flatten_helper(T, N, \ 198 | old_action_log_probs_batch) 199 | adv_targ = _flatten_helper(T, N, adv_targ) 200 | 201 | yield obs_batch, recurrent_hidden_states_batch, actions_batch, \ 202 | value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ 203 | -------------------------------------------------------------------------------- /a2c_ppo_acktr/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from a2c_ppo_acktr.envs import VecNormalize 8 | 9 | 10 | # Get a render function 11 | def get_render_func(venv): 12 | if hasattr(venv, 'envs'): 13 | return venv.envs[0].render 14 | elif hasattr(venv, 'venv'): 15 | return get_render_func(venv.venv) 16 | elif hasattr(venv, 'env'): 17 | return get_render_func(venv.env) 18 | 19 | return None 20 | 21 | 22 | def get_vec_normalize(venv): 23 | if isinstance(venv, VecNormalize): 24 | return venv 25 | elif hasattr(venv, 'venv'): 26 | return get_vec_normalize(venv.venv) 27 | 28 | return None 29 | 30 | 31 | # Necessary for my KFAC implementation. 32 | class AddBias(nn.Module): 33 | def __init__(self, bias): 34 | super(AddBias, self).__init__() 35 | self._bias = nn.Parameter(bias.unsqueeze(1)) 36 | 37 | def forward(self, x): 38 | if x.dim() == 2: 39 | bias = self._bias.t().view(1, -1) 40 | else: 41 | bias = self._bias.t().view(1, -1, 1, 1) 42 | 43 | return x + bias 44 | 45 | 46 | def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr): 47 | """Decreases the learning rate linearly""" 48 | lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs))) 49 | for param_group in optimizer.param_groups: 50 | param_group['lr'] = lr 51 | 52 | 53 | def init(module, weight_init, bias_init, gain=1): 54 | weight_init(module.weight.data, gain=gain) 55 | bias_init(module.bias.data) 56 | return module 57 | 58 | 59 | def cleanup_log_dir(log_dir): 60 | try: 61 | os.makedirs(log_dir) 62 | except OSError: 63 | files = glob.glob(os.path.join(log_dir, '*.monitor.csv')) 64 | for f in files: 65 | os.remove(f) 66 | -------------------------------------------------------------------------------- /alphamujoco.sh: -------------------------------------------------------------------------------- 1 | env=$1 2 | total_timesteps=${2:-3000000} 3 | if [ $# -le 0 ]; then 4 | echo './mujoco.sh ' 5 | exit 6 | fi 7 | 8 | 9 | for seed in {1,2,4,} 10 | do 11 | #for gamma in {0.5,0.25,0.125} 12 | for gamma in {0.75,0.50,0.25} 13 | do 14 | python main.py --env-name ${env}-v2 --algo ppo --use-gae --log-interval 1 --num-steps 2048 --num-processes 8 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --gae-lambda 0.95 --num-env-steps ${total_timesteps} --use-linear-lr-decay --use-proper-time-limits --gail --model_name ${env}alphBCGAIL${gamma} --seed ${seed} --bcgail 1 --gailgamma $gamma --decay 1 --num-traj 1 & 15 | done 16 | python main.py --env-name ${env}-v2 --algo ppo --use-gae --log-interval 1 --num-steps 2048 --num-processes 8 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --gae-lambda 0.95 --num-env-steps ${total_timesteps} --use-linear-lr-decay --use-proper-time-limits --gail --model_name ${env}alphBCGAIL --seed ${seed} --bcgail 1 --gailgamma 1 --decay 10 --num-traj 1 17 | done 18 | -------------------------------------------------------------------------------- /baselinesbc.sh: -------------------------------------------------------------------------------- 1 | env=$1 2 | timesteps=${2:-3000000} 3 | if [ $# -le 0 ]; then 4 | echo './mujoco.sh ' 5 | exit 6 | fi 7 | 8 | # Run behavior cloning and pretrain 9 | for seed in {1,2,4,} 10 | do 11 | python bc.py --env-name ${env}-v2 --algo ppo --use-gae --log-interval 1 --num-steps 1000 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --gae-lambda 0.95 --num-env-steps 1000000 --use-linear-lr-decay --use-proper-time-limits --model_name ${env}BC --gail --save-interval 1 --seed ${seed} 12 | python main.py --env-name ${env}-v2 --algo ppo --use-gae --log-interval 1 --num-steps 2048 --num-processes 8 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --gae-lambda 0.95 --num-env-steps ${timesteps} --use-linear-lr-decay --use-proper-time-limits --gail --load_model_name ${env}BC --model_name ${env}GAILpretrain --seed ${seed} 13 | done 14 | -------------------------------------------------------------------------------- /bc.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import glob 3 | import os 4 | import time 5 | from collections import deque 6 | 7 | import gym 8 | import numpy as np 9 | import torch 10 | from torch.autograd import Variable 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | 15 | from a2c_ppo_acktr import algo, utils 16 | from a2c_ppo_acktr.algo import gail 17 | from a2c_ppo_acktr.arguments import get_args 18 | from a2c_ppo_acktr.envs import make_vec_envs 19 | from a2c_ppo_acktr.model import Policy 20 | from a2c_ppo_acktr.storage import RolloutStorage 21 | from evaluation import evaluate 22 | 23 | 24 | def record_trajectories(): 25 | args = get_args() 26 | print(args) 27 | 28 | torch.manual_seed(args.seed) 29 | torch.cuda.manual_seed_all(args.seed) 30 | 31 | if args.cuda and torch.cuda.is_available() and args.cuda_deterministic: 32 | torch.backends.cudnn.benchmark = False 33 | torch.backends.cudnn.deterministic = True 34 | 35 | # Append the model name 36 | log_dir = os.path.expanduser(args.log_dir) 37 | log_dir = os.path.join(log_dir, args.model_name, str(args.seed)) 38 | 39 | eval_log_dir = log_dir + "_eval" 40 | utils.cleanup_log_dir(log_dir) 41 | utils.cleanup_log_dir(eval_log_dir) 42 | 43 | torch.set_num_threads(1) 44 | device = torch.device("cuda:0" if args.cuda else "cpu") 45 | 46 | envs = make_vec_envs(args.env_name, args.seed, 1, 47 | args.gamma, log_dir, device, True, training=False) 48 | 49 | # Take activation for carracing 50 | print("Loaded env...") 51 | activation = None 52 | if args.env_name == 'CarRacing-v0' and args.use_activation: 53 | activation = torch.tanh 54 | print(activation) 55 | 56 | actor_critic = Policy( 57 | envs.observation_space.shape, 58 | envs.action_space, 59 | base_kwargs={'recurrent': args.recurrent_policy, 'env': args.env_name}, 60 | activation=activation, 61 | ) 62 | actor_critic.to(device) 63 | 64 | # Load from previous model 65 | if args.load_model_name: 66 | loaddata = torch.load(os.path.join(args.save_dir, args.load_model_name, args.load_model_name + '_{}.pt'.format(args.seed))) 67 | state = loaddata[0] 68 | try: 69 | obs_rms, ret_rms = loaddata[1:] 70 | # Feed it into the env 71 | envs.obs_rms = None 72 | envs.ret_rms = None 73 | except: 74 | print("Couldnt load obsrms") 75 | obs_rms = ret_rms = None 76 | try: 77 | actor_critic.load_state_dict(state) 78 | except: 79 | actor_critic = state 80 | else: 81 | raise NotImplementedError 82 | 83 | # Record trajectories 84 | actions = [] 85 | rewards = [] 86 | observations = [] 87 | episode_starts = [] 88 | 89 | for eps in range(args.num_episodes): 90 | obs = envs.reset() 91 | # Init variables for storing 92 | episode_starts.append(True) 93 | reward = 0 94 | while True: 95 | # Take action 96 | act = actor_critic.act(obs, None, None, None)[1] 97 | next_state, rew, done, info = envs.step(act) 98 | #print(obs.shape, act.shape, rew.shape, done) 99 | reward += rew 100 | # Add the current observation and act 101 | observations.append(obs.data.cpu().numpy()[0]) # [C, H, W] 102 | actions.append(act.data.cpu().numpy()[0]) # [A] 103 | rewards.append(rew[0, 0].data.cpu().numpy()) 104 | if done[0]: 105 | break 106 | episode_starts.append(False) 107 | obs = next_state + 0 108 | print("Total reward: {}".format(reward[0, 0].data.cpu().numpy())) 109 | 110 | # Save these values 111 | save_trajectories_images(observations, actions, rewards, episode_starts) 112 | 113 | 114 | def save_trajectories_images(obs, acts, rews, eps): 115 | args = get_args() 116 | obs_path = [] 117 | acts = np.array(acts) 118 | rews = np.array(rews) 119 | eps = np.array(eps) 120 | print(acts.shape, rews.shape, eps.shape) 121 | 122 | # Get image dir to save 123 | save_dir = os.path.join(args.save_dir, args.load_model_name, ) 124 | image_dir = os.path.join(save_dir, 'images') 125 | os.makedirs(save_dir, exist_ok=True) 126 | os.makedirs(image_dir, exist_ok=True) 127 | 128 | image_id = 0 129 | for ob in obs: 130 | # Scaled image from [0, 1] 131 | path = os.path.join(image_dir, str(image_id)) 132 | obimg = (ob * 255).astype(np.uint8).transpose(1, 2, 0) # [H, W, C] 133 | # Save image and record image path 134 | np.save(path, obimg) 135 | obs_path.append(path) 136 | image_id += 1 137 | 138 | expert_dict = { 139 | 'obs': obs_path, 140 | 'actions': acts, 141 | 'rewards': rews, 142 | 'episode_starts': eps, 143 | } 144 | 145 | torch.save(expert_dict, os.path.join(save_dir, 'expert_data.pkl')) 146 | print("Saved") 147 | 148 | def main(): 149 | args = get_args() 150 | 151 | # Record trajectories 152 | if args.record_trajectories: 153 | record_trajectories() 154 | return 155 | 156 | print(args) 157 | torch.manual_seed(args.seed) 158 | torch.cuda.manual_seed_all(args.seed) 159 | 160 | if args.cuda and torch.cuda.is_available() and args.cuda_deterministic: 161 | torch.backends.cudnn.benchmark = False 162 | torch.backends.cudnn.deterministic = True 163 | 164 | # Append the model name 165 | log_dir = os.path.expanduser(args.log_dir) 166 | log_dir = os.path.join(log_dir, args.model_name, str(args.seed)) 167 | 168 | eval_log_dir = log_dir + "_eval" 169 | utils.cleanup_log_dir(log_dir) 170 | utils.cleanup_log_dir(eval_log_dir) 171 | 172 | torch.set_num_threads(1) 173 | device = torch.device("cuda:0" if args.cuda else "cpu") 174 | 175 | envs = make_vec_envs(args.env_name, args.seed, args.num_processes, 176 | args.gamma, log_dir, device, False) 177 | 178 | # Take activation for carracing 179 | print("Loaded env...") 180 | activation = None 181 | if args.env_name == 'CarRacing-v0' and args.use_activation: 182 | activation = torch.tanh 183 | print(activation) 184 | 185 | actor_critic = Policy( 186 | envs.observation_space.shape, 187 | envs.action_space, 188 | base_kwargs={'recurrent': args.recurrent_policy, 'env': args.env_name}, 189 | activation=activation 190 | ) 191 | actor_critic.to(device) 192 | # Load from previous model 193 | if args.load_model_name: 194 | state = torch.load(os.path.join(args.save_dir, args.load_model_name, args.load_model_name + '_{}.pt'.format(args.seed)))[0] 195 | try: 196 | actor_critic.load_state_dict(state) 197 | except: 198 | actor_critic = state 199 | 200 | if args.algo == 'a2c': 201 | agent = algo.A2C_ACKTR( 202 | actor_critic, 203 | args.value_loss_coef, 204 | args.entropy_coef, 205 | lr=args.lr, 206 | eps=args.eps, 207 | alpha=args.alpha, 208 | max_grad_norm=args.max_grad_norm) 209 | elif args.algo == 'ppo': 210 | agent = algo.PPO( 211 | actor_critic, 212 | args.clip_param, 213 | args.ppo_epoch, 214 | args.num_mini_batch, 215 | args.value_loss_coef, 216 | args.entropy_coef, 217 | lr=args.lr, 218 | eps=args.eps, 219 | max_grad_norm=args.max_grad_norm) 220 | elif args.algo == 'acktr': 221 | agent = algo.A2C_ACKTR( 222 | actor_critic, args.value_loss_coef, args.entropy_coef, acktr=True) 223 | 224 | if args.gail: 225 | if len(envs.observation_space.shape) == 1: 226 | discr = gail.Discriminator( 227 | envs.observation_space.shape[0] + envs.action_space.shape[0], 100, 228 | device) 229 | file_name = os.path.join( 230 | args.gail_experts_dir, "trajs_{}.pt".format( 231 | args.env_name.split('-')[0].lower())) 232 | 233 | expert_dataset = gail.ExpertDataset( 234 | file_name, num_trajectories=3, subsample_frequency=1) 235 | expert_dataset_test = gail.ExpertDataset( 236 | file_name, num_trajectories=1, start=3, subsample_frequency=1) 237 | drop_last = len(expert_dataset) > args.gail_batch_size 238 | gail_train_loader = torch.utils.data.DataLoader( 239 | dataset=expert_dataset, 240 | batch_size=args.gail_batch_size, 241 | shuffle=True, 242 | drop_last=drop_last) 243 | gail_test_loader = torch.utils.data.DataLoader( 244 | dataset=expert_dataset_test, 245 | batch_size=args.gail_batch_size, 246 | shuffle=False, 247 | drop_last=False) 248 | print(len(expert_dataset), len(expert_dataset_test)) 249 | else: 250 | # env observation shape is 3 => its an image 251 | assert len(envs.observation_space.shape) == 3 252 | discr = gail.CNNDiscriminator( 253 | envs.observation_space.shape, envs.action_space, 100, 254 | device) 255 | file_name = os.path.join( 256 | args.gail_experts_dir, 'expert_data.pkl') 257 | 258 | expert_dataset = gail.ExpertImageDataset(file_name, train=True) 259 | test_dataset = gail.ExpertImageDataset(file_name, train=False) 260 | gail_train_loader = torch.utils.data.DataLoader( 261 | dataset=expert_dataset, 262 | batch_size=args.gail_batch_size, 263 | shuffle=True, 264 | drop_last = len(expert_dataset) > args.gail_batch_size, 265 | ) 266 | gail_test_loader = torch.utils.data.DataLoader( 267 | dataset=test_dataset, 268 | batch_size=args.gail_batch_size, 269 | shuffle=False, 270 | drop_last = len(test_dataset) > args.gail_batch_size, 271 | ) 272 | print('Dataloader size', len(gail_train_loader)) 273 | 274 | rollouts = RolloutStorage(args.num_steps, args.num_processes, 275 | envs.observation_space.shape, envs.action_space, 276 | actor_critic.recurrent_hidden_state_size) 277 | 278 | obs = envs.reset() 279 | rollouts.obs[0].copy_(obs) 280 | rollouts.to(device) 281 | 282 | episode_rewards = deque(maxlen=10) 283 | start = time.time() 284 | #num_updates = int( 285 | #args.num_env_steps) // args.num_steps // args.num_processes 286 | num_updates = args.num_steps 287 | print(num_updates) 288 | 289 | # count the number of times validation loss increases 290 | val_loss_increase = 0 291 | prev_val_action = np.inf 292 | best_val_loss = np.inf 293 | 294 | for j in range(num_updates): 295 | if args.use_linear_lr_decay: 296 | # decrease learning rate linearly 297 | utils.update_linear_schedule( 298 | agent.optimizer, j, num_updates, 299 | agent.optimizer.lr if args.algo == "acktr" else args.lr) 300 | 301 | 302 | for step in range(args.num_steps): 303 | # Sample actions 304 | with torch.no_grad(): 305 | value, action, action_log_prob, recurrent_hidden_states = actor_critic.act( 306 | rollouts.obs[step], rollouts.recurrent_hidden_states[step], 307 | rollouts.masks[step]) 308 | 309 | # Observe reward and next obs 310 | obs, reward, done, infos = envs.step(action) 311 | for info in infos: 312 | if 'episode' in info.keys(): 313 | episode_rewards.append(info['episode']['r']) 314 | 315 | # If done then clean the history of observations. 316 | masks = torch.FloatTensor( 317 | [[0.0] if done_ else [1.0] for done_ in done]) 318 | bad_masks = torch.FloatTensor( 319 | [[0.0] if 'bad_transition' in info.keys() else [1.0] 320 | for info in infos]) 321 | rollouts.insert(obs, recurrent_hidden_states, action, 322 | action_log_prob, value, reward, masks, bad_masks) 323 | 324 | with torch.no_grad(): 325 | next_value = actor_critic.get_value( 326 | rollouts.obs[-1], rollouts.recurrent_hidden_states[-1], 327 | rollouts.masks[-1]).detach() 328 | 329 | if args.gail: 330 | if j >= 10: 331 | try: 332 | envs.venv.eval() 333 | except: 334 | pass 335 | 336 | gail_epoch = args.gail_epoch 337 | #if j < 10: 338 | #gail_epoch = 100 # Warm up 339 | for _ in range(gail_epoch): 340 | #discr.update(gail_train_loader, rollouts, 341 | #None) 342 | pass 343 | 344 | for step in range(args.num_steps): 345 | rollouts.rewards[step] = discr.predict_reward( 346 | rollouts.obs[step], rollouts.actions[step], args.gamma, 347 | rollouts.masks[step]) 348 | 349 | rollouts.compute_returns(next_value, args.use_gae, args.gamma, 350 | args.gae_lambda, args.use_proper_time_limits) 351 | 352 | #value_loss, action_loss, dist_entropy = agent.update(rollouts) 353 | value_loss = 0 354 | dist_entropy = 0 355 | for data in gail_train_loader: 356 | expert_states, expert_actions = data 357 | expert_states = Variable(expert_states).to(device) 358 | expert_actions = Variable(expert_actions).to(device) 359 | loss = agent.update_bc(expert_states, expert_actions) 360 | action_loss = loss.data.cpu().numpy() 361 | print("Epoch: {}, Loss: {}".format(j, action_loss)) 362 | 363 | with torch.no_grad(): 364 | cnt = 0 365 | val_action_loss = 0 366 | for data in gail_test_loader: 367 | expert_states, expert_actions = data 368 | expert_states = Variable(expert_states).to(device) 369 | expert_actions = Variable(expert_actions).to(device) 370 | loss = agent.get_action_loss(expert_states, expert_actions) 371 | val_action_loss += loss.data.cpu().numpy() 372 | cnt += 1 373 | val_action_loss /= cnt 374 | print("Val Loss: {}".format(val_action_loss)) 375 | 376 | #rollouts.after_update() 377 | 378 | # save for every interval-th episode or for the last epoch 379 | if (j % args.save_interval == 0 380 | or j == num_updates - 1) and args.save_dir != "": 381 | 382 | if val_action_loss < best_val_loss: 383 | val_loss_increase = 0 384 | best_val_loss = val_action_loss 385 | save_path = os.path.join(args.save_dir, args.model_name) 386 | try: 387 | os.makedirs(save_path) 388 | except OSError: 389 | pass 390 | 391 | torch.save([ 392 | actor_critic.state_dict(), 393 | getattr(utils.get_vec_normalize(envs), 'ob_rms', None), 394 | getattr(utils.get_vec_normalize(envs), 'ret_rms', None) 395 | ], os.path.join(save_path, args.model_name + "_{}.pt".format(args.seed))) 396 | elif val_action_loss > prev_val_action: 397 | val_loss_increase += 1 398 | if val_loss_increase == 10: 399 | print("Val loss increasing too much, breaking here...") 400 | break 401 | elif val_action_loss < prev_val_action: 402 | val_loss_increase = 0 403 | 404 | # Update prev val action 405 | prev_val_action = val_action_loss 406 | 407 | # log interval 408 | if j % args.log_interval == 0 and len(episode_rewards) > 1: 409 | total_num_steps = (j + 1) * args.num_processes * args.num_steps 410 | end = time.time() 411 | print( 412 | "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n" 413 | .format(j, total_num_steps, 414 | int(total_num_steps / (end - start)), 415 | len(episode_rewards), np.mean(episode_rewards), 416 | np.median(episode_rewards), np.min(episode_rewards), 417 | np.max(episode_rewards), dist_entropy, value_loss, 418 | action_loss)) 419 | 420 | if (args.eval_interval is not None and len(episode_rewards) > 1 421 | and j % args.eval_interval == 0): 422 | ob_rms = utils.get_vec_normalize(envs).ob_rms 423 | evaluate(actor_critic, ob_rms, args.env_name, args.seed, 424 | args.num_processes, eval_log_dir, device) 425 | 426 | 427 | if __name__ == "__main__": 428 | main() 429 | -------------------------------------------------------------------------------- /bcgail.sh: -------------------------------------------------------------------------------- 1 | python main.py --env-name PongNoFrameskip-v4 --algo ppo --gail --gail-experts-dir /serverdata/rohit/BCGAIL/PongPPO/ --use-gae --lr 2.5e-4 --clip-param 0.1 --value-loss-coef 0.5 --num-processes 8 --num-steps 128 --num-mini-batch 4 --log-interval 1 --use-linear-lr-decay --entropy-coef 0.01 --model_name PongBCGAIL --gail-batch-size 32 --bcgail 1 --gailgamma 1 --decay 100 --num-env-steps 500000 --seed 1 2 | -------------------------------------------------------------------------------- /bconly.sh: -------------------------------------------------------------------------------- 1 | env=$1 2 | total_timesteps=${2:-3000000} 3 | if [ $# -le 0 ]; then 4 | echo './mujoco.sh ' 5 | exit 6 | fi 7 | 8 | 9 | for seed in {1,2,4,} 10 | do 11 | #python main.py --env-name ${env}-v2 --algo ppo --use-gae --log-interval 1 --num-steps 2048 --num-processes 8 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --gae-lambda 0.95 --num-env-steps ${total_timesteps} --use-linear-lr-decay --use-proper-time-limits --gail --model_name ${env}BCnoGAIL --seed ${seed} --bcgail 1 --gailgamma 0.5 --decay 1 --learn 0 & 12 | python main.py --env-name ${env}-v2 --algo ppo --use-gae --log-interval 1 --num-steps 2048 --num-processes 8 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --gae-lambda 0.95 --num-env-steps ${total_timesteps} --use-linear-lr-decay --use-proper-time-limits --gail --model_name ${env}noGAIL --seed ${seed} --learn 0 & 13 | done 14 | -------------------------------------------------------------------------------- /calclength.sh: -------------------------------------------------------------------------------- 1 | for env in {Ant,Hopper,HalfCheetah,Reacher,Walker2d} 2 | do 3 | #for method in {BC,BCnoGAIL,GAIL,noGAIL} 4 | for method in {BCGAIL,} 5 | do 6 | python main.py --env-name ${env}-v2 --algo ppo --use-gae --log-interval 1 --num-steps 2048 --num-processes 8 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 10 --num-mini-batch 32 --record_trajectories 1 --model_name ${env}${method} --load_model_name ${env}${method} --savelength 1 7 | done 8 | done 9 | -------------------------------------------------------------------------------- /enjoy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | # workaround to unpickle olf model files 4 | import sys 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from a2c_ppo_acktr.envs import VecPyTorch, make_vec_envs 10 | from a2c_ppo_acktr.utils import get_render_func, get_vec_normalize 11 | 12 | sys.path.append('a2c_ppo_acktr') 13 | 14 | parser = argparse.ArgumentParser(description='RL') 15 | parser.add_argument( 16 | '--seed', type=int, default=1, help='random seed (default: 1)') 17 | parser.add_argument( 18 | '--log-interval', 19 | type=int, 20 | default=10, 21 | help='log interval, one log per n updates (default: 10)') 22 | parser.add_argument( 23 | '--env-name', 24 | default='PongNoFrameskip-v4', 25 | help='environment to train on (default: PongNoFrameskip-v4)') 26 | parser.add_argument( 27 | '--load-dir', 28 | default='./trained_models/', 29 | help='directory to save agent logs (default: ./trained_models/)') 30 | parser.add_argument( 31 | '--non-det', 32 | action='store_true', 33 | default=False, 34 | help='whether to use a non-deterministic policy') 35 | args = parser.parse_args() 36 | 37 | args.det = not args.non_det 38 | 39 | env = make_vec_envs( 40 | args.env_name, 41 | args.seed + 1000, 42 | 1, 43 | None, 44 | None, 45 | device='cpu', 46 | allow_early_resets=False) 47 | 48 | # Get a render function 49 | render_func = get_render_func(env) 50 | 51 | # We need to use the same statistics for normalization as used in training 52 | actor_critic, ob_rms = \ 53 | torch.load(os.path.join(args.load_dir, args.env_name + ".pt")) 54 | 55 | vec_norm = get_vec_normalize(env) 56 | if vec_norm is not None: 57 | vec_norm.eval() 58 | vec_norm.ob_rms = ob_rms 59 | 60 | recurrent_hidden_states = torch.zeros(1, 61 | actor_critic.recurrent_hidden_state_size) 62 | masks = torch.zeros(1, 1) 63 | 64 | obs = env.reset() 65 | 66 | if render_func is not None: 67 | render_func('human') 68 | 69 | if args.env_name.find('Bullet') > -1: 70 | import pybullet as p 71 | 72 | torsoId = -1 73 | for i in range(p.getNumBodies()): 74 | if (p.getBodyInfo(i)[0].decode() == "torso"): 75 | torsoId = i 76 | 77 | while True: 78 | with torch.no_grad(): 79 | value, action, _, recurrent_hidden_states = actor_critic.act( 80 | obs, recurrent_hidden_states, masks, deterministic=args.det) 81 | 82 | # Obser reward and next obs 83 | obs, reward, done, _ = env.step(action) 84 | 85 | masks.fill_(0.0 if done else 1.0) 86 | 87 | if args.env_name.find('Bullet') > -1: 88 | if torsoId > -1: 89 | distance = 5 90 | yaw = 0 91 | humanPos, humanOrn = p.getBasePositionAndOrientation(torsoId) 92 | p.resetDebugVisualizerCamera(distance, yaw, -20, humanPos) 93 | 94 | if render_func is not None: 95 | render_func('human') 96 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from a2c_ppo_acktr import utils 5 | from a2c_ppo_acktr.envs import make_vec_envs 6 | 7 | 8 | def evaluate(actor_critic, ob_rms, env_name, seed, num_processes, eval_log_dir, 9 | device): 10 | eval_envs = make_vec_envs(env_name, seed + num_processes, num_processes, 11 | None, eval_log_dir, device, True) 12 | 13 | vec_norm = utils.get_vec_normalize(eval_envs) 14 | if vec_norm is not None: 15 | vec_norm.eval() 16 | vec_norm.ob_rms = ob_rms 17 | 18 | eval_episode_rewards = [] 19 | 20 | obs = eval_envs.reset() 21 | eval_recurrent_hidden_states = torch.zeros( 22 | num_processes, actor_critic.recurrent_hidden_state_size, device=device) 23 | eval_masks = torch.zeros(num_processes, 1, device=device) 24 | 25 | while len(eval_episode_rewards) < 10: 26 | with torch.no_grad(): 27 | _, action, _, eval_recurrent_hidden_states = actor_critic.act( 28 | obs, 29 | eval_recurrent_hidden_states, 30 | eval_masks, 31 | deterministic=True) 32 | 33 | # Obser reward and next obs 34 | obs, _, done, infos = eval_envs.step(action) 35 | 36 | eval_masks = torch.tensor( 37 | [[0.0] if done_ else [1.0] for done_ in done], 38 | dtype=torch.float32, 39 | device=device) 40 | 41 | for info in infos: 42 | if 'episode' in info.keys(): 43 | eval_episode_rewards.append(info['episode']['r']) 44 | 45 | eval_envs.close() 46 | 47 | print(" Evaluation using {} episodes: mean reward {:.5f}\n".format( 48 | len(eval_episode_rewards), np.mean(eval_episode_rewards))) 49 | -------------------------------------------------------------------------------- /gail.sh: -------------------------------------------------------------------------------- 1 | ## GAIL baseline 2 | CUDA_VISIBLE_DEVICES=6,7 python main.py --env-name PongNoFrameskip-v4 --algo ppo --gail --gail-experts-dir /serverdata/rohit/BCGAIL/PongPPO/ --use-gae --lr 2.5e-4 --clip-param 0.1 --value-loss-coef 0.5 --num-processes 8 --num-steps 128 --num-mini-batch 4 --log-interval 1 --use-linear-lr-decay --entropy-coef 0.01 --model_name PongGAIL --gail-batch-size 32 --seed 1 3 | -------------------------------------------------------------------------------- /gail_experts/README.md: -------------------------------------------------------------------------------- 1 | ## Data 2 | 3 | Download from 4 | https://drive.google.com/open?id=1Ipu5k99nwewVDG1yFetUxqtwVlgBg5su 5 | 6 | and store in this folder. 7 | 8 | ## Convert to pytorch 9 | 10 | ```bash 11 | python convert_to_pytorch.py --h5-file trajs_halfcheetah.h5 12 | ``` 13 | 14 | ## Run 15 | 16 | ```bash 17 | python main.py --env-name "HalfCheetah-v2" --algo ppo --use-gae --log-interval 1 --num-steps 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --gae-lambda 0.95 --num-env-steps 10000000 --use-linear-lr-decay --use-proper-time-limits --gail 18 | ``` 19 | -------------------------------------------------------------------------------- /gail_experts/convert_to_pytorch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import h5py 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser( 12 | 'Converts expert trajectories from h5 to pt format.') 13 | parser.add_argument( 14 | '--h5-file', 15 | default='trajs_halfcheetah.h5', 16 | help='input h5 file', 17 | type=str) 18 | parser.add_argument( 19 | '--pt-file', 20 | default=None, 21 | help='output pt file, by default replaces file extension with pt', 22 | type=str) 23 | args = parser.parse_args() 24 | 25 | if args.pt_file is None: 26 | args.pt_file = os.path.splitext(args.h5_file)[0] + '.pt' 27 | 28 | with h5py.File(args.h5_file, 'r') as f: 29 | dataset_size = f['obs_B_T_Do'].shape[0] # full dataset size 30 | 31 | states = f['obs_B_T_Do'][:dataset_size, ...][...] 32 | actions = f['a_B_T_Da'][:dataset_size, ...][...] 33 | rewards = f['r_B_T'][:dataset_size, ...][...] 34 | lens = f['len_B'][:dataset_size, ...][...] 35 | 36 | states = torch.from_numpy(states).float() 37 | actions = torch.from_numpy(actions).float() 38 | rewards = torch.from_numpy(rewards).float() 39 | lens = torch.from_numpy(lens).long() 40 | 41 | data = { 42 | 'states': states, 43 | 'actions': actions, 44 | 'rewards': rewards, 45 | 'lengths': lens 46 | } 47 | 48 | torch.save(data, args.pt_file) 49 | 50 | 51 | if __name__ == '__main__': 52 | main() 53 | -------------------------------------------------------------------------------- /generate_tmux_yaml.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import yaml 4 | 5 | parser = argparse.ArgumentParser(description='Process some integers.') 6 | parser.add_argument( 7 | '--num-seeds', 8 | type=int, 9 | default=4, 10 | help='number of random seeds to generate') 11 | parser.add_argument( 12 | '--env-names', 13 | default="PongNoFrameskip-v4", 14 | help='environment name separated by semicolons') 15 | args = parser.parse_args() 16 | 17 | ppo_mujoco_template = "python main.py --env-name {0} --algo ppo --use-gae --log-interval 1 --num-steps 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay --no-cuda --log-dir /tmp/gym/{1}/{1}-{2} --seed {2} --use-proper-time-limits" 18 | 19 | ppo_atari_template = "env CUDA_VISIBLE_DEVICES={2} python main.py --env-name {0} --algo ppo --use-gae --lr 2.5e-4 --clip-param 0.1 --value-loss-coef 0.5 --num-processes 8 --num-steps 128 --num-mini-batch 4 --log-interval 1 --use-linear-lr-decay --entropy-coef 0.01 --log-dir /tmp/gym/{1}/{1}-{2} --seed {2}" 20 | 21 | template = ppo_atari_template 22 | 23 | config = {"session_name": "run-all", "windows": []} 24 | 25 | for i in range(args.num_seeds): 26 | panes_list = [] 27 | for env_name in args.env_names.split(';'): 28 | panes_list.append( 29 | template.format(env_name, 30 | env_name.split('-')[0].lower(), i)) 31 | 32 | config["windows"].append({ 33 | "window_name": "seed-{}".format(i), 34 | "panes": panes_list 35 | }) 36 | 37 | yaml.dump(config, open("run_all.yaml", "w"), default_flow_style=False) 38 | -------------------------------------------------------------------------------- /imgs/a2c_beamrider.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitrango/BC-regularized-GAIL/08e18fb407587e0383cf53dab5b75a33738c9ce0/imgs/a2c_beamrider.png -------------------------------------------------------------------------------- /imgs/a2c_breakout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitrango/BC-regularized-GAIL/08e18fb407587e0383cf53dab5b75a33738c9ce0/imgs/a2c_breakout.png -------------------------------------------------------------------------------- /imgs/a2c_qbert.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitrango/BC-regularized-GAIL/08e18fb407587e0383cf53dab5b75a33738c9ce0/imgs/a2c_qbert.png -------------------------------------------------------------------------------- /imgs/a2c_seaquest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitrango/BC-regularized-GAIL/08e18fb407587e0383cf53dab5b75a33738c9ce0/imgs/a2c_seaquest.png -------------------------------------------------------------------------------- /imgs/acktr_beamrider.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitrango/BC-regularized-GAIL/08e18fb407587e0383cf53dab5b75a33738c9ce0/imgs/acktr_beamrider.png -------------------------------------------------------------------------------- /imgs/acktr_breakout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitrango/BC-regularized-GAIL/08e18fb407587e0383cf53dab5b75a33738c9ce0/imgs/acktr_breakout.png -------------------------------------------------------------------------------- /imgs/acktr_qbert.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitrango/BC-regularized-GAIL/08e18fb407587e0383cf53dab5b75a33738c9ce0/imgs/acktr_qbert.png -------------------------------------------------------------------------------- /imgs/acktr_seaquest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitrango/BC-regularized-GAIL/08e18fb407587e0383cf53dab5b75a33738c9ce0/imgs/acktr_seaquest.png -------------------------------------------------------------------------------- /imgs/ppo_halfcheetah.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitrango/BC-regularized-GAIL/08e18fb407587e0383cf53dab5b75a33738c9ce0/imgs/ppo_halfcheetah.png -------------------------------------------------------------------------------- /imgs/ppo_hopper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitrango/BC-regularized-GAIL/08e18fb407587e0383cf53dab5b75a33738c9ce0/imgs/ppo_hopper.png -------------------------------------------------------------------------------- /imgs/ppo_reacher.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitrango/BC-regularized-GAIL/08e18fb407587e0383cf53dab5b75a33738c9ce0/imgs/ppo_reacher.png -------------------------------------------------------------------------------- /imgs/ppo_walker.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rohitrango/BC-regularized-GAIL/08e18fb407587e0383cf53dab5b75a33738c9ce0/imgs/ppo_walker.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import glob 3 | import os 4 | import time 5 | from collections import deque 6 | 7 | import gym 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | 14 | from a2c_ppo_acktr import algo, utils 15 | from a2c_ppo_acktr.algo import gail 16 | from a2c_ppo_acktr.arguments import get_args 17 | from a2c_ppo_acktr.envs import make_vec_envs 18 | from a2c_ppo_acktr.model import Policy 19 | from a2c_ppo_acktr.storage import RolloutStorage 20 | from evaluation import evaluate 21 | 22 | 23 | def record_trajectories(): 24 | args = get_args() 25 | print(args) 26 | 27 | torch.manual_seed(args.seed) 28 | torch.cuda.manual_seed_all(args.seed) 29 | 30 | if args.cuda and torch.cuda.is_available() and args.cuda_deterministic: 31 | torch.backends.cudnn.benchmark = False 32 | torch.backends.cudnn.deterministic = True 33 | 34 | # Append the model name 35 | log_dir = os.path.expanduser(args.log_dir) 36 | log_dir = os.path.join(log_dir, args.model_name, str(args.seed)) 37 | 38 | eval_log_dir = log_dir + "_eval" 39 | utils.cleanup_log_dir(log_dir) 40 | utils.cleanup_log_dir(eval_log_dir) 41 | 42 | torch.set_num_threads(1) 43 | device = torch.device("cuda:0" if args.cuda else "cpu") 44 | 45 | envs = make_vec_envs(args.env_name, args.seed, 1, 46 | args.gamma, log_dir, device, True, training=False, red=True) 47 | 48 | # Take activation for carracing 49 | print("Loaded env...") 50 | activation = None 51 | if args.env_name == 'CarRacing-v0' and args.use_activation: 52 | activation = torch.tanh 53 | print(activation) 54 | 55 | actor_critic = Policy( 56 | envs.observation_space.shape, 57 | envs.action_space, 58 | base_kwargs={'recurrent': args.recurrent_policy, 'env': args.env_name}, 59 | activation=activation, 60 | ) 61 | actor_critic.to(device) 62 | 63 | # Load from previous model 64 | if args.load_model_name and args.model_name != 'random': 65 | loaddata = torch.load(os.path.join(args.save_dir, args.load_model_name, args.load_model_name + '_{}.pt'.format(args.seed))) 66 | state = loaddata[0] 67 | try: 68 | #print(envs.venv, loaddata[2]) 69 | if not args.load_model_name.endswith('BC'): 70 | obs_rms, ret_rms = loaddata[1:] 71 | envs.venv.ob_rms = obs_rms 72 | envs.venv.ret_rms = None 73 | #envs.venv.ret_rms = ret_rms 74 | print(obs_rms) 75 | print("Loaded obs rms") 76 | else: 77 | envs.venv.ob_rms = None 78 | envs.venv.ret_rms = None 79 | print("Its BC, no normalization from env") 80 | except: 81 | print("Couldnt load obsrms") 82 | obs_rms = ret_rms = None 83 | try: 84 | actor_critic.load_state_dict(state) 85 | except: 86 | actor_critic = state 87 | elif args.load_model_name and args.model_name == 'random': 88 | print("Using random policy...") 89 | envs.venv.ret_rms = None 90 | envs.venv.ob_rms = None 91 | else: 92 | raise NotImplementedError 93 | 94 | # Record trajectories 95 | actions = [] 96 | rewards = [] 97 | observations = [] 98 | episode_starts = [] 99 | total_rewards = [] 100 | 101 | last_length = 0 102 | 103 | for eps in range(args.num_episodes): 104 | obs = envs.reset() 105 | # Init variables for storing 106 | episode_starts.append(True) 107 | reward = 0 108 | while True: 109 | # Take action 110 | act = actor_critic.act(obs, None, None, None)[1] 111 | next_state, rew, done, info = envs.step(act) 112 | #print(obs.shape, act.shape, rew.shape, done) 113 | reward += rew 114 | # Add the current observation and act 115 | observations.append(obs.data.cpu().numpy()[0]) # [C, H, W] 116 | actions.append(act.data.cpu().numpy()[0]) # [A] 117 | rewards.append(rew[0, 0].data.cpu().numpy()) 118 | if done[0]: 119 | break 120 | episode_starts.append(False) 121 | obs = next_state + 0 122 | print("Total reward: {}".format(reward[0, 0].data.cpu().numpy())) 123 | print("Total length: {}".format(len(observations) - last_length)) 124 | last_length = len(observations) 125 | total_rewards.append(reward[0, 0].data.cpu().numpy()) 126 | 127 | # Save these values 128 | ''' 129 | if len(envs.observation_space.shape) == 3: 130 | save_trajectories_images(observations, actions, rewards, episode_starts) 131 | else: 132 | save_trajectories(observations, actions, rewards, episode_starts) 133 | ''' 134 | pathname = args.load_model_name if args.model_name != 'random' else 'random' 135 | prefix = 'length' if args.savelength else '' 136 | with open(os.path.join(args.save_dir, args.load_model_name, pathname + '_{}{}.txt'.format(prefix, args.seed)), 'wt') as f: 137 | if not args.savelength: 138 | for rew in total_rewards: 139 | f.write('{}\n'.format(rew)) 140 | else: 141 | avg_length = len(observations) / args.num_episodes 142 | f.write("{}\n".format(avg_length)) 143 | 144 | 145 | def save_trajectories(obs, acts, rews, eps): 146 | raise NotImplementedError 147 | 148 | 149 | def save_trajectories_images(obs, acts, rews, eps): 150 | # Save images only 151 | args = get_args() 152 | obs_path = [] 153 | acts = np.array(acts) 154 | rews = np.array(rews) 155 | eps = np.array(eps) 156 | print(acts.shape, rews.shape, eps.shape) 157 | 158 | # Get image dir to save 159 | save_dir = os.path.join(args.save_dir, args.load_model_name, ) 160 | image_dir = os.path.join(save_dir, 'images') 161 | os.makedirs(save_dir, exist_ok=True) 162 | os.makedirs(image_dir, exist_ok=True) 163 | 164 | image_id = 0 165 | for ob in obs: 166 | # Scaled image from [0, 1] 167 | path = os.path.join(image_dir, str(image_id)) 168 | obimg = (ob * 255).astype(np.uint8).transpose(1, 2, 0) # [H, W, C] 169 | # Save image and record image path 170 | np.save(path, obimg) 171 | obs_path.append(path) 172 | image_id += 1 173 | 174 | expert_dict = { 175 | 'obs': obs_path, 176 | 'actions': acts, 177 | 'rewards': rews, 178 | 'episode_starts': eps, 179 | } 180 | 181 | torch.save(expert_dict, os.path.join(save_dir, 'expert_data.pkl')) 182 | print("Saved") 183 | 184 | def main(): 185 | args = get_args() 186 | 187 | # Record trajectories 188 | if args.record_trajectories: 189 | record_trajectories() 190 | return 191 | 192 | print(args) 193 | torch.manual_seed(args.seed) 194 | torch.cuda.manual_seed_all(args.seed) 195 | 196 | if args.cuda and torch.cuda.is_available() and args.cuda_deterministic: 197 | torch.backends.cudnn.benchmark = False 198 | torch.backends.cudnn.deterministic = True 199 | 200 | # Append the model name 201 | log_dir = os.path.expanduser(args.log_dir) 202 | log_dir = os.path.join(log_dir, args.model_name, str(args.seed)) 203 | 204 | eval_log_dir = log_dir + "_eval" 205 | utils.cleanup_log_dir(log_dir) 206 | utils.cleanup_log_dir(eval_log_dir) 207 | 208 | torch.set_num_threads(1) 209 | device = torch.device("cuda:0" if args.cuda else "cpu") 210 | 211 | envs = make_vec_envs(args.env_name, args.seed, args.num_processes, 212 | args.gamma, log_dir, device, False) 213 | obs_shape = len(envs.observation_space.shape) 214 | 215 | # Take activation for carracing 216 | print("Loaded env...") 217 | activation = None 218 | if args.env_name == 'CarRacing-v0' and args.use_activation: 219 | activation = torch.tanh 220 | print(activation) 221 | 222 | actor_critic = Policy( 223 | envs.observation_space.shape, 224 | envs.action_space, 225 | base_kwargs={'recurrent': args.recurrent_policy, 'env': args.env_name}, 226 | activation=activation 227 | ) 228 | actor_critic.to(device) 229 | # Load from previous model 230 | if args.load_model_name: 231 | state = torch.load(os.path.join(args.save_dir, args.load_model_name, args.load_model_name + '_{}.pt'.format(args.seed)))[0] 232 | try: 233 | actor_critic.load_state_dict(state) 234 | except: 235 | actor_critic = state 236 | 237 | # If BCGAIL, then decay factor and gamma should be float 238 | if args.bcgail: 239 | assert type(args.decay) == float 240 | assert type(args.gailgamma) == float 241 | if args.decay < 0: 242 | args.decay = 1 243 | elif args.decay > 1: 244 | args.decay = 0.5**(1./args.decay) 245 | 246 | print('Gamma: {}, decay: {}'.format(args.gailgamma, args.decay)) 247 | print('BCGAIL used') 248 | 249 | if args.algo == 'a2c': 250 | agent = algo.A2C_ACKTR( 251 | actor_critic, 252 | args.value_loss_coef, 253 | args.entropy_coef, 254 | lr=args.lr, 255 | eps=args.eps, 256 | alpha=args.alpha, 257 | max_grad_norm=args.max_grad_norm) 258 | elif args.algo == 'ppo': 259 | agent = algo.PPO( 260 | actor_critic, 261 | args.clip_param, 262 | args.ppo_epoch, 263 | args.num_mini_batch, 264 | args.value_loss_coef, 265 | args.entropy_coef, 266 | lr=args.lr, 267 | eps=args.eps, 268 | gamma=args.gailgamma, 269 | decay=args.decay, 270 | act_space=envs.action_space, 271 | max_grad_norm=args.max_grad_norm) 272 | elif args.algo == 'acktr': 273 | agent = algo.A2C_ACKTR( 274 | actor_critic, args.value_loss_coef, args.entropy_coef, acktr=True) 275 | 276 | if args.gail: 277 | if len(envs.observation_space.shape) == 1: 278 | 279 | # Load RED here 280 | red = None 281 | if args.red: 282 | red = gail.RED(envs.observation_space.shape[0] + envs.action_space.shape[0], 283 | 100, device, args.redsigma, args.rediters) 284 | 285 | discr = gail.Discriminator( 286 | envs.observation_space.shape[0] + envs.action_space.shape[0], 100, 287 | device, red=red, sail=args.sail, learn=args.learn) 288 | file_name = os.path.join( 289 | args.gail_experts_dir, "trajs_{}.pt".format( 290 | args.env_name.split('-')[0].lower())) 291 | 292 | expert_dataset = gail.ExpertDataset( 293 | file_name, num_trajectories=args.num_traj, subsample_frequency=1) 294 | args.gail_batch_size = min(args.gail_batch_size, len(expert_dataset)) 295 | drop_last = len(expert_dataset) > args.gail_batch_size 296 | gail_train_loader = torch.utils.data.DataLoader( 297 | dataset=expert_dataset, 298 | batch_size=args.gail_batch_size, 299 | shuffle=True, 300 | drop_last=drop_last) 301 | print("Data loader size", len(expert_dataset)) 302 | else: 303 | # env observation shape is 3 => its an image 304 | assert len(envs.observation_space.shape) == 3 305 | discr = gail.CNNDiscriminator( 306 | envs.observation_space.shape, envs.action_space, 100, 307 | device) 308 | file_name = os.path.join( 309 | args.gail_experts_dir, 'expert_data.pkl') 310 | 311 | expert_dataset = gail.ExpertImageDataset(file_name, act=envs.action_space) 312 | gail_train_loader = torch.utils.data.DataLoader( 313 | dataset=expert_dataset, 314 | batch_size=args.gail_batch_size, 315 | shuffle=True, 316 | drop_last = len(expert_dataset) > args.gail_batch_size, 317 | ) 318 | print('Dataloader size', len(gail_train_loader)) 319 | 320 | rollouts = RolloutStorage(args.num_steps, args.num_processes, 321 | envs.observation_space.shape, envs.action_space, 322 | actor_critic.recurrent_hidden_state_size) 323 | 324 | obs = envs.reset() 325 | rollouts.obs[0].copy_(obs) 326 | rollouts.to(device) 327 | 328 | episode_rewards = deque(maxlen=10) 329 | start = time.time() 330 | num_updates = int( 331 | args.num_env_steps) // args.num_steps // args.num_processes 332 | print(num_updates) 333 | for j in range(num_updates): 334 | if args.use_linear_lr_decay: 335 | # decrease learning rate linearly 336 | utils.update_linear_schedule( 337 | agent.optimizer, j, num_updates, 338 | agent.optimizer.lr if args.algo == "acktr" else args.lr) 339 | 340 | for step in range(args.num_steps): 341 | # Sample actions 342 | with torch.no_grad(): 343 | value, action, action_log_prob, recurrent_hidden_states = actor_critic.act( 344 | rollouts.obs[step], rollouts.recurrent_hidden_states[step], 345 | rollouts.masks[step]) 346 | 347 | # Observe reward and next obs 348 | obs, reward, done, infos = envs.step(action) 349 | for info in infos: 350 | if 'episode' in info.keys(): 351 | episode_rewards.append(info['episode']['r']) 352 | 353 | # If done then clean the history of observations. 354 | masks = torch.FloatTensor( 355 | [[0.0] if done_ else [1.0] for done_ in done]) 356 | bad_masks = torch.FloatTensor( 357 | [[0.0] if 'bad_transition' in info.keys() else [1.0] 358 | for info in infos]) 359 | rollouts.insert(obs, recurrent_hidden_states, action, 360 | action_log_prob, value, reward, masks, bad_masks) 361 | 362 | with torch.no_grad(): 363 | next_value = actor_critic.get_value( 364 | rollouts.obs[-1], rollouts.recurrent_hidden_states[-1], 365 | rollouts.masks[-1]).detach() 366 | 367 | if args.gail: 368 | if j >= 10: 369 | try: 370 | envs.venv.eval() 371 | except: 372 | pass 373 | 374 | gail_epoch = args.gail_epoch 375 | if j < 10 and obs_shape == 1: 376 | gail_epoch = 100 # Warm up 377 | for _ in range(gail_epoch): 378 | if obs_shape == 1: 379 | discr.update(gail_train_loader, rollouts, 380 | utils.get_vec_normalize(envs)._obfilt) 381 | else: 382 | discr.update(gail_train_loader, rollouts, 383 | None) 384 | if obs_shape == 3: 385 | obfilt = None 386 | else: 387 | obfilt = utils.get_vec_normalize(envs)._rev_obfilt 388 | 389 | for step in range(args.num_steps): 390 | rollouts.rewards[step] = discr.predict_reward( 391 | rollouts.obs[step], rollouts.actions[step], args.gamma, 392 | rollouts.masks[step], obfilt) # The reverse function is passed down for RED to receive unnormalized obs which it is trained on 393 | 394 | rollouts.compute_returns(next_value, args.use_gae, args.gamma, 395 | args.gae_lambda, args.use_proper_time_limits) 396 | 397 | if args.bcgail: 398 | if obs_shape == 3: 399 | obfilt = None 400 | else: 401 | obfilt = utils.get_vec_normalize(envs)._obfilt 402 | value_loss, action_loss, dist_entropy = agent.update(rollouts, gail_train_loader, obfilt) 403 | else: 404 | value_loss, action_loss, dist_entropy = agent.update(rollouts) 405 | 406 | rollouts.after_update() 407 | 408 | # save for every interval-th episode or for the last epoch 409 | if (j % args.save_interval == 0 410 | or j == num_updates - 1) and args.save_dir != "": 411 | save_path = os.path.join(args.save_dir, args.model_name) 412 | try: 413 | os.makedirs(save_path) 414 | except OSError: 415 | pass 416 | 417 | torch.save([ 418 | actor_critic.state_dict(), 419 | getattr(utils.get_vec_normalize(envs), 'ob_rms', None), 420 | getattr(utils.get_vec_normalize(envs), 'ret_rms', None) 421 | ], os.path.join(save_path, args.model_name + "_{}.pt".format(args.seed))) 422 | 423 | if j % args.log_interval == 0 and len(episode_rewards) > 1: 424 | total_num_steps = (j + 1) * args.num_processes * args.num_steps 425 | end = time.time() 426 | print( 427 | "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n" 428 | .format(j, total_num_steps, 429 | int(total_num_steps / (end - start)), 430 | len(episode_rewards), np.mean(episode_rewards), 431 | np.median(episode_rewards), np.min(episode_rewards), 432 | np.max(episode_rewards), dist_entropy, value_loss, 433 | action_loss)) 434 | 435 | if (args.eval_interval is not None and len(episode_rewards) > 1 436 | and j % args.eval_interval == 0): 437 | ob_rms = utils.get_vec_normalize(envs).ob_rms 438 | evaluate(actor_critic, ob_rms, args.env_name, args.seed, 439 | args.num_processes, eval_log_dir, device) 440 | 441 | 442 | if __name__ == "__main__": 443 | main() 444 | -------------------------------------------------------------------------------- /mujoco.sh: -------------------------------------------------------------------------------- 1 | env=$1 2 | total_timesteps=${2:-3000000} 3 | tt=3000000 # Change this for specific envs 4 | if [ $# -le 0 ]; then 5 | echo './mujoco.sh ' 6 | exit 7 | fi 8 | 9 | 10 | for seed in {1,2,4,} 11 | do 12 | python main.py --env-name ${env}-v2 --algo ppo --use-gae --log-interval 1 --num-steps 2048 --num-processes 8 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --gae-lambda 0.95 --num-env-steps ${tt} --use-linear-lr-decay --use-proper-time-limits --gail --model_name ${env}BCGAIL --seed ${seed} --bcgail 1 --gailgamma 1 --decay 10 & 13 | #python main.py --env-name ${env}-v2 --algo ppo --use-gae --log-interval 1 --num-steps 2048 --num-processes 8 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --gae-lambda 0.95 --num-env-steps $total_timesteps --use-linear-lr-decay --use-proper-time-limits --gail --model_name ${env}GAIL --seed ${seed} 14 | done 15 | -------------------------------------------------------------------------------- /plot_graphs.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | from matplotlib import pyplot as plt 7 | import argparse 8 | import torch 9 | import os 10 | import sys 11 | from baselines.common import plot_util as pu 12 | from baselines.common.plot_util import COLORS 13 | import warnings 14 | import matplotlib as mplot 15 | mplot.rcParams.update({'font.size': 15}) 16 | 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--files', nargs='+', required=True, 20 | help='Get filenames') 21 | parser.add_argument('--legend', nargs='+', default=[], 22 | help='Legend values') 23 | parser.add_argument('--max_steps', default=0, type=int) 24 | parser.add_argument('--yscale', default='linear', type=str) 25 | parser.add_argument('--bcpath', default='', type=str) 26 | args = parser.parse_args() 27 | 28 | def check_last_name(result): 29 | path = result.dirname 30 | splits = path.split('/') 31 | for sp in splits[::-1]: 32 | if sp == '': 33 | continue 34 | try: 35 | sp = int(sp) 36 | except: 37 | return sp 38 | return '' 39 | 40 | 41 | def main(): 42 | """ 43 | Plot the plots inside the folder given 44 | """ 45 | # Now plot the common things 46 | splits = args.files[0].split('/') 47 | if splits[-1] == '': 48 | splits = splits[-2] 49 | else: 50 | splits = splits[-1] 51 | env = splits 52 | results = pu.load_results(args.files, ) 53 | fig = pu.plot_results(results, average_group=True, 54 | shaded_err=False, 55 | shaded_std=True, 56 | group_fn=lambda _: check_last_name(_), 57 | split_fn=lambda _: '', figsize=(10, 10)) 58 | 59 | # Add results for behaviour cloning if present 60 | allbcfiles = [args.bcpath] 61 | allfiles = [] 62 | allrandomfiles = [] # For random agent behavior 63 | 64 | for file in allbcfiles: 65 | for r, dirs, files in os.walk(file): 66 | print(files) 67 | txtfiles = list(filter(lambda x: 'BC_' in x and '.txt' in x, files)) 68 | rndfiles = list(filter(lambda x: 'random_' in x and '.txt' in x, files)) 69 | allfiles.extend(list(map(lambda x: os.path.join(r, x), txtfiles))) 70 | allrandomfiles.extend(list(map(lambda x: os.path.join(r, x), rndfiles))) 71 | 72 | ## Show all files for BC and plot 73 | print(allfiles) 74 | if allfiles != []: 75 | bcreward = [] 76 | for file in allfiles: 77 | with open(file, 'r') as fi: 78 | rews = fi.read().split('\n') 79 | rews = filter(lambda x: x != '', rews) 80 | rews = list(map(lambda x: float(x), rews)) 81 | bcreward.extend(rews) 82 | 83 | # Get mean and std 84 | #print(bcreward) 85 | mean = np.mean(bcreward) 86 | std = np.std(bcreward) 87 | idxcolor=10 88 | plt.plot([0, args.max_steps], [mean, mean], label='BC', color=COLORS[idxcolor]) 89 | plt.fill_between([0, args.max_steps], [mean - std, mean - std], [mean + std, mean + std], alpha=0.2, color=COLORS[idxcolor]) 90 | 91 | ## Get random policy 92 | if allrandomfiles != []: 93 | rndreward = [] 94 | for file in allrandomfiles: 95 | with open(file, 'r') as fi: 96 | rews = fi.read().split('\n') 97 | rews = filter(lambda x: x != '', rews) 98 | rews = list(map(lambda x: float(x), rews)) 99 | rndreward.extend(rews) 100 | 101 | # Get mean and std 102 | #print(bcreward) 103 | mean = np.mean(rndreward) 104 | plt.plot([0, args.max_steps], [mean, mean], label='random', color='gray', linestyle='dashed') 105 | 106 | plt.xlabel('# environment interactions', fontsize=20) 107 | envnamehere = 'ant' 108 | if env.lower().startswith(envnamehere): 109 | plt.ylim(ymin=-5000, ymax=5000) 110 | if env.lower().startswith(''): 111 | plt.ylabel('Reward', fontsize=30) 112 | plt.yscale(args.yscale) 113 | plt.title(env.replace('BC','').replace('GAIL', '').replace('no', '').replace('alph', ''), \ 114 | fontsize=50) 115 | 116 | if env.lower().startswith(envnamehere): 117 | if args.legend != []: 118 | if allfiles != []: 119 | args.legend.append('BC') 120 | plt.legend(args.legend, fontsize=30, loc='bottom right') 121 | else: 122 | plt.legend().set_visible(False) 123 | #plt.ticklabel_format(useOffset=1) 124 | plt.savefig('{}.png'.format(env), bbox_inches='tight', ) 125 | print("saved ", env) 126 | 127 | 128 | if __name__ == "__main__": 129 | with warnings.catch_warnings(): 130 | warnings.simplefilter('ignore') 131 | main() 132 | -------------------------------------------------------------------------------- /plot_length_bar_graphs.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | from matplotlib import pyplot as plt 7 | import argparse 8 | import torch 9 | import os 10 | import sys 11 | from baselines.common import plot_util as pu 12 | from baselines.common.plot_util import COLORS 13 | import warnings 14 | import matplotlib as mplot 15 | mplot.rcParams.update({'font.size': 15}) 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--dirs', type=str, required=True, 19 | help='Get directories') 20 | args = parser.parse_args() 21 | 22 | 23 | def main(): 24 | """ 25 | Plot the plots inside the folder given 26 | """ 27 | # Now plot the common things 28 | print(args.dirs) 29 | envname = args.dirs.split('/')[-1].split('*')[0] 30 | dirs = glob.glob(args.dirs) 31 | allfiles = [] 32 | for d in dirs: 33 | for r, _, files in os.walk(d): 34 | files = list(filter(lambda x: 'txt' in x and 'length' in x, files)) 35 | files = list(map(lambda x: os.path.join(r, x), files)) 36 | allfiles.extend(files) 37 | allfiles = sorted(allfiles) 38 | # given all the files, plot them 39 | vals = [] 40 | for fil in allfiles: 41 | with open(fil, 'r') as fi: 42 | v = float(fi.read().replace('\n', '')) 43 | vals.append(v) 44 | # Make a bar graph 45 | plt.bar(np.arange(len(vals)), vals, color=['red', 'green', 'lightgreen', 'blue', 'lightblue'], alpha=0.6) 46 | print([x.split('/')[-1] for x in allfiles]) 47 | plt.xticks(range(5), ['']*5) 48 | #plt.xticks(np.arange(5), ['BC', 'Ours', 'Ours (noDisTr)', 'GAIL', 'GAIL (noDisTr)'], rotation=30) 49 | method = ['BC', 'Ours', 'Ours (no disc. training)', 'GAIL', 'GAIL (no disc. training)'] 50 | for i in range(5): 51 | plt.text(i, 0, str(" " + method[i]), rotation=90, ha='center', va='bottom', fontsize=18) 52 | plt.xlabel('Method') 53 | plt.ylabel('Avg episode length') 54 | plt.title(envname) 55 | plt.tight_layout() 56 | plt.savefig('{}_barplot.png'.format(envname)) 57 | 58 | 59 | 60 | if __name__ == "__main__": 61 | with warnings.catch_warnings(): 62 | warnings.simplefilter('ignore') 63 | main() 64 | -------------------------------------------------------------------------------- /plots/plotalpha.sh: -------------------------------------------------------------------------------- 1 | python plot_graphs.py --files /serverdata/rohit/BCGAIL/logs/HopperalphBCGAIL /serverdata/rohit/BCGAIL/logs/HopperalphBCGAIL0.25 /serverdata/rohit/BCGAIL/logs/HopperalphBCGAIL0.50 /serverdata/rohit/BCGAIL/logs/HopperalphBCGAIL0.75 --legend $\\alpha=\\alpha_0^T$ $\\alpha=0.25$ $\\alpha=0.50$ $\\alpha=0.75$ 2 | python plot_graphs.py --files /serverdata/rohit/BCGAIL/logs/AntalphBCGAIL /serverdata/rohit/BCGAIL/logs/AntalphBCGAIL0.25 /serverdata/rohit/BCGAIL/logs/AntalphBCGAIL0.50 /serverdata/rohit/BCGAIL/logs/AntalphBCGAIL0.75 --legend $\\alpha=\\alpha_0^T$ $\\alpha=0.25$ $\\alpha=0.50$ $\\alpha=0.75$ 3 | python plot_graphs.py --files /serverdata/rohit/BCGAIL/logs/HalfCheetahalphBCGAIL /serverdata/rohit/BCGAIL/logs/HalfCheetahalphBCGAIL0.25 /serverdata/rohit/BCGAIL/logs/HalfCheetahalphBCGAIL0.50 /serverdata/rohit/BCGAIL/logs/HalfCheetahalphBCGAIL0.75 --legend $\\alpha=\\alpha_0^T$ $\\alpha=0.25$ $\\alpha=0.50$ $\\alpha=0.75$ 4 | python plot_graphs.py --files /serverdata/rohit/BCGAIL/logs/Walker2dalphBCGAIL /serverdata/rohit/BCGAIL/logs/Walker2dalphBCGAIL0.25 /serverdata/rohit/BCGAIL/logs/Walker2dalphBCGAIL0.50 /serverdata/rohit/BCGAIL/logs/Walker2dalphBCGAIL0.75 --legend $\\alpha=\\alpha_0^T$ $\\alpha=0.25$ $\\alpha=0.50$ $\\alpha=0.75$ 5 | python plot_graphs.py --files /serverdata/rohit/BCGAIL/logs/ReacheralphBCGAIL /serverdata/rohit/BCGAIL/logs/ReacheralphBCGAIL0.25 /serverdata/rohit/BCGAIL/logs/ReacheralphBCGAIL0.50 /serverdata/rohit/BCGAIL/logs/ReacheralphBCGAIL0.75 --legend $\\alpha=\\alpha_0^T$ $\\alpha=0.25$ $\\alpha=0.50$ $\\alpha=0.75$ 6 | -------------------------------------------------------------------------------- /plots/plotbcgail.sh: -------------------------------------------------------------------------------- 1 | python plot_graphs.py --files /serverdata/rohit/BCGAIL/logs/HopperBCGAIL/ /serverdata/rohit/BCGAIL/logs/HopperGAIL/ /serverdata/rohit/BCGAIL/logs/HopperGAILpretrain/ /serverdata/rohit/BCGAIL/logs/HopperRED/ /serverdata/rohit/BCGAIL/logs/HopperSAIL/ --bcpath /serverdata/rohit/BCGAIL/HopperBC --max_steps 1000000 --legend Ours GAIL BC+GAIL RED SAIL 2 | python plot_graphs.py --files /serverdata/rohit/BCGAIL/logs/AntBCGAIL/ /serverdata/rohit/BCGAIL/logs/AntGAIL/ /serverdata/rohit/BCGAIL/logs/AntGAILpretrain/ /serverdata/rohit/BCGAIL/logs/AntRED/ /serverdata/rohit/BCGAIL/logs/AntSAIL/ --bcpath /serverdata/rohit/BCGAIL/AntBC --max_steps 3000000 --legend Ours GAIL BC+GAIL RED SAIL 3 | python plot_graphs.py --files /serverdata/rohit/BCGAIL/logs/ReacherBCGAIL/ /serverdata/rohit/BCGAIL/logs/ReacherGAIL/ /serverdata/rohit/BCGAIL/logs/ReacherGAILpretrain/ /serverdata/rohit/BCGAIL/logs/ReacherRED/ /serverdata/rohit/BCGAIL/logs/ReacherSAIL/ --bcpath /serverdata/rohit/BCGAIL/ReacherBC --max_steps 2000000 --legend Ours GAIL BC+GAIL RED SAIL 4 | python plot_graphs.py --files /serverdata/rohit/BCGAIL/logs/Walker2dBCGAIL/ /serverdata/rohit/BCGAIL/logs/Walker2dGAIL/ /serverdata/rohit/BCGAIL/logs/Walker2dGAILpretrain/ /serverdata/rohit/BCGAIL/logs/Walker2dRED/ /serverdata/rohit/BCGAIL/logs/Walker2dSAIL/ --bcpath /serverdata/rohit/BCGAIL/Walker2dBC --max_steps 3000000 --legend Ours GAIL BC+GAIL RED SAIL 5 | python plot_graphs.py --files /serverdata/rohit/BCGAIL/logs/HalfCheetahBCGAIL/ /serverdata/rohit/BCGAIL/logs/HalfCheetahGAIL/ /serverdata/rohit/BCGAIL/logs/HalfCheetahGAILpretrain/ /serverdata/rohit/BCGAIL/logs/HalfCheetahRED/ /serverdata/rohit/BCGAIL/logs/HalfCheetahSAIL/ --bcpath /serverdata/rohit/BCGAIL/HalfCheetahBC --max_steps 3000000 --legend Ours GAIL BC+GAIL RED SAIL 6 | 7 | 8 | 9 | 10 | # python plot_graphs.py --files /serverdata/rohit/BCGAIL/logs/HalfCheetahBCnoGAIL/ /serverdata/rohit/BCGAIL/logs/HalfCheetahBCGAIL/ /serverdata/rohit/BCGAIL/logs/HalfCheetahGAIL/ --bcpath /serverdata/rohit/BCGAIL/HalfCheetahBC --max_steps 3000000 --legend Ours Ours(no GAIL training) GAIL 11 | # python plot_graphs.py --files /serverdata/rohit/BCGAIL/logs/AntalphBCGAIL /serverdata/rohit/BCGAIL/logs/AntalphBCGAIL0.25 /serverdata/rohit/BCGAIL/logs/AntalphBCGAIL0.50 /serverdata/rohit/BCGAIL/logs/AntalphBCGAIL0.75 --legend $\alpha=\alpha_0^T$ $\alpha=0.25$ $\alpha=0.50$ $\alpha=0.75$ 12 | -------------------------------------------------------------------------------- /plots/plotnogail.sh: -------------------------------------------------------------------------------- 1 | python plot_graphs.py --files /serverdata/rohit/BCGAIL/logs/HalfCheetahBCnoGAIL/ /serverdata/rohit/BCGAIL/logs/HalfCheetahnoGAIL/ /serverdata/rohit/BCGAIL/logs/HalfCheetahBCGAIL/ /serverdata/rohit/BCGAIL/logs/HalfCheetahGAIL/ --bcpath /serverdata/rohit/BCGAIL/HalfCheetahBC --max_steps 3000000 --legend Ours "Ours(no GAIL training)" GAIL "GAIL(no GAIL training)" 2 | python plot_graphs.py --files /serverdata/rohit/BCGAIL/logs/AntBCnoGAIL/ /serverdata/rohit/BCGAIL/logs/AntnoGAIL/ /serverdata/rohit/BCGAIL/logs/AntBCGAIL/ /serverdata/rohit/BCGAIL/logs/AntGAIL/ --bcpath /serverdata/rohit/BCGAIL/AntBC --max_steps 3000000 --legend Ours "Ours(no GAIL training)" GAIL "GAIL(no GAIL training)" 3 | python plot_graphs.py --files /serverdata/rohit/BCGAIL/logs/HopperBCnoGAIL/ /serverdata/rohit/BCGAIL/logs/HoppernoGAIL/ /serverdata/rohit/BCGAIL/logs/HopperBCGAIL/ /serverdata/rohit/BCGAIL/logs/HopperGAIL/ --bcpath /serverdata/rohit/BCGAIL/HopperBC --max_steps 1000000 --legend Ours "Ours(no GAIL training)" GAIL "GAIL(no GAIL training)" 4 | python plot_graphs.py --files /serverdata/rohit/BCGAIL/logs/ReacherBCnoGAIL/ /serverdata/rohit/BCGAIL/logs/ReachernoGAIL/ /serverdata/rohit/BCGAIL/logs/ReacherBCGAIL/ /serverdata/rohit/BCGAIL/logs/ReacherGAIL/ --bcpath /serverdata/rohit/BCGAIL/ReacherBC --max_steps 2000000 --legend Ours "Ours(no GAIL training)" GAIL "GAIL(no GAIL training)" 5 | python plot_graphs.py --files /serverdata/rohit/BCGAIL/logs/Walker2dBCnoGAIL/ /serverdata/rohit/BCGAIL/logs/Walker2dnoGAIL/ /serverdata/rohit/BCGAIL/logs/Walker2dBCGAIL/ /serverdata/rohit/BCGAIL/logs/Walker2dGAIL/ --bcpath /serverdata/rohit/BCGAIL/Walker2dBC --max_steps 3000000 --legend Ours "Ours(no GAIL training)" GAIL "GAIL(no GAIL training)" 6 | -------------------------------------------------------------------------------- /print_experts.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import pandas as pd 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | import argparse 6 | import torch 7 | import os 8 | import sys 9 | from baselines.common import plot_util as pu 10 | from baselines.common.plot_util import COLORS 11 | import warnings 12 | import matplotlib as mplot 13 | mplot.rcParams.update({'font.size': 12}) 14 | 15 | def main(): 16 | """ 17 | Plot the plots inside the folder given 18 | """ 19 | res = dict() 20 | filelist = [] 21 | for r, d, f in os.walk('./gail_experts'): 22 | f = list(filter(lambda x: x.endswith('pt'), f)) 23 | f = list(map(lambda x: os.path.join(r, x), f)) 24 | filelist.extend(f) 25 | 26 | for file in filelist: 27 | a = torch.load(file)['rewards'] 28 | rew = a.sum(1) 29 | res[file] = rew 30 | 31 | for k in sorted(res.keys()): 32 | v = res[k] 33 | print('{} {:.2f} \\pm {:.2f}'.format(k, v.mean(), v.std())) 34 | 35 | 36 | if __name__ == "__main__": 37 | with warnings.catch_warnings(): 38 | warnings.simplefilter('ignore') 39 | main() 40 | -------------------------------------------------------------------------------- /print_results.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import pandas as pd 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | import argparse 6 | import torch 7 | import os 8 | import sys 9 | from baselines.common import plot_util as pu 10 | from baselines.common.plot_util import COLORS 11 | import warnings 12 | import matplotlib as mplot 13 | mplot.rcParams.update({'font.size': 12}) 14 | 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--files', nargs='+', required=True, 18 | help='Get filenames') 19 | args = parser.parse_args() 20 | 21 | def main(): 22 | """ 23 | Plot the plots inside the folder given 24 | """ 25 | res = dict() 26 | filelist = [] 27 | for dirs in args.files: 28 | for r, d, f in os.walk(dirs): 29 | f = list(filter(lambda x: x.endswith('txt'), f)) 30 | f = list(map(lambda x: os.path.join(r, x), f)) 31 | filelist.extend(f) 32 | 33 | for f in filelist: 34 | key = f.split('/')[-1].split('_')[0] 35 | if key == 'random': 36 | key = f.split('_')[0] 37 | with open(f) as fi: 38 | rews = fi.read().split('\n') 39 | rews = filter(lambda x: x != '', rews) 40 | rews = list(map(lambda x: float(x), rews)) 41 | lis = res.get(key, []) 42 | lis.extend(rews) 43 | res[key] = lis 44 | 45 | #for k, v in res.items(): 46 | keys = sorted(list(res.keys())) 47 | for k in keys: 48 | v = res[k] 49 | print('{} {:.2f} \\pm {:.2f}'.format(k, np.mean(v), np.std(v))) 50 | 51 | 52 | if __name__ == "__main__": 53 | with warnings.catch_warnings(): 54 | warnings.simplefilter('ignore') 55 | main() 56 | -------------------------------------------------------------------------------- /redsail.sh: -------------------------------------------------------------------------------- 1 | env=$1 2 | tt=${2:-3000000} 3 | if [ $# -le 0 ]; then 4 | echo './mujoco.sh ' 5 | exit 6 | fi 7 | 8 | 9 | for seed in {1,2,4,} 10 | do 11 | python main.py --env-name ${env}-v2 --algo ppo --use-gae --log-interval 1 --num-steps 2048 --num-processes 8 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --gae-lambda 0.95 --num-env-steps ${tt} --use-linear-lr-decay --use-proper-time-limits --gail --model_name ${env}RED --seed ${seed} --red 1 --redsigma 1000.0 --rediters 100 12 | python main.py --env-name ${env}-v2 --algo ppo --use-gae --log-interval 1 --num-steps 2048 --num-processes 8 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --gae-lambda 0.95 --num-env-steps ${tt} --use-linear-lr-decay --use-proper-time-limits --gail --model_name ${env}SAIL --seed ${seed} --red 1 --redsigma 1000.0 --rediters 100 --sail 1 13 | done 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym 2 | matplotlib 3 | pybullet 4 | -------------------------------------------------------------------------------- /run_all.yaml: -------------------------------------------------------------------------------- 1 | session_name: run-all 2 | windows: 3 | - panes: 4 | - python main.py --env-name Reacher-v2 --algo ppo --use-gae --log-interval 1 --num-steps 5 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 6 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 7 | --no-cuda --log-dir /tmp/gym/reacher/reacher-0 --seed 0 --use-proper-time-limits 8 | - python main.py --env-name HalfCheetah-v2 --algo ppo --use-gae --log-interval 1 9 | --num-steps 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 10 | 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 11 | 1000000 --use-linear-lr-decay --no-cuda --log-dir /tmp/gym/halfcheetah/halfcheetah-0 12 | --seed 0 --use-proper-time-limits 13 | - python main.py --env-name Walker2d-v2 --algo ppo --use-gae --log-interval 1 --num-steps 14 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 15 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 16 | --no-cuda --log-dir /tmp/gym/walker2d/walker2d-0 --seed 0 --use-proper-time-limits 17 | - python main.py --env-name Hopper-v2 --algo ppo --use-gae --log-interval 1 --num-steps 18 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 19 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 20 | --no-cuda --log-dir /tmp/gym/hopper/hopper-0 --seed 0 --use-proper-time-limits 21 | window_name: seed-0 22 | - panes: 23 | - python main.py --env-name Reacher-v2 --algo ppo --use-gae --log-interval 1 --num-steps 24 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 25 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 26 | --no-cuda --log-dir /tmp/gym/reacher/reacher-1 --seed 1 --use-proper-time-limits 27 | - python main.py --env-name HalfCheetah-v2 --algo ppo --use-gae --log-interval 1 28 | --num-steps 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 29 | 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 30 | 1000000 --use-linear-lr-decay --no-cuda --log-dir /tmp/gym/halfcheetah/halfcheetah-1 31 | --seed 1 --use-proper-time-limits 32 | - python main.py --env-name Walker2d-v2 --algo ppo --use-gae --log-interval 1 --num-steps 33 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 34 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 35 | --no-cuda --log-dir /tmp/gym/walker2d/walker2d-1 --seed 1 --use-proper-time-limits 36 | - python main.py --env-name Hopper-v2 --algo ppo --use-gae --log-interval 1 --num-steps 37 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 38 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 39 | --no-cuda --log-dir /tmp/gym/hopper/hopper-1 --seed 1 --use-proper-time-limits 40 | window_name: seed-1 41 | - panes: 42 | - python main.py --env-name Reacher-v2 --algo ppo --use-gae --log-interval 1 --num-steps 43 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 44 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 45 | --no-cuda --log-dir /tmp/gym/reacher/reacher-2 --seed 2 --use-proper-time-limits 46 | - python main.py --env-name HalfCheetah-v2 --algo ppo --use-gae --log-interval 1 47 | --num-steps 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 48 | 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 49 | 1000000 --use-linear-lr-decay --no-cuda --log-dir /tmp/gym/halfcheetah/halfcheetah-2 50 | --seed 2 --use-proper-time-limits 51 | - python main.py --env-name Walker2d-v2 --algo ppo --use-gae --log-interval 1 --num-steps 52 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 53 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 54 | --no-cuda --log-dir /tmp/gym/walker2d/walker2d-2 --seed 2 --use-proper-time-limits 55 | - python main.py --env-name Hopper-v2 --algo ppo --use-gae --log-interval 1 --num-steps 56 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 57 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 58 | --no-cuda --log-dir /tmp/gym/hopper/hopper-2 --seed 2 --use-proper-time-limits 59 | window_name: seed-2 60 | - panes: 61 | - python main.py --env-name Reacher-v2 --algo ppo --use-gae --log-interval 1 --num-steps 62 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 63 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 64 | --no-cuda --log-dir /tmp/gym/reacher/reacher-3 --seed 3 --use-proper-time-limits 65 | - python main.py --env-name HalfCheetah-v2 --algo ppo --use-gae --log-interval 1 66 | --num-steps 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 67 | 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 68 | 1000000 --use-linear-lr-decay --no-cuda --log-dir /tmp/gym/halfcheetah/halfcheetah-3 69 | --seed 3 --use-proper-time-limits 70 | - python main.py --env-name Walker2d-v2 --algo ppo --use-gae --log-interval 1 --num-steps 71 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 72 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 73 | --no-cuda --log-dir /tmp/gym/walker2d/walker2d-3 --seed 3 --use-proper-time-limits 74 | - python main.py --env-name Hopper-v2 --algo ppo --use-gae --log-interval 1 --num-steps 75 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 76 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 77 | --no-cuda --log-dir /tmp/gym/hopper/hopper-3 --seed 3 --use-proper-time-limits 78 | window_name: seed-3 79 | - panes: 80 | - python main.py --env-name Reacher-v2 --algo ppo --use-gae --log-interval 1 --num-steps 81 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 82 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 83 | --no-cuda --log-dir /tmp/gym/reacher/reacher-4 --seed 4 --use-proper-time-limits 84 | - python main.py --env-name HalfCheetah-v2 --algo ppo --use-gae --log-interval 1 85 | --num-steps 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 86 | 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 87 | 1000000 --use-linear-lr-decay --no-cuda --log-dir /tmp/gym/halfcheetah/halfcheetah-4 88 | --seed 4 --use-proper-time-limits 89 | - python main.py --env-name Walker2d-v2 --algo ppo --use-gae --log-interval 1 --num-steps 90 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 91 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 92 | --no-cuda --log-dir /tmp/gym/walker2d/walker2d-4 --seed 4 --use-proper-time-limits 93 | - python main.py --env-name Hopper-v2 --algo ppo --use-gae --log-interval 1 --num-steps 94 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 95 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 96 | --no-cuda --log-dir /tmp/gym/hopper/hopper-4 --seed 4 --use-proper-time-limits 97 | window_name: seed-4 98 | - panes: 99 | - python main.py --env-name Reacher-v2 --algo ppo --use-gae --log-interval 1 --num-steps 100 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 101 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 102 | --no-cuda --log-dir /tmp/gym/reacher/reacher-5 --seed 5 --use-proper-time-limits 103 | - python main.py --env-name HalfCheetah-v2 --algo ppo --use-gae --log-interval 1 104 | --num-steps 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 105 | 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 106 | 1000000 --use-linear-lr-decay --no-cuda --log-dir /tmp/gym/halfcheetah/halfcheetah-5 107 | --seed 5 --use-proper-time-limits 108 | - python main.py --env-name Walker2d-v2 --algo ppo --use-gae --log-interval 1 --num-steps 109 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 110 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 111 | --no-cuda --log-dir /tmp/gym/walker2d/walker2d-5 --seed 5 --use-proper-time-limits 112 | - python main.py --env-name Hopper-v2 --algo ppo --use-gae --log-interval 1 --num-steps 113 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 114 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 115 | --no-cuda --log-dir /tmp/gym/hopper/hopper-5 --seed 5 --use-proper-time-limits 116 | window_name: seed-5 117 | - panes: 118 | - python main.py --env-name Reacher-v2 --algo ppo --use-gae --log-interval 1 --num-steps 119 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 120 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 121 | --no-cuda --log-dir /tmp/gym/reacher/reacher-6 --seed 6 --use-proper-time-limits 122 | - python main.py --env-name HalfCheetah-v2 --algo ppo --use-gae --log-interval 1 123 | --num-steps 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 124 | 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 125 | 1000000 --use-linear-lr-decay --no-cuda --log-dir /tmp/gym/halfcheetah/halfcheetah-6 126 | --seed 6 --use-proper-time-limits 127 | - python main.py --env-name Walker2d-v2 --algo ppo --use-gae --log-interval 1 --num-steps 128 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 129 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 130 | --no-cuda --log-dir /tmp/gym/walker2d/walker2d-6 --seed 6 --use-proper-time-limits 131 | - python main.py --env-name Hopper-v2 --algo ppo --use-gae --log-interval 1 --num-steps 132 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 133 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 134 | --no-cuda --log-dir /tmp/gym/hopper/hopper-6 --seed 6 --use-proper-time-limits 135 | window_name: seed-6 136 | - panes: 137 | - python main.py --env-name Reacher-v2 --algo ppo --use-gae --log-interval 1 --num-steps 138 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 139 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 140 | --no-cuda --log-dir /tmp/gym/reacher/reacher-7 --seed 7 --use-proper-time-limits 141 | - python main.py --env-name HalfCheetah-v2 --algo ppo --use-gae --log-interval 1 142 | --num-steps 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 143 | 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 144 | 1000000 --use-linear-lr-decay --no-cuda --log-dir /tmp/gym/halfcheetah/halfcheetah-7 145 | --seed 7 --use-proper-time-limits 146 | - python main.py --env-name Walker2d-v2 --algo ppo --use-gae --log-interval 1 --num-steps 147 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 148 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 149 | --no-cuda --log-dir /tmp/gym/walker2d/walker2d-7 --seed 7 --use-proper-time-limits 150 | - python main.py --env-name Hopper-v2 --algo ppo --use-gae --log-interval 1 --num-steps 151 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 152 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 153 | --no-cuda --log-dir /tmp/gym/hopper/hopper-7 --seed 7 --use-proper-time-limits 154 | window_name: seed-7 155 | - panes: 156 | - python main.py --env-name Reacher-v2 --algo ppo --use-gae --log-interval 1 --num-steps 157 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 158 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 159 | --no-cuda --log-dir /tmp/gym/reacher/reacher-8 --seed 8 --use-proper-time-limits 160 | - python main.py --env-name HalfCheetah-v2 --algo ppo --use-gae --log-interval 1 161 | --num-steps 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 162 | 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 163 | 1000000 --use-linear-lr-decay --no-cuda --log-dir /tmp/gym/halfcheetah/halfcheetah-8 164 | --seed 8 --use-proper-time-limits 165 | - python main.py --env-name Walker2d-v2 --algo ppo --use-gae --log-interval 1 --num-steps 166 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 167 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 168 | --no-cuda --log-dir /tmp/gym/walker2d/walker2d-8 --seed 8 --use-proper-time-limits 169 | - python main.py --env-name Hopper-v2 --algo ppo --use-gae --log-interval 1 --num-steps 170 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 171 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 172 | --no-cuda --log-dir /tmp/gym/hopper/hopper-8 --seed 8 --use-proper-time-limits 173 | window_name: seed-8 174 | - panes: 175 | - python main.py --env-name Reacher-v2 --algo ppo --use-gae --log-interval 1 --num-steps 176 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 177 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 178 | --no-cuda --log-dir /tmp/gym/reacher/reacher-9 --seed 9 --use-proper-time-limits 179 | - python main.py --env-name HalfCheetah-v2 --algo ppo --use-gae --log-interval 1 180 | --num-steps 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 181 | 0.5 --ppo-epoch 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 182 | 1000000 --use-linear-lr-decay --no-cuda --log-dir /tmp/gym/halfcheetah/halfcheetah-9 183 | --seed 9 --use-proper-time-limits 184 | - python main.py --env-name Walker2d-v2 --algo ppo --use-gae --log-interval 1 --num-steps 185 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 186 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 187 | --no-cuda --log-dir /tmp/gym/walker2d/walker2d-9 --seed 9 --use-proper-time-limits 188 | - python main.py --env-name Hopper-v2 --algo ppo --use-gae --log-interval 1 --num-steps 189 | 2048 --num-processes 1 --lr 3e-4 --entropy-coef 0 --value-loss-coef 0.5 --ppo-epoch 190 | 10 --num-mini-batch 32 --gamma 0.99 --tau 0.95 --num-env-steps 1000000 --use-linear-lr-decay 191 | --no-cuda --log-dir /tmp/gym/hopper/hopper-9 --seed 9 --use-proper-time-limits 192 | window_name: seed-9 193 | -------------------------------------------------------------------------------- /save.sh: -------------------------------------------------------------------------------- 1 | env=$1 2 | savedir=$2 3 | if [ $# -ne 2 ]; then 4 | echo './save.sh ' 5 | exit 6 | fi 7 | 8 | for seed in {1,2,4} 9 | do 10 | python main.py --env-name ${env} --algo ppo --use-gae --lr 2.5e-4 --clip-param 0.1 --value-loss-coef 0.5 --num-processes 8 --num-steps 128 --num-mini-batch 4 --log-interval 1 --use-linear-lr-decay --entropy-coef 0.01 --model_name demo --load_model_name ${savedir} --record_trajectories 1 --seed ${seed} 11 | ## Save random policy 12 | #python main.py --env-name ${env} --algo ppo --use-gae --lr 2.5e-4 --clip-param 0.1 --value-loss-coef 0.5 --num-processes 8 --num-steps 128 --num-mini-batch 4 --log-interval 1 --use-linear-lr-decay --entropy-coef 0.01 --model_name random --load_model_name ${savedir} --record_trajectories 1 --seed ${seed} 13 | done 14 | -------------------------------------------------------------------------------- /savemaster.sh: -------------------------------------------------------------------------------- 1 | for env in {Hopper,Ant,HalfCheetah,Walker2d,Reacher} 2 | do 3 | #for method in {BCGAIL,BC,GAIL} 4 | for method in {RED,SAIL,GAILpretrain} 5 | do 6 | ./save.sh ${env}-v2 ${env}${method} 7 | done 8 | done 9 | -------------------------------------------------------------------------------- /scripts.txt: -------------------------------------------------------------------------------- 1 | # Car Racing baseline 2 | python main.py --env-name CarRacing-v0 --algo ppo --use-gae --lr 2.5e-4 3 | --clip-param 0.1 --value-loss-coef 0.5 --num-processes 8 4 | --num-mini-batch 4 --log-interval 1 --use-linear-lr-decay --entropy-coef 0.01 5 | --seed 42 --num-mini-batch 128 --num-env-steps 10000000 6 | 7 | CUDA_VISIBLE_DEVICES=6,7 python main.py --env-name CarRacing-v0 --algo ppo --use-gae --lr 2.5e-4 --clip-param 0.2 --value-loss-coef 0.5 --num-processes 8 --num-steps 128 --num-mini-batch 4 --log-interval 1 --entropy-coef 0.01 --model_name CarRacingPPOtanh --load_model_name CarRacingPPO --use_activation 1 8 | python main.py --env-name CarRacing-v0 --algo ppo --use-gae --lr 2.5e-4 --clip-param 0.1 --value-loss-coef 0.5 --num-processes 8 --num-steps 128 --num-mini-batch 4 --log-interval 1 --use-linear-lr-decay --entropy-coef 0.01 --model_name CarRacingPPO --load_model_name CarRacingPPO 9 | python main.py --env-name CarRacing-v0 --algo ppo --gail --gail-experts-dir /serverdata/rohit/BCGAIL/CarRacingPPO/ --use-gae --lr 2.5e-4 --clip-param 0.1 --value-loss-coef 0.5 --num-processes 8 --num-steps 128 --num-mini-batch 4 --log-interval 1 --use-linear-lr-decay --entropy-coef 0.01 --model_name CarRacingGAIL --gail-batch-size 32 10 | 11 | ## GAIL baseline 12 | python main.py --env-name CarRacing-v0 --algo ppo --gail --gail-experts-dir /serverdata/rohit/BCGAIL/CarRacingPPO/ --use-gae --lr 2.5e-4 --clip-param 0.1 --value-loss-coef 0.5 --num-processes 8 --num-steps 128 --num-mini-batch 4 --log-interval 1 --use-linear-lr-decay --entropy-coef 0.01 --model_name CarRacingGAIL --gail-batch-size 32 13 | python main.py --env-name CarRacing-v0 --algo ppo --gail --gail-experts-dir /serverdata/rohit/BCGAIL/CarRacingPPO/ --use-gae --lr 2.5e-4 --clip-param 0.1 --value-loss-coef 0.5 --num-processes 8 --num-steps 128 --num-mini-batch 4 --log-interval 1 --use-linear-lr-decay --entropy-coef 0.01 --model_name CarRacingBCGAIL0.125 --gail-batch-size 32 --bcgail 1 --gailgamma 0.125 --decay 1 --num-env-steps 1000000 --seed 1 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name='a2c-ppo-acktr', 5 | packages=find_packages(), 6 | version='0.0.1', 7 | install_requires=['gym', 'matplotlib', 'pybullet']) 8 | -------------------------------------------------------------------------------- /visualize.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 64, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from baselines.common import plot_util as pu" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "If you want to average results for multiple seeds, LOG_DIRS must contain subfolders in the following format: ```-0```, ```-1```, ```-0```, ```-1```. Where names correspond to experiments you want to compare separated with random seeds by dash." 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 80, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "LOG_DIRS = 'logs/reacher/'\n", 26 | "# Uncomment below to see the effect of the timit limits flag\n", 27 | "# LOG_DIRS = 'time_limit_logs/reacher'" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 81, 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "name": "stderr", 37 | "output_type": "stream", 38 | "text": [ 39 | "/home/kostrikov/GitHub/baselines/baselines/bench/monitor.py:163: UserWarning: Pandas doesn't allow columns to be created via a new attribute name - see https://pandas.pydata.org/pandas-docs/stable/indexing.html#attribute-access\n", 40 | " df.headers = headers # HACK to preserve backwards compatibility\n" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "results = pu.load_results(LOG_DIRS)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 82, 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "data": { 55 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAacAAAGoCAYAAADiuSpNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzs3XmcnWV9///XdfYzZ+bMmplMMtkgIWQhBLIQsCgKMUARKkoFqxWtoIhtre2j2mJxKfZr+Vm1LrVFqVSrUpcKKChIwaIsJQkQTYBsJCSTbfbl7Nv9++OaMzMJk8lMZjknud/PxyOPmbnPfe77c85Mrve5rvu679s4joOIiEg58ZS6ABERkWMpnEREpOwonEREpOwonEREpOwonEREpOwonEREpOwonEREpOwonEREpOwonEREpOz4Sl3AWDU0NDjz588vdRkiInISNm/e3OE4zoyxrl+ycDLGXA78M+AFvuk4zudGW3/+/Pls2rRpWmoTEZHJZYx5dTzrl2RYzxjjBb4GXAEsBW4wxiwtRS0iIlJ+SnXMaS2wy3GcVxzHyQD3AteUqBYRESkzpQqn2cD+YT+3Diw7ijHmZmPMJmPMpvb29mkrTkRESqusZ+s5jnOX4zirHcdZPWPGmI+jiYjIKa5U4XQAmDPs55aBZSIiIiULp43AImPMAmNMALgeeKBEtYiISJkpyVRyx3FyxpgPAw9jp5L/u+M420pRi4iIlJ+SnefkOM5DwEOl2r+IiJSvsp4QISIi7qRwEhGRsqNwEhGRsqNwEhGRsqNwEhGRsqNwEhGRsqNwEhGRsnPK3GxQRMpXJp/Ba7zkCjn8Xj8ec/zPvZ2JTnpSPcyIzKAqUIUxBoBcIUd3shuA6lA1AW+ATD7DkdgRosEo0WB0cF3Hccg7eXye0Zswx3HsVxx2d+1mZ9dOMvkMTZEmKgOVVPgr2NOzh8OxwxyOHSadS7Nm9hpCvhCpXIpYJkZnopOCUyBbyFJwCvg8Ps5vPp+qQBWHY4fZcmQL5808j329+9jatpVsIcva2WtZ1byKfb378Hq8JLNJCk4Bv9dPZ6KTdD6Nx3gI+UIsqls0uP1MPkM2b/dzJH6Eb73wLXKFHOta1lEbqiWbz9Kd6qYuXEdNqIbtHdvJOTkO9x+mtb+VdC5NTaiG7lQ3+UKe6lA1iUyCvJNneeNyKgOVbG3bioNDoVDA5/Xh8/iIZ+IsbljMmllraO1r5XDsMAf7D5Iv5GmqbOL1817PR9Z9ZDL+VMbMFH955W716tWObjYo5c5xHHKFHMYYXu15lWQuSUNFA3XhOpLZJMlckmQ2SXNVM36PH2PMUQ15KpdiZ+dOZlXNIhqM8mL7i+zu3k17vJ1kLkl/pp+gN0h7op0KfwUhr21Ew/4wjuPwTOszZAtZ5lTPoT/dT0eigxmRGXjwkC1k8RgPeSePwZAv5OlMdtKV7CLgDVBwCvSl+4hn4iRztsZkNklvqpdoKEpNqAbHcXAch7ZEG9FglHQuTXuinUQ2MfgaqgJVXDz3YqKhKIf6D9GR6CCZS1IdrMYYw3OHnhtcN+gNUheuI+gLcqDvANlCFoCwL0xloJL2xNDdCLzGSzQYJewL05m0DXzEHyESiJDIJkjlUoPrGgxhf5j+dD8O09fGeY0XgLyTn5Tt1YZqCXgDHIkfGVxWDG1gMJwrA5VEg1E8xkM6lyYSiOAxHlK5FEFvEGMMR2JHyBayNFQ04DgOxhgKToGCUyDgCXA4fphULoXP4yPsC1MVrMJgiGViLKpfxMabNk7otRhjNjuOs3rM6yucZDIU/46Kn2yP50jsCB2JDnKFHF6Pl8X1i/F5fLTF2zjYf5DqUDUzK2dS4a8YfE6+kKc90c7urt3s7t6Ng0PYFyboDRL0BelIdLDp4CaWzVjG1YuvZlv7NvKFPA0VDSSyCeor6vn1q79mR+cOmquaCfvD7Oneg8/jI+QL8UzrM2w+tJlkNklduG4wTMK+MPv79mMwYMBjPCysW8jvjvyO/X37qQpUURWsIuwLE8vE6En10JPqGWxgR1NsVBzHIRqMUh+uZ0ZkBts7t9OV7ALs/gpOYVy/h2gwiuM4xLNxQr4QFf6K1zTaxW0XP7mHfCFyhRwe4yHgDRDyhQDoTfcOvsepXIpkNonHePAaL2F/mFQuRcAbIBKIUOGrGOzJ9KZ62d+3n1whR4W/ggp/hf10no2TL+RZWLeQhooG4pk4fek+EtkEOSc3+D44OByJHSGdT1Mfrqc2XEsmn6Ev1Udfuo9MPkNloJKwP0w8EyedTxPyhfB7/YOvz8EhnUtT4a/A6/HiOA5VwSrOrD2TulAd+/v2k8wliWfiNFc1UxOsYVbVLJbMWMJjex6jM9FJ0B/Ea7wsqltEU6SJmlANbYk2YpkYmw9upivZRSQQ4YyaM2hLtIEDb13yVqpD1Ty+53Gebn2aaDBKxB/B6/ES8Udsj8UpML96PgUKpLIpftf2O4wxVAWq8Hq8+Dw+HMeud/mZl7N69mr29+4nkU2QzCVpjDTSHm9nR+cOFtUtIhwIE/AGaKhoIOQLkS/k8Xv9hHwhktkkXo+XqkAVXckuulPdVAYqyRfyeD1e/B4/2UKWCn8FPckeftf2OxrCDVSHq+17PPC37eBwdsPZ4/pbPJbCSSas4BTY1bULsJ9gw/4wIV+Itngb29q24fV4ebXnVfb07KEt3sb2ju1sPLiRkC/E2Q1nc07jOVT4K3jh8Atsbd9KXbiOpkgTh2KH2Ne776h9FT/VHduAtkRb8Hl8dKe66U31nvDTr8GMa53h31cHbSAW/zOn8rYhLjaCxfck7+TpT/dTH66nvqKeTD5DKpciW8jaRt4bIuS3X3OFHNFQlGggSiwTI5aN4ffY4S6/x093qntwKKzY6CRzSaoCVSyuX0wimyCWiTGzaiYLahawsG4hNaEa/B4/fek+osEoyZwdKpoTnUPBKdCR6GBezTz8Hj9HYkfweX0sqFlArpADwOvxEs/E8Xv9gA2oYsACxDPxwcYR7IeCYuAEvUE6k534PD68xkvAGxj8gBHwBqgMVJIr5PB5fIMfNrqSXYR8ocFGOVfIkcgmaKhoGPz0ns6nyeQzdlvGS3WoenDfnclOqoPVBH3Bwd9BX7qP/nQ/kUCEiD9Cf6afdC5Ndah6sFF3cAh4AySyCYLeIH6vn3zBvo4TfXg6leXz0NMDoRBEIsdfL5mEVAqqqsA3jQd2FE5yXPlCnp5UD1uObKEt3kZVoIrnDj1He6KdoC/IvVvv5WD/wcH/4GNR7P63VLVQcAq0J9oHe0bVoWpmVc4iU8iQyCao9Fcyp3oOtaFa2yugwL6efeSdPFXBKhrCDaTzaToSHbTF2wZ7SBX+CiKBCA3hBpqrmmkIN9Cb7iWWiQ1+an7Tgjexo3MHP9/1c2ZVzaI2WEtXsguf10cqm+KM2jO4YtEV7Ovdx+H4YWpDtVSHqulMdLJsxjKWNS7D7/EPfroHBoe7io15wSmQyqXwGi8t0Rb7ng4Els/jw+/1E/QGbcjlkvg8vsFjKolswg6pDXxiLTiFwcY7V8hRcAqD4TCzciZBb5D+TD8V/goC3sDU/EGUWKEAnkmcknWi7WUy8NBDdp2LL4baWrvccSAet//6+qCiAvx+6O2FWAzCYdi3DzZvtg3/5ZfD/Png9dpG/v/+D8480y7r7rbbCYWgocF+n05DXZ3dr+PYn/v7IZezX6urbe35vA2LbNb+3NYG//qvNkxaWuw2+vogkYArr7Tbf+45+/imTfDzn9ttNzTYbQWDsGaNXdbeDk1NcPgwPPkkVFbCnDnQ3GxDqqsLZs+Gc86x6/b22td34IB9/p/9GVx11cR+PwonF0tkE9zzwj08tucxtrZtZXZ0NlcuvJL9fft5bM9j7O7efdSxgaLi8NG86nnMqppFwSkMju3nCrnBT7d+j5850TmEfCHm186nOljNnOgcZkVn0RRpojZcS3+6n709e+lKdlHhr2BB7YLBoSWv8TIjMuOog9jFIajiUBHYEGiLt5HOpfF7/fg9doiiOI4+mmIPrDg0dbrKZoc+9fb32wbU77cNVaFgPzlns7Yhi0bBmKMb70TCNmAAO3ZAa6ttlGfPto32zp0QCNgG96mn7LYWL7aPBwK2IayttT8/9BB873u24ayqgvPPh499zNbzxBO2wevvh0OH4MgR+MY34NVXbeMcCsGqVTBvnn1+d7eta9Ei+2/PHnjlFdtYzpxpX1cuZ2tNJu170NhoA6K1Fc44w76m+nrb6CcSdh8vvggvvWTDBmzDHInY19DebuvKZu22c7nR3/uKChsA/f1DIRMKQU2Nrcnrtf88HhswtbX28WDQ/h56emyDH4/b+quq7LLqavt4b6/9CkO/o7E480xbQ1ub3VdHh90u2P3kcvbrOefY7e7bZ/flOLbWwggjyMXlCxfav4mJUDi5SL6Q5/nDz/Pkvid5uvVp/ueV/6Ej2UE0GKUh3MDB2MGjPunXhGqoC9cxo2IGMyIz6E/3M7NyJpfMv4RYOsbc2rksm7Fs8DhHNBil4BQGZxBlC1lqQjUnDAix/+F/+1vbsAeDQ5+Mr7zSNmiJBGzcaBuzxYvhrLNsA9DTA9/6ll2+fz/88pd2/ZYW+8k3n7efZl96yX76DYXstiorYcYMePllu++WFvsp2RhYutQ2/MVGs6PDNozFwEq89vPKIGPs9opCIbssmTx6vdpa+6+3Fzo7bSOYzx/93KL6ehsiwaANhl277Lp+vw1ZsPUO32cwOLSsWPdwoZANqX37bHhmMq+tr6XFvheHD8Pu3fY96O62vYe5c+3+PR5bu9c71Gjn8/b5NTX2te3ebX8n0ehQKO3aZZ8fjdraUikbdjNmDAVusUcUCNh/TU32fezvt72inh67jcpKGyTGwLJlNgwTCbuux2MfO3TIrtvQYIOuuXmolljM1m6MDari30k2a5dHozBrln2vslm73YoK23tKJm1YFt/jqiq7jVAIvva10f/mT2S84aSp5KeAPd17eHDngzy25zGWNy6nOljNT3f8lOcPP09f2v6PrQxUMjMyk8sXXs67VryLRfWLiGfi/GLXL5hXPY9zZ55LdaiaqkAVQV/whFNwi4YfND9dHT4M991nG/YZM+AP/9B+Sq+qsv/xAwH7nzidHmpQP/95++n/4ottg9fVZbfz9NP2E/yePfbnYxUbmmNVV8Mb3mDDrKNjaPncubam7dth61bbOEUidrgmlbKNXmWlramtDVasGGqUli2zDeOuXXb7sZhtcIq9Bsexz581a6hBT6Xs12IDmU7b/TuObURjMbt89mz7WKFgH29utl99Pvse7N1r99HUNNRAFnsOZ5wBS5bYbVZW2l5NPD7U24hE7PtXbEiLoVEo2OcXG05j7OvLZu12qqpsfaGQ/X14PPa5uZzdTk2NfT8bG+HZZ23Ytbfbuisq7HsUjdrX4Pfb3lxlpe1V9fTY/dXV2WB48UX7NzJzpt1uoQC//vVQz7S+3r6maHTob6f4HjqOrdHns6+1rc3WUAyqhga733374Lzz7L9Dh4YCPxi0v9N9++z7W1lpH1u50v5uh38oKPY66+rs85JJ+/uOROzvubvbvrZAwD6vOMyYz9t1i+/JkSOv+ZOdcuo5lal8Ic8nf/VJ7tp81+B02gp/xeCwXF2ojlnRWSysXciShiWc13wei+oWsbhh8eDwmBvE47Bliw2Xc86B3/s920B1d9sGo9j4/PKX8OCDtmF697tt49HWZnspTz119FBO8RNoNGobsMZG22j39g4dfygOkR37Cd4Y+9zaWjvMMneufS4MBUhzs102d65tIHp67LGDI0dsY/D7v28by2DQ1jlnjm0oDhywjWahAAsW2CG33buHhpV6emyDunCh7bXFYra3sHHjUGNVKNjGrLbWbsvnsw3T2Wfb19PTY0PV67WN5Isv2te/eLENjL17be0LFtgg3LnTNryrVtllYIPh6aftdlassPvs6rKvI5ezNTU0DL1nuZz9PVZVHT3seOSIfcxxbIPa32/rCoft+oHAUGBmMvZDQjw+NLTZ2Wnfz/377RBhc/PQPvN5W3drq91eS8vQUNpYZLP2/RuvZHIoTItBm8nY11Io2MeLkxmKw3CnCw3rneIy+Qzf3vJtvr3l2/x6369ZULOAlmgL5zady1VnXUVHooPNhzbzxvlv5Lzm82iubMbr8Za67HHJZOwnwULBBkBtrQ2Pri7bmNXVDQ2L/fjH9lOuzwfr1tnjI+eeaxvTn/wEnnnGPm8sip+sh//JB4OwerUNkqqqoU+axUa4eIA6FLK1JhK2ITv3XLt+a6tt6IoHu+fOtesuWGBDYs4c+zq7u+3jzzxj97twoW0cFy2yjXhbGzz+uG20Vq+2rzWTsdsLh+1zOzvtth3H7nuskklbo9dr36uZM+33ItNJ4XQK29uzl2v/61qeP/w8fo+fdS3ruP0Nt7OiaQUNFQ2nzLEexxk6GP+zn9nw6euzPZzf/MYGTEfH0LCO328b76Lip8hib6ax0X6iP/Y4gtc7dOC8ONziOEO9geIwUjxue09NTTZ4WluH1pk1yz62YoX9PpWyPY1s1oZBcbjL6x0aSqqthWuvtZ/Ei8dv8nk7zNfYaD/Bh0+i87pnj93/WWcpPOT0o3A6xfSn+/mnp/+JuzbfxaHYIfweP29a8CauWHgF6+asY+2stWV7bkZfH9x/v53G+thjR88yCgTscEo6ffRzIhHbq5g5c2iMO522PQSPx35fnL5bU2OPbVRV2U//3d02SHbtssFTnDlWWWnH22trbW/L57OBEo/bnksqZbd50UX2Odu32/309dnjH+ec89opyI5jexnV1bZnk8vZeoqz4iZzCrSIG2hCxClk44GNXPHdK+hMdjInOoc1s9awdvZabl51MyuaVpS6vNfo6LBDaR4P/PCHdppwcdbWjBlDU2GL53IsX257NV1d9vHa2qF1Zs2yvYx02vZ4gkF7AP/IEbvNw4ft8ZZ162xwBQJ2yC0YtEHz0kt2KG7WLPuv2NNYt+7Er2PGjBOvY4zdP9hhPBGZXgqnEuhMdPK9332PO564g1whxzuXv5N3nvNOljUuo6GiYfCqBKXkOPZkvQcftENcu3fDww8PhZHPZ4+7LF5sG/tIxB7kTiRs4BQP5p5/vu35tLfbdRsb7XrB4PhrWr586Pvzzpuc1yki5UnhNM3+ddO/8hcP/wWpXIr6cD3XL7+ez7zxMzRGGktSj+PYYzAPP2x7RXv22FDZv3/oBD6wPZP5823YgO1NVFXZkFi50vaIimfcF7dbPONdRGS81HRMo+//7vvc8uAtzInOYcPCDbxl0Vu4aO5FNFQ0nPjJk6B47kJ3tx22SqXg7W+3ExXADp01NtpeT1OTPUYzf75dHonY559zztBxoCVLjn/g3xgFk4icPDUf0+TBHQ/yx/f9MbOrZvPx132c96x8D5HAKFdnnCT79sG//Zu9vMvzzw9Nu/b7bdgEAvbcoHnzbOg0Ndkg8vvtMNqCBUMz5zweO5GhTOdniMhpROE0De57+T7e9oO30VDRwPvPez/vO/99U3bFhZ4e+Ku/sseLfD57EmWhYGeaVVXZnlBtrT2vJpu1Vzh4y1vsrLXmZjslW0Sk1BROU+y3R37LDT++gaZIEx9e82E+cuFHpiSYHMeerPrOd9oTSBsb7bLzzrOXulmxwvZ8zjrLhlJ9vb1cz7p1dtabiEg5UThNoa5kF1d//2r8Hj/vP+/9fPSij05qMPX22isrPPcc/Nd/2anW4TD8wR/AO94xNBT3+tfb6dbHWrVq0koREZlUCqcpki/kueFHN9Da18pNq27ittffNnjTtMmwdaudzLB9u/25qQne+Ea44AK48UY7bVtE5FSlcJoitz9+O4+88ghXLLyCv3/j309KMOXz9tI/d9xhe0qBAFxzjZ3MsGgRXHqpHaILnJ73phMRF1E4TYGfvPQT/uE3/8A5jefw1xf99aRMFf/ud+FDH7KX3PH57IVDr77a3qFy1ixdTkdETi8Kp0n2csfLvPsn76a5spn3n/d+fm/e701oe/39cMstNpyam2HtWjtk99a32vOQTuYCoyIi5U7hNIkcx+GDP/sgBafAu1a8i5tW3TTmm/qNZPt2e2+fV16xlwq64QZ7PKm2VkN3InJ6UzhNosf2PMb/vvq/vPmMN/NXF/3VhG7619MDV11lL4D69rfD9dfD+vXju4+PiMipSuE0ie588k4q/ZW8e8W7T/paeR0d8Kd/Cv/7v/YK3e95D3z2s0ffxVNE5HSncJokT+9/mkdeeYQ3zn8jV5515Ulto7/fTgffts2eRPuWt8DttyuYRMR9FE6T5O8e/zsi/gjXLL6GuvD4rwG0dSu87W2wc6cdzrv1Vnt/I91LSETcSBOQJ8Ezrc/wP3v+h7Wz1/KO5e8Y9/P/3/+zlxfav99OD7/tNtiwQcEkIu6lntMk+Mff/CNhX5jrllzHzMqZ43ruvffC3/7t0HlLt9xivxcRcTOF0wRt79jO/dvv56I5F3Hd8uvG9dxt2+Cmm+xJtJ/4hL0eXmhqLlYuInJK0bDeBH3zuW/iMR7efOabx3UliK1b4ZJL7PfvfKeCSURkOPWcJqDgFPje1u9xRu0ZvGvFu8b8vJ07bTBlMnaq+F//tYJJRGQ49Zwm4Ml9T3Kw/yDnNJ7D3Oq5Y3pOImFn5SWT8L73wac+BTNmTG2dIiKnGoXTBHx/6/fxe/xcvfjqMV2myHHsMaatW+3kh09+Ehomfk1YEZHTjsJpAh7a+RALahbwhvlvGNP6X/0qfO979p5LH/+4vUaeiIi8lsLpJLX2tfJq76vMr5lPS/TEJyT9+tfw0Y/C/PnwgQ/Y85pERGRkCqeT9MSrTwCwpGHJCYf0jhyBa6+Fykp7ZfG3vx2MmY4qRUROTVMWTsaY/88Y87Ix5rfGmJ8YY2qGPfY3xphdxpjtxpgNU1XDVHp87+MEvUGuXXLtCdf9i7+Ari74gz+wF3WtrJyGAkVETmFT2XP6JbDccZwVwA7gbwCMMUuB64FlwOXAvxhjvFNYx5R4fM/jtERbWFC7YNT1fvUr+P737f2YbrlFF3EVERmLKQsnx3EecRwnN/DjM0DxwMw1wL2O46Qdx9kD7ALWTlUdU6Ej0cHu7t3MqZ4z6q0xHMeew1RVZU+yXbVqGosUETmFTdcxp/cBPx/4fjawf9hjrQPLXsMYc7MxZpMxZlN7e/sUlzh2zx16DoAFNQsI+oLHXe+JJ2DjRhtKf/RH4D3l+ociIqUxoStEGGMeBUa60ultjuPcP7DObUAO+O54t+84zl3AXQCrV692JlDqpCqG0+/N/b1R1/vMZ6CiAj74QV1hXERkPCYUTo7jXDba48aYG4GrgEsdxymGywFgzrDVWgaWnTKePfAsNaEaVs9afdx1Nm+Gxx6DN7zB3v5CRETGbipn610O/DVwteM4iWEPPQBcb4wJGmMWAIuAZ6eqjqmw+eBmmiJNo17o9Y47IBi0F3WtqTnuaiIiMoKpPOb0VaAK+KUx5gVjzL8COI6zDfgB8CLwC+BWx3HyU1jHpOpOdrOvbx8zK2ceN5xeegnuvx9WroQrT+6O7SIirjZlVyV3HOe4t8xzHOezwGenat9T6YXDLwAwr3oeAW9gxHXuuAN8PhtMOtYkIjJ+ukLEOP2u7XcArGga+fpDe/fau9uefba9EoSIiIyfwmmcXmp/iaA3yLqWdSM+/rGP2a+XXgqLFk1jYSIipxGF0zhtbd9KXbhuxJNvX3gBfvADWL7cXkvP7y9BgSIipwGF0zjt7NxJXbiOmtBrp+B94Qs2kC67DM47rwTFiYicJhRO49Cf7udI/Aj14XqqQ9VHPZZIwI9+BGeeaS9VpIu7ioicPIXTOOzo3AHAzMqZr5mp99Of2luvr1qlezWJiEyUwmkctnduBxjxSuT33GN7S297G4RC01yYiMhpRuE0Di+1v4TBcGHLhUct7+mBRx+FhQvh/PNLVJyIyGlE4TQO29q3EQ1GmVs996jlv/gF5HL2nk066VZEZOIUTuOwq2sXNaGa10yG+O//hnAYfv/3dVsMEZHJoHAah1d7X6U6WE00GB1clsvBww/D3Lm25yQiIhOncBqj7mQ3fek+6irqjgqnp56Cvj5YuhTmzSthgSIipxGF0xi90v0KAI0VjUdNI7//fvB47D2bgse/Ka6IiIyDwmmMiuF07GSI+++H2bPhda8rRVUiIqcnhdMY7ezaCcCqWasGl+3aBbt32ynkc+ce75kiIjJeCqcx2tG5g7AvzBm1Zwwu+9nP7Nfzz4do9DhPFBGRcVM4jdHOrp1Uh6qPuuDrT38KdXXwlreUsDARkdOQwmmMXu2x08irAlUA5PPwzDP2pNszzjjBk0VEZFwUTmOQyWc4FDtEbaiWqqANp61b7ZXIFyyA5uYSFygicppROI3Bvt59FJwCDZEGIv4IYM9vAntjQZ+vhMWJiJyGFE5jsLtrNwBNkSaMMQD8+tf2KuQbNpSyMhGR05PCaQyK5zgtbVg6uOzJJ2HWLJgzp1RViYicvhROY7Cjawc+j4+ljTacWlth3z578m1TU4mLExE5DSmcxmB3126qAlXMrJwJwC9/aZcvXWqvRi4iIpNL4TQG+/v2Ew1GB6eRP/IIRCI63iQiMlUUTmNwsP8glYHKwWnkv/mNPb9pyZISFyYicppSOJ1ANp+lPd5ONBilMlDJoUP2mFNLi+56KyIyVRROJ3AodggHh9pQLWFfmKeftsvPOgtCodLWJiJyulI4nUBrXysA9eF6jDE8/bS9Ffvll5e4MBGR05jC6QT29+4HoKXajuH95jf2ckUrVpSyKhGR05vC6QT299lwWt64nGwWnnvOHmtqbCxxYSIipzGF0wns7dmL3+NnfvV8Xn4ZMhkbThUVpa5MROT0pXA6gVd7X6UqWEV1qJoXXrDLFi4sbU0iIqc7hdMJ7O/dT1Wgikggwgsv2CuQ6+RbEZGppXA6gUOxQ1QGKon4I2zebK+ld+aZpa5KROT0pnAaxfATcMO+CrZsgZkzoaGh1JWJiJzeFE6jGDwBN1xL55EwPT12Grku9ioiMrUUTqOkRPFlAAAgAElEQVQonuNUF6rjt1vsWzV3bikrEhFxB4XTKIpXh5hZOZOXX7bL1q4tYUEiIi6hcBpF8QTcpTOWsn27vU3GeeeVuCgRERdQOI1i8ATcmvm8+CLU12syhIjIdFA4jWJ/334qA5XUhevYudOGU01NqasSETn9KZxGcbD/IJFAhEKqio4OG066bJGIyNRTOI3iSOwIEX+Ew3ujgD3HSUREpp7C6Tgcx6Et3kZloJLWPREAli4tcVEiIi6hcDqOvnQf6XyaaDDK3t1BjIGLLy51VSIi7jDl4WSM+UtjjGOMaRj42RhjvmyM2WWM+a0x5vypruFkHOw/CEBNqIaXX/JQUwOzZ5e4KBERl5jScDLGzAHeDOwbtvgKYNHAv5uBr09lDSfrUOwQAHXhOrZv10w9EZHpNNU9py8Cfw04w5ZdA3zbsZ4BaowxzVNcx7gd6rfh1FI1l927bThFoyUuSkTEJaYsnIwx1wAHHMfZcsxDs4H9w35uHVhWVoo9p8b8SlIpe1t2r7fERYmIuIRvIk82xjwKjDTB+jbgb7FDehPZ/s3YoT/mTvMVV/f37cfn8ZFvt7e9nTNnWncvIuJqEwonx3EuG2m5MeYcYAGwxRgD0AI8Z4xZCxwAhjf1LQPLRtr+XcBdAKtXr3ZGWmeqHOg7QMQfoW2/PdB0wQXTuXcREXebkmE9x3F+5zhOo+M48x3HmY8dujvfcZzDwAPAHw/M2lsH9DqOc2gq6piIg/0H7TlOr0Tw+3XBVxGR6TShntNJegi4EtgFJID3lqCGEyrenv2VHSHq66G2ttQViYi4x7SE00Dvqfi9A9w6HfudiPZ4O4sbFrNnl18z9UREppmuEDGCeCZOPBsnYuo4cMBQVweVlaWuSkTEPRROIyhOI/f1LqZQMDQ2gkfvlIjItFGTO4LiCbhOx1kAzJpVympERNxH4TSCYs+JThtOa9aUsBgRERdSOI2g2HPKt59BJALLl5e4IBERl1E4jaC1rxWP8RBvn0F1taaRi4hMN4XTCFr7Won4I3QeDhONahq5iMh0UziNoHh1iPYjAaJRqKoqdUUiIu6icBrB4dhhwvkmEnEP0Sj4/aWuSETEXRROIzgSP0IgfgZg7+MkIiLTS+F0jHQuTW+6F29sHgBNTSUuSETEhRROxzgcOwyAt9+Gk6aRi4hMP4XTMYon4Dq9LQCsWFHKakRE3EnhdIy2eBsA2a5ZRCIwY0aJCxIRcSGF0zGK4ZTrnkU0qquRi4iUgsLpGMVwSnbWU10NkUiJCxIRcSGF0zEO9R8i4A3Q3RGmqgoqKkpdkYiI+yicjnE4fphwYQaJmJeaGjCm1BWJiLiPwukYbbG2wRNw6+pKXIyIiEspnI7RlmjDH18AwOzZJS5GRMSlFE7H6Eh0DJ6Au3RpiYsREXEphdMwBadAd7IbeucACicRkVJROA3Tk+oh7+Qp9M6mokIn4IqIlIrCaZjBE3C77Am4mkYuIlIaCqdhiuGU7p5BNKoTcEVESkXhNMzg1SG66tRzEhEpIYXTMO3xdsiEScfDugOuiEgJKZyGaYu3QZ+9VUZtbYmLERFxMYXTMO2Jdvy6PbuISMkpnIbpSHTgj80HYN680tYiIuJmCqdhupJdePrnA7BsWWlrERFxM4XTMN3JbkzfHEKhAnPnlroaERH3UjgN053qptA7myqd4yQiUlIKp2H60n3ke5qJVimcRERKSeE0TG+6l1xvI1VVRuEkIlJCvlIXUC6S2SSZTAHi1VRWKpxEREpJPacB3aluiM8Ax0N1ta4OISJSSgqnAd3JbojNBKCmpsTFiIi4nMJpQHdqKJx0dQgRkdJSOA0Y3nNauLDExYiIuJzCacDwntPZZ5e4GBERl1M4DehJ9UBsJv5gjpkzS12NiIi7KZwGFIf1KioKusmgiEiJKZwGdCW7MPFZVER0B1wRkVJTOA3oTHZCbCaRiEM4XOpqRETcTeE0wIZTk8JJRKQMTGk4GWP+1BjzsjFmmzHmzmHL/8YYs8sYs90Ys2Eqaxirjp4kTrqKykoIBEpdjYiIu03ZtfWMMW8ErgHOdRwnbYxpHFi+FLgeWAbMAh41xpzlOE5+qmoZi452LwDV1QZjSlmJiIhMZc/pFuBzjuOkARzHaRtYfg1wr+M4acdx9gC7gLVTWMeY9HXasbzqaiWTiEipTWU4nQVcbIz5P2PM/xpj1gwsnw3sH7Ze68Cykor3BgGordZhOBGRUpvQsJ4x5lFgpFNWbxvYdh2wDlgD/MAYc8Y4t38zcDPA3Cm8b3q+kCcdsz2nOS3eKduPiIiMzYTCyXGcy473mDHmFuC/HcdxgGeNMQWgATgAzBm2asvAspG2fxdwF8Dq1audidQ6mlgmBsk6AObN07CeiEipTeUY1n3AGwGMMWcBAaADeAC43hgTNMYsABYBz05hHSfUm+4dCKcCU9hBExGRMZrKO+H+O/DvxpitQAZ4z0Avapsx5gfAi0AOuLXUM/V6UzacvME0NTU6yUlEpNSmLJwcx8kA7zrOY58FPjtV+x6vvnTfQDilCIUUTiIipaapaQwN6/mDOYLBUlcjIiIKJ4Z6ToFQQeEkIlIGFE4MHXMKh4zCSUSkDEzlhIhTRm+6F1K1REIeQqFSVyMiIuo5Ab3JfkjWEanw66KvIiJlQOEEHOlKgOMlUqETcEVEyoHCCWjryAEQiZS4EBERARROAHR1269Vleo5iYiUA4UT0N1lQykaVTiJiJQDhRPQ3+MHoHmk66uLiMi0UzgBsT47RW/OHL0dIiLlQK0xkOy3JzfNn6d7OYmIlAOFE5Dpj2D8KZob/aUuRUREUDiRyqUoJGrwBOOEw5oQISJSDlwfTkP3ckrpunoiImXC9eFUvCK5L5hWOImIlAnXh1PxXk6+YFbhJCJSJlwfToP3ctKNBkVEyobrw6knaXtOwZDRFclFRMqE68OpvTcG+RAVIS8+3d1KRKQsuD6cjnRkAIhWqNskIlIuXB9Oxdtl1FRWlLgSEREpcn04dXTZcKqu0pieiEi5cH04dXblAaitKXEhIiIyyPXh1N1jvzbqdhkiImXD9eHU32vfgpbZJS5EREQGuT6cYn32SuTz5rr+rRARKRuub5ET/QFMIE5ttSZEiIiUC9eHUyoWxgTjRMIKJxGRcuH6cMrGI3iCSSordKNBEZFy4epwKjgF8slKvIGkek4iImXE1eHUn+6HZC3eYFoXfRURKSOuDqfedC+kavAFdLsMEZFy4upw6kv3KZxERMqQq8OpO9EL6Sj+QEG3yxARKSOuDqeDHXHAo16TiEiZcXU4HelIAxAJe0tciYiIDOfucOq04VRVoa6TiEg5cXU4Fe/lVKsbDYqIlBVXh1N3twNAU32oxJWIiMhwrg6nnh4DwMwmU+JKRERkOFeHU9/AvZxmNSucRETKiavDKdZnX35zs6vfBhGRsuPqVjkR90EgRk1UU8lFRMqJq8MpGfdiAgkqIwonEZFy4upwSiUCGH+SipCuXSQiUk6mLJyMMSuNMc8YY14wxmwyxqwdWG6MMV82xuwyxvzWGHP+VNVwItlEEBNIEQ4qnEREyslU9pzuBD7tOM5K4PaBnwGuABYN/LsZ+PoU1jCqbCqIx58iHNKwnohIOZnKcHKA6MD31cDBge+vAb7tWM8ANcaY5ims47jyqQo8gQwVQd2iXUSknEzleNZHgIeNMZ/HhuBFA8tnA/uHrdc6sOzQFNYyokK6Ak9VjFBQPScRkXIyoXAyxjwKzBzhoduAS4G/cBznx8aYPwTuBi4b5/Zvxg79MXfu3ImU+hoFp4CTjuD1Z3UvJxGRMjOhZtlxnOOGjTHm28CfD/z4Q+CbA98fAOYMW7VlYNlI278LuAtg9erVzkRqPVYym4RMFV5fjkBgMrcsIiITNZXHnA4Cbxj4/k3AzoHvHwD+eGDW3jqg13GcaR/S60vFbDj583hcPaFeRKT8TOWA1k3APxtjfECKgeE54CHgSmAXkADeO4U1HFdHbwIAnz9fit2LiMgopiycHMf5DbBqhOUOcOtU7Xes2rqTAPgDkzpaKCIik8C1A1qdPfYuuAG/rkguIlJuXBtOXX02nEJBhZOISLlxbTh192UACAc1VU9EpNy4Npx6+7IAVFYES1yJiIgcy7Xh1NOfA6A+qksXiYiUG9eGU3/cTiGvrwmVuBIRETmWa8MpFisA0NxQUeJKRETkWK4Np/6YPb+paYaG9UREyo1rwyket1/r6zSVXESk3Lg3nBIOeDJUR137FoiIlC3XtszJuAf8Kd3LSUSkDLk3nBJejD9JOKCbOYmIlBvXhlM65cX4UwQD6jmJiJQb14ZTNhmw4eRXOImIlBvXhlMuHcD4M4SDGtYTESk3Lg6nIB5fRsN6IiJlyLXhVMiE1HMSESlT7g2ndBiPL0tIs/VERMqOa8PJyVbg9WUJBHSFCBGRcuPKcMoX8pCpwOPL4dUhJxGRsuPKcEpkk5CN4PHnS12KiIiMwJXh1BNLguPF61M4iYiUI1eGU1dfGgCvek4iImXJneHUa8PJ73NKXImIiIzEleHU058BwBdQOImIlCN3hlNfFgC/boIrIlKWXBlOvbEcAMFAiQsREZERuTKc+vrtRIiKsE5yEhEpR64Mp/6BnlO0UuEkIlKO3BlO8QIANdW6rp6ISDlyZzjF7Cy9hhoddBIRKUeuDKd43IZTY32oxJWIiMhIXBlOiYT92qRwEhEpS64Mp2QS8GSZUV1Z6lJERGQE7g0nX5LKsI45iYiUI/eGkz9F0K+p5CIi5ciV4ZRKesCXIuB35csXESl7rmyd0ykPxpchFNB5TiIi5ciV4ZRJecGXJhjQsJ6ISDlyZzilfRhfmoDPlS9fRKTsubJ1zmV8GF9Ww3oiImXKleGUzwTweDMa1hMRKVMuDSc/xpdTz0lEpEy5MpwK2YCG9UREyphLwymIx5fTsJ6ISJlyZTg52RAebw6PK1+9iEj5m1DzbIy5zhizzRhTMMasPuaxvzHG7DLGbDfGbBi2/PKBZbuMMR+fyP5PWi6Ex5cvya5FROTEJtp32ApcCzwxfKExZilwPbAMuBz4F2OM1xjjBb4GXAEsBW4YWHfa5HIO5AN4vIXp3K2IiIzDhGYEOI7zEoAx5tiHrgHudRwnDewxxuwC1g48tstxnFcGnnfvwLovTqSO8eiLZ4EAXr96TiIi5WqqjrrMBvYP+7l1YNnxlk+b7r4UAF6fM527FRGRcThhz8kY8ygwc4SHbnMc5/7JL+mofd8M3Awwd+7cSdlmdywFRPH5NKwnIlKuThhOjuNcdhLbPQDMGfZzy8AyRlk+0r7vAu4CWL169aR0dXr7MwD4dIqTiEjZmqphvQeA640xQWPMAmAR8CywEVhkjFlgjAlgJ008MEU1jKgnlgbA55/OvYqIyHhMqP9gjHkr8BVgBvCgMeYFx3E2OI6zzRjzA+xEhxxwq+M4+YHnfBh4GPAC/+44zrYJvYJx6otlAQiHpnOvIiIyHhOdrfcT4CfHeeyzwGdHWP4Q8NBE9jsR/fEcAJGwzsAVESlXrmuh+xO25xSp0KWLRETKlfvCKWbPb4pWakaEiEi5cl04xRI2nGqimhEhIlKuXBdO8YFwqq8OlrgSERE5HveFU9KefFtXHShxJSIicjyuC6dEwp7LW18dLnElIiJyPO4Lp6T9OqOmorSFiIjIcbkunJIJB7wpaioipS5FRESOw3XhlEga8CcJBjSVXESkXLkunFJJA74Ufq/rXrqIyCnDdS10OmXDKRBw3UsXETlluK6FTqe8GH+aoF+XLxIRKVeuC6ds2ofxZQj4FE4iIuXKneHkzRBQz0lEpGy5LpxyGR/GlyUYUDiJiJQr14VTPhOw4eTXVHIRkXLlynDy+LKaECEiUsZcF06FrO056ZiTiEj5cl04OdkgHm8en07CFREpW65roZ1cEI8vV+oyRERkFK4KJ8cBsmE83nypSxERkVG4KpwyGQccLx6fwklEpJy5Kpy6+1MAeH1OiSsREZHRuDOcvIUSVyIiIqNxVTj19KcB8PnVcxIRKWeuCqfeWAYAn4b1RETKmrvCKV4MpxIXIiIio3JVOPXFsgD4/SUuRERERuWucIrbcAoGTYkrERGR0bgqnPrj9soQkQpdV09EpJy5KpxicXvybWXEVS9bROSU46pWOpaw4VRdqYNOIiLlzFXhFE/Yk29rqxVOIiLlzJXh1FATKHElIiIyGleFUzJpT75tqKkocSUiIjIal4UTQIGGqMJJRKScuSqcEknAnyQaVjiJiJQzV4VTKgn4kgQDun6RiEg5c1k4ecCfxu911csWETnluKqVTqU8GF+KQMBVL1tE5JTjqlY6k/KCL03Ap8sXiYiUM1eFUzbtxfgy+H2uetkiIqccV7XS2bQP48sQDKjnJCJSzlwVTrmMH+PLaFhPRKTMuSqc8pkAxptTOImIlDlXhVMhG8D4spqtJyJS5ibUShtjrjPGbDPGFIwxq4ctX2+M2WyM+d3A1zcNe2zVwPJdxpgvG2Om7ba0hUwQj3pOIiJlb6JdiK3AtcATxyzvAN7iOM45wHuA7wx77OvATcCigX+XT7CGMXOyQYwvR9CvK0SIiJSzCbXSjuO8BHBs58dxnOeH/bgNCBtjgkAdEHUc55mB530b+APg5xOpY8z15kJ4vHnN1hMRKXPTcfDlbcBzjuOkgdlA67DHWgeWTblCARgIJ79X4SQiUs5O2HMyxjwKzBzhodscx7n/BM9dBvwj8OaTKc4YczNwM8DcuXNPZhOD4ok84MXjLbympyciIuXlhOHkOM5lJ7NhY0wL8BPgjx3H2T2w+ADQMmy1loFlx9v3XcBdAKtXr3ZOpo6inlgKiOD1FSayGRERmQZTMqxnjKkBHgQ+7jjOk8XljuMcAvqMMesGZun9MTBq72uydPenARROIiKngIlOJX+rMaYVuBB40Bjz8MBDHwYWArcbY14Y+Nc48NiHgG8Cu4DdTNNkiK6+FAA+/4Q6YCIiMg0mOlvvJ9ihu2OX3wHccZznbAKWT2S/J6Oz14aTZpGLiJQ/11wqoac/A4DPX+JCRETkhNwTTn1ZAAKBEhciIiIn5Jpw6o3ZcKoIaxq5iEi5c0049fXnAKiM6ARcEZFy555wiuUBqKnSuJ6ISLlzTTjF4vb8pvoazYgQESl37gmnmD2/qbEuXOJKRETkRFwTTv1xB0yexurKUpciIiIn4Jpwiscc8MeprgyVuhQRETkB94RTAvAnqQjqmJOISLlzTTilEh7wJwj4XfOSRUROWa5pqVNJH8af1l1wRUROAa65DGom5QNfinCgptSliLhSNpultbWVVCpV6lJkCoVCIVpaWvD7J3YIxTXhlE0FMP6Eek4iJdLa2kpVVRXz58/X3ahPU47j0NnZSWtrKwsWLJjQtlwzrJdNhPEEkgT9CieRUkilUtTX1yuYTmPGGOrr6yeld+yacMqlKnTMSaTEFEynv8n6HbsmnAqpSjz+DOGga0YyRUROWa4IJ8cBUlV4A1n1nERkUv3qV7/iqquumtRt3nfffXzmM58B4IknnuD888/H5/Pxox/96Kj1/uM//oNFixaxaNEi/uM//mNw+WWXXUZ3d/ek1jTdXBFOXb0ZwIPPnycUUM9JxO0cx6FQKJS6DAByudxrlt1555186EMfAmDu3Lncc889vPOd7zxqna6uLj796U/zf//3fzz77LN8+tOfHgykd7/73fzLv/zL1Bc/hVzRUh9ojwF1CieRMvGRX3yEFw6/MKnbXDlzJV+6/EvHfXzv3r1s2LCBCy64gM2bN/PQQw+xfft2PvnJT5JOpznzzDP51re+RWVlJZ/5zGf46U9/SjKZ5KKLLuLf/u3fMMawa9cuPvjBD9Le3o7X6+WHP/whALFYjLe//e1s3bqVVatW8Z//+Z8YY9i8eTMf/ehHicViNDQ0cM8999Dc3Mwll1zCypUr+c1vfsMNN9zAX/7lXw7WuWPHDoLBIA0NDQDMnz8fAI/n6L7Eww8/zPr166mrqwNg/fr1/OIXv+CGG27g6quv5uKLL+a2226bzLd4Wrmi53SwPQaA3+8onERcbOfOnXzoQx9i27ZtRCIR7rjjDh599FGee+45Vq9ezRe+8AUAPvzhD7Nx40a2bt1KMpnkZz/7GQB/9Ed/xK233sqWLVt46qmnaG5uBuD555/nS1/6Ei+++CKvvPIKTz75JNlslj/90z/lRz/6EZs3b+Z973vfUWGRyWTYtGnTUcEE8OSTT3L++eef8LUcOHCAOXPmDP7c0tLCgQMHAKitrSWdTtPZ2TmxN6yEXNFSH+5MAuAPQMDnipcsUtZG6+FMpXnz5rFu3ToAnnnmGV588UVe97rXATYsLrzwQgAef/xx7rzzThKJBF1dXSxbtoxLLrmEAwcO8Na3vhWwJ5sWrV27lpaWFgBWrlzJ3r17qampYevWraxfvx6AfD4/GGYA73jHO0as8dChQ8yYMWPCr7WxsZGDBw9SX18/4W2Vgita6rYuO+c+FHBFR1FEjiMSiQx+7zgO69ev5/vf//5R66RSKT70oQ+xadMm5syZw6c+9akTnrcTDAYHv/d6veRyORzHYdmyZTz99NMnrGW4cDhMb2/vCV/L7Nmz+dWvfjX4c2trK5dccslRryMcPnXvX+eK1rq9Kw1ASDP1RGTAunXrePLJJ9m1axcA8XicHTt2DAZRQ0MDsVhscIZcVVUVLS0t3HfffQCk02kSicRxt7948WLa29sHwymbzbJt27YT1rVkyZLBmkazYcMGHnnkEbq7u+nu7uaRRx5hw4YNgA3ew4cPDx6vOhW5Ipw6u+1smHBI4SQi1owZM7jnnnu44YYbWLFiBRdeeCEvv/wyNTU13HTTTSxfvpwNGzawZs2awed85zvf4ctf/jIrVqzgoosu4vDhw8fdfiAQ4Ec/+hEf+9jHOPfcc1m5ciVPPfXUCet6/etfz/PPP4/j2Lt3b9y4kZaWFn74wx/ygQ98gGXLlgFQV1fH3/3d37FmzRrWrFnD7bffPjg5YvPmzaxbtw7fKXwYwxTfgHK3evVqZ9OmTSf13K8++D98/F+e5II55/M//zq55yOIyNi89NJLLFmypNRlnBL+/M//nLe85S1cdtllJ/38q6++mksvvXSSKxubkX7XxpjNjuOsHus2XNFz+vDvX8qVSy6hMjDyGK+ISDn527/921GHDE9k+fLlJQumyeKKcALI5QtM8AruIiLToqmpiauvvvqkn3/TTTdNYjWl4ZpwAqiI6KKTIiKnAleFU1WFJkSIiJwKXBNOHi9UhBVOIiKnAteEk8+LrkguInKKcE04zVsUZ815oROvKCIyTPE8qC9+8YvcfvvtPProo9New969e1m+fPmkbvP555/nT/7kTwB4+eWXufDCCwkGg3z+858/ar1f/OIXLF68mIULF/K5z31ucPn111/Pzp07J7Wm4U7dM7TG6XN/cqXuwiki43L48GE2btw4pis2jCSfz+P1ln7EJpfLveaE3H/4h3/gE5/4BGBP6P3yl788ePWLonw+z6233sovf/lLWlpaWLNmDVdffTVLly7llltu4c477+Qb3/jGlNTsmnBSMImUj498BF6Y3DtmsHIlfGmU68nu3buXq666iq1btwLw+c9/nlgsxqc+9SkuueQSLrjgAh5//HF6enq4++67ufjii3nzm9/MgQMHWLlyJV/5yle4++67ueqqq1i/fj1r167lgQceYPHixdxwww286U1v4qabbqKyspIPfOADPProo3zta18jHA6PeNuMb3zjG9x1111kMhkWLlzId77zHSoqKjhy5Agf/OAHeeWVVwD4+te/zqxZs8jn89x000089dRTzJ49m/vvv59wOMzu3bu59dZbaW9vp6Kigm984xucffbZ3HjjjYRCIZ5//nle97rXDV5xHaC/v5/f/va3nHvuuYC9SGxjYyMPPvjgUe/Zs88+y8KFCznjjDMA21u6//77Wbp0KRdffDE33njjiME3GVwzrCciMppcLsezzz7Ll770JT796U8D8MADD3DmmWfywgsvcPHFFw+uW11dzVe/+lVuvPFG7r33Xrq7uwfPLYrH41xwwQVs2bKFCy644Li3zbj22mvZuHEjW7ZsYcmSJdx9990A/Nmf/RlveMMb2LJlC88999zg5Yp27tzJrbfeyrZt26ipqeHHP/4xADfffDNf+cpX2Lx5M5///OcHb1II9mKwTz311FHBBLBp06YxDROOdlsOj8fDwoUL2bJly/je6DFyTc9JRMrHaD2cUrn22msBWLVqFXv37j3h+uvXr+eHP/zh4P2dirxeL29729sA2L59+3Fvm7F161Y+8YlP0NPTQywWG7xo62OPPca3v/3twW1VV1fT3d3NggULWLly5VE1xmIxnnrqKa677rrB/afT6cHvr7vuuhGHFSf7thyrVq2a8LaOpXASEVfw+XxH3Zr92NtgFG97UbzlxYkUCgVeeuklKioq6O7uHryfUygUGgyE0W6bceONN3Lfffdx7rnncs899xx1+4uRHHtbjmQySaFQoKamhheOM0Y62m05TnQbELC35di/f//gz62trcyePXvw56m8LYeG9UTEFZqammhra6Ozs5N0Oj14d9uT9cUvfpElS5bwve99j/e+971ks9nXrDPabTP6+/tpbm4mm83y3e9+d/A5l156KV//+tcB29Ma7d5O0WiUBQsWDN4u3nGcMQ2zjfW2HGvWrGHnzp3s2bOHTCbDvffee9RllXbs2DHpswiLFE4i4gp+v5/bb7+dtWvXsn79es4+++yT3tb27dv55je/yT/90z9x8cUX8/rXv5477rjjNeuNdtuMv//7v+eCCy7gda973VG1/PM//zOPP/4455xzDqtWreLFF18ctZbvfve73H333Zx77rksW7aM+++//4T1n3322fT29lTzjU8AAAciSURBVNLf3w/YWYktLS184Qtf4I477qClpYW+vj58Ph9f/epX2bBhA0uWLOEP//APB4+BHTlyhHA4zMyZM8f8vo2HK26ZISKlp1tmlJcvfvGLVFVV8f73v/+knx+NRgfPlRpOt8wQEZGTcssttxx1HGu8ampqeM973jOJFR1NEyJERFwoFArx7ne/+6Sf/973vncSq3kt9ZxEZNqcKocR5ORN1u9Y4SQi0yIUCtHZ2amAOo05jkNnZyeh0MSvY6phPRGZFi0tLbS2ttLe3l7qUmQKhUKhwXO+JkLhJCLTwu/3s2DBglKXIaeICQ3rGWOuM8ZsM8YUjDGvmSJojJlrjIkZY/5q2LLLjTHbjTG7jDEfn8j+RUTk9DTRY05bgWuBJ47z+BeAnxd/MMZ4ga8BVwBLgRuMMUsnWIOIiJxmJjSs5zjOSzDy7SiMMX8A7AHiwxavBXY5jvPKwDr3AtcAo58CLSIirjIlx5yMMZXAx4D1wF8Ne2g2sH/Yz63ABaNs52bg5oEfY8aY7RMoqwHomMDzTyd6L46m92OI3oshei+GTMZ7MW88K58wnIwxjwIjXTzpNsdxjncRp08BX3QcJzaRm/w5jnMXcNdJb2AYY8ym8Vw643Sm9+Joej+G6L34/9s7vxCrijiOf764qSWhaw+ytYIrSOBTmdAuRUSJlkQQ+LAWaP9eiqA/D+HiSz0mESVCCv0hokwzKdkeljKft5JKLV1dM2pF04SUHgLDXw/z271nl3vP4qo7d/f8PjDcOb+Zc34zc773zL1z5sypEW1RI0dbjNs5mdmKCRz3TmCNpE3APOCSpH+B/cDCQr524OQEjh8EQRBMY67JsJ6ZjbwyUtIrwD9mtkVSC7BEUgepU+oGHr0WZQiCIAimLlc6lfwRSUNAF/ClpL6y/Gb2H/Ac0AccBnaa2c9XUobL4KoMD04Toi1GE+1RI9qiRrRFjUlviynzyowgCIKgOsTaekEQBEHTEZ1TEARB0HRUonOaLksmSVooaZ+kX3zZqOfdPl/SV5KO+Wer2yVps9f7gKRlhWOt9/zHJK0v2O+QdND32Sx/FqCRj9xImiHpB0m9vt0hqd/Lv0PSTLfP8u1BT19UOEaP2wckrSrY6+qmkY/cSJonaZekI5IOS+qqqjYkvejfkUOStkuaXRVtSHpP0hlJhwq2bDoo81GKmU3rAMwAjgOLgZnAT8DS3OWaYF3agGUevxE4SloGahOwwe0bgNc8vpq0fJSATqDf7fOBX/2z1eOtnvat55Xv+6Db6/rIHYCXgI+BXt/eCXR7fCvwjMefBbZ6vBvY4fGlrolZQIdrZUaZbhr5yB2AD4CnPT6T9BhH5bRBetj/BHB94Xw9XhVtAPcAy4BDBVs2HTTyMW49cn+hJuFEdQF9he0eoCd3ua5S3b4grcIxALS5rQ0Y8Pg2YG0h/4CnrwW2Fezb3NYGHCnYR/I18pG5/u3AXuA+oNfF/xfQMvbck2aIdnm8xfNprB6G8zXSTZmPzG0xl3RB1hh75bRBbSWa+X6ue4FVVdIGsIjRnVM2HTTyMV4dqjCsV2/JpFsyleWq4UMPtwP9wAIzO+VJp4EFHm9U9zL7UB07JT5y8ibwMnDJt28C/rb0yAKMLv9InT39vOe/3DYq85GTDuAs8L7SMOc7kuZQQW2Y2UngdeB34BTpXO+nutqAvDqY0DW4Cp3TtENp7cLPgBfM7EIxzdJPk2v6fMBk+BgPSQ8BZ8xsf85yNBEtpKGct83sdtKCy6Pur1ZIG62kBaU7gJuBOcADOcvUTEwVHVShczrJNFoySdJ1pI7pIzPb7eY/JbV5ehtwxu2N6l5mb69jL/ORi7uAhyX9BnxCGtp7C5intBIJjC7/SJ09fS5wjstvo3MlPnIyBAyZWb9v7yJ1VlXUxgrghJmdNbOLwG6SXqqqDcirgwldg6vQOX2HL5nkM2e6gT2ZyzQhfFbMu8BhM3ujkLQHGJ5Ns550L2rYvs5ny3QC5/1vdx+wUlKr/8pcSRobPwVckNTpvtaNOVY9H1kwsx4zazezRaRz+o2ZPQbsA9Z4trFtMVz+NZ7f3N7tM7Y6gCWkG751deP7NPKRDTM7Dfwh6VY33U96FU3ltEEazuuUdIOXdbgtKqkNJ6cOGvkoJ8fNuskOpNkiR0kzbDbmLs8V1ONu0l/lA8CPHlaTxrr3AseAr4H5nl+klzseBw4CywvHehIY9PBEwb6c9BLJ48AWaquI1PXRDAG4l9psvcWkC8gg8Ckwy+2zfXvQ0xcX9t/o9R3AZx6V6aaRj9wBuA343vXxOWmWVSW1AbwKHPHyfkiacVcJbQDbSffaLpL+UT+VUwdlPspCLF8UBEEQNB1VGNYLgiAIphjROQVBEARNR3ROQRAEQdMRnVMQBEHQdETnFARBEDQd0TkFQRAETUd0TkEQBEHT8T/alYnNNjcA5QAAAABJRU5ErkJggg==\n", 56 | "text/plain": [ 57 | "
" 58 | ] 59 | }, 60 | "metadata": { 61 | "needs_background": "light" 62 | }, 63 | "output_type": "display_data" 64 | } 65 | ], 66 | "source": [ 67 | "fig = pu.plot_results(results, average_group=True, split_fn=lambda _: '', shaded_std=False)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [] 76 | } 77 | ], 78 | "metadata": { 79 | "kernelspec": { 80 | "display_name": "Python 3", 81 | "language": "python", 82 | "name": "python3" 83 | }, 84 | "language_info": { 85 | "codemirror_mode": { 86 | "name": "ipython", 87 | "version": 3 88 | }, 89 | "file_extension": ".py", 90 | "mimetype": "text/x-python", 91 | "name": "python", 92 | "nbconvert_exporter": "python", 93 | "pygments_lexer": "ipython3", 94 | "version": "3.6.8" 95 | } 96 | }, 97 | "nbformat": 4, 98 | "nbformat_minor": 2 99 | } 100 | --------------------------------------------------------------------------------