├── .gitignore ├── LICENSE ├── README.md ├── data ├── images │ ├── humanoid_perturbation.jpg │ ├── logo.png │ └── selection.jpg └── logs │ └── tensorflow_logs.pkl ├── setup.py └── tonic ├── __init__.py ├── agents ├── __init__.py ├── agent.py └── basic.py ├── environments ├── __init__.py ├── builders.py ├── distributed.py └── wrappers.py ├── explorations ├── __init__.py └── noisy.py ├── play.py ├── plot.py ├── replays ├── __init__.py ├── buffers.py ├── segments.py └── utils.py ├── tensorflow ├── __init__.py ├── agents │ ├── __init__.py │ ├── a2c.py │ ├── agent.py │ ├── d4pg.py │ ├── ddpg.py │ ├── mpo.py │ ├── ppo.py │ ├── sac.py │ ├── td3.py │ ├── td4.py │ └── trpo.py ├── models │ ├── __init__.py │ ├── actor_critics.py │ ├── actors.py │ ├── critics.py │ ├── encoders.py │ └── utils.py ├── normalizers │ ├── __init__.py │ ├── mean_stds.py │ └── returns.py └── updaters │ ├── __init__.py │ ├── actors.py │ ├── critics.py │ ├── optimizers.py │ └── utils.py ├── torch ├── __init__.py ├── agents │ ├── __init__.py │ ├── a2c.py │ ├── agent.py │ ├── d4pg.py │ ├── ddpg.py │ ├── mpo.py │ ├── ppo.py │ ├── sac.py │ ├── td3.py │ └── trpo.py ├── models │ ├── __init__.py │ ├── actor_critics.py │ ├── actors.py │ ├── critics.py │ ├── encoders.py │ └── utils.py ├── normalizers │ ├── __init__.py │ ├── mean_stds.py │ └── returns.py └── updaters │ ├── __init__.py │ ├── actors.py │ ├── critics.py │ ├── optimizers.py │ └── utils.py ├── train.py └── utils ├── logger.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.py~ 3 | .ipynb_checkpoints 4 | *.egg-info 5 | __pycache__/ 6 | MUJOCO_LOG.TXT 7 | *DS_Store 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2020 Fabio Pardo (https://github.com/fabiopardo/tonic) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tonic 2 | 3 |
4 |

5 |
6 | 7 | Welcome to the Tonic RL library! 8 | 9 | Please take a look at [the paper](https://arxiv.org/abs/2011.07537) for details 10 | and results. 11 | 12 | The main design principles are: 13 | 14 | * **Modularity:** Building blocks for creating RL agents, such as models, 15 | replays, or exploration strategies, are implemented as configurable modules. 16 | 17 | * **Readability:** Agents are written in a simple way with an identical API and 18 | logs are nicely displayed on the terminal with a progress bar. 19 | 20 | * **Fair comparison:** The training pipeline is unique and compatible with all 21 | Tonic agents and environments. Agents are defined by their core ideas while 22 | general tricks/improvements like 23 | [non-terminal timeouts](https://arxiv.org/pdf/1712.00378.pdf), 24 | observation normalization and action scaling are shared. 25 | 26 | * **Benchmarking:** Benchmark data of the provided agents trained on 27 | [70 continuous control environments](https://github.com/fabiopardo/tonic_data/blob/master/images/benchmark.pdf) 28 | are provided for direct comparison. 29 | 30 | * **Wrapped popular environments:** Environments from 31 | [OpenAI Gym](https://github.com/openai/gym), 32 | [PyBullet](https://github.com/bulletphysics/bullet3) and 33 | [DeepMind Control Suite](https://github.com/deepmind/dm_control) are made 34 | compatible with 35 | [non-terminal timeouts](https://arxiv.org/pdf/1712.00378.pdf) and synchronous 36 | distributed training. 37 | 38 | * **Compatibility with different ML frameworks:** Both TensorFlow 2 and PyTorch 39 | are currently supported. Simply import `tonic.tensorflow` or `tonic.torch`. 40 | 41 | * **Experimenting from the console:** While launch scripts can be used, 42 | iterating over various configurations from a console is made possible using 43 | snippets of Python code directly. 44 | 45 | * **Visualization of trained agents:** Experiment configurations and 46 | checkpoints can be loaded to play with trained agents. 47 | 48 | * **Collection of trained models:** To keep the main Tonic repository light, 49 | the full logs and trained models from the benchmark are available in the 50 | [tonic_data repository](https://github.com/fabiopardo/tonic_data). 51 | 52 | # Instructions 53 | 54 | ## Install from source 55 | 56 | Download and install Tonic: 57 | ```bash 58 | git clone https://github.com/fabiopardo/tonic.git 59 | pip install -e tonic/ 60 | ``` 61 | 62 | Install TensorFlow or PyTorch, for example using: 63 | ```bash 64 | pip install tensorflow torch 65 | ``` 66 | 67 | ## Launch experiments 68 | 69 | Use TensorFlow or PyTorch to train an agent, for example using: 70 | ```bash 71 | python3 -m tonic.train \ 72 | --header 'import tonic.torch' \ 73 | --agent 'tonic.torch.agents.PPO()' \ 74 | --environment 'tonic.environments.Gym("BipedalWalker-v3")' \ 75 | --name PPO-X \ 76 | --seed 0 77 | ``` 78 | 79 | Snippets of Python code are used to directly configure the experiment. This is 80 | a very powerful feature allowing to configure agents and environments with 81 | various arguments or even load custom modules without adding them to the 82 | library. For example: 83 | ```bash 84 | python3 -m tonic.train \ 85 | --header "import sys; sys.path.append('path/to/custom'); from custom import CustomAgent" \ 86 | --agent "CustomAgent()" \ 87 | --environment "tonic.environments.Bullet('AntBulletEnv-v0')" \ 88 | --seed 0 89 | ``` 90 | 91 | By default, environments use non-terminal timeouts, which is particularly 92 | important for locomotion tasks. But a time feature can be added to the 93 | observations to keep the MDP Markovian. See the 94 | [Time Limits in RL](https://arxiv.org/pdf/1712.00378.pdf) paper for more 95 | details. For example: 96 | ```bash 97 | python3 -m tonic.train \ ⏎ 98 | --header 'import tonic.tensorflow' \ 99 | --agent 'tonic.tensorflow.agents.PPO()' \ 100 | --environment 'tonic.environments.Gym("Reacher-v2", terminal_timeouts=True, time_feature=True)' \ 101 | --seed 0 102 | ``` 103 | 104 | Distributed training can be used to accelerate learning. In Tonic, groups of 105 | sequential workers can be launched in parallel processes using for example: 106 | ```bash 107 | python3 -m tonic.train \ 108 | --header "import tonic.tensorflow" \ 109 | --agent "tonic.tensorflow.agents.PPO()" \ 110 | --environment "tonic.environments.Gym('HalfCheetah-v3')" \ 111 | --parallel 10 --sequential 100 \ 112 | --seed 0 113 | ``` 114 | 115 | ## Plot results 116 | 117 | During training, the experiment configuration, logs and checkpoints are 118 | saved in `environment/agent/seed/`. 119 | 120 | Result can be plotted with: 121 | ```bash 122 | python3 -m tonic.plot --path BipedalWalker-v3/ --baselines all 123 | ``` 124 | Regular expressions like `BipedalWalker-v3/PPO-X/0`, 125 | `BipedalWalker-v3/{PPO*,DDPG*}` or `*Bullet*` can be used to point to different 126 | sets of logs. 127 | The `--baselines` argument can be used to load logs from the benchmark. For 128 | example `--baselines all` uses all agents while `--baselines A2C PPO TRPO` will 129 | use logs from A2C, PPO and TRPO. 130 | 131 | Different headers can be used for the x and y axes, for example to compare the 132 | gain in wall clock time of using distributed training, replace `--parallel 10` 133 | with `--parallel 5` in the last training example and plot the result with: 134 | ```bash 135 | python3 -m tonic.plot --path HalfCheetah-v3/ --x_axis train/seconds --x_label Seconds 136 | ``` 137 | 138 | ## Play with trained models 139 | 140 | After some training time, checkpoints are generated and can be used to play 141 | with the trained agent: 142 | ```bash 143 | python3 -m tonic.play --path BipedalWalker-v3/PPO-X/0 144 | ``` 145 | 146 | Environments are rendered using the appropriate framework. For example, when 147 | playing with DeepMind Control Suite environments, policies are loaded in a 148 | `dm_control.viewer` where `Space` is used to start the interaction, `Backspace` 149 | is used to start a new episode, `[` and `]` are used to switch cameras and 150 | double click on a body part followed by `Ctrl + mouse clicks` is used to add 151 | perturbations. 152 | 153 | ## Play with models from tonic_data 154 | 155 | The `tonic_data` repository can be downloaded with: 156 | ```bash 157 | git clone https://github.com/fabiopardo/tonic_data.git 158 | ``` 159 | 160 | The best seed for each agent is stored in `environment/agent` and can be 161 | reloaded using for example: 162 | ```bash 163 | python3 -m tonic.play --path tonic_data/tensorflow/humanoid-stand/TD3 164 | ``` 165 | 166 |
167 |

168 |
169 | 170 | The full benchmark plots are available 171 | [here](https://github.com/fabiopardo/tonic_data/blob/master/images/benchmark.pdf). 172 | 173 | They can be generated with: 174 | ```bash 175 | python3 -m tonic.plot \ 176 | --baselines all \ 177 | --backend agg --columns 7 --font_size 17 --legend_font_size 30 --legend_marker_size 20 \ 178 | --name benchmark 179 | ``` 180 | 181 | Or: 182 | ```bash 183 | python3 -m tonic.plot \ 184 | --path tonic_data/tensorflow \ 185 | --backend agg --columns 7 --font_size 17 --legend_font_size 30 --legend_marker_size 20 \ 186 | --name benchmark 187 | ``` 188 | 189 | And a selection can be generated with: 190 | ```bash 191 | python3 -m tonic.plot \ 192 | --path tonic_data/tensorflow/{AntBulletEnv-v0,BipedalWalker-v3,finger-turn_hard,fish-swim,HalfCheetah-v3,HopperBulletEnv-v0,Humanoid-v3,quadruped-walk,swimmer-swimmer15,Walker2d-v3} \ 193 | --backend agg --columns 5 --font_size 20 --legend_font_size 30 --legend_marker_size 20 \ 194 | --name selection 195 | ``` 196 | 197 |
198 |

199 |
200 | 201 | # Credit 202 | 203 | ## Other code bases 204 | 205 | Tonic was inspired by a number of other deep RL code bases. In particular, 206 | [OpenAI Baselines](https://github.com/openai/baselines), 207 | [Spinning Up in Deep RL](https://github.com/openai/spinningup) 208 | and [Acme](https://github.com/deepmind/acme). 209 | 210 | ## Citing Tonic 211 | 212 | If you use Tonic in your research, please cite the [paper](https://arxiv.org/abs/2011.07537): 213 | 214 | ``` 215 | @article{pardo2020tonic, 216 | title={Tonic: A Deep Reinforcement Learning Library for Fast Prototyping and Benchmarking}, 217 | author={Pardo, Fabio}, 218 | journal={arXiv preprint arXiv:2011.07537}, 219 | year={2020} 220 | } 221 | ``` 222 | -------------------------------------------------------------------------------- /data/images/humanoid_perturbation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fabiopardo/tonic/0e20c894ee68278ab68322de61bb2c7204a11d5f/data/images/humanoid_perturbation.jpg -------------------------------------------------------------------------------- /data/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fabiopardo/tonic/0e20c894ee68278ab68322de61bb2c7204a11d5f/data/images/logo.png -------------------------------------------------------------------------------- /data/images/selection.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fabiopardo/tonic/0e20c894ee68278ab68322de61bb2c7204a11d5f/data/images/selection.jpg -------------------------------------------------------------------------------- /data/logs/tensorflow_logs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fabiopardo/tonic/0e20c894ee68278ab68322de61bb2c7204a11d5f/data/logs/tensorflow_logs.pkl -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | 4 | setuptools.setup( 5 | name='tonic', 6 | description='Tonic RL Library', 7 | url='https://github.com/fabiopardo/tonic', 8 | version='0.3.0', 9 | author='Fabio Pardo', 10 | author_email='f.pardo@imperial.ac.uk', 11 | install_requires=[ 12 | 'gym', 'matplotlib', 'numpy', 'pandas', 'pyyaml', 'termcolor'], 13 | license='MIT', 14 | python_requires='>=3.6', 15 | keywords=['tonic', 'deep learning', 'reinforcement learning']) 16 | -------------------------------------------------------------------------------- /tonic/__init__.py: -------------------------------------------------------------------------------- 1 | from . import agents 2 | from . import environments 3 | from . import explorations 4 | from . import replays 5 | from .utils import logger 6 | from .utils.trainer import Trainer 7 | 8 | 9 | __all__ = [agents, environments, explorations, logger, replays, Trainer] 10 | -------------------------------------------------------------------------------- /tonic/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent import Agent 2 | from .basic import Constant, NormalRandom, OrnsteinUhlenbeck, UniformRandom 3 | 4 | 5 | __all__ = [Agent, Constant, NormalRandom, OrnsteinUhlenbeck, UniformRandom] 6 | -------------------------------------------------------------------------------- /tonic/agents/agent.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Agent(abc.ABC): 5 | '''Abstract class used to build agents.''' 6 | 7 | def initialize(self, observation_space, action_space, seed=None): 8 | pass 9 | 10 | @abc.abstractmethod 11 | def step(self, observations, steps): 12 | '''Returns actions during training.''' 13 | pass 14 | 15 | def update(self, observations, rewards, resets, terminations, steps): 16 | '''Informs the agent of the latest transitions during training.''' 17 | pass 18 | 19 | @abc.abstractmethod 20 | def test_step(self, observations, steps): 21 | '''Returns actions during testing.''' 22 | pass 23 | 24 | def test_update(self, observations, rewards, resets, terminations, steps): 25 | '''Informs the agent of the latest transitions during testing.''' 26 | pass 27 | 28 | def save(self, path): 29 | '''Saves the agent weights during training.''' 30 | pass 31 | 32 | def load(self, path): 33 | '''Reloads the agent weights from a checkpoint.''' 34 | pass 35 | -------------------------------------------------------------------------------- /tonic/agents/basic.py: -------------------------------------------------------------------------------- 1 | '''Some basic non-learning agents used for example for debugging.''' 2 | 3 | import numpy as np 4 | 5 | from tonic import agents 6 | 7 | 8 | class NormalRandom(agents.Agent): 9 | '''Random agent producing actions from normal distributions.''' 10 | 11 | def __init__(self, loc=0, scale=1): 12 | self.loc = loc 13 | self.scale = scale 14 | 15 | def initialize(self, observation_space, action_space, seed=None): 16 | self.action_size = action_space.shape[0] 17 | self.np_random = np.random.RandomState(seed) 18 | 19 | def step(self, observations, steps): 20 | return self._policy(observations) 21 | 22 | def test_step(self, observations, steps): 23 | return self._policy(observations) 24 | 25 | def _policy(self, observations): 26 | batch_size = len(observations) 27 | shape = (batch_size, self.action_size) 28 | return self.np_random.normal(self.loc, self.scale, shape) 29 | 30 | 31 | class UniformRandom(agents.Agent): 32 | '''Random agent producing actions from uniform distributions.''' 33 | 34 | def initialize(self, observation_space, action_space, seed=None): 35 | self.action_size = action_space.shape[0] 36 | self.np_random = np.random.RandomState(seed) 37 | 38 | def step(self, observations, steps): 39 | return self._policy(observations) 40 | 41 | def test_step(self, observations, steps): 42 | return self._policy(observations) 43 | 44 | def _policy(self, observations): 45 | batch_size = len(observations) 46 | shape = (batch_size, self.action_size) 47 | return self.np_random.uniform(-1, 1, shape) 48 | 49 | 50 | class OrnsteinUhlenbeck(agents.Agent): 51 | '''Random agent producing correlated actions from an OU process.''' 52 | 53 | def __init__(self, scale=0.2, clip=2, theta=.15, dt=1e-2): 54 | self.scale = scale 55 | self.clip = clip 56 | self.theta = theta 57 | self.dt = dt 58 | 59 | def initialize(self, observation_space, action_space, seed=None): 60 | self.action_size = action_space.shape[0] 61 | self.np_random = np.random.RandomState(seed) 62 | self.train_actions = None 63 | self.test_actions = None 64 | 65 | def step(self, observations, steps): 66 | return self._train_policy(observations) 67 | 68 | def test_step(self, observations, steps): 69 | return self._test_policy(observations) 70 | 71 | def _train_policy(self, observations): 72 | if self.train_actions is None: 73 | shape = (len(observations), self.action_size) 74 | self.train_actions = np.zeros(shape) 75 | self.train_actions = self._next_actions(self.train_actions) 76 | return self.train_actions 77 | 78 | def _test_policy(self, observations): 79 | if self.test_actions is None: 80 | shape = (len(observations), self.action_size) 81 | self.test_actions = np.zeros(shape) 82 | self.test_actions = self._next_actions(self.test_actions) 83 | return self.test_actions 84 | 85 | def _next_actions(self, actions): 86 | noises = self.np_random.normal(size=actions.shape) 87 | noises = np.clip(noises, -self.clip, self.clip) 88 | next_actions = (1 - self.theta * self.dt) * actions 89 | next_actions += self.scale * np.sqrt(self.dt) * noises 90 | next_actions = np.clip(next_actions, -1, 1) 91 | return next_actions 92 | 93 | def update(self, observations, rewards, resets, terminations, steps): 94 | self.train_actions *= (1. - resets)[:, None] 95 | 96 | def test_update(self, observations, rewards, resets, terminations, steps): 97 | self.test_actions *= (1. - resets)[:, None] 98 | 99 | 100 | class Constant(agents.Agent): 101 | '''Agent producing a unique constant action.''' 102 | 103 | def __init__(self, constant=0): 104 | self.constant = constant 105 | 106 | def initialize(self, observation_space, action_space, seed=None): 107 | self.action_size = action_space.shape[0] 108 | 109 | def step(self, observations, steps): 110 | return self._policy(observations) 111 | 112 | def test_step(self, observations, steps): 113 | return self._policy(observations) 114 | 115 | def _policy(self, observations): 116 | shape = (len(observations), self.action_size) 117 | return np.full(shape, self.constant) 118 | -------------------------------------------------------------------------------- /tonic/environments/__init__.py: -------------------------------------------------------------------------------- 1 | from .builders import Bullet, ControlSuite, Gym 2 | from .distributed import distribute, Parallel, Sequential 3 | from .wrappers import ActionRescaler, TimeFeature 4 | 5 | 6 | __all__ = [ 7 | Bullet, ControlSuite, Gym, distribute, Parallel, Sequential, 8 | ActionRescaler, TimeFeature] 9 | -------------------------------------------------------------------------------- /tonic/environments/builders.py: -------------------------------------------------------------------------------- 1 | '''Environment builders for popular domains.''' 2 | 3 | import os 4 | 5 | import gym.wrappers 6 | import numpy as np 7 | 8 | from tonic import environments 9 | from tonic.utils import logger 10 | 11 | 12 | def gym_environment(*args, **kwargs): 13 | '''Returns a wrapped Gym environment.''' 14 | 15 | def _builder(*args, **kwargs): 16 | return gym.make(*args, **kwargs) 17 | 18 | return build_environment(_builder, *args, **kwargs) 19 | 20 | 21 | def bullet_environment(*args, **kwargs): 22 | '''Returns a wrapped PyBullet environment.''' 23 | 24 | def _builder(*args, **kwargs): 25 | import pybullet_envs # noqa 26 | return gym.make(*args, **kwargs) 27 | 28 | return build_environment(_builder, *args, **kwargs) 29 | 30 | 31 | def control_suite_environment(*args, **kwargs): 32 | '''Returns a wrapped Control Suite environment.''' 33 | 34 | def _builder(name, *args, **kwargs): 35 | domain, task = name.split('-') 36 | environment = ControlSuiteEnvironment( 37 | domain_name=domain, task_name=task, *args, **kwargs) 38 | time_limit = int(environment.environment._step_limit) 39 | return gym.wrappers.TimeLimit(environment, time_limit) 40 | 41 | return build_environment(_builder, *args, **kwargs) 42 | 43 | 44 | def build_environment( 45 | builder, name, terminal_timeouts=False, time_feature=False, 46 | max_episode_steps='default', scaled_actions=True, *args, **kwargs 47 | ): 48 | '''Builds and wrap an environment. 49 | Time limits can be properly handled with terminal_timeouts=False or 50 | time_feature=True, see https://arxiv.org/pdf/1712.00378.pdf for more 51 | details. 52 | ''' 53 | 54 | # Build the environment. 55 | environment = builder(name, *args, **kwargs) 56 | 57 | # Get the default time limit. 58 | if max_episode_steps == 'default': 59 | max_episode_steps = environment._max_episode_steps 60 | 61 | # Remove the TimeLimit wrapper if needed. 62 | if not terminal_timeouts: 63 | assert type(environment) == gym.wrappers.TimeLimit, environment 64 | environment = environment.env 65 | 66 | # Add time as a feature if needed. 67 | if time_feature: 68 | environment = environments.wrappers.TimeFeature( 69 | environment, max_episode_steps) 70 | 71 | # Scale actions from [-1, 1]^n to the true action space if needed. 72 | if scaled_actions: 73 | environment = environments.wrappers.ActionRescaler(environment) 74 | 75 | environment.name = name 76 | environment.max_episode_steps = max_episode_steps 77 | 78 | return environment 79 | 80 | 81 | def _flatten_observation(observation): 82 | '''Turns OrderedDict observations into vectors.''' 83 | observation = [np.array([o]) if np.isscalar(o) else o.ravel() 84 | for o in observation.values()] 85 | return np.concatenate(observation, axis=0) 86 | 87 | 88 | class ControlSuiteEnvironment(gym.core.Env): 89 | '''Turns a Control Suite environment into a Gym environment.''' 90 | 91 | def __init__( 92 | self, domain_name, task_name, task_kwargs=None, visualize_reward=True, 93 | environment_kwargs=None 94 | ): 95 | from dm_control import suite 96 | self.environment = suite.load( 97 | domain_name=domain_name, task_name=task_name, 98 | task_kwargs=task_kwargs, visualize_reward=visualize_reward, 99 | environment_kwargs=environment_kwargs) 100 | 101 | # Create the observation space. 102 | observation_spec = self.environment.observation_spec() 103 | dim = sum([np.int(np.prod(spec.shape)) 104 | for spec in observation_spec.values()]) 105 | high = np.full(dim, np.inf, np.float32) 106 | self.observation_space = gym.spaces.Box(-high, high, dtype=np.float32) 107 | 108 | # Create the action space. 109 | action_spec = self.environment.action_spec() 110 | self.action_space = gym.spaces.Box( 111 | action_spec.minimum, action_spec.maximum, dtype=np.float32) 112 | 113 | def seed(self, seed): 114 | self.environment.task._random = np.random.RandomState(seed) 115 | 116 | def step(self, action): 117 | try: 118 | time_step = self.environment.step(action) 119 | observation = _flatten_observation(time_step.observation) 120 | reward = time_step.reward 121 | 122 | # Remove terminations from timeouts. 123 | done = time_step.last() 124 | if done: 125 | done = self.environment.task.get_termination( 126 | self.environment.physics) 127 | done = done is not None 128 | 129 | self.last_time_step = time_step 130 | 131 | # In case MuJoCo crashed. 132 | except Exception as e: 133 | path = logger.get_path() 134 | os.makedirs(path, exist_ok=True) 135 | save_path = os.path.join(path, 'crashes.txt') 136 | error = str(e) 137 | with open(save_path, 'a') as file: 138 | file.write(error + '\n') 139 | logger.error(error) 140 | observation = _flatten_observation(self.last_time_step.observation) 141 | observation = np.zeros_like(observation) 142 | reward = 0. 143 | done = True 144 | 145 | return observation, reward, done, {} 146 | 147 | def reset(self): 148 | time_step = self.environment.reset() 149 | self.last_time_step = time_step 150 | return _flatten_observation(time_step.observation) 151 | 152 | def render(self, mode='rgb_array', height=None, width=None, camera_id=0): 153 | '''Returns RGB frames from a camera.''' 154 | assert mode == 'rgb_array' 155 | return self.environment.physics.render( 156 | height=height, width=width, camera_id=camera_id) 157 | 158 | 159 | # Aliases. 160 | Gym = gym_environment 161 | Bullet = bullet_environment 162 | ControlSuite = control_suite_environment 163 | -------------------------------------------------------------------------------- /tonic/environments/distributed.py: -------------------------------------------------------------------------------- 1 | '''Builders for distributed training.''' 2 | 3 | import multiprocessing 4 | 5 | import numpy as np 6 | 7 | 8 | class Sequential: 9 | '''A group of environments used in sequence.''' 10 | 11 | def __init__(self, environment_builder, max_episode_steps, workers): 12 | self.environments = [environment_builder() for _ in range(workers)] 13 | self.max_episode_steps = max_episode_steps 14 | self.observation_space = self.environments[0].observation_space 15 | self.action_space = self.environments[0].action_space 16 | self.name = self.environments[0].name 17 | 18 | def initialize(self, seed): 19 | for i, environment in enumerate(self.environments): 20 | environment.seed(seed + i) 21 | 22 | def start(self): 23 | '''Used once to get the initial observations.''' 24 | observations = [env.reset() for env in self.environments] 25 | self.lengths = np.zeros(len(self.environments), int) 26 | return np.array(observations, np.float32) 27 | 28 | def step(self, actions): 29 | next_observations = [] # Observations for the transitions. 30 | rewards = [] 31 | resets = [] 32 | terminations = [] 33 | observations = [] # Observations for the actions selection. 34 | 35 | for i in range(len(self.environments)): 36 | ob, rew, term, _ = self.environments[i].step(actions[i]) 37 | 38 | self.lengths[i] += 1 39 | # Timeouts trigger resets but are not true terminations. 40 | reset = term or self.lengths[i] == self.max_episode_steps 41 | next_observations.append(ob) 42 | rewards.append(rew) 43 | resets.append(reset) 44 | terminations.append(term) 45 | 46 | if reset: 47 | ob = self.environments[i].reset() 48 | self.lengths[i] = 0 49 | 50 | observations.append(ob) 51 | 52 | observations = np.array(observations, np.float32) 53 | infos = dict( 54 | observations=np.array(next_observations, np.float32), 55 | rewards=np.array(rewards, np.float32), 56 | resets=np.array(resets, np.bool), 57 | terminations=np.array(terminations, np.bool)) 58 | return observations, infos 59 | 60 | def render(self, mode='human', *args, **kwargs): 61 | outs = [] 62 | for env in self.environments: 63 | out = env.render(mode=mode, *args, **kwargs) 64 | outs.append(out) 65 | if mode != 'human': 66 | return np.array(outs) 67 | 68 | 69 | class Parallel: 70 | '''A group of sequential environments used in parallel.''' 71 | 72 | def __init__( 73 | self, environment_builder, worker_groups, workers_per_group, 74 | max_episode_steps 75 | ): 76 | self.environment_builder = environment_builder 77 | self.worker_groups = worker_groups 78 | self.workers_per_group = workers_per_group 79 | self.max_episode_steps = max_episode_steps 80 | 81 | def initialize(self, seed): 82 | def proc(action_pipe, index, seed): 83 | '''Process holding a sequential group of environments.''' 84 | envs = Sequential( 85 | self.environment_builder, self.max_episode_steps, 86 | self.workers_per_group) 87 | envs.initialize(seed) 88 | 89 | observations = envs.start() 90 | self.output_queue.put((index, observations)) 91 | 92 | while True: 93 | actions = action_pipe.recv() 94 | out = envs.step(actions) 95 | self.output_queue.put((index, out)) 96 | 97 | dummy_environment = self.environment_builder() 98 | self.observation_space = dummy_environment.observation_space 99 | self.action_space = dummy_environment.action_space 100 | del dummy_environment 101 | self.started = False 102 | 103 | self.output_queue = multiprocessing.Queue() 104 | self.action_pipes = [] 105 | 106 | for i in range(self.worker_groups): 107 | pipe, worker_end = multiprocessing.Pipe() 108 | self.action_pipes.append(pipe) 109 | group_seed = seed + i * self.workers_per_group 110 | process = multiprocessing.Process( 111 | target=proc, args=(worker_end, i, group_seed)) 112 | process.daemon = True 113 | process.start() 114 | 115 | def start(self): 116 | '''Used once to get the initial observations.''' 117 | assert not self.started 118 | self.started = True 119 | observations_list = [None for _ in range(self.worker_groups)] 120 | 121 | for _ in range(self.worker_groups): 122 | index, observations = self.output_queue.get() 123 | observations_list[index] = observations 124 | 125 | self.observations_list = np.array(observations_list) 126 | self.next_observations_list = np.zeros_like(self.observations_list) 127 | self.rewards_list = np.zeros( 128 | (self.worker_groups, self.workers_per_group), np.float32) 129 | self.resets_list = np.zeros( 130 | (self.worker_groups, self.workers_per_group), np.bool) 131 | self.terminations_list = np.zeros( 132 | (self.worker_groups, self.workers_per_group), np.bool) 133 | 134 | return np.concatenate(self.observations_list) 135 | 136 | def step(self, actions): 137 | actions_list = np.split(actions, self.worker_groups) 138 | for actions, pipe in zip(actions_list, self.action_pipes): 139 | pipe.send(actions) 140 | 141 | for _ in range(self.worker_groups): 142 | index, (observations, infos) = self.output_queue.get() 143 | self.observations_list[index] = observations 144 | self.next_observations_list[index] = infos['observations'] 145 | self.rewards_list[index] = infos['rewards'] 146 | self.resets_list[index] = infos['resets'] 147 | self.terminations_list[index] = infos['terminations'] 148 | 149 | observations = np.concatenate(self.observations_list) 150 | infos = dict( 151 | observations=np.concatenate(self.next_observations_list), 152 | rewards=np.concatenate(self.rewards_list), 153 | resets=np.concatenate(self.resets_list), 154 | terminations=np.concatenate(self.terminations_list)) 155 | return observations, infos 156 | 157 | 158 | def distribute(environment_builder, worker_groups=1, workers_per_group=1): 159 | '''Distributes workers over parallel and sequential groups.''' 160 | dummy_environment = environment_builder() 161 | max_episode_steps = dummy_environment.max_episode_steps 162 | del dummy_environment 163 | 164 | if worker_groups < 2: 165 | return Sequential( 166 | environment_builder, max_episode_steps=max_episode_steps, 167 | workers=workers_per_group) 168 | 169 | return Parallel( 170 | environment_builder, worker_groups=worker_groups, 171 | workers_per_group=workers_per_group, 172 | max_episode_steps=max_episode_steps) 173 | -------------------------------------------------------------------------------- /tonic/environments/wrappers.py: -------------------------------------------------------------------------------- 1 | '''Environment wrappers.''' 2 | 3 | import gym 4 | import numpy as np 5 | 6 | 7 | class ActionRescaler(gym.ActionWrapper): 8 | '''Rescales actions from [-1, 1]^n to the true action space. 9 | The baseline agents return actions in [-1, 1]^n.''' 10 | 11 | def __init__(self, env): 12 | assert isinstance(env.action_space, gym.spaces.Box) 13 | super().__init__(env) 14 | high = np.ones(env.action_space.shape, dtype=np.float32) 15 | self.action_space = gym.spaces.Box(low=-high, high=high) 16 | true_low = env.action_space.low 17 | true_high = env.action_space.high 18 | self.bias = (true_high + true_low) / 2 19 | self.scale = (true_high - true_low) / 2 20 | 21 | def action(self, action): 22 | return self.bias + self.scale * np.clip(action, -1, 1) 23 | 24 | 25 | class TimeFeature(gym.Wrapper): 26 | '''Adds a notion of time in the observations. 27 | It can be used in terminal timeout settings to get Markovian MDPs. 28 | ''' 29 | 30 | def __init__(self, env, max_steps, low=-1, high=1): 31 | super().__init__(env) 32 | dtype = self.observation_space.dtype 33 | self.observation_space = gym.spaces.Box( 34 | low=np.append(self.observation_space.low, low).astype(dtype), 35 | high=np.append(self.observation_space.high, high).astype(dtype)) 36 | self.max_episode_steps = max_steps 37 | self.steps = 0 38 | self.low = low 39 | self.high = high 40 | 41 | def reset(self, **kwargs): 42 | self.steps = 0 43 | observation = self.env.reset(**kwargs) 44 | observation = np.append(observation, self.low) 45 | return observation 46 | 47 | def step(self, action): 48 | assert self.steps < self.max_episode_steps 49 | observation, reward, done, info = self.env.step(action) 50 | self.steps += 1 51 | prop = self.steps / self.max_episode_steps 52 | v = self.low + (self.high - self.low) * prop 53 | observation = np.append(observation, v) 54 | return observation, reward, done, info 55 | -------------------------------------------------------------------------------- /tonic/explorations/__init__.py: -------------------------------------------------------------------------------- 1 | from .noisy import NoActionNoise 2 | from .noisy import NormalActionNoise 3 | from .noisy import OrnsteinUhlenbeckActionNoise 4 | 5 | 6 | __all__ = [NoActionNoise, NormalActionNoise, OrnsteinUhlenbeckActionNoise] 7 | -------------------------------------------------------------------------------- /tonic/explorations/noisy.py: -------------------------------------------------------------------------------- 1 | '''Non-differentiable noisy exploration methods.''' 2 | 3 | import numpy as np 4 | 5 | 6 | class NoActionNoise: 7 | def __init__(self, start_steps=20000): 8 | self.start_steps = start_steps 9 | 10 | def initialize(self, policy, action_space, seed=None): 11 | self.policy = policy 12 | self.action_size = action_space.shape[0] 13 | self.np_random = np.random.RandomState(seed) 14 | 15 | def __call__(self, observations, steps): 16 | if steps > self.start_steps: 17 | actions = self.policy(observations) 18 | actions = np.clip(actions, -1, 1) 19 | else: 20 | shape = (len(observations), self.action_size) 21 | actions = self.np_random.uniform(-1, 1, shape) 22 | return actions 23 | 24 | def update(self, resets): 25 | pass 26 | 27 | 28 | class NormalActionNoise: 29 | def __init__(self, scale=0.1, start_steps=20000): 30 | self.scale = scale 31 | self.start_steps = start_steps 32 | 33 | def initialize(self, policy, action_space, seed=None): 34 | self.policy = policy 35 | self.action_size = action_space.shape[0] 36 | self.np_random = np.random.RandomState(seed) 37 | 38 | def __call__(self, observations, steps): 39 | if steps > self.start_steps: 40 | actions = self.policy(observations) 41 | noises = self.scale * self.np_random.normal(size=actions.shape) 42 | actions = (actions + noises).astype(np.float32) 43 | actions = np.clip(actions, -1, 1) 44 | else: 45 | shape = (len(observations), self.action_size) 46 | actions = self.np_random.uniform(-1, 1, shape) 47 | return actions 48 | 49 | def update(self, resets): 50 | pass 51 | 52 | 53 | class OrnsteinUhlenbeckActionNoise: 54 | def __init__( 55 | self, scale=0.1, clip=2, theta=.15, dt=1e-2, start_steps=20000 56 | ): 57 | self.scale = scale 58 | self.clip = clip 59 | self.theta = theta 60 | self.dt = dt 61 | self.start_steps = start_steps 62 | 63 | def initialize(self, policy, action_space, seed=None): 64 | self.policy = policy 65 | self.action_size = action_space.shape[0] 66 | self.np_random = np.random.RandomState(seed) 67 | self.noises = None 68 | 69 | def __call__(self, observations, steps): 70 | if steps > self.start_steps: 71 | actions = self.policy(observations) 72 | 73 | if self.noises is None: 74 | self.noises = np.zeros_like(actions) 75 | noises = self.np_random.normal(size=actions.shape) 76 | noises = np.clip(noises, -self.clip, self.clip) 77 | self.noises -= self.theta * self.noises * self.dt 78 | self.noises += self.scale * np.sqrt(self.dt) * noises 79 | actions = (actions + self.noises).astype(np.float32) 80 | actions = np.clip(actions, -1, 1) 81 | else: 82 | shape = (len(observations), self.action_size) 83 | actions = self.np_random.uniform(-1, 1, shape) 84 | return actions 85 | 86 | def update(self, resets): 87 | if self.noises is not None: 88 | self.noises *= (1. - resets)[:, None] 89 | -------------------------------------------------------------------------------- /tonic/play.py: -------------------------------------------------------------------------------- 1 | '''Script used to play with trained agents.''' 2 | 3 | import argparse 4 | import os 5 | 6 | import numpy as np 7 | import yaml 8 | 9 | import tonic # noqa 10 | 11 | 12 | def play_gym(agent, environment): 13 | '''Launches an agent in a Gym-based environment.''' 14 | 15 | environment = tonic.environments.distribute(lambda: environment) 16 | 17 | observations = environment.start() 18 | environment.render() 19 | 20 | score = 0 21 | length = 0 22 | min_reward = float('inf') 23 | max_reward = -float('inf') 24 | global_min_reward = float('inf') 25 | global_max_reward = -float('inf') 26 | steps = 0 27 | episodes = 0 28 | 29 | while True: 30 | actions = agent.test_step(observations, steps) 31 | observations, infos = environment.step(actions) 32 | agent.test_update(**infos, steps=steps) 33 | environment.render() 34 | 35 | steps += 1 36 | reward = infos['rewards'][0] 37 | score += reward 38 | min_reward = min(min_reward, reward) 39 | max_reward = max(max_reward, reward) 40 | global_min_reward = min(global_min_reward, reward) 41 | global_max_reward = max(global_max_reward, reward) 42 | length += 1 43 | 44 | if infos['resets'][0]: 45 | term = infos['terminations'][0] 46 | episodes += 1 47 | 48 | print() 49 | print(f'Episodes: {episodes:,}') 50 | print(f'Score: {score:,.3f}') 51 | print(f'Length: {length:,}') 52 | print(f'Terminal: {term:}') 53 | print(f'Min reward: {min_reward:,.3f}') 54 | print(f'Max reward: {max_reward:,.3f}') 55 | print(f'Global min reward: {min_reward:,.3f}') 56 | print(f'Global max reward: {max_reward:,.3f}') 57 | 58 | score = 0 59 | length = 0 60 | min_reward = float('inf') 61 | max_reward = -float('inf') 62 | 63 | 64 | def play_control_suite(agent, environment): 65 | '''Launches an agent in a DeepMind Control Suite-based environment.''' 66 | 67 | from dm_control import viewer 68 | 69 | class Wrapper: 70 | '''Wrapper used to plug a Tonic environment in a dm_control viewer.''' 71 | 72 | def __init__(self, environment): 73 | self.environment = environment 74 | self.unwrapped = environment.unwrapped 75 | self.action_spec = self.unwrapped.environment.action_spec 76 | self.physics = self.unwrapped.environment.physics 77 | self.infos = None 78 | self.steps = 0 79 | self.episodes = 0 80 | self.min_reward = float('inf') 81 | self.max_reward = -float('inf') 82 | self.global_min_reward = float('inf') 83 | self.global_max_reward = -float('inf') 84 | 85 | def reset(self): 86 | '''Mimics a dm_control reset for the viewer.''' 87 | 88 | self.observations = self.environment.reset()[None] 89 | 90 | self.score = 0 91 | self.length = 0 92 | self.min_reward = float('inf') 93 | self.max_reward = -float('inf') 94 | 95 | return self.unwrapped.last_time_step 96 | 97 | def step(self, actions): 98 | '''Mimics a dm_control step for the viewer.''' 99 | 100 | assert not np.isnan(actions.sum()) 101 | ob, rew, term, _ = self.environment.step(actions[0]) 102 | 103 | self.score += rew 104 | self.length += 1 105 | self.min_reward = min(self.min_reward, rew) 106 | self.max_reward = max(self.max_reward, rew) 107 | self.global_min_reward = min(self.global_min_reward, rew) 108 | self.global_max_reward = max(self.global_max_reward, rew) 109 | timeout = self.length == self.environment.max_episode_steps 110 | done = term or timeout 111 | 112 | if done: 113 | self.episodes += 1 114 | print() 115 | print(f'Episodes: {self.episodes:,}') 116 | print(f'Score: {self.score:,.3f}') 117 | print(f'Length: {self.length:,}') 118 | print(f'Terminal: {term:}') 119 | print(f'Min reward: {self.min_reward:,.3f}') 120 | print(f'Max reward: {self.max_reward:,.3f}') 121 | print(f'Global min reward: {self.min_reward:,.3f}') 122 | print(f'Global max reward: {self.max_reward:,.3f}') 123 | 124 | self.observations = ob[None] 125 | self.infos = dict( 126 | observations=ob[None], rewards=np.array([rew]), 127 | resets=np.array([done]), terminations=np.array([term])) 128 | 129 | return self.unwrapped.last_time_step 130 | 131 | # Wrap the environment for the viewer. 132 | environment = Wrapper(environment) 133 | 134 | def policy(timestep): 135 | '''Mimics a dm_control policy for the viewer.''' 136 | 137 | if environment.infos is not None: 138 | agent.test_update(**environment.infos, steps=environment.steps) 139 | environment.steps += 1 140 | return agent.test_step(environment.observations, environment.steps) 141 | 142 | # Launch the viewer with the wrapped environment and policy. 143 | viewer.launch(environment, policy) 144 | 145 | 146 | def play(path, checkpoint, seed, header, agent, environment): 147 | '''Reloads an agent and an environment from a previous experiment.''' 148 | 149 | checkpoint_path = None 150 | 151 | if path: 152 | tonic.logger.log(f'Loading experiment from {path}') 153 | 154 | # Use no checkpoint, the agent is freshly created. 155 | if checkpoint == 'none' or agent is not None: 156 | tonic.logger.log('Not loading any weights') 157 | 158 | else: 159 | checkpoint_path = os.path.join(path, 'checkpoints') 160 | if not os.path.isdir(checkpoint_path): 161 | tonic.logger.error(f'{checkpoint_path} is not a directory') 162 | checkpoint_path = None 163 | 164 | # List all the checkpoints. 165 | checkpoint_ids = [] 166 | for file in os.listdir(checkpoint_path): 167 | if file[:5] == 'step_': 168 | checkpoint_id = file.split('.')[0] 169 | checkpoint_ids.append(int(checkpoint_id[5:])) 170 | 171 | if checkpoint_ids: 172 | # Use the last checkpoint. 173 | if checkpoint == 'last': 174 | checkpoint_id = max(checkpoint_ids) 175 | checkpoint_path = os.path.join( 176 | checkpoint_path, f'step_{checkpoint_id}') 177 | 178 | # Use the specified checkpoint. 179 | else: 180 | checkpoint_id = int(checkpoint) 181 | if checkpoint_id in checkpoint_ids: 182 | checkpoint_path = os.path.join( 183 | checkpoint_path, f'step_{checkpoint_id}') 184 | else: 185 | tonic.logger.error(f'Checkpoint {checkpoint_id} ' 186 | f'not found in {checkpoint_path}') 187 | checkpoint_path = None 188 | 189 | else: 190 | tonic.logger.error(f'No checkpoint found in {checkpoint_path}') 191 | checkpoint_path = None 192 | 193 | # Load the experiment configuration. 194 | arguments_path = os.path.join(path, 'config.yaml') 195 | with open(arguments_path, 'r') as config_file: 196 | config = yaml.load(config_file, Loader=yaml.FullLoader) 197 | config = argparse.Namespace(**config) 198 | 199 | header = header or config.header 200 | agent = agent or config.agent 201 | environment = environment or config.test_environment 202 | environment = environment or config.environment 203 | 204 | # Run the header first, e.g. to load an ML framework. 205 | if header: 206 | exec(header) 207 | 208 | # Build the agent. 209 | if not agent: 210 | raise ValueError('No agent specified.') 211 | agent = eval(agent) 212 | 213 | # Build the environment. 214 | environment = eval(environment) 215 | environment.seed(seed) 216 | 217 | # Initialize the agent. 218 | agent.initialize( 219 | observation_space=environment.observation_space, 220 | action_space=environment.action_space, seed=seed) 221 | 222 | # Load the weights of the agent form a checkpoint. 223 | if checkpoint_path: 224 | agent.load(checkpoint_path) 225 | 226 | # Play with the agent in the environment. 227 | if isinstance(environment, tonic.environments.wrappers.ActionRescaler): 228 | environment_type = environment.env.__class__.__name__ 229 | else: 230 | environment_type = environment.__class__.__name__ 231 | 232 | if environment_type == 'ControlSuiteEnvironment': 233 | play_control_suite(agent, environment) 234 | else: 235 | if 'Bullet' in environment_type: 236 | environment.render() 237 | play_gym(agent, environment) 238 | 239 | 240 | if __name__ == '__main__': 241 | # Argument parsing. 242 | parser = argparse.ArgumentParser() 243 | parser.add_argument('--path') 244 | parser.add_argument('--checkpoint', default='last') 245 | parser.add_argument('--seed', type=int, default=0) 246 | parser.add_argument('--header') 247 | parser.add_argument('--agent') 248 | parser.add_argument('--environment', '--env') 249 | args = vars(parser.parse_args()) 250 | play(**args) 251 | -------------------------------------------------------------------------------- /tonic/replays/__init__.py: -------------------------------------------------------------------------------- 1 | from .buffers import Buffer 2 | from .segments import Segment 3 | from .utils import flatten_batch, lambda_returns 4 | 5 | 6 | __all__ = [flatten_batch, lambda_returns, Buffer, Segment] 7 | -------------------------------------------------------------------------------- /tonic/replays/buffers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Buffer: 5 | '''Replay storing a large number of transitions for off-policy learning 6 | and using n-step returns.''' 7 | 8 | def __init__( 9 | self, size=int(1e6), return_steps=1, batch_iterations=50, 10 | batch_size=100, discount_factor=0.99, steps_before_batches=int(1e4), 11 | steps_between_batches=50 12 | ): 13 | self.full_max_size = size 14 | self.return_steps = return_steps 15 | self.batch_iterations = batch_iterations 16 | self.batch_size = batch_size 17 | self.discount_factor = discount_factor 18 | self.steps_before_batches = steps_before_batches 19 | self.steps_between_batches = steps_between_batches 20 | 21 | def initialize(self, seed=None): 22 | self.np_random = np.random.RandomState(seed) 23 | self.buffers = None 24 | self.index = 0 25 | self.size = 0 26 | self.last_steps = 0 27 | 28 | def ready(self, steps): 29 | if steps < self.steps_before_batches: 30 | return False 31 | return (steps - self.last_steps) >= self.steps_between_batches 32 | 33 | def store(self, **kwargs): 34 | if 'terminations' in kwargs: 35 | continuations = np.float32(1 - kwargs['terminations']) 36 | kwargs['discounts'] = continuations * self.discount_factor 37 | 38 | # Create the named buffers. 39 | if self.buffers is None: 40 | self.num_workers = len(list(kwargs.values())[0]) 41 | self.max_size = self.full_max_size // self.num_workers 42 | self.buffers = {} 43 | for key, val in kwargs.items(): 44 | shape = (self.max_size,) + np.array(val).shape 45 | self.buffers[key] = np.full(shape, np.nan, np.float32) 46 | 47 | # Store the new values. 48 | for key, val in kwargs.items(): 49 | self.buffers[key][self.index] = val 50 | 51 | # Accumulate values for n-step returns. 52 | if self.return_steps > 1: 53 | self.accumulate_n_steps(kwargs) 54 | 55 | self.index = (self.index + 1) % self.max_size 56 | self.size = min(self.size + 1, self.max_size) 57 | 58 | def accumulate_n_steps(self, kwargs): 59 | rewards = kwargs['rewards'] 60 | next_observations = kwargs['next_observations'] 61 | discounts = kwargs['discounts'] 62 | masks = np.ones(self.num_workers, np.float32) 63 | 64 | for i in range(min(self.size, self.return_steps - 1)): 65 | index = (self.index - i - 1) % self.max_size 66 | masks *= (1 - self.buffers['resets'][index]) 67 | new_rewards = (self.buffers['rewards'][index] + 68 | self.buffers['discounts'][index] * rewards) 69 | self.buffers['rewards'][index] = ( 70 | (1 - masks) * self.buffers['rewards'][index] + 71 | masks * new_rewards) 72 | new_discounts = self.buffers['discounts'][index] * discounts 73 | self.buffers['discounts'][index] = ( 74 | (1 - masks) * self.buffers['discounts'][index] + 75 | masks * new_discounts) 76 | self.buffers['next_observations'][index] = ( 77 | (1 - masks)[:, None] * 78 | self.buffers['next_observations'][index] + 79 | masks[:, None] * next_observations) 80 | 81 | def get(self, *keys, steps): 82 | '''Get batches from named buffers.''' 83 | 84 | for _ in range(self.batch_iterations): 85 | total_size = self.size * self.num_workers 86 | indices = self.np_random.randint(total_size, size=self.batch_size) 87 | rows = indices // self.num_workers 88 | columns = indices % self.num_workers 89 | yield {k: self.buffers[k][rows, columns] for k in keys} 90 | 91 | self.last_steps = steps 92 | -------------------------------------------------------------------------------- /tonic/replays/segments.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tonic import replays 4 | 5 | 6 | class Segment: 7 | '''Replay storing recent transitions for on-policy learning.''' 8 | 9 | def __init__( 10 | self, size=4096, batch_iterations=80, batch_size=None, 11 | discount_factor=0.99, trace_decay=0.97 12 | ): 13 | self.max_size = size 14 | self.batch_iterations = batch_iterations 15 | self.batch_size = batch_size 16 | self.discount_factor = discount_factor 17 | self.trace_decay = trace_decay 18 | 19 | def initialize(self, seed=None): 20 | self.np_random = np.random.RandomState(seed) 21 | self.buffers = None 22 | self.index = 0 23 | 24 | def ready(self): 25 | return self.index == self.max_size 26 | 27 | def store(self, **kwargs): 28 | if self.buffers is None: 29 | self.num_workers = len(list(kwargs.values())[0]) 30 | self.buffers = {} 31 | for key, val in kwargs.items(): 32 | shape = (self.max_size,) + np.array(val).shape 33 | self.buffers[key] = np.zeros(shape, np.float32) 34 | for key, val in kwargs.items(): 35 | self.buffers[key][self.index] = val 36 | self.index += 1 37 | 38 | def get_full(self, *keys): 39 | self.index = 0 40 | 41 | if 'advantages' in keys: 42 | advs = self.buffers['returns'] - self.buffers['values'] 43 | std = advs.std() 44 | if std != 0: 45 | advs = (advs - advs.mean()) / std 46 | self.buffers['advantages'] = advs 47 | 48 | return {k: replays.flatten_batch(self.buffers[k]) for k in keys} 49 | 50 | def get(self, *keys): 51 | '''Get mini-batches from named buffers.''' 52 | 53 | batch = self.get_full(*keys) 54 | 55 | if self.batch_size is None: 56 | for _ in range(self.batch_iterations): 57 | yield batch 58 | else: 59 | size = self.max_size * self.num_workers 60 | all_indices = np.arange(size) 61 | for _ in range(self.batch_iterations): 62 | self.np_random.shuffle(all_indices) 63 | for i in range(0, size, self.batch_size): 64 | indices = all_indices[i:i + self.batch_size] 65 | yield {k: v[indices] for k, v in batch.items()} 66 | 67 | def compute_returns(self, values, next_values): 68 | shape = self.buffers['rewards'].shape 69 | self.buffers['values'] = values.reshape(shape) 70 | self.buffers['next_values'] = next_values.reshape(shape) 71 | self.buffers['returns'] = replays.lambda_returns( 72 | values=self.buffers['values'], 73 | next_values=self.buffers['next_values'], 74 | rewards=self.buffers['rewards'], 75 | resets=self.buffers['resets'], 76 | terminations=self.buffers['terminations'], 77 | discount_factor=self.discount_factor, 78 | trace_decay=self.trace_decay) 79 | -------------------------------------------------------------------------------- /tonic/replays/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def lambda_returns( 5 | values, next_values, rewards, resets, terminations, discount_factor, 6 | trace_decay 7 | ): 8 | '''Function used to calculate lambda-returns on parallel buffers.''' 9 | 10 | returns = np.zeros_like(values) 11 | last_returns = next_values[-1] 12 | for t in reversed(range(len(rewards))): 13 | bootstrap = ( 14 | (1 - trace_decay) * next_values[t] + trace_decay * last_returns) 15 | bootstrap *= (1 - resets[t]) 16 | bootstrap += resets[t] * next_values[t] 17 | bootstrap *= (1 - terminations[t]) 18 | returns[t] = last_returns = rewards[t] + discount_factor * bootstrap 19 | return returns 20 | 21 | 22 | def flatten_batch(values): 23 | shape = values.shape 24 | new_shape = (np.prod(shape[:2], dtype=int),) + shape[2:] 25 | return values.reshape(new_shape) 26 | -------------------------------------------------------------------------------- /tonic/tensorflow/__init__.py: -------------------------------------------------------------------------------- 1 | from . import agents, models, normalizers, updaters 2 | 3 | 4 | __all__ = [agents, models, normalizers, updaters] 5 | -------------------------------------------------------------------------------- /tonic/tensorflow/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent import Agent 2 | 3 | from .a2c import A2C # noqa 4 | from .ddpg import DDPG 5 | from .d4pg import D4PG # noqa 6 | from .mpo import MPO 7 | from .ppo import PPO 8 | from .sac import SAC 9 | from .td3 import TD3 10 | from .td4 import TD4 11 | from .trpo import TRPO 12 | 13 | 14 | __all__ = [Agent, A2C, DDPG, D4PG, MPO, PPO, SAC, TD3, TD4, TRPO] 15 | -------------------------------------------------------------------------------- /tonic/tensorflow/agents/a2c.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tonic import logger, replays 4 | from tonic.tensorflow import agents, models, normalizers, updaters 5 | 6 | 7 | def default_model(): 8 | return models.ActorCritic( 9 | actor=models.Actor( 10 | encoder=models.ObservationEncoder(), 11 | torso=models.MLP((64, 64), 'tanh'), 12 | head=models.DetachedScaleGaussianPolicyHead()), 13 | critic=models.Critic( 14 | encoder=models.ObservationEncoder(), 15 | torso=models.MLP((64, 64), 'tanh'), 16 | head=models.ValueHead()), 17 | observation_normalizer=normalizers.MeanStd()) 18 | 19 | 20 | class A2C(agents.Agent): 21 | '''Advantage Actor Critic (aka Vanilla Policy Gradient). 22 | A3C: https://arxiv.org/pdf/1602.01783.pdf 23 | ''' 24 | 25 | def __init__( 26 | self, model=None, replay=None, actor_updater=None, critic_updater=None 27 | ): 28 | self.model = model or default_model() 29 | self.replay = replay or replays.Segment() 30 | self.actor_updater = actor_updater or \ 31 | updaters.StochasticPolicyGradient() 32 | self.critic_updater = critic_updater or updaters.VRegression() 33 | 34 | def initialize(self, observation_space, action_space, seed=None): 35 | super().initialize(seed=seed) 36 | self.model.initialize(observation_space, action_space) 37 | self.replay.initialize(seed) 38 | self.actor_updater.initialize(self.model) 39 | self.critic_updater.initialize(self.model) 40 | 41 | def step(self, observations, steps): 42 | # Sample actions and get their log-probabilities for training. 43 | actions, log_probs = self._step(observations) 44 | actions = actions.numpy() 45 | log_probs = log_probs.numpy() 46 | 47 | # Keep some values for the next update. 48 | self.last_observations = observations.copy() 49 | self.last_actions = actions.copy() 50 | self.last_log_probs = log_probs.copy() 51 | 52 | return actions 53 | 54 | def test_step(self, observations, steps): 55 | # Sample actions for testing. 56 | return self._test_step(observations).numpy() 57 | 58 | def update(self, observations, rewards, resets, terminations, steps): 59 | # Store the last transitions in the replay. 60 | self.replay.store( 61 | observations=self.last_observations, actions=self.last_actions, 62 | next_observations=observations, rewards=rewards, resets=resets, 63 | terminations=terminations, log_probs=self.last_log_probs) 64 | 65 | # Prepare to update the normalizers. 66 | if self.model.observation_normalizer: 67 | self.model.observation_normalizer.record(self.last_observations) 68 | if self.model.return_normalizer: 69 | self.model.return_normalizer.record(rewards) 70 | 71 | # Update the model if the replay is ready. 72 | if self.replay.ready(): 73 | self._update() 74 | 75 | @tf.function 76 | def _step(self, observations): 77 | distributions = self.model.actor(observations) 78 | if hasattr(distributions, 'sample_with_log_prob'): 79 | actions, log_probs = distributions.sample_with_log_prob() 80 | else: 81 | actions = distributions.sample() 82 | log_probs = distributions.log_prob(actions) 83 | return actions, log_probs 84 | 85 | @tf.function 86 | def _test_step(self, observations): 87 | return self.model.actor(observations).sample() 88 | 89 | @tf.function 90 | def _evaluate(self, observations, next_observations): 91 | values = self.model.critic(observations) 92 | next_values = self.model.critic(next_observations) 93 | return values, next_values 94 | 95 | def _update(self): 96 | # Compute the lambda-returns. 97 | batch = self.replay.get_full('observations', 'next_observations') 98 | values, next_values = self._evaluate(**batch) 99 | values, next_values = values.numpy(), next_values.numpy() 100 | self.replay.compute_returns(values, next_values) 101 | 102 | # Update the actor once. 103 | keys = 'observations', 'actions', 'advantages', 'log_probs' 104 | batch = self.replay.get_full(*keys) 105 | infos = self.actor_updater(**batch) 106 | for k, v in infos.items(): 107 | logger.store('actor/' + k, v.numpy()) 108 | 109 | # Update the critic multiple times. 110 | for batch in self.replay.get('observations', 'returns'): 111 | infos = self.critic_updater(**batch) 112 | for k, v in infos.items(): 113 | logger.store('critic/' + k, v.numpy()) 114 | 115 | # Update the normalizers. 116 | if self.model.observation_normalizer: 117 | self.model.observation_normalizer.update() 118 | if self.model.return_normalizer: 119 | self.model.return_normalizer.update() 120 | -------------------------------------------------------------------------------- /tonic/tensorflow/agents/agent.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from tonic import agents, logger 7 | 8 | 9 | class Agent(agents.Agent): 10 | def initialize(self, seed=None): 11 | if seed is not None: 12 | np.random.seed(seed) 13 | random.seed(seed) 14 | tf.random.set_seed(seed) 15 | 16 | def save(self, path): 17 | logger.log(f'\nSaving weights to {path}') 18 | self.model.save_weights(path) 19 | 20 | def load(self, path): 21 | logger.log(f'\nLoading weights from {path}') 22 | self.model.load_weights(path) 23 | -------------------------------------------------------------------------------- /tonic/tensorflow/agents/d4pg.py: -------------------------------------------------------------------------------- 1 | from tonic import replays 2 | from tonic.tensorflow import agents, models, normalizers, updaters 3 | 4 | 5 | def default_model(): 6 | return models.ActorCriticWithTargets( 7 | actor=models.Actor( 8 | encoder=models.ObservationEncoder(), 9 | torso=models.MLP((256, 256), 'relu'), 10 | head=models.DeterministicPolicyHead()), 11 | critic=models.Critic( 12 | encoder=models.ObservationActionEncoder(), 13 | torso=models.MLP((256, 256), 'relu'), 14 | # These values are for the control suite with 0.99 discount. 15 | head=models.DistributionalValueHead(-150., 150., 51)), 16 | observation_normalizer=normalizers.MeanStd()) 17 | 18 | 19 | class D4PG(agents.DDPG): 20 | '''Distributed Distributional Deterministic Policy Gradients. 21 | D4PG: https://arxiv.org/pdf/1804.08617.pdf 22 | ''' 23 | 24 | def __init__( 25 | self, model=None, replay=None, exploration=None, actor_updater=None, 26 | critic_updater=None 27 | ): 28 | model = model or default_model() 29 | replay = replay or replays.Buffer(return_steps=5) 30 | actor_updater = actor_updater or \ 31 | updaters.DistributionalDeterministicPolicyGradient() 32 | critic_updater = critic_updater or \ 33 | updaters.DistributionalDeterministicQLearning() 34 | super().__init__( 35 | model, replay, exploration, actor_updater, critic_updater) 36 | -------------------------------------------------------------------------------- /tonic/tensorflow/agents/ddpg.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tonic import explorations, logger, replays 4 | from tonic.tensorflow import agents, models, normalizers, updaters 5 | 6 | 7 | def default_model(): 8 | return models.ActorCriticWithTargets( 9 | actor=models.Actor( 10 | encoder=models.ObservationEncoder(), 11 | torso=models.MLP((256, 256), 'relu'), 12 | head=models.DeterministicPolicyHead()), 13 | critic=models.Critic( 14 | encoder=models.ObservationActionEncoder(), 15 | torso=models.MLP((256, 256), 'relu'), 16 | head=models.ValueHead()), 17 | observation_normalizer=normalizers.MeanStd()) 18 | 19 | 20 | class DDPG(agents.Agent): 21 | '''Deep Deterministic Policy Gradient. 22 | DDPG: https://arxiv.org/pdf/1509.02971.pdf 23 | ''' 24 | 25 | def __init__( 26 | self, model=None, replay=None, exploration=None, actor_updater=None, 27 | critic_updater=None 28 | ): 29 | self.model = model or default_model() 30 | self.replay = replay or replays.Buffer() 31 | self.exploration = exploration or explorations.NormalActionNoise() 32 | self.actor_updater = actor_updater or \ 33 | updaters.DeterministicPolicyGradient() 34 | self.critic_updater = critic_updater or \ 35 | updaters.DeterministicQLearning() 36 | 37 | def initialize(self, observation_space, action_space, seed=None): 38 | super().initialize(seed=seed) 39 | self.model.initialize(observation_space, action_space) 40 | self.replay.initialize(seed) 41 | self.exploration.initialize(self._policy, action_space, seed) 42 | self.actor_updater.initialize(self.model) 43 | self.critic_updater.initialize(self.model) 44 | 45 | def step(self, observations, steps): 46 | # Get actions from the actor and exploration method. 47 | actions = self.exploration(observations, steps) 48 | 49 | # Keep some values for the next update. 50 | self.last_observations = observations.copy() 51 | self.last_actions = actions.copy() 52 | 53 | return actions 54 | 55 | def test_step(self, observations, steps): 56 | # Greedy actions for testing. 57 | return self._greedy_actions(observations).numpy() 58 | 59 | def update(self, observations, rewards, resets, terminations, steps): 60 | # Store the last transitions in the replay. 61 | self.replay.store( 62 | observations=self.last_observations, actions=self.last_actions, 63 | next_observations=observations, rewards=rewards, resets=resets, 64 | terminations=terminations) 65 | 66 | # Prepare to update the normalizers. 67 | if self.model.observation_normalizer: 68 | self.model.observation_normalizer.record(self.last_observations) 69 | if self.model.return_normalizer: 70 | self.model.return_normalizer.record(rewards) 71 | 72 | # Update the model if the replay is ready. 73 | if self.replay.ready(steps): 74 | self._update(steps) 75 | 76 | self.exploration.update(resets) 77 | 78 | @tf.function 79 | def _greedy_actions(self, observations): 80 | return self.model.actor(observations) 81 | 82 | def _policy(self, observations): 83 | return self._greedy_actions(observations).numpy() 84 | 85 | def _update(self, steps): 86 | keys = ('observations', 'actions', 'next_observations', 'rewards', 87 | 'discounts') 88 | 89 | # Update both the actor and the critic multiple times. 90 | for batch in self.replay.get(*keys, steps=steps): 91 | infos = self._update_actor_critic(**batch) 92 | 93 | for key in infos: 94 | for k, v in infos[key].items(): 95 | logger.store(key + '/' + k, v.numpy()) 96 | 97 | # Update the normalizers. 98 | if self.model.observation_normalizer: 99 | self.model.observation_normalizer.update() 100 | if self.model.return_normalizer: 101 | self.model.return_normalizer.update() 102 | 103 | @tf.function 104 | def _update_actor_critic( 105 | self, observations, actions, next_observations, rewards, discounts 106 | ): 107 | critic_infos = self.critic_updater( 108 | observations, actions, next_observations, rewards, discounts) 109 | actor_infos = self.actor_updater(observations) 110 | self.model.update_targets() 111 | return dict(critic=critic_infos, actor=actor_infos) 112 | -------------------------------------------------------------------------------- /tonic/tensorflow/agents/mpo.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tonic import logger, replays 4 | from tonic.tensorflow import agents, models, normalizers, updaters 5 | 6 | 7 | def default_model(): 8 | return models.ActorCriticWithTargets( 9 | actor=models.Actor( 10 | encoder=models.ObservationEncoder(), 11 | torso=models.MLP((256, 256), 'relu'), 12 | head=models.GaussianPolicyHead()), 13 | critic=models.Critic( 14 | encoder=models.ObservationActionEncoder(), 15 | torso=models.MLP((256, 256), 'relu'), 16 | head=models.ValueHead()), 17 | observation_normalizer=normalizers.MeanStd()) 18 | 19 | 20 | class MPO(agents.Agent): 21 | '''Maximum a Posteriori Policy Optimisation. 22 | MPO: https://arxiv.org/pdf/1806.06920.pdf 23 | MO-MPO: https://arxiv.org/pdf/2005.07513.pdf 24 | ''' 25 | 26 | def __init__( 27 | self, model=None, replay=None, actor_updater=None, critic_updater=None 28 | ): 29 | self.model = model or default_model() 30 | self.replay = replay or replays.Buffer(return_steps=5) 31 | self.actor_updater = actor_updater or \ 32 | updaters.MaximumAPosterioriPolicyOptimization() 33 | self.critic_updater = critic_updater or updaters.ExpectedSARSA() 34 | 35 | def initialize(self, observation_space, action_space, seed=None): 36 | super().initialize(seed=seed) 37 | self.model.initialize(observation_space, action_space) 38 | self.replay.initialize(seed) 39 | self.actor_updater.initialize(self.model, action_space) 40 | self.critic_updater.initialize(self.model) 41 | 42 | def step(self, observations, steps): 43 | actions = self._step(observations) 44 | actions = actions.numpy() 45 | 46 | # Keep some values for the next update. 47 | self.last_observations = observations.copy() 48 | self.last_actions = actions.copy() 49 | 50 | return actions 51 | 52 | def test_step(self, observations, steps): 53 | # Sample actions for testing. 54 | return self._test_step(observations).numpy() 55 | 56 | def update(self, observations, rewards, resets, terminations, steps): 57 | # Store the last transitions in the replay. 58 | self.replay.store( 59 | observations=self.last_observations, actions=self.last_actions, 60 | next_observations=observations, rewards=rewards, resets=resets, 61 | terminations=terminations) 62 | 63 | # Prepare to update the normalizers. 64 | if self.model.observation_normalizer: 65 | self.model.observation_normalizer.record(self.last_observations) 66 | if self.model.return_normalizer: 67 | self.model.return_normalizer.record(rewards) 68 | 69 | # Update the model if the replay is ready. 70 | if self.replay.ready(steps): 71 | self._update(steps) 72 | 73 | @tf.function 74 | def _step(self, observations): 75 | return self.model.actor(observations).sample() 76 | 77 | @tf.function 78 | def _test_step(self, observations): 79 | return self.model.actor(observations).mode() 80 | 81 | def _update(self, steps): 82 | keys = ('observations', 'actions', 'next_observations', 'rewards', 83 | 'discounts') 84 | 85 | # Update both the actor and the critic multiple times. 86 | for batch in self.replay.get(*keys, steps=steps): 87 | infos = self._update_actor_critic(**batch) 88 | 89 | for key in infos: 90 | for k, v in infos[key].items(): 91 | logger.store(key + '/' + k, v.numpy()) 92 | 93 | # Update the normalizers. 94 | if self.model.observation_normalizer: 95 | self.model.observation_normalizer.update() 96 | if self.model.return_normalizer: 97 | self.model.return_normalizer.update() 98 | 99 | @tf.function 100 | def _update_actor_critic( 101 | self, observations, actions, next_observations, rewards, discounts 102 | ): 103 | critic_infos = self.critic_updater( 104 | observations, actions, next_observations, rewards, discounts) 105 | actor_infos = self.actor_updater(observations) 106 | self.model.update_targets() 107 | return dict(critic=critic_infos, actor=actor_infos) 108 | -------------------------------------------------------------------------------- /tonic/tensorflow/agents/ppo.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tonic import logger 4 | from tonic.tensorflow import agents, updaters 5 | 6 | 7 | class PPO(agents.A2C): 8 | '''Proximal Policy Optimization. 9 | PPO: https://arxiv.org/pdf/1707.06347.pdf 10 | ''' 11 | 12 | def __init__( 13 | self, model=None, replay=None, actor_updater=None, critic_updater=None 14 | ): 15 | actor_updater = actor_updater or updaters.ClippedRatio() 16 | super().__init__( 17 | model=model, replay=replay, actor_updater=actor_updater, 18 | critic_updater=critic_updater) 19 | 20 | def _update(self): 21 | # Compute the lambda-returns. 22 | batch = self.replay.get_full('observations', 'next_observations') 23 | values, next_values = self._evaluate(**batch) 24 | values, next_values = values.numpy(), next_values.numpy() 25 | self.replay.compute_returns(values, next_values) 26 | 27 | train_actor = True 28 | actor_iterations = 0 29 | critic_iterations = 0 30 | keys = 'observations', 'actions', 'advantages', 'log_probs', 'returns' 31 | 32 | # Update both the actor and the critic multiple times. 33 | for batch in self.replay.get(*keys): 34 | if train_actor: 35 | infos = self._update_actor_critic(**batch) 36 | actor_iterations += 1 37 | else: 38 | batch = {k: batch[k] for k in ('observations', 'returns')} 39 | infos = dict(critic=self.critic_updater(**batch)) 40 | critic_iterations += 1 41 | 42 | # Stop earlier the training of the actor. 43 | if train_actor: 44 | train_actor = not infos['actor']['stop'].numpy() 45 | 46 | for key in infos: 47 | for k, v in infos[key].items(): 48 | logger.store(key + '/' + k, v.numpy()) 49 | 50 | logger.store('actor/iterations', actor_iterations) 51 | logger.store('critic/iterations', critic_iterations) 52 | 53 | # Update the normalizers. 54 | if self.model.observation_normalizer: 55 | self.model.observation_normalizer.update() 56 | if self.model.return_normalizer: 57 | self.model.return_normalizer.update() 58 | 59 | @tf.function 60 | def _update_actor_critic( 61 | self, observations, actions, advantages, log_probs, returns 62 | ): 63 | actor_infos = self.actor_updater( 64 | observations, actions, advantages, log_probs) 65 | critic_infos = self.critic_updater(observations, returns) 66 | return dict(actor=actor_infos, critic=critic_infos) 67 | -------------------------------------------------------------------------------- /tonic/tensorflow/agents/sac.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tonic import explorations 4 | from tonic.tensorflow import agents, models, normalizers, updaters 5 | 6 | 7 | def default_model(): 8 | return models.ActorTwinCriticWithTargets( 9 | actor=models.Actor( 10 | encoder=models.ObservationEncoder(), 11 | torso=models.MLP((256, 256), 'relu'), 12 | head=models.GaussianPolicyHead( 13 | loc_activation=None, 14 | distribution=models.SquashedMultivariateNormalDiag)), 15 | critic=models.Critic( 16 | encoder=models.ObservationActionEncoder(), 17 | torso=models.MLP((256, 256), 'relu'), 18 | head=models.ValueHead()), 19 | observation_normalizer=normalizers.MeanStd()) 20 | 21 | 22 | class SAC(agents.DDPG): 23 | '''Soft Actor-Critic. 24 | SAC: https://arxiv.org/pdf/1801.01290.pdf 25 | ''' 26 | 27 | def __init__( 28 | self, model=None, replay=None, exploration=None, actor_updater=None, 29 | critic_updater=None 30 | ): 31 | model = model or default_model() 32 | exploration = exploration or explorations.NoActionNoise() 33 | actor_updater = actor_updater or \ 34 | updaters.TwinCriticSoftDeterministicPolicyGradient() 35 | critic_updater = critic_updater or updaters.TwinCriticSoftQLearning() 36 | super().__init__( 37 | model=model, replay=replay, exploration=exploration, 38 | actor_updater=actor_updater, critic_updater=critic_updater) 39 | 40 | @tf.function 41 | def _stochastic_actions(self, observations): 42 | return self.model.actor(observations).sample() 43 | 44 | def _policy(self, observations): 45 | return self._stochastic_actions(observations).numpy() 46 | 47 | @tf.function 48 | def _greedy_actions(self, observations): 49 | return self.model.actor(observations).mode() 50 | -------------------------------------------------------------------------------- /tonic/tensorflow/agents/td3.py: -------------------------------------------------------------------------------- 1 | from tonic import logger 2 | from tonic.tensorflow import agents, models, normalizers, updaters 3 | 4 | 5 | def default_model(): 6 | return models.ActorTwinCriticWithTargets( 7 | actor=models.Actor( 8 | encoder=models.ObservationEncoder(), 9 | torso=models.MLP((256, 256), 'relu'), 10 | head=models.DeterministicPolicyHead()), 11 | critic=models.Critic( 12 | encoder=models.ObservationActionEncoder(), 13 | torso=models.MLP((256, 256), 'relu'), 14 | head=models.ValueHead()), 15 | observation_normalizer=normalizers.MeanStd()) 16 | 17 | 18 | class TD3(agents.DDPG): 19 | '''Twin Delayed Deep Deterministic Policy Gradient. 20 | TD3: https://arxiv.org/pdf/1802.09477.pdf 21 | ''' 22 | 23 | def __init__( 24 | self, model=None, replay=None, exploration=None, actor_updater=None, 25 | critic_updater=None, delay_steps=2 26 | ): 27 | model = model or default_model() 28 | critic_updater = critic_updater or \ 29 | updaters.TwinCriticDeterministicQLearning() 30 | super().__init__( 31 | model=model, replay=replay, exploration=exploration, 32 | actor_updater=actor_updater, critic_updater=critic_updater) 33 | self.delay_steps = delay_steps 34 | self.model.critic = self.model.critic_1 35 | 36 | def _update(self, steps): 37 | keys = ('observations', 'actions', 'next_observations', 'rewards', 38 | 'discounts') 39 | for i, batch in enumerate(self.replay.get(*keys, steps=steps)): 40 | if (i + 1) % self.delay_steps == 0: 41 | infos = self._update_actor_critic(**batch) 42 | else: 43 | infos = dict(critic=self.critic_updater(**batch)) 44 | for key in infos: 45 | for k, v in infos[key].items(): 46 | logger.store(key + '/' + k, v.numpy()) 47 | 48 | # Update the normalizers. 49 | if self.model.observation_normalizer: 50 | self.model.observation_normalizer.update() 51 | if self.model.return_normalizer: 52 | self.model.return_normalizer.update() 53 | -------------------------------------------------------------------------------- /tonic/tensorflow/agents/td4.py: -------------------------------------------------------------------------------- 1 | from tonic import replays 2 | from tonic.tensorflow import agents, models, normalizers, updaters 3 | 4 | 5 | def default_model(): 6 | return models.ActorTwinCriticWithTargets( 7 | actor=models.Actor( 8 | encoder=models.ObservationEncoder(), 9 | torso=models.MLP((256, 256), 'relu'), 10 | head=models.DeterministicPolicyHead()), 11 | critic=models.Critic( 12 | encoder=models.ObservationActionEncoder(), 13 | torso=models.MLP((256, 256), 'relu'), 14 | head=models.DistributionalValueHead(-150., 150., 51)), 15 | observation_normalizer=normalizers.MeanStd()) 16 | 17 | 18 | class TD4(agents.TD3): 19 | def __init__( 20 | self, model=None, replay=None, exploration=None, actor_updater=None, 21 | critic_updater=None, delay_steps=2 22 | ): 23 | model = model or default_model() 24 | replay = replay or replays.Buffer(return_steps=5) 25 | actor_updater = actor_updater or \ 26 | updaters.DistributionalDeterministicPolicyGradient() 27 | critic_updater = critic_updater or \ 28 | updaters.TwinCriticDistributionalDeterministicQLearning() 29 | super().__init__( 30 | model=model, replay=replay, exploration=exploration, 31 | actor_updater=actor_updater, critic_updater=critic_updater, 32 | delay_steps=delay_steps) 33 | -------------------------------------------------------------------------------- /tonic/tensorflow/agents/trpo.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tonic import logger 4 | from tonic.tensorflow import agents, updaters 5 | 6 | 7 | class TRPO(agents.A2C): 8 | '''Trust Region Policy Optimization. 9 | TRPO: https://arxiv.org/pdf/1502.05477.pdf 10 | ''' 11 | 12 | def __init__( 13 | self, model=None, replay=None, actor_updater=None, critic_updater=None 14 | ): 15 | actor_updater = actor_updater or updaters.TrustRegionPolicyGradient() 16 | super().__init__( 17 | model=model, replay=replay, actor_updater=actor_updater, 18 | critic_updater=critic_updater) 19 | 20 | def step(self, observations, steps): 21 | # Sample actions and get their log-probabilities for training. 22 | actions, log_probs, locs, scales = self._step(observations) 23 | actions = actions.numpy() 24 | log_probs = log_probs.numpy() 25 | locs = locs.numpy() 26 | scales = scales.numpy() 27 | 28 | # Keep some values for the next update. 29 | self.last_observations = observations.copy() 30 | self.last_actions = actions.copy() 31 | self.last_log_probs = log_probs.copy() 32 | self.last_locs = locs.copy() 33 | self.last_scales = scales.copy() 34 | 35 | return actions 36 | 37 | def update(self, observations, rewards, resets, terminations, steps): 38 | # Store the last transitions in the replay. 39 | self.replay.store( 40 | observations=self.last_observations, actions=self.last_actions, 41 | next_observations=observations, rewards=rewards, resets=resets, 42 | terminations=terminations, log_probs=self.last_log_probs, 43 | locs=self.last_locs, scales=self.last_scales) 44 | 45 | # Prepare to update the normalizers. 46 | if self.model.observation_normalizer: 47 | self.model.observation_normalizer.record(self.last_observations) 48 | if self.model.return_normalizer: 49 | self.model.return_normalizer.record(rewards) 50 | 51 | # Update the model if the replay is ready. 52 | if self.replay.ready(): 53 | self._update() 54 | 55 | @tf.function 56 | def _step(self, observations): 57 | distributions = self.model.actor(observations) 58 | if hasattr(distributions, 'sample_with_log_prob'): 59 | actions, log_probs = distributions.sample_with_log_prob() 60 | else: 61 | actions = distributions.sample() 62 | log_probs = distributions.log_prob(actions) 63 | locs = distributions.loc 64 | scales = distributions.stddev() 65 | return actions, log_probs, locs, scales 66 | 67 | def _update(self): 68 | # Compute the lambda-returns. 69 | batch = self.replay.get_full('observations', 'next_observations') 70 | values, next_values = self._evaluate(**batch) 71 | values, next_values = values.numpy(), next_values.numpy() 72 | self.replay.compute_returns(values, next_values) 73 | 74 | keys = ('observations', 'actions', 'log_probs', 'locs', 'scales', 75 | 'advantages') 76 | batch = self.replay.get_full(*keys) 77 | infos = self.actor_updater(**batch) 78 | for k, v in infos.items(): 79 | logger.store('actor/' + k, v.numpy()) 80 | 81 | critic_iterations = 0 82 | for batch in self.replay.get('observations', 'returns'): 83 | infos = self.critic_updater(**batch) 84 | critic_iterations += 1 85 | for k, v in infos.items(): 86 | logger.store('critic/' + k, v.numpy()) 87 | logger.store('critic/iterations', critic_iterations) 88 | 89 | # Update the normalizers. 90 | if self.model.observation_normalizer: 91 | self.model.observation_normalizer.update() 92 | if self.model.return_normalizer: 93 | self.model.return_normalizer.update() 94 | -------------------------------------------------------------------------------- /tonic/tensorflow/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .actor_critics import ActorCritic 2 | from .actor_critics import ActorCriticWithTargets 3 | from .actor_critics import ActorTwinCriticWithTargets 4 | 5 | from .actors import Actor 6 | from .actors import DetachedScaleGaussianPolicyHead 7 | from .actors import DeterministicPolicyHead 8 | from .actors import GaussianPolicyHead 9 | from .actors import SquashedMultivariateNormalDiag 10 | 11 | from .critics import Critic, DistributionalValueHead, ValueHead 12 | 13 | from .encoders import ObservationActionEncoder, ObservationEncoder 14 | 15 | from .utils import default_dense_kwargs, MLP 16 | 17 | 18 | __all__ = [ 19 | default_dense_kwargs, MLP, ObservationActionEncoder, 20 | ObservationEncoder, SquashedMultivariateNormalDiag, 21 | DetachedScaleGaussianPolicyHead, GaussianPolicyHead, 22 | DeterministicPolicyHead, Actor, Critic, DistributionalValueHead, 23 | ValueHead, ActorCritic, ActorCriticWithTargets, ActorTwinCriticWithTargets] 24 | -------------------------------------------------------------------------------- /tonic/tensorflow/models/actor_critics.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import tensorflow as tf 4 | 5 | 6 | class ActorCritic(tf.keras.Model): 7 | def __init__( 8 | self, actor, critic, observation_normalizer=None, 9 | return_normalizer=None 10 | ): 11 | super().__init__() 12 | self.actor = actor 13 | self.critic = critic 14 | self.observation_normalizer = observation_normalizer 15 | self.return_normalizer = return_normalizer 16 | 17 | def initialize(self, observation_space, action_space): 18 | if self.observation_normalizer: 19 | self.observation_normalizer.initialize(observation_space.shape) 20 | self.actor.initialize( 21 | observation_space, action_space, self.observation_normalizer) 22 | self.critic.initialize( 23 | observation_space, action_space, self.observation_normalizer, 24 | self.return_normalizer) 25 | dummy_observations = tf.zeros((1,) + observation_space.shape) 26 | self.actor(dummy_observations) 27 | self.critic(dummy_observations) 28 | 29 | 30 | class ActorCriticWithTargets(tf.keras.Model): 31 | def __init__( 32 | self, actor, critic, observation_normalizer=None, 33 | return_normalizer=None, target_coeff=0.005 34 | ): 35 | super().__init__() 36 | self.actor = actor 37 | self.critic = critic 38 | self.target_actor = copy.deepcopy(actor) 39 | self.target_critic = copy.deepcopy(critic) 40 | self.observation_normalizer = observation_normalizer 41 | self.return_normalizer = return_normalizer 42 | self.target_coeff = target_coeff 43 | 44 | def initialize(self, observation_space, action_space): 45 | if self.observation_normalizer: 46 | self.observation_normalizer.initialize(observation_space.shape) 47 | self.actor.initialize( 48 | observation_space, action_space, self.observation_normalizer) 49 | self.critic.initialize( 50 | observation_space, action_space, self.observation_normalizer, 51 | self.return_normalizer) 52 | self.target_actor.initialize( 53 | observation_space, action_space, self.observation_normalizer) 54 | self.target_critic.initialize( 55 | observation_space, action_space, self.observation_normalizer, 56 | self.return_normalizer) 57 | dummy_observations = tf.zeros((1,) + observation_space.shape) 58 | dummy_actions = tf.zeros((1,) + action_space.shape) 59 | self.actor(dummy_observations) 60 | self.critic(dummy_observations, dummy_actions) 61 | self.target_actor(dummy_observations) 62 | self.target_critic(dummy_observations, dummy_actions) 63 | self.online_variables = ( 64 | self.actor.trainable_variables + 65 | self.critic.trainable_variables) 66 | self.target_variables = ( 67 | self.target_actor.trainable_variables + 68 | self.target_critic.trainable_variables) 69 | self.assign_targets() 70 | 71 | def assign_targets(self): 72 | for o, t in zip(self.online_variables, self.target_variables): 73 | t.assign(o) 74 | 75 | def update_targets(self): 76 | for o, t in zip(self.online_variables, self.target_variables): 77 | t.assign((1 - self.target_coeff) * t + self.target_coeff * o) 78 | 79 | 80 | class ActorTwinCriticWithTargets(tf.keras.Model): 81 | def __init__( 82 | self, actor, critic, observation_normalizer=None, 83 | return_normalizer=None, target_coeff=0.005 84 | ): 85 | super().__init__() 86 | self.actor = actor 87 | self.critic_1 = critic 88 | self.critic_2 = copy.deepcopy(critic) 89 | self.target_actor = copy.deepcopy(actor) 90 | self.target_critic_1 = copy.deepcopy(critic) 91 | self.target_critic_2 = copy.deepcopy(critic) 92 | self.observation_normalizer = observation_normalizer 93 | self.return_normalizer = return_normalizer 94 | self.target_coeff = target_coeff 95 | 96 | def initialize(self, observation_space, action_space): 97 | if self.observation_normalizer: 98 | self.observation_normalizer.initialize(observation_space.shape) 99 | self.actor.initialize( 100 | observation_space, action_space, self.observation_normalizer) 101 | self.critic_1.initialize( 102 | observation_space, action_space, self.observation_normalizer, 103 | self.return_normalizer) 104 | self.critic_2.initialize( 105 | observation_space, action_space, self.observation_normalizer, 106 | self.return_normalizer) 107 | self.target_actor.initialize( 108 | observation_space, action_space, self.observation_normalizer) 109 | self.target_critic_1.initialize( 110 | observation_space, action_space, self.observation_normalizer, 111 | self.return_normalizer) 112 | self.target_critic_2.initialize( 113 | observation_space, action_space, self.observation_normalizer, 114 | self.return_normalizer) 115 | dummy_observations = tf.zeros((1,) + observation_space.shape) 116 | dummy_actions = tf.zeros((1,) + action_space.shape) 117 | self.actor(dummy_observations) 118 | self.critic_1(dummy_observations, dummy_actions) 119 | self.critic_2(dummy_observations, dummy_actions) 120 | self.target_actor(dummy_observations) 121 | self.target_critic_1(dummy_observations, dummy_actions) 122 | self.target_critic_2(dummy_observations, dummy_actions) 123 | self.online_variables = ( 124 | self.actor.trainable_variables + 125 | self.critic_1.trainable_variables + 126 | self.critic_2.trainable_variables) 127 | self.target_variables = ( 128 | self.target_actor.trainable_variables + 129 | self.target_critic_1.trainable_variables + 130 | self.target_critic_2.trainable_variables) 131 | self.assign_targets() 132 | 133 | def assign_targets(self): 134 | for o, t in zip(self.online_variables, self.target_variables): 135 | t.assign(o) 136 | 137 | def update_targets(self): 138 | for o, t in zip(self.online_variables, self.target_variables): 139 | t.assign((1 - self.target_coeff) * t + self.target_coeff * o) 140 | -------------------------------------------------------------------------------- /tonic/tensorflow/models/actors.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_probability as tfp 3 | 4 | from tonic.tensorflow import models 5 | 6 | 7 | FLOAT_EPSILON = 1e-8 8 | 9 | 10 | class SquashedMultivariateNormalDiag: 11 | def __init__(self, loc, scale): 12 | self._distribution = tfp.distributions.MultivariateNormalDiag( 13 | loc, scale) 14 | 15 | def sample_with_log_prob(self, shape=()): 16 | samples = self._distribution.sample(shape) 17 | squashed_samples = tf.tanh(samples) 18 | log_probs = self._distribution.log_prob(samples) 19 | log_probs -= tf.reduce_sum( 20 | tf.math.log(1 - squashed_samples ** 2 + 1e-6), axis=-1) 21 | return squashed_samples, log_probs 22 | 23 | def sample(self, shape=()): 24 | samples = self._distribution.sample(shape) 25 | return tf.tanh(samples) 26 | 27 | def log_prob(self, samples): 28 | '''Required unsquashed samples cannot be accurately recovered.''' 29 | raise NotImplementedError( 30 | 'Not implemented to avoid approximation errors. ' 31 | 'Use sample_with_log_prob directly.') 32 | 33 | def mode(self): 34 | return tf.tanh(self._distribution.mode()) 35 | 36 | 37 | class DetachedScaleGaussianPolicyHead(tf.keras.Model): 38 | def __init__( 39 | self, loc_activation='tanh', dense_loc_kwargs=None, log_scale_init=0., 40 | scale_min=1e-4, scale_max=1., 41 | distribution=tfp.distributions.MultivariateNormalDiag 42 | ): 43 | super().__init__() 44 | self.loc_activation = loc_activation 45 | if dense_loc_kwargs is None: 46 | dense_loc_kwargs = models.default_dense_kwargs() 47 | self.dense_loc_kwargs = dense_loc_kwargs 48 | self.log_scale_init = log_scale_init 49 | self.scale_min = scale_min 50 | self.scale_max = scale_max 51 | self.distribution = distribution 52 | 53 | def initialize(self, action_size): 54 | self.loc_layer = tf.keras.layers.Dense( 55 | action_size, self.loc_activation, **self.dense_loc_kwargs) 56 | log_scale = [[self.log_scale_init] * action_size] 57 | self.log_scale = tf.Variable(log_scale, dtype=tf.float32) 58 | 59 | def call(self, inputs): 60 | loc = self.loc_layer(inputs) 61 | batch_size = tf.shape(inputs)[0] 62 | scale = tf.math.softplus(self.log_scale) + FLOAT_EPSILON 63 | scale = tf.clip_by_value(scale, self.scale_min, self.scale_max) 64 | scale = tf.tile(scale, (batch_size, 1)) 65 | return self.distribution(loc, scale) 66 | 67 | 68 | class GaussianPolicyHead(tf.keras.Model): 69 | def __init__( 70 | self, loc_activation='tanh', dense_loc_kwargs=None, 71 | scale_activation='softplus', scale_min=1e-4, scale_max=1, 72 | dense_scale_kwargs=None, 73 | distribution=tfp.distributions.MultivariateNormalDiag 74 | ): 75 | super().__init__() 76 | self.loc_activation = loc_activation 77 | if dense_loc_kwargs is None: 78 | dense_loc_kwargs = models.default_dense_kwargs() 79 | self.dense_loc_kwargs = dense_loc_kwargs 80 | self.scale_activation = scale_activation 81 | self.scale_min = scale_min 82 | self.scale_max = scale_max 83 | if dense_scale_kwargs is None: 84 | dense_scale_kwargs = models.default_dense_kwargs() 85 | self.dense_scale_kwargs = dense_scale_kwargs 86 | self.distribution = distribution 87 | 88 | def initialize(self, action_size): 89 | self.loc_layer = tf.keras.layers.Dense( 90 | action_size, self.loc_activation, **self.dense_loc_kwargs) 91 | self.scale_layer = tf.keras.layers.Dense( 92 | action_size, self.scale_activation, **self.dense_scale_kwargs) 93 | 94 | def call(self, inputs): 95 | loc = self.loc_layer(inputs) 96 | scale = self.scale_layer(inputs) 97 | scale = tf.clip_by_value(scale, self.scale_min, self.scale_max) 98 | return self.distribution(loc, scale) 99 | 100 | 101 | class DeterministicPolicyHead(tf.keras.Model): 102 | def __init__(self, activation='tanh', dense_kwargs=None): 103 | super().__init__() 104 | self.activation = activation 105 | if dense_kwargs is None: 106 | dense_kwargs = models.default_dense_kwargs() 107 | self.dense_kwargs = dense_kwargs 108 | 109 | def initialize(self, action_size): 110 | self.action_layer = tf.keras.layers.Dense( 111 | action_size, self.activation, **self.dense_kwargs) 112 | 113 | def call(self, inputs): 114 | return self.action_layer(inputs) 115 | 116 | 117 | class Actor(tf.keras.Model): 118 | def __init__(self, encoder, torso, head): 119 | super().__init__() 120 | self.encoder = encoder 121 | self.torso = torso 122 | self.head = head 123 | 124 | def initialize( 125 | self, observation_space, action_space, observation_normalizer=None 126 | ): 127 | self.encoder.initialize(observation_normalizer) 128 | self.head.initialize(action_space.shape[0]) 129 | 130 | def call(self, *inputs): 131 | out = self.encoder(*inputs) 132 | out = self.torso(out) 133 | return self.head(out) 134 | -------------------------------------------------------------------------------- /tonic/tensorflow/models/critics.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tonic.tensorflow import models 4 | 5 | 6 | class ValueHead(tf.keras.Model): 7 | def __init__(self, dense_kwargs=None): 8 | super().__init__() 9 | if dense_kwargs is None: 10 | dense_kwargs = models.default_dense_kwargs() 11 | self.v_layer = tf.keras.layers.Dense(1, **dense_kwargs) 12 | 13 | def initialize(self, return_normalizer=None): 14 | self.return_normalizer = return_normalizer 15 | 16 | def call(self, inputs): 17 | out = self.v_layer(inputs) 18 | out = tf.squeeze(out, -1) 19 | if self.return_normalizer: 20 | out = self.return_normalizer(out) 21 | return out 22 | 23 | 24 | class CategoricalWithSupport: 25 | def __init__(self, values, logits): 26 | self.values = values 27 | self.logits = logits 28 | self.probabilities = tf.nn.softmax(logits) 29 | 30 | def mean(self): 31 | return tf.reduce_sum(self.probabilities * self.values, axis=-1) 32 | 33 | def project(self, returns): 34 | vmin, vmax = self.values[0], self.values[-1] 35 | d_pos = tf.concat([self.values, vmin[None]], 0)[1:] 36 | d_pos = (d_pos - self.values)[None, :, None] 37 | d_neg = tf.concat([vmax[None], self.values], 0)[:-1] 38 | d_neg = (self.values - d_neg)[None, :, None] 39 | 40 | clipped_returns = tf.clip_by_value(returns, vmin, vmax) 41 | delta_values = clipped_returns[:, None] - self.values[None, :, None] 42 | delta_sign = tf.cast(delta_values >= 0, tf.float32) 43 | delta_hat = ((delta_sign * delta_values / d_pos) - 44 | ((1 - delta_sign) * delta_values / d_neg)) 45 | delta_clipped = tf.clip_by_value(1 - delta_hat, 0, 1) 46 | 47 | return tf.reduce_sum(delta_clipped * self.probabilities[:, None], 2) 48 | 49 | 50 | class DistributionalValueHead(tf.keras.Model): 51 | def __init__(self, vmin, vmax, num_atoms, dense_kwargs=None): 52 | super().__init__() 53 | if dense_kwargs is None: 54 | dense_kwargs = models.default_dense_kwargs() 55 | self.distributional_layer = tf.keras.layers.Dense( 56 | num_atoms, **dense_kwargs) 57 | self.values = tf.cast(tf.linspace(vmin, vmax, num_atoms), tf.float32) 58 | 59 | def initialize(self, return_normalizer=None): 60 | if return_normalizer: 61 | raise ValueError( 62 | 'Return normalizers cannot be used with distributional value' 63 | 'heads.') 64 | 65 | def call(self, inputs): 66 | logits = self.distributional_layer(inputs) 67 | return CategoricalWithSupport(values=self.values, logits=logits) 68 | 69 | 70 | class Critic(tf.keras.Model): 71 | def __init__(self, encoder, torso, head): 72 | super().__init__() 73 | self.encoder = encoder 74 | self.torso = torso 75 | self.head = head 76 | 77 | def initialize( 78 | self, observation_space, action_space, observation_normalizer=None, 79 | return_normalizer=None 80 | ): 81 | self.encoder.initialize(observation_normalizer) 82 | self.head.initialize(return_normalizer) 83 | 84 | def call(self, *inputs): 85 | out = self.encoder(*inputs) 86 | out = self.torso(out) 87 | return self.head(out) 88 | -------------------------------------------------------------------------------- /tonic/tensorflow/models/encoders.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class ObservationEncoder(tf.keras.Model): 5 | def initialize(self, observation_normalizer=None): 6 | self.observation_normalizer = observation_normalizer 7 | 8 | def call(self, observations): 9 | if self.observation_normalizer: 10 | observations = self.observation_normalizer(observations) 11 | return observations 12 | 13 | 14 | class ObservationActionEncoder(tf.keras.Model): 15 | def initialize(self, observation_normalizer=None): 16 | self.observation_normalizer = observation_normalizer 17 | 18 | def call(self, observations, actions): 19 | if self.observation_normalizer: 20 | observations = self.observation_normalizer(observations) 21 | return tf.concat([observations, actions], axis=-1) 22 | -------------------------------------------------------------------------------- /tonic/tensorflow/models/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def default_dense_kwargs(): 5 | return dict( 6 | kernel_initializer=tf.keras.initializers.VarianceScaling( 7 | scale=1 / 3, mode='fan_in', distribution='uniform'), 8 | bias_initializer=tf.keras.initializers.VarianceScaling( 9 | scale=1 / 3, mode='fan_in', distribution='uniform')) 10 | 11 | 12 | def mlp(units, activation, dense_kwargs=None): 13 | if dense_kwargs is None: 14 | dense_kwargs = default_dense_kwargs() 15 | layers = [tf.keras.layers.Dense(u, activation, **dense_kwargs) 16 | for u in units] 17 | return tf.keras.Sequential(layers) 18 | 19 | 20 | MLP = mlp 21 | -------------------------------------------------------------------------------- /tonic/tensorflow/normalizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .mean_stds import MeanStd 2 | from .returns import Return 3 | 4 | 5 | __all__ = [MeanStd, Return] 6 | -------------------------------------------------------------------------------- /tonic/tensorflow/normalizers/mean_stds.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | class MeanStd(tf.keras.Model): 6 | def __init__(self, mean=0, std=1, clip=None, shape=None): 7 | super().__init__(name='global_mean_std') 8 | self.mean = mean 9 | self.std = std 10 | self.clip = clip 11 | self.count = 0 12 | self.new_sum = 0 13 | self.new_sum_sq = 0 14 | self.new_count = 0 15 | self.eps = 1e-2 16 | if shape: 17 | self.initialize(shape) 18 | 19 | def initialize(self, shape): 20 | if isinstance(self.mean, (int, float)): 21 | self.mean = np.full(shape, self.mean, np.float32) 22 | else: 23 | self.mean = np.array(self.mean, np.float32) 24 | if isinstance(self.std, (int, float)): 25 | self.std = np.full(shape, self.std, np.float32) 26 | else: 27 | self.std = np.array(self.std, np.float32) 28 | self.mean_sq = np.square(self.mean) 29 | self._mean = tf.Variable(self.mean, trainable=False, name='mean') 30 | self._std = tf.Variable(self.std, trainable=False, name='std') 31 | 32 | def call(self, val): 33 | val = (val - self._mean) / self._std 34 | if self.clip is not None: 35 | val = tf.clip_by_value(val, -self.clip, self.clip) 36 | return val 37 | 38 | def unnormalize(self, val): 39 | return val * self._std + self._mean 40 | 41 | def record(self, values): 42 | for val in values: 43 | self.new_sum += val 44 | self.new_sum_sq += np.square(val) 45 | self.new_count += 1 46 | 47 | # Careful: do not use in @tf.function 48 | def update(self): 49 | new_count = self.count + self.new_count 50 | new_mean = self.new_sum / self.new_count 51 | new_mean_sq = self.new_sum_sq / self.new_count 52 | w_old = self.count / new_count 53 | w_new = self.new_count / new_count 54 | self.mean = w_old * self.mean + w_new * new_mean 55 | self.mean_sq = w_old * self.mean_sq + w_new * new_mean_sq 56 | self.std = self._compute_std(self.mean, self.mean_sq) 57 | self.count = new_count 58 | self.new_count = 0 59 | self.new_sum = 0 60 | self.new_sum_sq = 0 61 | self._update(self.mean.astype(np.float32), self.std.astype(np.float32)) 62 | 63 | def _compute_std(self, mean, mean_sq): 64 | var = mean_sq - np.square(mean) 65 | var = np.maximum(var, 0) 66 | std = np.sqrt(var) 67 | std = np.maximum(std, self.eps) 68 | return std 69 | 70 | @tf.function 71 | def _update(self, mean, std): 72 | self._mean.assign(mean) 73 | self._std.assign(std) 74 | -------------------------------------------------------------------------------- /tonic/tensorflow/normalizers/returns.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | class Return(tf.keras.Model): 6 | def __init__(self, discount_factor): 7 | super().__init__(name='reward_normalizer') 8 | assert 0 <= discount_factor < 1 9 | self.coefficient = 1 / (1 - discount_factor) 10 | self.min_reward = np.float32(-1) 11 | self.max_reward = np.float32(1) 12 | self._low = tf.Variable( 13 | self.coefficient * self.min_reward, dtype=tf.float32, 14 | trainable=False, name='low') 15 | self._high = tf.Variable( 16 | self.coefficient * self.max_reward, dtype=np.float32, 17 | trainable=False, name='high') 18 | 19 | def call(self, val): 20 | val = tf.sigmoid(val) 21 | return self._low + val * (self._high - self._low) 22 | 23 | def record(self, values): 24 | for val in values: 25 | if val < self.min_reward: 26 | self.min_reward = np.float32(val) 27 | elif val > self.max_reward: 28 | self.max_reward = np.float32(val) 29 | 30 | # Careful: do not use in @tf.function 31 | def update(self): 32 | self._update(self.min_reward, self.max_reward) 33 | 34 | @tf.function 35 | def _update(self, min_reward, max_reward): 36 | self._low.assign(self.coefficient * min_reward) 37 | self._high.assign(self.coefficient * max_reward) 38 | -------------------------------------------------------------------------------- /tonic/tensorflow/updaters/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import merge_first_two_dims 2 | from .utils import tile 3 | 4 | from .actors import ClippedRatio # noqa 5 | from .actors import DeterministicPolicyGradient 6 | from .actors import DistributionalDeterministicPolicyGradient 7 | from .actors import MaximumAPosterioriPolicyOptimization 8 | from .actors import StochasticPolicyGradient 9 | from .actors import TrustRegionPolicyGradient 10 | from .actors import TwinCriticSoftDeterministicPolicyGradient 11 | 12 | from .critics import DeterministicQLearning 13 | from .critics import DistributionalDeterministicQLearning 14 | from .critics import ExpectedSARSA 15 | from .critics import QRegression 16 | from .critics import TargetActionNoise 17 | from .critics import TwinCriticDeterministicQLearning 18 | from .critics import TwinCriticDistributionalDeterministicQLearning 19 | from .critics import TwinCriticSoftQLearning 20 | from .critics import VRegression 21 | 22 | from .optimizers import ConjugateGradient 23 | 24 | 25 | __all__ = [ 26 | merge_first_two_dims, tile, ClippedRatio, DeterministicPolicyGradient, 27 | DistributionalDeterministicPolicyGradient, 28 | MaximumAPosterioriPolicyOptimization, StochasticPolicyGradient, 29 | TrustRegionPolicyGradient, TwinCriticSoftDeterministicPolicyGradient, 30 | DeterministicQLearning, DistributionalDeterministicQLearning, 31 | ExpectedSARSA, QRegression, TargetActionNoise, 32 | TwinCriticDeterministicQLearning, 33 | TwinCriticDistributionalDeterministicQLearning, TwinCriticSoftQLearning, 34 | VRegression, ConjugateGradient] 35 | -------------------------------------------------------------------------------- /tonic/tensorflow/updaters/critics.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tonic.tensorflow import updaters 4 | 5 | 6 | class VRegression: 7 | def __init__(self, loss=None, optimizer=None, gradient_clip=0): 8 | self.loss = loss or tf.keras.losses.MeanSquaredError() 9 | self.optimizer = optimizer or \ 10 | tf.keras.optimizers.Adam(lr=1e-3, epsilon=1e-8) 11 | self.gradient_clip = gradient_clip 12 | 13 | def initialize(self, model): 14 | self.model = model 15 | self.variables = self.model.critic.trainable_variables 16 | 17 | @tf.function 18 | def __call__(self, observations, returns): 19 | with tf.GradientTape() as tape: 20 | values = self.model.critic(observations) 21 | loss = self.loss(returns, values) 22 | 23 | gradients = tape.gradient(loss, self.variables) 24 | if self.gradient_clip > 0: 25 | gradients = tf.clip_by_global_norm( 26 | gradients, self.gradient_clip)[0] 27 | self.optimizer.apply_gradients(zip(gradients, self.variables)) 28 | 29 | return dict(loss=loss, v=values) 30 | 31 | 32 | class QRegression: 33 | def __init__(self, loss=None, optimizer=None, gradient_clip=0): 34 | self.loss = loss or tf.keras.losses.MeanSquaredError() 35 | self.optimizer = optimizer or \ 36 | tf.keras.optimizers.Adam(lr=1e-3, epsilon=1e-8) 37 | self.gradient_clip = gradient_clip 38 | 39 | def initialize(self, model): 40 | self.model = model 41 | self.variables = self.model.critic.trainable_variables 42 | 43 | @tf.function 44 | def __call__(self, observations, actions, returns): 45 | with tf.GradientTape() as tape: 46 | values = self.model.critic(observations, actions) 47 | loss = self.loss(returns, values) 48 | 49 | gradients = tape.gradient(loss, self.variables) 50 | if self.gradient_clip > 0: 51 | gradients = tf.clip_by_global_norm( 52 | gradients, self.gradient_clip)[0] 53 | self.optimizer.apply_gradients(zip(gradients, self.variables)) 54 | 55 | return dict(loss=loss, q=values) 56 | 57 | 58 | class DeterministicQLearning: 59 | def __init__(self, loss=None, optimizer=None, gradient_clip=0): 60 | self.loss = loss or tf.keras.losses.MeanSquaredError() 61 | self.optimizer = optimizer or \ 62 | tf.keras.optimizers.Adam(lr=1e-3, epsilon=1e-8) 63 | self.gradient_clip = gradient_clip 64 | 65 | def initialize(self, model): 66 | self.model = model 67 | self.variables = self.model.critic.trainable_variables 68 | 69 | @tf.function 70 | def __call__( 71 | self, observations, actions, next_observations, rewards, discounts 72 | ): 73 | next_actions = self.model.target_actor(next_observations) 74 | next_values = self.model.target_critic(next_observations, next_actions) 75 | returns = rewards + discounts * next_values 76 | 77 | with tf.GradientTape() as tape: 78 | values = self.model.critic(observations, actions) 79 | loss = self.loss(returns, values) 80 | 81 | gradients = tape.gradient(loss, self.variables) 82 | if self.gradient_clip > 0: 83 | gradients = tf.clip_by_global_norm( 84 | gradients, self.gradient_clip)[0] 85 | self.optimizer.apply_gradients(zip(gradients, self.variables)) 86 | 87 | return dict(loss=loss, q=values) 88 | 89 | 90 | class DistributionalDeterministicQLearning: 91 | def __init__(self, optimizer=None, gradient_clip=0): 92 | self.optimizer = optimizer or \ 93 | tf.keras.optimizers.Adam(lr=1e-3, epsilon=1e-8) 94 | self.gradient_clip = gradient_clip 95 | 96 | def initialize(self, model): 97 | self.model = model 98 | self.variables = self.model.critic.trainable_variables 99 | 100 | @tf.function 101 | def __call__( 102 | self, observations, actions, next_observations, rewards, discounts 103 | ): 104 | next_actions = self.model.target_actor(next_observations) 105 | next_value_distributions = self.model.target_critic( 106 | next_observations, next_actions) 107 | 108 | values = next_value_distributions.values 109 | returns = rewards[:, None] + discounts[:, None] * values 110 | targets = next_value_distributions.project(returns) 111 | 112 | with tf.GradientTape() as tape: 113 | value_distributions = self.model.critic(observations, actions) 114 | losses = tf.nn.softmax_cross_entropy_with_logits( 115 | logits=value_distributions.logits, labels=targets) 116 | loss = tf.reduce_mean(losses) 117 | 118 | gradients = tape.gradient(loss, self.variables) 119 | if self.gradient_clip > 0: 120 | gradients = tf.clip_by_global_norm( 121 | gradients, self.gradient_clip)[0] 122 | self.optimizer.apply_gradients(zip(gradients, self.variables)) 123 | 124 | return dict(loss=loss) 125 | 126 | 127 | class TargetActionNoise: 128 | def __init__(self, scale=0.2, clip=0.5): 129 | self.scale = scale 130 | self.clip = clip 131 | 132 | def __call__(self, actions): 133 | noises = self.scale * tf.random.normal(actions.shape) 134 | noises = tf.clip_by_value(noises, -self.clip, self.clip) 135 | actions = actions + noises 136 | return tf.clip_by_value(actions, -1, 1) 137 | 138 | 139 | class TwinCriticDeterministicQLearning: 140 | def __init__( 141 | self, loss=None, optimizer=None, target_action_noise=None, 142 | gradient_clip=0 143 | ): 144 | self.loss = loss or tf.keras.losses.MeanSquaredError() 145 | self.optimizer = optimizer or \ 146 | tf.keras.optimizers.Adam(lr=1e-3, epsilon=1e-8) 147 | self.target_action_noise = target_action_noise or \ 148 | TargetActionNoise(scale=0.2, clip=0.5) 149 | self.gradient_clip = gradient_clip 150 | 151 | def initialize(self, model): 152 | self.model = model 153 | variables_1 = self.model.critic_1.trainable_variables 154 | variables_2 = self.model.critic_2.trainable_variables 155 | self.variables = variables_1 + variables_2 156 | 157 | @tf.function 158 | def __call__( 159 | self, observations, actions, next_observations, rewards, discounts 160 | ): 161 | next_actions = self.model.target_actor(next_observations) 162 | next_actions = self.target_action_noise(next_actions) 163 | next_values_1 = self.model.target_critic_1( 164 | next_observations, next_actions) 165 | next_values_2 = self.model.target_critic_2( 166 | next_observations, next_actions) 167 | next_values = tf.minimum(next_values_1, next_values_2) 168 | returns = rewards + discounts * next_values 169 | 170 | with tf.GradientTape() as tape: 171 | values_1 = self.model.critic_1(observations, actions) 172 | values_2 = self.model.critic_2(observations, actions) 173 | loss_1 = self.loss(returns, values_1) 174 | loss_2 = self.loss(returns, values_2) 175 | loss = loss_1 + loss_2 176 | 177 | gradients = tape.gradient(loss, self.variables) 178 | if self.gradient_clip > 0: 179 | gradients = tf.clip_by_global_norm( 180 | gradients, self.gradient_clip)[0] 181 | self.optimizer.apply_gradients(zip(gradients, self.variables)) 182 | 183 | return dict(loss=loss, q1=values_1, q2=values_2) 184 | 185 | 186 | class TwinCriticSoftQLearning: 187 | def __init__( 188 | self, loss=None, optimizer=None, entropy_coeff=0.2, gradient_clip=0 189 | ): 190 | self.loss = loss or tf.keras.losses.MeanSquaredError() 191 | self.optimizer = optimizer or \ 192 | tf.keras.optimizers.Adam(lr=1e-3, epsilon=1e-8) 193 | self.entropy_coeff = entropy_coeff 194 | self.gradient_clip = gradient_clip 195 | 196 | def initialize(self, model): 197 | self.model = model 198 | variables_1 = self.model.critic_1.trainable_variables 199 | variables_2 = self.model.critic_2.trainable_variables 200 | self.variables = variables_1 + variables_2 201 | 202 | @tf.function 203 | def __call__( 204 | self, observations, actions, next_observations, rewards, discounts 205 | ): 206 | next_distributions = self.model.actor(next_observations) 207 | if hasattr(next_distributions, 'sample_with_log_prob'): 208 | outs = next_distributions.sample_with_log_prob() 209 | next_actions, next_log_probs = outs 210 | else: 211 | next_actions = next_distributions.sample() 212 | next_log_probs = next_distributions.log_prob(next_actions) 213 | next_values_1 = self.model.target_critic_1( 214 | next_observations, next_actions) 215 | next_values_2 = self.model.target_critic_2( 216 | next_observations, next_actions) 217 | next_values = tf.minimum(next_values_1, next_values_2) 218 | returns = rewards + discounts * ( 219 | next_values - self.entropy_coeff * next_log_probs) 220 | 221 | with tf.GradientTape() as tape: 222 | values_1 = self.model.critic_1(observations, actions) 223 | values_2 = self.model.critic_2(observations, actions) 224 | loss_1 = self.loss(returns, values_1) 225 | loss_2 = self.loss(returns, values_2) 226 | loss = loss_1 + loss_2 227 | 228 | gradients = tape.gradient(loss, self.variables) 229 | if self.gradient_clip > 0: 230 | gradients = tf.clip_by_global_norm( 231 | gradients, self.gradient_clip)[0] 232 | self.optimizer.apply_gradients(zip(gradients, self.variables)) 233 | 234 | return dict(loss=loss, q1=values_1, q2=values_2) 235 | 236 | 237 | class ExpectedSARSA: 238 | def __init__( 239 | self, num_samples=20, loss=None, optimizer=None, gradient_clip=0 240 | ): 241 | self.num_samples = num_samples 242 | self.loss = loss or tf.keras.losses.MeanSquaredError() 243 | self.optimizer = optimizer or \ 244 | tf.keras.optimizers.Adam(lr=1e-3, epsilon=1e-8) 245 | self.gradient_clip = gradient_clip 246 | 247 | def initialize(self, model): 248 | self.model = model 249 | self.variables = self.model.critic.trainable_variables 250 | 251 | @tf.function 252 | def __call__( 253 | self, observations, actions, next_observations, rewards, discounts 254 | ): 255 | # Approximate the expected next values. 256 | next_target_distributions = self.model.target_actor(next_observations) 257 | next_actions = next_target_distributions.sample(self.num_samples) 258 | next_actions = updaters.merge_first_two_dims(next_actions) 259 | next_observations = updaters.tile(next_observations, self.num_samples) 260 | next_observations = updaters.merge_first_two_dims(next_observations) 261 | next_values = self.model.target_critic(next_observations, next_actions) 262 | next_values = tf.reshape(next_values, (self.num_samples, -1)) 263 | next_values = tf.reduce_mean(next_values, axis=0) 264 | returns = rewards + discounts * next_values 265 | 266 | with tf.GradientTape() as tape: 267 | values = self.model.critic(observations, actions) 268 | loss = self.loss(returns, values) 269 | 270 | gradients = tape.gradient(loss, self.variables) 271 | if self.gradient_clip > 0: 272 | gradients = tf.clip_by_global_norm( 273 | gradients, self.gradient_clip)[0] 274 | self.optimizer.apply_gradients(zip(gradients, self.variables)) 275 | 276 | return dict(loss=loss, q=values) 277 | 278 | 279 | class TwinCriticDistributionalDeterministicQLearning: 280 | def __init__( 281 | self, optimizer=None, target_action_noise=None, gradient_clip=0 282 | ): 283 | self.optimizer = optimizer or \ 284 | tf.keras.optimizers.Adam(lr=1e-3, epsilon=1e-8) 285 | self.target_action_noise = target_action_noise or \ 286 | TargetActionNoise(scale=0.2, clip=0.5) 287 | self.gradient_clip = gradient_clip 288 | 289 | def initialize(self, model): 290 | self.model = model 291 | variables_1 = self.model.critic_1.trainable_variables 292 | variables_2 = self.model.critic_2.trainable_variables 293 | self.variables = variables_1 + variables_2 294 | 295 | @tf.function 296 | def __call__( 297 | self, observations, actions, next_observations, rewards, discounts 298 | ): 299 | next_actions = self.model.target_actor(next_observations) 300 | next_actions = self.target_action_noise(next_actions) 301 | next_value_distributions_1 = self.model.target_critic_1( 302 | next_observations, next_actions) 303 | next_value_distributions_2 = self.model.target_critic_2( 304 | next_observations, next_actions) 305 | 306 | values = next_value_distributions_1.values 307 | returns = rewards[:, None] + discounts[:, None] * values 308 | targets_1 = next_value_distributions_1.project(returns) 309 | targets_2 = next_value_distributions_2.project(returns) 310 | next_values_1 = next_value_distributions_1.mean() 311 | next_values_2 = next_value_distributions_2.mean() 312 | twin_next_values = tf.concat( 313 | [next_values_1[None], next_values_2[None]], axis=0) 314 | indices = tf.argmin(twin_next_values, axis=0, output_type=tf.int32) 315 | twin_targets = tf.concat([targets_1[None], targets_2[None]], axis=0) 316 | batch_size = tf.shape(observations)[0] 317 | indices = tf.stack([indices, tf.range(batch_size)], axis=-1) 318 | targets = tf.gather_nd(twin_targets, indices) 319 | 320 | with tf.GradientTape() as tape: 321 | value_distributions_1 = self.model.critic_1(observations, actions) 322 | losses_1 = tf.nn.softmax_cross_entropy_with_logits( 323 | logits=value_distributions_1.logits, labels=targets) 324 | value_distributions_2 = self.model.critic_2(observations, actions) 325 | losses_2 = tf.nn.softmax_cross_entropy_with_logits( 326 | logits=value_distributions_2.logits, labels=targets) 327 | loss = tf.reduce_mean(losses_1) + tf.reduce_mean(losses_2) 328 | 329 | gradients = tape.gradient(loss, self.variables) 330 | if self.gradient_clip > 0: 331 | gradients = tf.clip_by_global_norm( 332 | gradients, self.gradient_clip)[0] 333 | self.optimizer.apply_gradients(zip(gradients, self.variables)) 334 | 335 | return dict(loss=loss, q1=value_distributions_1.mean(), 336 | q2=value_distributions_2.mean()) 337 | -------------------------------------------------------------------------------- /tonic/tensorflow/updaters/optimizers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | FLOAT_EPSILON = 1e-8 6 | 7 | 8 | def flat_concat(xs): 9 | return tf.concat([tf.reshape(x, (-1,)) for x in xs], axis=0) 10 | 11 | 12 | def assign_params_from_flat(new_params, params): 13 | def flat_size(p): 14 | return int(np.prod(p.shape.as_list())) 15 | splits = tf.split(new_params, [flat_size(p) for p in params]) 16 | new_params = [tf.reshape(p_new, p.shape) 17 | for p, p_new in zip(params, splits)] 18 | for p, p_new in zip(params, new_params): 19 | p.assign(p_new) 20 | 21 | 22 | class ConjugateGradient: 23 | def __init__( 24 | self, conjugate_gradient_steps=10, damping_coefficient=0.1, 25 | constraint_threshold=0.01, backtrack_steps=10, 26 | backtrack_coefficient=0.8 27 | ): 28 | self.conjugate_gradient_steps = conjugate_gradient_steps 29 | self.damping_coefficient = damping_coefficient 30 | self.constraint_threshold = constraint_threshold 31 | self.backtrack_steps = backtrack_steps 32 | self.backtrack_coefficient = backtrack_coefficient 33 | 34 | def optimize(self, loss_function, constraint_function, variables): 35 | @tf.function 36 | def _hx(x): 37 | with tf.GradientTape() as tape_2: 38 | with tf.GradientTape() as tape_1: 39 | f = constraint_function() 40 | gradient_1 = flat_concat(tape_1.gradient(f, variables)) 41 | y = tf.reduce_sum(gradient_1 * x) 42 | gradient_2 = flat_concat(tape_2.gradient(y, variables)) 43 | 44 | if self.damping_coefficient > 0: 45 | gradient_2 += self.damping_coefficient * x 46 | 47 | return gradient_2 48 | 49 | def _cg(b): 50 | x = np.zeros_like(b) 51 | r = b.copy() 52 | p = r.copy() 53 | r_dot_old = np.dot(r, r) 54 | if r_dot_old == 0: 55 | return None 56 | 57 | for _ in range(self.conjugate_gradient_steps): 58 | z = _hx(p).numpy() 59 | alpha = r_dot_old / (np.dot(p, z) + FLOAT_EPSILON) 60 | x += alpha * p 61 | r -= alpha * z 62 | r_dot_new = np.dot(r, r) 63 | p = r + (r_dot_new / r_dot_old) * p 64 | r_dot_old = r_dot_new 65 | return x 66 | 67 | @tf.function 68 | def _update(alpha, conjugate_gradient, step, start_variables): 69 | new_variables = start_variables - alpha * conjugate_gradient * step 70 | assign_params_from_flat(new_variables, variables) 71 | constraint = constraint_function() 72 | loss = loss_function() 73 | return constraint, loss 74 | 75 | start_variables = flat_concat(variables) 76 | 77 | with tf.GradientTape() as tape: 78 | loss = loss_function() 79 | grad = flat_concat(tape.gradient(loss, variables)).numpy() 80 | start_loss = loss.numpy() 81 | 82 | conjugate_gradient = _cg(grad) 83 | if conjugate_gradient is None: 84 | constraint = tf.convert_to_tensor(0.) 85 | loss = tf.convert_to_tensor(0.) 86 | steps = tf.convert_to_tensor(0) 87 | return constraint, loss, steps 88 | 89 | alpha = np.sqrt(2 * self.constraint_threshold / np.dot( 90 | conjugate_gradient, _hx(conjugate_gradient)) + FLOAT_EPSILON) 91 | alpha = tf.convert_to_tensor(alpha, tf.float32) 92 | 93 | if self.backtrack_steps is None or self.backtrack_coefficient is None: 94 | constraint, loss = _update( 95 | alpha, conjugate_gradient, 1, start_variables) 96 | return constraint, loss 97 | 98 | for i in range(self.backtrack_steps): 99 | step = tf.convert_to_tensor( 100 | self.backtrack_coefficient ** i, tf.float32) 101 | constraint, loss = _update( 102 | alpha, conjugate_gradient, step, start_variables) 103 | 104 | if constraint <= self.constraint_threshold and loss <= start_loss: 105 | break 106 | 107 | if i == self.backtrack_steps - 1: 108 | step = tf.convert_to_tensor(0., tf.float32) 109 | constraint, loss = _update( 110 | alpha, conjugate_gradient, step, start_variables) 111 | i = self.backtrack_steps 112 | 113 | return constraint, loss, tf.convert_to_tensor(i + 1, dtype=tf.int32) 114 | -------------------------------------------------------------------------------- /tonic/tensorflow/updaters/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def tile(x, n): 5 | return tf.tile(x[None], [n] + [1] * len(x.shape)) 6 | 7 | 8 | def merge_first_two_dims(x): 9 | return tf.reshape(x, [x.shape[0] * x.shape[1]] + x.shape[2:]) 10 | -------------------------------------------------------------------------------- /tonic/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from . import agents, models, normalizers, updaters 2 | 3 | 4 | __all__ = [agents, models, normalizers, updaters] 5 | -------------------------------------------------------------------------------- /tonic/torch/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent import Agent 2 | 3 | from .a2c import A2C # noqa 4 | from .ddpg import DDPG 5 | from .d4pg import D4PG # noqa 6 | from .mpo import MPO 7 | from .ppo import PPO 8 | from .sac import SAC 9 | from .td3 import TD3 10 | from .trpo import TRPO 11 | 12 | 13 | __all__ = [Agent, A2C, DDPG, D4PG, MPO, PPO, SAC, TD3, TRPO] 14 | -------------------------------------------------------------------------------- /tonic/torch/agents/a2c.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from tonic import logger, replays # noqa 4 | from tonic.torch import agents, models, normalizers, updaters 5 | 6 | 7 | def default_model(): 8 | return models.ActorCritic( 9 | actor=models.Actor( 10 | encoder=models.ObservationEncoder(), 11 | torso=models.MLP((64, 64), torch.nn.Tanh), 12 | head=models.DetachedScaleGaussianPolicyHead()), 13 | critic=models.Critic( 14 | encoder=models.ObservationEncoder(), 15 | torso=models.MLP((64, 64), torch.nn.Tanh), 16 | head=models.ValueHead()), 17 | observation_normalizer=normalizers.MeanStd()) 18 | 19 | 20 | class A2C(agents.Agent): 21 | '''Advantage Actor Critic (aka Vanilla Policy Gradient). 22 | A3C: https://arxiv.org/pdf/1602.01783.pdf 23 | ''' 24 | 25 | def __init__( 26 | self, model=None, replay=None, actor_updater=None, critic_updater=None 27 | ): 28 | self.model = model or default_model() 29 | self.replay = replay or replays.Segment() 30 | self.actor_updater = actor_updater or \ 31 | updaters.StochasticPolicyGradient() 32 | self.critic_updater = critic_updater or updaters.VRegression() 33 | 34 | def initialize(self, observation_space, action_space, seed=None): 35 | super().initialize(seed=seed) 36 | self.model.initialize(observation_space, action_space) 37 | self.replay.initialize(seed) 38 | self.actor_updater.initialize(self.model) 39 | self.critic_updater.initialize(self.model) 40 | 41 | def step(self, observations, steps): 42 | # Sample actions and get their log-probabilities for training. 43 | actions, log_probs = self._step(observations) 44 | actions = actions.numpy() 45 | log_probs = log_probs.numpy() 46 | 47 | # Keep some values for the next update. 48 | self.last_observations = observations.copy() 49 | self.last_actions = actions.copy() 50 | self.last_log_probs = log_probs.copy() 51 | 52 | return actions 53 | 54 | def test_step(self, observations, steps): 55 | # Sample actions for testing. 56 | return self._test_step(observations).numpy() 57 | 58 | def update(self, observations, rewards, resets, terminations, steps): 59 | # Store the last transitions in the replay. 60 | self.replay.store( 61 | observations=self.last_observations, actions=self.last_actions, 62 | next_observations=observations, rewards=rewards, resets=resets, 63 | terminations=terminations, log_probs=self.last_log_probs) 64 | 65 | # Prepare to update the normalizers. 66 | if self.model.observation_normalizer: 67 | self.model.observation_normalizer.record(self.last_observations) 68 | if self.model.return_normalizer: 69 | self.model.return_normalizer.record(rewards) 70 | 71 | # Update the model if the replay is ready. 72 | if self.replay.ready(): 73 | self._update() 74 | 75 | def _step(self, observations): 76 | observations = torch.as_tensor(observations, dtype=torch.float32) 77 | with torch.no_grad(): 78 | distributions = self.model.actor(observations) 79 | if hasattr(distributions, 'sample_with_log_prob'): 80 | actions, log_probs = distributions.sample_with_log_prob() 81 | else: 82 | actions = distributions.sample() 83 | log_probs = distributions.log_prob(actions) 84 | log_probs = log_probs.sum(dim=-1) 85 | return actions, log_probs 86 | 87 | def _test_step(self, observations): 88 | observations = torch.as_tensor(observations, dtype=torch.float32) 89 | with torch.no_grad(): 90 | return self.model.actor(observations).sample() 91 | 92 | def _evaluate(self, observations, next_observations): 93 | observations = torch.as_tensor(observations, dtype=torch.float32) 94 | next_observations = torch.as_tensor( 95 | next_observations, dtype=torch.float32) 96 | with torch.no_grad(): 97 | values = self.model.critic(observations) 98 | next_values = self.model.critic(next_observations) 99 | return values, next_values 100 | 101 | def _update(self): 102 | # Compute the lambda-returns. 103 | batch = self.replay.get_full('observations', 'next_observations') 104 | values, next_values = self._evaluate(**batch) 105 | values, next_values = values.numpy(), next_values.numpy() 106 | self.replay.compute_returns(values, next_values) 107 | 108 | # Update the actor once. 109 | keys = 'observations', 'actions', 'advantages', 'log_probs' 110 | batch = self.replay.get_full(*keys) 111 | batch = {k: torch.as_tensor(v) for k, v in batch.items()} 112 | infos = self.actor_updater(**batch) 113 | for k, v in infos.items(): 114 | logger.store('actor/' + k, v.numpy()) 115 | 116 | # Update the critic multiple times. 117 | for batch in self.replay.get('observations', 'returns'): 118 | batch = {k: torch.as_tensor(v) for k, v in batch.items()} 119 | infos = self.critic_updater(**batch) 120 | for k, v in infos.items(): 121 | logger.store('critic/' + k, v.numpy()) 122 | 123 | # Update the normalizers. 124 | if self.model.observation_normalizer: 125 | self.model.observation_normalizer.update() 126 | if self.model.return_normalizer: 127 | self.model.return_normalizer.update() 128 | -------------------------------------------------------------------------------- /tonic/torch/agents/agent.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from tonic import agents, logger # noqa 8 | 9 | 10 | class Agent(agents.Agent): 11 | def initialize(self, seed=None): 12 | if seed is not None: 13 | np.random.seed(seed) 14 | random.seed(seed) 15 | torch.manual_seed(seed) 16 | 17 | def save(self, path): 18 | path = path + '.pt' 19 | logger.log(f'\nSaving weights to {path}') 20 | os.makedirs(os.path.dirname(path), exist_ok=True) 21 | torch.save(self.model.state_dict(), path) 22 | 23 | def load(self, path): 24 | path = path + '.pt' 25 | logger.log(f'\nLoading weights from {path}') 26 | self.model.load_state_dict(torch.load(path)) 27 | -------------------------------------------------------------------------------- /tonic/torch/agents/d4pg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from tonic import replays # noqa 4 | from tonic.torch import agents, models, normalizers, updaters 5 | 6 | 7 | def default_model(): 8 | return models.ActorCriticWithTargets( 9 | actor=models.Actor( 10 | encoder=models.ObservationEncoder(), 11 | torso=models.MLP((256, 256), torch.nn.ReLU), 12 | head=models.DeterministicPolicyHead()), 13 | critic=models.Critic( 14 | encoder=models.ObservationActionEncoder(), 15 | torso=models.MLP((256, 256), torch.nn.ReLU), 16 | # These values are for the control suite with 0.99 discount. 17 | head=models.DistributionalValueHead(-150., 150., 51)), 18 | observation_normalizer=normalizers.MeanStd()) 19 | 20 | 21 | class D4PG(agents.DDPG): 22 | '''Distributed Distributional Deterministic Policy Gradients. 23 | D4PG: https://arxiv.org/pdf/1804.08617.pdf 24 | ''' 25 | 26 | def __init__( 27 | self, model=None, replay=None, exploration=None, actor_updater=None, 28 | critic_updater=None 29 | ): 30 | model = model or default_model() 31 | replay = replay or replays.Buffer(return_steps=5) 32 | actor_updater = actor_updater or \ 33 | updaters.DistributionalDeterministicPolicyGradient() 34 | critic_updater = critic_updater or \ 35 | updaters.DistributionalDeterministicQLearning() 36 | super().__init__( 37 | model, replay, exploration, actor_updater, critic_updater) 38 | -------------------------------------------------------------------------------- /tonic/torch/agents/ddpg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from tonic import explorations, logger, replays # noqa 4 | from tonic.torch import agents, models, normalizers, updaters 5 | 6 | 7 | def default_model(): 8 | return models.ActorCriticWithTargets( 9 | actor=models.Actor( 10 | encoder=models.ObservationEncoder(), 11 | torso=models.MLP((256, 256), torch.nn.ReLU), 12 | head=models.DeterministicPolicyHead()), 13 | critic=models.Critic( 14 | encoder=models.ObservationActionEncoder(), 15 | torso=models.MLP((256, 256), torch.nn.ReLU), 16 | head=models.ValueHead()), 17 | observation_normalizer=normalizers.MeanStd()) 18 | 19 | 20 | class DDPG(agents.Agent): 21 | '''Deep Deterministic Policy Gradient. 22 | DDPG: https://arxiv.org/pdf/1509.02971.pdf 23 | ''' 24 | 25 | def __init__( 26 | self, model=None, replay=None, exploration=None, actor_updater=None, 27 | critic_updater=None 28 | ): 29 | self.model = model or default_model() 30 | self.replay = replay or replays.Buffer() 31 | self.exploration = exploration or explorations.NormalActionNoise() 32 | self.actor_updater = actor_updater or \ 33 | updaters.DeterministicPolicyGradient() 34 | self.critic_updater = critic_updater or \ 35 | updaters.DeterministicQLearning() 36 | 37 | def initialize(self, observation_space, action_space, seed=None): 38 | super().initialize(seed=seed) 39 | self.model.initialize(observation_space, action_space) 40 | self.replay.initialize(seed) 41 | self.exploration.initialize(self._policy, action_space, seed) 42 | self.actor_updater.initialize(self.model) 43 | self.critic_updater.initialize(self.model) 44 | 45 | def step(self, observations, steps): 46 | # Get actions from the actor and exploration method. 47 | actions = self.exploration(observations, steps) 48 | 49 | # Keep some values for the next update. 50 | self.last_observations = observations.copy() 51 | self.last_actions = actions.copy() 52 | 53 | return actions 54 | 55 | def test_step(self, observations, steps): 56 | # Greedy actions for testing. 57 | return self._greedy_actions(observations).numpy() 58 | 59 | def update(self, observations, rewards, resets, terminations, steps): 60 | # Store the last transitions in the replay. 61 | self.replay.store( 62 | observations=self.last_observations, actions=self.last_actions, 63 | next_observations=observations, rewards=rewards, resets=resets, 64 | terminations=terminations) 65 | 66 | # Prepare to update the normalizers. 67 | if self.model.observation_normalizer: 68 | self.model.observation_normalizer.record(self.last_observations) 69 | if self.model.return_normalizer: 70 | self.model.return_normalizer.record(rewards) 71 | 72 | # Update the model if the replay is ready. 73 | if self.replay.ready(steps): 74 | self._update(steps) 75 | 76 | self.exploration.update(resets) 77 | 78 | def _greedy_actions(self, observations): 79 | observations = torch.as_tensor(observations, dtype=torch.float32) 80 | with torch.no_grad(): 81 | return self.model.actor(observations) 82 | 83 | def _policy(self, observations): 84 | return self._greedy_actions(observations).numpy() 85 | 86 | def _update(self, steps): 87 | keys = ('observations', 'actions', 'next_observations', 'rewards', 88 | 'discounts') 89 | 90 | # Update both the actor and the critic multiple times. 91 | for batch in self.replay.get(*keys, steps=steps): 92 | batch = {k: torch.as_tensor(v) for k, v in batch.items()} 93 | infos = self._update_actor_critic(**batch) 94 | 95 | for key in infos: 96 | for k, v in infos[key].items(): 97 | logger.store(key + '/' + k, v.numpy()) 98 | 99 | # Update the normalizers. 100 | if self.model.observation_normalizer: 101 | self.model.observation_normalizer.update() 102 | if self.model.return_normalizer: 103 | self.model.return_normalizer.update() 104 | 105 | def _update_actor_critic( 106 | self, observations, actions, next_observations, rewards, discounts 107 | ): 108 | critic_infos = self.critic_updater( 109 | observations, actions, next_observations, rewards, discounts) 110 | actor_infos = self.actor_updater(observations) 111 | self.model.update_targets() 112 | return dict(critic=critic_infos, actor=actor_infos) 113 | -------------------------------------------------------------------------------- /tonic/torch/agents/mpo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from tonic import logger, replays # noqa 4 | from tonic.torch import agents, models, normalizers, updaters 5 | 6 | 7 | def default_model(): 8 | return models.ActorCriticWithTargets( 9 | actor=models.Actor( 10 | encoder=models.ObservationEncoder(), 11 | torso=models.MLP((256, 256), torch.nn.ReLU), 12 | head=models.GaussianPolicyHead()), 13 | critic=models.Critic( 14 | encoder=models.ObservationActionEncoder(), 15 | torso=models.MLP((256, 256), torch.nn.ReLU), 16 | head=models.ValueHead()), 17 | observation_normalizer=normalizers.MeanStd()) 18 | 19 | 20 | class MPO(agents.Agent): 21 | '''Maximum a Posteriori Policy Optimisation. 22 | MPO: https://arxiv.org/pdf/1806.06920.pdf 23 | MO-MPO: https://arxiv.org/pdf/2005.07513.pdf 24 | ''' 25 | 26 | def __init__( 27 | self, model=None, replay=None, actor_updater=None, critic_updater=None 28 | ): 29 | self.model = model or default_model() 30 | self.replay = replay or replays.Buffer(return_steps=5) 31 | self.actor_updater = actor_updater or \ 32 | updaters.MaximumAPosterioriPolicyOptimization() 33 | self.critic_updater = critic_updater or updaters.ExpectedSARSA() 34 | 35 | def initialize(self, observation_space, action_space, seed=None): 36 | super().initialize(seed=seed) 37 | self.model.initialize(observation_space, action_space) 38 | self.replay.initialize(seed) 39 | self.actor_updater.initialize(self.model, action_space) 40 | self.critic_updater.initialize(self.model) 41 | 42 | def step(self, observations, steps): 43 | actions = self._step(observations) 44 | actions = actions.numpy() 45 | 46 | # Keep some values for the next update. 47 | self.last_observations = observations.copy() 48 | self.last_actions = actions.copy() 49 | 50 | return actions 51 | 52 | def test_step(self, observations, steps): 53 | # Sample actions for testing. 54 | return self._test_step(observations).numpy() 55 | 56 | def update(self, observations, rewards, resets, terminations, steps): 57 | # Store the last transitions in the replay. 58 | self.replay.store( 59 | observations=self.last_observations, actions=self.last_actions, 60 | next_observations=observations, rewards=rewards, resets=resets, 61 | terminations=terminations) 62 | 63 | # Prepare to update the normalizers. 64 | if self.model.observation_normalizer: 65 | self.model.observation_normalizer.record(self.last_observations) 66 | if self.model.return_normalizer: 67 | self.model.return_normalizer.record(rewards) 68 | 69 | # Update the model if the replay is ready. 70 | if self.replay.ready(steps): 71 | self._update(steps) 72 | 73 | def _step(self, observations): 74 | observations = torch.as_tensor(observations, dtype=torch.float32) 75 | with torch.no_grad(): 76 | return self.model.actor(observations).sample() 77 | 78 | def _test_step(self, observations): 79 | observations = torch.as_tensor(observations, dtype=torch.float32) 80 | with torch.no_grad(): 81 | return self.model.actor(observations).loc 82 | 83 | def _update(self, steps): 84 | keys = ('observations', 'actions', 'next_observations', 'rewards', 85 | 'discounts') 86 | 87 | # Update both the actor and the critic multiple times. 88 | for batch in self.replay.get(*keys, steps=steps): 89 | batch = {k: torch.as_tensor(v) for k, v in batch.items()} 90 | infos = self._update_actor_critic(**batch) 91 | 92 | for key in infos: 93 | for k, v in infos[key].items(): 94 | logger.store(key + '/' + k, v.numpy()) 95 | 96 | # Update the normalizers. 97 | if self.model.observation_normalizer: 98 | self.model.observation_normalizer.update() 99 | if self.model.return_normalizer: 100 | self.model.return_normalizer.update() 101 | 102 | def _update_actor_critic( 103 | self, observations, actions, next_observations, rewards, discounts 104 | ): 105 | critic_infos = self.critic_updater( 106 | observations, actions, next_observations, rewards, discounts) 107 | actor_infos = self.actor_updater(observations) 108 | self.model.update_targets() 109 | return dict(critic=critic_infos, actor=actor_infos) 110 | -------------------------------------------------------------------------------- /tonic/torch/agents/ppo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from tonic import logger # noqa 4 | from tonic.torch import agents, updaters 5 | 6 | 7 | class PPO(agents.A2C): 8 | '''Proximal Policy Optimization. 9 | PPO: https://arxiv.org/pdf/1707.06347.pdf 10 | ''' 11 | 12 | def __init__( 13 | self, model=None, replay=None, actor_updater=None, critic_updater=None 14 | ): 15 | actor_updater = actor_updater or updaters.ClippedRatio() 16 | super().__init__( 17 | model=model, replay=replay, actor_updater=actor_updater, 18 | critic_updater=critic_updater) 19 | 20 | def _update(self): 21 | # Compute the lambda-returns. 22 | batch = self.replay.get_full('observations', 'next_observations') 23 | values, next_values = self._evaluate(**batch) 24 | values, next_values = values.numpy(), next_values.numpy() 25 | self.replay.compute_returns(values, next_values) 26 | 27 | train_actor = True 28 | actor_iterations = 0 29 | critic_iterations = 0 30 | keys = 'observations', 'actions', 'advantages', 'log_probs', 'returns' 31 | 32 | # Update both the actor and the critic multiple times. 33 | for batch in self.replay.get(*keys): 34 | if train_actor: 35 | batch = {k: torch.as_tensor(v) for k, v in batch.items()} 36 | infos = self._update_actor_critic(**batch) 37 | actor_iterations += 1 38 | else: 39 | batch = {k: torch.as_tensor(batch[k]) 40 | for k in ('observations', 'returns')} 41 | infos = dict(critic=self.critic_updater(**batch)) 42 | critic_iterations += 1 43 | 44 | # Stop earlier the training of the actor. 45 | if train_actor: 46 | train_actor = not infos['actor']['stop'].numpy() 47 | 48 | for key in infos: 49 | for k, v in infos[key].items(): 50 | logger.store(key + '/' + k, v.numpy()) 51 | 52 | logger.store('actor/iterations', actor_iterations) 53 | logger.store('critic/iterations', critic_iterations) 54 | 55 | # Update the normalizers. 56 | if self.model.observation_normalizer: 57 | self.model.observation_normalizer.update() 58 | if self.model.return_normalizer: 59 | self.model.return_normalizer.update() 60 | 61 | def _update_actor_critic( 62 | self, observations, actions, advantages, log_probs, returns 63 | ): 64 | actor_infos = self.actor_updater( 65 | observations, actions, advantages, log_probs) 66 | critic_infos = self.critic_updater(observations, returns) 67 | return dict(actor=actor_infos, critic=critic_infos) 68 | -------------------------------------------------------------------------------- /tonic/torch/agents/sac.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from tonic import explorations # noqa 4 | from tonic.torch import agents, models, normalizers, updaters 5 | 6 | 7 | def default_model(): 8 | return models.ActorTwinCriticWithTargets( 9 | actor=models.Actor( 10 | encoder=models.ObservationEncoder(), 11 | torso=models.MLP((256, 256), torch.nn.ReLU), 12 | head=models.GaussianPolicyHead( 13 | loc_activation=torch.nn.Identity, 14 | distribution=models.SquashedMultivariateNormalDiag)), 15 | critic=models.Critic( 16 | encoder=models.ObservationActionEncoder(), 17 | torso=models.MLP((256, 256), torch.nn.ReLU), 18 | head=models.ValueHead()), 19 | observation_normalizer=normalizers.MeanStd()) 20 | 21 | 22 | class SAC(agents.DDPG): 23 | '''Soft Actor-Critic. 24 | SAC: https://arxiv.org/pdf/1801.01290.pdf 25 | ''' 26 | 27 | def __init__( 28 | self, model=None, replay=None, exploration=None, actor_updater=None, 29 | critic_updater=None 30 | ): 31 | model = model or default_model() 32 | exploration = exploration or explorations.NoActionNoise() 33 | actor_updater = actor_updater or \ 34 | updaters.TwinCriticSoftDeterministicPolicyGradient() 35 | critic_updater = critic_updater or updaters.TwinCriticSoftQLearning() 36 | super().__init__( 37 | model=model, replay=replay, exploration=exploration, 38 | actor_updater=actor_updater, critic_updater=critic_updater) 39 | 40 | def _stochastic_actions(self, observations): 41 | observations = torch.as_tensor(observations, dtype=torch.float32) 42 | with torch.no_grad(): 43 | return self.model.actor(observations).sample() 44 | 45 | def _policy(self, observations): 46 | return self._stochastic_actions(observations).numpy() 47 | 48 | def _greedy_actions(self, observations): 49 | observations = torch.as_tensor(observations, dtype=torch.float32) 50 | with torch.no_grad(): 51 | return self.model.actor(observations).loc 52 | -------------------------------------------------------------------------------- /tonic/torch/agents/td3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from tonic import logger # noqa 4 | from tonic.torch import agents, models, normalizers, updaters 5 | 6 | 7 | def default_model(): 8 | return models.ActorTwinCriticWithTargets( 9 | actor=models.Actor( 10 | encoder=models.ObservationEncoder(), 11 | torso=models.MLP((256, 256), torch.nn.ReLU), 12 | head=models.DeterministicPolicyHead()), 13 | critic=models.Critic( 14 | encoder=models.ObservationActionEncoder(), 15 | torso=models.MLP((256, 256), torch.nn.ReLU), 16 | head=models.ValueHead()), 17 | observation_normalizer=normalizers.MeanStd()) 18 | 19 | 20 | class TD3(agents.DDPG): 21 | '''Twin Delayed Deep Deterministic Policy Gradient. 22 | TD3: https://arxiv.org/pdf/1802.09477.pdf 23 | ''' 24 | 25 | def __init__( 26 | self, model=None, replay=None, exploration=None, actor_updater=None, 27 | critic_updater=None, delay_steps=2 28 | ): 29 | model = model or default_model() 30 | critic_updater = critic_updater or \ 31 | updaters.TwinCriticDeterministicQLearning() 32 | super().__init__( 33 | model=model, replay=replay, exploration=exploration, 34 | actor_updater=actor_updater, critic_updater=critic_updater) 35 | self.delay_steps = delay_steps 36 | self.model.critic = self.model.critic_1 37 | 38 | def _update(self, steps): 39 | keys = ('observations', 'actions', 'next_observations', 'rewards', 40 | 'discounts') 41 | for i, batch in enumerate(self.replay.get(*keys, steps=steps)): 42 | batch = {k: torch.as_tensor(v) for k, v in batch.items()} 43 | if (i + 1) % self.delay_steps == 0: 44 | infos = self._update_actor_critic(**batch) 45 | else: 46 | infos = dict(critic=self.critic_updater(**batch)) 47 | for key in infos: 48 | for k, v in infos[key].items(): 49 | logger.store(key + '/' + k, v.numpy()) 50 | 51 | # Update the normalizers. 52 | if self.model.observation_normalizer: 53 | self.model.observation_normalizer.update() 54 | if self.model.return_normalizer: 55 | self.model.return_normalizer.update() 56 | -------------------------------------------------------------------------------- /tonic/torch/agents/trpo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from tonic import logger # noqa 4 | from tonic.torch import agents, updaters 5 | 6 | 7 | class TRPO(agents.A2C): 8 | '''Trust Region Policy Optimization. 9 | TRPO: https://arxiv.org/pdf/1502.05477.pdf 10 | ''' 11 | 12 | def __init__( 13 | self, model=None, replay=None, actor_updater=None, critic_updater=None 14 | ): 15 | actor_updater = actor_updater or updaters.TrustRegionPolicyGradient() 16 | super().__init__( 17 | model=model, replay=replay, actor_updater=actor_updater, 18 | critic_updater=critic_updater) 19 | 20 | def step(self, observations, steps): 21 | # Sample actions and get their log-probabilities for training. 22 | actions, log_probs, locs, scales = self._step(observations) 23 | actions = actions.numpy() 24 | log_probs = log_probs.numpy() 25 | locs = locs.numpy() 26 | scales = scales.numpy() 27 | 28 | # Keep some values for the next update. 29 | self.last_observations = observations.copy() 30 | self.last_actions = actions.copy() 31 | self.last_log_probs = log_probs.copy() 32 | self.last_locs = locs.copy() 33 | self.last_scales = scales.copy() 34 | 35 | return actions 36 | 37 | def update(self, observations, rewards, resets, terminations, steps): 38 | # Store the last transitions in the replay. 39 | self.replay.store( 40 | observations=self.last_observations, actions=self.last_actions, 41 | next_observations=observations, rewards=rewards, resets=resets, 42 | terminations=terminations, log_probs=self.last_log_probs, 43 | locs=self.last_locs, scales=self.last_scales) 44 | 45 | # Prepare to update the normalizers. 46 | if self.model.observation_normalizer: 47 | self.model.observation_normalizer.record(self.last_observations) 48 | if self.model.return_normalizer: 49 | self.model.return_normalizer.record(rewards) 50 | 51 | # Update the model if the replay is ready. 52 | if self.replay.ready(): 53 | self._update() 54 | 55 | def _step(self, observations): 56 | observations = torch.as_tensor(observations, dtype=torch.float32) 57 | with torch.no_grad(): 58 | distributions = self.model.actor(observations) 59 | if hasattr(distributions, 'sample_with_log_prob'): 60 | actions, log_probs = distributions.sample_with_log_prob() 61 | else: 62 | actions = distributions.sample() 63 | log_probs = distributions.log_prob(actions) 64 | log_probs = log_probs.sum(dim=-1) 65 | locs = distributions.loc 66 | scales = distributions.stddev 67 | return actions, log_probs, locs, scales 68 | 69 | def _update(self): 70 | # Compute the lambda-returns. 71 | batch = self.replay.get_full('observations', 'next_observations') 72 | values, next_values = self._evaluate(**batch) 73 | values, next_values = values.numpy(), next_values.numpy() 74 | self.replay.compute_returns(values, next_values) 75 | 76 | keys = ('observations', 'actions', 'log_probs', 'locs', 'scales', 77 | 'advantages') 78 | batch = self.replay.get_full(*keys) 79 | batch = {k: torch.as_tensor(v) for k, v in batch.items()} 80 | infos = self.actor_updater(**batch) 81 | for k, v in infos.items(): 82 | logger.store('actor/' + k, v.numpy()) 83 | 84 | critic_iterations = 0 85 | for batch in self.replay.get('observations', 'returns'): 86 | batch = {k: torch.as_tensor(v) for k, v in batch.items()} 87 | infos = self.critic_updater(**batch) 88 | critic_iterations += 1 89 | for k, v in infos.items(): 90 | logger.store('critic/' + k, v.numpy()) 91 | logger.store('critic/iterations', critic_iterations) 92 | 93 | # Update the normalizers. 94 | if self.model.observation_normalizer: 95 | self.model.observation_normalizer.update() 96 | if self.model.return_normalizer: 97 | self.model.return_normalizer.update() 98 | -------------------------------------------------------------------------------- /tonic/torch/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .actor_critics import ActorCritic 2 | from .actor_critics import ActorCriticWithTargets 3 | from .actor_critics import ActorTwinCriticWithTargets 4 | 5 | from .actors import Actor 6 | from .actors import DetachedScaleGaussianPolicyHead 7 | from .actors import DeterministicPolicyHead 8 | from .actors import GaussianPolicyHead 9 | from .actors import SquashedMultivariateNormalDiag 10 | 11 | from .critics import Critic, DistributionalValueHead, ValueHead 12 | 13 | from .encoders import ObservationActionEncoder, ObservationEncoder 14 | 15 | from .utils import MLP, trainable_variables 16 | 17 | 18 | __all__ = [ 19 | MLP, trainable_variables, ObservationActionEncoder, 20 | ObservationEncoder, SquashedMultivariateNormalDiag, 21 | DetachedScaleGaussianPolicyHead, GaussianPolicyHead, 22 | DeterministicPolicyHead, Actor, Critic, DistributionalValueHead, 23 | ValueHead, ActorCritic, ActorCriticWithTargets, ActorTwinCriticWithTargets] 24 | -------------------------------------------------------------------------------- /tonic/torch/models/actor_critics.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | 5 | from tonic.torch import models # noqa 6 | 7 | 8 | class ActorCritic(torch.nn.Module): 9 | def __init__( 10 | self, actor, critic, observation_normalizer=None, 11 | return_normalizer=None 12 | ): 13 | super().__init__() 14 | self.actor = actor 15 | self.critic = critic 16 | self.observation_normalizer = observation_normalizer 17 | self.return_normalizer = return_normalizer 18 | 19 | def initialize(self, observation_space, action_space): 20 | if self.observation_normalizer: 21 | self.observation_normalizer.initialize(observation_space.shape) 22 | self.actor.initialize( 23 | observation_space, action_space, self.observation_normalizer) 24 | self.critic.initialize( 25 | observation_space, action_space, self.observation_normalizer, 26 | self.return_normalizer) 27 | 28 | 29 | class ActorCriticWithTargets(torch.nn.Module): 30 | def __init__( 31 | self, actor, critic, observation_normalizer=None, 32 | return_normalizer=None, target_coeff=0.005 33 | ): 34 | super().__init__() 35 | self.actor = actor 36 | self.critic = critic 37 | self.target_actor = copy.deepcopy(actor) 38 | self.target_critic = copy.deepcopy(critic) 39 | self.observation_normalizer = observation_normalizer 40 | self.return_normalizer = return_normalizer 41 | self.target_coeff = target_coeff 42 | 43 | def initialize(self, observation_space, action_space): 44 | if self.observation_normalizer: 45 | self.observation_normalizer.initialize(observation_space.shape) 46 | self.actor.initialize( 47 | observation_space, action_space, self.observation_normalizer) 48 | self.critic.initialize( 49 | observation_space, action_space, self.observation_normalizer, 50 | self.return_normalizer) 51 | self.target_actor.initialize( 52 | observation_space, action_space, self.observation_normalizer) 53 | self.target_critic.initialize( 54 | observation_space, action_space, self.observation_normalizer, 55 | self.return_normalizer) 56 | self.online_variables = models.trainable_variables(self.actor) 57 | self.online_variables += models.trainable_variables(self.critic) 58 | self.target_variables = models.trainable_variables(self.target_actor) 59 | self.target_variables += models.trainable_variables(self.target_critic) 60 | for target in self.target_variables: 61 | target.requires_grad = False 62 | self.assign_targets() 63 | 64 | def assign_targets(self): 65 | for o, t in zip(self.online_variables, self.target_variables): 66 | t.data.copy_(o.data) 67 | 68 | def update_targets(self): 69 | with torch.no_grad(): 70 | for o, t in zip(self.online_variables, self.target_variables): 71 | t.data.mul_(1 - self.target_coeff) 72 | t.data.add_(self.target_coeff * o.data) 73 | 74 | 75 | class ActorTwinCriticWithTargets(torch.nn.Module): 76 | def __init__( 77 | self, actor, critic, observation_normalizer=None, 78 | return_normalizer=None, target_coeff=0.005 79 | ): 80 | super().__init__() 81 | self.actor = actor 82 | self.critic_1 = critic 83 | self.critic_2 = copy.deepcopy(critic) 84 | self.target_actor = copy.deepcopy(actor) 85 | self.target_critic_1 = copy.deepcopy(critic) 86 | self.target_critic_2 = copy.deepcopy(critic) 87 | self.observation_normalizer = observation_normalizer 88 | self.return_normalizer = return_normalizer 89 | self.target_coeff = target_coeff 90 | 91 | def initialize(self, observation_space, action_space): 92 | if self.observation_normalizer: 93 | self.observation_normalizer.initialize(observation_space.shape) 94 | self.actor.initialize( 95 | observation_space, action_space, self.observation_normalizer) 96 | self.critic_1.initialize( 97 | observation_space, action_space, self.observation_normalizer, 98 | self.return_normalizer) 99 | self.critic_2.initialize( 100 | observation_space, action_space, self.observation_normalizer, 101 | self.return_normalizer) 102 | self.target_actor.initialize( 103 | observation_space, action_space, self.observation_normalizer) 104 | self.target_critic_1.initialize( 105 | observation_space, action_space, self.observation_normalizer, 106 | self.return_normalizer) 107 | self.target_critic_2.initialize( 108 | observation_space, action_space, self.observation_normalizer, 109 | self.return_normalizer) 110 | self.online_variables = models.trainable_variables(self.actor) 111 | self.online_variables += models.trainable_variables(self.critic_1) 112 | self.online_variables += models.trainable_variables(self.critic_2) 113 | self.target_variables = models.trainable_variables(self.target_actor) 114 | self.target_variables += models.trainable_variables( 115 | self.target_critic_1) 116 | self.target_variables += models.trainable_variables( 117 | self.target_critic_2) 118 | for target in self.target_variables: 119 | target.requires_grad = False 120 | self.assign_targets() 121 | 122 | def assign_targets(self): 123 | for o, t in zip(self.online_variables, self.target_variables): 124 | t.data.copy_(o.data) 125 | 126 | def update_targets(self): 127 | with torch.no_grad(): 128 | for o, t in zip(self.online_variables, self.target_variables): 129 | t.data.mul_(1 - self.target_coeff) 130 | t.data.add_(self.target_coeff * o.data) 131 | -------------------------------------------------------------------------------- /tonic/torch/models/actors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | FLOAT_EPSILON = 1e-8 5 | 6 | 7 | class SquashedMultivariateNormalDiag: 8 | def __init__(self, loc, scale): 9 | self._distribution = torch.distributions.normal.Normal(loc, scale) 10 | 11 | def rsample_with_log_prob(self, shape=()): 12 | samples = self._distribution.rsample(shape) 13 | squashed_samples = torch.tanh(samples) 14 | log_probs = self._distribution.log_prob(samples) 15 | log_probs -= torch.log(1 - squashed_samples ** 2 + 1e-6) 16 | return squashed_samples, log_probs 17 | 18 | def rsample(self, shape=()): 19 | samples = self._distribution.rsample(shape) 20 | return torch.tanh(samples) 21 | 22 | def sample(self, shape=()): 23 | samples = self._distribution.sample(shape) 24 | return torch.tanh(samples) 25 | 26 | def log_prob(self, samples): 27 | '''Required unsquashed samples cannot be accurately recovered.''' 28 | raise NotImplementedError( 29 | 'Not implemented to avoid approximation errors. ' 30 | 'Use sample_with_log_prob directly.') 31 | 32 | @property 33 | def loc(self): 34 | return torch.tanh(self._distribution.mean) 35 | 36 | 37 | class DetachedScaleGaussianPolicyHead(torch.nn.Module): 38 | def __init__( 39 | self, loc_activation=torch.nn.Tanh, loc_fn=None, log_scale_init=0., 40 | scale_min=1e-4, scale_max=1., 41 | distribution=torch.distributions.normal.Normal 42 | ): 43 | super().__init__() 44 | self.loc_activation = loc_activation 45 | self.loc_fn = loc_fn 46 | self.log_scale_init = log_scale_init 47 | self.scale_min = scale_min 48 | self.scale_max = scale_max 49 | self.distribution = distribution 50 | 51 | def initialize(self, input_size, action_size): 52 | self.loc_layer = torch.nn.Sequential( 53 | torch.nn.Linear(input_size, action_size), self.loc_activation()) 54 | if self.loc_fn: 55 | self.loc_layer.apply(self.loc_fn) 56 | log_scale = [[self.log_scale_init] * action_size] 57 | self.log_scale = torch.nn.Parameter( 58 | torch.as_tensor(log_scale, dtype=torch.float32)) 59 | 60 | def forward(self, inputs): 61 | loc = self.loc_layer(inputs) 62 | batch_size = inputs.shape[0] 63 | scale = torch.nn.functional.softplus(self.log_scale) + FLOAT_EPSILON 64 | scale = torch.clamp(scale, self.scale_min, self.scale_max) 65 | scale = scale.repeat(batch_size, 1) 66 | return self.distribution(loc, scale) 67 | 68 | 69 | class GaussianPolicyHead(torch.nn.Module): 70 | def __init__( 71 | self, loc_activation=torch.nn.Tanh, loc_fn=None, 72 | scale_activation=torch.nn.Softplus, scale_min=1e-4, scale_max=1, 73 | scale_fn=None, distribution=torch.distributions.normal.Normal 74 | ): 75 | super().__init__() 76 | self.loc_activation = loc_activation 77 | self.loc_fn = loc_fn 78 | self.scale_activation = scale_activation 79 | self.scale_min = scale_min 80 | self.scale_max = scale_max 81 | self.scale_fn = scale_fn 82 | self.distribution = distribution 83 | 84 | def initialize(self, input_size, action_size): 85 | self.loc_layer = torch.nn.Sequential( 86 | torch.nn.Linear(input_size, action_size), self.loc_activation()) 87 | if self.loc_fn: 88 | self.loc_layer.apply(self.loc_fn) 89 | self.scale_layer = torch.nn.Sequential( 90 | torch.nn.Linear(input_size, action_size), self.scale_activation()) 91 | if self.scale_fn: 92 | self.scale_layer.apply(self.scale_fn) 93 | 94 | def forward(self, inputs): 95 | loc = self.loc_layer(inputs) 96 | scale = self.scale_layer(inputs) 97 | scale = torch.clamp(scale, self.scale_min, self.scale_max) 98 | return self.distribution(loc, scale) 99 | 100 | 101 | class DeterministicPolicyHead(torch.nn.Module): 102 | def __init__(self, activation=torch.nn.Tanh, fn=None): 103 | super().__init__() 104 | self.activation = activation 105 | self.fn = fn 106 | 107 | def initialize(self, input_size, action_size): 108 | self.action_layer = torch.nn.Sequential( 109 | torch.nn.Linear(input_size, action_size), 110 | self.activation()) 111 | if self.fn is not None: 112 | self.action_layer.apply(self.fn) 113 | 114 | def forward(self, inputs): 115 | return self.action_layer(inputs) 116 | 117 | 118 | class Actor(torch.nn.Module): 119 | def __init__(self, encoder, torso, head): 120 | super().__init__() 121 | self.encoder = encoder 122 | self.torso = torso 123 | self.head = head 124 | 125 | def initialize( 126 | self, observation_space, action_space, observation_normalizer=None 127 | ): 128 | size = self.encoder.initialize( 129 | observation_space, observation_normalizer) 130 | size = self.torso.initialize(size) 131 | action_size = action_space.shape[0] 132 | self.head.initialize(size, action_size) 133 | 134 | def forward(self, *inputs): 135 | out = self.encoder(*inputs) 136 | out = self.torso(out) 137 | return self.head(out) 138 | -------------------------------------------------------------------------------- /tonic/torch/models/critics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ValueHead(torch.nn.Module): 5 | def __init__(self, fn=None): 6 | super().__init__() 7 | self.fn = fn 8 | 9 | def initialize(self, input_size, return_normalizer=None): 10 | self.return_normalizer = return_normalizer 11 | self.v_layer = torch.nn.Linear(input_size, 1) 12 | if self.fn: 13 | self.v_layer.apply(self.fn) 14 | 15 | def forward(self, inputs): 16 | out = self.v_layer(inputs) 17 | out = torch.squeeze(out, -1) 18 | if self.return_normalizer: 19 | out = self.return_normalizer(out) 20 | return out 21 | 22 | 23 | class CategoricalWithSupport: 24 | def __init__(self, values, logits): 25 | self.values = values 26 | self.logits = logits 27 | self.probabilities = torch.nn.functional.softmax(logits, dim=-1) 28 | 29 | def mean(self): 30 | return (self.probabilities * self.values).sum(dim=-1) 31 | 32 | def project(self, returns): 33 | vmin, vmax = self.values[0], self.values[-1] 34 | d_pos = torch.cat([self.values, vmin[None]], 0)[1:] 35 | d_pos = (d_pos - self.values)[None, :, None] 36 | d_neg = torch.cat([vmax[None], self.values], 0)[:-1] 37 | d_neg = (self.values - d_neg)[None, :, None] 38 | 39 | clipped_returns = torch.clamp(returns, vmin, vmax) 40 | delta_values = clipped_returns[:, None] - self.values[None, :, None] 41 | delta_sign = (delta_values >= 0).float() 42 | delta_hat = ((delta_sign * delta_values / d_pos) - 43 | ((1 - delta_sign) * delta_values / d_neg)) 44 | delta_clipped = torch.clamp(1 - delta_hat, 0, 1) 45 | 46 | return (delta_clipped * self.probabilities[:, None]).sum(dim=2) 47 | 48 | 49 | class DistributionalValueHead(torch.nn.Module): 50 | def __init__(self, vmin, vmax, num_atoms, fn=None): 51 | super().__init__() 52 | self.num_atoms = num_atoms 53 | self.fn = fn 54 | self.values = torch.linspace(vmin, vmax, num_atoms).float() 55 | 56 | def initialize(self, input_size, return_normalizer=None): 57 | if return_normalizer: 58 | raise ValueError( 59 | 'Return normalizers cannot be used with distributional value' 60 | 'heads.') 61 | self.distributional_layer = torch.nn.Linear(input_size, self.num_atoms) 62 | if self.fn: 63 | self.distributional_layer.apply(self.fn) 64 | 65 | def forward(self, inputs): 66 | logits = self.distributional_layer(inputs) 67 | return CategoricalWithSupport(values=self.values, logits=logits) 68 | 69 | 70 | class Critic(torch.nn.Module): 71 | def __init__(self, encoder, torso, head): 72 | super().__init__() 73 | self.encoder = encoder 74 | self.torso = torso 75 | self.head = head 76 | 77 | def initialize( 78 | self, observation_space, action_space, observation_normalizer=None, 79 | return_normalizer=None 80 | ): 81 | size = self.encoder.initialize( 82 | observation_space=observation_space, action_space=action_space, 83 | observation_normalizer=observation_normalizer) 84 | size = self.torso.initialize(size) 85 | self.head.initialize(size, return_normalizer) 86 | 87 | def forward(self, *inputs): 88 | out = self.encoder(*inputs) 89 | out = self.torso(out) 90 | return self.head(out) 91 | -------------------------------------------------------------------------------- /tonic/torch/models/encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ObservationEncoder(torch.nn.Module): 5 | def initialize( 6 | self, observation_space, action_space=None, 7 | observation_normalizer=None, 8 | ): 9 | self.observation_normalizer = observation_normalizer 10 | observation_size = observation_space.shape[0] 11 | return observation_size 12 | 13 | def forward(self, observations): 14 | if self.observation_normalizer: 15 | observations = self.observation_normalizer(observations) 16 | return observations 17 | 18 | 19 | class ObservationActionEncoder(torch.nn.Module): 20 | def initialize( 21 | self, observation_space, action_space, observation_normalizer=None 22 | ): 23 | self.observation_normalizer = observation_normalizer 24 | observation_size = observation_space.shape[0] 25 | action_size = action_space.shape[0] 26 | return observation_size + action_size 27 | 28 | def forward(self, observations, actions): 29 | if self.observation_normalizer: 30 | observations = self.observation_normalizer(observations) 31 | return torch.cat([observations, actions], dim=-1) 32 | -------------------------------------------------------------------------------- /tonic/torch/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class MLP(torch.nn.Module): 5 | def __init__(self, sizes, activation, fn=None): 6 | super().__init__() 7 | self.sizes = sizes 8 | self.activation = activation 9 | self.fn = fn 10 | 11 | def initialize(self, input_size): 12 | sizes = [input_size] + list(self.sizes) 13 | layers = [] 14 | for i in range(len(sizes) - 1): 15 | layers += [torch.nn.Linear(sizes[i], sizes[i + 1]), 16 | self.activation()] 17 | self.model = torch.nn.Sequential(*layers) 18 | if self.fn is not None: 19 | self.model.apply(self.fn) 20 | return sizes[-1] 21 | 22 | def forward(self, inputs): 23 | return self.model(inputs) 24 | 25 | 26 | def trainable_variables(model): 27 | return [p for p in model.parameters() if p.requires_grad] 28 | -------------------------------------------------------------------------------- /tonic/torch/normalizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .mean_stds import MeanStd 2 | from .returns import Return 3 | 4 | 5 | __all__ = [MeanStd, Return] 6 | -------------------------------------------------------------------------------- /tonic/torch/normalizers/mean_stds.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class MeanStd(torch.nn.Module): 6 | def __init__(self, mean=0, std=1, clip=None, shape=None): 7 | super().__init__() 8 | self.mean = mean 9 | self.std = std 10 | self.clip = clip 11 | self.count = 0 12 | self.new_sum = 0 13 | self.new_sum_sq = 0 14 | self.new_count = 0 15 | self.eps = 1e-2 16 | if shape: 17 | self.initialize(shape) 18 | 19 | def initialize(self, shape): 20 | if isinstance(self.mean, (int, float)): 21 | self.mean = np.full(shape, self.mean, np.float32) 22 | else: 23 | self.mean = np.array(self.mean, np.float32) 24 | if isinstance(self.std, (int, float)): 25 | self.std = np.full(shape, self.std, np.float32) 26 | else: 27 | self.std = np.array(self.std, np.float32) 28 | self.mean_sq = np.square(self.mean) 29 | self._mean = torch.nn.Parameter(torch.as_tensor( 30 | self.mean, dtype=torch.float32), requires_grad=False) 31 | self._std = torch.nn.Parameter(torch.as_tensor( 32 | self.std, dtype=torch.float32), requires_grad=False) 33 | 34 | def forward(self, val): 35 | with torch.no_grad(): 36 | val = (val - self._mean) / self._std 37 | if self.clip is not None: 38 | val = torch.clamp(val, -self.clip, self.clip) 39 | return val 40 | 41 | def unnormalize(self, val): 42 | return val * self._std + self._mean 43 | 44 | def record(self, values): 45 | for val in values: 46 | self.new_sum += val 47 | self.new_sum_sq += np.square(val) 48 | self.new_count += 1 49 | 50 | def update(self): 51 | new_count = self.count + self.new_count 52 | new_mean = self.new_sum / self.new_count 53 | new_mean_sq = self.new_sum_sq / self.new_count 54 | w_old = self.count / new_count 55 | w_new = self.new_count / new_count 56 | self.mean = w_old * self.mean + w_new * new_mean 57 | self.mean_sq = w_old * self.mean_sq + w_new * new_mean_sq 58 | self.std = self._compute_std(self.mean, self.mean_sq) 59 | self.count = new_count 60 | self.new_count = 0 61 | self.new_sum = 0 62 | self.new_sum_sq = 0 63 | self._update(self.mean.astype(np.float32), self.std.astype(np.float32)) 64 | 65 | def _compute_std(self, mean, mean_sq): 66 | var = mean_sq - np.square(mean) 67 | var = np.maximum(var, 0) 68 | std = np.sqrt(var) 69 | std = np.maximum(std, self.eps) 70 | return std 71 | 72 | def _update(self, mean, std): 73 | self._mean.data.copy_(torch.as_tensor(self.mean, dtype=torch.float32)) 74 | self._std.data.copy_(torch.as_tensor(self.std, dtype=torch.float32)) 75 | -------------------------------------------------------------------------------- /tonic/torch/normalizers/returns.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class Return(torch.nn.Module): 6 | def __init__(self, discount_factor): 7 | super().__init__() 8 | assert 0 <= discount_factor < 1 9 | self.coefficient = 1 / (1 - discount_factor) 10 | self.min_reward = np.float32(-1) 11 | self.max_reward = np.float32(1) 12 | self._low = torch.nn.Parameter(torch.as_tensor( 13 | self.coefficient * self.min_reward, dtype=torch.float32), 14 | requires_grad=False) 15 | self._high = torch.nn.Parameter(torch.as_tensor( 16 | self.coefficient * self.max_reward, dtype=torch.float32), 17 | requires_grad=False) 18 | 19 | def forward(self, val): 20 | val = torch.sigmoid(val) 21 | return self._low + val * (self._high - self._low) 22 | 23 | def record(self, values): 24 | for val in values: 25 | if val < self.min_reward: 26 | self.min_reward = np.float32(val) 27 | elif val > self.max_reward: 28 | self.max_reward = np.float32(val) 29 | 30 | def update(self): 31 | self._update(self.min_reward, self.max_reward) 32 | 33 | def _update(self, min_reward, max_reward): 34 | self._low.data.copy_(torch.as_tensor( 35 | self.coefficient * min_reward, dtype=torch.float32)) 36 | self._high.data.copy_(torch.as_tensor( 37 | self.coefficient * max_reward, dtype=torch.float32)) 38 | -------------------------------------------------------------------------------- /tonic/torch/updaters/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import merge_first_two_dims 2 | from .utils import tile 3 | 4 | from .actors import ClippedRatio # noqa 5 | from .actors import DeterministicPolicyGradient 6 | from .actors import DistributionalDeterministicPolicyGradient 7 | from .actors import MaximumAPosterioriPolicyOptimization 8 | from .actors import StochasticPolicyGradient 9 | from .actors import TrustRegionPolicyGradient 10 | from .actors import TwinCriticSoftDeterministicPolicyGradient 11 | 12 | from .critics import DeterministicQLearning 13 | from .critics import DistributionalDeterministicQLearning 14 | from .critics import ExpectedSARSA 15 | from .critics import QRegression 16 | from .critics import TargetActionNoise 17 | from .critics import TwinCriticDeterministicQLearning 18 | from .critics import TwinCriticSoftQLearning 19 | from .critics import VRegression 20 | 21 | from .optimizers import ConjugateGradient 22 | 23 | 24 | __all__ = [ 25 | merge_first_two_dims, tile, ClippedRatio, DeterministicPolicyGradient, 26 | DistributionalDeterministicPolicyGradient, 27 | MaximumAPosterioriPolicyOptimization, StochasticPolicyGradient, 28 | TrustRegionPolicyGradient, TwinCriticSoftDeterministicPolicyGradient, 29 | DeterministicQLearning, DistributionalDeterministicQLearning, 30 | ExpectedSARSA, QRegression, TargetActionNoise, 31 | TwinCriticDeterministicQLearning, TwinCriticSoftQLearning, VRegression, 32 | ConjugateGradient] 33 | -------------------------------------------------------------------------------- /tonic/torch/updaters/critics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from tonic.torch import models, updaters # noqa 4 | 5 | 6 | class VRegression: 7 | def __init__(self, loss=None, optimizer=None, gradient_clip=0): 8 | self.loss = loss or torch.nn.MSELoss() 9 | self.optimizer = optimizer or ( 10 | lambda params: torch.optim.Adam(params, lr=1e-3)) 11 | self.gradient_clip = gradient_clip 12 | 13 | def initialize(self, model): 14 | self.model = model 15 | self.variables = models.trainable_variables(self.model.critic) 16 | self.optimizer = self.optimizer(self.variables) 17 | 18 | def __call__(self, observations, returns): 19 | self.optimizer.zero_grad() 20 | values = self.model.critic(observations) 21 | loss = self.loss(values, returns) 22 | 23 | loss.backward() 24 | if self.gradient_clip > 0: 25 | torch.nn.utils.clip_grad_norm_(self.variables, self.gradient_clip) 26 | self.optimizer.step() 27 | 28 | return dict(loss=loss.detach(), v=values.detach()) 29 | 30 | 31 | class QRegression: 32 | def __init__(self, loss=None, optimizer=None, gradient_clip=0): 33 | self.loss = loss or torch.nn.MSELoss() 34 | self.optimizer = optimizer or ( 35 | lambda params: torch.optim.Adam(params, lr=1e-3)) 36 | self.gradient_clip = gradient_clip 37 | 38 | def initialize(self, model): 39 | self.model = model 40 | self.variables = models.trainable_variables(self.model.critic) 41 | self.optimizer = self.optimizer(self.variables) 42 | 43 | def __call__(self, observations, actions, returns): 44 | self.optimizer.zero_grad() 45 | values = self.model.critic(observations, actions) 46 | loss = self.loss(values, returns) 47 | 48 | loss.backward() 49 | if self.gradient_clip > 0: 50 | torch.nn.utils.clip_grad_norm_(self.variables, self.gradient_clip) 51 | self.optimizer.step() 52 | 53 | return dict(loss=loss.detach(), q=values.detach()) 54 | 55 | 56 | class DeterministicQLearning: 57 | def __init__(self, loss=None, optimizer=None, gradient_clip=0): 58 | self.loss = loss or torch.nn.MSELoss() 59 | self.optimizer = optimizer or ( 60 | lambda params: torch.optim.Adam(params, lr=1e-3)) 61 | self.gradient_clip = gradient_clip 62 | 63 | def initialize(self, model): 64 | self.model = model 65 | self.variables = models.trainable_variables(self.model.critic) 66 | self.optimizer = self.optimizer(self.variables) 67 | 68 | def __call__( 69 | self, observations, actions, next_observations, rewards, discounts 70 | ): 71 | with torch.no_grad(): 72 | next_actions = self.model.target_actor(next_observations) 73 | next_values = self.model.target_critic( 74 | next_observations, next_actions) 75 | returns = rewards + discounts * next_values 76 | 77 | self.optimizer.zero_grad() 78 | values = self.model.critic(observations, actions) 79 | loss = self.loss(values, returns) 80 | 81 | loss.backward() 82 | if self.gradient_clip > 0: 83 | torch.nn.utils.clip_grad_norm_(self.variables, self.gradient_clip) 84 | self.optimizer.step() 85 | 86 | return dict(loss=loss.detach(), q=values.detach()) 87 | 88 | 89 | class DistributionalDeterministicQLearning: 90 | def __init__(self, optimizer=None, gradient_clip=0): 91 | self.optimizer = optimizer or ( 92 | lambda params: torch.optim.Adam(params, lr=1e-3)) 93 | self.gradient_clip = gradient_clip 94 | 95 | def initialize(self, model): 96 | self.model = model 97 | self.variables = models.trainable_variables(self.model.critic) 98 | self.optimizer = self.optimizer(self.variables) 99 | 100 | def __call__( 101 | self, observations, actions, next_observations, rewards, discounts 102 | ): 103 | with torch.no_grad(): 104 | next_actions = self.model.target_actor(next_observations) 105 | next_value_distributions = self.model.target_critic( 106 | next_observations, next_actions) 107 | values = next_value_distributions.values 108 | returns = rewards[:, None] + discounts[:, None] * values 109 | targets = next_value_distributions.project(returns) 110 | 111 | self.optimizer.zero_grad() 112 | value_distributions = self.model.critic(observations, actions) 113 | log_probabilities = torch.nn.functional.log_softmax( 114 | value_distributions.logits, dim=-1) 115 | loss = -(targets * log_probabilities).sum(dim=-1).mean() 116 | 117 | loss.backward() 118 | if self.gradient_clip > 0: 119 | torch.nn.utils.clip_grad_norm_(self.variables, self.gradient_clip) 120 | self.optimizer.step() 121 | 122 | return dict(loss=loss.detach()) 123 | 124 | 125 | class TargetActionNoise: 126 | def __init__(self, scale=0.2, clip=0.5): 127 | self.scale = scale 128 | self.clip = clip 129 | 130 | def __call__(self, actions): 131 | noises = self.scale * torch.randn_like(actions) 132 | noises = torch.clamp(noises, -self.clip, self.clip) 133 | actions = actions + noises 134 | return torch.clamp(actions, -1, 1) 135 | 136 | 137 | class TwinCriticDeterministicQLearning: 138 | def __init__( 139 | self, loss=None, optimizer=None, target_action_noise=None, 140 | gradient_clip=0 141 | ): 142 | self.loss = loss or torch.nn.MSELoss() 143 | self.optimizer = optimizer or ( 144 | lambda params: torch.optim.Adam(params, lr=1e-3)) 145 | self.target_action_noise = target_action_noise or \ 146 | TargetActionNoise(scale=0.2, clip=0.5) 147 | self.gradient_clip = gradient_clip 148 | 149 | def initialize(self, model): 150 | self.model = model 151 | variables_1 = models.trainable_variables(self.model.critic_1) 152 | variables_2 = models.trainable_variables(self.model.critic_2) 153 | self.variables = variables_1 + variables_2 154 | self.optimizer = self.optimizer(self.variables) 155 | 156 | def __call__( 157 | self, observations, actions, next_observations, rewards, discounts 158 | ): 159 | with torch.no_grad(): 160 | next_actions = self.model.target_actor(next_observations) 161 | next_actions = self.target_action_noise(next_actions) 162 | next_values_1 = self.model.target_critic_1( 163 | next_observations, next_actions) 164 | next_values_2 = self.model.target_critic_2( 165 | next_observations, next_actions) 166 | next_values = torch.min(next_values_1, next_values_2) 167 | returns = rewards + discounts * next_values 168 | 169 | self.optimizer.zero_grad() 170 | values_1 = self.model.critic_1(observations, actions) 171 | values_2 = self.model.critic_2(observations, actions) 172 | loss_1 = self.loss(values_1, returns) 173 | loss_2 = self.loss(values_2, returns) 174 | loss = loss_1 + loss_2 175 | 176 | loss.backward() 177 | if self.gradient_clip > 0: 178 | torch.nn.utils.clip_grad_norm_(self.variables, self.gradient_clip) 179 | self.optimizer.step() 180 | 181 | return dict( 182 | loss=loss.detach(), q1=values_1.detach(), q2=values_2.detach()) 183 | 184 | 185 | class TwinCriticSoftQLearning: 186 | def __init__( 187 | self, loss=None, optimizer=None, entropy_coeff=0.2, gradient_clip=0 188 | ): 189 | self.loss = loss or torch.nn.MSELoss() 190 | self.optimizer = optimizer or ( 191 | lambda params: torch.optim.Adam(params, lr=3e-4)) 192 | self.entropy_coeff = entropy_coeff 193 | self.gradient_clip = gradient_clip 194 | 195 | def initialize(self, model): 196 | self.model = model 197 | variables_1 = models.trainable_variables(self.model.critic_1) 198 | variables_2 = models.trainable_variables(self.model.critic_2) 199 | self.variables = variables_1 + variables_2 200 | self.optimizer = self.optimizer(self.variables) 201 | 202 | def __call__( 203 | self, observations, actions, next_observations, rewards, discounts 204 | ): 205 | with torch.no_grad(): 206 | next_distributions = self.model.actor(next_observations) 207 | if hasattr(next_distributions, 'rsample_with_log_prob'): 208 | outs = next_distributions.rsample_with_log_prob() 209 | next_actions, next_log_probs = outs 210 | else: 211 | next_actions = next_distributions.rsample() 212 | next_log_probs = next_distributions.log_prob(next_actions) 213 | next_log_probs = next_log_probs.sum(dim=-1) 214 | next_values_1 = self.model.target_critic_1( 215 | next_observations, next_actions) 216 | next_values_2 = self.model.target_critic_2( 217 | next_observations, next_actions) 218 | next_values = torch.min(next_values_1, next_values_2) 219 | returns = rewards + discounts * ( 220 | next_values - self.entropy_coeff * next_log_probs) 221 | 222 | self.optimizer.zero_grad() 223 | values_1 = self.model.critic_1(observations, actions) 224 | values_2 = self.model.critic_2(observations, actions) 225 | loss_1 = self.loss(values_1, returns) 226 | loss_2 = self.loss(values_2, returns) 227 | loss = loss_1 + loss_2 228 | 229 | loss.backward() 230 | if self.gradient_clip > 0: 231 | torch.nn.utils.clip_grad_norm_(self.variables, self.gradient_clip) 232 | self.optimizer.step() 233 | 234 | return dict( 235 | loss=loss.detach(), q1=values_1.detach(), q2=values_2.detach()) 236 | 237 | 238 | class ExpectedSARSA: 239 | def __init__( 240 | self, num_samples=20, loss=None, optimizer=None, gradient_clip=0 241 | ): 242 | self.num_samples = num_samples 243 | self.loss = loss or torch.nn.MSELoss() 244 | self.optimizer = optimizer or ( 245 | lambda params: torch.optim.Adam(params, lr=3e-4)) 246 | self.gradient_clip = gradient_clip 247 | 248 | def initialize(self, model): 249 | self.model = model 250 | self.variables = models.trainable_variables(self.model.critic) 251 | self.optimizer = self.optimizer(self.variables) 252 | 253 | def __call__( 254 | self, observations, actions, next_observations, rewards, discounts 255 | ): 256 | # Approximate the expected next values. 257 | with torch.no_grad(): 258 | next_target_distributions = self.model.target_actor( 259 | next_observations) 260 | next_actions = next_target_distributions.rsample( 261 | (self.num_samples,)) 262 | next_actions = updaters.merge_first_two_dims(next_actions) 263 | next_observations = updaters.tile( 264 | next_observations, self.num_samples) 265 | next_observations = updaters.merge_first_two_dims( 266 | next_observations) 267 | next_values = self.model.target_critic( 268 | next_observations, next_actions) 269 | next_values = next_values.view(self.num_samples, -1) 270 | next_values = next_values.mean(dim=0) 271 | returns = rewards + discounts * next_values 272 | 273 | self.optimizer.zero_grad() 274 | values = self.model.critic(observations, actions) 275 | loss = self.loss(returns, values) 276 | 277 | loss.backward() 278 | if self.gradient_clip > 0: 279 | torch.nn.utils.clip_grad_norm_(self.variables, self.gradient_clip) 280 | self.optimizer.step() 281 | 282 | return dict(loss=loss.detach(), q=values.detach()) 283 | -------------------------------------------------------------------------------- /tonic/torch/updaters/optimizers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | FLOAT_EPSILON = 1e-8 6 | 7 | 8 | def flat_concat(xs): 9 | return torch.cat([torch.reshape(x, (-1,)) for x in xs], dim=0) 10 | 11 | 12 | def assign_params_from_flat(new_params, params): 13 | def flat_size(p): 14 | return int(np.prod(p.shape)) 15 | splits = torch.split(new_params, [flat_size(p) for p in params]) 16 | new_params = [torch.reshape(p_new, p.shape) 17 | for p, p_new in zip(params, splits)] 18 | for p, p_new in zip(params, new_params): 19 | p.data.copy_(p_new) 20 | 21 | 22 | class ConjugateGradient: 23 | def __init__( 24 | self, conjugate_gradient_steps=10, damping_coefficient=0.1, 25 | constraint_threshold=0.01, backtrack_steps=10, 26 | backtrack_coefficient=0.8 27 | ): 28 | self.conjugate_gradient_steps = conjugate_gradient_steps 29 | self.damping_coefficient = damping_coefficient 30 | self.constraint_threshold = constraint_threshold 31 | self.backtrack_steps = backtrack_steps 32 | self.backtrack_coefficient = backtrack_coefficient 33 | 34 | def optimize(self, loss_function, constraint_function, variables): 35 | def _hx(x): 36 | f = constraint_function() 37 | gradient_1 = torch.autograd.grad(f, variables, create_graph=True) 38 | gradient_1 = flat_concat(gradient_1) 39 | x = torch.as_tensor(x) 40 | y = (gradient_1 * x).sum() 41 | gradient_2 = torch.autograd.grad(y, variables) 42 | gradient_2 = flat_concat(gradient_2) 43 | 44 | if self.damping_coefficient > 0: 45 | gradient_2 += self.damping_coefficient * x 46 | 47 | return gradient_2 48 | 49 | def _cg(b): 50 | x = np.zeros_like(b) 51 | r = b.copy() 52 | p = r.copy() 53 | r_dot_old = np.dot(r, r) 54 | if r_dot_old == 0: 55 | return None 56 | 57 | for _ in range(self.conjugate_gradient_steps): 58 | z = _hx(p).numpy() 59 | alpha = r_dot_old / (np.dot(p, z) + FLOAT_EPSILON) 60 | x += alpha * p 61 | r -= alpha * z 62 | r_dot_new = np.dot(r, r) 63 | p = r + (r_dot_new / r_dot_old) * p 64 | r_dot_old = r_dot_new 65 | return x 66 | 67 | def _update(alpha, conjugate_gradient, step, start_variables): 68 | conjugate_gradient = torch.as_tensor(conjugate_gradient) 69 | new_variables = start_variables - alpha * conjugate_gradient * step 70 | assign_params_from_flat(new_variables, variables) 71 | constraint = constraint_function() 72 | loss = loss_function() 73 | return constraint.detach(), loss.detach() 74 | 75 | start_variables = flat_concat(variables) 76 | 77 | for var in variables: 78 | if var.grad: 79 | var.grad.data.zero_() 80 | 81 | loss = loss_function() 82 | grad = torch.autograd.grad(loss, variables) 83 | grad = flat_concat(grad).numpy() 84 | start_loss = loss.detach().numpy() 85 | 86 | conjugate_gradient = _cg(grad) 87 | if conjugate_gradient is None: 88 | constraint = torch.as_tensor(0., dtype=torch.float32) 89 | loss = torch.as_tensor(0., dtype=torch.float32) 90 | steps = torch.as_tensor(0, dtype=torch.int32) 91 | return constraint, loss, steps 92 | 93 | alpha = np.sqrt(2 * self.constraint_threshold / np.dot( 94 | conjugate_gradient, _hx(conjugate_gradient)) + FLOAT_EPSILON) 95 | 96 | if self.backtrack_steps is None or self.backtrack_coefficient is None: 97 | constraint, loss = _update( 98 | alpha, conjugate_gradient, 1, start_variables) 99 | return constraint, loss 100 | 101 | for i in range(self.backtrack_steps): 102 | constraint, loss = _update( 103 | alpha, conjugate_gradient, self.backtrack_coefficient ** i, 104 | start_variables) 105 | 106 | if (constraint.numpy() <= self.constraint_threshold and 107 | loss.numpy() <= start_loss): 108 | break 109 | 110 | if i == self.backtrack_steps - 1: 111 | constraint, loss = _update( 112 | alpha, conjugate_gradient, 0, start_variables) 113 | i = self.backtrack_steps 114 | 115 | return constraint, loss, torch.as_tensor(i + 1, dtype=torch.int32) 116 | -------------------------------------------------------------------------------- /tonic/torch/updaters/utils.py: -------------------------------------------------------------------------------- 1 | def tile(x, n): 2 | return x[None].repeat([n] + [1] * len(x.shape)) 3 | 4 | 5 | def merge_first_two_dims(x): 6 | return x.view(x.shape[0] * x.shape[1], *x.shape[2:]) 7 | -------------------------------------------------------------------------------- /tonic/train.py: -------------------------------------------------------------------------------- 1 | '''Script used to train agents.''' 2 | 3 | import argparse 4 | import os 5 | 6 | import tonic 7 | import yaml 8 | 9 | 10 | def train( 11 | header, agent, environment, test_environment, trainer, before_training, 12 | after_training, parallel, sequential, seed, name, environment_name, 13 | checkpoint, path 14 | ): 15 | '''Trains an agent on an environment.''' 16 | 17 | # Capture the arguments to save them, e.g. to play with the trained agent. 18 | args = dict(locals()) 19 | 20 | checkpoint_path = None 21 | 22 | # Process the checkpoint path same way as in tonic.play 23 | if path: 24 | tonic.logger.log(f'Loading experiment from {path}') 25 | 26 | # Use no checkpoint, the agent is freshly created. 27 | if checkpoint == 'none' or agent is not None: 28 | tonic.logger.log('Not loading any weights') 29 | 30 | else: 31 | checkpoint_path = os.path.join(path, 'checkpoints') 32 | if not os.path.isdir(checkpoint_path): 33 | tonic.logger.error(f'{checkpoint_path} is not a directory') 34 | checkpoint_path = None 35 | 36 | # List all the checkpoints. 37 | checkpoint_ids = [] 38 | for file in os.listdir(checkpoint_path): 39 | if file[:5] == 'step_': 40 | checkpoint_id = file.split('.')[0] 41 | checkpoint_ids.append(int(checkpoint_id[5:])) 42 | 43 | if checkpoint_ids: 44 | # Use the last checkpoint. 45 | if checkpoint == 'last': 46 | checkpoint_id = max(checkpoint_ids) 47 | checkpoint_path = os.path.join( 48 | checkpoint_path, f'step_{checkpoint_id}') 49 | 50 | # Use the specified checkpoint. 51 | else: 52 | checkpoint_id = int(checkpoint) 53 | if checkpoint_id in checkpoint_ids: 54 | checkpoint_path = os.path.join( 55 | checkpoint_path, f'step_{checkpoint_id}') 56 | else: 57 | tonic.logger.error(f'Checkpoint {checkpoint_id} ' 58 | f'not found in {checkpoint_path}') 59 | checkpoint_path = None 60 | 61 | else: 62 | tonic.logger.error(f'No checkpoint found in {checkpoint_path}') 63 | checkpoint_path = None 64 | 65 | # Load the experiment configuration. 66 | arguments_path = os.path.join(path, 'config.yaml') 67 | with open(arguments_path, 'r') as config_file: 68 | config = yaml.load(config_file, Loader=yaml.FullLoader) 69 | config = argparse.Namespace(**config) 70 | 71 | header = header or config.header 72 | agent = agent or config.agent 73 | environment = environment or config.test_environment 74 | environment = environment or config.environment 75 | trainer = trainer or config.trainer 76 | 77 | # Run the header first, e.g. to load an ML framework. 78 | if header: 79 | exec(header) 80 | 81 | # Build the training environment. 82 | _environment = environment 83 | environment = tonic.environments.distribute( 84 | lambda: eval(_environment), parallel, sequential) 85 | environment.initialize(seed=seed) 86 | 87 | # Build the testing environment. 88 | _test_environment = test_environment if test_environment else _environment 89 | test_environment = tonic.environments.distribute( 90 | lambda: eval(_test_environment)) 91 | test_environment.initialize(seed=seed + 10000) 92 | 93 | # Build the agent. 94 | if not agent: 95 | raise ValueError('No agent specified.') 96 | agent = eval(agent) 97 | agent.initialize( 98 | observation_space=environment.observation_space, 99 | action_space=environment.action_space, seed=seed) 100 | 101 | # Load the weights of the agent form a checkpoint. 102 | if checkpoint_path: 103 | agent.load(checkpoint_path) 104 | 105 | # Initialize the logger to save data to the path environment/name/seed. 106 | if not environment_name: 107 | if hasattr(test_environment, 'name'): 108 | environment_name = test_environment.name 109 | else: 110 | environment_name = test_environment.__class__.__name__ 111 | if not name: 112 | if hasattr(agent, 'name'): 113 | name = agent.name 114 | else: 115 | name = agent.__class__.__name__ 116 | if parallel != 1 or sequential != 1: 117 | name += f'-{parallel}x{sequential}' 118 | path = os.path.join(environment_name, name, str(seed)) 119 | tonic.logger.initialize(path, script_path=__file__, config=args) 120 | 121 | # Build the trainer. 122 | trainer = trainer or 'tonic.Trainer()' 123 | trainer = eval(trainer) 124 | trainer.initialize( 125 | agent=agent, environment=environment, 126 | test_environment=test_environment) 127 | 128 | # Run some code before training. 129 | if before_training: 130 | exec(before_training) 131 | 132 | # Train. 133 | trainer.run() 134 | 135 | # Run some code after training. 136 | if after_training: 137 | exec(after_training) 138 | 139 | 140 | if __name__ == '__main__': 141 | # Argument parsing. 142 | parser = argparse.ArgumentParser() 143 | parser.add_argument('--header') 144 | parser.add_argument('--agent') 145 | parser.add_argument('--environment', '--env') 146 | parser.add_argument('--test_environment', '--test_env') 147 | parser.add_argument('--trainer') 148 | parser.add_argument('--before_training') 149 | parser.add_argument('--after_training') 150 | parser.add_argument('--parallel', type=int, default=1) 151 | parser.add_argument('--sequential', type=int, default=1) 152 | parser.add_argument('--seed', type=int, default=0) 153 | parser.add_argument('--name') 154 | parser.add_argument('--environment_name') 155 | parser.add_argument('--checkpoint', default='last') 156 | parser.add_argument('--path') 157 | 158 | args = vars(parser.parse_args()) 159 | train(**args) 160 | -------------------------------------------------------------------------------- /tonic/utils/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import termcolor 7 | import yaml 8 | 9 | 10 | current_logger = None 11 | 12 | 13 | class Logger: 14 | '''Logger used to display and save logs, and save experiment configs.''' 15 | 16 | def __init__(self, path=None, width=60, script_path=None, config=None): 17 | self.path = path or str(time.time()) 18 | self.log_file_path = os.path.join(self.path, 'log.csv') 19 | 20 | # Save the launch script. 21 | if script_path: 22 | with open(script_path, 'r') as script_file: 23 | script = script_file.read() 24 | try: 25 | os.makedirs(self.path, exist_ok=True) 26 | except Exception: 27 | pass 28 | script_path = os.path.join(self.path, 'script.py') 29 | with open(script_path, 'w') as config_file: 30 | config_file.write(script) 31 | log(f'Script file saved to {script_path}') 32 | 33 | # Save the configuration. 34 | if config: 35 | try: 36 | os.makedirs(self.path, exist_ok=True) 37 | except Exception: 38 | pass 39 | config_path = os.path.join(self.path, 'config.yaml') 40 | with open(config_path, 'w') as config_file: 41 | yaml.dump(config, config_file) 42 | log(f'Config file saved to {config_path}') 43 | 44 | self.known_keys = set() 45 | self.stat_keys = set() 46 | self.epoch_dict = {} 47 | self.width = width 48 | self.last_epoch_progress = None 49 | self.start_time = time.time() 50 | 51 | def store(self, key, value, stats=False): 52 | '''Keeps named values during an epoch.''' 53 | 54 | if key not in self.epoch_dict: 55 | self.epoch_dict[key] = [value] 56 | if stats: 57 | self.stat_keys.add(key) 58 | else: 59 | self.epoch_dict[key].append(value) 60 | 61 | def dump(self): 62 | '''Displays and saves the values at the end of an epoch.''' 63 | 64 | # Compute statistics if needed. 65 | keys = list(self.epoch_dict.keys()) 66 | for key in keys: 67 | values = self.epoch_dict[key] 68 | if key in self.stat_keys: 69 | self.epoch_dict[key + '/mean'] = np.mean(values) 70 | self.epoch_dict[key + '/std'] = np.std(values) 71 | self.epoch_dict[key + '/min'] = np.min(values) 72 | self.epoch_dict[key + '/max'] = np.max(values) 73 | self.epoch_dict[key + '/size'] = len(values) 74 | del self.epoch_dict[key] 75 | else: 76 | self.epoch_dict[key] = np.mean(values) 77 | 78 | # Check if new keys were added. 79 | new_keys = [key for key in self.epoch_dict.keys() 80 | if key not in self.known_keys] 81 | if new_keys: 82 | first_row = len(self.known_keys) == 0 83 | if not first_row: 84 | print() 85 | warning(f'Logging new keys {new_keys}') 86 | # List the keys and prepare the display layout. 87 | for key in new_keys: 88 | self.known_keys.add(key) 89 | self.final_keys = list(sorted(self.known_keys)) 90 | self.console_formats = [] 91 | known_keys = set() 92 | for key in self.final_keys: 93 | *left_keys, right_key = key.split('/') 94 | for i, k in enumerate(left_keys): 95 | left_key = '/'.join(left_keys[:i + 1]) 96 | if left_key not in known_keys: 97 | left = ' ' * i + k.replace('_', ' ') 98 | self.console_formats.append((left, None)) 99 | known_keys.add(left_key) 100 | indent = ' ' * len(left_keys) 101 | right_key = right_key.replace('_', ' ') 102 | self.console_formats.append((indent + right_key, key)) 103 | 104 | # Display the values following the layout. 105 | print() 106 | for left, key in self.console_formats: 107 | if key: 108 | val = self.epoch_dict.get(key) 109 | str_type = str(type(val)) 110 | if 'tensorflow' in str_type: 111 | warning(f'Logging TensorFlow tensor {key}') 112 | elif 'torch' in str_type: 113 | warning(f'Logging Torch tensor {key}') 114 | if np.issubdtype(type(val), np.floating): 115 | right = f'{val:8.3g}' 116 | elif np.issubdtype(type(val), np.integer): 117 | right = f'{val:,}' 118 | else: 119 | right = str(val) 120 | spaces = ' ' * (self.width - len(left) - len(right)) 121 | print(left + spaces + right) 122 | else: 123 | spaces = ' ' * (self.width - len(left)) 124 | print(left + spaces) 125 | print() 126 | 127 | # Save the data to the log file 128 | vals = [self.epoch_dict.get(key) for key in self.final_keys] 129 | if new_keys: 130 | if first_row: 131 | log(f'Logging data to {self.log_file_path}') 132 | try: 133 | os.makedirs(self.path, exist_ok=True) 134 | except Exception: 135 | pass 136 | with open(self.log_file_path, 'w') as file: 137 | file.write(','.join(self.final_keys) + '\n') 138 | file.write(','.join(map(str, vals)) + '\n') 139 | else: 140 | with open(self.log_file_path, 'r') as file: 141 | lines = file.read().splitlines() 142 | old_keys = lines[0].split(',') 143 | old_lines = [line.split(',') for line in lines[1:]] 144 | new_indices = [] 145 | j = 0 146 | for i, key in enumerate(self.final_keys): 147 | if key == old_keys[j]: 148 | j += 1 149 | else: 150 | new_indices.append(i) 151 | assert j == len(old_keys) 152 | for line in old_lines: 153 | for i in new_indices: 154 | line.insert(i, 'None') 155 | with open(self.log_file_path, 'w') as file: 156 | file.write(','.join(self.final_keys) + '\n') 157 | for line in old_lines: 158 | file.write(','.join(line) + '\n') 159 | file.write(','.join(map(str, vals)) + '\n') 160 | else: 161 | with open(self.log_file_path, 'a') as file: 162 | file.write(','.join(map(str, vals)) + '\n') 163 | 164 | self.epoch_dict.clear() 165 | self.last_epoch_progress = None 166 | self.last_epoch_time = time.time() 167 | 168 | def show_progress( 169 | self, steps, num_epoch_steps, num_steps, color='white', 170 | on_color='on_blue' 171 | ): 172 | '''Shows a progress bar for the current epoch and total training.''' 173 | 174 | epoch_steps = (steps - 1) % num_epoch_steps + 1 175 | epoch_progress = int(self.width * epoch_steps / num_epoch_steps) 176 | if epoch_progress != self.last_epoch_progress: 177 | current_time = time.time() 178 | seconds = current_time - self.start_time 179 | seconds_per_step = seconds / steps 180 | epoch_rem_steps = num_epoch_steps - epoch_steps 181 | epoch_rem_secs = max(epoch_rem_steps * seconds_per_step, 0) 182 | epoch_rem_secs = datetime.timedelta(seconds=epoch_rem_secs + 1e-6) 183 | epoch_rem_secs = str(epoch_rem_secs)[:-7] 184 | total_rem_steps = num_steps - steps 185 | total_rem_secs = max(total_rem_steps * seconds_per_step, 0) 186 | total_rem_secs = datetime.timedelta(seconds=total_rem_secs) 187 | total_rem_secs = str(total_rem_secs)[:-7] 188 | msg = f'Time left: epoch {epoch_rem_secs} total {total_rem_secs}' 189 | msg = msg.center(self.width) 190 | print(termcolor.colored( 191 | '\r' + msg[:epoch_progress], color, on_color), end='') 192 | print(msg[epoch_progress:], sep='', end='') 193 | self.last_epoch_progress = epoch_progress 194 | 195 | 196 | def initialize(*args, **kwargs): 197 | global current_logger 198 | current_logger = Logger(*args, **kwargs) 199 | return current_logger 200 | 201 | 202 | def get_current_logger(): 203 | global current_logger 204 | if current_logger is None: 205 | current_logger = Logger() 206 | return current_logger 207 | 208 | 209 | def store(*args, **kwargs): 210 | logger = get_current_logger() 211 | return logger.store(*args, **kwargs) 212 | 213 | 214 | def dump(*args, **kwargs): 215 | logger = get_current_logger() 216 | return logger.dump(*args, **kwargs) 217 | 218 | 219 | def show_progress(*args, **kwargs): 220 | logger = get_current_logger() 221 | return logger.show_progress(*args, **kwargs) 222 | 223 | 224 | def get_path(): 225 | logger = get_current_logger() 226 | return logger.path 227 | 228 | 229 | def log(msg, color='green'): 230 | print(termcolor.colored(msg, color, attrs=['bold'])) 231 | 232 | 233 | def warning(msg, color='yellow'): 234 | print(termcolor.colored('Warning: ' + msg, color, attrs=['bold'])) 235 | 236 | 237 | def error(msg, color='red'): 238 | print(termcolor.colored('Error: ' + msg, color, attrs=['bold'])) 239 | -------------------------------------------------------------------------------- /tonic/utils/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import numpy as np 5 | 6 | from tonic import logger 7 | 8 | 9 | class Trainer: 10 | '''Trainer used to train and evaluate an agent on an environment.''' 11 | 12 | def __init__( 13 | self, steps=int(1e7), epoch_steps=int(2e4), save_steps=int(5e5), 14 | test_episodes=5, show_progress=True, replace_checkpoint=False, 15 | ): 16 | self.max_steps = steps 17 | self.epoch_steps = epoch_steps 18 | self.save_steps = save_steps 19 | self.test_episodes = test_episodes 20 | self.show_progress = show_progress 21 | self.replace_checkpoint = replace_checkpoint 22 | 23 | def initialize(self, agent, environment, test_environment=None): 24 | self.agent = agent 25 | self.environment = environment 26 | self.test_environment = test_environment 27 | 28 | def run(self): 29 | '''Runs the main training loop.''' 30 | 31 | start_time = last_epoch_time = time.time() 32 | 33 | # Start the environments. 34 | observations = self.environment.start() 35 | 36 | num_workers = len(observations) 37 | scores = np.zeros(num_workers) 38 | lengths = np.zeros(num_workers, int) 39 | self.steps, epoch_steps, epochs, episodes = 0, 0, 0, 0 40 | steps_since_save = 0 41 | 42 | while True: 43 | # Select actions. 44 | actions = self.agent.step(observations, self.steps) 45 | assert not np.isnan(actions.sum()) 46 | logger.store('train/action', actions, stats=True) 47 | 48 | # Take a step in the environments. 49 | observations, infos = self.environment.step(actions) 50 | self.agent.update(**infos, steps=self.steps) 51 | 52 | scores += infos['rewards'] 53 | lengths += 1 54 | self.steps += num_workers 55 | epoch_steps += num_workers 56 | steps_since_save += num_workers 57 | 58 | # Show the progress bar. 59 | if self.show_progress: 60 | logger.show_progress( 61 | self.steps, self.epoch_steps, self.max_steps) 62 | 63 | # Check the finished episodes. 64 | for i in range(num_workers): 65 | if infos['resets'][i]: 66 | logger.store('train/episode_score', scores[i], stats=True) 67 | logger.store( 68 | 'train/episode_length', lengths[i], stats=True) 69 | scores[i] = 0 70 | lengths[i] = 0 71 | episodes += 1 72 | 73 | # End of the epoch. 74 | if epoch_steps >= self.epoch_steps: 75 | # Evaluate the agent on the test environment. 76 | if self.test_environment: 77 | self._test() 78 | 79 | # Log the data. 80 | epochs += 1 81 | current_time = time.time() 82 | epoch_time = current_time - last_epoch_time 83 | sps = epoch_steps / epoch_time 84 | logger.store('train/episodes', episodes) 85 | logger.store('train/epochs', epochs) 86 | logger.store('train/seconds', current_time - start_time) 87 | logger.store('train/epoch_seconds', epoch_time) 88 | logger.store('train/epoch_steps', epoch_steps) 89 | logger.store('train/steps', self.steps) 90 | logger.store('train/worker_steps', self.steps // num_workers) 91 | logger.store('train/steps_per_second', sps) 92 | logger.dump() 93 | last_epoch_time = time.time() 94 | epoch_steps = 0 95 | 96 | # End of training. 97 | stop_training = self.steps >= self.max_steps 98 | 99 | # Save a checkpoint. 100 | if stop_training or steps_since_save >= self.save_steps: 101 | path = os.path.join(logger.get_path(), 'checkpoints') 102 | if os.path.isdir(path) and self.replace_checkpoint: 103 | for file in os.listdir(path): 104 | if file.startswith('step_'): 105 | os.remove(os.path.join(path, file)) 106 | checkpoint_name = f'step_{self.steps}' 107 | save_path = os.path.join(path, checkpoint_name) 108 | self.agent.save(save_path) 109 | steps_since_save = self.steps % self.save_steps 110 | 111 | if stop_training: 112 | break 113 | 114 | def _test(self): 115 | '''Tests the agent on the test environment.''' 116 | 117 | # Start the environment. 118 | if not hasattr(self, 'test_observations'): 119 | self.test_observations = self.test_environment.start() 120 | assert len(self.test_observations) == 1 121 | 122 | # Test loop. 123 | for _ in range(self.test_episodes): 124 | score, length = 0, 0 125 | 126 | while True: 127 | # Select an action. 128 | actions = self.agent.test_step( 129 | self.test_observations, self.steps) 130 | assert not np.isnan(actions.sum()) 131 | logger.store('test/action', actions, stats=True) 132 | 133 | # Take a step in the environment. 134 | self.test_observations, infos = self.test_environment.step( 135 | actions) 136 | self.agent.test_update(**infos, steps=self.steps) 137 | 138 | score += infos['rewards'][0] 139 | length += 1 140 | 141 | if infos['resets'][0]: 142 | break 143 | 144 | # Log the data. 145 | logger.store('test/episode_score', score, stats=True) 146 | logger.store('test/episode_length', length, stats=True) 147 | --------------------------------------------------------------------------------