├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── SECURITY.md
├── install_mujoco.sh
├── scripts
└── main.py
├── setup.py
└── src
└── atac
├── atac.py
└── garage_tools
├── rl_utils.py
├── trainer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # ATAC
132 | exp_data
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ATAC: Adversarially Trained Actor Critic
2 |
3 | This repository contains the code to reproduce the experimental results of ATAC algorithm in the paper Adversarially Trained Actor Critic for Offline Reinforcement Learning by Ching-An Cheng*, Tengyang Xie*, Nan Jiang, and Alekh Agarwal (https://arxiv.org/abs/2202.02446).
4 |
5 | ***Please see also https://github.com/microsoft/lightATAC for a lightweight reimplementation of ATAC, which gives a 1.5-2X speed up compared with the original code here.
6 |
7 | ### Setup
8 |
9 | #### Clone the repository and create a conda environment.
10 | ```
11 | git clone https://github.com/microsoft/ATAC.git
12 | conda create -n atac python=3.8
13 | cd atac
14 | ```
15 | #### Prerequisite: Install Mujoco
16 | (Optional) Install free mujoco210 for mujoco_py and mujoco211 for dm_control.
17 | ```
18 | bash install_mujoco.sh
19 | echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mujoco210/bin:/usr/lib/nvidia" >> ~/.bashrc
20 | source ~/.bashrc
21 | ```
22 | #### Install ATAC
23 | ```
24 | conda activate atac
25 | pip install -e .[mujoco210]
26 | # or below, if the original paid mujoco is used.
27 | pip install -e .[mujoco200]
28 | ```
29 | #### Run ATAC
30 | ```
31 | python scripts/main.py
32 | ```
33 |
34 | ### Contributing
35 |
36 | This project welcomes contributions and suggestions. Most contributions require you to agree to a
37 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
38 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
39 |
40 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide
41 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
42 | provided by the bot. You will only need to do this once across all repos using our CLA.
43 |
44 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
45 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
46 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
47 |
48 | ### Trademarks
49 |
50 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
51 | trademarks or logos is subject to and must follow
52 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
53 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
54 | Any use of third-party trademarks or logos are subject to those third-party's policies.
55 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
40 |
41 |
--------------------------------------------------------------------------------
/install_mujoco.sh:
--------------------------------------------------------------------------------
1 | sudo apt-get install libglew-dev patchelf
2 | wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-x86_64.tar.gz
3 | wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz
4 | tar -xf mujoco-2.1.1-linux-x86_64.tar.gz
5 | tar -xf mujoco210-linux-x86_64.tar.gz
6 | rm mujoco-2.1.1-linux-x86_64.tar.gz mujoco210-linux-x86_64.tar.gz
7 | mkdir ~/.mujoco
8 | mv mujoco210 mujoco-2.1.1 ~/.mujoco
--------------------------------------------------------------------------------
/scripts/main.py:
--------------------------------------------------------------------------------
1 | import gym, d4rl, torch, os
2 |
3 | import numpy as np
4 | from urllib.error import HTTPError
5 | from garage.envs import GymEnv
6 | from garage.experiment.deterministic import set_seed
7 | from garage.replay_buffer import PathBuffer
8 | from garage.torch.algos import SAC
9 | from garage.torch.policies import TanhGaussianMLPPolicy
10 | from garage.torch.q_functions import ContinuousMLPQFunction
11 |
12 |
13 | from atac.atac import ATAC
14 | from atac.garage_tools.rl_utils import train_agent, get_sampler, setup_gpu, get_algo, get_log_dir_name, load_algo
15 | from atac.garage_tools.trainer import Trainer
16 |
17 |
18 | def load_d4rl_data_as_buffer(dataset, replay_buffer):
19 | assert isinstance(replay_buffer, PathBuffer)
20 | replay_buffer.add_path(
21 | dict(observation=dataset['observations'],
22 | action=dataset['actions'],
23 | reward=dataset['rewards'].reshape(-1, 1),
24 | next_observation=dataset['next_observations'],
25 | terminal=dataset['terminals'].reshape(-1,1),
26 | ))
27 |
28 | def train_func(ctxt=None,
29 | *,
30 | algo='ATAC',
31 | # Environment parameters
32 | env_name,
33 | # Evaluation mode
34 | evaluation_mode=False,
35 | policy_path=None,
36 | # Trainer parameters
37 | n_epochs=3000, # number of training epochs
38 | batch_size=0, # number of samples collected per update
39 | replay_buffer_size=int(2e6),
40 | # Network parameters
41 | policy_hidden_sizes=(256, 256, 256),
42 | policy_activation='ReLU',
43 | policy_init_std=1.0,
44 | value_hidden_sizes=(256, 256, 256),
45 | value_activation='ReLU',
46 | min_std=1e-5,
47 | # Algorithm parameters
48 | discount=0.99,
49 | policy_lr=5e-7, # optimization stepsize for policy update
50 | value_lr=5e-4, # optimization stepsize for value regression
51 | target_update_tau=5e-3, # for target network
52 | minibatch_size=256, # optimization/replaybuffer minibatch size
53 | n_grad_steps=2000, # number of gradient updates per epoch
54 | n_warmstart_steps=200000, # number of warm-up steps
55 | fixed_alpha=None, # whether to fix the temperate parameter
56 | use_deterministic_evaluation=True, # do evaluation based on the deterministic policy
57 | num_evaluation_episodes=5, # number of episodes to evaluate (only affect off-policy algorithms)
58 | # ATAC parameters
59 | beta=1.0, # weight on the Bellman error
60 | norm_constraint=100,
61 | use_two_qfs=True, # whether to use two q function
62 | q_eval_mode='0.5_0.5',
63 | init_pess=False,
64 | # Compute parameters
65 | seed=0,
66 | n_workers=1, # number of workers for data collection
67 | gpu_id=-1, # try to use gpu, if implemented
68 | force_cpu_data_collection=True, # use cpu for data collection.
69 | # Logging parameters
70 | save_mode='light',
71 | ignore_shutdown=False, # do not shutdown workers after training
72 | return_mode='average', # 'full', 'average', 'last'
73 | return_attr='Evaluation/AverageReturn', # the log attribute
74 | ):
75 |
76 | """ Train an agent in batch mode. """
77 |
78 | # Set the random seed
79 | set_seed(seed)
80 |
81 | # Initialize gym env
82 | dataset = None
83 | d4rl_env = gym.make(env_name) # d4rl env
84 | while dataset is None:
85 | try:
86 | dataset = d4rl.qlearning_dataset(d4rl_env)
87 | except (HTTPError, OSError):
88 | print('Unable to download dataset. Retry.')
89 | pass
90 |
91 | if init_pess: # for ATAC0
92 | dataset_raw = d4rl_env.get_dataset()
93 | ends = dataset_raw['terminals']+ dataset_raw['timeouts']
94 | starts = np.concatenate([[True], ends[:-1]])
95 | init_observations = dataset_raw['observations'][starts]
96 | else:
97 | init_observations = None
98 |
99 | # Initialize replay buffer and gymenv
100 | env = GymEnv(d4rl_env)
101 | replay_buffer = PathBuffer(capacity_in_transitions=int(replay_buffer_size))
102 | load_d4rl_data_as_buffer(dataset, replay_buffer)
103 | reward_scale = 1.0
104 |
105 | # Initialize the algorithm
106 | env_spec = env.spec
107 |
108 | policy = TanhGaussianMLPPolicy(
109 | env_spec=env_spec,
110 | hidden_sizes=policy_hidden_sizes,
111 | hidden_nonlinearity=eval('torch.nn.'+policy_activation),
112 | init_std=policy_init_std,
113 | min_std=min_std)
114 |
115 | qf1 = ContinuousMLPQFunction(
116 | env_spec=env_spec,
117 | hidden_sizes=value_hidden_sizes,
118 | hidden_nonlinearity=eval('torch.nn.'+value_activation),
119 | output_nonlinearity=None)
120 |
121 | qf2 = ContinuousMLPQFunction(
122 | env_spec=env_spec,
123 | hidden_sizes=value_hidden_sizes,
124 | hidden_nonlinearity=eval('torch.nn.'+value_activation),
125 | output_nonlinearity=None)
126 |
127 | sampler = get_sampler(policy, env, n_workers=n_workers)
128 |
129 | Algo = globals()[algo]
130 |
131 | algo_config = dict( # union of all algorithm configs
132 | env_spec=env_spec,
133 | policy=policy,
134 | qf1=qf1,
135 | qf2=qf2,
136 | sampler=sampler,
137 | replay_buffer=replay_buffer,
138 | discount=discount,
139 | policy_lr=policy_lr,
140 | qf_lr=value_lr,
141 | target_update_tau=target_update_tau,
142 | buffer_batch_size=minibatch_size,
143 | gradient_steps_per_itr=n_grad_steps,
144 | use_deterministic_evaluation=use_deterministic_evaluation,
145 | min_buffer_size=int(0),
146 | num_evaluation_episodes=num_evaluation_episodes,
147 | fixed_alpha=fixed_alpha,
148 | reward_scale=reward_scale,
149 | )
150 |
151 | # ATAC
152 | extra_algo_config = dict(
153 | beta=beta,
154 | norm_constraint=norm_constraint,
155 | use_two_qfs=use_two_qfs,
156 | n_warmstart_steps=n_warmstart_steps,
157 | q_eval_mode=q_eval_mode,
158 | init_observations=init_observations,
159 | )
160 |
161 | algo_config.update(extra_algo_config)
162 |
163 | algo = Algo(**algo_config)
164 |
165 | setup_gpu(algo, gpu_id=gpu_id)
166 |
167 | # Initialize the trainer
168 | from atac.garage_tools.trainer import BatchTrainer as Trainer
169 | trainer = Trainer(ctxt)
170 | trainer.setup(algo=algo,
171 | env=env,
172 | force_cpu_data_collection=force_cpu_data_collection,
173 | save_mode=save_mode,
174 | return_mode=return_mode,
175 | return_attr=return_attr)
176 |
177 | return trainer.train(n_epochs=n_epochs,
178 | batch_size=batch_size,
179 | ignore_shutdown=ignore_shutdown)
180 |
181 |
182 | def run(log_root='.',
183 | torch_n_threads=2,
184 | snapshot_frequency=0,
185 | **train_kwargs):
186 | torch.set_num_threads(torch_n_threads)
187 | log_dir = get_log_dir_name(train_kwargs, ['beta', 'discount', 'norm_constraint',
188 | 'policy_lr', 'value_lr',
189 | 'use_two_qfs',
190 | 'fixed_alpha',
191 | 'q_eval_mode',
192 | 'n_warmstart_steps', 'seed'])
193 | train_kwargs['return_mode'] = 'full'
194 |
195 | # Offline training
196 | log_dir_path = os.path.join(log_root,'exp_data','Offline'+train_kwargs['algo']+'_'+train_kwargs['env_name'], log_dir)
197 | full_score = train_agent(train_func,
198 | log_dir=log_dir_path,
199 | train_kwargs=train_kwargs,
200 | snapshot_frequency=snapshot_frequency,
201 | x_axis='Epoch')
202 |
203 | window = 50
204 | score = np.median(full_score[-min(len(full_score),window):])
205 | print('Median of performance of last {} epochs'.format(window), score)
206 | return {'score': score, # last 50 epochs
207 | 'mean': np.mean(full_score)}
208 |
209 | if __name__=='__main__':
210 | import argparse
211 | from atac.garage_tools.utils import str2bool
212 | parser = argparse.ArgumentParser()
213 | parser.add_argument('-a', '--algo', type=str, default='ATAC')
214 | parser.add_argument('-e', '---env_name', type=str, default='hopper-medium-replay-v2')
215 | parser.add_argument('--n_epochs', type=int, default=3000)
216 | parser.add_argument('--discount', type=float, default=0.99)
217 | parser.add_argument('--gpu_id', type=int, default=-1) # use cpu by default
218 | parser.add_argument('--n_workers', type=int, default=1)
219 | parser.add_argument('--force_cpu_data_collection', type=str2bool, default=True)
220 | parser.add_argument('--seed', type=int, default=0)
221 | parser.add_argument('--n_warmstart_steps', type=int, default=100000)
222 | parser.add_argument('--fixed_alpha', type=float, default=None)
223 | parser.add_argument('--beta', type=float, default=16)
224 | parser.add_argument('--norm_constraint', type=float, default=100)
225 | parser.add_argument('--policy_lr', type=float, default=5e-7)
226 | parser.add_argument('--value_lr', type=float, default=5e-4)
227 | parser.add_argument('--target_update_tau', type=float, default=5e-3)
228 | parser.add_argument('--use_deterministic_evaluation', type=str2bool, default=True)
229 | parser.add_argument('--use_two_qfs', type=str2bool, default=True)
230 | parser.add_argument('--q_eval_mode', type=str, default='0.5_0.5')
231 | parser.add_argument('--init_pess', type=str2bool, default=False) # turn this on of ATAC0
232 |
233 | train_kwargs = vars(parser.parse_args())
234 | run(**train_kwargs)
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 |
3 | setup(
4 | name='ATAC',
5 | version='0.1.0',
6 | author='Ching-An Cheng',
7 | author_email='chinganc@microsoft.com',
8 | package_dir={'':'src'},
9 | python_requires='>=3.8',
10 | packages=['atac'],
11 | url='https://github.com/microsoft/ATAC',
12 | license='MIT LICENSE',
13 | description='ATAC code',
14 | long_description=open('README.md').read(),
15 | install_requires=[
16 | "garage==2021.3.0",
17 | "gym==0.17.2",],
18 | extras_require={
19 | 'mujoco200': ["mujoco_py==2.0.2.8", "d4rl @ git+https://github.com/chinganc/d4rl@master#egg=d4rl"],
20 | 'mujoco210': ["d4rl @ git+https://github.com/rail-berkeley/d4rl@master#egg=d4rl"]}
21 | )
22 |
--------------------------------------------------------------------------------
/src/atac/atac.py:
--------------------------------------------------------------------------------
1 | # yapf: disable
2 | from collections import deque
3 | import copy
4 |
5 | from dowel import tabular
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from garage import log_performance, obtain_evaluation_episodes, StepType
11 | from garage.np.algos import RLAlgorithm
12 | from garage.torch import as_torch_dict, global_device
13 | # yapf: enable
14 |
15 | torch.set_flush_denormal(True)
16 |
17 | def normalized_sum(loss, reg, w):
18 | return loss/w + reg if w>1 else loss + w*reg
19 |
20 | def l2_projection(constraint):
21 | @torch.no_grad()
22 | def fn(module):
23 | if hasattr(module, 'weight') and constraint>0:
24 | w = module.weight
25 | norm = torch.norm(w)
26 | w.mul_(torch.clip(constraint/norm, max=1))
27 | return fn
28 |
29 | def weight_l2(model):
30 | l2 = 0.
31 | for name, param in model.named_parameters():
32 | if 'weight' in name:
33 | l2 += torch.norm(param)**2
34 | return l2
35 |
36 | class ATAC(RLAlgorithm):
37 | """ Adversarilly Trained Actor Critic """
38 | def __init__(
39 | self,
40 | env_spec,
41 | policy,
42 | qf1,
43 | qf2,
44 | replay_buffer,
45 | sampler,
46 | *, # Everything after this is numbers.
47 | max_episode_length_eval=None,
48 | gradient_steps_per_itr,
49 | fixed_alpha=None,
50 | target_entropy=None,
51 | initial_log_entropy=0.,
52 | discount=0.99,
53 | buffer_batch_size=256,
54 | min_buffer_size=int(1e4),
55 | target_update_tau=5e-3,
56 | policy_lr=5e-7,
57 | qf_lr=5e-4,
58 | reward_scale=1.0,
59 | optimizer='Adam',
60 | steps_per_epoch=1,
61 | num_evaluation_episodes=10,
62 | eval_env=None,
63 | use_deterministic_evaluation=True,
64 | # ATAC parameters
65 | beta=1.0, # the regularization coefficient in front of the Bellman error
66 | lambd=0., # coeff for global pessimism
67 | init_observations=None, # for ATAC0 (None or np.ndarray)
68 | n_warmstart_steps=200000,
69 | norm_constraint=100,
70 | q_eval_mode='0.5_0.5', # 'max' 'w1_w2', 'adaptive'
71 | q_eval_loss='MSELoss', # 'MSELoss', 'SmoothL1Loss'
72 | use_two_qfs=True,
73 | terminal_value=None,
74 | Vmin=-float('inf'), # min value of Q (used in target backup)
75 | Vmax=float('inf'), # max value of Q (used in target backup)
76 | debug=True,
77 | stats_avg_rate=0.99, # for logging
78 | bellman_surrogate='td', #'td', None, 'target'
79 | ):
80 |
81 | #############################################################################################
82 |
83 | assert beta>=0
84 | assert norm_constraint>=0
85 | # Parsing
86 | optimizer = eval('torch.optim.'+optimizer)
87 | policy_lr = qf_lr if policy_lr is None or policy_lr < 0 else policy_lr # use shared lr if not provided.
88 |
89 | ## ATAC parameters
90 | self.beta = torch.Tensor([beta]) # regularization constant on the Bellman surrogate
91 | self._lambd = torch.Tensor([lambd]) # global pessimism coefficient
92 | self._init_observations = torch.Tensor(init_observations) if init_observations is not None else init_observations # if provided, it runs ATAC0
93 | self._n_warmstart_steps = n_warmstart_steps # during which, it performs independent C and Bellman minimization
94 | # q update parameters
95 | self._norm_constraint = norm_constraint # l2 norm constraint on the qf's weight; if negative, it gives the weight decay coefficient.
96 | self._q_eval_mode = [float(w) for w in q_eval_mode.split('_')] if '_' in q_eval_mode else q_eval_mode # residual algorithm
97 | self._q_eval_loss = eval('torch.nn.'+q_eval_loss)(reduction='none')
98 | self._use_two_qfs = use_two_qfs
99 | self._Vmin = Vmin # lower bound on the target
100 | self._Vmax = Vmax # upper bound on the target
101 | self._terminal_value = terminal_value if terminal_value is not None else lambda r, gamma: 0.
102 |
103 | # Stepsizes
104 | self._alpha_lr = qf_lr # potentially a larger stepsize, for the most inner optimization.
105 | self._bc_policy_lr = qf_lr # potentially a larger stepsize
106 |
107 | # Logging and algorithm state
108 | self._debug = debug
109 | self._n_updates_performed = 0 # Counter of number of grad steps performed
110 | self._cac_learning=False
111 | self._stats_avg_rate = stats_avg_rate
112 | self._bellman_surrogate = bellman_surrogate
113 | self._avg_bellman_error = 1. # for logging; so this works with zero warm-start
114 | self._avg_terminal_td_error = 1
115 |
116 | #############################################################################################
117 | # Original SAC parameters
118 | self._qf1 = qf1
119 | self._qf2 = qf2
120 | self.replay_buffer = replay_buffer
121 | self._tau = target_update_tau
122 | self._policy_lr = policy_lr
123 | self._qf_lr = qf_lr
124 | self._initial_log_entropy = initial_log_entropy
125 | self._gradient_steps = gradient_steps_per_itr
126 | self._optimizer = optimizer
127 | self._num_evaluation_episodes = num_evaluation_episodes
128 | self._eval_env = eval_env
129 |
130 | self._min_buffer_size = min_buffer_size
131 | self._steps_per_epoch = steps_per_epoch
132 | self._buffer_batch_size = buffer_batch_size
133 | self._discount = discount
134 | self._reward_scale = reward_scale
135 | self.max_episode_length = env_spec.max_episode_length
136 | self._max_episode_length_eval = env_spec.max_episode_length
137 |
138 | if max_episode_length_eval is not None:
139 | self._max_episode_length_eval = max_episode_length_eval
140 | self._use_deterministic_evaluation = use_deterministic_evaluation
141 |
142 | self.policy = policy
143 | self.env_spec = env_spec
144 | self.replay_buffer = replay_buffer
145 |
146 | self._sampler = sampler
147 |
148 | # use 2 target q networks
149 | self._target_qf1 = copy.deepcopy(self._qf1)
150 | self._target_qf2 = copy.deepcopy(self._qf2)
151 | self._policy_optimizer = self._optimizer(self.policy.parameters(),
152 | lr=self._bc_policy_lr) # lr for warmstart
153 | self._qf1_optimizer = self._optimizer(self._qf1.parameters(),
154 | lr=self._qf_lr)
155 | self._qf2_optimizer = self._optimizer(self._qf2.parameters(),
156 | lr=self._qf_lr)
157 |
158 | # automatic entropy coefficient tuning
159 | self._use_automatic_entropy_tuning = fixed_alpha is None
160 | self._fixed_alpha = fixed_alpha
161 | if self._use_automatic_entropy_tuning:
162 | if target_entropy:
163 | self._target_entropy = target_entropy
164 | else:
165 | self._target_entropy = -np.prod(
166 | self.env_spec.action_space.shape).item()
167 | self._log_alpha = torch.Tensor([self._initial_log_entropy
168 | ]).requires_grad_()
169 | self._alpha_optimizer = optimizer([self._log_alpha],
170 | lr=self._alpha_lr)
171 | else:
172 | self._log_alpha = torch.Tensor([self._fixed_alpha]).log()
173 | self.episode_rewards = deque(maxlen=30)
174 |
175 |
176 | def optimize_policy(self,
177 | samples_data,
178 | warmstart=False):
179 | """Optimize the policy q_functions, and temperature coefficient.
180 |
181 | Args:
182 | samples_data (dict): Transitions(S,A,R,S') that are sampled from
183 | the replay buffer. It should have the keys 'observation',
184 | 'action', 'reward', 'terminal', and 'next_observations'.
185 |
186 | Note:
187 | samples_data's entries should be torch.Tensor's with the following
188 | shapes:
189 | observation: :math:`(N, O^*)`
190 | action: :math:`(N, A^*)`
191 | reward: :math:`(N, 1)`
192 | terminal: :math:`(N, 1)`
193 | next_observation: :math:`(N, O^*)`
194 |
195 | Returns:
196 | torch.Tensor: loss from actor/policy network after optimization.
197 | torch.Tensor: loss from 1st q-function after optimization.
198 | torch.Tensor: loss from 2nd q-function after optimization.
199 |
200 | """
201 |
202 | obs = samples_data['observation']
203 | next_obs = samples_data['next_observation']
204 | actions = samples_data['action']
205 | rewards = samples_data['reward'].flatten() * self._reward_scale
206 | terminals = samples_data['terminal'].flatten()
207 |
208 | ##### Update Critic #####
209 | def compute_bellman_backup(q_pred_next):
210 | assert rewards.shape == q_pred_next.shape
211 | return rewards + (1.-terminals) * self._discount * q_pred_next + terminals * self._terminal_value(rewards, self._discount)
212 |
213 | def compute_bellman_loss(q_pred, q_pred_next, q_target):
214 | assert q_pred.shape == q_pred_next.shape == q_target.shape
215 | target_error = self._q_eval_loss(q_pred, q_target)
216 | q_target_pred = compute_bellman_backup(q_pred_next)
217 | td_error = self._q_eval_loss(q_pred, q_target_pred)
218 | w1, w2 = self._q_eval_mode
219 | bellman_loss = w1*target_error+ w2*td_error
220 | return bellman_loss, target_error, td_error
221 |
222 | ## Compute Bellman error
223 | with torch.no_grad():
224 | new_next_actions_dist = self.policy(next_obs)[0]
225 | _, new_next_actions = new_next_actions_dist.rsample_with_pre_tanh_value()
226 | target_q_values = self._target_qf1(next_obs, new_next_actions)
227 | if self._use_two_qfs:
228 | target_q_values = torch.min(target_q_values, self._target_qf2(next_obs, new_next_actions))
229 | target_q_values = torch.clip(target_q_values, min=self._Vmin, max=self._Vmax) # projection
230 | q_target = compute_bellman_backup(target_q_values.flatten())
231 |
232 | qf1_pred = self._qf1(obs, actions).flatten()
233 | qf1_pred_next = self._qf1(next_obs, new_next_actions).flatten()
234 | qf1_bellman_losses, qf1_target_errors, qf1_td_errors = compute_bellman_loss(qf1_pred, qf1_pred_next, q_target)
235 | qf1_bellman_loss = qf1_bellman_losses.mean()
236 |
237 | qf2_bellman_loss = qf2_target_error = qf2_td_error = torch.Tensor([0.])
238 | if self._use_two_qfs:
239 | qf2_pred = self._qf2(obs, actions).flatten()
240 | qf2_pred_next = self._qf2(next_obs, new_next_actions).flatten()
241 | qf2_bellman_losses, qf2_target_errors, qf2_td_errors = compute_bellman_loss(qf2_pred, qf2_pred_next, q_target)
242 | qf2_bellman_loss = qf2_bellman_losses.mean()
243 |
244 | # Compute GAN error
245 | # These samples will be used for the actor update too, so they need to be traced.
246 | new_actions_dist = self.policy(obs)[0]
247 | new_actions_pre_tanh, new_actions = new_actions_dist.rsample_with_pre_tanh_value()
248 |
249 | gan_qf1_loss = gan_qf2_loss = 0
250 | if not warmstart: # Compute gan_qf1_loss, gan_qf2_loss
251 | if self._init_observations is None:
252 | # Compute value difference
253 | qf1_new_actions = self._qf1(obs, new_actions.detach())
254 | gan_qf1_loss = (qf1_new_actions*(1+self._lambd) - qf1_pred).mean()
255 | if self._use_two_qfs:
256 | qf2_new_actions = self._qf2(obs, new_actions.detach())
257 | gan_qf2_loss = (qf2_new_actions*(1+self._lambd) - qf2_pred).mean()
258 | else: # initial state pessimism
259 | idx_ = np.random.choice(len(self._init_observations), self._buffer_batch_size)
260 | init_observations = self._init_observations[idx_]
261 | init_actions_dist = self.policy(init_observations)[0]
262 | init_actions_pre_tanh, init_actions = init_actions_dist.rsample_with_pre_tanh_value()
263 | qf1_new_actions = self._qf1(init_observations, init_actions.detach())
264 | gan_qf1_loss = qf1_new_actions.mean()
265 | if self._use_two_qfs:
266 | qf2_new_actions = self._qf2(init_observations, init_actions.detach())
267 | gan_qf2_loss = qf2_new_actions.mean()
268 |
269 |
270 | ## Compute full q loss
271 | # We normalized the objective to prevent exploding gradients
272 | # qf1_loss = gan_qf1_loss + beta * qf1_bellman_loss
273 | # qf2_loss = gan_qf2_loss + beta * qf2_bellman_loss
274 | with torch.no_grad():
275 | beta = self.beta
276 | qf1_loss = normalized_sum(gan_qf1_loss, qf1_bellman_loss, beta)
277 | qf2_loss = normalized_sum(gan_qf2_loss, qf2_bellman_loss, beta)
278 |
279 | if beta>0 or not warmstart:
280 | self._qf1_optimizer.zero_grad()
281 | qf1_loss.backward()
282 | self._qf1_optimizer.step()
283 | self._qf1.apply(l2_projection(self._norm_constraint))
284 |
285 | if self._use_two_qfs:
286 | self._qf2_optimizer.zero_grad()
287 | qf2_loss.backward()
288 | self._qf2_optimizer.step()
289 | self._qf2.apply(l2_projection(self._norm_constraint))
290 |
291 | ##### Update Actor #####
292 |
293 | # Compuate entropy
294 | log_pi_new_actions = new_actions_dist.log_prob(value=new_actions, pre_tanh_value=new_actions_pre_tanh)
295 | policy_entropy = -log_pi_new_actions.mean()
296 |
297 | alpha_loss = 0
298 | if self._use_automatic_entropy_tuning: # it comes first; seems to work also when put after policy update
299 | alpha_loss = self._log_alpha * (policy_entropy.detach() - self._target_entropy) # entropy - target
300 | self._alpha_optimizer.zero_grad()
301 | alpha_loss.backward()
302 | self._alpha_optimizer.step()
303 |
304 | with torch.no_grad():
305 | alpha = self._log_alpha.exp()
306 |
307 | lower_bound = 0
308 | if warmstart: # BC warmstart
309 | policy_log_prob = new_actions_dist.log_prob(samples_data['action'])
310 | # policy_loss = - policy_log_prob.mean() - alpha * policy_entropy
311 | policy_loss = normalized_sum(-policy_log_prob.mean(), -policy_entropy, alpha)
312 | else:
313 | # Compute performance difference lower bound
314 | min_q_new_actions = self._qf1(obs, new_actions)
315 | lower_bound = min_q_new_actions.mean()
316 | # policy_loss = - lower_bound - alpha * policy_kl
317 | policy_loss = normalized_sum(-lower_bound, -policy_entropy, alpha)
318 |
319 | self._policy_optimizer.zero_grad()
320 | policy_loss.backward()
321 | self._policy_optimizer.step()
322 |
323 | log_info = dict(
324 | policy_loss=policy_loss,
325 | qf1_loss=qf1_loss,
326 | qf2_loss=qf2_loss,
327 | qf1_bellman_loss=qf1_bellman_loss,
328 | gan_qf1_loss=gan_qf1_loss,
329 | qf2_bellman_loss=qf2_bellman_loss,
330 | gan_qf2_loss=gan_qf2_loss,
331 | beta=beta,
332 | alpha_loss=alpha_loss,
333 | policy_entropy=policy_entropy,
334 | alpha=alpha,
335 | lower_bound=lower_bound,
336 | )
337 |
338 | # For logging
339 | if self._debug:
340 | with torch.no_grad():
341 | if self._bellman_surrogate=='td':
342 | qf1_bellman_surrogate = qf1_td_errors.mean()
343 | qf2_bellman_surrogate = qf2_td_errors.mean()
344 | elif self._bellman_surrogate=='target':
345 | qf1_bellman_surrogate = qf1_target_errors.mean()
346 | qf2_bellman_surrogate = qf2_target_errors.mean()
347 | elif self._bellman_surrogate is None:
348 | qf1_bellman_surrogate = qf1_bellman_loss
349 | qf2_bellman_surrogate = qf2_bellman_loss
350 |
351 | bellman_surrogate = torch.max(qf1_bellman_surrogate, qf2_bellman_surrogate) # measure the TD error
352 | self._avg_bellman_error = self._avg_bellman_error*self._stats_avg_rate + bellman_surrogate*(1-self._stats_avg_rate)
353 |
354 | if terminals.sum()>0:
355 | terminal_td_error = (qf1_td_errors * terminals).sum() / terminals.sum()
356 | self._avg_terminal_td_error = self._avg_terminal_td_error*self._stats_avg_rate + terminal_td_error*(1-self._stats_avg_rate)
357 |
358 | qf1_pred_mean = qf1_pred.mean()
359 | qf2_pred_mean = qf2_pred.mean() if self._use_two_qfs else 0.
360 | q_target_mean = q_target.mean()
361 | target_q_values_mean = target_q_values.mean()
362 | qf1_new_actions_mean = qf1_new_actions.mean() if not warmstart else 0.
363 | qf2_new_actions_mean = qf2_new_actions.mean() if not warmstart and self._use_two_qfs else 0.
364 | action_diff = torch.mean(torch.norm(samples_data['action'] - new_actions, dim=1))
365 |
366 |
367 | debug_log_info = dict(
368 | avg_bellman_error=self._avg_bellman_error,
369 | avg_terminal_td_error=self._avg_terminal_td_error,
370 | qf1_pred_mean=qf1_pred_mean,
371 | qf2_pred_mean=qf2_pred_mean,
372 | q_target_mean=q_target_mean,
373 | target_q_values_mean=target_q_values_mean,
374 | qf1_new_actions_mean=qf1_new_actions_mean,
375 | qf2_new_actions_mean=qf2_new_actions_mean,
376 | action_diff=action_diff,
377 | qf1_target_error=qf1_target_errors.mean(),
378 | qf1_td_error=qf1_td_errors.mean(),
379 | qf2_target_error=qf2_target_errors.mean(),
380 | qf2_td_error=qf2_td_errors.mean()
381 | )
382 | log_info.update(debug_log_info)
383 |
384 | return log_info
385 |
386 | # Below is overwritten for general logging with log_info
387 | def train(self, trainer):
388 | """Obtain samplers and start actual training for each epoch.
389 |
390 | Args:
391 | trainer (Trainer): Gives the algorithm the access to
392 | :method:`~Trainer.step_epochs()`, which provides services
393 | such as snapshotting and sampler control.
394 |
395 | Returns:
396 | float: The average return in last epoch cycle.
397 |
398 | """
399 | if not self._eval_env:
400 | self._eval_env = trainer.get_env_copy()
401 | last_return = None
402 | for _ in trainer.step_epochs():
403 | for _ in range(self._steps_per_epoch):
404 | if not (self.replay_buffer.n_transitions_stored >=
405 | self._min_buffer_size):
406 | batch_size = int(self._min_buffer_size)
407 | else:
408 | batch_size = None
409 | trainer.step_episode = trainer.obtain_samples(
410 | trainer.step_itr, batch_size)
411 | path_returns = []
412 | for path in trainer.step_episode:
413 | self.replay_buffer.add_path(
414 | dict(observation=path['observations'],
415 | action=path['actions'],
416 | reward=path['rewards'].reshape(-1, 1),
417 | next_observation=path['next_observations'],
418 | terminal=np.array([
419 | step_type == StepType.TERMINAL
420 | for step_type in path['step_types']
421 | ]).reshape(-1, 1)))
422 | path_returns.append(sum(path['rewards']))
423 | assert len(path_returns) == len(trainer.step_episode)
424 | self.episode_rewards.append(np.mean(path_returns))
425 |
426 | for _ in range(self._gradient_steps):
427 | log_info = self.train_once()
428 |
429 | if self._num_evaluation_episodes>0:
430 | last_return = self._evaluate_policy(trainer.step_itr)
431 | self._log_statistics(log_info)
432 | tabular.record('TotalEnvSteps', trainer.total_env_steps)
433 | trainer.step_itr += 1
434 |
435 | return np.mean(last_return) if last_return is not None else 0
436 |
437 | def train_once(self, itr=None, paths=None):
438 | """Complete 1 training iteration of ATAC.
439 |
440 | Args:
441 | itr (int): Iteration number. This argument is deprecated.
442 | paths (list[dict]): A list of collected paths.
443 | This argument is deprecated.
444 |
445 | Returns:
446 | torch.Tensor: loss from actor/policy network after optimization.
447 | torch.Tensor: loss from 1st q-function after optimization.
448 | torch.Tensor: loss from 2nd q-function after optimization.
449 |
450 | """
451 | del itr
452 | del paths
453 | if self.replay_buffer.n_transitions_stored >= self._min_buffer_size:
454 | warmstart = self._n_updates_performed0 else 1
26 | snapshot_mode = 'gap_and_last' if snapshot_frequency>0 else 'last'
27 | return snapshot_gap, snapshot_mode
28 |
29 | def train_agent(train_func,
30 | *,
31 | train_kwargs,
32 | log_dir=None,
33 | snapshot_frequency=0,
34 | use_existing_dir=True,
35 | x_axis='TotalEnvSteps',
36 | ):
37 | """ A helper method to run experiments in garage. """
38 | snapshot_gap, snapshot_mode = get_snapshot_info(snapshot_frequency)
39 | save_mode = train_kwargs.get('save_mode', 'light')
40 | wrapped_train_func = wrap_experiment(train_func,
41 | log_dir=log_dir, # overwrite
42 | snapshot_mode=snapshot_mode,
43 | snapshot_gap=snapshot_gap,
44 | archive_launch_repo=save_mode!='light',
45 | use_existing_dir=use_existing_dir,
46 | x_axis=x_axis) # overwrites existing directory
47 | score = wrapped_train_func(**train_kwargs)
48 | return score
49 |
50 | def load_algo(path, itr='last'):
51 | from garage.experiment import Snapshotter
52 | snapshotter = Snapshotter()
53 | data = snapshotter.load(path, itr=itr)
54 | return data['algo']
55 |
56 | def setup_gpu(algo, gpu_id=-1):
57 | if gpu_id>=0:
58 | set_gpu_mode(torch.cuda.is_available(), gpu_id=gpu_id)
59 | if callable(getattr(algo, 'to', None)):
60 | algo.to()
61 |
62 |
63 | def collect_episode_batch(policy, *,
64 | env,
65 | batch_size,
66 | sampler_mode='ray',
67 | n_workers=4):
68 | """Obtain one batch of episodes."""
69 | sampler = get_sampler(policy, env=env, sampler_mode=sampler_mode, n_workers=n_workers)
70 | agent_update = policy.get_param_values()
71 | episodes = sampler.obtain_samples(0, batch_size, agent_update)
72 | return episodes
73 |
74 | from garage.sampler import Sampler
75 | import copy
76 | from garage._dtypes import EpisodeBatch
77 | class BatchSampler(Sampler):
78 |
79 | def __init__(self, episode_batch, randomize=True):
80 | self.episode_batch = episode_batch
81 | self.randomize = randomize
82 | self._counter = 0
83 |
84 | def obtain_samples(self, itr, num_samples, agent_update, env_update=None):
85 |
86 | ns = self.episode_batch.lengths
87 | if num_samples=num_samples)[0]
94 | if len(itemindex)>0:
95 | ld = self.episode_batch.to_list()
96 | j_max = min(len(ld), itemindex[0]+1)
97 | ld = [ld[i] for i in ind[:j_max].tolist()]
98 | sampled_eb = EpisodeBatch.from_list(self.episode_batch.env_spec,ld)
99 | else:
100 | sampled_eb = None
101 | else:
102 | ns = self.episode_batch.lengths
103 | ind = np.arange(len(ns))
104 | cumsum_permuted_ns = np.cumsum(ns[ind])
105 | counter = int(self._counter)
106 | itemindex = np.where(cumsum_permuted_ns>=num_samples*(counter+1))[0]
107 | itemindex0 = np.where(cumsum_permuted_ns>num_samples*counter)[0]
108 | if len(itemindex)>0:
109 | ld = self.episode_batch.to_list()
110 | j_max = min(len(ld), itemindex[0]+1)
111 | j_min = itemindex0[0]
112 | ld = [ld[i] for i in ind[j_min:j_max].tolist()]
113 | sampled_eb = EpisodeBatch.from_list(self.episode_batch.env_spec,ld)
114 | self._counter+=1
115 | else:
116 | sampled_eb = None
117 | else:
118 | sampled_eb = self.episode_batch
119 |
120 | return sampled_eb
121 |
122 | def shutdown_worker(self):
123 | pass
124 |
125 |
126 | from garage.sampler import DefaultWorker, VecWorker
127 | def get_sampler(policy, env,
128 | n_workers=4):
129 | if n_workers>1:
130 | return RaySampler(agents=policy,
131 | envs=env,
132 | max_episode_length=env.spec.max_episode_length,
133 | n_workers=n_workers)
134 | else:
135 | return LocalSampler(agents=policy,
136 | envs=env,
137 | max_episode_length=env.spec.max_episode_length,
138 | worker_class=FragmentWorker,
139 | n_workers=n_workers)
140 |
141 |
142 | def get_algo(Algo, algo_config):
143 | import inspect
144 | algospec = inspect.getfullargspec(Algo)
145 | allowed_args = algospec.args + algospec.kwonlyargs
146 | for k in list(algo_config.keys()):
147 | if k not in allowed_args:
148 | del algo_config[k]
149 | algo = Algo(**algo_config)
150 | return algo
--------------------------------------------------------------------------------
/src/atac/garage_tools/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from collections import namedtuple
4 | from dowel import logger, tabular
5 | import numpy as np
6 |
7 | from garage.trainer import Trainer as garageTrainer
8 | from garage.trainer import TrainArgs, NotSetupError
9 | from garage.experiment.experiment import dump_json
10 | from .utils import read_attr_from_csv
11 |
12 |
13 | class Trainer(garageTrainer):
14 | """ A modifed version of the Garage Trainer.
15 |
16 | This subclass adds
17 | 1) a light saving mode to minimze the stroage usage (only saving the
18 | networks, not the trainer and the full algo.)
19 | 2) a ignore_shutdown flag for running multiple experiments.
20 | 3) a return_attr option.
21 | 4) a cpu data collection mode.
22 | 5) logging of sampling time.
23 | 6) logging of current epoch index.
24 | """
25 |
26 | # Add a light saving mode to minimze the stroage usage.
27 | # Add return_attr, return_mode options.
28 | # Add a cpu data collection mode.
29 | def setup(self, algo, env,
30 | force_cpu_data_collection=False,
31 | save_mode='light',
32 | return_mode='average',
33 | return_attr='Evaluation/AverageReturn'):
34 | """Set up trainer for algorithm and environment.
35 |
36 | This method saves algo and env within trainer and creates a sampler.
37 |
38 | Note:
39 | After setup() is called all variables in session should have been
40 | initialized. setup() respects existing values in session so
41 | policy weights can be loaded before setup().
42 |
43 | Args:
44 | algo (RLAlgorithm): An algorithm instance. If this algo want to use
45 | samplers, it should have a `_sampler` field.
46 | env (Environment): An environment instance.
47 | save_mode (str): 'light' or 'full'
48 | return_mode (str): 'full', 'average', or 'last'
49 | return_attr (str): the name of the logged attribute
50 |
51 | """
52 | super().setup(algo, env)
53 | assert save_mode in ('light', 'full')
54 | assert return_mode in ('full', 'average', 'last')
55 | self.save_mode = save_mode
56 | self.return_mode = return_mode
57 | self.return_attr = return_attr
58 | self.force_cpu_data_collection = force_cpu_data_collection
59 | self._sampling_time = 0.
60 |
61 | # Add a light saving mode (which saves only policy and value functions of an algorithm)
62 | def save(self, epoch):
63 | """Save snapshot of current batch.
64 |
65 | Args:
66 | epoch (int): Epoch.
67 |
68 | Raises:
69 | NotSetupError: if save() is called before the trainer is set up.
70 |
71 | """
72 | if not self._has_setup:
73 | raise NotSetupError('Use setup() to setup trainer before saving.')
74 |
75 | logger.log('Saving snapshot...')
76 |
77 | params = dict()
78 | # Save arguments
79 | params['seed'] = self._seed
80 | params['train_args'] = self._train_args
81 | params['stats'] = self._stats
82 |
83 | if self.save_mode=='light':
84 | # Only save networks
85 | networks = self._algo.networks
86 | keys = []
87 | values = []
88 | for k, v in self._algo.__dict__.items():
89 | if v in networks:
90 | keys.append(k)
91 | values.append(v)
92 |
93 | AlgoData = namedtuple(type(self._algo).__name__+'Networks',
94 | field_names=keys,
95 | defaults=values,
96 | rename=True)
97 | params['algo'] = AlgoData()
98 |
99 | elif self.save_mode=='full':
100 | # Default behavior: save everything
101 | # Save states
102 | params['env'] = self._env
103 | params['algo'] = self._algo
104 | params['n_workers'] = self._n_workers
105 | params['worker_class'] = self._worker_class
106 | params['worker_args'] = self._worker_args
107 | else:
108 | raise ValueError('Unknown save_mode.')
109 |
110 | self._snapshotter.save_itr_params(epoch, params)
111 |
112 | logger.log('Saved')
113 |
114 | # Include ignore_shutdown flag
115 | def train(self,
116 | n_epochs,
117 | batch_size=None,
118 | plot=False,
119 | store_episodes=False,
120 | pause_for_plot=False,
121 | ignore_shutdown=False):
122 | """Start training.
123 |
124 | Args:
125 | n_epochs (int): Number of epochs.
126 | batch_size (int or None): Number of environment steps in one batch.
127 | plot (bool): Visualize an episode from the policy after each epoch.
128 | store_episodes (bool): Save episodes in snapshot.
129 | pause_for_plot (bool): Pause for plot.
130 |
131 | Raises:
132 | NotSetupError: If train() is called before setup().
133 |
134 | Returns:
135 | float: The average return in last epoch cycle.
136 |
137 | """
138 | if not self._has_setup:
139 | raise NotSetupError(
140 | 'Use setup() to setup trainer before training.')
141 |
142 | # Save arguments for restore
143 | self._train_args = TrainArgs(n_epochs=n_epochs,
144 | batch_size=batch_size,
145 | plot=plot,
146 | store_episodes=store_episodes,
147 | pause_for_plot=pause_for_plot,
148 | start_epoch=0)
149 |
150 | self._plot = plot
151 | self._start_worker()
152 |
153 | log_dir = self._snapshotter.snapshot_dir
154 | if self.save_mode !='light':
155 | summary_file = os.path.join(log_dir, 'experiment.json')
156 | dump_json(summary_file, self)
157 |
158 | # Train the agent
159 | last_return = self._algo.train(self)
160 |
161 | # XXX Ignore shutdown, if needed
162 | if not ignore_shutdown:
163 | self._shutdown_worker()
164 |
165 | # XXX Return other statistics from logged data
166 | csv_file = os.path.join(log_dir,'progress.csv')
167 | progress = read_attr_from_csv(csv_file, self.return_attr)
168 | progress = progress if progress is not None else 0
169 | if self.return_mode == 'average':
170 | score = np.mean(progress)
171 | elif self.return_mode == 'full':
172 | score = progress
173 | elif self.return_mode == 'last':
174 | score = last_return
175 | else:
176 | NotImplementedError
177 | return score
178 |
179 | # Add a cpu data collection mode
180 | def obtain_episodes(self,
181 | itr,
182 | batch_size=None,
183 | agent_update=None,
184 | env_update=None):
185 | """Obtain one batch of episodes.
186 |
187 | Args:
188 | itr (int): Index of iteration (epoch).
189 | batch_size (int): Number of steps in batch. This is a hint that the
190 | sampler may or may not respect.
191 | agent_update (object): Value which will be passed into the
192 | `agent_update_fn` before doing sampling episodes. If a list is
193 | passed in, it must have length exactly `factory.n_workers`, and
194 | will be spread across the workers.
195 | env_update (object): Value which will be passed into the
196 | `env_update_fn` before sampling episodes. If a list is passed
197 | in, it must have length exactly `factory.n_workers`, and will
198 | be spread across the workers.
199 |
200 | Raises:
201 | ValueError: If the trainer was initialized without a sampler, or
202 | batch_size wasn't provided here or to train.
203 |
204 | Returns:
205 | EpisodeBatch: Batch of episodes.
206 |
207 | """
208 | if self._sampler is None:
209 | raise ValueError('trainer was not initialized with `sampler`. '
210 | 'the algo should have a `_sampler` field when'
211 | '`setup()` is called')
212 | if batch_size is None and self._train_args.batch_size is None:
213 | raise ValueError(
214 | 'trainer was not initialized with `batch_size`. '
215 | 'Either provide `batch_size` to trainer.train, '
216 | ' or pass `batch_size` to trainer.obtain_samples.')
217 | episodes = None
218 | if agent_update is None:
219 | policy = getattr(self._algo, 'exploration_policy', None)
220 | if policy is None:
221 | # This field should exist, since self.make_sampler would have
222 | # failed otherwise.
223 | policy = self._algo.policy
224 | agent_update = policy.get_param_values()
225 |
226 | # XXX Move the tensor to cpu.
227 | if self.force_cpu_data_collection:
228 | for k,v in agent_update.items():
229 | if v.device.type != 'cpu':
230 | agent_update[k] = v.to('cpu')
231 |
232 | # XXX Time data collection.
233 | _start_sampling_time = time.time()
234 | episodes = self._sampler.obtain_samples(
235 | itr, (batch_size or self._train_args.batch_size),
236 | agent_update=agent_update,
237 | env_update=env_update)
238 | self._sampling_time = time.time() - _start_sampling_time
239 | self._stats.total_env_steps += sum(episodes.lengths)
240 | return episodes
241 |
242 | # Log sampling time and Epoch
243 | def log_diagnostics(self, pause_for_plot=False):
244 | """Log diagnostics.
245 |
246 | Args:
247 | pause_for_plot (bool): Pause for plot.
248 |
249 | """
250 | logger.log('Time %.2f s' % (time.time() - self._start_time))
251 | logger.log('EpochTime %.2f s' % (time.time() - self._itr_start_time)) # XXX
252 | logger.log('SamplingTime %.2f s' % (self._sampling_time))
253 | tabular.record('TotalEnvSteps', self._stats.total_env_steps)
254 | tabular.record('Epoch', self.step_itr)
255 | logger.log(tabular)
256 |
257 | if self._plot:
258 | self._plotter.update_plot(self._algo.policy,
259 | self._algo.max_episode_length)
260 | if pause_for_plot:
261 | input('Plotting evaluation run: Press Enter to " "continue...')
262 |
263 | class BatchTrainer(Trainer):
264 | """ A batch version of Trainer that disables environment sampling. """
265 |
266 | def obtain_samples(self,
267 | itr,
268 | batch_size=None,
269 | agent_update=None,
270 | env_update=None):
271 | """ Return an empty list. """
272 | return []
--------------------------------------------------------------------------------
/src/atac/garage_tools/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def str2bool(v):
4 | if isinstance(v, bool):
5 | return v
6 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
7 | return True
8 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
9 | return False
10 | else:
11 | raise TypeError('Boolean value expected.')
12 |
13 |
14 | def torch_method(torch_fun):
15 | def wrapped_fun(x):
16 | with torch.no_grad():
17 | return torch_fun(torch.Tensor(x)).numpy()
18 | return wrapped_fun
19 |
20 |
21 | import numpy as np
22 | import csv
23 | def read_attr_from_csv(csv_path, attr, delimiter=','):
24 | with open(csv_path) as csv_file:
25 | reader = csv.reader(csv_file, delimiter=delimiter)
26 | try:
27 | row = next(reader)
28 | except Exception:
29 | return None
30 | if attr not in row:
31 | return None
32 | idx = row.index(attr) # the column number for this attribute
33 | vals = []
34 | for row in reader:
35 | vals.append(row[idx])
36 |
37 | vals = [np.nan if v=='' else v for v in vals]
38 | return np.array(vals, dtype=np.float64)
39 |
--------------------------------------------------------------------------------