├── .gitignore
├── LICENSE
├── README.md
├── data
├── images
│ ├── humanoid_perturbation.jpg
│ ├── logo.png
│ └── selection.jpg
└── logs
│ └── tensorflow_logs.pkl
├── setup.py
└── tonic
├── __init__.py
├── agents
├── __init__.py
├── agent.py
└── basic.py
├── environments
├── __init__.py
├── builders.py
├── distributed.py
└── wrappers.py
├── explorations
├── __init__.py
└── noisy.py
├── play.py
├── plot.py
├── replays
├── __init__.py
├── buffers.py
├── segments.py
└── utils.py
├── tensorflow
├── __init__.py
├── agents
│ ├── __init__.py
│ ├── a2c.py
│ ├── agent.py
│ ├── d4pg.py
│ ├── ddpg.py
│ ├── mpo.py
│ ├── ppo.py
│ ├── sac.py
│ ├── td3.py
│ ├── td4.py
│ └── trpo.py
├── models
│ ├── __init__.py
│ ├── actor_critics.py
│ ├── actors.py
│ ├── critics.py
│ ├── encoders.py
│ └── utils.py
├── normalizers
│ ├── __init__.py
│ ├── mean_stds.py
│ └── returns.py
└── updaters
│ ├── __init__.py
│ ├── actors.py
│ ├── critics.py
│ ├── optimizers.py
│ └── utils.py
├── torch
├── __init__.py
├── agents
│ ├── __init__.py
│ ├── a2c.py
│ ├── agent.py
│ ├── d4pg.py
│ ├── ddpg.py
│ ├── mpo.py
│ ├── ppo.py
│ ├── sac.py
│ ├── td3.py
│ └── trpo.py
├── models
│ ├── __init__.py
│ ├── actor_critics.py
│ ├── actors.py
│ ├── critics.py
│ ├── encoders.py
│ └── utils.py
├── normalizers
│ ├── __init__.py
│ ├── mean_stds.py
│ └── returns.py
└── updaters
│ ├── __init__.py
│ ├── actors.py
│ ├── critics.py
│ ├── optimizers.py
│ └── utils.py
├── train.py
└── utils
├── logger.py
└── trainer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.py~
3 | .ipynb_checkpoints
4 | *.egg-info
5 | __pycache__/
6 | MUJOCO_LOG.TXT
7 | *DS_Store
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License
2 |
3 | Copyright (c) 2020 Fabio Pardo (https://github.com/fabiopardo/tonic)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Tonic
2 |
3 |
4 |

5 |
6 |
7 | Welcome to the Tonic RL library!
8 |
9 | Please take a look at [the paper](https://arxiv.org/abs/2011.07537) for details
10 | and results.
11 |
12 | The main design principles are:
13 |
14 | * **Modularity:** Building blocks for creating RL agents, such as models,
15 | replays, or exploration strategies, are implemented as configurable modules.
16 |
17 | * **Readability:** Agents are written in a simple way with an identical API and
18 | logs are nicely displayed on the terminal with a progress bar.
19 |
20 | * **Fair comparison:** The training pipeline is unique and compatible with all
21 | Tonic agents and environments. Agents are defined by their core ideas while
22 | general tricks/improvements like
23 | [non-terminal timeouts](https://arxiv.org/pdf/1712.00378.pdf),
24 | observation normalization and action scaling are shared.
25 |
26 | * **Benchmarking:** Benchmark data of the provided agents trained on
27 | [70 continuous control environments](https://github.com/fabiopardo/tonic_data/blob/master/images/benchmark.pdf)
28 | are provided for direct comparison.
29 |
30 | * **Wrapped popular environments:** Environments from
31 | [OpenAI Gym](https://github.com/openai/gym),
32 | [PyBullet](https://github.com/bulletphysics/bullet3) and
33 | [DeepMind Control Suite](https://github.com/deepmind/dm_control) are made
34 | compatible with
35 | [non-terminal timeouts](https://arxiv.org/pdf/1712.00378.pdf) and synchronous
36 | distributed training.
37 |
38 | * **Compatibility with different ML frameworks:** Both TensorFlow 2 and PyTorch
39 | are currently supported. Simply import `tonic.tensorflow` or `tonic.torch`.
40 |
41 | * **Experimenting from the console:** While launch scripts can be used,
42 | iterating over various configurations from a console is made possible using
43 | snippets of Python code directly.
44 |
45 | * **Visualization of trained agents:** Experiment configurations and
46 | checkpoints can be loaded to play with trained agents.
47 |
48 | * **Collection of trained models:** To keep the main Tonic repository light,
49 | the full logs and trained models from the benchmark are available in the
50 | [tonic_data repository](https://github.com/fabiopardo/tonic_data).
51 |
52 | # Instructions
53 |
54 | ## Install from source
55 |
56 | Download and install Tonic:
57 | ```bash
58 | git clone https://github.com/fabiopardo/tonic.git
59 | pip install -e tonic/
60 | ```
61 |
62 | Install TensorFlow or PyTorch, for example using:
63 | ```bash
64 | pip install tensorflow torch
65 | ```
66 |
67 | ## Launch experiments
68 |
69 | Use TensorFlow or PyTorch to train an agent, for example using:
70 | ```bash
71 | python3 -m tonic.train \
72 | --header 'import tonic.torch' \
73 | --agent 'tonic.torch.agents.PPO()' \
74 | --environment 'tonic.environments.Gym("BipedalWalker-v3")' \
75 | --name PPO-X \
76 | --seed 0
77 | ```
78 |
79 | Snippets of Python code are used to directly configure the experiment. This is
80 | a very powerful feature allowing to configure agents and environments with
81 | various arguments or even load custom modules without adding them to the
82 | library. For example:
83 | ```bash
84 | python3 -m tonic.train \
85 | --header "import sys; sys.path.append('path/to/custom'); from custom import CustomAgent" \
86 | --agent "CustomAgent()" \
87 | --environment "tonic.environments.Bullet('AntBulletEnv-v0')" \
88 | --seed 0
89 | ```
90 |
91 | By default, environments use non-terminal timeouts, which is particularly
92 | important for locomotion tasks. But a time feature can be added to the
93 | observations to keep the MDP Markovian. See the
94 | [Time Limits in RL](https://arxiv.org/pdf/1712.00378.pdf) paper for more
95 | details. For example:
96 | ```bash
97 | python3 -m tonic.train \ ⏎
98 | --header 'import tonic.tensorflow' \
99 | --agent 'tonic.tensorflow.agents.PPO()' \
100 | --environment 'tonic.environments.Gym("Reacher-v2", terminal_timeouts=True, time_feature=True)' \
101 | --seed 0
102 | ```
103 |
104 | Distributed training can be used to accelerate learning. In Tonic, groups of
105 | sequential workers can be launched in parallel processes using for example:
106 | ```bash
107 | python3 -m tonic.train \
108 | --header "import tonic.tensorflow" \
109 | --agent "tonic.tensorflow.agents.PPO()" \
110 | --environment "tonic.environments.Gym('HalfCheetah-v3')" \
111 | --parallel 10 --sequential 100 \
112 | --seed 0
113 | ```
114 |
115 | ## Plot results
116 |
117 | During training, the experiment configuration, logs and checkpoints are
118 | saved in `environment/agent/seed/`.
119 |
120 | Result can be plotted with:
121 | ```bash
122 | python3 -m tonic.plot --path BipedalWalker-v3/ --baselines all
123 | ```
124 | Regular expressions like `BipedalWalker-v3/PPO-X/0`,
125 | `BipedalWalker-v3/{PPO*,DDPG*}` or `*Bullet*` can be used to point to different
126 | sets of logs.
127 | The `--baselines` argument can be used to load logs from the benchmark. For
128 | example `--baselines all` uses all agents while `--baselines A2C PPO TRPO` will
129 | use logs from A2C, PPO and TRPO.
130 |
131 | Different headers can be used for the x and y axes, for example to compare the
132 | gain in wall clock time of using distributed training, replace `--parallel 10`
133 | with `--parallel 5` in the last training example and plot the result with:
134 | ```bash
135 | python3 -m tonic.plot --path HalfCheetah-v3/ --x_axis train/seconds --x_label Seconds
136 | ```
137 |
138 | ## Play with trained models
139 |
140 | After some training time, checkpoints are generated and can be used to play
141 | with the trained agent:
142 | ```bash
143 | python3 -m tonic.play --path BipedalWalker-v3/PPO-X/0
144 | ```
145 |
146 | Environments are rendered using the appropriate framework. For example, when
147 | playing with DeepMind Control Suite environments, policies are loaded in a
148 | `dm_control.viewer` where `Space` is used to start the interaction, `Backspace`
149 | is used to start a new episode, `[` and `]` are used to switch cameras and
150 | double click on a body part followed by `Ctrl + mouse clicks` is used to add
151 | perturbations.
152 |
153 | ## Play with models from tonic_data
154 |
155 | The `tonic_data` repository can be downloaded with:
156 | ```bash
157 | git clone https://github.com/fabiopardo/tonic_data.git
158 | ```
159 |
160 | The best seed for each agent is stored in `environment/agent` and can be
161 | reloaded using for example:
162 | ```bash
163 | python3 -m tonic.play --path tonic_data/tensorflow/humanoid-stand/TD3
164 | ```
165 |
166 |
167 |

168 |
169 |
170 | The full benchmark plots are available
171 | [here](https://github.com/fabiopardo/tonic_data/blob/master/images/benchmark.pdf).
172 |
173 | They can be generated with:
174 | ```bash
175 | python3 -m tonic.plot \
176 | --baselines all \
177 | --backend agg --columns 7 --font_size 17 --legend_font_size 30 --legend_marker_size 20 \
178 | --name benchmark
179 | ```
180 |
181 | Or:
182 | ```bash
183 | python3 -m tonic.plot \
184 | --path tonic_data/tensorflow \
185 | --backend agg --columns 7 --font_size 17 --legend_font_size 30 --legend_marker_size 20 \
186 | --name benchmark
187 | ```
188 |
189 | And a selection can be generated with:
190 | ```bash
191 | python3 -m tonic.plot \
192 | --path tonic_data/tensorflow/{AntBulletEnv-v0,BipedalWalker-v3,finger-turn_hard,fish-swim,HalfCheetah-v3,HopperBulletEnv-v0,Humanoid-v3,quadruped-walk,swimmer-swimmer15,Walker2d-v3} \
193 | --backend agg --columns 5 --font_size 20 --legend_font_size 30 --legend_marker_size 20 \
194 | --name selection
195 | ```
196 |
197 |
198 |

199 |
200 |
201 | # Credit
202 |
203 | ## Other code bases
204 |
205 | Tonic was inspired by a number of other deep RL code bases. In particular,
206 | [OpenAI Baselines](https://github.com/openai/baselines),
207 | [Spinning Up in Deep RL](https://github.com/openai/spinningup)
208 | and [Acme](https://github.com/deepmind/acme).
209 |
210 | ## Citing Tonic
211 |
212 | If you use Tonic in your research, please cite the [paper](https://arxiv.org/abs/2011.07537):
213 |
214 | ```
215 | @article{pardo2020tonic,
216 | title={Tonic: A Deep Reinforcement Learning Library for Fast Prototyping and Benchmarking},
217 | author={Pardo, Fabio},
218 | journal={arXiv preprint arXiv:2011.07537},
219 | year={2020}
220 | }
221 | ```
222 |
--------------------------------------------------------------------------------
/data/images/humanoid_perturbation.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fabiopardo/tonic/0e20c894ee68278ab68322de61bb2c7204a11d5f/data/images/humanoid_perturbation.jpg
--------------------------------------------------------------------------------
/data/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fabiopardo/tonic/0e20c894ee68278ab68322de61bb2c7204a11d5f/data/images/logo.png
--------------------------------------------------------------------------------
/data/images/selection.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fabiopardo/tonic/0e20c894ee68278ab68322de61bb2c7204a11d5f/data/images/selection.jpg
--------------------------------------------------------------------------------
/data/logs/tensorflow_logs.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fabiopardo/tonic/0e20c894ee68278ab68322de61bb2c7204a11d5f/data/logs/tensorflow_logs.pkl
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 |
4 | setuptools.setup(
5 | name='tonic',
6 | description='Tonic RL Library',
7 | url='https://github.com/fabiopardo/tonic',
8 | version='0.3.0',
9 | author='Fabio Pardo',
10 | author_email='f.pardo@imperial.ac.uk',
11 | install_requires=[
12 | 'gym', 'matplotlib', 'numpy', 'pandas', 'pyyaml', 'termcolor'],
13 | license='MIT',
14 | python_requires='>=3.6',
15 | keywords=['tonic', 'deep learning', 'reinforcement learning'])
16 |
--------------------------------------------------------------------------------
/tonic/__init__.py:
--------------------------------------------------------------------------------
1 | from . import agents
2 | from . import environments
3 | from . import explorations
4 | from . import replays
5 | from .utils import logger
6 | from .utils.trainer import Trainer
7 |
8 |
9 | __all__ = [agents, environments, explorations, logger, replays, Trainer]
10 |
--------------------------------------------------------------------------------
/tonic/agents/__init__.py:
--------------------------------------------------------------------------------
1 | from .agent import Agent
2 | from .basic import Constant, NormalRandom, OrnsteinUhlenbeck, UniformRandom
3 |
4 |
5 | __all__ = [Agent, Constant, NormalRandom, OrnsteinUhlenbeck, UniformRandom]
6 |
--------------------------------------------------------------------------------
/tonic/agents/agent.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class Agent(abc.ABC):
5 | '''Abstract class used to build agents.'''
6 |
7 | def initialize(self, observation_space, action_space, seed=None):
8 | pass
9 |
10 | @abc.abstractmethod
11 | def step(self, observations, steps):
12 | '''Returns actions during training.'''
13 | pass
14 |
15 | def update(self, observations, rewards, resets, terminations, steps):
16 | '''Informs the agent of the latest transitions during training.'''
17 | pass
18 |
19 | @abc.abstractmethod
20 | def test_step(self, observations, steps):
21 | '''Returns actions during testing.'''
22 | pass
23 |
24 | def test_update(self, observations, rewards, resets, terminations, steps):
25 | '''Informs the agent of the latest transitions during testing.'''
26 | pass
27 |
28 | def save(self, path):
29 | '''Saves the agent weights during training.'''
30 | pass
31 |
32 | def load(self, path):
33 | '''Reloads the agent weights from a checkpoint.'''
34 | pass
35 |
--------------------------------------------------------------------------------
/tonic/agents/basic.py:
--------------------------------------------------------------------------------
1 | '''Some basic non-learning agents used for example for debugging.'''
2 |
3 | import numpy as np
4 |
5 | from tonic import agents
6 |
7 |
8 | class NormalRandom(agents.Agent):
9 | '''Random agent producing actions from normal distributions.'''
10 |
11 | def __init__(self, loc=0, scale=1):
12 | self.loc = loc
13 | self.scale = scale
14 |
15 | def initialize(self, observation_space, action_space, seed=None):
16 | self.action_size = action_space.shape[0]
17 | self.np_random = np.random.RandomState(seed)
18 |
19 | def step(self, observations, steps):
20 | return self._policy(observations)
21 |
22 | def test_step(self, observations, steps):
23 | return self._policy(observations)
24 |
25 | def _policy(self, observations):
26 | batch_size = len(observations)
27 | shape = (batch_size, self.action_size)
28 | return self.np_random.normal(self.loc, self.scale, shape)
29 |
30 |
31 | class UniformRandom(agents.Agent):
32 | '''Random agent producing actions from uniform distributions.'''
33 |
34 | def initialize(self, observation_space, action_space, seed=None):
35 | self.action_size = action_space.shape[0]
36 | self.np_random = np.random.RandomState(seed)
37 |
38 | def step(self, observations, steps):
39 | return self._policy(observations)
40 |
41 | def test_step(self, observations, steps):
42 | return self._policy(observations)
43 |
44 | def _policy(self, observations):
45 | batch_size = len(observations)
46 | shape = (batch_size, self.action_size)
47 | return self.np_random.uniform(-1, 1, shape)
48 |
49 |
50 | class OrnsteinUhlenbeck(agents.Agent):
51 | '''Random agent producing correlated actions from an OU process.'''
52 |
53 | def __init__(self, scale=0.2, clip=2, theta=.15, dt=1e-2):
54 | self.scale = scale
55 | self.clip = clip
56 | self.theta = theta
57 | self.dt = dt
58 |
59 | def initialize(self, observation_space, action_space, seed=None):
60 | self.action_size = action_space.shape[0]
61 | self.np_random = np.random.RandomState(seed)
62 | self.train_actions = None
63 | self.test_actions = None
64 |
65 | def step(self, observations, steps):
66 | return self._train_policy(observations)
67 |
68 | def test_step(self, observations, steps):
69 | return self._test_policy(observations)
70 |
71 | def _train_policy(self, observations):
72 | if self.train_actions is None:
73 | shape = (len(observations), self.action_size)
74 | self.train_actions = np.zeros(shape)
75 | self.train_actions = self._next_actions(self.train_actions)
76 | return self.train_actions
77 |
78 | def _test_policy(self, observations):
79 | if self.test_actions is None:
80 | shape = (len(observations), self.action_size)
81 | self.test_actions = np.zeros(shape)
82 | self.test_actions = self._next_actions(self.test_actions)
83 | return self.test_actions
84 |
85 | def _next_actions(self, actions):
86 | noises = self.np_random.normal(size=actions.shape)
87 | noises = np.clip(noises, -self.clip, self.clip)
88 | next_actions = (1 - self.theta * self.dt) * actions
89 | next_actions += self.scale * np.sqrt(self.dt) * noises
90 | next_actions = np.clip(next_actions, -1, 1)
91 | return next_actions
92 |
93 | def update(self, observations, rewards, resets, terminations, steps):
94 | self.train_actions *= (1. - resets)[:, None]
95 |
96 | def test_update(self, observations, rewards, resets, terminations, steps):
97 | self.test_actions *= (1. - resets)[:, None]
98 |
99 |
100 | class Constant(agents.Agent):
101 | '''Agent producing a unique constant action.'''
102 |
103 | def __init__(self, constant=0):
104 | self.constant = constant
105 |
106 | def initialize(self, observation_space, action_space, seed=None):
107 | self.action_size = action_space.shape[0]
108 |
109 | def step(self, observations, steps):
110 | return self._policy(observations)
111 |
112 | def test_step(self, observations, steps):
113 | return self._policy(observations)
114 |
115 | def _policy(self, observations):
116 | shape = (len(observations), self.action_size)
117 | return np.full(shape, self.constant)
118 |
--------------------------------------------------------------------------------
/tonic/environments/__init__.py:
--------------------------------------------------------------------------------
1 | from .builders import Bullet, ControlSuite, Gym
2 | from .distributed import distribute, Parallel, Sequential
3 | from .wrappers import ActionRescaler, TimeFeature
4 |
5 |
6 | __all__ = [
7 | Bullet, ControlSuite, Gym, distribute, Parallel, Sequential,
8 | ActionRescaler, TimeFeature]
9 |
--------------------------------------------------------------------------------
/tonic/environments/builders.py:
--------------------------------------------------------------------------------
1 | '''Environment builders for popular domains.'''
2 |
3 | import os
4 |
5 | import gym.wrappers
6 | import numpy as np
7 |
8 | from tonic import environments
9 | from tonic.utils import logger
10 |
11 |
12 | def gym_environment(*args, **kwargs):
13 | '''Returns a wrapped Gym environment.'''
14 |
15 | def _builder(*args, **kwargs):
16 | return gym.make(*args, **kwargs)
17 |
18 | return build_environment(_builder, *args, **kwargs)
19 |
20 |
21 | def bullet_environment(*args, **kwargs):
22 | '''Returns a wrapped PyBullet environment.'''
23 |
24 | def _builder(*args, **kwargs):
25 | import pybullet_envs # noqa
26 | return gym.make(*args, **kwargs)
27 |
28 | return build_environment(_builder, *args, **kwargs)
29 |
30 |
31 | def control_suite_environment(*args, **kwargs):
32 | '''Returns a wrapped Control Suite environment.'''
33 |
34 | def _builder(name, *args, **kwargs):
35 | domain, task = name.split('-')
36 | environment = ControlSuiteEnvironment(
37 | domain_name=domain, task_name=task, *args, **kwargs)
38 | time_limit = int(environment.environment._step_limit)
39 | return gym.wrappers.TimeLimit(environment, time_limit)
40 |
41 | return build_environment(_builder, *args, **kwargs)
42 |
43 |
44 | def build_environment(
45 | builder, name, terminal_timeouts=False, time_feature=False,
46 | max_episode_steps='default', scaled_actions=True, *args, **kwargs
47 | ):
48 | '''Builds and wrap an environment.
49 | Time limits can be properly handled with terminal_timeouts=False or
50 | time_feature=True, see https://arxiv.org/pdf/1712.00378.pdf for more
51 | details.
52 | '''
53 |
54 | # Build the environment.
55 | environment = builder(name, *args, **kwargs)
56 |
57 | # Get the default time limit.
58 | if max_episode_steps == 'default':
59 | max_episode_steps = environment._max_episode_steps
60 |
61 | # Remove the TimeLimit wrapper if needed.
62 | if not terminal_timeouts:
63 | assert type(environment) == gym.wrappers.TimeLimit, environment
64 | environment = environment.env
65 |
66 | # Add time as a feature if needed.
67 | if time_feature:
68 | environment = environments.wrappers.TimeFeature(
69 | environment, max_episode_steps)
70 |
71 | # Scale actions from [-1, 1]^n to the true action space if needed.
72 | if scaled_actions:
73 | environment = environments.wrappers.ActionRescaler(environment)
74 |
75 | environment.name = name
76 | environment.max_episode_steps = max_episode_steps
77 |
78 | return environment
79 |
80 |
81 | def _flatten_observation(observation):
82 | '''Turns OrderedDict observations into vectors.'''
83 | observation = [np.array([o]) if np.isscalar(o) else o.ravel()
84 | for o in observation.values()]
85 | return np.concatenate(observation, axis=0)
86 |
87 |
88 | class ControlSuiteEnvironment(gym.core.Env):
89 | '''Turns a Control Suite environment into a Gym environment.'''
90 |
91 | def __init__(
92 | self, domain_name, task_name, task_kwargs=None, visualize_reward=True,
93 | environment_kwargs=None
94 | ):
95 | from dm_control import suite
96 | self.environment = suite.load(
97 | domain_name=domain_name, task_name=task_name,
98 | task_kwargs=task_kwargs, visualize_reward=visualize_reward,
99 | environment_kwargs=environment_kwargs)
100 |
101 | # Create the observation space.
102 | observation_spec = self.environment.observation_spec()
103 | dim = sum([np.int(np.prod(spec.shape))
104 | for spec in observation_spec.values()])
105 | high = np.full(dim, np.inf, np.float32)
106 | self.observation_space = gym.spaces.Box(-high, high, dtype=np.float32)
107 |
108 | # Create the action space.
109 | action_spec = self.environment.action_spec()
110 | self.action_space = gym.spaces.Box(
111 | action_spec.minimum, action_spec.maximum, dtype=np.float32)
112 |
113 | def seed(self, seed):
114 | self.environment.task._random = np.random.RandomState(seed)
115 |
116 | def step(self, action):
117 | try:
118 | time_step = self.environment.step(action)
119 | observation = _flatten_observation(time_step.observation)
120 | reward = time_step.reward
121 |
122 | # Remove terminations from timeouts.
123 | done = time_step.last()
124 | if done:
125 | done = self.environment.task.get_termination(
126 | self.environment.physics)
127 | done = done is not None
128 |
129 | self.last_time_step = time_step
130 |
131 | # In case MuJoCo crashed.
132 | except Exception as e:
133 | path = logger.get_path()
134 | os.makedirs(path, exist_ok=True)
135 | save_path = os.path.join(path, 'crashes.txt')
136 | error = str(e)
137 | with open(save_path, 'a') as file:
138 | file.write(error + '\n')
139 | logger.error(error)
140 | observation = _flatten_observation(self.last_time_step.observation)
141 | observation = np.zeros_like(observation)
142 | reward = 0.
143 | done = True
144 |
145 | return observation, reward, done, {}
146 |
147 | def reset(self):
148 | time_step = self.environment.reset()
149 | self.last_time_step = time_step
150 | return _flatten_observation(time_step.observation)
151 |
152 | def render(self, mode='rgb_array', height=None, width=None, camera_id=0):
153 | '''Returns RGB frames from a camera.'''
154 | assert mode == 'rgb_array'
155 | return self.environment.physics.render(
156 | height=height, width=width, camera_id=camera_id)
157 |
158 |
159 | # Aliases.
160 | Gym = gym_environment
161 | Bullet = bullet_environment
162 | ControlSuite = control_suite_environment
163 |
--------------------------------------------------------------------------------
/tonic/environments/distributed.py:
--------------------------------------------------------------------------------
1 | '''Builders for distributed training.'''
2 |
3 | import multiprocessing
4 |
5 | import numpy as np
6 |
7 |
8 | class Sequential:
9 | '''A group of environments used in sequence.'''
10 |
11 | def __init__(self, environment_builder, max_episode_steps, workers):
12 | self.environments = [environment_builder() for _ in range(workers)]
13 | self.max_episode_steps = max_episode_steps
14 | self.observation_space = self.environments[0].observation_space
15 | self.action_space = self.environments[0].action_space
16 | self.name = self.environments[0].name
17 |
18 | def initialize(self, seed):
19 | for i, environment in enumerate(self.environments):
20 | environment.seed(seed + i)
21 |
22 | def start(self):
23 | '''Used once to get the initial observations.'''
24 | observations = [env.reset() for env in self.environments]
25 | self.lengths = np.zeros(len(self.environments), int)
26 | return np.array(observations, np.float32)
27 |
28 | def step(self, actions):
29 | next_observations = [] # Observations for the transitions.
30 | rewards = []
31 | resets = []
32 | terminations = []
33 | observations = [] # Observations for the actions selection.
34 |
35 | for i in range(len(self.environments)):
36 | ob, rew, term, _ = self.environments[i].step(actions[i])
37 |
38 | self.lengths[i] += 1
39 | # Timeouts trigger resets but are not true terminations.
40 | reset = term or self.lengths[i] == self.max_episode_steps
41 | next_observations.append(ob)
42 | rewards.append(rew)
43 | resets.append(reset)
44 | terminations.append(term)
45 |
46 | if reset:
47 | ob = self.environments[i].reset()
48 | self.lengths[i] = 0
49 |
50 | observations.append(ob)
51 |
52 | observations = np.array(observations, np.float32)
53 | infos = dict(
54 | observations=np.array(next_observations, np.float32),
55 | rewards=np.array(rewards, np.float32),
56 | resets=np.array(resets, np.bool),
57 | terminations=np.array(terminations, np.bool))
58 | return observations, infos
59 |
60 | def render(self, mode='human', *args, **kwargs):
61 | outs = []
62 | for env in self.environments:
63 | out = env.render(mode=mode, *args, **kwargs)
64 | outs.append(out)
65 | if mode != 'human':
66 | return np.array(outs)
67 |
68 |
69 | class Parallel:
70 | '''A group of sequential environments used in parallel.'''
71 |
72 | def __init__(
73 | self, environment_builder, worker_groups, workers_per_group,
74 | max_episode_steps
75 | ):
76 | self.environment_builder = environment_builder
77 | self.worker_groups = worker_groups
78 | self.workers_per_group = workers_per_group
79 | self.max_episode_steps = max_episode_steps
80 |
81 | def initialize(self, seed):
82 | def proc(action_pipe, index, seed):
83 | '''Process holding a sequential group of environments.'''
84 | envs = Sequential(
85 | self.environment_builder, self.max_episode_steps,
86 | self.workers_per_group)
87 | envs.initialize(seed)
88 |
89 | observations = envs.start()
90 | self.output_queue.put((index, observations))
91 |
92 | while True:
93 | actions = action_pipe.recv()
94 | out = envs.step(actions)
95 | self.output_queue.put((index, out))
96 |
97 | dummy_environment = self.environment_builder()
98 | self.observation_space = dummy_environment.observation_space
99 | self.action_space = dummy_environment.action_space
100 | del dummy_environment
101 | self.started = False
102 |
103 | self.output_queue = multiprocessing.Queue()
104 | self.action_pipes = []
105 |
106 | for i in range(self.worker_groups):
107 | pipe, worker_end = multiprocessing.Pipe()
108 | self.action_pipes.append(pipe)
109 | group_seed = seed + i * self.workers_per_group
110 | process = multiprocessing.Process(
111 | target=proc, args=(worker_end, i, group_seed))
112 | process.daemon = True
113 | process.start()
114 |
115 | def start(self):
116 | '''Used once to get the initial observations.'''
117 | assert not self.started
118 | self.started = True
119 | observations_list = [None for _ in range(self.worker_groups)]
120 |
121 | for _ in range(self.worker_groups):
122 | index, observations = self.output_queue.get()
123 | observations_list[index] = observations
124 |
125 | self.observations_list = np.array(observations_list)
126 | self.next_observations_list = np.zeros_like(self.observations_list)
127 | self.rewards_list = np.zeros(
128 | (self.worker_groups, self.workers_per_group), np.float32)
129 | self.resets_list = np.zeros(
130 | (self.worker_groups, self.workers_per_group), np.bool)
131 | self.terminations_list = np.zeros(
132 | (self.worker_groups, self.workers_per_group), np.bool)
133 |
134 | return np.concatenate(self.observations_list)
135 |
136 | def step(self, actions):
137 | actions_list = np.split(actions, self.worker_groups)
138 | for actions, pipe in zip(actions_list, self.action_pipes):
139 | pipe.send(actions)
140 |
141 | for _ in range(self.worker_groups):
142 | index, (observations, infos) = self.output_queue.get()
143 | self.observations_list[index] = observations
144 | self.next_observations_list[index] = infos['observations']
145 | self.rewards_list[index] = infos['rewards']
146 | self.resets_list[index] = infos['resets']
147 | self.terminations_list[index] = infos['terminations']
148 |
149 | observations = np.concatenate(self.observations_list)
150 | infos = dict(
151 | observations=np.concatenate(self.next_observations_list),
152 | rewards=np.concatenate(self.rewards_list),
153 | resets=np.concatenate(self.resets_list),
154 | terminations=np.concatenate(self.terminations_list))
155 | return observations, infos
156 |
157 |
158 | def distribute(environment_builder, worker_groups=1, workers_per_group=1):
159 | '''Distributes workers over parallel and sequential groups.'''
160 | dummy_environment = environment_builder()
161 | max_episode_steps = dummy_environment.max_episode_steps
162 | del dummy_environment
163 |
164 | if worker_groups < 2:
165 | return Sequential(
166 | environment_builder, max_episode_steps=max_episode_steps,
167 | workers=workers_per_group)
168 |
169 | return Parallel(
170 | environment_builder, worker_groups=worker_groups,
171 | workers_per_group=workers_per_group,
172 | max_episode_steps=max_episode_steps)
173 |
--------------------------------------------------------------------------------
/tonic/environments/wrappers.py:
--------------------------------------------------------------------------------
1 | '''Environment wrappers.'''
2 |
3 | import gym
4 | import numpy as np
5 |
6 |
7 | class ActionRescaler(gym.ActionWrapper):
8 | '''Rescales actions from [-1, 1]^n to the true action space.
9 | The baseline agents return actions in [-1, 1]^n.'''
10 |
11 | def __init__(self, env):
12 | assert isinstance(env.action_space, gym.spaces.Box)
13 | super().__init__(env)
14 | high = np.ones(env.action_space.shape, dtype=np.float32)
15 | self.action_space = gym.spaces.Box(low=-high, high=high)
16 | true_low = env.action_space.low
17 | true_high = env.action_space.high
18 | self.bias = (true_high + true_low) / 2
19 | self.scale = (true_high - true_low) / 2
20 |
21 | def action(self, action):
22 | return self.bias + self.scale * np.clip(action, -1, 1)
23 |
24 |
25 | class TimeFeature(gym.Wrapper):
26 | '''Adds a notion of time in the observations.
27 | It can be used in terminal timeout settings to get Markovian MDPs.
28 | '''
29 |
30 | def __init__(self, env, max_steps, low=-1, high=1):
31 | super().__init__(env)
32 | dtype = self.observation_space.dtype
33 | self.observation_space = gym.spaces.Box(
34 | low=np.append(self.observation_space.low, low).astype(dtype),
35 | high=np.append(self.observation_space.high, high).astype(dtype))
36 | self.max_episode_steps = max_steps
37 | self.steps = 0
38 | self.low = low
39 | self.high = high
40 |
41 | def reset(self, **kwargs):
42 | self.steps = 0
43 | observation = self.env.reset(**kwargs)
44 | observation = np.append(observation, self.low)
45 | return observation
46 |
47 | def step(self, action):
48 | assert self.steps < self.max_episode_steps
49 | observation, reward, done, info = self.env.step(action)
50 | self.steps += 1
51 | prop = self.steps / self.max_episode_steps
52 | v = self.low + (self.high - self.low) * prop
53 | observation = np.append(observation, v)
54 | return observation, reward, done, info
55 |
--------------------------------------------------------------------------------
/tonic/explorations/__init__.py:
--------------------------------------------------------------------------------
1 | from .noisy import NoActionNoise
2 | from .noisy import NormalActionNoise
3 | from .noisy import OrnsteinUhlenbeckActionNoise
4 |
5 |
6 | __all__ = [NoActionNoise, NormalActionNoise, OrnsteinUhlenbeckActionNoise]
7 |
--------------------------------------------------------------------------------
/tonic/explorations/noisy.py:
--------------------------------------------------------------------------------
1 | '''Non-differentiable noisy exploration methods.'''
2 |
3 | import numpy as np
4 |
5 |
6 | class NoActionNoise:
7 | def __init__(self, start_steps=20000):
8 | self.start_steps = start_steps
9 |
10 | def initialize(self, policy, action_space, seed=None):
11 | self.policy = policy
12 | self.action_size = action_space.shape[0]
13 | self.np_random = np.random.RandomState(seed)
14 |
15 | def __call__(self, observations, steps):
16 | if steps > self.start_steps:
17 | actions = self.policy(observations)
18 | actions = np.clip(actions, -1, 1)
19 | else:
20 | shape = (len(observations), self.action_size)
21 | actions = self.np_random.uniform(-1, 1, shape)
22 | return actions
23 |
24 | def update(self, resets):
25 | pass
26 |
27 |
28 | class NormalActionNoise:
29 | def __init__(self, scale=0.1, start_steps=20000):
30 | self.scale = scale
31 | self.start_steps = start_steps
32 |
33 | def initialize(self, policy, action_space, seed=None):
34 | self.policy = policy
35 | self.action_size = action_space.shape[0]
36 | self.np_random = np.random.RandomState(seed)
37 |
38 | def __call__(self, observations, steps):
39 | if steps > self.start_steps:
40 | actions = self.policy(observations)
41 | noises = self.scale * self.np_random.normal(size=actions.shape)
42 | actions = (actions + noises).astype(np.float32)
43 | actions = np.clip(actions, -1, 1)
44 | else:
45 | shape = (len(observations), self.action_size)
46 | actions = self.np_random.uniform(-1, 1, shape)
47 | return actions
48 |
49 | def update(self, resets):
50 | pass
51 |
52 |
53 | class OrnsteinUhlenbeckActionNoise:
54 | def __init__(
55 | self, scale=0.1, clip=2, theta=.15, dt=1e-2, start_steps=20000
56 | ):
57 | self.scale = scale
58 | self.clip = clip
59 | self.theta = theta
60 | self.dt = dt
61 | self.start_steps = start_steps
62 |
63 | def initialize(self, policy, action_space, seed=None):
64 | self.policy = policy
65 | self.action_size = action_space.shape[0]
66 | self.np_random = np.random.RandomState(seed)
67 | self.noises = None
68 |
69 | def __call__(self, observations, steps):
70 | if steps > self.start_steps:
71 | actions = self.policy(observations)
72 |
73 | if self.noises is None:
74 | self.noises = np.zeros_like(actions)
75 | noises = self.np_random.normal(size=actions.shape)
76 | noises = np.clip(noises, -self.clip, self.clip)
77 | self.noises -= self.theta * self.noises * self.dt
78 | self.noises += self.scale * np.sqrt(self.dt) * noises
79 | actions = (actions + self.noises).astype(np.float32)
80 | actions = np.clip(actions, -1, 1)
81 | else:
82 | shape = (len(observations), self.action_size)
83 | actions = self.np_random.uniform(-1, 1, shape)
84 | return actions
85 |
86 | def update(self, resets):
87 | if self.noises is not None:
88 | self.noises *= (1. - resets)[:, None]
89 |
--------------------------------------------------------------------------------
/tonic/play.py:
--------------------------------------------------------------------------------
1 | '''Script used to play with trained agents.'''
2 |
3 | import argparse
4 | import os
5 |
6 | import numpy as np
7 | import yaml
8 |
9 | import tonic # noqa
10 |
11 |
12 | def play_gym(agent, environment):
13 | '''Launches an agent in a Gym-based environment.'''
14 |
15 | environment = tonic.environments.distribute(lambda: environment)
16 |
17 | observations = environment.start()
18 | environment.render()
19 |
20 | score = 0
21 | length = 0
22 | min_reward = float('inf')
23 | max_reward = -float('inf')
24 | global_min_reward = float('inf')
25 | global_max_reward = -float('inf')
26 | steps = 0
27 | episodes = 0
28 |
29 | while True:
30 | actions = agent.test_step(observations, steps)
31 | observations, infos = environment.step(actions)
32 | agent.test_update(**infos, steps=steps)
33 | environment.render()
34 |
35 | steps += 1
36 | reward = infos['rewards'][0]
37 | score += reward
38 | min_reward = min(min_reward, reward)
39 | max_reward = max(max_reward, reward)
40 | global_min_reward = min(global_min_reward, reward)
41 | global_max_reward = max(global_max_reward, reward)
42 | length += 1
43 |
44 | if infos['resets'][0]:
45 | term = infos['terminations'][0]
46 | episodes += 1
47 |
48 | print()
49 | print(f'Episodes: {episodes:,}')
50 | print(f'Score: {score:,.3f}')
51 | print(f'Length: {length:,}')
52 | print(f'Terminal: {term:}')
53 | print(f'Min reward: {min_reward:,.3f}')
54 | print(f'Max reward: {max_reward:,.3f}')
55 | print(f'Global min reward: {min_reward:,.3f}')
56 | print(f'Global max reward: {max_reward:,.3f}')
57 |
58 | score = 0
59 | length = 0
60 | min_reward = float('inf')
61 | max_reward = -float('inf')
62 |
63 |
64 | def play_control_suite(agent, environment):
65 | '''Launches an agent in a DeepMind Control Suite-based environment.'''
66 |
67 | from dm_control import viewer
68 |
69 | class Wrapper:
70 | '''Wrapper used to plug a Tonic environment in a dm_control viewer.'''
71 |
72 | def __init__(self, environment):
73 | self.environment = environment
74 | self.unwrapped = environment.unwrapped
75 | self.action_spec = self.unwrapped.environment.action_spec
76 | self.physics = self.unwrapped.environment.physics
77 | self.infos = None
78 | self.steps = 0
79 | self.episodes = 0
80 | self.min_reward = float('inf')
81 | self.max_reward = -float('inf')
82 | self.global_min_reward = float('inf')
83 | self.global_max_reward = -float('inf')
84 |
85 | def reset(self):
86 | '''Mimics a dm_control reset for the viewer.'''
87 |
88 | self.observations = self.environment.reset()[None]
89 |
90 | self.score = 0
91 | self.length = 0
92 | self.min_reward = float('inf')
93 | self.max_reward = -float('inf')
94 |
95 | return self.unwrapped.last_time_step
96 |
97 | def step(self, actions):
98 | '''Mimics a dm_control step for the viewer.'''
99 |
100 | assert not np.isnan(actions.sum())
101 | ob, rew, term, _ = self.environment.step(actions[0])
102 |
103 | self.score += rew
104 | self.length += 1
105 | self.min_reward = min(self.min_reward, rew)
106 | self.max_reward = max(self.max_reward, rew)
107 | self.global_min_reward = min(self.global_min_reward, rew)
108 | self.global_max_reward = max(self.global_max_reward, rew)
109 | timeout = self.length == self.environment.max_episode_steps
110 | done = term or timeout
111 |
112 | if done:
113 | self.episodes += 1
114 | print()
115 | print(f'Episodes: {self.episodes:,}')
116 | print(f'Score: {self.score:,.3f}')
117 | print(f'Length: {self.length:,}')
118 | print(f'Terminal: {term:}')
119 | print(f'Min reward: {self.min_reward:,.3f}')
120 | print(f'Max reward: {self.max_reward:,.3f}')
121 | print(f'Global min reward: {self.min_reward:,.3f}')
122 | print(f'Global max reward: {self.max_reward:,.3f}')
123 |
124 | self.observations = ob[None]
125 | self.infos = dict(
126 | observations=ob[None], rewards=np.array([rew]),
127 | resets=np.array([done]), terminations=np.array([term]))
128 |
129 | return self.unwrapped.last_time_step
130 |
131 | # Wrap the environment for the viewer.
132 | environment = Wrapper(environment)
133 |
134 | def policy(timestep):
135 | '''Mimics a dm_control policy for the viewer.'''
136 |
137 | if environment.infos is not None:
138 | agent.test_update(**environment.infos, steps=environment.steps)
139 | environment.steps += 1
140 | return agent.test_step(environment.observations, environment.steps)
141 |
142 | # Launch the viewer with the wrapped environment and policy.
143 | viewer.launch(environment, policy)
144 |
145 |
146 | def play(path, checkpoint, seed, header, agent, environment):
147 | '''Reloads an agent and an environment from a previous experiment.'''
148 |
149 | checkpoint_path = None
150 |
151 | if path:
152 | tonic.logger.log(f'Loading experiment from {path}')
153 |
154 | # Use no checkpoint, the agent is freshly created.
155 | if checkpoint == 'none' or agent is not None:
156 | tonic.logger.log('Not loading any weights')
157 |
158 | else:
159 | checkpoint_path = os.path.join(path, 'checkpoints')
160 | if not os.path.isdir(checkpoint_path):
161 | tonic.logger.error(f'{checkpoint_path} is not a directory')
162 | checkpoint_path = None
163 |
164 | # List all the checkpoints.
165 | checkpoint_ids = []
166 | for file in os.listdir(checkpoint_path):
167 | if file[:5] == 'step_':
168 | checkpoint_id = file.split('.')[0]
169 | checkpoint_ids.append(int(checkpoint_id[5:]))
170 |
171 | if checkpoint_ids:
172 | # Use the last checkpoint.
173 | if checkpoint == 'last':
174 | checkpoint_id = max(checkpoint_ids)
175 | checkpoint_path = os.path.join(
176 | checkpoint_path, f'step_{checkpoint_id}')
177 |
178 | # Use the specified checkpoint.
179 | else:
180 | checkpoint_id = int(checkpoint)
181 | if checkpoint_id in checkpoint_ids:
182 | checkpoint_path = os.path.join(
183 | checkpoint_path, f'step_{checkpoint_id}')
184 | else:
185 | tonic.logger.error(f'Checkpoint {checkpoint_id} '
186 | f'not found in {checkpoint_path}')
187 | checkpoint_path = None
188 |
189 | else:
190 | tonic.logger.error(f'No checkpoint found in {checkpoint_path}')
191 | checkpoint_path = None
192 |
193 | # Load the experiment configuration.
194 | arguments_path = os.path.join(path, 'config.yaml')
195 | with open(arguments_path, 'r') as config_file:
196 | config = yaml.load(config_file, Loader=yaml.FullLoader)
197 | config = argparse.Namespace(**config)
198 |
199 | header = header or config.header
200 | agent = agent or config.agent
201 | environment = environment or config.test_environment
202 | environment = environment or config.environment
203 |
204 | # Run the header first, e.g. to load an ML framework.
205 | if header:
206 | exec(header)
207 |
208 | # Build the agent.
209 | if not agent:
210 | raise ValueError('No agent specified.')
211 | agent = eval(agent)
212 |
213 | # Build the environment.
214 | environment = eval(environment)
215 | environment.seed(seed)
216 |
217 | # Initialize the agent.
218 | agent.initialize(
219 | observation_space=environment.observation_space,
220 | action_space=environment.action_space, seed=seed)
221 |
222 | # Load the weights of the agent form a checkpoint.
223 | if checkpoint_path:
224 | agent.load(checkpoint_path)
225 |
226 | # Play with the agent in the environment.
227 | if isinstance(environment, tonic.environments.wrappers.ActionRescaler):
228 | environment_type = environment.env.__class__.__name__
229 | else:
230 | environment_type = environment.__class__.__name__
231 |
232 | if environment_type == 'ControlSuiteEnvironment':
233 | play_control_suite(agent, environment)
234 | else:
235 | if 'Bullet' in environment_type:
236 | environment.render()
237 | play_gym(agent, environment)
238 |
239 |
240 | if __name__ == '__main__':
241 | # Argument parsing.
242 | parser = argparse.ArgumentParser()
243 | parser.add_argument('--path')
244 | parser.add_argument('--checkpoint', default='last')
245 | parser.add_argument('--seed', type=int, default=0)
246 | parser.add_argument('--header')
247 | parser.add_argument('--agent')
248 | parser.add_argument('--environment', '--env')
249 | args = vars(parser.parse_args())
250 | play(**args)
251 |
--------------------------------------------------------------------------------
/tonic/replays/__init__.py:
--------------------------------------------------------------------------------
1 | from .buffers import Buffer
2 | from .segments import Segment
3 | from .utils import flatten_batch, lambda_returns
4 |
5 |
6 | __all__ = [flatten_batch, lambda_returns, Buffer, Segment]
7 |
--------------------------------------------------------------------------------
/tonic/replays/buffers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Buffer:
5 | '''Replay storing a large number of transitions for off-policy learning
6 | and using n-step returns.'''
7 |
8 | def __init__(
9 | self, size=int(1e6), return_steps=1, batch_iterations=50,
10 | batch_size=100, discount_factor=0.99, steps_before_batches=int(1e4),
11 | steps_between_batches=50
12 | ):
13 | self.full_max_size = size
14 | self.return_steps = return_steps
15 | self.batch_iterations = batch_iterations
16 | self.batch_size = batch_size
17 | self.discount_factor = discount_factor
18 | self.steps_before_batches = steps_before_batches
19 | self.steps_between_batches = steps_between_batches
20 |
21 | def initialize(self, seed=None):
22 | self.np_random = np.random.RandomState(seed)
23 | self.buffers = None
24 | self.index = 0
25 | self.size = 0
26 | self.last_steps = 0
27 |
28 | def ready(self, steps):
29 | if steps < self.steps_before_batches:
30 | return False
31 | return (steps - self.last_steps) >= self.steps_between_batches
32 |
33 | def store(self, **kwargs):
34 | if 'terminations' in kwargs:
35 | continuations = np.float32(1 - kwargs['terminations'])
36 | kwargs['discounts'] = continuations * self.discount_factor
37 |
38 | # Create the named buffers.
39 | if self.buffers is None:
40 | self.num_workers = len(list(kwargs.values())[0])
41 | self.max_size = self.full_max_size // self.num_workers
42 | self.buffers = {}
43 | for key, val in kwargs.items():
44 | shape = (self.max_size,) + np.array(val).shape
45 | self.buffers[key] = np.full(shape, np.nan, np.float32)
46 |
47 | # Store the new values.
48 | for key, val in kwargs.items():
49 | self.buffers[key][self.index] = val
50 |
51 | # Accumulate values for n-step returns.
52 | if self.return_steps > 1:
53 | self.accumulate_n_steps(kwargs)
54 |
55 | self.index = (self.index + 1) % self.max_size
56 | self.size = min(self.size + 1, self.max_size)
57 |
58 | def accumulate_n_steps(self, kwargs):
59 | rewards = kwargs['rewards']
60 | next_observations = kwargs['next_observations']
61 | discounts = kwargs['discounts']
62 | masks = np.ones(self.num_workers, np.float32)
63 |
64 | for i in range(min(self.size, self.return_steps - 1)):
65 | index = (self.index - i - 1) % self.max_size
66 | masks *= (1 - self.buffers['resets'][index])
67 | new_rewards = (self.buffers['rewards'][index] +
68 | self.buffers['discounts'][index] * rewards)
69 | self.buffers['rewards'][index] = (
70 | (1 - masks) * self.buffers['rewards'][index] +
71 | masks * new_rewards)
72 | new_discounts = self.buffers['discounts'][index] * discounts
73 | self.buffers['discounts'][index] = (
74 | (1 - masks) * self.buffers['discounts'][index] +
75 | masks * new_discounts)
76 | self.buffers['next_observations'][index] = (
77 | (1 - masks)[:, None] *
78 | self.buffers['next_observations'][index] +
79 | masks[:, None] * next_observations)
80 |
81 | def get(self, *keys, steps):
82 | '''Get batches from named buffers.'''
83 |
84 | for _ in range(self.batch_iterations):
85 | total_size = self.size * self.num_workers
86 | indices = self.np_random.randint(total_size, size=self.batch_size)
87 | rows = indices // self.num_workers
88 | columns = indices % self.num_workers
89 | yield {k: self.buffers[k][rows, columns] for k in keys}
90 |
91 | self.last_steps = steps
92 |
--------------------------------------------------------------------------------
/tonic/replays/segments.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from tonic import replays
4 |
5 |
6 | class Segment:
7 | '''Replay storing recent transitions for on-policy learning.'''
8 |
9 | def __init__(
10 | self, size=4096, batch_iterations=80, batch_size=None,
11 | discount_factor=0.99, trace_decay=0.97
12 | ):
13 | self.max_size = size
14 | self.batch_iterations = batch_iterations
15 | self.batch_size = batch_size
16 | self.discount_factor = discount_factor
17 | self.trace_decay = trace_decay
18 |
19 | def initialize(self, seed=None):
20 | self.np_random = np.random.RandomState(seed)
21 | self.buffers = None
22 | self.index = 0
23 |
24 | def ready(self):
25 | return self.index == self.max_size
26 |
27 | def store(self, **kwargs):
28 | if self.buffers is None:
29 | self.num_workers = len(list(kwargs.values())[0])
30 | self.buffers = {}
31 | for key, val in kwargs.items():
32 | shape = (self.max_size,) + np.array(val).shape
33 | self.buffers[key] = np.zeros(shape, np.float32)
34 | for key, val in kwargs.items():
35 | self.buffers[key][self.index] = val
36 | self.index += 1
37 |
38 | def get_full(self, *keys):
39 | self.index = 0
40 |
41 | if 'advantages' in keys:
42 | advs = self.buffers['returns'] - self.buffers['values']
43 | std = advs.std()
44 | if std != 0:
45 | advs = (advs - advs.mean()) / std
46 | self.buffers['advantages'] = advs
47 |
48 | return {k: replays.flatten_batch(self.buffers[k]) for k in keys}
49 |
50 | def get(self, *keys):
51 | '''Get mini-batches from named buffers.'''
52 |
53 | batch = self.get_full(*keys)
54 |
55 | if self.batch_size is None:
56 | for _ in range(self.batch_iterations):
57 | yield batch
58 | else:
59 | size = self.max_size * self.num_workers
60 | all_indices = np.arange(size)
61 | for _ in range(self.batch_iterations):
62 | self.np_random.shuffle(all_indices)
63 | for i in range(0, size, self.batch_size):
64 | indices = all_indices[i:i + self.batch_size]
65 | yield {k: v[indices] for k, v in batch.items()}
66 |
67 | def compute_returns(self, values, next_values):
68 | shape = self.buffers['rewards'].shape
69 | self.buffers['values'] = values.reshape(shape)
70 | self.buffers['next_values'] = next_values.reshape(shape)
71 | self.buffers['returns'] = replays.lambda_returns(
72 | values=self.buffers['values'],
73 | next_values=self.buffers['next_values'],
74 | rewards=self.buffers['rewards'],
75 | resets=self.buffers['resets'],
76 | terminations=self.buffers['terminations'],
77 | discount_factor=self.discount_factor,
78 | trace_decay=self.trace_decay)
79 |
--------------------------------------------------------------------------------
/tonic/replays/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def lambda_returns(
5 | values, next_values, rewards, resets, terminations, discount_factor,
6 | trace_decay
7 | ):
8 | '''Function used to calculate lambda-returns on parallel buffers.'''
9 |
10 | returns = np.zeros_like(values)
11 | last_returns = next_values[-1]
12 | for t in reversed(range(len(rewards))):
13 | bootstrap = (
14 | (1 - trace_decay) * next_values[t] + trace_decay * last_returns)
15 | bootstrap *= (1 - resets[t])
16 | bootstrap += resets[t] * next_values[t]
17 | bootstrap *= (1 - terminations[t])
18 | returns[t] = last_returns = rewards[t] + discount_factor * bootstrap
19 | return returns
20 |
21 |
22 | def flatten_batch(values):
23 | shape = values.shape
24 | new_shape = (np.prod(shape[:2], dtype=int),) + shape[2:]
25 | return values.reshape(new_shape)
26 |
--------------------------------------------------------------------------------
/tonic/tensorflow/__init__.py:
--------------------------------------------------------------------------------
1 | from . import agents, models, normalizers, updaters
2 |
3 |
4 | __all__ = [agents, models, normalizers, updaters]
5 |
--------------------------------------------------------------------------------
/tonic/tensorflow/agents/__init__.py:
--------------------------------------------------------------------------------
1 | from .agent import Agent
2 |
3 | from .a2c import A2C # noqa
4 | from .ddpg import DDPG
5 | from .d4pg import D4PG # noqa
6 | from .mpo import MPO
7 | from .ppo import PPO
8 | from .sac import SAC
9 | from .td3 import TD3
10 | from .td4 import TD4
11 | from .trpo import TRPO
12 |
13 |
14 | __all__ = [Agent, A2C, DDPG, D4PG, MPO, PPO, SAC, TD3, TD4, TRPO]
15 |
--------------------------------------------------------------------------------
/tonic/tensorflow/agents/a2c.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from tonic import logger, replays
4 | from tonic.tensorflow import agents, models, normalizers, updaters
5 |
6 |
7 | def default_model():
8 | return models.ActorCritic(
9 | actor=models.Actor(
10 | encoder=models.ObservationEncoder(),
11 | torso=models.MLP((64, 64), 'tanh'),
12 | head=models.DetachedScaleGaussianPolicyHead()),
13 | critic=models.Critic(
14 | encoder=models.ObservationEncoder(),
15 | torso=models.MLP((64, 64), 'tanh'),
16 | head=models.ValueHead()),
17 | observation_normalizer=normalizers.MeanStd())
18 |
19 |
20 | class A2C(agents.Agent):
21 | '''Advantage Actor Critic (aka Vanilla Policy Gradient).
22 | A3C: https://arxiv.org/pdf/1602.01783.pdf
23 | '''
24 |
25 | def __init__(
26 | self, model=None, replay=None, actor_updater=None, critic_updater=None
27 | ):
28 | self.model = model or default_model()
29 | self.replay = replay or replays.Segment()
30 | self.actor_updater = actor_updater or \
31 | updaters.StochasticPolicyGradient()
32 | self.critic_updater = critic_updater or updaters.VRegression()
33 |
34 | def initialize(self, observation_space, action_space, seed=None):
35 | super().initialize(seed=seed)
36 | self.model.initialize(observation_space, action_space)
37 | self.replay.initialize(seed)
38 | self.actor_updater.initialize(self.model)
39 | self.critic_updater.initialize(self.model)
40 |
41 | def step(self, observations, steps):
42 | # Sample actions and get their log-probabilities for training.
43 | actions, log_probs = self._step(observations)
44 | actions = actions.numpy()
45 | log_probs = log_probs.numpy()
46 |
47 | # Keep some values for the next update.
48 | self.last_observations = observations.copy()
49 | self.last_actions = actions.copy()
50 | self.last_log_probs = log_probs.copy()
51 |
52 | return actions
53 |
54 | def test_step(self, observations, steps):
55 | # Sample actions for testing.
56 | return self._test_step(observations).numpy()
57 |
58 | def update(self, observations, rewards, resets, terminations, steps):
59 | # Store the last transitions in the replay.
60 | self.replay.store(
61 | observations=self.last_observations, actions=self.last_actions,
62 | next_observations=observations, rewards=rewards, resets=resets,
63 | terminations=terminations, log_probs=self.last_log_probs)
64 |
65 | # Prepare to update the normalizers.
66 | if self.model.observation_normalizer:
67 | self.model.observation_normalizer.record(self.last_observations)
68 | if self.model.return_normalizer:
69 | self.model.return_normalizer.record(rewards)
70 |
71 | # Update the model if the replay is ready.
72 | if self.replay.ready():
73 | self._update()
74 |
75 | @tf.function
76 | def _step(self, observations):
77 | distributions = self.model.actor(observations)
78 | if hasattr(distributions, 'sample_with_log_prob'):
79 | actions, log_probs = distributions.sample_with_log_prob()
80 | else:
81 | actions = distributions.sample()
82 | log_probs = distributions.log_prob(actions)
83 | return actions, log_probs
84 |
85 | @tf.function
86 | def _test_step(self, observations):
87 | return self.model.actor(observations).sample()
88 |
89 | @tf.function
90 | def _evaluate(self, observations, next_observations):
91 | values = self.model.critic(observations)
92 | next_values = self.model.critic(next_observations)
93 | return values, next_values
94 |
95 | def _update(self):
96 | # Compute the lambda-returns.
97 | batch = self.replay.get_full('observations', 'next_observations')
98 | values, next_values = self._evaluate(**batch)
99 | values, next_values = values.numpy(), next_values.numpy()
100 | self.replay.compute_returns(values, next_values)
101 |
102 | # Update the actor once.
103 | keys = 'observations', 'actions', 'advantages', 'log_probs'
104 | batch = self.replay.get_full(*keys)
105 | infos = self.actor_updater(**batch)
106 | for k, v in infos.items():
107 | logger.store('actor/' + k, v.numpy())
108 |
109 | # Update the critic multiple times.
110 | for batch in self.replay.get('observations', 'returns'):
111 | infos = self.critic_updater(**batch)
112 | for k, v in infos.items():
113 | logger.store('critic/' + k, v.numpy())
114 |
115 | # Update the normalizers.
116 | if self.model.observation_normalizer:
117 | self.model.observation_normalizer.update()
118 | if self.model.return_normalizer:
119 | self.model.return_normalizer.update()
120 |
--------------------------------------------------------------------------------
/tonic/tensorflow/agents/agent.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import tensorflow as tf
5 |
6 | from tonic import agents, logger
7 |
8 |
9 | class Agent(agents.Agent):
10 | def initialize(self, seed=None):
11 | if seed is not None:
12 | np.random.seed(seed)
13 | random.seed(seed)
14 | tf.random.set_seed(seed)
15 |
16 | def save(self, path):
17 | logger.log(f'\nSaving weights to {path}')
18 | self.model.save_weights(path)
19 |
20 | def load(self, path):
21 | logger.log(f'\nLoading weights from {path}')
22 | self.model.load_weights(path)
23 |
--------------------------------------------------------------------------------
/tonic/tensorflow/agents/d4pg.py:
--------------------------------------------------------------------------------
1 | from tonic import replays
2 | from tonic.tensorflow import agents, models, normalizers, updaters
3 |
4 |
5 | def default_model():
6 | return models.ActorCriticWithTargets(
7 | actor=models.Actor(
8 | encoder=models.ObservationEncoder(),
9 | torso=models.MLP((256, 256), 'relu'),
10 | head=models.DeterministicPolicyHead()),
11 | critic=models.Critic(
12 | encoder=models.ObservationActionEncoder(),
13 | torso=models.MLP((256, 256), 'relu'),
14 | # These values are for the control suite with 0.99 discount.
15 | head=models.DistributionalValueHead(-150., 150., 51)),
16 | observation_normalizer=normalizers.MeanStd())
17 |
18 |
19 | class D4PG(agents.DDPG):
20 | '''Distributed Distributional Deterministic Policy Gradients.
21 | D4PG: https://arxiv.org/pdf/1804.08617.pdf
22 | '''
23 |
24 | def __init__(
25 | self, model=None, replay=None, exploration=None, actor_updater=None,
26 | critic_updater=None
27 | ):
28 | model = model or default_model()
29 | replay = replay or replays.Buffer(return_steps=5)
30 | actor_updater = actor_updater or \
31 | updaters.DistributionalDeterministicPolicyGradient()
32 | critic_updater = critic_updater or \
33 | updaters.DistributionalDeterministicQLearning()
34 | super().__init__(
35 | model, replay, exploration, actor_updater, critic_updater)
36 |
--------------------------------------------------------------------------------
/tonic/tensorflow/agents/ddpg.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from tonic import explorations, logger, replays
4 | from tonic.tensorflow import agents, models, normalizers, updaters
5 |
6 |
7 | def default_model():
8 | return models.ActorCriticWithTargets(
9 | actor=models.Actor(
10 | encoder=models.ObservationEncoder(),
11 | torso=models.MLP((256, 256), 'relu'),
12 | head=models.DeterministicPolicyHead()),
13 | critic=models.Critic(
14 | encoder=models.ObservationActionEncoder(),
15 | torso=models.MLP((256, 256), 'relu'),
16 | head=models.ValueHead()),
17 | observation_normalizer=normalizers.MeanStd())
18 |
19 |
20 | class DDPG(agents.Agent):
21 | '''Deep Deterministic Policy Gradient.
22 | DDPG: https://arxiv.org/pdf/1509.02971.pdf
23 | '''
24 |
25 | def __init__(
26 | self, model=None, replay=None, exploration=None, actor_updater=None,
27 | critic_updater=None
28 | ):
29 | self.model = model or default_model()
30 | self.replay = replay or replays.Buffer()
31 | self.exploration = exploration or explorations.NormalActionNoise()
32 | self.actor_updater = actor_updater or \
33 | updaters.DeterministicPolicyGradient()
34 | self.critic_updater = critic_updater or \
35 | updaters.DeterministicQLearning()
36 |
37 | def initialize(self, observation_space, action_space, seed=None):
38 | super().initialize(seed=seed)
39 | self.model.initialize(observation_space, action_space)
40 | self.replay.initialize(seed)
41 | self.exploration.initialize(self._policy, action_space, seed)
42 | self.actor_updater.initialize(self.model)
43 | self.critic_updater.initialize(self.model)
44 |
45 | def step(self, observations, steps):
46 | # Get actions from the actor and exploration method.
47 | actions = self.exploration(observations, steps)
48 |
49 | # Keep some values for the next update.
50 | self.last_observations = observations.copy()
51 | self.last_actions = actions.copy()
52 |
53 | return actions
54 |
55 | def test_step(self, observations, steps):
56 | # Greedy actions for testing.
57 | return self._greedy_actions(observations).numpy()
58 |
59 | def update(self, observations, rewards, resets, terminations, steps):
60 | # Store the last transitions in the replay.
61 | self.replay.store(
62 | observations=self.last_observations, actions=self.last_actions,
63 | next_observations=observations, rewards=rewards, resets=resets,
64 | terminations=terminations)
65 |
66 | # Prepare to update the normalizers.
67 | if self.model.observation_normalizer:
68 | self.model.observation_normalizer.record(self.last_observations)
69 | if self.model.return_normalizer:
70 | self.model.return_normalizer.record(rewards)
71 |
72 | # Update the model if the replay is ready.
73 | if self.replay.ready(steps):
74 | self._update(steps)
75 |
76 | self.exploration.update(resets)
77 |
78 | @tf.function
79 | def _greedy_actions(self, observations):
80 | return self.model.actor(observations)
81 |
82 | def _policy(self, observations):
83 | return self._greedy_actions(observations).numpy()
84 |
85 | def _update(self, steps):
86 | keys = ('observations', 'actions', 'next_observations', 'rewards',
87 | 'discounts')
88 |
89 | # Update both the actor and the critic multiple times.
90 | for batch in self.replay.get(*keys, steps=steps):
91 | infos = self._update_actor_critic(**batch)
92 |
93 | for key in infos:
94 | for k, v in infos[key].items():
95 | logger.store(key + '/' + k, v.numpy())
96 |
97 | # Update the normalizers.
98 | if self.model.observation_normalizer:
99 | self.model.observation_normalizer.update()
100 | if self.model.return_normalizer:
101 | self.model.return_normalizer.update()
102 |
103 | @tf.function
104 | def _update_actor_critic(
105 | self, observations, actions, next_observations, rewards, discounts
106 | ):
107 | critic_infos = self.critic_updater(
108 | observations, actions, next_observations, rewards, discounts)
109 | actor_infos = self.actor_updater(observations)
110 | self.model.update_targets()
111 | return dict(critic=critic_infos, actor=actor_infos)
112 |
--------------------------------------------------------------------------------
/tonic/tensorflow/agents/mpo.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from tonic import logger, replays
4 | from tonic.tensorflow import agents, models, normalizers, updaters
5 |
6 |
7 | def default_model():
8 | return models.ActorCriticWithTargets(
9 | actor=models.Actor(
10 | encoder=models.ObservationEncoder(),
11 | torso=models.MLP((256, 256), 'relu'),
12 | head=models.GaussianPolicyHead()),
13 | critic=models.Critic(
14 | encoder=models.ObservationActionEncoder(),
15 | torso=models.MLP((256, 256), 'relu'),
16 | head=models.ValueHead()),
17 | observation_normalizer=normalizers.MeanStd())
18 |
19 |
20 | class MPO(agents.Agent):
21 | '''Maximum a Posteriori Policy Optimisation.
22 | MPO: https://arxiv.org/pdf/1806.06920.pdf
23 | MO-MPO: https://arxiv.org/pdf/2005.07513.pdf
24 | '''
25 |
26 | def __init__(
27 | self, model=None, replay=None, actor_updater=None, critic_updater=None
28 | ):
29 | self.model = model or default_model()
30 | self.replay = replay or replays.Buffer(return_steps=5)
31 | self.actor_updater = actor_updater or \
32 | updaters.MaximumAPosterioriPolicyOptimization()
33 | self.critic_updater = critic_updater or updaters.ExpectedSARSA()
34 |
35 | def initialize(self, observation_space, action_space, seed=None):
36 | super().initialize(seed=seed)
37 | self.model.initialize(observation_space, action_space)
38 | self.replay.initialize(seed)
39 | self.actor_updater.initialize(self.model, action_space)
40 | self.critic_updater.initialize(self.model)
41 |
42 | def step(self, observations, steps):
43 | actions = self._step(observations)
44 | actions = actions.numpy()
45 |
46 | # Keep some values for the next update.
47 | self.last_observations = observations.copy()
48 | self.last_actions = actions.copy()
49 |
50 | return actions
51 |
52 | def test_step(self, observations, steps):
53 | # Sample actions for testing.
54 | return self._test_step(observations).numpy()
55 |
56 | def update(self, observations, rewards, resets, terminations, steps):
57 | # Store the last transitions in the replay.
58 | self.replay.store(
59 | observations=self.last_observations, actions=self.last_actions,
60 | next_observations=observations, rewards=rewards, resets=resets,
61 | terminations=terminations)
62 |
63 | # Prepare to update the normalizers.
64 | if self.model.observation_normalizer:
65 | self.model.observation_normalizer.record(self.last_observations)
66 | if self.model.return_normalizer:
67 | self.model.return_normalizer.record(rewards)
68 |
69 | # Update the model if the replay is ready.
70 | if self.replay.ready(steps):
71 | self._update(steps)
72 |
73 | @tf.function
74 | def _step(self, observations):
75 | return self.model.actor(observations).sample()
76 |
77 | @tf.function
78 | def _test_step(self, observations):
79 | return self.model.actor(observations).mode()
80 |
81 | def _update(self, steps):
82 | keys = ('observations', 'actions', 'next_observations', 'rewards',
83 | 'discounts')
84 |
85 | # Update both the actor and the critic multiple times.
86 | for batch in self.replay.get(*keys, steps=steps):
87 | infos = self._update_actor_critic(**batch)
88 |
89 | for key in infos:
90 | for k, v in infos[key].items():
91 | logger.store(key + '/' + k, v.numpy())
92 |
93 | # Update the normalizers.
94 | if self.model.observation_normalizer:
95 | self.model.observation_normalizer.update()
96 | if self.model.return_normalizer:
97 | self.model.return_normalizer.update()
98 |
99 | @tf.function
100 | def _update_actor_critic(
101 | self, observations, actions, next_observations, rewards, discounts
102 | ):
103 | critic_infos = self.critic_updater(
104 | observations, actions, next_observations, rewards, discounts)
105 | actor_infos = self.actor_updater(observations)
106 | self.model.update_targets()
107 | return dict(critic=critic_infos, actor=actor_infos)
108 |
--------------------------------------------------------------------------------
/tonic/tensorflow/agents/ppo.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from tonic import logger
4 | from tonic.tensorflow import agents, updaters
5 |
6 |
7 | class PPO(agents.A2C):
8 | '''Proximal Policy Optimization.
9 | PPO: https://arxiv.org/pdf/1707.06347.pdf
10 | '''
11 |
12 | def __init__(
13 | self, model=None, replay=None, actor_updater=None, critic_updater=None
14 | ):
15 | actor_updater = actor_updater or updaters.ClippedRatio()
16 | super().__init__(
17 | model=model, replay=replay, actor_updater=actor_updater,
18 | critic_updater=critic_updater)
19 |
20 | def _update(self):
21 | # Compute the lambda-returns.
22 | batch = self.replay.get_full('observations', 'next_observations')
23 | values, next_values = self._evaluate(**batch)
24 | values, next_values = values.numpy(), next_values.numpy()
25 | self.replay.compute_returns(values, next_values)
26 |
27 | train_actor = True
28 | actor_iterations = 0
29 | critic_iterations = 0
30 | keys = 'observations', 'actions', 'advantages', 'log_probs', 'returns'
31 |
32 | # Update both the actor and the critic multiple times.
33 | for batch in self.replay.get(*keys):
34 | if train_actor:
35 | infos = self._update_actor_critic(**batch)
36 | actor_iterations += 1
37 | else:
38 | batch = {k: batch[k] for k in ('observations', 'returns')}
39 | infos = dict(critic=self.critic_updater(**batch))
40 | critic_iterations += 1
41 |
42 | # Stop earlier the training of the actor.
43 | if train_actor:
44 | train_actor = not infos['actor']['stop'].numpy()
45 |
46 | for key in infos:
47 | for k, v in infos[key].items():
48 | logger.store(key + '/' + k, v.numpy())
49 |
50 | logger.store('actor/iterations', actor_iterations)
51 | logger.store('critic/iterations', critic_iterations)
52 |
53 | # Update the normalizers.
54 | if self.model.observation_normalizer:
55 | self.model.observation_normalizer.update()
56 | if self.model.return_normalizer:
57 | self.model.return_normalizer.update()
58 |
59 | @tf.function
60 | def _update_actor_critic(
61 | self, observations, actions, advantages, log_probs, returns
62 | ):
63 | actor_infos = self.actor_updater(
64 | observations, actions, advantages, log_probs)
65 | critic_infos = self.critic_updater(observations, returns)
66 | return dict(actor=actor_infos, critic=critic_infos)
67 |
--------------------------------------------------------------------------------
/tonic/tensorflow/agents/sac.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from tonic import explorations
4 | from tonic.tensorflow import agents, models, normalizers, updaters
5 |
6 |
7 | def default_model():
8 | return models.ActorTwinCriticWithTargets(
9 | actor=models.Actor(
10 | encoder=models.ObservationEncoder(),
11 | torso=models.MLP((256, 256), 'relu'),
12 | head=models.GaussianPolicyHead(
13 | loc_activation=None,
14 | distribution=models.SquashedMultivariateNormalDiag)),
15 | critic=models.Critic(
16 | encoder=models.ObservationActionEncoder(),
17 | torso=models.MLP((256, 256), 'relu'),
18 | head=models.ValueHead()),
19 | observation_normalizer=normalizers.MeanStd())
20 |
21 |
22 | class SAC(agents.DDPG):
23 | '''Soft Actor-Critic.
24 | SAC: https://arxiv.org/pdf/1801.01290.pdf
25 | '''
26 |
27 | def __init__(
28 | self, model=None, replay=None, exploration=None, actor_updater=None,
29 | critic_updater=None
30 | ):
31 | model = model or default_model()
32 | exploration = exploration or explorations.NoActionNoise()
33 | actor_updater = actor_updater or \
34 | updaters.TwinCriticSoftDeterministicPolicyGradient()
35 | critic_updater = critic_updater or updaters.TwinCriticSoftQLearning()
36 | super().__init__(
37 | model=model, replay=replay, exploration=exploration,
38 | actor_updater=actor_updater, critic_updater=critic_updater)
39 |
40 | @tf.function
41 | def _stochastic_actions(self, observations):
42 | return self.model.actor(observations).sample()
43 |
44 | def _policy(self, observations):
45 | return self._stochastic_actions(observations).numpy()
46 |
47 | @tf.function
48 | def _greedy_actions(self, observations):
49 | return self.model.actor(observations).mode()
50 |
--------------------------------------------------------------------------------
/tonic/tensorflow/agents/td3.py:
--------------------------------------------------------------------------------
1 | from tonic import logger
2 | from tonic.tensorflow import agents, models, normalizers, updaters
3 |
4 |
5 | def default_model():
6 | return models.ActorTwinCriticWithTargets(
7 | actor=models.Actor(
8 | encoder=models.ObservationEncoder(),
9 | torso=models.MLP((256, 256), 'relu'),
10 | head=models.DeterministicPolicyHead()),
11 | critic=models.Critic(
12 | encoder=models.ObservationActionEncoder(),
13 | torso=models.MLP((256, 256), 'relu'),
14 | head=models.ValueHead()),
15 | observation_normalizer=normalizers.MeanStd())
16 |
17 |
18 | class TD3(agents.DDPG):
19 | '''Twin Delayed Deep Deterministic Policy Gradient.
20 | TD3: https://arxiv.org/pdf/1802.09477.pdf
21 | '''
22 |
23 | def __init__(
24 | self, model=None, replay=None, exploration=None, actor_updater=None,
25 | critic_updater=None, delay_steps=2
26 | ):
27 | model = model or default_model()
28 | critic_updater = critic_updater or \
29 | updaters.TwinCriticDeterministicQLearning()
30 | super().__init__(
31 | model=model, replay=replay, exploration=exploration,
32 | actor_updater=actor_updater, critic_updater=critic_updater)
33 | self.delay_steps = delay_steps
34 | self.model.critic = self.model.critic_1
35 |
36 | def _update(self, steps):
37 | keys = ('observations', 'actions', 'next_observations', 'rewards',
38 | 'discounts')
39 | for i, batch in enumerate(self.replay.get(*keys, steps=steps)):
40 | if (i + 1) % self.delay_steps == 0:
41 | infos = self._update_actor_critic(**batch)
42 | else:
43 | infos = dict(critic=self.critic_updater(**batch))
44 | for key in infos:
45 | for k, v in infos[key].items():
46 | logger.store(key + '/' + k, v.numpy())
47 |
48 | # Update the normalizers.
49 | if self.model.observation_normalizer:
50 | self.model.observation_normalizer.update()
51 | if self.model.return_normalizer:
52 | self.model.return_normalizer.update()
53 |
--------------------------------------------------------------------------------
/tonic/tensorflow/agents/td4.py:
--------------------------------------------------------------------------------
1 | from tonic import replays
2 | from tonic.tensorflow import agents, models, normalizers, updaters
3 |
4 |
5 | def default_model():
6 | return models.ActorTwinCriticWithTargets(
7 | actor=models.Actor(
8 | encoder=models.ObservationEncoder(),
9 | torso=models.MLP((256, 256), 'relu'),
10 | head=models.DeterministicPolicyHead()),
11 | critic=models.Critic(
12 | encoder=models.ObservationActionEncoder(),
13 | torso=models.MLP((256, 256), 'relu'),
14 | head=models.DistributionalValueHead(-150., 150., 51)),
15 | observation_normalizer=normalizers.MeanStd())
16 |
17 |
18 | class TD4(agents.TD3):
19 | def __init__(
20 | self, model=None, replay=None, exploration=None, actor_updater=None,
21 | critic_updater=None, delay_steps=2
22 | ):
23 | model = model or default_model()
24 | replay = replay or replays.Buffer(return_steps=5)
25 | actor_updater = actor_updater or \
26 | updaters.DistributionalDeterministicPolicyGradient()
27 | critic_updater = critic_updater or \
28 | updaters.TwinCriticDistributionalDeterministicQLearning()
29 | super().__init__(
30 | model=model, replay=replay, exploration=exploration,
31 | actor_updater=actor_updater, critic_updater=critic_updater,
32 | delay_steps=delay_steps)
33 |
--------------------------------------------------------------------------------
/tonic/tensorflow/agents/trpo.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from tonic import logger
4 | from tonic.tensorflow import agents, updaters
5 |
6 |
7 | class TRPO(agents.A2C):
8 | '''Trust Region Policy Optimization.
9 | TRPO: https://arxiv.org/pdf/1502.05477.pdf
10 | '''
11 |
12 | def __init__(
13 | self, model=None, replay=None, actor_updater=None, critic_updater=None
14 | ):
15 | actor_updater = actor_updater or updaters.TrustRegionPolicyGradient()
16 | super().__init__(
17 | model=model, replay=replay, actor_updater=actor_updater,
18 | critic_updater=critic_updater)
19 |
20 | def step(self, observations, steps):
21 | # Sample actions and get their log-probabilities for training.
22 | actions, log_probs, locs, scales = self._step(observations)
23 | actions = actions.numpy()
24 | log_probs = log_probs.numpy()
25 | locs = locs.numpy()
26 | scales = scales.numpy()
27 |
28 | # Keep some values for the next update.
29 | self.last_observations = observations.copy()
30 | self.last_actions = actions.copy()
31 | self.last_log_probs = log_probs.copy()
32 | self.last_locs = locs.copy()
33 | self.last_scales = scales.copy()
34 |
35 | return actions
36 |
37 | def update(self, observations, rewards, resets, terminations, steps):
38 | # Store the last transitions in the replay.
39 | self.replay.store(
40 | observations=self.last_observations, actions=self.last_actions,
41 | next_observations=observations, rewards=rewards, resets=resets,
42 | terminations=terminations, log_probs=self.last_log_probs,
43 | locs=self.last_locs, scales=self.last_scales)
44 |
45 | # Prepare to update the normalizers.
46 | if self.model.observation_normalizer:
47 | self.model.observation_normalizer.record(self.last_observations)
48 | if self.model.return_normalizer:
49 | self.model.return_normalizer.record(rewards)
50 |
51 | # Update the model if the replay is ready.
52 | if self.replay.ready():
53 | self._update()
54 |
55 | @tf.function
56 | def _step(self, observations):
57 | distributions = self.model.actor(observations)
58 | if hasattr(distributions, 'sample_with_log_prob'):
59 | actions, log_probs = distributions.sample_with_log_prob()
60 | else:
61 | actions = distributions.sample()
62 | log_probs = distributions.log_prob(actions)
63 | locs = distributions.loc
64 | scales = distributions.stddev()
65 | return actions, log_probs, locs, scales
66 |
67 | def _update(self):
68 | # Compute the lambda-returns.
69 | batch = self.replay.get_full('observations', 'next_observations')
70 | values, next_values = self._evaluate(**batch)
71 | values, next_values = values.numpy(), next_values.numpy()
72 | self.replay.compute_returns(values, next_values)
73 |
74 | keys = ('observations', 'actions', 'log_probs', 'locs', 'scales',
75 | 'advantages')
76 | batch = self.replay.get_full(*keys)
77 | infos = self.actor_updater(**batch)
78 | for k, v in infos.items():
79 | logger.store('actor/' + k, v.numpy())
80 |
81 | critic_iterations = 0
82 | for batch in self.replay.get('observations', 'returns'):
83 | infos = self.critic_updater(**batch)
84 | critic_iterations += 1
85 | for k, v in infos.items():
86 | logger.store('critic/' + k, v.numpy())
87 | logger.store('critic/iterations', critic_iterations)
88 |
89 | # Update the normalizers.
90 | if self.model.observation_normalizer:
91 | self.model.observation_normalizer.update()
92 | if self.model.return_normalizer:
93 | self.model.return_normalizer.update()
94 |
--------------------------------------------------------------------------------
/tonic/tensorflow/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .actor_critics import ActorCritic
2 | from .actor_critics import ActorCriticWithTargets
3 | from .actor_critics import ActorTwinCriticWithTargets
4 |
5 | from .actors import Actor
6 | from .actors import DetachedScaleGaussianPolicyHead
7 | from .actors import DeterministicPolicyHead
8 | from .actors import GaussianPolicyHead
9 | from .actors import SquashedMultivariateNormalDiag
10 |
11 | from .critics import Critic, DistributionalValueHead, ValueHead
12 |
13 | from .encoders import ObservationActionEncoder, ObservationEncoder
14 |
15 | from .utils import default_dense_kwargs, MLP
16 |
17 |
18 | __all__ = [
19 | default_dense_kwargs, MLP, ObservationActionEncoder,
20 | ObservationEncoder, SquashedMultivariateNormalDiag,
21 | DetachedScaleGaussianPolicyHead, GaussianPolicyHead,
22 | DeterministicPolicyHead, Actor, Critic, DistributionalValueHead,
23 | ValueHead, ActorCritic, ActorCriticWithTargets, ActorTwinCriticWithTargets]
24 |
--------------------------------------------------------------------------------
/tonic/tensorflow/models/actor_critics.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import tensorflow as tf
4 |
5 |
6 | class ActorCritic(tf.keras.Model):
7 | def __init__(
8 | self, actor, critic, observation_normalizer=None,
9 | return_normalizer=None
10 | ):
11 | super().__init__()
12 | self.actor = actor
13 | self.critic = critic
14 | self.observation_normalizer = observation_normalizer
15 | self.return_normalizer = return_normalizer
16 |
17 | def initialize(self, observation_space, action_space):
18 | if self.observation_normalizer:
19 | self.observation_normalizer.initialize(observation_space.shape)
20 | self.actor.initialize(
21 | observation_space, action_space, self.observation_normalizer)
22 | self.critic.initialize(
23 | observation_space, action_space, self.observation_normalizer,
24 | self.return_normalizer)
25 | dummy_observations = tf.zeros((1,) + observation_space.shape)
26 | self.actor(dummy_observations)
27 | self.critic(dummy_observations)
28 |
29 |
30 | class ActorCriticWithTargets(tf.keras.Model):
31 | def __init__(
32 | self, actor, critic, observation_normalizer=None,
33 | return_normalizer=None, target_coeff=0.005
34 | ):
35 | super().__init__()
36 | self.actor = actor
37 | self.critic = critic
38 | self.target_actor = copy.deepcopy(actor)
39 | self.target_critic = copy.deepcopy(critic)
40 | self.observation_normalizer = observation_normalizer
41 | self.return_normalizer = return_normalizer
42 | self.target_coeff = target_coeff
43 |
44 | def initialize(self, observation_space, action_space):
45 | if self.observation_normalizer:
46 | self.observation_normalizer.initialize(observation_space.shape)
47 | self.actor.initialize(
48 | observation_space, action_space, self.observation_normalizer)
49 | self.critic.initialize(
50 | observation_space, action_space, self.observation_normalizer,
51 | self.return_normalizer)
52 | self.target_actor.initialize(
53 | observation_space, action_space, self.observation_normalizer)
54 | self.target_critic.initialize(
55 | observation_space, action_space, self.observation_normalizer,
56 | self.return_normalizer)
57 | dummy_observations = tf.zeros((1,) + observation_space.shape)
58 | dummy_actions = tf.zeros((1,) + action_space.shape)
59 | self.actor(dummy_observations)
60 | self.critic(dummy_observations, dummy_actions)
61 | self.target_actor(dummy_observations)
62 | self.target_critic(dummy_observations, dummy_actions)
63 | self.online_variables = (
64 | self.actor.trainable_variables +
65 | self.critic.trainable_variables)
66 | self.target_variables = (
67 | self.target_actor.trainable_variables +
68 | self.target_critic.trainable_variables)
69 | self.assign_targets()
70 |
71 | def assign_targets(self):
72 | for o, t in zip(self.online_variables, self.target_variables):
73 | t.assign(o)
74 |
75 | def update_targets(self):
76 | for o, t in zip(self.online_variables, self.target_variables):
77 | t.assign((1 - self.target_coeff) * t + self.target_coeff * o)
78 |
79 |
80 | class ActorTwinCriticWithTargets(tf.keras.Model):
81 | def __init__(
82 | self, actor, critic, observation_normalizer=None,
83 | return_normalizer=None, target_coeff=0.005
84 | ):
85 | super().__init__()
86 | self.actor = actor
87 | self.critic_1 = critic
88 | self.critic_2 = copy.deepcopy(critic)
89 | self.target_actor = copy.deepcopy(actor)
90 | self.target_critic_1 = copy.deepcopy(critic)
91 | self.target_critic_2 = copy.deepcopy(critic)
92 | self.observation_normalizer = observation_normalizer
93 | self.return_normalizer = return_normalizer
94 | self.target_coeff = target_coeff
95 |
96 | def initialize(self, observation_space, action_space):
97 | if self.observation_normalizer:
98 | self.observation_normalizer.initialize(observation_space.shape)
99 | self.actor.initialize(
100 | observation_space, action_space, self.observation_normalizer)
101 | self.critic_1.initialize(
102 | observation_space, action_space, self.observation_normalizer,
103 | self.return_normalizer)
104 | self.critic_2.initialize(
105 | observation_space, action_space, self.observation_normalizer,
106 | self.return_normalizer)
107 | self.target_actor.initialize(
108 | observation_space, action_space, self.observation_normalizer)
109 | self.target_critic_1.initialize(
110 | observation_space, action_space, self.observation_normalizer,
111 | self.return_normalizer)
112 | self.target_critic_2.initialize(
113 | observation_space, action_space, self.observation_normalizer,
114 | self.return_normalizer)
115 | dummy_observations = tf.zeros((1,) + observation_space.shape)
116 | dummy_actions = tf.zeros((1,) + action_space.shape)
117 | self.actor(dummy_observations)
118 | self.critic_1(dummy_observations, dummy_actions)
119 | self.critic_2(dummy_observations, dummy_actions)
120 | self.target_actor(dummy_observations)
121 | self.target_critic_1(dummy_observations, dummy_actions)
122 | self.target_critic_2(dummy_observations, dummy_actions)
123 | self.online_variables = (
124 | self.actor.trainable_variables +
125 | self.critic_1.trainable_variables +
126 | self.critic_2.trainable_variables)
127 | self.target_variables = (
128 | self.target_actor.trainable_variables +
129 | self.target_critic_1.trainable_variables +
130 | self.target_critic_2.trainable_variables)
131 | self.assign_targets()
132 |
133 | def assign_targets(self):
134 | for o, t in zip(self.online_variables, self.target_variables):
135 | t.assign(o)
136 |
137 | def update_targets(self):
138 | for o, t in zip(self.online_variables, self.target_variables):
139 | t.assign((1 - self.target_coeff) * t + self.target_coeff * o)
140 |
--------------------------------------------------------------------------------
/tonic/tensorflow/models/actors.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow_probability as tfp
3 |
4 | from tonic.tensorflow import models
5 |
6 |
7 | FLOAT_EPSILON = 1e-8
8 |
9 |
10 | class SquashedMultivariateNormalDiag:
11 | def __init__(self, loc, scale):
12 | self._distribution = tfp.distributions.MultivariateNormalDiag(
13 | loc, scale)
14 |
15 | def sample_with_log_prob(self, shape=()):
16 | samples = self._distribution.sample(shape)
17 | squashed_samples = tf.tanh(samples)
18 | log_probs = self._distribution.log_prob(samples)
19 | log_probs -= tf.reduce_sum(
20 | tf.math.log(1 - squashed_samples ** 2 + 1e-6), axis=-1)
21 | return squashed_samples, log_probs
22 |
23 | def sample(self, shape=()):
24 | samples = self._distribution.sample(shape)
25 | return tf.tanh(samples)
26 |
27 | def log_prob(self, samples):
28 | '''Required unsquashed samples cannot be accurately recovered.'''
29 | raise NotImplementedError(
30 | 'Not implemented to avoid approximation errors. '
31 | 'Use sample_with_log_prob directly.')
32 |
33 | def mode(self):
34 | return tf.tanh(self._distribution.mode())
35 |
36 |
37 | class DetachedScaleGaussianPolicyHead(tf.keras.Model):
38 | def __init__(
39 | self, loc_activation='tanh', dense_loc_kwargs=None, log_scale_init=0.,
40 | scale_min=1e-4, scale_max=1.,
41 | distribution=tfp.distributions.MultivariateNormalDiag
42 | ):
43 | super().__init__()
44 | self.loc_activation = loc_activation
45 | if dense_loc_kwargs is None:
46 | dense_loc_kwargs = models.default_dense_kwargs()
47 | self.dense_loc_kwargs = dense_loc_kwargs
48 | self.log_scale_init = log_scale_init
49 | self.scale_min = scale_min
50 | self.scale_max = scale_max
51 | self.distribution = distribution
52 |
53 | def initialize(self, action_size):
54 | self.loc_layer = tf.keras.layers.Dense(
55 | action_size, self.loc_activation, **self.dense_loc_kwargs)
56 | log_scale = [[self.log_scale_init] * action_size]
57 | self.log_scale = tf.Variable(log_scale, dtype=tf.float32)
58 |
59 | def call(self, inputs):
60 | loc = self.loc_layer(inputs)
61 | batch_size = tf.shape(inputs)[0]
62 | scale = tf.math.softplus(self.log_scale) + FLOAT_EPSILON
63 | scale = tf.clip_by_value(scale, self.scale_min, self.scale_max)
64 | scale = tf.tile(scale, (batch_size, 1))
65 | return self.distribution(loc, scale)
66 |
67 |
68 | class GaussianPolicyHead(tf.keras.Model):
69 | def __init__(
70 | self, loc_activation='tanh', dense_loc_kwargs=None,
71 | scale_activation='softplus', scale_min=1e-4, scale_max=1,
72 | dense_scale_kwargs=None,
73 | distribution=tfp.distributions.MultivariateNormalDiag
74 | ):
75 | super().__init__()
76 | self.loc_activation = loc_activation
77 | if dense_loc_kwargs is None:
78 | dense_loc_kwargs = models.default_dense_kwargs()
79 | self.dense_loc_kwargs = dense_loc_kwargs
80 | self.scale_activation = scale_activation
81 | self.scale_min = scale_min
82 | self.scale_max = scale_max
83 | if dense_scale_kwargs is None:
84 | dense_scale_kwargs = models.default_dense_kwargs()
85 | self.dense_scale_kwargs = dense_scale_kwargs
86 | self.distribution = distribution
87 |
88 | def initialize(self, action_size):
89 | self.loc_layer = tf.keras.layers.Dense(
90 | action_size, self.loc_activation, **self.dense_loc_kwargs)
91 | self.scale_layer = tf.keras.layers.Dense(
92 | action_size, self.scale_activation, **self.dense_scale_kwargs)
93 |
94 | def call(self, inputs):
95 | loc = self.loc_layer(inputs)
96 | scale = self.scale_layer(inputs)
97 | scale = tf.clip_by_value(scale, self.scale_min, self.scale_max)
98 | return self.distribution(loc, scale)
99 |
100 |
101 | class DeterministicPolicyHead(tf.keras.Model):
102 | def __init__(self, activation='tanh', dense_kwargs=None):
103 | super().__init__()
104 | self.activation = activation
105 | if dense_kwargs is None:
106 | dense_kwargs = models.default_dense_kwargs()
107 | self.dense_kwargs = dense_kwargs
108 |
109 | def initialize(self, action_size):
110 | self.action_layer = tf.keras.layers.Dense(
111 | action_size, self.activation, **self.dense_kwargs)
112 |
113 | def call(self, inputs):
114 | return self.action_layer(inputs)
115 |
116 |
117 | class Actor(tf.keras.Model):
118 | def __init__(self, encoder, torso, head):
119 | super().__init__()
120 | self.encoder = encoder
121 | self.torso = torso
122 | self.head = head
123 |
124 | def initialize(
125 | self, observation_space, action_space, observation_normalizer=None
126 | ):
127 | self.encoder.initialize(observation_normalizer)
128 | self.head.initialize(action_space.shape[0])
129 |
130 | def call(self, *inputs):
131 | out = self.encoder(*inputs)
132 | out = self.torso(out)
133 | return self.head(out)
134 |
--------------------------------------------------------------------------------
/tonic/tensorflow/models/critics.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from tonic.tensorflow import models
4 |
5 |
6 | class ValueHead(tf.keras.Model):
7 | def __init__(self, dense_kwargs=None):
8 | super().__init__()
9 | if dense_kwargs is None:
10 | dense_kwargs = models.default_dense_kwargs()
11 | self.v_layer = tf.keras.layers.Dense(1, **dense_kwargs)
12 |
13 | def initialize(self, return_normalizer=None):
14 | self.return_normalizer = return_normalizer
15 |
16 | def call(self, inputs):
17 | out = self.v_layer(inputs)
18 | out = tf.squeeze(out, -1)
19 | if self.return_normalizer:
20 | out = self.return_normalizer(out)
21 | return out
22 |
23 |
24 | class CategoricalWithSupport:
25 | def __init__(self, values, logits):
26 | self.values = values
27 | self.logits = logits
28 | self.probabilities = tf.nn.softmax(logits)
29 |
30 | def mean(self):
31 | return tf.reduce_sum(self.probabilities * self.values, axis=-1)
32 |
33 | def project(self, returns):
34 | vmin, vmax = self.values[0], self.values[-1]
35 | d_pos = tf.concat([self.values, vmin[None]], 0)[1:]
36 | d_pos = (d_pos - self.values)[None, :, None]
37 | d_neg = tf.concat([vmax[None], self.values], 0)[:-1]
38 | d_neg = (self.values - d_neg)[None, :, None]
39 |
40 | clipped_returns = tf.clip_by_value(returns, vmin, vmax)
41 | delta_values = clipped_returns[:, None] - self.values[None, :, None]
42 | delta_sign = tf.cast(delta_values >= 0, tf.float32)
43 | delta_hat = ((delta_sign * delta_values / d_pos) -
44 | ((1 - delta_sign) * delta_values / d_neg))
45 | delta_clipped = tf.clip_by_value(1 - delta_hat, 0, 1)
46 |
47 | return tf.reduce_sum(delta_clipped * self.probabilities[:, None], 2)
48 |
49 |
50 | class DistributionalValueHead(tf.keras.Model):
51 | def __init__(self, vmin, vmax, num_atoms, dense_kwargs=None):
52 | super().__init__()
53 | if dense_kwargs is None:
54 | dense_kwargs = models.default_dense_kwargs()
55 | self.distributional_layer = tf.keras.layers.Dense(
56 | num_atoms, **dense_kwargs)
57 | self.values = tf.cast(tf.linspace(vmin, vmax, num_atoms), tf.float32)
58 |
59 | def initialize(self, return_normalizer=None):
60 | if return_normalizer:
61 | raise ValueError(
62 | 'Return normalizers cannot be used with distributional value'
63 | 'heads.')
64 |
65 | def call(self, inputs):
66 | logits = self.distributional_layer(inputs)
67 | return CategoricalWithSupport(values=self.values, logits=logits)
68 |
69 |
70 | class Critic(tf.keras.Model):
71 | def __init__(self, encoder, torso, head):
72 | super().__init__()
73 | self.encoder = encoder
74 | self.torso = torso
75 | self.head = head
76 |
77 | def initialize(
78 | self, observation_space, action_space, observation_normalizer=None,
79 | return_normalizer=None
80 | ):
81 | self.encoder.initialize(observation_normalizer)
82 | self.head.initialize(return_normalizer)
83 |
84 | def call(self, *inputs):
85 | out = self.encoder(*inputs)
86 | out = self.torso(out)
87 | return self.head(out)
88 |
--------------------------------------------------------------------------------
/tonic/tensorflow/models/encoders.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | class ObservationEncoder(tf.keras.Model):
5 | def initialize(self, observation_normalizer=None):
6 | self.observation_normalizer = observation_normalizer
7 |
8 | def call(self, observations):
9 | if self.observation_normalizer:
10 | observations = self.observation_normalizer(observations)
11 | return observations
12 |
13 |
14 | class ObservationActionEncoder(tf.keras.Model):
15 | def initialize(self, observation_normalizer=None):
16 | self.observation_normalizer = observation_normalizer
17 |
18 | def call(self, observations, actions):
19 | if self.observation_normalizer:
20 | observations = self.observation_normalizer(observations)
21 | return tf.concat([observations, actions], axis=-1)
22 |
--------------------------------------------------------------------------------
/tonic/tensorflow/models/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def default_dense_kwargs():
5 | return dict(
6 | kernel_initializer=tf.keras.initializers.VarianceScaling(
7 | scale=1 / 3, mode='fan_in', distribution='uniform'),
8 | bias_initializer=tf.keras.initializers.VarianceScaling(
9 | scale=1 / 3, mode='fan_in', distribution='uniform'))
10 |
11 |
12 | def mlp(units, activation, dense_kwargs=None):
13 | if dense_kwargs is None:
14 | dense_kwargs = default_dense_kwargs()
15 | layers = [tf.keras.layers.Dense(u, activation, **dense_kwargs)
16 | for u in units]
17 | return tf.keras.Sequential(layers)
18 |
19 |
20 | MLP = mlp
21 |
--------------------------------------------------------------------------------
/tonic/tensorflow/normalizers/__init__.py:
--------------------------------------------------------------------------------
1 | from .mean_stds import MeanStd
2 | from .returns import Return
3 |
4 |
5 | __all__ = [MeanStd, Return]
6 |
--------------------------------------------------------------------------------
/tonic/tensorflow/normalizers/mean_stds.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 |
5 | class MeanStd(tf.keras.Model):
6 | def __init__(self, mean=0, std=1, clip=None, shape=None):
7 | super().__init__(name='global_mean_std')
8 | self.mean = mean
9 | self.std = std
10 | self.clip = clip
11 | self.count = 0
12 | self.new_sum = 0
13 | self.new_sum_sq = 0
14 | self.new_count = 0
15 | self.eps = 1e-2
16 | if shape:
17 | self.initialize(shape)
18 |
19 | def initialize(self, shape):
20 | if isinstance(self.mean, (int, float)):
21 | self.mean = np.full(shape, self.mean, np.float32)
22 | else:
23 | self.mean = np.array(self.mean, np.float32)
24 | if isinstance(self.std, (int, float)):
25 | self.std = np.full(shape, self.std, np.float32)
26 | else:
27 | self.std = np.array(self.std, np.float32)
28 | self.mean_sq = np.square(self.mean)
29 | self._mean = tf.Variable(self.mean, trainable=False, name='mean')
30 | self._std = tf.Variable(self.std, trainable=False, name='std')
31 |
32 | def call(self, val):
33 | val = (val - self._mean) / self._std
34 | if self.clip is not None:
35 | val = tf.clip_by_value(val, -self.clip, self.clip)
36 | return val
37 |
38 | def unnormalize(self, val):
39 | return val * self._std + self._mean
40 |
41 | def record(self, values):
42 | for val in values:
43 | self.new_sum += val
44 | self.new_sum_sq += np.square(val)
45 | self.new_count += 1
46 |
47 | # Careful: do not use in @tf.function
48 | def update(self):
49 | new_count = self.count + self.new_count
50 | new_mean = self.new_sum / self.new_count
51 | new_mean_sq = self.new_sum_sq / self.new_count
52 | w_old = self.count / new_count
53 | w_new = self.new_count / new_count
54 | self.mean = w_old * self.mean + w_new * new_mean
55 | self.mean_sq = w_old * self.mean_sq + w_new * new_mean_sq
56 | self.std = self._compute_std(self.mean, self.mean_sq)
57 | self.count = new_count
58 | self.new_count = 0
59 | self.new_sum = 0
60 | self.new_sum_sq = 0
61 | self._update(self.mean.astype(np.float32), self.std.astype(np.float32))
62 |
63 | def _compute_std(self, mean, mean_sq):
64 | var = mean_sq - np.square(mean)
65 | var = np.maximum(var, 0)
66 | std = np.sqrt(var)
67 | std = np.maximum(std, self.eps)
68 | return std
69 |
70 | @tf.function
71 | def _update(self, mean, std):
72 | self._mean.assign(mean)
73 | self._std.assign(std)
74 |
--------------------------------------------------------------------------------
/tonic/tensorflow/normalizers/returns.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 |
5 | class Return(tf.keras.Model):
6 | def __init__(self, discount_factor):
7 | super().__init__(name='reward_normalizer')
8 | assert 0 <= discount_factor < 1
9 | self.coefficient = 1 / (1 - discount_factor)
10 | self.min_reward = np.float32(-1)
11 | self.max_reward = np.float32(1)
12 | self._low = tf.Variable(
13 | self.coefficient * self.min_reward, dtype=tf.float32,
14 | trainable=False, name='low')
15 | self._high = tf.Variable(
16 | self.coefficient * self.max_reward, dtype=np.float32,
17 | trainable=False, name='high')
18 |
19 | def call(self, val):
20 | val = tf.sigmoid(val)
21 | return self._low + val * (self._high - self._low)
22 |
23 | def record(self, values):
24 | for val in values:
25 | if val < self.min_reward:
26 | self.min_reward = np.float32(val)
27 | elif val > self.max_reward:
28 | self.max_reward = np.float32(val)
29 |
30 | # Careful: do not use in @tf.function
31 | def update(self):
32 | self._update(self.min_reward, self.max_reward)
33 |
34 | @tf.function
35 | def _update(self, min_reward, max_reward):
36 | self._low.assign(self.coefficient * min_reward)
37 | self._high.assign(self.coefficient * max_reward)
38 |
--------------------------------------------------------------------------------
/tonic/tensorflow/updaters/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import merge_first_two_dims
2 | from .utils import tile
3 |
4 | from .actors import ClippedRatio # noqa
5 | from .actors import DeterministicPolicyGradient
6 | from .actors import DistributionalDeterministicPolicyGradient
7 | from .actors import MaximumAPosterioriPolicyOptimization
8 | from .actors import StochasticPolicyGradient
9 | from .actors import TrustRegionPolicyGradient
10 | from .actors import TwinCriticSoftDeterministicPolicyGradient
11 |
12 | from .critics import DeterministicQLearning
13 | from .critics import DistributionalDeterministicQLearning
14 | from .critics import ExpectedSARSA
15 | from .critics import QRegression
16 | from .critics import TargetActionNoise
17 | from .critics import TwinCriticDeterministicQLearning
18 | from .critics import TwinCriticDistributionalDeterministicQLearning
19 | from .critics import TwinCriticSoftQLearning
20 | from .critics import VRegression
21 |
22 | from .optimizers import ConjugateGradient
23 |
24 |
25 | __all__ = [
26 | merge_first_two_dims, tile, ClippedRatio, DeterministicPolicyGradient,
27 | DistributionalDeterministicPolicyGradient,
28 | MaximumAPosterioriPolicyOptimization, StochasticPolicyGradient,
29 | TrustRegionPolicyGradient, TwinCriticSoftDeterministicPolicyGradient,
30 | DeterministicQLearning, DistributionalDeterministicQLearning,
31 | ExpectedSARSA, QRegression, TargetActionNoise,
32 | TwinCriticDeterministicQLearning,
33 | TwinCriticDistributionalDeterministicQLearning, TwinCriticSoftQLearning,
34 | VRegression, ConjugateGradient]
35 |
--------------------------------------------------------------------------------
/tonic/tensorflow/updaters/critics.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from tonic.tensorflow import updaters
4 |
5 |
6 | class VRegression:
7 | def __init__(self, loss=None, optimizer=None, gradient_clip=0):
8 | self.loss = loss or tf.keras.losses.MeanSquaredError()
9 | self.optimizer = optimizer or \
10 | tf.keras.optimizers.Adam(lr=1e-3, epsilon=1e-8)
11 | self.gradient_clip = gradient_clip
12 |
13 | def initialize(self, model):
14 | self.model = model
15 | self.variables = self.model.critic.trainable_variables
16 |
17 | @tf.function
18 | def __call__(self, observations, returns):
19 | with tf.GradientTape() as tape:
20 | values = self.model.critic(observations)
21 | loss = self.loss(returns, values)
22 |
23 | gradients = tape.gradient(loss, self.variables)
24 | if self.gradient_clip > 0:
25 | gradients = tf.clip_by_global_norm(
26 | gradients, self.gradient_clip)[0]
27 | self.optimizer.apply_gradients(zip(gradients, self.variables))
28 |
29 | return dict(loss=loss, v=values)
30 |
31 |
32 | class QRegression:
33 | def __init__(self, loss=None, optimizer=None, gradient_clip=0):
34 | self.loss = loss or tf.keras.losses.MeanSquaredError()
35 | self.optimizer = optimizer or \
36 | tf.keras.optimizers.Adam(lr=1e-3, epsilon=1e-8)
37 | self.gradient_clip = gradient_clip
38 |
39 | def initialize(self, model):
40 | self.model = model
41 | self.variables = self.model.critic.trainable_variables
42 |
43 | @tf.function
44 | def __call__(self, observations, actions, returns):
45 | with tf.GradientTape() as tape:
46 | values = self.model.critic(observations, actions)
47 | loss = self.loss(returns, values)
48 |
49 | gradients = tape.gradient(loss, self.variables)
50 | if self.gradient_clip > 0:
51 | gradients = tf.clip_by_global_norm(
52 | gradients, self.gradient_clip)[0]
53 | self.optimizer.apply_gradients(zip(gradients, self.variables))
54 |
55 | return dict(loss=loss, q=values)
56 |
57 |
58 | class DeterministicQLearning:
59 | def __init__(self, loss=None, optimizer=None, gradient_clip=0):
60 | self.loss = loss or tf.keras.losses.MeanSquaredError()
61 | self.optimizer = optimizer or \
62 | tf.keras.optimizers.Adam(lr=1e-3, epsilon=1e-8)
63 | self.gradient_clip = gradient_clip
64 |
65 | def initialize(self, model):
66 | self.model = model
67 | self.variables = self.model.critic.trainable_variables
68 |
69 | @tf.function
70 | def __call__(
71 | self, observations, actions, next_observations, rewards, discounts
72 | ):
73 | next_actions = self.model.target_actor(next_observations)
74 | next_values = self.model.target_critic(next_observations, next_actions)
75 | returns = rewards + discounts * next_values
76 |
77 | with tf.GradientTape() as tape:
78 | values = self.model.critic(observations, actions)
79 | loss = self.loss(returns, values)
80 |
81 | gradients = tape.gradient(loss, self.variables)
82 | if self.gradient_clip > 0:
83 | gradients = tf.clip_by_global_norm(
84 | gradients, self.gradient_clip)[0]
85 | self.optimizer.apply_gradients(zip(gradients, self.variables))
86 |
87 | return dict(loss=loss, q=values)
88 |
89 |
90 | class DistributionalDeterministicQLearning:
91 | def __init__(self, optimizer=None, gradient_clip=0):
92 | self.optimizer = optimizer or \
93 | tf.keras.optimizers.Adam(lr=1e-3, epsilon=1e-8)
94 | self.gradient_clip = gradient_clip
95 |
96 | def initialize(self, model):
97 | self.model = model
98 | self.variables = self.model.critic.trainable_variables
99 |
100 | @tf.function
101 | def __call__(
102 | self, observations, actions, next_observations, rewards, discounts
103 | ):
104 | next_actions = self.model.target_actor(next_observations)
105 | next_value_distributions = self.model.target_critic(
106 | next_observations, next_actions)
107 |
108 | values = next_value_distributions.values
109 | returns = rewards[:, None] + discounts[:, None] * values
110 | targets = next_value_distributions.project(returns)
111 |
112 | with tf.GradientTape() as tape:
113 | value_distributions = self.model.critic(observations, actions)
114 | losses = tf.nn.softmax_cross_entropy_with_logits(
115 | logits=value_distributions.logits, labels=targets)
116 | loss = tf.reduce_mean(losses)
117 |
118 | gradients = tape.gradient(loss, self.variables)
119 | if self.gradient_clip > 0:
120 | gradients = tf.clip_by_global_norm(
121 | gradients, self.gradient_clip)[0]
122 | self.optimizer.apply_gradients(zip(gradients, self.variables))
123 |
124 | return dict(loss=loss)
125 |
126 |
127 | class TargetActionNoise:
128 | def __init__(self, scale=0.2, clip=0.5):
129 | self.scale = scale
130 | self.clip = clip
131 |
132 | def __call__(self, actions):
133 | noises = self.scale * tf.random.normal(actions.shape)
134 | noises = tf.clip_by_value(noises, -self.clip, self.clip)
135 | actions = actions + noises
136 | return tf.clip_by_value(actions, -1, 1)
137 |
138 |
139 | class TwinCriticDeterministicQLearning:
140 | def __init__(
141 | self, loss=None, optimizer=None, target_action_noise=None,
142 | gradient_clip=0
143 | ):
144 | self.loss = loss or tf.keras.losses.MeanSquaredError()
145 | self.optimizer = optimizer or \
146 | tf.keras.optimizers.Adam(lr=1e-3, epsilon=1e-8)
147 | self.target_action_noise = target_action_noise or \
148 | TargetActionNoise(scale=0.2, clip=0.5)
149 | self.gradient_clip = gradient_clip
150 |
151 | def initialize(self, model):
152 | self.model = model
153 | variables_1 = self.model.critic_1.trainable_variables
154 | variables_2 = self.model.critic_2.trainable_variables
155 | self.variables = variables_1 + variables_2
156 |
157 | @tf.function
158 | def __call__(
159 | self, observations, actions, next_observations, rewards, discounts
160 | ):
161 | next_actions = self.model.target_actor(next_observations)
162 | next_actions = self.target_action_noise(next_actions)
163 | next_values_1 = self.model.target_critic_1(
164 | next_observations, next_actions)
165 | next_values_2 = self.model.target_critic_2(
166 | next_observations, next_actions)
167 | next_values = tf.minimum(next_values_1, next_values_2)
168 | returns = rewards + discounts * next_values
169 |
170 | with tf.GradientTape() as tape:
171 | values_1 = self.model.critic_1(observations, actions)
172 | values_2 = self.model.critic_2(observations, actions)
173 | loss_1 = self.loss(returns, values_1)
174 | loss_2 = self.loss(returns, values_2)
175 | loss = loss_1 + loss_2
176 |
177 | gradients = tape.gradient(loss, self.variables)
178 | if self.gradient_clip > 0:
179 | gradients = tf.clip_by_global_norm(
180 | gradients, self.gradient_clip)[0]
181 | self.optimizer.apply_gradients(zip(gradients, self.variables))
182 |
183 | return dict(loss=loss, q1=values_1, q2=values_2)
184 |
185 |
186 | class TwinCriticSoftQLearning:
187 | def __init__(
188 | self, loss=None, optimizer=None, entropy_coeff=0.2, gradient_clip=0
189 | ):
190 | self.loss = loss or tf.keras.losses.MeanSquaredError()
191 | self.optimizer = optimizer or \
192 | tf.keras.optimizers.Adam(lr=1e-3, epsilon=1e-8)
193 | self.entropy_coeff = entropy_coeff
194 | self.gradient_clip = gradient_clip
195 |
196 | def initialize(self, model):
197 | self.model = model
198 | variables_1 = self.model.critic_1.trainable_variables
199 | variables_2 = self.model.critic_2.trainable_variables
200 | self.variables = variables_1 + variables_2
201 |
202 | @tf.function
203 | def __call__(
204 | self, observations, actions, next_observations, rewards, discounts
205 | ):
206 | next_distributions = self.model.actor(next_observations)
207 | if hasattr(next_distributions, 'sample_with_log_prob'):
208 | outs = next_distributions.sample_with_log_prob()
209 | next_actions, next_log_probs = outs
210 | else:
211 | next_actions = next_distributions.sample()
212 | next_log_probs = next_distributions.log_prob(next_actions)
213 | next_values_1 = self.model.target_critic_1(
214 | next_observations, next_actions)
215 | next_values_2 = self.model.target_critic_2(
216 | next_observations, next_actions)
217 | next_values = tf.minimum(next_values_1, next_values_2)
218 | returns = rewards + discounts * (
219 | next_values - self.entropy_coeff * next_log_probs)
220 |
221 | with tf.GradientTape() as tape:
222 | values_1 = self.model.critic_1(observations, actions)
223 | values_2 = self.model.critic_2(observations, actions)
224 | loss_1 = self.loss(returns, values_1)
225 | loss_2 = self.loss(returns, values_2)
226 | loss = loss_1 + loss_2
227 |
228 | gradients = tape.gradient(loss, self.variables)
229 | if self.gradient_clip > 0:
230 | gradients = tf.clip_by_global_norm(
231 | gradients, self.gradient_clip)[0]
232 | self.optimizer.apply_gradients(zip(gradients, self.variables))
233 |
234 | return dict(loss=loss, q1=values_1, q2=values_2)
235 |
236 |
237 | class ExpectedSARSA:
238 | def __init__(
239 | self, num_samples=20, loss=None, optimizer=None, gradient_clip=0
240 | ):
241 | self.num_samples = num_samples
242 | self.loss = loss or tf.keras.losses.MeanSquaredError()
243 | self.optimizer = optimizer or \
244 | tf.keras.optimizers.Adam(lr=1e-3, epsilon=1e-8)
245 | self.gradient_clip = gradient_clip
246 |
247 | def initialize(self, model):
248 | self.model = model
249 | self.variables = self.model.critic.trainable_variables
250 |
251 | @tf.function
252 | def __call__(
253 | self, observations, actions, next_observations, rewards, discounts
254 | ):
255 | # Approximate the expected next values.
256 | next_target_distributions = self.model.target_actor(next_observations)
257 | next_actions = next_target_distributions.sample(self.num_samples)
258 | next_actions = updaters.merge_first_two_dims(next_actions)
259 | next_observations = updaters.tile(next_observations, self.num_samples)
260 | next_observations = updaters.merge_first_two_dims(next_observations)
261 | next_values = self.model.target_critic(next_observations, next_actions)
262 | next_values = tf.reshape(next_values, (self.num_samples, -1))
263 | next_values = tf.reduce_mean(next_values, axis=0)
264 | returns = rewards + discounts * next_values
265 |
266 | with tf.GradientTape() as tape:
267 | values = self.model.critic(observations, actions)
268 | loss = self.loss(returns, values)
269 |
270 | gradients = tape.gradient(loss, self.variables)
271 | if self.gradient_clip > 0:
272 | gradients = tf.clip_by_global_norm(
273 | gradients, self.gradient_clip)[0]
274 | self.optimizer.apply_gradients(zip(gradients, self.variables))
275 |
276 | return dict(loss=loss, q=values)
277 |
278 |
279 | class TwinCriticDistributionalDeterministicQLearning:
280 | def __init__(
281 | self, optimizer=None, target_action_noise=None, gradient_clip=0
282 | ):
283 | self.optimizer = optimizer or \
284 | tf.keras.optimizers.Adam(lr=1e-3, epsilon=1e-8)
285 | self.target_action_noise = target_action_noise or \
286 | TargetActionNoise(scale=0.2, clip=0.5)
287 | self.gradient_clip = gradient_clip
288 |
289 | def initialize(self, model):
290 | self.model = model
291 | variables_1 = self.model.critic_1.trainable_variables
292 | variables_2 = self.model.critic_2.trainable_variables
293 | self.variables = variables_1 + variables_2
294 |
295 | @tf.function
296 | def __call__(
297 | self, observations, actions, next_observations, rewards, discounts
298 | ):
299 | next_actions = self.model.target_actor(next_observations)
300 | next_actions = self.target_action_noise(next_actions)
301 | next_value_distributions_1 = self.model.target_critic_1(
302 | next_observations, next_actions)
303 | next_value_distributions_2 = self.model.target_critic_2(
304 | next_observations, next_actions)
305 |
306 | values = next_value_distributions_1.values
307 | returns = rewards[:, None] + discounts[:, None] * values
308 | targets_1 = next_value_distributions_1.project(returns)
309 | targets_2 = next_value_distributions_2.project(returns)
310 | next_values_1 = next_value_distributions_1.mean()
311 | next_values_2 = next_value_distributions_2.mean()
312 | twin_next_values = tf.concat(
313 | [next_values_1[None], next_values_2[None]], axis=0)
314 | indices = tf.argmin(twin_next_values, axis=0, output_type=tf.int32)
315 | twin_targets = tf.concat([targets_1[None], targets_2[None]], axis=0)
316 | batch_size = tf.shape(observations)[0]
317 | indices = tf.stack([indices, tf.range(batch_size)], axis=-1)
318 | targets = tf.gather_nd(twin_targets, indices)
319 |
320 | with tf.GradientTape() as tape:
321 | value_distributions_1 = self.model.critic_1(observations, actions)
322 | losses_1 = tf.nn.softmax_cross_entropy_with_logits(
323 | logits=value_distributions_1.logits, labels=targets)
324 | value_distributions_2 = self.model.critic_2(observations, actions)
325 | losses_2 = tf.nn.softmax_cross_entropy_with_logits(
326 | logits=value_distributions_2.logits, labels=targets)
327 | loss = tf.reduce_mean(losses_1) + tf.reduce_mean(losses_2)
328 |
329 | gradients = tape.gradient(loss, self.variables)
330 | if self.gradient_clip > 0:
331 | gradients = tf.clip_by_global_norm(
332 | gradients, self.gradient_clip)[0]
333 | self.optimizer.apply_gradients(zip(gradients, self.variables))
334 |
335 | return dict(loss=loss, q1=value_distributions_1.mean(),
336 | q2=value_distributions_2.mean())
337 |
--------------------------------------------------------------------------------
/tonic/tensorflow/updaters/optimizers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 |
5 | FLOAT_EPSILON = 1e-8
6 |
7 |
8 | def flat_concat(xs):
9 | return tf.concat([tf.reshape(x, (-1,)) for x in xs], axis=0)
10 |
11 |
12 | def assign_params_from_flat(new_params, params):
13 | def flat_size(p):
14 | return int(np.prod(p.shape.as_list()))
15 | splits = tf.split(new_params, [flat_size(p) for p in params])
16 | new_params = [tf.reshape(p_new, p.shape)
17 | for p, p_new in zip(params, splits)]
18 | for p, p_new in zip(params, new_params):
19 | p.assign(p_new)
20 |
21 |
22 | class ConjugateGradient:
23 | def __init__(
24 | self, conjugate_gradient_steps=10, damping_coefficient=0.1,
25 | constraint_threshold=0.01, backtrack_steps=10,
26 | backtrack_coefficient=0.8
27 | ):
28 | self.conjugate_gradient_steps = conjugate_gradient_steps
29 | self.damping_coefficient = damping_coefficient
30 | self.constraint_threshold = constraint_threshold
31 | self.backtrack_steps = backtrack_steps
32 | self.backtrack_coefficient = backtrack_coefficient
33 |
34 | def optimize(self, loss_function, constraint_function, variables):
35 | @tf.function
36 | def _hx(x):
37 | with tf.GradientTape() as tape_2:
38 | with tf.GradientTape() as tape_1:
39 | f = constraint_function()
40 | gradient_1 = flat_concat(tape_1.gradient(f, variables))
41 | y = tf.reduce_sum(gradient_1 * x)
42 | gradient_2 = flat_concat(tape_2.gradient(y, variables))
43 |
44 | if self.damping_coefficient > 0:
45 | gradient_2 += self.damping_coefficient * x
46 |
47 | return gradient_2
48 |
49 | def _cg(b):
50 | x = np.zeros_like(b)
51 | r = b.copy()
52 | p = r.copy()
53 | r_dot_old = np.dot(r, r)
54 | if r_dot_old == 0:
55 | return None
56 |
57 | for _ in range(self.conjugate_gradient_steps):
58 | z = _hx(p).numpy()
59 | alpha = r_dot_old / (np.dot(p, z) + FLOAT_EPSILON)
60 | x += alpha * p
61 | r -= alpha * z
62 | r_dot_new = np.dot(r, r)
63 | p = r + (r_dot_new / r_dot_old) * p
64 | r_dot_old = r_dot_new
65 | return x
66 |
67 | @tf.function
68 | def _update(alpha, conjugate_gradient, step, start_variables):
69 | new_variables = start_variables - alpha * conjugate_gradient * step
70 | assign_params_from_flat(new_variables, variables)
71 | constraint = constraint_function()
72 | loss = loss_function()
73 | return constraint, loss
74 |
75 | start_variables = flat_concat(variables)
76 |
77 | with tf.GradientTape() as tape:
78 | loss = loss_function()
79 | grad = flat_concat(tape.gradient(loss, variables)).numpy()
80 | start_loss = loss.numpy()
81 |
82 | conjugate_gradient = _cg(grad)
83 | if conjugate_gradient is None:
84 | constraint = tf.convert_to_tensor(0.)
85 | loss = tf.convert_to_tensor(0.)
86 | steps = tf.convert_to_tensor(0)
87 | return constraint, loss, steps
88 |
89 | alpha = np.sqrt(2 * self.constraint_threshold / np.dot(
90 | conjugate_gradient, _hx(conjugate_gradient)) + FLOAT_EPSILON)
91 | alpha = tf.convert_to_tensor(alpha, tf.float32)
92 |
93 | if self.backtrack_steps is None or self.backtrack_coefficient is None:
94 | constraint, loss = _update(
95 | alpha, conjugate_gradient, 1, start_variables)
96 | return constraint, loss
97 |
98 | for i in range(self.backtrack_steps):
99 | step = tf.convert_to_tensor(
100 | self.backtrack_coefficient ** i, tf.float32)
101 | constraint, loss = _update(
102 | alpha, conjugate_gradient, step, start_variables)
103 |
104 | if constraint <= self.constraint_threshold and loss <= start_loss:
105 | break
106 |
107 | if i == self.backtrack_steps - 1:
108 | step = tf.convert_to_tensor(0., tf.float32)
109 | constraint, loss = _update(
110 | alpha, conjugate_gradient, step, start_variables)
111 | i = self.backtrack_steps
112 |
113 | return constraint, loss, tf.convert_to_tensor(i + 1, dtype=tf.int32)
114 |
--------------------------------------------------------------------------------
/tonic/tensorflow/updaters/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def tile(x, n):
5 | return tf.tile(x[None], [n] + [1] * len(x.shape))
6 |
7 |
8 | def merge_first_two_dims(x):
9 | return tf.reshape(x, [x.shape[0] * x.shape[1]] + x.shape[2:])
10 |
--------------------------------------------------------------------------------
/tonic/torch/__init__.py:
--------------------------------------------------------------------------------
1 | from . import agents, models, normalizers, updaters
2 |
3 |
4 | __all__ = [agents, models, normalizers, updaters]
5 |
--------------------------------------------------------------------------------
/tonic/torch/agents/__init__.py:
--------------------------------------------------------------------------------
1 | from .agent import Agent
2 |
3 | from .a2c import A2C # noqa
4 | from .ddpg import DDPG
5 | from .d4pg import D4PG # noqa
6 | from .mpo import MPO
7 | from .ppo import PPO
8 | from .sac import SAC
9 | from .td3 import TD3
10 | from .trpo import TRPO
11 |
12 |
13 | __all__ = [Agent, A2C, DDPG, D4PG, MPO, PPO, SAC, TD3, TRPO]
14 |
--------------------------------------------------------------------------------
/tonic/torch/agents/a2c.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from tonic import logger, replays # noqa
4 | from tonic.torch import agents, models, normalizers, updaters
5 |
6 |
7 | def default_model():
8 | return models.ActorCritic(
9 | actor=models.Actor(
10 | encoder=models.ObservationEncoder(),
11 | torso=models.MLP((64, 64), torch.nn.Tanh),
12 | head=models.DetachedScaleGaussianPolicyHead()),
13 | critic=models.Critic(
14 | encoder=models.ObservationEncoder(),
15 | torso=models.MLP((64, 64), torch.nn.Tanh),
16 | head=models.ValueHead()),
17 | observation_normalizer=normalizers.MeanStd())
18 |
19 |
20 | class A2C(agents.Agent):
21 | '''Advantage Actor Critic (aka Vanilla Policy Gradient).
22 | A3C: https://arxiv.org/pdf/1602.01783.pdf
23 | '''
24 |
25 | def __init__(
26 | self, model=None, replay=None, actor_updater=None, critic_updater=None
27 | ):
28 | self.model = model or default_model()
29 | self.replay = replay or replays.Segment()
30 | self.actor_updater = actor_updater or \
31 | updaters.StochasticPolicyGradient()
32 | self.critic_updater = critic_updater or updaters.VRegression()
33 |
34 | def initialize(self, observation_space, action_space, seed=None):
35 | super().initialize(seed=seed)
36 | self.model.initialize(observation_space, action_space)
37 | self.replay.initialize(seed)
38 | self.actor_updater.initialize(self.model)
39 | self.critic_updater.initialize(self.model)
40 |
41 | def step(self, observations, steps):
42 | # Sample actions and get their log-probabilities for training.
43 | actions, log_probs = self._step(observations)
44 | actions = actions.numpy()
45 | log_probs = log_probs.numpy()
46 |
47 | # Keep some values for the next update.
48 | self.last_observations = observations.copy()
49 | self.last_actions = actions.copy()
50 | self.last_log_probs = log_probs.copy()
51 |
52 | return actions
53 |
54 | def test_step(self, observations, steps):
55 | # Sample actions for testing.
56 | return self._test_step(observations).numpy()
57 |
58 | def update(self, observations, rewards, resets, terminations, steps):
59 | # Store the last transitions in the replay.
60 | self.replay.store(
61 | observations=self.last_observations, actions=self.last_actions,
62 | next_observations=observations, rewards=rewards, resets=resets,
63 | terminations=terminations, log_probs=self.last_log_probs)
64 |
65 | # Prepare to update the normalizers.
66 | if self.model.observation_normalizer:
67 | self.model.observation_normalizer.record(self.last_observations)
68 | if self.model.return_normalizer:
69 | self.model.return_normalizer.record(rewards)
70 |
71 | # Update the model if the replay is ready.
72 | if self.replay.ready():
73 | self._update()
74 |
75 | def _step(self, observations):
76 | observations = torch.as_tensor(observations, dtype=torch.float32)
77 | with torch.no_grad():
78 | distributions = self.model.actor(observations)
79 | if hasattr(distributions, 'sample_with_log_prob'):
80 | actions, log_probs = distributions.sample_with_log_prob()
81 | else:
82 | actions = distributions.sample()
83 | log_probs = distributions.log_prob(actions)
84 | log_probs = log_probs.sum(dim=-1)
85 | return actions, log_probs
86 |
87 | def _test_step(self, observations):
88 | observations = torch.as_tensor(observations, dtype=torch.float32)
89 | with torch.no_grad():
90 | return self.model.actor(observations).sample()
91 |
92 | def _evaluate(self, observations, next_observations):
93 | observations = torch.as_tensor(observations, dtype=torch.float32)
94 | next_observations = torch.as_tensor(
95 | next_observations, dtype=torch.float32)
96 | with torch.no_grad():
97 | values = self.model.critic(observations)
98 | next_values = self.model.critic(next_observations)
99 | return values, next_values
100 |
101 | def _update(self):
102 | # Compute the lambda-returns.
103 | batch = self.replay.get_full('observations', 'next_observations')
104 | values, next_values = self._evaluate(**batch)
105 | values, next_values = values.numpy(), next_values.numpy()
106 | self.replay.compute_returns(values, next_values)
107 |
108 | # Update the actor once.
109 | keys = 'observations', 'actions', 'advantages', 'log_probs'
110 | batch = self.replay.get_full(*keys)
111 | batch = {k: torch.as_tensor(v) for k, v in batch.items()}
112 | infos = self.actor_updater(**batch)
113 | for k, v in infos.items():
114 | logger.store('actor/' + k, v.numpy())
115 |
116 | # Update the critic multiple times.
117 | for batch in self.replay.get('observations', 'returns'):
118 | batch = {k: torch.as_tensor(v) for k, v in batch.items()}
119 | infos = self.critic_updater(**batch)
120 | for k, v in infos.items():
121 | logger.store('critic/' + k, v.numpy())
122 |
123 | # Update the normalizers.
124 | if self.model.observation_normalizer:
125 | self.model.observation_normalizer.update()
126 | if self.model.return_normalizer:
127 | self.model.return_normalizer.update()
128 |
--------------------------------------------------------------------------------
/tonic/torch/agents/agent.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from tonic import agents, logger # noqa
8 |
9 |
10 | class Agent(agents.Agent):
11 | def initialize(self, seed=None):
12 | if seed is not None:
13 | np.random.seed(seed)
14 | random.seed(seed)
15 | torch.manual_seed(seed)
16 |
17 | def save(self, path):
18 | path = path + '.pt'
19 | logger.log(f'\nSaving weights to {path}')
20 | os.makedirs(os.path.dirname(path), exist_ok=True)
21 | torch.save(self.model.state_dict(), path)
22 |
23 | def load(self, path):
24 | path = path + '.pt'
25 | logger.log(f'\nLoading weights from {path}')
26 | self.model.load_state_dict(torch.load(path))
27 |
--------------------------------------------------------------------------------
/tonic/torch/agents/d4pg.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from tonic import replays # noqa
4 | from tonic.torch import agents, models, normalizers, updaters
5 |
6 |
7 | def default_model():
8 | return models.ActorCriticWithTargets(
9 | actor=models.Actor(
10 | encoder=models.ObservationEncoder(),
11 | torso=models.MLP((256, 256), torch.nn.ReLU),
12 | head=models.DeterministicPolicyHead()),
13 | critic=models.Critic(
14 | encoder=models.ObservationActionEncoder(),
15 | torso=models.MLP((256, 256), torch.nn.ReLU),
16 | # These values are for the control suite with 0.99 discount.
17 | head=models.DistributionalValueHead(-150., 150., 51)),
18 | observation_normalizer=normalizers.MeanStd())
19 |
20 |
21 | class D4PG(agents.DDPG):
22 | '''Distributed Distributional Deterministic Policy Gradients.
23 | D4PG: https://arxiv.org/pdf/1804.08617.pdf
24 | '''
25 |
26 | def __init__(
27 | self, model=None, replay=None, exploration=None, actor_updater=None,
28 | critic_updater=None
29 | ):
30 | model = model or default_model()
31 | replay = replay or replays.Buffer(return_steps=5)
32 | actor_updater = actor_updater or \
33 | updaters.DistributionalDeterministicPolicyGradient()
34 | critic_updater = critic_updater or \
35 | updaters.DistributionalDeterministicQLearning()
36 | super().__init__(
37 | model, replay, exploration, actor_updater, critic_updater)
38 |
--------------------------------------------------------------------------------
/tonic/torch/agents/ddpg.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from tonic import explorations, logger, replays # noqa
4 | from tonic.torch import agents, models, normalizers, updaters
5 |
6 |
7 | def default_model():
8 | return models.ActorCriticWithTargets(
9 | actor=models.Actor(
10 | encoder=models.ObservationEncoder(),
11 | torso=models.MLP((256, 256), torch.nn.ReLU),
12 | head=models.DeterministicPolicyHead()),
13 | critic=models.Critic(
14 | encoder=models.ObservationActionEncoder(),
15 | torso=models.MLP((256, 256), torch.nn.ReLU),
16 | head=models.ValueHead()),
17 | observation_normalizer=normalizers.MeanStd())
18 |
19 |
20 | class DDPG(agents.Agent):
21 | '''Deep Deterministic Policy Gradient.
22 | DDPG: https://arxiv.org/pdf/1509.02971.pdf
23 | '''
24 |
25 | def __init__(
26 | self, model=None, replay=None, exploration=None, actor_updater=None,
27 | critic_updater=None
28 | ):
29 | self.model = model or default_model()
30 | self.replay = replay or replays.Buffer()
31 | self.exploration = exploration or explorations.NormalActionNoise()
32 | self.actor_updater = actor_updater or \
33 | updaters.DeterministicPolicyGradient()
34 | self.critic_updater = critic_updater or \
35 | updaters.DeterministicQLearning()
36 |
37 | def initialize(self, observation_space, action_space, seed=None):
38 | super().initialize(seed=seed)
39 | self.model.initialize(observation_space, action_space)
40 | self.replay.initialize(seed)
41 | self.exploration.initialize(self._policy, action_space, seed)
42 | self.actor_updater.initialize(self.model)
43 | self.critic_updater.initialize(self.model)
44 |
45 | def step(self, observations, steps):
46 | # Get actions from the actor and exploration method.
47 | actions = self.exploration(observations, steps)
48 |
49 | # Keep some values for the next update.
50 | self.last_observations = observations.copy()
51 | self.last_actions = actions.copy()
52 |
53 | return actions
54 |
55 | def test_step(self, observations, steps):
56 | # Greedy actions for testing.
57 | return self._greedy_actions(observations).numpy()
58 |
59 | def update(self, observations, rewards, resets, terminations, steps):
60 | # Store the last transitions in the replay.
61 | self.replay.store(
62 | observations=self.last_observations, actions=self.last_actions,
63 | next_observations=observations, rewards=rewards, resets=resets,
64 | terminations=terminations)
65 |
66 | # Prepare to update the normalizers.
67 | if self.model.observation_normalizer:
68 | self.model.observation_normalizer.record(self.last_observations)
69 | if self.model.return_normalizer:
70 | self.model.return_normalizer.record(rewards)
71 |
72 | # Update the model if the replay is ready.
73 | if self.replay.ready(steps):
74 | self._update(steps)
75 |
76 | self.exploration.update(resets)
77 |
78 | def _greedy_actions(self, observations):
79 | observations = torch.as_tensor(observations, dtype=torch.float32)
80 | with torch.no_grad():
81 | return self.model.actor(observations)
82 |
83 | def _policy(self, observations):
84 | return self._greedy_actions(observations).numpy()
85 |
86 | def _update(self, steps):
87 | keys = ('observations', 'actions', 'next_observations', 'rewards',
88 | 'discounts')
89 |
90 | # Update both the actor and the critic multiple times.
91 | for batch in self.replay.get(*keys, steps=steps):
92 | batch = {k: torch.as_tensor(v) for k, v in batch.items()}
93 | infos = self._update_actor_critic(**batch)
94 |
95 | for key in infos:
96 | for k, v in infos[key].items():
97 | logger.store(key + '/' + k, v.numpy())
98 |
99 | # Update the normalizers.
100 | if self.model.observation_normalizer:
101 | self.model.observation_normalizer.update()
102 | if self.model.return_normalizer:
103 | self.model.return_normalizer.update()
104 |
105 | def _update_actor_critic(
106 | self, observations, actions, next_observations, rewards, discounts
107 | ):
108 | critic_infos = self.critic_updater(
109 | observations, actions, next_observations, rewards, discounts)
110 | actor_infos = self.actor_updater(observations)
111 | self.model.update_targets()
112 | return dict(critic=critic_infos, actor=actor_infos)
113 |
--------------------------------------------------------------------------------
/tonic/torch/agents/mpo.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from tonic import logger, replays # noqa
4 | from tonic.torch import agents, models, normalizers, updaters
5 |
6 |
7 | def default_model():
8 | return models.ActorCriticWithTargets(
9 | actor=models.Actor(
10 | encoder=models.ObservationEncoder(),
11 | torso=models.MLP((256, 256), torch.nn.ReLU),
12 | head=models.GaussianPolicyHead()),
13 | critic=models.Critic(
14 | encoder=models.ObservationActionEncoder(),
15 | torso=models.MLP((256, 256), torch.nn.ReLU),
16 | head=models.ValueHead()),
17 | observation_normalizer=normalizers.MeanStd())
18 |
19 |
20 | class MPO(agents.Agent):
21 | '''Maximum a Posteriori Policy Optimisation.
22 | MPO: https://arxiv.org/pdf/1806.06920.pdf
23 | MO-MPO: https://arxiv.org/pdf/2005.07513.pdf
24 | '''
25 |
26 | def __init__(
27 | self, model=None, replay=None, actor_updater=None, critic_updater=None
28 | ):
29 | self.model = model or default_model()
30 | self.replay = replay or replays.Buffer(return_steps=5)
31 | self.actor_updater = actor_updater or \
32 | updaters.MaximumAPosterioriPolicyOptimization()
33 | self.critic_updater = critic_updater or updaters.ExpectedSARSA()
34 |
35 | def initialize(self, observation_space, action_space, seed=None):
36 | super().initialize(seed=seed)
37 | self.model.initialize(observation_space, action_space)
38 | self.replay.initialize(seed)
39 | self.actor_updater.initialize(self.model, action_space)
40 | self.critic_updater.initialize(self.model)
41 |
42 | def step(self, observations, steps):
43 | actions = self._step(observations)
44 | actions = actions.numpy()
45 |
46 | # Keep some values for the next update.
47 | self.last_observations = observations.copy()
48 | self.last_actions = actions.copy()
49 |
50 | return actions
51 |
52 | def test_step(self, observations, steps):
53 | # Sample actions for testing.
54 | return self._test_step(observations).numpy()
55 |
56 | def update(self, observations, rewards, resets, terminations, steps):
57 | # Store the last transitions in the replay.
58 | self.replay.store(
59 | observations=self.last_observations, actions=self.last_actions,
60 | next_observations=observations, rewards=rewards, resets=resets,
61 | terminations=terminations)
62 |
63 | # Prepare to update the normalizers.
64 | if self.model.observation_normalizer:
65 | self.model.observation_normalizer.record(self.last_observations)
66 | if self.model.return_normalizer:
67 | self.model.return_normalizer.record(rewards)
68 |
69 | # Update the model if the replay is ready.
70 | if self.replay.ready(steps):
71 | self._update(steps)
72 |
73 | def _step(self, observations):
74 | observations = torch.as_tensor(observations, dtype=torch.float32)
75 | with torch.no_grad():
76 | return self.model.actor(observations).sample()
77 |
78 | def _test_step(self, observations):
79 | observations = torch.as_tensor(observations, dtype=torch.float32)
80 | with torch.no_grad():
81 | return self.model.actor(observations).loc
82 |
83 | def _update(self, steps):
84 | keys = ('observations', 'actions', 'next_observations', 'rewards',
85 | 'discounts')
86 |
87 | # Update both the actor and the critic multiple times.
88 | for batch in self.replay.get(*keys, steps=steps):
89 | batch = {k: torch.as_tensor(v) for k, v in batch.items()}
90 | infos = self._update_actor_critic(**batch)
91 |
92 | for key in infos:
93 | for k, v in infos[key].items():
94 | logger.store(key + '/' + k, v.numpy())
95 |
96 | # Update the normalizers.
97 | if self.model.observation_normalizer:
98 | self.model.observation_normalizer.update()
99 | if self.model.return_normalizer:
100 | self.model.return_normalizer.update()
101 |
102 | def _update_actor_critic(
103 | self, observations, actions, next_observations, rewards, discounts
104 | ):
105 | critic_infos = self.critic_updater(
106 | observations, actions, next_observations, rewards, discounts)
107 | actor_infos = self.actor_updater(observations)
108 | self.model.update_targets()
109 | return dict(critic=critic_infos, actor=actor_infos)
110 |
--------------------------------------------------------------------------------
/tonic/torch/agents/ppo.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from tonic import logger # noqa
4 | from tonic.torch import agents, updaters
5 |
6 |
7 | class PPO(agents.A2C):
8 | '''Proximal Policy Optimization.
9 | PPO: https://arxiv.org/pdf/1707.06347.pdf
10 | '''
11 |
12 | def __init__(
13 | self, model=None, replay=None, actor_updater=None, critic_updater=None
14 | ):
15 | actor_updater = actor_updater or updaters.ClippedRatio()
16 | super().__init__(
17 | model=model, replay=replay, actor_updater=actor_updater,
18 | critic_updater=critic_updater)
19 |
20 | def _update(self):
21 | # Compute the lambda-returns.
22 | batch = self.replay.get_full('observations', 'next_observations')
23 | values, next_values = self._evaluate(**batch)
24 | values, next_values = values.numpy(), next_values.numpy()
25 | self.replay.compute_returns(values, next_values)
26 |
27 | train_actor = True
28 | actor_iterations = 0
29 | critic_iterations = 0
30 | keys = 'observations', 'actions', 'advantages', 'log_probs', 'returns'
31 |
32 | # Update both the actor and the critic multiple times.
33 | for batch in self.replay.get(*keys):
34 | if train_actor:
35 | batch = {k: torch.as_tensor(v) for k, v in batch.items()}
36 | infos = self._update_actor_critic(**batch)
37 | actor_iterations += 1
38 | else:
39 | batch = {k: torch.as_tensor(batch[k])
40 | for k in ('observations', 'returns')}
41 | infos = dict(critic=self.critic_updater(**batch))
42 | critic_iterations += 1
43 |
44 | # Stop earlier the training of the actor.
45 | if train_actor:
46 | train_actor = not infos['actor']['stop'].numpy()
47 |
48 | for key in infos:
49 | for k, v in infos[key].items():
50 | logger.store(key + '/' + k, v.numpy())
51 |
52 | logger.store('actor/iterations', actor_iterations)
53 | logger.store('critic/iterations', critic_iterations)
54 |
55 | # Update the normalizers.
56 | if self.model.observation_normalizer:
57 | self.model.observation_normalizer.update()
58 | if self.model.return_normalizer:
59 | self.model.return_normalizer.update()
60 |
61 | def _update_actor_critic(
62 | self, observations, actions, advantages, log_probs, returns
63 | ):
64 | actor_infos = self.actor_updater(
65 | observations, actions, advantages, log_probs)
66 | critic_infos = self.critic_updater(observations, returns)
67 | return dict(actor=actor_infos, critic=critic_infos)
68 |
--------------------------------------------------------------------------------
/tonic/torch/agents/sac.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from tonic import explorations # noqa
4 | from tonic.torch import agents, models, normalizers, updaters
5 |
6 |
7 | def default_model():
8 | return models.ActorTwinCriticWithTargets(
9 | actor=models.Actor(
10 | encoder=models.ObservationEncoder(),
11 | torso=models.MLP((256, 256), torch.nn.ReLU),
12 | head=models.GaussianPolicyHead(
13 | loc_activation=torch.nn.Identity,
14 | distribution=models.SquashedMultivariateNormalDiag)),
15 | critic=models.Critic(
16 | encoder=models.ObservationActionEncoder(),
17 | torso=models.MLP((256, 256), torch.nn.ReLU),
18 | head=models.ValueHead()),
19 | observation_normalizer=normalizers.MeanStd())
20 |
21 |
22 | class SAC(agents.DDPG):
23 | '''Soft Actor-Critic.
24 | SAC: https://arxiv.org/pdf/1801.01290.pdf
25 | '''
26 |
27 | def __init__(
28 | self, model=None, replay=None, exploration=None, actor_updater=None,
29 | critic_updater=None
30 | ):
31 | model = model or default_model()
32 | exploration = exploration or explorations.NoActionNoise()
33 | actor_updater = actor_updater or \
34 | updaters.TwinCriticSoftDeterministicPolicyGradient()
35 | critic_updater = critic_updater or updaters.TwinCriticSoftQLearning()
36 | super().__init__(
37 | model=model, replay=replay, exploration=exploration,
38 | actor_updater=actor_updater, critic_updater=critic_updater)
39 |
40 | def _stochastic_actions(self, observations):
41 | observations = torch.as_tensor(observations, dtype=torch.float32)
42 | with torch.no_grad():
43 | return self.model.actor(observations).sample()
44 |
45 | def _policy(self, observations):
46 | return self._stochastic_actions(observations).numpy()
47 |
48 | def _greedy_actions(self, observations):
49 | observations = torch.as_tensor(observations, dtype=torch.float32)
50 | with torch.no_grad():
51 | return self.model.actor(observations).loc
52 |
--------------------------------------------------------------------------------
/tonic/torch/agents/td3.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from tonic import logger # noqa
4 | from tonic.torch import agents, models, normalizers, updaters
5 |
6 |
7 | def default_model():
8 | return models.ActorTwinCriticWithTargets(
9 | actor=models.Actor(
10 | encoder=models.ObservationEncoder(),
11 | torso=models.MLP((256, 256), torch.nn.ReLU),
12 | head=models.DeterministicPolicyHead()),
13 | critic=models.Critic(
14 | encoder=models.ObservationActionEncoder(),
15 | torso=models.MLP((256, 256), torch.nn.ReLU),
16 | head=models.ValueHead()),
17 | observation_normalizer=normalizers.MeanStd())
18 |
19 |
20 | class TD3(agents.DDPG):
21 | '''Twin Delayed Deep Deterministic Policy Gradient.
22 | TD3: https://arxiv.org/pdf/1802.09477.pdf
23 | '''
24 |
25 | def __init__(
26 | self, model=None, replay=None, exploration=None, actor_updater=None,
27 | critic_updater=None, delay_steps=2
28 | ):
29 | model = model or default_model()
30 | critic_updater = critic_updater or \
31 | updaters.TwinCriticDeterministicQLearning()
32 | super().__init__(
33 | model=model, replay=replay, exploration=exploration,
34 | actor_updater=actor_updater, critic_updater=critic_updater)
35 | self.delay_steps = delay_steps
36 | self.model.critic = self.model.critic_1
37 |
38 | def _update(self, steps):
39 | keys = ('observations', 'actions', 'next_observations', 'rewards',
40 | 'discounts')
41 | for i, batch in enumerate(self.replay.get(*keys, steps=steps)):
42 | batch = {k: torch.as_tensor(v) for k, v in batch.items()}
43 | if (i + 1) % self.delay_steps == 0:
44 | infos = self._update_actor_critic(**batch)
45 | else:
46 | infos = dict(critic=self.critic_updater(**batch))
47 | for key in infos:
48 | for k, v in infos[key].items():
49 | logger.store(key + '/' + k, v.numpy())
50 |
51 | # Update the normalizers.
52 | if self.model.observation_normalizer:
53 | self.model.observation_normalizer.update()
54 | if self.model.return_normalizer:
55 | self.model.return_normalizer.update()
56 |
--------------------------------------------------------------------------------
/tonic/torch/agents/trpo.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from tonic import logger # noqa
4 | from tonic.torch import agents, updaters
5 |
6 |
7 | class TRPO(agents.A2C):
8 | '''Trust Region Policy Optimization.
9 | TRPO: https://arxiv.org/pdf/1502.05477.pdf
10 | '''
11 |
12 | def __init__(
13 | self, model=None, replay=None, actor_updater=None, critic_updater=None
14 | ):
15 | actor_updater = actor_updater or updaters.TrustRegionPolicyGradient()
16 | super().__init__(
17 | model=model, replay=replay, actor_updater=actor_updater,
18 | critic_updater=critic_updater)
19 |
20 | def step(self, observations, steps):
21 | # Sample actions and get their log-probabilities for training.
22 | actions, log_probs, locs, scales = self._step(observations)
23 | actions = actions.numpy()
24 | log_probs = log_probs.numpy()
25 | locs = locs.numpy()
26 | scales = scales.numpy()
27 |
28 | # Keep some values for the next update.
29 | self.last_observations = observations.copy()
30 | self.last_actions = actions.copy()
31 | self.last_log_probs = log_probs.copy()
32 | self.last_locs = locs.copy()
33 | self.last_scales = scales.copy()
34 |
35 | return actions
36 |
37 | def update(self, observations, rewards, resets, terminations, steps):
38 | # Store the last transitions in the replay.
39 | self.replay.store(
40 | observations=self.last_observations, actions=self.last_actions,
41 | next_observations=observations, rewards=rewards, resets=resets,
42 | terminations=terminations, log_probs=self.last_log_probs,
43 | locs=self.last_locs, scales=self.last_scales)
44 |
45 | # Prepare to update the normalizers.
46 | if self.model.observation_normalizer:
47 | self.model.observation_normalizer.record(self.last_observations)
48 | if self.model.return_normalizer:
49 | self.model.return_normalizer.record(rewards)
50 |
51 | # Update the model if the replay is ready.
52 | if self.replay.ready():
53 | self._update()
54 |
55 | def _step(self, observations):
56 | observations = torch.as_tensor(observations, dtype=torch.float32)
57 | with torch.no_grad():
58 | distributions = self.model.actor(observations)
59 | if hasattr(distributions, 'sample_with_log_prob'):
60 | actions, log_probs = distributions.sample_with_log_prob()
61 | else:
62 | actions = distributions.sample()
63 | log_probs = distributions.log_prob(actions)
64 | log_probs = log_probs.sum(dim=-1)
65 | locs = distributions.loc
66 | scales = distributions.stddev
67 | return actions, log_probs, locs, scales
68 |
69 | def _update(self):
70 | # Compute the lambda-returns.
71 | batch = self.replay.get_full('observations', 'next_observations')
72 | values, next_values = self._evaluate(**batch)
73 | values, next_values = values.numpy(), next_values.numpy()
74 | self.replay.compute_returns(values, next_values)
75 |
76 | keys = ('observations', 'actions', 'log_probs', 'locs', 'scales',
77 | 'advantages')
78 | batch = self.replay.get_full(*keys)
79 | batch = {k: torch.as_tensor(v) for k, v in batch.items()}
80 | infos = self.actor_updater(**batch)
81 | for k, v in infos.items():
82 | logger.store('actor/' + k, v.numpy())
83 |
84 | critic_iterations = 0
85 | for batch in self.replay.get('observations', 'returns'):
86 | batch = {k: torch.as_tensor(v) for k, v in batch.items()}
87 | infos = self.critic_updater(**batch)
88 | critic_iterations += 1
89 | for k, v in infos.items():
90 | logger.store('critic/' + k, v.numpy())
91 | logger.store('critic/iterations', critic_iterations)
92 |
93 | # Update the normalizers.
94 | if self.model.observation_normalizer:
95 | self.model.observation_normalizer.update()
96 | if self.model.return_normalizer:
97 | self.model.return_normalizer.update()
98 |
--------------------------------------------------------------------------------
/tonic/torch/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .actor_critics import ActorCritic
2 | from .actor_critics import ActorCriticWithTargets
3 | from .actor_critics import ActorTwinCriticWithTargets
4 |
5 | from .actors import Actor
6 | from .actors import DetachedScaleGaussianPolicyHead
7 | from .actors import DeterministicPolicyHead
8 | from .actors import GaussianPolicyHead
9 | from .actors import SquashedMultivariateNormalDiag
10 |
11 | from .critics import Critic, DistributionalValueHead, ValueHead
12 |
13 | from .encoders import ObservationActionEncoder, ObservationEncoder
14 |
15 | from .utils import MLP, trainable_variables
16 |
17 |
18 | __all__ = [
19 | MLP, trainable_variables, ObservationActionEncoder,
20 | ObservationEncoder, SquashedMultivariateNormalDiag,
21 | DetachedScaleGaussianPolicyHead, GaussianPolicyHead,
22 | DeterministicPolicyHead, Actor, Critic, DistributionalValueHead,
23 | ValueHead, ActorCritic, ActorCriticWithTargets, ActorTwinCriticWithTargets]
24 |
--------------------------------------------------------------------------------
/tonic/torch/models/actor_critics.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch
4 |
5 | from tonic.torch import models # noqa
6 |
7 |
8 | class ActorCritic(torch.nn.Module):
9 | def __init__(
10 | self, actor, critic, observation_normalizer=None,
11 | return_normalizer=None
12 | ):
13 | super().__init__()
14 | self.actor = actor
15 | self.critic = critic
16 | self.observation_normalizer = observation_normalizer
17 | self.return_normalizer = return_normalizer
18 |
19 | def initialize(self, observation_space, action_space):
20 | if self.observation_normalizer:
21 | self.observation_normalizer.initialize(observation_space.shape)
22 | self.actor.initialize(
23 | observation_space, action_space, self.observation_normalizer)
24 | self.critic.initialize(
25 | observation_space, action_space, self.observation_normalizer,
26 | self.return_normalizer)
27 |
28 |
29 | class ActorCriticWithTargets(torch.nn.Module):
30 | def __init__(
31 | self, actor, critic, observation_normalizer=None,
32 | return_normalizer=None, target_coeff=0.005
33 | ):
34 | super().__init__()
35 | self.actor = actor
36 | self.critic = critic
37 | self.target_actor = copy.deepcopy(actor)
38 | self.target_critic = copy.deepcopy(critic)
39 | self.observation_normalizer = observation_normalizer
40 | self.return_normalizer = return_normalizer
41 | self.target_coeff = target_coeff
42 |
43 | def initialize(self, observation_space, action_space):
44 | if self.observation_normalizer:
45 | self.observation_normalizer.initialize(observation_space.shape)
46 | self.actor.initialize(
47 | observation_space, action_space, self.observation_normalizer)
48 | self.critic.initialize(
49 | observation_space, action_space, self.observation_normalizer,
50 | self.return_normalizer)
51 | self.target_actor.initialize(
52 | observation_space, action_space, self.observation_normalizer)
53 | self.target_critic.initialize(
54 | observation_space, action_space, self.observation_normalizer,
55 | self.return_normalizer)
56 | self.online_variables = models.trainable_variables(self.actor)
57 | self.online_variables += models.trainable_variables(self.critic)
58 | self.target_variables = models.trainable_variables(self.target_actor)
59 | self.target_variables += models.trainable_variables(self.target_critic)
60 | for target in self.target_variables:
61 | target.requires_grad = False
62 | self.assign_targets()
63 |
64 | def assign_targets(self):
65 | for o, t in zip(self.online_variables, self.target_variables):
66 | t.data.copy_(o.data)
67 |
68 | def update_targets(self):
69 | with torch.no_grad():
70 | for o, t in zip(self.online_variables, self.target_variables):
71 | t.data.mul_(1 - self.target_coeff)
72 | t.data.add_(self.target_coeff * o.data)
73 |
74 |
75 | class ActorTwinCriticWithTargets(torch.nn.Module):
76 | def __init__(
77 | self, actor, critic, observation_normalizer=None,
78 | return_normalizer=None, target_coeff=0.005
79 | ):
80 | super().__init__()
81 | self.actor = actor
82 | self.critic_1 = critic
83 | self.critic_2 = copy.deepcopy(critic)
84 | self.target_actor = copy.deepcopy(actor)
85 | self.target_critic_1 = copy.deepcopy(critic)
86 | self.target_critic_2 = copy.deepcopy(critic)
87 | self.observation_normalizer = observation_normalizer
88 | self.return_normalizer = return_normalizer
89 | self.target_coeff = target_coeff
90 |
91 | def initialize(self, observation_space, action_space):
92 | if self.observation_normalizer:
93 | self.observation_normalizer.initialize(observation_space.shape)
94 | self.actor.initialize(
95 | observation_space, action_space, self.observation_normalizer)
96 | self.critic_1.initialize(
97 | observation_space, action_space, self.observation_normalizer,
98 | self.return_normalizer)
99 | self.critic_2.initialize(
100 | observation_space, action_space, self.observation_normalizer,
101 | self.return_normalizer)
102 | self.target_actor.initialize(
103 | observation_space, action_space, self.observation_normalizer)
104 | self.target_critic_1.initialize(
105 | observation_space, action_space, self.observation_normalizer,
106 | self.return_normalizer)
107 | self.target_critic_2.initialize(
108 | observation_space, action_space, self.observation_normalizer,
109 | self.return_normalizer)
110 | self.online_variables = models.trainable_variables(self.actor)
111 | self.online_variables += models.trainable_variables(self.critic_1)
112 | self.online_variables += models.trainable_variables(self.critic_2)
113 | self.target_variables = models.trainable_variables(self.target_actor)
114 | self.target_variables += models.trainable_variables(
115 | self.target_critic_1)
116 | self.target_variables += models.trainable_variables(
117 | self.target_critic_2)
118 | for target in self.target_variables:
119 | target.requires_grad = False
120 | self.assign_targets()
121 |
122 | def assign_targets(self):
123 | for o, t in zip(self.online_variables, self.target_variables):
124 | t.data.copy_(o.data)
125 |
126 | def update_targets(self):
127 | with torch.no_grad():
128 | for o, t in zip(self.online_variables, self.target_variables):
129 | t.data.mul_(1 - self.target_coeff)
130 | t.data.add_(self.target_coeff * o.data)
131 |
--------------------------------------------------------------------------------
/tonic/torch/models/actors.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | FLOAT_EPSILON = 1e-8
5 |
6 |
7 | class SquashedMultivariateNormalDiag:
8 | def __init__(self, loc, scale):
9 | self._distribution = torch.distributions.normal.Normal(loc, scale)
10 |
11 | def rsample_with_log_prob(self, shape=()):
12 | samples = self._distribution.rsample(shape)
13 | squashed_samples = torch.tanh(samples)
14 | log_probs = self._distribution.log_prob(samples)
15 | log_probs -= torch.log(1 - squashed_samples ** 2 + 1e-6)
16 | return squashed_samples, log_probs
17 |
18 | def rsample(self, shape=()):
19 | samples = self._distribution.rsample(shape)
20 | return torch.tanh(samples)
21 |
22 | def sample(self, shape=()):
23 | samples = self._distribution.sample(shape)
24 | return torch.tanh(samples)
25 |
26 | def log_prob(self, samples):
27 | '''Required unsquashed samples cannot be accurately recovered.'''
28 | raise NotImplementedError(
29 | 'Not implemented to avoid approximation errors. '
30 | 'Use sample_with_log_prob directly.')
31 |
32 | @property
33 | def loc(self):
34 | return torch.tanh(self._distribution.mean)
35 |
36 |
37 | class DetachedScaleGaussianPolicyHead(torch.nn.Module):
38 | def __init__(
39 | self, loc_activation=torch.nn.Tanh, loc_fn=None, log_scale_init=0.,
40 | scale_min=1e-4, scale_max=1.,
41 | distribution=torch.distributions.normal.Normal
42 | ):
43 | super().__init__()
44 | self.loc_activation = loc_activation
45 | self.loc_fn = loc_fn
46 | self.log_scale_init = log_scale_init
47 | self.scale_min = scale_min
48 | self.scale_max = scale_max
49 | self.distribution = distribution
50 |
51 | def initialize(self, input_size, action_size):
52 | self.loc_layer = torch.nn.Sequential(
53 | torch.nn.Linear(input_size, action_size), self.loc_activation())
54 | if self.loc_fn:
55 | self.loc_layer.apply(self.loc_fn)
56 | log_scale = [[self.log_scale_init] * action_size]
57 | self.log_scale = torch.nn.Parameter(
58 | torch.as_tensor(log_scale, dtype=torch.float32))
59 |
60 | def forward(self, inputs):
61 | loc = self.loc_layer(inputs)
62 | batch_size = inputs.shape[0]
63 | scale = torch.nn.functional.softplus(self.log_scale) + FLOAT_EPSILON
64 | scale = torch.clamp(scale, self.scale_min, self.scale_max)
65 | scale = scale.repeat(batch_size, 1)
66 | return self.distribution(loc, scale)
67 |
68 |
69 | class GaussianPolicyHead(torch.nn.Module):
70 | def __init__(
71 | self, loc_activation=torch.nn.Tanh, loc_fn=None,
72 | scale_activation=torch.nn.Softplus, scale_min=1e-4, scale_max=1,
73 | scale_fn=None, distribution=torch.distributions.normal.Normal
74 | ):
75 | super().__init__()
76 | self.loc_activation = loc_activation
77 | self.loc_fn = loc_fn
78 | self.scale_activation = scale_activation
79 | self.scale_min = scale_min
80 | self.scale_max = scale_max
81 | self.scale_fn = scale_fn
82 | self.distribution = distribution
83 |
84 | def initialize(self, input_size, action_size):
85 | self.loc_layer = torch.nn.Sequential(
86 | torch.nn.Linear(input_size, action_size), self.loc_activation())
87 | if self.loc_fn:
88 | self.loc_layer.apply(self.loc_fn)
89 | self.scale_layer = torch.nn.Sequential(
90 | torch.nn.Linear(input_size, action_size), self.scale_activation())
91 | if self.scale_fn:
92 | self.scale_layer.apply(self.scale_fn)
93 |
94 | def forward(self, inputs):
95 | loc = self.loc_layer(inputs)
96 | scale = self.scale_layer(inputs)
97 | scale = torch.clamp(scale, self.scale_min, self.scale_max)
98 | return self.distribution(loc, scale)
99 |
100 |
101 | class DeterministicPolicyHead(torch.nn.Module):
102 | def __init__(self, activation=torch.nn.Tanh, fn=None):
103 | super().__init__()
104 | self.activation = activation
105 | self.fn = fn
106 |
107 | def initialize(self, input_size, action_size):
108 | self.action_layer = torch.nn.Sequential(
109 | torch.nn.Linear(input_size, action_size),
110 | self.activation())
111 | if self.fn is not None:
112 | self.action_layer.apply(self.fn)
113 |
114 | def forward(self, inputs):
115 | return self.action_layer(inputs)
116 |
117 |
118 | class Actor(torch.nn.Module):
119 | def __init__(self, encoder, torso, head):
120 | super().__init__()
121 | self.encoder = encoder
122 | self.torso = torso
123 | self.head = head
124 |
125 | def initialize(
126 | self, observation_space, action_space, observation_normalizer=None
127 | ):
128 | size = self.encoder.initialize(
129 | observation_space, observation_normalizer)
130 | size = self.torso.initialize(size)
131 | action_size = action_space.shape[0]
132 | self.head.initialize(size, action_size)
133 |
134 | def forward(self, *inputs):
135 | out = self.encoder(*inputs)
136 | out = self.torso(out)
137 | return self.head(out)
138 |
--------------------------------------------------------------------------------
/tonic/torch/models/critics.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class ValueHead(torch.nn.Module):
5 | def __init__(self, fn=None):
6 | super().__init__()
7 | self.fn = fn
8 |
9 | def initialize(self, input_size, return_normalizer=None):
10 | self.return_normalizer = return_normalizer
11 | self.v_layer = torch.nn.Linear(input_size, 1)
12 | if self.fn:
13 | self.v_layer.apply(self.fn)
14 |
15 | def forward(self, inputs):
16 | out = self.v_layer(inputs)
17 | out = torch.squeeze(out, -1)
18 | if self.return_normalizer:
19 | out = self.return_normalizer(out)
20 | return out
21 |
22 |
23 | class CategoricalWithSupport:
24 | def __init__(self, values, logits):
25 | self.values = values
26 | self.logits = logits
27 | self.probabilities = torch.nn.functional.softmax(logits, dim=-1)
28 |
29 | def mean(self):
30 | return (self.probabilities * self.values).sum(dim=-1)
31 |
32 | def project(self, returns):
33 | vmin, vmax = self.values[0], self.values[-1]
34 | d_pos = torch.cat([self.values, vmin[None]], 0)[1:]
35 | d_pos = (d_pos - self.values)[None, :, None]
36 | d_neg = torch.cat([vmax[None], self.values], 0)[:-1]
37 | d_neg = (self.values - d_neg)[None, :, None]
38 |
39 | clipped_returns = torch.clamp(returns, vmin, vmax)
40 | delta_values = clipped_returns[:, None] - self.values[None, :, None]
41 | delta_sign = (delta_values >= 0).float()
42 | delta_hat = ((delta_sign * delta_values / d_pos) -
43 | ((1 - delta_sign) * delta_values / d_neg))
44 | delta_clipped = torch.clamp(1 - delta_hat, 0, 1)
45 |
46 | return (delta_clipped * self.probabilities[:, None]).sum(dim=2)
47 |
48 |
49 | class DistributionalValueHead(torch.nn.Module):
50 | def __init__(self, vmin, vmax, num_atoms, fn=None):
51 | super().__init__()
52 | self.num_atoms = num_atoms
53 | self.fn = fn
54 | self.values = torch.linspace(vmin, vmax, num_atoms).float()
55 |
56 | def initialize(self, input_size, return_normalizer=None):
57 | if return_normalizer:
58 | raise ValueError(
59 | 'Return normalizers cannot be used with distributional value'
60 | 'heads.')
61 | self.distributional_layer = torch.nn.Linear(input_size, self.num_atoms)
62 | if self.fn:
63 | self.distributional_layer.apply(self.fn)
64 |
65 | def forward(self, inputs):
66 | logits = self.distributional_layer(inputs)
67 | return CategoricalWithSupport(values=self.values, logits=logits)
68 |
69 |
70 | class Critic(torch.nn.Module):
71 | def __init__(self, encoder, torso, head):
72 | super().__init__()
73 | self.encoder = encoder
74 | self.torso = torso
75 | self.head = head
76 |
77 | def initialize(
78 | self, observation_space, action_space, observation_normalizer=None,
79 | return_normalizer=None
80 | ):
81 | size = self.encoder.initialize(
82 | observation_space=observation_space, action_space=action_space,
83 | observation_normalizer=observation_normalizer)
84 | size = self.torso.initialize(size)
85 | self.head.initialize(size, return_normalizer)
86 |
87 | def forward(self, *inputs):
88 | out = self.encoder(*inputs)
89 | out = self.torso(out)
90 | return self.head(out)
91 |
--------------------------------------------------------------------------------
/tonic/torch/models/encoders.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class ObservationEncoder(torch.nn.Module):
5 | def initialize(
6 | self, observation_space, action_space=None,
7 | observation_normalizer=None,
8 | ):
9 | self.observation_normalizer = observation_normalizer
10 | observation_size = observation_space.shape[0]
11 | return observation_size
12 |
13 | def forward(self, observations):
14 | if self.observation_normalizer:
15 | observations = self.observation_normalizer(observations)
16 | return observations
17 |
18 |
19 | class ObservationActionEncoder(torch.nn.Module):
20 | def initialize(
21 | self, observation_space, action_space, observation_normalizer=None
22 | ):
23 | self.observation_normalizer = observation_normalizer
24 | observation_size = observation_space.shape[0]
25 | action_size = action_space.shape[0]
26 | return observation_size + action_size
27 |
28 | def forward(self, observations, actions):
29 | if self.observation_normalizer:
30 | observations = self.observation_normalizer(observations)
31 | return torch.cat([observations, actions], dim=-1)
32 |
--------------------------------------------------------------------------------
/tonic/torch/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class MLP(torch.nn.Module):
5 | def __init__(self, sizes, activation, fn=None):
6 | super().__init__()
7 | self.sizes = sizes
8 | self.activation = activation
9 | self.fn = fn
10 |
11 | def initialize(self, input_size):
12 | sizes = [input_size] + list(self.sizes)
13 | layers = []
14 | for i in range(len(sizes) - 1):
15 | layers += [torch.nn.Linear(sizes[i], sizes[i + 1]),
16 | self.activation()]
17 | self.model = torch.nn.Sequential(*layers)
18 | if self.fn is not None:
19 | self.model.apply(self.fn)
20 | return sizes[-1]
21 |
22 | def forward(self, inputs):
23 | return self.model(inputs)
24 |
25 |
26 | def trainable_variables(model):
27 | return [p for p in model.parameters() if p.requires_grad]
28 |
--------------------------------------------------------------------------------
/tonic/torch/normalizers/__init__.py:
--------------------------------------------------------------------------------
1 | from .mean_stds import MeanStd
2 | from .returns import Return
3 |
4 |
5 | __all__ = [MeanStd, Return]
6 |
--------------------------------------------------------------------------------
/tonic/torch/normalizers/mean_stds.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class MeanStd(torch.nn.Module):
6 | def __init__(self, mean=0, std=1, clip=None, shape=None):
7 | super().__init__()
8 | self.mean = mean
9 | self.std = std
10 | self.clip = clip
11 | self.count = 0
12 | self.new_sum = 0
13 | self.new_sum_sq = 0
14 | self.new_count = 0
15 | self.eps = 1e-2
16 | if shape:
17 | self.initialize(shape)
18 |
19 | def initialize(self, shape):
20 | if isinstance(self.mean, (int, float)):
21 | self.mean = np.full(shape, self.mean, np.float32)
22 | else:
23 | self.mean = np.array(self.mean, np.float32)
24 | if isinstance(self.std, (int, float)):
25 | self.std = np.full(shape, self.std, np.float32)
26 | else:
27 | self.std = np.array(self.std, np.float32)
28 | self.mean_sq = np.square(self.mean)
29 | self._mean = torch.nn.Parameter(torch.as_tensor(
30 | self.mean, dtype=torch.float32), requires_grad=False)
31 | self._std = torch.nn.Parameter(torch.as_tensor(
32 | self.std, dtype=torch.float32), requires_grad=False)
33 |
34 | def forward(self, val):
35 | with torch.no_grad():
36 | val = (val - self._mean) / self._std
37 | if self.clip is not None:
38 | val = torch.clamp(val, -self.clip, self.clip)
39 | return val
40 |
41 | def unnormalize(self, val):
42 | return val * self._std + self._mean
43 |
44 | def record(self, values):
45 | for val in values:
46 | self.new_sum += val
47 | self.new_sum_sq += np.square(val)
48 | self.new_count += 1
49 |
50 | def update(self):
51 | new_count = self.count + self.new_count
52 | new_mean = self.new_sum / self.new_count
53 | new_mean_sq = self.new_sum_sq / self.new_count
54 | w_old = self.count / new_count
55 | w_new = self.new_count / new_count
56 | self.mean = w_old * self.mean + w_new * new_mean
57 | self.mean_sq = w_old * self.mean_sq + w_new * new_mean_sq
58 | self.std = self._compute_std(self.mean, self.mean_sq)
59 | self.count = new_count
60 | self.new_count = 0
61 | self.new_sum = 0
62 | self.new_sum_sq = 0
63 | self._update(self.mean.astype(np.float32), self.std.astype(np.float32))
64 |
65 | def _compute_std(self, mean, mean_sq):
66 | var = mean_sq - np.square(mean)
67 | var = np.maximum(var, 0)
68 | std = np.sqrt(var)
69 | std = np.maximum(std, self.eps)
70 | return std
71 |
72 | def _update(self, mean, std):
73 | self._mean.data.copy_(torch.as_tensor(self.mean, dtype=torch.float32))
74 | self._std.data.copy_(torch.as_tensor(self.std, dtype=torch.float32))
75 |
--------------------------------------------------------------------------------
/tonic/torch/normalizers/returns.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class Return(torch.nn.Module):
6 | def __init__(self, discount_factor):
7 | super().__init__()
8 | assert 0 <= discount_factor < 1
9 | self.coefficient = 1 / (1 - discount_factor)
10 | self.min_reward = np.float32(-1)
11 | self.max_reward = np.float32(1)
12 | self._low = torch.nn.Parameter(torch.as_tensor(
13 | self.coefficient * self.min_reward, dtype=torch.float32),
14 | requires_grad=False)
15 | self._high = torch.nn.Parameter(torch.as_tensor(
16 | self.coefficient * self.max_reward, dtype=torch.float32),
17 | requires_grad=False)
18 |
19 | def forward(self, val):
20 | val = torch.sigmoid(val)
21 | return self._low + val * (self._high - self._low)
22 |
23 | def record(self, values):
24 | for val in values:
25 | if val < self.min_reward:
26 | self.min_reward = np.float32(val)
27 | elif val > self.max_reward:
28 | self.max_reward = np.float32(val)
29 |
30 | def update(self):
31 | self._update(self.min_reward, self.max_reward)
32 |
33 | def _update(self, min_reward, max_reward):
34 | self._low.data.copy_(torch.as_tensor(
35 | self.coefficient * min_reward, dtype=torch.float32))
36 | self._high.data.copy_(torch.as_tensor(
37 | self.coefficient * max_reward, dtype=torch.float32))
38 |
--------------------------------------------------------------------------------
/tonic/torch/updaters/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import merge_first_two_dims
2 | from .utils import tile
3 |
4 | from .actors import ClippedRatio # noqa
5 | from .actors import DeterministicPolicyGradient
6 | from .actors import DistributionalDeterministicPolicyGradient
7 | from .actors import MaximumAPosterioriPolicyOptimization
8 | from .actors import StochasticPolicyGradient
9 | from .actors import TrustRegionPolicyGradient
10 | from .actors import TwinCriticSoftDeterministicPolicyGradient
11 |
12 | from .critics import DeterministicQLearning
13 | from .critics import DistributionalDeterministicQLearning
14 | from .critics import ExpectedSARSA
15 | from .critics import QRegression
16 | from .critics import TargetActionNoise
17 | from .critics import TwinCriticDeterministicQLearning
18 | from .critics import TwinCriticSoftQLearning
19 | from .critics import VRegression
20 |
21 | from .optimizers import ConjugateGradient
22 |
23 |
24 | __all__ = [
25 | merge_first_two_dims, tile, ClippedRatio, DeterministicPolicyGradient,
26 | DistributionalDeterministicPolicyGradient,
27 | MaximumAPosterioriPolicyOptimization, StochasticPolicyGradient,
28 | TrustRegionPolicyGradient, TwinCriticSoftDeterministicPolicyGradient,
29 | DeterministicQLearning, DistributionalDeterministicQLearning,
30 | ExpectedSARSA, QRegression, TargetActionNoise,
31 | TwinCriticDeterministicQLearning, TwinCriticSoftQLearning, VRegression,
32 | ConjugateGradient]
33 |
--------------------------------------------------------------------------------
/tonic/torch/updaters/critics.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from tonic.torch import models, updaters # noqa
4 |
5 |
6 | class VRegression:
7 | def __init__(self, loss=None, optimizer=None, gradient_clip=0):
8 | self.loss = loss or torch.nn.MSELoss()
9 | self.optimizer = optimizer or (
10 | lambda params: torch.optim.Adam(params, lr=1e-3))
11 | self.gradient_clip = gradient_clip
12 |
13 | def initialize(self, model):
14 | self.model = model
15 | self.variables = models.trainable_variables(self.model.critic)
16 | self.optimizer = self.optimizer(self.variables)
17 |
18 | def __call__(self, observations, returns):
19 | self.optimizer.zero_grad()
20 | values = self.model.critic(observations)
21 | loss = self.loss(values, returns)
22 |
23 | loss.backward()
24 | if self.gradient_clip > 0:
25 | torch.nn.utils.clip_grad_norm_(self.variables, self.gradient_clip)
26 | self.optimizer.step()
27 |
28 | return dict(loss=loss.detach(), v=values.detach())
29 |
30 |
31 | class QRegression:
32 | def __init__(self, loss=None, optimizer=None, gradient_clip=0):
33 | self.loss = loss or torch.nn.MSELoss()
34 | self.optimizer = optimizer or (
35 | lambda params: torch.optim.Adam(params, lr=1e-3))
36 | self.gradient_clip = gradient_clip
37 |
38 | def initialize(self, model):
39 | self.model = model
40 | self.variables = models.trainable_variables(self.model.critic)
41 | self.optimizer = self.optimizer(self.variables)
42 |
43 | def __call__(self, observations, actions, returns):
44 | self.optimizer.zero_grad()
45 | values = self.model.critic(observations, actions)
46 | loss = self.loss(values, returns)
47 |
48 | loss.backward()
49 | if self.gradient_clip > 0:
50 | torch.nn.utils.clip_grad_norm_(self.variables, self.gradient_clip)
51 | self.optimizer.step()
52 |
53 | return dict(loss=loss.detach(), q=values.detach())
54 |
55 |
56 | class DeterministicQLearning:
57 | def __init__(self, loss=None, optimizer=None, gradient_clip=0):
58 | self.loss = loss or torch.nn.MSELoss()
59 | self.optimizer = optimizer or (
60 | lambda params: torch.optim.Adam(params, lr=1e-3))
61 | self.gradient_clip = gradient_clip
62 |
63 | def initialize(self, model):
64 | self.model = model
65 | self.variables = models.trainable_variables(self.model.critic)
66 | self.optimizer = self.optimizer(self.variables)
67 |
68 | def __call__(
69 | self, observations, actions, next_observations, rewards, discounts
70 | ):
71 | with torch.no_grad():
72 | next_actions = self.model.target_actor(next_observations)
73 | next_values = self.model.target_critic(
74 | next_observations, next_actions)
75 | returns = rewards + discounts * next_values
76 |
77 | self.optimizer.zero_grad()
78 | values = self.model.critic(observations, actions)
79 | loss = self.loss(values, returns)
80 |
81 | loss.backward()
82 | if self.gradient_clip > 0:
83 | torch.nn.utils.clip_grad_norm_(self.variables, self.gradient_clip)
84 | self.optimizer.step()
85 |
86 | return dict(loss=loss.detach(), q=values.detach())
87 |
88 |
89 | class DistributionalDeterministicQLearning:
90 | def __init__(self, optimizer=None, gradient_clip=0):
91 | self.optimizer = optimizer or (
92 | lambda params: torch.optim.Adam(params, lr=1e-3))
93 | self.gradient_clip = gradient_clip
94 |
95 | def initialize(self, model):
96 | self.model = model
97 | self.variables = models.trainable_variables(self.model.critic)
98 | self.optimizer = self.optimizer(self.variables)
99 |
100 | def __call__(
101 | self, observations, actions, next_observations, rewards, discounts
102 | ):
103 | with torch.no_grad():
104 | next_actions = self.model.target_actor(next_observations)
105 | next_value_distributions = self.model.target_critic(
106 | next_observations, next_actions)
107 | values = next_value_distributions.values
108 | returns = rewards[:, None] + discounts[:, None] * values
109 | targets = next_value_distributions.project(returns)
110 |
111 | self.optimizer.zero_grad()
112 | value_distributions = self.model.critic(observations, actions)
113 | log_probabilities = torch.nn.functional.log_softmax(
114 | value_distributions.logits, dim=-1)
115 | loss = -(targets * log_probabilities).sum(dim=-1).mean()
116 |
117 | loss.backward()
118 | if self.gradient_clip > 0:
119 | torch.nn.utils.clip_grad_norm_(self.variables, self.gradient_clip)
120 | self.optimizer.step()
121 |
122 | return dict(loss=loss.detach())
123 |
124 |
125 | class TargetActionNoise:
126 | def __init__(self, scale=0.2, clip=0.5):
127 | self.scale = scale
128 | self.clip = clip
129 |
130 | def __call__(self, actions):
131 | noises = self.scale * torch.randn_like(actions)
132 | noises = torch.clamp(noises, -self.clip, self.clip)
133 | actions = actions + noises
134 | return torch.clamp(actions, -1, 1)
135 |
136 |
137 | class TwinCriticDeterministicQLearning:
138 | def __init__(
139 | self, loss=None, optimizer=None, target_action_noise=None,
140 | gradient_clip=0
141 | ):
142 | self.loss = loss or torch.nn.MSELoss()
143 | self.optimizer = optimizer or (
144 | lambda params: torch.optim.Adam(params, lr=1e-3))
145 | self.target_action_noise = target_action_noise or \
146 | TargetActionNoise(scale=0.2, clip=0.5)
147 | self.gradient_clip = gradient_clip
148 |
149 | def initialize(self, model):
150 | self.model = model
151 | variables_1 = models.trainable_variables(self.model.critic_1)
152 | variables_2 = models.trainable_variables(self.model.critic_2)
153 | self.variables = variables_1 + variables_2
154 | self.optimizer = self.optimizer(self.variables)
155 |
156 | def __call__(
157 | self, observations, actions, next_observations, rewards, discounts
158 | ):
159 | with torch.no_grad():
160 | next_actions = self.model.target_actor(next_observations)
161 | next_actions = self.target_action_noise(next_actions)
162 | next_values_1 = self.model.target_critic_1(
163 | next_observations, next_actions)
164 | next_values_2 = self.model.target_critic_2(
165 | next_observations, next_actions)
166 | next_values = torch.min(next_values_1, next_values_2)
167 | returns = rewards + discounts * next_values
168 |
169 | self.optimizer.zero_grad()
170 | values_1 = self.model.critic_1(observations, actions)
171 | values_2 = self.model.critic_2(observations, actions)
172 | loss_1 = self.loss(values_1, returns)
173 | loss_2 = self.loss(values_2, returns)
174 | loss = loss_1 + loss_2
175 |
176 | loss.backward()
177 | if self.gradient_clip > 0:
178 | torch.nn.utils.clip_grad_norm_(self.variables, self.gradient_clip)
179 | self.optimizer.step()
180 |
181 | return dict(
182 | loss=loss.detach(), q1=values_1.detach(), q2=values_2.detach())
183 |
184 |
185 | class TwinCriticSoftQLearning:
186 | def __init__(
187 | self, loss=None, optimizer=None, entropy_coeff=0.2, gradient_clip=0
188 | ):
189 | self.loss = loss or torch.nn.MSELoss()
190 | self.optimizer = optimizer or (
191 | lambda params: torch.optim.Adam(params, lr=3e-4))
192 | self.entropy_coeff = entropy_coeff
193 | self.gradient_clip = gradient_clip
194 |
195 | def initialize(self, model):
196 | self.model = model
197 | variables_1 = models.trainable_variables(self.model.critic_1)
198 | variables_2 = models.trainable_variables(self.model.critic_2)
199 | self.variables = variables_1 + variables_2
200 | self.optimizer = self.optimizer(self.variables)
201 |
202 | def __call__(
203 | self, observations, actions, next_observations, rewards, discounts
204 | ):
205 | with torch.no_grad():
206 | next_distributions = self.model.actor(next_observations)
207 | if hasattr(next_distributions, 'rsample_with_log_prob'):
208 | outs = next_distributions.rsample_with_log_prob()
209 | next_actions, next_log_probs = outs
210 | else:
211 | next_actions = next_distributions.rsample()
212 | next_log_probs = next_distributions.log_prob(next_actions)
213 | next_log_probs = next_log_probs.sum(dim=-1)
214 | next_values_1 = self.model.target_critic_1(
215 | next_observations, next_actions)
216 | next_values_2 = self.model.target_critic_2(
217 | next_observations, next_actions)
218 | next_values = torch.min(next_values_1, next_values_2)
219 | returns = rewards + discounts * (
220 | next_values - self.entropy_coeff * next_log_probs)
221 |
222 | self.optimizer.zero_grad()
223 | values_1 = self.model.critic_1(observations, actions)
224 | values_2 = self.model.critic_2(observations, actions)
225 | loss_1 = self.loss(values_1, returns)
226 | loss_2 = self.loss(values_2, returns)
227 | loss = loss_1 + loss_2
228 |
229 | loss.backward()
230 | if self.gradient_clip > 0:
231 | torch.nn.utils.clip_grad_norm_(self.variables, self.gradient_clip)
232 | self.optimizer.step()
233 |
234 | return dict(
235 | loss=loss.detach(), q1=values_1.detach(), q2=values_2.detach())
236 |
237 |
238 | class ExpectedSARSA:
239 | def __init__(
240 | self, num_samples=20, loss=None, optimizer=None, gradient_clip=0
241 | ):
242 | self.num_samples = num_samples
243 | self.loss = loss or torch.nn.MSELoss()
244 | self.optimizer = optimizer or (
245 | lambda params: torch.optim.Adam(params, lr=3e-4))
246 | self.gradient_clip = gradient_clip
247 |
248 | def initialize(self, model):
249 | self.model = model
250 | self.variables = models.trainable_variables(self.model.critic)
251 | self.optimizer = self.optimizer(self.variables)
252 |
253 | def __call__(
254 | self, observations, actions, next_observations, rewards, discounts
255 | ):
256 | # Approximate the expected next values.
257 | with torch.no_grad():
258 | next_target_distributions = self.model.target_actor(
259 | next_observations)
260 | next_actions = next_target_distributions.rsample(
261 | (self.num_samples,))
262 | next_actions = updaters.merge_first_two_dims(next_actions)
263 | next_observations = updaters.tile(
264 | next_observations, self.num_samples)
265 | next_observations = updaters.merge_first_two_dims(
266 | next_observations)
267 | next_values = self.model.target_critic(
268 | next_observations, next_actions)
269 | next_values = next_values.view(self.num_samples, -1)
270 | next_values = next_values.mean(dim=0)
271 | returns = rewards + discounts * next_values
272 |
273 | self.optimizer.zero_grad()
274 | values = self.model.critic(observations, actions)
275 | loss = self.loss(returns, values)
276 |
277 | loss.backward()
278 | if self.gradient_clip > 0:
279 | torch.nn.utils.clip_grad_norm_(self.variables, self.gradient_clip)
280 | self.optimizer.step()
281 |
282 | return dict(loss=loss.detach(), q=values.detach())
283 |
--------------------------------------------------------------------------------
/tonic/torch/updaters/optimizers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | FLOAT_EPSILON = 1e-8
6 |
7 |
8 | def flat_concat(xs):
9 | return torch.cat([torch.reshape(x, (-1,)) for x in xs], dim=0)
10 |
11 |
12 | def assign_params_from_flat(new_params, params):
13 | def flat_size(p):
14 | return int(np.prod(p.shape))
15 | splits = torch.split(new_params, [flat_size(p) for p in params])
16 | new_params = [torch.reshape(p_new, p.shape)
17 | for p, p_new in zip(params, splits)]
18 | for p, p_new in zip(params, new_params):
19 | p.data.copy_(p_new)
20 |
21 |
22 | class ConjugateGradient:
23 | def __init__(
24 | self, conjugate_gradient_steps=10, damping_coefficient=0.1,
25 | constraint_threshold=0.01, backtrack_steps=10,
26 | backtrack_coefficient=0.8
27 | ):
28 | self.conjugate_gradient_steps = conjugate_gradient_steps
29 | self.damping_coefficient = damping_coefficient
30 | self.constraint_threshold = constraint_threshold
31 | self.backtrack_steps = backtrack_steps
32 | self.backtrack_coefficient = backtrack_coefficient
33 |
34 | def optimize(self, loss_function, constraint_function, variables):
35 | def _hx(x):
36 | f = constraint_function()
37 | gradient_1 = torch.autograd.grad(f, variables, create_graph=True)
38 | gradient_1 = flat_concat(gradient_1)
39 | x = torch.as_tensor(x)
40 | y = (gradient_1 * x).sum()
41 | gradient_2 = torch.autograd.grad(y, variables)
42 | gradient_2 = flat_concat(gradient_2)
43 |
44 | if self.damping_coefficient > 0:
45 | gradient_2 += self.damping_coefficient * x
46 |
47 | return gradient_2
48 |
49 | def _cg(b):
50 | x = np.zeros_like(b)
51 | r = b.copy()
52 | p = r.copy()
53 | r_dot_old = np.dot(r, r)
54 | if r_dot_old == 0:
55 | return None
56 |
57 | for _ in range(self.conjugate_gradient_steps):
58 | z = _hx(p).numpy()
59 | alpha = r_dot_old / (np.dot(p, z) + FLOAT_EPSILON)
60 | x += alpha * p
61 | r -= alpha * z
62 | r_dot_new = np.dot(r, r)
63 | p = r + (r_dot_new / r_dot_old) * p
64 | r_dot_old = r_dot_new
65 | return x
66 |
67 | def _update(alpha, conjugate_gradient, step, start_variables):
68 | conjugate_gradient = torch.as_tensor(conjugate_gradient)
69 | new_variables = start_variables - alpha * conjugate_gradient * step
70 | assign_params_from_flat(new_variables, variables)
71 | constraint = constraint_function()
72 | loss = loss_function()
73 | return constraint.detach(), loss.detach()
74 |
75 | start_variables = flat_concat(variables)
76 |
77 | for var in variables:
78 | if var.grad:
79 | var.grad.data.zero_()
80 |
81 | loss = loss_function()
82 | grad = torch.autograd.grad(loss, variables)
83 | grad = flat_concat(grad).numpy()
84 | start_loss = loss.detach().numpy()
85 |
86 | conjugate_gradient = _cg(grad)
87 | if conjugate_gradient is None:
88 | constraint = torch.as_tensor(0., dtype=torch.float32)
89 | loss = torch.as_tensor(0., dtype=torch.float32)
90 | steps = torch.as_tensor(0, dtype=torch.int32)
91 | return constraint, loss, steps
92 |
93 | alpha = np.sqrt(2 * self.constraint_threshold / np.dot(
94 | conjugate_gradient, _hx(conjugate_gradient)) + FLOAT_EPSILON)
95 |
96 | if self.backtrack_steps is None or self.backtrack_coefficient is None:
97 | constraint, loss = _update(
98 | alpha, conjugate_gradient, 1, start_variables)
99 | return constraint, loss
100 |
101 | for i in range(self.backtrack_steps):
102 | constraint, loss = _update(
103 | alpha, conjugate_gradient, self.backtrack_coefficient ** i,
104 | start_variables)
105 |
106 | if (constraint.numpy() <= self.constraint_threshold and
107 | loss.numpy() <= start_loss):
108 | break
109 |
110 | if i == self.backtrack_steps - 1:
111 | constraint, loss = _update(
112 | alpha, conjugate_gradient, 0, start_variables)
113 | i = self.backtrack_steps
114 |
115 | return constraint, loss, torch.as_tensor(i + 1, dtype=torch.int32)
116 |
--------------------------------------------------------------------------------
/tonic/torch/updaters/utils.py:
--------------------------------------------------------------------------------
1 | def tile(x, n):
2 | return x[None].repeat([n] + [1] * len(x.shape))
3 |
4 |
5 | def merge_first_two_dims(x):
6 | return x.view(x.shape[0] * x.shape[1], *x.shape[2:])
7 |
--------------------------------------------------------------------------------
/tonic/train.py:
--------------------------------------------------------------------------------
1 | '''Script used to train agents.'''
2 |
3 | import argparse
4 | import os
5 |
6 | import tonic
7 | import yaml
8 |
9 |
10 | def train(
11 | header, agent, environment, test_environment, trainer, before_training,
12 | after_training, parallel, sequential, seed, name, environment_name,
13 | checkpoint, path
14 | ):
15 | '''Trains an agent on an environment.'''
16 |
17 | # Capture the arguments to save them, e.g. to play with the trained agent.
18 | args = dict(locals())
19 |
20 | checkpoint_path = None
21 |
22 | # Process the checkpoint path same way as in tonic.play
23 | if path:
24 | tonic.logger.log(f'Loading experiment from {path}')
25 |
26 | # Use no checkpoint, the agent is freshly created.
27 | if checkpoint == 'none' or agent is not None:
28 | tonic.logger.log('Not loading any weights')
29 |
30 | else:
31 | checkpoint_path = os.path.join(path, 'checkpoints')
32 | if not os.path.isdir(checkpoint_path):
33 | tonic.logger.error(f'{checkpoint_path} is not a directory')
34 | checkpoint_path = None
35 |
36 | # List all the checkpoints.
37 | checkpoint_ids = []
38 | for file in os.listdir(checkpoint_path):
39 | if file[:5] == 'step_':
40 | checkpoint_id = file.split('.')[0]
41 | checkpoint_ids.append(int(checkpoint_id[5:]))
42 |
43 | if checkpoint_ids:
44 | # Use the last checkpoint.
45 | if checkpoint == 'last':
46 | checkpoint_id = max(checkpoint_ids)
47 | checkpoint_path = os.path.join(
48 | checkpoint_path, f'step_{checkpoint_id}')
49 |
50 | # Use the specified checkpoint.
51 | else:
52 | checkpoint_id = int(checkpoint)
53 | if checkpoint_id in checkpoint_ids:
54 | checkpoint_path = os.path.join(
55 | checkpoint_path, f'step_{checkpoint_id}')
56 | else:
57 | tonic.logger.error(f'Checkpoint {checkpoint_id} '
58 | f'not found in {checkpoint_path}')
59 | checkpoint_path = None
60 |
61 | else:
62 | tonic.logger.error(f'No checkpoint found in {checkpoint_path}')
63 | checkpoint_path = None
64 |
65 | # Load the experiment configuration.
66 | arguments_path = os.path.join(path, 'config.yaml')
67 | with open(arguments_path, 'r') as config_file:
68 | config = yaml.load(config_file, Loader=yaml.FullLoader)
69 | config = argparse.Namespace(**config)
70 |
71 | header = header or config.header
72 | agent = agent or config.agent
73 | environment = environment or config.test_environment
74 | environment = environment or config.environment
75 | trainer = trainer or config.trainer
76 |
77 | # Run the header first, e.g. to load an ML framework.
78 | if header:
79 | exec(header)
80 |
81 | # Build the training environment.
82 | _environment = environment
83 | environment = tonic.environments.distribute(
84 | lambda: eval(_environment), parallel, sequential)
85 | environment.initialize(seed=seed)
86 |
87 | # Build the testing environment.
88 | _test_environment = test_environment if test_environment else _environment
89 | test_environment = tonic.environments.distribute(
90 | lambda: eval(_test_environment))
91 | test_environment.initialize(seed=seed + 10000)
92 |
93 | # Build the agent.
94 | if not agent:
95 | raise ValueError('No agent specified.')
96 | agent = eval(agent)
97 | agent.initialize(
98 | observation_space=environment.observation_space,
99 | action_space=environment.action_space, seed=seed)
100 |
101 | # Load the weights of the agent form a checkpoint.
102 | if checkpoint_path:
103 | agent.load(checkpoint_path)
104 |
105 | # Initialize the logger to save data to the path environment/name/seed.
106 | if not environment_name:
107 | if hasattr(test_environment, 'name'):
108 | environment_name = test_environment.name
109 | else:
110 | environment_name = test_environment.__class__.__name__
111 | if not name:
112 | if hasattr(agent, 'name'):
113 | name = agent.name
114 | else:
115 | name = agent.__class__.__name__
116 | if parallel != 1 or sequential != 1:
117 | name += f'-{parallel}x{sequential}'
118 | path = os.path.join(environment_name, name, str(seed))
119 | tonic.logger.initialize(path, script_path=__file__, config=args)
120 |
121 | # Build the trainer.
122 | trainer = trainer or 'tonic.Trainer()'
123 | trainer = eval(trainer)
124 | trainer.initialize(
125 | agent=agent, environment=environment,
126 | test_environment=test_environment)
127 |
128 | # Run some code before training.
129 | if before_training:
130 | exec(before_training)
131 |
132 | # Train.
133 | trainer.run()
134 |
135 | # Run some code after training.
136 | if after_training:
137 | exec(after_training)
138 |
139 |
140 | if __name__ == '__main__':
141 | # Argument parsing.
142 | parser = argparse.ArgumentParser()
143 | parser.add_argument('--header')
144 | parser.add_argument('--agent')
145 | parser.add_argument('--environment', '--env')
146 | parser.add_argument('--test_environment', '--test_env')
147 | parser.add_argument('--trainer')
148 | parser.add_argument('--before_training')
149 | parser.add_argument('--after_training')
150 | parser.add_argument('--parallel', type=int, default=1)
151 | parser.add_argument('--sequential', type=int, default=1)
152 | parser.add_argument('--seed', type=int, default=0)
153 | parser.add_argument('--name')
154 | parser.add_argument('--environment_name')
155 | parser.add_argument('--checkpoint', default='last')
156 | parser.add_argument('--path')
157 |
158 | args = vars(parser.parse_args())
159 | train(**args)
160 |
--------------------------------------------------------------------------------
/tonic/utils/logger.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | import time
4 |
5 | import numpy as np
6 | import termcolor
7 | import yaml
8 |
9 |
10 | current_logger = None
11 |
12 |
13 | class Logger:
14 | '''Logger used to display and save logs, and save experiment configs.'''
15 |
16 | def __init__(self, path=None, width=60, script_path=None, config=None):
17 | self.path = path or str(time.time())
18 | self.log_file_path = os.path.join(self.path, 'log.csv')
19 |
20 | # Save the launch script.
21 | if script_path:
22 | with open(script_path, 'r') as script_file:
23 | script = script_file.read()
24 | try:
25 | os.makedirs(self.path, exist_ok=True)
26 | except Exception:
27 | pass
28 | script_path = os.path.join(self.path, 'script.py')
29 | with open(script_path, 'w') as config_file:
30 | config_file.write(script)
31 | log(f'Script file saved to {script_path}')
32 |
33 | # Save the configuration.
34 | if config:
35 | try:
36 | os.makedirs(self.path, exist_ok=True)
37 | except Exception:
38 | pass
39 | config_path = os.path.join(self.path, 'config.yaml')
40 | with open(config_path, 'w') as config_file:
41 | yaml.dump(config, config_file)
42 | log(f'Config file saved to {config_path}')
43 |
44 | self.known_keys = set()
45 | self.stat_keys = set()
46 | self.epoch_dict = {}
47 | self.width = width
48 | self.last_epoch_progress = None
49 | self.start_time = time.time()
50 |
51 | def store(self, key, value, stats=False):
52 | '''Keeps named values during an epoch.'''
53 |
54 | if key not in self.epoch_dict:
55 | self.epoch_dict[key] = [value]
56 | if stats:
57 | self.stat_keys.add(key)
58 | else:
59 | self.epoch_dict[key].append(value)
60 |
61 | def dump(self):
62 | '''Displays and saves the values at the end of an epoch.'''
63 |
64 | # Compute statistics if needed.
65 | keys = list(self.epoch_dict.keys())
66 | for key in keys:
67 | values = self.epoch_dict[key]
68 | if key in self.stat_keys:
69 | self.epoch_dict[key + '/mean'] = np.mean(values)
70 | self.epoch_dict[key + '/std'] = np.std(values)
71 | self.epoch_dict[key + '/min'] = np.min(values)
72 | self.epoch_dict[key + '/max'] = np.max(values)
73 | self.epoch_dict[key + '/size'] = len(values)
74 | del self.epoch_dict[key]
75 | else:
76 | self.epoch_dict[key] = np.mean(values)
77 |
78 | # Check if new keys were added.
79 | new_keys = [key for key in self.epoch_dict.keys()
80 | if key not in self.known_keys]
81 | if new_keys:
82 | first_row = len(self.known_keys) == 0
83 | if not first_row:
84 | print()
85 | warning(f'Logging new keys {new_keys}')
86 | # List the keys and prepare the display layout.
87 | for key in new_keys:
88 | self.known_keys.add(key)
89 | self.final_keys = list(sorted(self.known_keys))
90 | self.console_formats = []
91 | known_keys = set()
92 | for key in self.final_keys:
93 | *left_keys, right_key = key.split('/')
94 | for i, k in enumerate(left_keys):
95 | left_key = '/'.join(left_keys[:i + 1])
96 | if left_key not in known_keys:
97 | left = ' ' * i + k.replace('_', ' ')
98 | self.console_formats.append((left, None))
99 | known_keys.add(left_key)
100 | indent = ' ' * len(left_keys)
101 | right_key = right_key.replace('_', ' ')
102 | self.console_formats.append((indent + right_key, key))
103 |
104 | # Display the values following the layout.
105 | print()
106 | for left, key in self.console_formats:
107 | if key:
108 | val = self.epoch_dict.get(key)
109 | str_type = str(type(val))
110 | if 'tensorflow' in str_type:
111 | warning(f'Logging TensorFlow tensor {key}')
112 | elif 'torch' in str_type:
113 | warning(f'Logging Torch tensor {key}')
114 | if np.issubdtype(type(val), np.floating):
115 | right = f'{val:8.3g}'
116 | elif np.issubdtype(type(val), np.integer):
117 | right = f'{val:,}'
118 | else:
119 | right = str(val)
120 | spaces = ' ' * (self.width - len(left) - len(right))
121 | print(left + spaces + right)
122 | else:
123 | spaces = ' ' * (self.width - len(left))
124 | print(left + spaces)
125 | print()
126 |
127 | # Save the data to the log file
128 | vals = [self.epoch_dict.get(key) for key in self.final_keys]
129 | if new_keys:
130 | if first_row:
131 | log(f'Logging data to {self.log_file_path}')
132 | try:
133 | os.makedirs(self.path, exist_ok=True)
134 | except Exception:
135 | pass
136 | with open(self.log_file_path, 'w') as file:
137 | file.write(','.join(self.final_keys) + '\n')
138 | file.write(','.join(map(str, vals)) + '\n')
139 | else:
140 | with open(self.log_file_path, 'r') as file:
141 | lines = file.read().splitlines()
142 | old_keys = lines[0].split(',')
143 | old_lines = [line.split(',') for line in lines[1:]]
144 | new_indices = []
145 | j = 0
146 | for i, key in enumerate(self.final_keys):
147 | if key == old_keys[j]:
148 | j += 1
149 | else:
150 | new_indices.append(i)
151 | assert j == len(old_keys)
152 | for line in old_lines:
153 | for i in new_indices:
154 | line.insert(i, 'None')
155 | with open(self.log_file_path, 'w') as file:
156 | file.write(','.join(self.final_keys) + '\n')
157 | for line in old_lines:
158 | file.write(','.join(line) + '\n')
159 | file.write(','.join(map(str, vals)) + '\n')
160 | else:
161 | with open(self.log_file_path, 'a') as file:
162 | file.write(','.join(map(str, vals)) + '\n')
163 |
164 | self.epoch_dict.clear()
165 | self.last_epoch_progress = None
166 | self.last_epoch_time = time.time()
167 |
168 | def show_progress(
169 | self, steps, num_epoch_steps, num_steps, color='white',
170 | on_color='on_blue'
171 | ):
172 | '''Shows a progress bar for the current epoch and total training.'''
173 |
174 | epoch_steps = (steps - 1) % num_epoch_steps + 1
175 | epoch_progress = int(self.width * epoch_steps / num_epoch_steps)
176 | if epoch_progress != self.last_epoch_progress:
177 | current_time = time.time()
178 | seconds = current_time - self.start_time
179 | seconds_per_step = seconds / steps
180 | epoch_rem_steps = num_epoch_steps - epoch_steps
181 | epoch_rem_secs = max(epoch_rem_steps * seconds_per_step, 0)
182 | epoch_rem_secs = datetime.timedelta(seconds=epoch_rem_secs + 1e-6)
183 | epoch_rem_secs = str(epoch_rem_secs)[:-7]
184 | total_rem_steps = num_steps - steps
185 | total_rem_secs = max(total_rem_steps * seconds_per_step, 0)
186 | total_rem_secs = datetime.timedelta(seconds=total_rem_secs)
187 | total_rem_secs = str(total_rem_secs)[:-7]
188 | msg = f'Time left: epoch {epoch_rem_secs} total {total_rem_secs}'
189 | msg = msg.center(self.width)
190 | print(termcolor.colored(
191 | '\r' + msg[:epoch_progress], color, on_color), end='')
192 | print(msg[epoch_progress:], sep='', end='')
193 | self.last_epoch_progress = epoch_progress
194 |
195 |
196 | def initialize(*args, **kwargs):
197 | global current_logger
198 | current_logger = Logger(*args, **kwargs)
199 | return current_logger
200 |
201 |
202 | def get_current_logger():
203 | global current_logger
204 | if current_logger is None:
205 | current_logger = Logger()
206 | return current_logger
207 |
208 |
209 | def store(*args, **kwargs):
210 | logger = get_current_logger()
211 | return logger.store(*args, **kwargs)
212 |
213 |
214 | def dump(*args, **kwargs):
215 | logger = get_current_logger()
216 | return logger.dump(*args, **kwargs)
217 |
218 |
219 | def show_progress(*args, **kwargs):
220 | logger = get_current_logger()
221 | return logger.show_progress(*args, **kwargs)
222 |
223 |
224 | def get_path():
225 | logger = get_current_logger()
226 | return logger.path
227 |
228 |
229 | def log(msg, color='green'):
230 | print(termcolor.colored(msg, color, attrs=['bold']))
231 |
232 |
233 | def warning(msg, color='yellow'):
234 | print(termcolor.colored('Warning: ' + msg, color, attrs=['bold']))
235 |
236 |
237 | def error(msg, color='red'):
238 | print(termcolor.colored('Error: ' + msg, color, attrs=['bold']))
239 |
--------------------------------------------------------------------------------
/tonic/utils/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | import numpy as np
5 |
6 | from tonic import logger
7 |
8 |
9 | class Trainer:
10 | '''Trainer used to train and evaluate an agent on an environment.'''
11 |
12 | def __init__(
13 | self, steps=int(1e7), epoch_steps=int(2e4), save_steps=int(5e5),
14 | test_episodes=5, show_progress=True, replace_checkpoint=False,
15 | ):
16 | self.max_steps = steps
17 | self.epoch_steps = epoch_steps
18 | self.save_steps = save_steps
19 | self.test_episodes = test_episodes
20 | self.show_progress = show_progress
21 | self.replace_checkpoint = replace_checkpoint
22 |
23 | def initialize(self, agent, environment, test_environment=None):
24 | self.agent = agent
25 | self.environment = environment
26 | self.test_environment = test_environment
27 |
28 | def run(self):
29 | '''Runs the main training loop.'''
30 |
31 | start_time = last_epoch_time = time.time()
32 |
33 | # Start the environments.
34 | observations = self.environment.start()
35 |
36 | num_workers = len(observations)
37 | scores = np.zeros(num_workers)
38 | lengths = np.zeros(num_workers, int)
39 | self.steps, epoch_steps, epochs, episodes = 0, 0, 0, 0
40 | steps_since_save = 0
41 |
42 | while True:
43 | # Select actions.
44 | actions = self.agent.step(observations, self.steps)
45 | assert not np.isnan(actions.sum())
46 | logger.store('train/action', actions, stats=True)
47 |
48 | # Take a step in the environments.
49 | observations, infos = self.environment.step(actions)
50 | self.agent.update(**infos, steps=self.steps)
51 |
52 | scores += infos['rewards']
53 | lengths += 1
54 | self.steps += num_workers
55 | epoch_steps += num_workers
56 | steps_since_save += num_workers
57 |
58 | # Show the progress bar.
59 | if self.show_progress:
60 | logger.show_progress(
61 | self.steps, self.epoch_steps, self.max_steps)
62 |
63 | # Check the finished episodes.
64 | for i in range(num_workers):
65 | if infos['resets'][i]:
66 | logger.store('train/episode_score', scores[i], stats=True)
67 | logger.store(
68 | 'train/episode_length', lengths[i], stats=True)
69 | scores[i] = 0
70 | lengths[i] = 0
71 | episodes += 1
72 |
73 | # End of the epoch.
74 | if epoch_steps >= self.epoch_steps:
75 | # Evaluate the agent on the test environment.
76 | if self.test_environment:
77 | self._test()
78 |
79 | # Log the data.
80 | epochs += 1
81 | current_time = time.time()
82 | epoch_time = current_time - last_epoch_time
83 | sps = epoch_steps / epoch_time
84 | logger.store('train/episodes', episodes)
85 | logger.store('train/epochs', epochs)
86 | logger.store('train/seconds', current_time - start_time)
87 | logger.store('train/epoch_seconds', epoch_time)
88 | logger.store('train/epoch_steps', epoch_steps)
89 | logger.store('train/steps', self.steps)
90 | logger.store('train/worker_steps', self.steps // num_workers)
91 | logger.store('train/steps_per_second', sps)
92 | logger.dump()
93 | last_epoch_time = time.time()
94 | epoch_steps = 0
95 |
96 | # End of training.
97 | stop_training = self.steps >= self.max_steps
98 |
99 | # Save a checkpoint.
100 | if stop_training or steps_since_save >= self.save_steps:
101 | path = os.path.join(logger.get_path(), 'checkpoints')
102 | if os.path.isdir(path) and self.replace_checkpoint:
103 | for file in os.listdir(path):
104 | if file.startswith('step_'):
105 | os.remove(os.path.join(path, file))
106 | checkpoint_name = f'step_{self.steps}'
107 | save_path = os.path.join(path, checkpoint_name)
108 | self.agent.save(save_path)
109 | steps_since_save = self.steps % self.save_steps
110 |
111 | if stop_training:
112 | break
113 |
114 | def _test(self):
115 | '''Tests the agent on the test environment.'''
116 |
117 | # Start the environment.
118 | if not hasattr(self, 'test_observations'):
119 | self.test_observations = self.test_environment.start()
120 | assert len(self.test_observations) == 1
121 |
122 | # Test loop.
123 | for _ in range(self.test_episodes):
124 | score, length = 0, 0
125 |
126 | while True:
127 | # Select an action.
128 | actions = self.agent.test_step(
129 | self.test_observations, self.steps)
130 | assert not np.isnan(actions.sum())
131 | logger.store('test/action', actions, stats=True)
132 |
133 | # Take a step in the environment.
134 | self.test_observations, infos = self.test_environment.step(
135 | actions)
136 | self.agent.test_update(**infos, steps=self.steps)
137 |
138 | score += infos['rewards'][0]
139 | length += 1
140 |
141 | if infos['resets'][0]:
142 | break
143 |
144 | # Log the data.
145 | logger.store('test/episode_score', score, stats=True)
146 | logger.store('test/episode_length', length, stats=True)
147 |
--------------------------------------------------------------------------------