├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── install_mujoco.sh ├── scripts └── main.py ├── setup.py └── src └── atac ├── atac.py └── garage_tools ├── rl_utils.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # ATAC 132 | exp_data -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ATAC: Adversarially Trained Actor Critic 2 | 3 | This repository contains the code to reproduce the experimental results of ATAC algorithm in the paper Adversarially Trained Actor Critic for Offline Reinforcement Learning by Ching-An Cheng*, Tengyang Xie*, Nan Jiang, and Alekh Agarwal (https://arxiv.org/abs/2202.02446). 4 | 5 | ***Please see also https://github.com/microsoft/lightATAC for a lightweight reimplementation of ATAC, which gives a 1.5-2X speed up compared with the original code here. 6 | 7 | ### Setup 8 | 9 | #### Clone the repository and create a conda environment. 10 | ``` 11 | git clone https://github.com/microsoft/ATAC.git 12 | conda create -n atac python=3.8 13 | cd atac 14 | ``` 15 | #### Prerequisite: Install Mujoco 16 | (Optional) Install free mujoco210 for mujoco_py and mujoco211 for dm_control. 17 | ``` 18 | bash install_mujoco.sh 19 | echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mujoco210/bin:/usr/lib/nvidia" >> ~/.bashrc 20 | source ~/.bashrc 21 | ``` 22 | #### Install ATAC 23 | ``` 24 | conda activate atac 25 | pip install -e .[mujoco210] 26 | # or below, if the original paid mujoco is used. 27 | pip install -e .[mujoco200] 28 | ``` 29 | #### Run ATAC 30 | ``` 31 | python scripts/main.py 32 | ``` 33 | 34 | ### Contributing 35 | 36 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 37 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 38 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 39 | 40 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 41 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 42 | provided by the bot. You will only need to do this once across all repos using our CLA. 43 | 44 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 45 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 46 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 47 | 48 | ### Trademarks 49 | 50 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 51 | trademarks or logos is subject to and must follow 52 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 53 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 54 | Any use of third-party trademarks or logos are subject to those third-party's policies. 55 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /install_mujoco.sh: -------------------------------------------------------------------------------- 1 | sudo apt-get install libglew-dev patchelf 2 | wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-x86_64.tar.gz 3 | wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz 4 | tar -xf mujoco-2.1.1-linux-x86_64.tar.gz 5 | tar -xf mujoco210-linux-x86_64.tar.gz 6 | rm mujoco-2.1.1-linux-x86_64.tar.gz mujoco210-linux-x86_64.tar.gz 7 | mkdir ~/.mujoco 8 | mv mujoco210 mujoco-2.1.1 ~/.mujoco -------------------------------------------------------------------------------- /scripts/main.py: -------------------------------------------------------------------------------- 1 | import gym, d4rl, torch, os 2 | 3 | import numpy as np 4 | from urllib.error import HTTPError 5 | from garage.envs import GymEnv 6 | from garage.experiment.deterministic import set_seed 7 | from garage.replay_buffer import PathBuffer 8 | from garage.torch.algos import SAC 9 | from garage.torch.policies import TanhGaussianMLPPolicy 10 | from garage.torch.q_functions import ContinuousMLPQFunction 11 | 12 | 13 | from atac.atac import ATAC 14 | from atac.garage_tools.rl_utils import train_agent, get_sampler, setup_gpu, get_algo, get_log_dir_name, load_algo 15 | from atac.garage_tools.trainer import Trainer 16 | 17 | 18 | def load_d4rl_data_as_buffer(dataset, replay_buffer): 19 | assert isinstance(replay_buffer, PathBuffer) 20 | replay_buffer.add_path( 21 | dict(observation=dataset['observations'], 22 | action=dataset['actions'], 23 | reward=dataset['rewards'].reshape(-1, 1), 24 | next_observation=dataset['next_observations'], 25 | terminal=dataset['terminals'].reshape(-1,1), 26 | )) 27 | 28 | def train_func(ctxt=None, 29 | *, 30 | algo='ATAC', 31 | # Environment parameters 32 | env_name, 33 | # Evaluation mode 34 | evaluation_mode=False, 35 | policy_path=None, 36 | # Trainer parameters 37 | n_epochs=3000, # number of training epochs 38 | batch_size=0, # number of samples collected per update 39 | replay_buffer_size=int(2e6), 40 | # Network parameters 41 | policy_hidden_sizes=(256, 256, 256), 42 | policy_activation='ReLU', 43 | policy_init_std=1.0, 44 | value_hidden_sizes=(256, 256, 256), 45 | value_activation='ReLU', 46 | min_std=1e-5, 47 | # Algorithm parameters 48 | discount=0.99, 49 | policy_lr=5e-7, # optimization stepsize for policy update 50 | value_lr=5e-4, # optimization stepsize for value regression 51 | target_update_tau=5e-3, # for target network 52 | minibatch_size=256, # optimization/replaybuffer minibatch size 53 | n_grad_steps=2000, # number of gradient updates per epoch 54 | n_warmstart_steps=200000, # number of warm-up steps 55 | fixed_alpha=None, # whether to fix the temperate parameter 56 | use_deterministic_evaluation=True, # do evaluation based on the deterministic policy 57 | num_evaluation_episodes=5, # number of episodes to evaluate (only affect off-policy algorithms) 58 | # ATAC parameters 59 | beta=1.0, # weight on the Bellman error 60 | norm_constraint=100, 61 | use_two_qfs=True, # whether to use two q function 62 | q_eval_mode='0.5_0.5', 63 | init_pess=False, 64 | # Compute parameters 65 | seed=0, 66 | n_workers=1, # number of workers for data collection 67 | gpu_id=-1, # try to use gpu, if implemented 68 | force_cpu_data_collection=True, # use cpu for data collection. 69 | # Logging parameters 70 | save_mode='light', 71 | ignore_shutdown=False, # do not shutdown workers after training 72 | return_mode='average', # 'full', 'average', 'last' 73 | return_attr='Evaluation/AverageReturn', # the log attribute 74 | ): 75 | 76 | """ Train an agent in batch mode. """ 77 | 78 | # Set the random seed 79 | set_seed(seed) 80 | 81 | # Initialize gym env 82 | dataset = None 83 | d4rl_env = gym.make(env_name) # d4rl env 84 | while dataset is None: 85 | try: 86 | dataset = d4rl.qlearning_dataset(d4rl_env) 87 | except (HTTPError, OSError): 88 | print('Unable to download dataset. Retry.') 89 | pass 90 | 91 | if init_pess: # for ATAC0 92 | dataset_raw = d4rl_env.get_dataset() 93 | ends = dataset_raw['terminals']+ dataset_raw['timeouts'] 94 | starts = np.concatenate([[True], ends[:-1]]) 95 | init_observations = dataset_raw['observations'][starts] 96 | else: 97 | init_observations = None 98 | 99 | # Initialize replay buffer and gymenv 100 | env = GymEnv(d4rl_env) 101 | replay_buffer = PathBuffer(capacity_in_transitions=int(replay_buffer_size)) 102 | load_d4rl_data_as_buffer(dataset, replay_buffer) 103 | reward_scale = 1.0 104 | 105 | # Initialize the algorithm 106 | env_spec = env.spec 107 | 108 | policy = TanhGaussianMLPPolicy( 109 | env_spec=env_spec, 110 | hidden_sizes=policy_hidden_sizes, 111 | hidden_nonlinearity=eval('torch.nn.'+policy_activation), 112 | init_std=policy_init_std, 113 | min_std=min_std) 114 | 115 | qf1 = ContinuousMLPQFunction( 116 | env_spec=env_spec, 117 | hidden_sizes=value_hidden_sizes, 118 | hidden_nonlinearity=eval('torch.nn.'+value_activation), 119 | output_nonlinearity=None) 120 | 121 | qf2 = ContinuousMLPQFunction( 122 | env_spec=env_spec, 123 | hidden_sizes=value_hidden_sizes, 124 | hidden_nonlinearity=eval('torch.nn.'+value_activation), 125 | output_nonlinearity=None) 126 | 127 | sampler = get_sampler(policy, env, n_workers=n_workers) 128 | 129 | Algo = globals()[algo] 130 | 131 | algo_config = dict( # union of all algorithm configs 132 | env_spec=env_spec, 133 | policy=policy, 134 | qf1=qf1, 135 | qf2=qf2, 136 | sampler=sampler, 137 | replay_buffer=replay_buffer, 138 | discount=discount, 139 | policy_lr=policy_lr, 140 | qf_lr=value_lr, 141 | target_update_tau=target_update_tau, 142 | buffer_batch_size=minibatch_size, 143 | gradient_steps_per_itr=n_grad_steps, 144 | use_deterministic_evaluation=use_deterministic_evaluation, 145 | min_buffer_size=int(0), 146 | num_evaluation_episodes=num_evaluation_episodes, 147 | fixed_alpha=fixed_alpha, 148 | reward_scale=reward_scale, 149 | ) 150 | 151 | # ATAC 152 | extra_algo_config = dict( 153 | beta=beta, 154 | norm_constraint=norm_constraint, 155 | use_two_qfs=use_two_qfs, 156 | n_warmstart_steps=n_warmstart_steps, 157 | q_eval_mode=q_eval_mode, 158 | init_observations=init_observations, 159 | ) 160 | 161 | algo_config.update(extra_algo_config) 162 | 163 | algo = Algo(**algo_config) 164 | 165 | setup_gpu(algo, gpu_id=gpu_id) 166 | 167 | # Initialize the trainer 168 | from atac.garage_tools.trainer import BatchTrainer as Trainer 169 | trainer = Trainer(ctxt) 170 | trainer.setup(algo=algo, 171 | env=env, 172 | force_cpu_data_collection=force_cpu_data_collection, 173 | save_mode=save_mode, 174 | return_mode=return_mode, 175 | return_attr=return_attr) 176 | 177 | return trainer.train(n_epochs=n_epochs, 178 | batch_size=batch_size, 179 | ignore_shutdown=ignore_shutdown) 180 | 181 | 182 | def run(log_root='.', 183 | torch_n_threads=2, 184 | snapshot_frequency=0, 185 | **train_kwargs): 186 | torch.set_num_threads(torch_n_threads) 187 | log_dir = get_log_dir_name(train_kwargs, ['beta', 'discount', 'norm_constraint', 188 | 'policy_lr', 'value_lr', 189 | 'use_two_qfs', 190 | 'fixed_alpha', 191 | 'q_eval_mode', 192 | 'n_warmstart_steps', 'seed']) 193 | train_kwargs['return_mode'] = 'full' 194 | 195 | # Offline training 196 | log_dir_path = os.path.join(log_root,'exp_data','Offline'+train_kwargs['algo']+'_'+train_kwargs['env_name'], log_dir) 197 | full_score = train_agent(train_func, 198 | log_dir=log_dir_path, 199 | train_kwargs=train_kwargs, 200 | snapshot_frequency=snapshot_frequency, 201 | x_axis='Epoch') 202 | 203 | window = 50 204 | score = np.median(full_score[-min(len(full_score),window):]) 205 | print('Median of performance of last {} epochs'.format(window), score) 206 | return {'score': score, # last 50 epochs 207 | 'mean': np.mean(full_score)} 208 | 209 | if __name__=='__main__': 210 | import argparse 211 | from atac.garage_tools.utils import str2bool 212 | parser = argparse.ArgumentParser() 213 | parser.add_argument('-a', '--algo', type=str, default='ATAC') 214 | parser.add_argument('-e', '---env_name', type=str, default='hopper-medium-replay-v2') 215 | parser.add_argument('--n_epochs', type=int, default=3000) 216 | parser.add_argument('--discount', type=float, default=0.99) 217 | parser.add_argument('--gpu_id', type=int, default=-1) # use cpu by default 218 | parser.add_argument('--n_workers', type=int, default=1) 219 | parser.add_argument('--force_cpu_data_collection', type=str2bool, default=True) 220 | parser.add_argument('--seed', type=int, default=0) 221 | parser.add_argument('--n_warmstart_steps', type=int, default=100000) 222 | parser.add_argument('--fixed_alpha', type=float, default=None) 223 | parser.add_argument('--beta', type=float, default=16) 224 | parser.add_argument('--norm_constraint', type=float, default=100) 225 | parser.add_argument('--policy_lr', type=float, default=5e-7) 226 | parser.add_argument('--value_lr', type=float, default=5e-4) 227 | parser.add_argument('--target_update_tau', type=float, default=5e-3) 228 | parser.add_argument('--use_deterministic_evaluation', type=str2bool, default=True) 229 | parser.add_argument('--use_two_qfs', type=str2bool, default=True) 230 | parser.add_argument('--q_eval_mode', type=str, default='0.5_0.5') 231 | parser.add_argument('--init_pess', type=str2bool, default=False) # turn this on of ATAC0 232 | 233 | train_kwargs = vars(parser.parse_args()) 234 | run(**train_kwargs) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | setup( 4 | name='ATAC', 5 | version='0.1.0', 6 | author='Ching-An Cheng', 7 | author_email='chinganc@microsoft.com', 8 | package_dir={'':'src'}, 9 | python_requires='>=3.8', 10 | packages=['atac'], 11 | url='https://github.com/microsoft/ATAC', 12 | license='MIT LICENSE', 13 | description='ATAC code', 14 | long_description=open('README.md').read(), 15 | install_requires=[ 16 | "garage==2021.3.0", 17 | "gym==0.17.2",], 18 | extras_require={ 19 | 'mujoco200': ["mujoco_py==2.0.2.8", "d4rl @ git+https://github.com/chinganc/d4rl@master#egg=d4rl"], 20 | 'mujoco210': ["d4rl @ git+https://github.com/rail-berkeley/d4rl@master#egg=d4rl"]} 21 | ) 22 | -------------------------------------------------------------------------------- /src/atac/atac.py: -------------------------------------------------------------------------------- 1 | # yapf: disable 2 | from collections import deque 3 | import copy 4 | 5 | from dowel import tabular 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from garage import log_performance, obtain_evaluation_episodes, StepType 11 | from garage.np.algos import RLAlgorithm 12 | from garage.torch import as_torch_dict, global_device 13 | # yapf: enable 14 | 15 | torch.set_flush_denormal(True) 16 | 17 | def normalized_sum(loss, reg, w): 18 | return loss/w + reg if w>1 else loss + w*reg 19 | 20 | def l2_projection(constraint): 21 | @torch.no_grad() 22 | def fn(module): 23 | if hasattr(module, 'weight') and constraint>0: 24 | w = module.weight 25 | norm = torch.norm(w) 26 | w.mul_(torch.clip(constraint/norm, max=1)) 27 | return fn 28 | 29 | def weight_l2(model): 30 | l2 = 0. 31 | for name, param in model.named_parameters(): 32 | if 'weight' in name: 33 | l2 += torch.norm(param)**2 34 | return l2 35 | 36 | class ATAC(RLAlgorithm): 37 | """ Adversarilly Trained Actor Critic """ 38 | def __init__( 39 | self, 40 | env_spec, 41 | policy, 42 | qf1, 43 | qf2, 44 | replay_buffer, 45 | sampler, 46 | *, # Everything after this is numbers. 47 | max_episode_length_eval=None, 48 | gradient_steps_per_itr, 49 | fixed_alpha=None, 50 | target_entropy=None, 51 | initial_log_entropy=0., 52 | discount=0.99, 53 | buffer_batch_size=256, 54 | min_buffer_size=int(1e4), 55 | target_update_tau=5e-3, 56 | policy_lr=5e-7, 57 | qf_lr=5e-4, 58 | reward_scale=1.0, 59 | optimizer='Adam', 60 | steps_per_epoch=1, 61 | num_evaluation_episodes=10, 62 | eval_env=None, 63 | use_deterministic_evaluation=True, 64 | # ATAC parameters 65 | beta=1.0, # the regularization coefficient in front of the Bellman error 66 | lambd=0., # coeff for global pessimism 67 | init_observations=None, # for ATAC0 (None or np.ndarray) 68 | n_warmstart_steps=200000, 69 | norm_constraint=100, 70 | q_eval_mode='0.5_0.5', # 'max' 'w1_w2', 'adaptive' 71 | q_eval_loss='MSELoss', # 'MSELoss', 'SmoothL1Loss' 72 | use_two_qfs=True, 73 | terminal_value=None, 74 | Vmin=-float('inf'), # min value of Q (used in target backup) 75 | Vmax=float('inf'), # max value of Q (used in target backup) 76 | debug=True, 77 | stats_avg_rate=0.99, # for logging 78 | bellman_surrogate='td', #'td', None, 'target' 79 | ): 80 | 81 | ############################################################################################# 82 | 83 | assert beta>=0 84 | assert norm_constraint>=0 85 | # Parsing 86 | optimizer = eval('torch.optim.'+optimizer) 87 | policy_lr = qf_lr if policy_lr is None or policy_lr < 0 else policy_lr # use shared lr if not provided. 88 | 89 | ## ATAC parameters 90 | self.beta = torch.Tensor([beta]) # regularization constant on the Bellman surrogate 91 | self._lambd = torch.Tensor([lambd]) # global pessimism coefficient 92 | self._init_observations = torch.Tensor(init_observations) if init_observations is not None else init_observations # if provided, it runs ATAC0 93 | self._n_warmstart_steps = n_warmstart_steps # during which, it performs independent C and Bellman minimization 94 | # q update parameters 95 | self._norm_constraint = norm_constraint # l2 norm constraint on the qf's weight; if negative, it gives the weight decay coefficient. 96 | self._q_eval_mode = [float(w) for w in q_eval_mode.split('_')] if '_' in q_eval_mode else q_eval_mode # residual algorithm 97 | self._q_eval_loss = eval('torch.nn.'+q_eval_loss)(reduction='none') 98 | self._use_two_qfs = use_two_qfs 99 | self._Vmin = Vmin # lower bound on the target 100 | self._Vmax = Vmax # upper bound on the target 101 | self._terminal_value = terminal_value if terminal_value is not None else lambda r, gamma: 0. 102 | 103 | # Stepsizes 104 | self._alpha_lr = qf_lr # potentially a larger stepsize, for the most inner optimization. 105 | self._bc_policy_lr = qf_lr # potentially a larger stepsize 106 | 107 | # Logging and algorithm state 108 | self._debug = debug 109 | self._n_updates_performed = 0 # Counter of number of grad steps performed 110 | self._cac_learning=False 111 | self._stats_avg_rate = stats_avg_rate 112 | self._bellman_surrogate = bellman_surrogate 113 | self._avg_bellman_error = 1. # for logging; so this works with zero warm-start 114 | self._avg_terminal_td_error = 1 115 | 116 | ############################################################################################# 117 | # Original SAC parameters 118 | self._qf1 = qf1 119 | self._qf2 = qf2 120 | self.replay_buffer = replay_buffer 121 | self._tau = target_update_tau 122 | self._policy_lr = policy_lr 123 | self._qf_lr = qf_lr 124 | self._initial_log_entropy = initial_log_entropy 125 | self._gradient_steps = gradient_steps_per_itr 126 | self._optimizer = optimizer 127 | self._num_evaluation_episodes = num_evaluation_episodes 128 | self._eval_env = eval_env 129 | 130 | self._min_buffer_size = min_buffer_size 131 | self._steps_per_epoch = steps_per_epoch 132 | self._buffer_batch_size = buffer_batch_size 133 | self._discount = discount 134 | self._reward_scale = reward_scale 135 | self.max_episode_length = env_spec.max_episode_length 136 | self._max_episode_length_eval = env_spec.max_episode_length 137 | 138 | if max_episode_length_eval is not None: 139 | self._max_episode_length_eval = max_episode_length_eval 140 | self._use_deterministic_evaluation = use_deterministic_evaluation 141 | 142 | self.policy = policy 143 | self.env_spec = env_spec 144 | self.replay_buffer = replay_buffer 145 | 146 | self._sampler = sampler 147 | 148 | # use 2 target q networks 149 | self._target_qf1 = copy.deepcopy(self._qf1) 150 | self._target_qf2 = copy.deepcopy(self._qf2) 151 | self._policy_optimizer = self._optimizer(self.policy.parameters(), 152 | lr=self._bc_policy_lr) # lr for warmstart 153 | self._qf1_optimizer = self._optimizer(self._qf1.parameters(), 154 | lr=self._qf_lr) 155 | self._qf2_optimizer = self._optimizer(self._qf2.parameters(), 156 | lr=self._qf_lr) 157 | 158 | # automatic entropy coefficient tuning 159 | self._use_automatic_entropy_tuning = fixed_alpha is None 160 | self._fixed_alpha = fixed_alpha 161 | if self._use_automatic_entropy_tuning: 162 | if target_entropy: 163 | self._target_entropy = target_entropy 164 | else: 165 | self._target_entropy = -np.prod( 166 | self.env_spec.action_space.shape).item() 167 | self._log_alpha = torch.Tensor([self._initial_log_entropy 168 | ]).requires_grad_() 169 | self._alpha_optimizer = optimizer([self._log_alpha], 170 | lr=self._alpha_lr) 171 | else: 172 | self._log_alpha = torch.Tensor([self._fixed_alpha]).log() 173 | self.episode_rewards = deque(maxlen=30) 174 | 175 | 176 | def optimize_policy(self, 177 | samples_data, 178 | warmstart=False): 179 | """Optimize the policy q_functions, and temperature coefficient. 180 | 181 | Args: 182 | samples_data (dict): Transitions(S,A,R,S') that are sampled from 183 | the replay buffer. It should have the keys 'observation', 184 | 'action', 'reward', 'terminal', and 'next_observations'. 185 | 186 | Note: 187 | samples_data's entries should be torch.Tensor's with the following 188 | shapes: 189 | observation: :math:`(N, O^*)` 190 | action: :math:`(N, A^*)` 191 | reward: :math:`(N, 1)` 192 | terminal: :math:`(N, 1)` 193 | next_observation: :math:`(N, O^*)` 194 | 195 | Returns: 196 | torch.Tensor: loss from actor/policy network after optimization. 197 | torch.Tensor: loss from 1st q-function after optimization. 198 | torch.Tensor: loss from 2nd q-function after optimization. 199 | 200 | """ 201 | 202 | obs = samples_data['observation'] 203 | next_obs = samples_data['next_observation'] 204 | actions = samples_data['action'] 205 | rewards = samples_data['reward'].flatten() * self._reward_scale 206 | terminals = samples_data['terminal'].flatten() 207 | 208 | ##### Update Critic ##### 209 | def compute_bellman_backup(q_pred_next): 210 | assert rewards.shape == q_pred_next.shape 211 | return rewards + (1.-terminals) * self._discount * q_pred_next + terminals * self._terminal_value(rewards, self._discount) 212 | 213 | def compute_bellman_loss(q_pred, q_pred_next, q_target): 214 | assert q_pred.shape == q_pred_next.shape == q_target.shape 215 | target_error = self._q_eval_loss(q_pred, q_target) 216 | q_target_pred = compute_bellman_backup(q_pred_next) 217 | td_error = self._q_eval_loss(q_pred, q_target_pred) 218 | w1, w2 = self._q_eval_mode 219 | bellman_loss = w1*target_error+ w2*td_error 220 | return bellman_loss, target_error, td_error 221 | 222 | ## Compute Bellman error 223 | with torch.no_grad(): 224 | new_next_actions_dist = self.policy(next_obs)[0] 225 | _, new_next_actions = new_next_actions_dist.rsample_with_pre_tanh_value() 226 | target_q_values = self._target_qf1(next_obs, new_next_actions) 227 | if self._use_two_qfs: 228 | target_q_values = torch.min(target_q_values, self._target_qf2(next_obs, new_next_actions)) 229 | target_q_values = torch.clip(target_q_values, min=self._Vmin, max=self._Vmax) # projection 230 | q_target = compute_bellman_backup(target_q_values.flatten()) 231 | 232 | qf1_pred = self._qf1(obs, actions).flatten() 233 | qf1_pred_next = self._qf1(next_obs, new_next_actions).flatten() 234 | qf1_bellman_losses, qf1_target_errors, qf1_td_errors = compute_bellman_loss(qf1_pred, qf1_pred_next, q_target) 235 | qf1_bellman_loss = qf1_bellman_losses.mean() 236 | 237 | qf2_bellman_loss = qf2_target_error = qf2_td_error = torch.Tensor([0.]) 238 | if self._use_two_qfs: 239 | qf2_pred = self._qf2(obs, actions).flatten() 240 | qf2_pred_next = self._qf2(next_obs, new_next_actions).flatten() 241 | qf2_bellman_losses, qf2_target_errors, qf2_td_errors = compute_bellman_loss(qf2_pred, qf2_pred_next, q_target) 242 | qf2_bellman_loss = qf2_bellman_losses.mean() 243 | 244 | # Compute GAN error 245 | # These samples will be used for the actor update too, so they need to be traced. 246 | new_actions_dist = self.policy(obs)[0] 247 | new_actions_pre_tanh, new_actions = new_actions_dist.rsample_with_pre_tanh_value() 248 | 249 | gan_qf1_loss = gan_qf2_loss = 0 250 | if not warmstart: # Compute gan_qf1_loss, gan_qf2_loss 251 | if self._init_observations is None: 252 | # Compute value difference 253 | qf1_new_actions = self._qf1(obs, new_actions.detach()) 254 | gan_qf1_loss = (qf1_new_actions*(1+self._lambd) - qf1_pred).mean() 255 | if self._use_two_qfs: 256 | qf2_new_actions = self._qf2(obs, new_actions.detach()) 257 | gan_qf2_loss = (qf2_new_actions*(1+self._lambd) - qf2_pred).mean() 258 | else: # initial state pessimism 259 | idx_ = np.random.choice(len(self._init_observations), self._buffer_batch_size) 260 | init_observations = self._init_observations[idx_] 261 | init_actions_dist = self.policy(init_observations)[0] 262 | init_actions_pre_tanh, init_actions = init_actions_dist.rsample_with_pre_tanh_value() 263 | qf1_new_actions = self._qf1(init_observations, init_actions.detach()) 264 | gan_qf1_loss = qf1_new_actions.mean() 265 | if self._use_two_qfs: 266 | qf2_new_actions = self._qf2(init_observations, init_actions.detach()) 267 | gan_qf2_loss = qf2_new_actions.mean() 268 | 269 | 270 | ## Compute full q loss 271 | # We normalized the objective to prevent exploding gradients 272 | # qf1_loss = gan_qf1_loss + beta * qf1_bellman_loss 273 | # qf2_loss = gan_qf2_loss + beta * qf2_bellman_loss 274 | with torch.no_grad(): 275 | beta = self.beta 276 | qf1_loss = normalized_sum(gan_qf1_loss, qf1_bellman_loss, beta) 277 | qf2_loss = normalized_sum(gan_qf2_loss, qf2_bellman_loss, beta) 278 | 279 | if beta>0 or not warmstart: 280 | self._qf1_optimizer.zero_grad() 281 | qf1_loss.backward() 282 | self._qf1_optimizer.step() 283 | self._qf1.apply(l2_projection(self._norm_constraint)) 284 | 285 | if self._use_two_qfs: 286 | self._qf2_optimizer.zero_grad() 287 | qf2_loss.backward() 288 | self._qf2_optimizer.step() 289 | self._qf2.apply(l2_projection(self._norm_constraint)) 290 | 291 | ##### Update Actor ##### 292 | 293 | # Compuate entropy 294 | log_pi_new_actions = new_actions_dist.log_prob(value=new_actions, pre_tanh_value=new_actions_pre_tanh) 295 | policy_entropy = -log_pi_new_actions.mean() 296 | 297 | alpha_loss = 0 298 | if self._use_automatic_entropy_tuning: # it comes first; seems to work also when put after policy update 299 | alpha_loss = self._log_alpha * (policy_entropy.detach() - self._target_entropy) # entropy - target 300 | self._alpha_optimizer.zero_grad() 301 | alpha_loss.backward() 302 | self._alpha_optimizer.step() 303 | 304 | with torch.no_grad(): 305 | alpha = self._log_alpha.exp() 306 | 307 | lower_bound = 0 308 | if warmstart: # BC warmstart 309 | policy_log_prob = new_actions_dist.log_prob(samples_data['action']) 310 | # policy_loss = - policy_log_prob.mean() - alpha * policy_entropy 311 | policy_loss = normalized_sum(-policy_log_prob.mean(), -policy_entropy, alpha) 312 | else: 313 | # Compute performance difference lower bound 314 | min_q_new_actions = self._qf1(obs, new_actions) 315 | lower_bound = min_q_new_actions.mean() 316 | # policy_loss = - lower_bound - alpha * policy_kl 317 | policy_loss = normalized_sum(-lower_bound, -policy_entropy, alpha) 318 | 319 | self._policy_optimizer.zero_grad() 320 | policy_loss.backward() 321 | self._policy_optimizer.step() 322 | 323 | log_info = dict( 324 | policy_loss=policy_loss, 325 | qf1_loss=qf1_loss, 326 | qf2_loss=qf2_loss, 327 | qf1_bellman_loss=qf1_bellman_loss, 328 | gan_qf1_loss=gan_qf1_loss, 329 | qf2_bellman_loss=qf2_bellman_loss, 330 | gan_qf2_loss=gan_qf2_loss, 331 | beta=beta, 332 | alpha_loss=alpha_loss, 333 | policy_entropy=policy_entropy, 334 | alpha=alpha, 335 | lower_bound=lower_bound, 336 | ) 337 | 338 | # For logging 339 | if self._debug: 340 | with torch.no_grad(): 341 | if self._bellman_surrogate=='td': 342 | qf1_bellman_surrogate = qf1_td_errors.mean() 343 | qf2_bellman_surrogate = qf2_td_errors.mean() 344 | elif self._bellman_surrogate=='target': 345 | qf1_bellman_surrogate = qf1_target_errors.mean() 346 | qf2_bellman_surrogate = qf2_target_errors.mean() 347 | elif self._bellman_surrogate is None: 348 | qf1_bellman_surrogate = qf1_bellman_loss 349 | qf2_bellman_surrogate = qf2_bellman_loss 350 | 351 | bellman_surrogate = torch.max(qf1_bellman_surrogate, qf2_bellman_surrogate) # measure the TD error 352 | self._avg_bellman_error = self._avg_bellman_error*self._stats_avg_rate + bellman_surrogate*(1-self._stats_avg_rate) 353 | 354 | if terminals.sum()>0: 355 | terminal_td_error = (qf1_td_errors * terminals).sum() / terminals.sum() 356 | self._avg_terminal_td_error = self._avg_terminal_td_error*self._stats_avg_rate + terminal_td_error*(1-self._stats_avg_rate) 357 | 358 | qf1_pred_mean = qf1_pred.mean() 359 | qf2_pred_mean = qf2_pred.mean() if self._use_two_qfs else 0. 360 | q_target_mean = q_target.mean() 361 | target_q_values_mean = target_q_values.mean() 362 | qf1_new_actions_mean = qf1_new_actions.mean() if not warmstart else 0. 363 | qf2_new_actions_mean = qf2_new_actions.mean() if not warmstart and self._use_two_qfs else 0. 364 | action_diff = torch.mean(torch.norm(samples_data['action'] - new_actions, dim=1)) 365 | 366 | 367 | debug_log_info = dict( 368 | avg_bellman_error=self._avg_bellman_error, 369 | avg_terminal_td_error=self._avg_terminal_td_error, 370 | qf1_pred_mean=qf1_pred_mean, 371 | qf2_pred_mean=qf2_pred_mean, 372 | q_target_mean=q_target_mean, 373 | target_q_values_mean=target_q_values_mean, 374 | qf1_new_actions_mean=qf1_new_actions_mean, 375 | qf2_new_actions_mean=qf2_new_actions_mean, 376 | action_diff=action_diff, 377 | qf1_target_error=qf1_target_errors.mean(), 378 | qf1_td_error=qf1_td_errors.mean(), 379 | qf2_target_error=qf2_target_errors.mean(), 380 | qf2_td_error=qf2_td_errors.mean() 381 | ) 382 | log_info.update(debug_log_info) 383 | 384 | return log_info 385 | 386 | # Below is overwritten for general logging with log_info 387 | def train(self, trainer): 388 | """Obtain samplers and start actual training for each epoch. 389 | 390 | Args: 391 | trainer (Trainer): Gives the algorithm the access to 392 | :method:`~Trainer.step_epochs()`, which provides services 393 | such as snapshotting and sampler control. 394 | 395 | Returns: 396 | float: The average return in last epoch cycle. 397 | 398 | """ 399 | if not self._eval_env: 400 | self._eval_env = trainer.get_env_copy() 401 | last_return = None 402 | for _ in trainer.step_epochs(): 403 | for _ in range(self._steps_per_epoch): 404 | if not (self.replay_buffer.n_transitions_stored >= 405 | self._min_buffer_size): 406 | batch_size = int(self._min_buffer_size) 407 | else: 408 | batch_size = None 409 | trainer.step_episode = trainer.obtain_samples( 410 | trainer.step_itr, batch_size) 411 | path_returns = [] 412 | for path in trainer.step_episode: 413 | self.replay_buffer.add_path( 414 | dict(observation=path['observations'], 415 | action=path['actions'], 416 | reward=path['rewards'].reshape(-1, 1), 417 | next_observation=path['next_observations'], 418 | terminal=np.array([ 419 | step_type == StepType.TERMINAL 420 | for step_type in path['step_types'] 421 | ]).reshape(-1, 1))) 422 | path_returns.append(sum(path['rewards'])) 423 | assert len(path_returns) == len(trainer.step_episode) 424 | self.episode_rewards.append(np.mean(path_returns)) 425 | 426 | for _ in range(self._gradient_steps): 427 | log_info = self.train_once() 428 | 429 | if self._num_evaluation_episodes>0: 430 | last_return = self._evaluate_policy(trainer.step_itr) 431 | self._log_statistics(log_info) 432 | tabular.record('TotalEnvSteps', trainer.total_env_steps) 433 | trainer.step_itr += 1 434 | 435 | return np.mean(last_return) if last_return is not None else 0 436 | 437 | def train_once(self, itr=None, paths=None): 438 | """Complete 1 training iteration of ATAC. 439 | 440 | Args: 441 | itr (int): Iteration number. This argument is deprecated. 442 | paths (list[dict]): A list of collected paths. 443 | This argument is deprecated. 444 | 445 | Returns: 446 | torch.Tensor: loss from actor/policy network after optimization. 447 | torch.Tensor: loss from 1st q-function after optimization. 448 | torch.Tensor: loss from 2nd q-function after optimization. 449 | 450 | """ 451 | del itr 452 | del paths 453 | if self.replay_buffer.n_transitions_stored >= self._min_buffer_size: 454 | warmstart = self._n_updates_performed0 else 1 26 | snapshot_mode = 'gap_and_last' if snapshot_frequency>0 else 'last' 27 | return snapshot_gap, snapshot_mode 28 | 29 | def train_agent(train_func, 30 | *, 31 | train_kwargs, 32 | log_dir=None, 33 | snapshot_frequency=0, 34 | use_existing_dir=True, 35 | x_axis='TotalEnvSteps', 36 | ): 37 | """ A helper method to run experiments in garage. """ 38 | snapshot_gap, snapshot_mode = get_snapshot_info(snapshot_frequency) 39 | save_mode = train_kwargs.get('save_mode', 'light') 40 | wrapped_train_func = wrap_experiment(train_func, 41 | log_dir=log_dir, # overwrite 42 | snapshot_mode=snapshot_mode, 43 | snapshot_gap=snapshot_gap, 44 | archive_launch_repo=save_mode!='light', 45 | use_existing_dir=use_existing_dir, 46 | x_axis=x_axis) # overwrites existing directory 47 | score = wrapped_train_func(**train_kwargs) 48 | return score 49 | 50 | def load_algo(path, itr='last'): 51 | from garage.experiment import Snapshotter 52 | snapshotter = Snapshotter() 53 | data = snapshotter.load(path, itr=itr) 54 | return data['algo'] 55 | 56 | def setup_gpu(algo, gpu_id=-1): 57 | if gpu_id>=0: 58 | set_gpu_mode(torch.cuda.is_available(), gpu_id=gpu_id) 59 | if callable(getattr(algo, 'to', None)): 60 | algo.to() 61 | 62 | 63 | def collect_episode_batch(policy, *, 64 | env, 65 | batch_size, 66 | sampler_mode='ray', 67 | n_workers=4): 68 | """Obtain one batch of episodes.""" 69 | sampler = get_sampler(policy, env=env, sampler_mode=sampler_mode, n_workers=n_workers) 70 | agent_update = policy.get_param_values() 71 | episodes = sampler.obtain_samples(0, batch_size, agent_update) 72 | return episodes 73 | 74 | from garage.sampler import Sampler 75 | import copy 76 | from garage._dtypes import EpisodeBatch 77 | class BatchSampler(Sampler): 78 | 79 | def __init__(self, episode_batch, randomize=True): 80 | self.episode_batch = episode_batch 81 | self.randomize = randomize 82 | self._counter = 0 83 | 84 | def obtain_samples(self, itr, num_samples, agent_update, env_update=None): 85 | 86 | ns = self.episode_batch.lengths 87 | if num_samples=num_samples)[0] 94 | if len(itemindex)>0: 95 | ld = self.episode_batch.to_list() 96 | j_max = min(len(ld), itemindex[0]+1) 97 | ld = [ld[i] for i in ind[:j_max].tolist()] 98 | sampled_eb = EpisodeBatch.from_list(self.episode_batch.env_spec,ld) 99 | else: 100 | sampled_eb = None 101 | else: 102 | ns = self.episode_batch.lengths 103 | ind = np.arange(len(ns)) 104 | cumsum_permuted_ns = np.cumsum(ns[ind]) 105 | counter = int(self._counter) 106 | itemindex = np.where(cumsum_permuted_ns>=num_samples*(counter+1))[0] 107 | itemindex0 = np.where(cumsum_permuted_ns>num_samples*counter)[0] 108 | if len(itemindex)>0: 109 | ld = self.episode_batch.to_list() 110 | j_max = min(len(ld), itemindex[0]+1) 111 | j_min = itemindex0[0] 112 | ld = [ld[i] for i in ind[j_min:j_max].tolist()] 113 | sampled_eb = EpisodeBatch.from_list(self.episode_batch.env_spec,ld) 114 | self._counter+=1 115 | else: 116 | sampled_eb = None 117 | else: 118 | sampled_eb = self.episode_batch 119 | 120 | return sampled_eb 121 | 122 | def shutdown_worker(self): 123 | pass 124 | 125 | 126 | from garage.sampler import DefaultWorker, VecWorker 127 | def get_sampler(policy, env, 128 | n_workers=4): 129 | if n_workers>1: 130 | return RaySampler(agents=policy, 131 | envs=env, 132 | max_episode_length=env.spec.max_episode_length, 133 | n_workers=n_workers) 134 | else: 135 | return LocalSampler(agents=policy, 136 | envs=env, 137 | max_episode_length=env.spec.max_episode_length, 138 | worker_class=FragmentWorker, 139 | n_workers=n_workers) 140 | 141 | 142 | def get_algo(Algo, algo_config): 143 | import inspect 144 | algospec = inspect.getfullargspec(Algo) 145 | allowed_args = algospec.args + algospec.kwonlyargs 146 | for k in list(algo_config.keys()): 147 | if k not in allowed_args: 148 | del algo_config[k] 149 | algo = Algo(**algo_config) 150 | return algo -------------------------------------------------------------------------------- /src/atac/garage_tools/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from collections import namedtuple 4 | from dowel import logger, tabular 5 | import numpy as np 6 | 7 | from garage.trainer import Trainer as garageTrainer 8 | from garage.trainer import TrainArgs, NotSetupError 9 | from garage.experiment.experiment import dump_json 10 | from .utils import read_attr_from_csv 11 | 12 | 13 | class Trainer(garageTrainer): 14 | """ A modifed version of the Garage Trainer. 15 | 16 | This subclass adds 17 | 1) a light saving mode to minimze the stroage usage (only saving the 18 | networks, not the trainer and the full algo.) 19 | 2) a ignore_shutdown flag for running multiple experiments. 20 | 3) a return_attr option. 21 | 4) a cpu data collection mode. 22 | 5) logging of sampling time. 23 | 6) logging of current epoch index. 24 | """ 25 | 26 | # Add a light saving mode to minimze the stroage usage. 27 | # Add return_attr, return_mode options. 28 | # Add a cpu data collection mode. 29 | def setup(self, algo, env, 30 | force_cpu_data_collection=False, 31 | save_mode='light', 32 | return_mode='average', 33 | return_attr='Evaluation/AverageReturn'): 34 | """Set up trainer for algorithm and environment. 35 | 36 | This method saves algo and env within trainer and creates a sampler. 37 | 38 | Note: 39 | After setup() is called all variables in session should have been 40 | initialized. setup() respects existing values in session so 41 | policy weights can be loaded before setup(). 42 | 43 | Args: 44 | algo (RLAlgorithm): An algorithm instance. If this algo want to use 45 | samplers, it should have a `_sampler` field. 46 | env (Environment): An environment instance. 47 | save_mode (str): 'light' or 'full' 48 | return_mode (str): 'full', 'average', or 'last' 49 | return_attr (str): the name of the logged attribute 50 | 51 | """ 52 | super().setup(algo, env) 53 | assert save_mode in ('light', 'full') 54 | assert return_mode in ('full', 'average', 'last') 55 | self.save_mode = save_mode 56 | self.return_mode = return_mode 57 | self.return_attr = return_attr 58 | self.force_cpu_data_collection = force_cpu_data_collection 59 | self._sampling_time = 0. 60 | 61 | # Add a light saving mode (which saves only policy and value functions of an algorithm) 62 | def save(self, epoch): 63 | """Save snapshot of current batch. 64 | 65 | Args: 66 | epoch (int): Epoch. 67 | 68 | Raises: 69 | NotSetupError: if save() is called before the trainer is set up. 70 | 71 | """ 72 | if not self._has_setup: 73 | raise NotSetupError('Use setup() to setup trainer before saving.') 74 | 75 | logger.log('Saving snapshot...') 76 | 77 | params = dict() 78 | # Save arguments 79 | params['seed'] = self._seed 80 | params['train_args'] = self._train_args 81 | params['stats'] = self._stats 82 | 83 | if self.save_mode=='light': 84 | # Only save networks 85 | networks = self._algo.networks 86 | keys = [] 87 | values = [] 88 | for k, v in self._algo.__dict__.items(): 89 | if v in networks: 90 | keys.append(k) 91 | values.append(v) 92 | 93 | AlgoData = namedtuple(type(self._algo).__name__+'Networks', 94 | field_names=keys, 95 | defaults=values, 96 | rename=True) 97 | params['algo'] = AlgoData() 98 | 99 | elif self.save_mode=='full': 100 | # Default behavior: save everything 101 | # Save states 102 | params['env'] = self._env 103 | params['algo'] = self._algo 104 | params['n_workers'] = self._n_workers 105 | params['worker_class'] = self._worker_class 106 | params['worker_args'] = self._worker_args 107 | else: 108 | raise ValueError('Unknown save_mode.') 109 | 110 | self._snapshotter.save_itr_params(epoch, params) 111 | 112 | logger.log('Saved') 113 | 114 | # Include ignore_shutdown flag 115 | def train(self, 116 | n_epochs, 117 | batch_size=None, 118 | plot=False, 119 | store_episodes=False, 120 | pause_for_plot=False, 121 | ignore_shutdown=False): 122 | """Start training. 123 | 124 | Args: 125 | n_epochs (int): Number of epochs. 126 | batch_size (int or None): Number of environment steps in one batch. 127 | plot (bool): Visualize an episode from the policy after each epoch. 128 | store_episodes (bool): Save episodes in snapshot. 129 | pause_for_plot (bool): Pause for plot. 130 | 131 | Raises: 132 | NotSetupError: If train() is called before setup(). 133 | 134 | Returns: 135 | float: The average return in last epoch cycle. 136 | 137 | """ 138 | if not self._has_setup: 139 | raise NotSetupError( 140 | 'Use setup() to setup trainer before training.') 141 | 142 | # Save arguments for restore 143 | self._train_args = TrainArgs(n_epochs=n_epochs, 144 | batch_size=batch_size, 145 | plot=plot, 146 | store_episodes=store_episodes, 147 | pause_for_plot=pause_for_plot, 148 | start_epoch=0) 149 | 150 | self._plot = plot 151 | self._start_worker() 152 | 153 | log_dir = self._snapshotter.snapshot_dir 154 | if self.save_mode !='light': 155 | summary_file = os.path.join(log_dir, 'experiment.json') 156 | dump_json(summary_file, self) 157 | 158 | # Train the agent 159 | last_return = self._algo.train(self) 160 | 161 | # XXX Ignore shutdown, if needed 162 | if not ignore_shutdown: 163 | self._shutdown_worker() 164 | 165 | # XXX Return other statistics from logged data 166 | csv_file = os.path.join(log_dir,'progress.csv') 167 | progress = read_attr_from_csv(csv_file, self.return_attr) 168 | progress = progress if progress is not None else 0 169 | if self.return_mode == 'average': 170 | score = np.mean(progress) 171 | elif self.return_mode == 'full': 172 | score = progress 173 | elif self.return_mode == 'last': 174 | score = last_return 175 | else: 176 | NotImplementedError 177 | return score 178 | 179 | # Add a cpu data collection mode 180 | def obtain_episodes(self, 181 | itr, 182 | batch_size=None, 183 | agent_update=None, 184 | env_update=None): 185 | """Obtain one batch of episodes. 186 | 187 | Args: 188 | itr (int): Index of iteration (epoch). 189 | batch_size (int): Number of steps in batch. This is a hint that the 190 | sampler may or may not respect. 191 | agent_update (object): Value which will be passed into the 192 | `agent_update_fn` before doing sampling episodes. If a list is 193 | passed in, it must have length exactly `factory.n_workers`, and 194 | will be spread across the workers. 195 | env_update (object): Value which will be passed into the 196 | `env_update_fn` before sampling episodes. If a list is passed 197 | in, it must have length exactly `factory.n_workers`, and will 198 | be spread across the workers. 199 | 200 | Raises: 201 | ValueError: If the trainer was initialized without a sampler, or 202 | batch_size wasn't provided here or to train. 203 | 204 | Returns: 205 | EpisodeBatch: Batch of episodes. 206 | 207 | """ 208 | if self._sampler is None: 209 | raise ValueError('trainer was not initialized with `sampler`. ' 210 | 'the algo should have a `_sampler` field when' 211 | '`setup()` is called') 212 | if batch_size is None and self._train_args.batch_size is None: 213 | raise ValueError( 214 | 'trainer was not initialized with `batch_size`. ' 215 | 'Either provide `batch_size` to trainer.train, ' 216 | ' or pass `batch_size` to trainer.obtain_samples.') 217 | episodes = None 218 | if agent_update is None: 219 | policy = getattr(self._algo, 'exploration_policy', None) 220 | if policy is None: 221 | # This field should exist, since self.make_sampler would have 222 | # failed otherwise. 223 | policy = self._algo.policy 224 | agent_update = policy.get_param_values() 225 | 226 | # XXX Move the tensor to cpu. 227 | if self.force_cpu_data_collection: 228 | for k,v in agent_update.items(): 229 | if v.device.type != 'cpu': 230 | agent_update[k] = v.to('cpu') 231 | 232 | # XXX Time data collection. 233 | _start_sampling_time = time.time() 234 | episodes = self._sampler.obtain_samples( 235 | itr, (batch_size or self._train_args.batch_size), 236 | agent_update=agent_update, 237 | env_update=env_update) 238 | self._sampling_time = time.time() - _start_sampling_time 239 | self._stats.total_env_steps += sum(episodes.lengths) 240 | return episodes 241 | 242 | # Log sampling time and Epoch 243 | def log_diagnostics(self, pause_for_plot=False): 244 | """Log diagnostics. 245 | 246 | Args: 247 | pause_for_plot (bool): Pause for plot. 248 | 249 | """ 250 | logger.log('Time %.2f s' % (time.time() - self._start_time)) 251 | logger.log('EpochTime %.2f s' % (time.time() - self._itr_start_time)) # XXX 252 | logger.log('SamplingTime %.2f s' % (self._sampling_time)) 253 | tabular.record('TotalEnvSteps', self._stats.total_env_steps) 254 | tabular.record('Epoch', self.step_itr) 255 | logger.log(tabular) 256 | 257 | if self._plot: 258 | self._plotter.update_plot(self._algo.policy, 259 | self._algo.max_episode_length) 260 | if pause_for_plot: 261 | input('Plotting evaluation run: Press Enter to " "continue...') 262 | 263 | class BatchTrainer(Trainer): 264 | """ A batch version of Trainer that disables environment sampling. """ 265 | 266 | def obtain_samples(self, 267 | itr, 268 | batch_size=None, 269 | agent_update=None, 270 | env_update=None): 271 | """ Return an empty list. """ 272 | return [] -------------------------------------------------------------------------------- /src/atac/garage_tools/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def str2bool(v): 4 | if isinstance(v, bool): 5 | return v 6 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 7 | return True 8 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 9 | return False 10 | else: 11 | raise TypeError('Boolean value expected.') 12 | 13 | 14 | def torch_method(torch_fun): 15 | def wrapped_fun(x): 16 | with torch.no_grad(): 17 | return torch_fun(torch.Tensor(x)).numpy() 18 | return wrapped_fun 19 | 20 | 21 | import numpy as np 22 | import csv 23 | def read_attr_from_csv(csv_path, attr, delimiter=','): 24 | with open(csv_path) as csv_file: 25 | reader = csv.reader(csv_file, delimiter=delimiter) 26 | try: 27 | row = next(reader) 28 | except Exception: 29 | return None 30 | if attr not in row: 31 | return None 32 | idx = row.index(attr) # the column number for this attribute 33 | vals = [] 34 | for row in reader: 35 | vals.append(row[idx]) 36 | 37 | vals = [np.nan if v=='' else v for v in vals] 38 | return np.array(vals, dtype=np.float64) 39 | --------------------------------------------------------------------------------