├── maml_rl ├── __init__.py ├── utils │ ├── __init__.py │ ├── optimization.py │ ├── reinforcement_learning.py │ └── torch_utils.py ├── envs │ ├── mujoco │ │ ├── __init__.py │ │ ├── half_cheetah.py │ │ └── ant.py │ ├── utils.py │ ├── __init__.py │ ├── navigation.py │ ├── mdp.py │ ├── bandit.py │ ├── normalized_env.py │ └── subproc_vec_env.py ├── policies │ ├── __init__.py │ ├── policy.py │ ├── categorical_mlp.py │ └── normal_mlp.py ├── baseline.py ├── sampler.py ├── episode.py └── metalearner.py ├── requirements.txt ├── .idea ├── dictionaries │ └── i.xml └── vcs.xml ├── _assets └── halfcheetahdir.gif ├── test ├── test_pipe2.py └── test_pipe.py ├── LICENSE ├── .gitignore ├── README.md └── main.py /maml_rl/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /maml_rl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /maml_rl/envs/mujoco/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.14.0 2 | torch>=0.4.0 3 | gym>=0.10.5 4 | tensorboardX>=1.2 -------------------------------------------------------------------------------- /.idea/dictionaries/i.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /_assets/halfcheetahdir.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragen1860/MAML-Pytorch-RL/HEAD/_assets/halfcheetahdir.gif -------------------------------------------------------------------------------- /maml_rl/policies/__init__.py: -------------------------------------------------------------------------------- 1 | from .categorical_mlp import CategoricalMLPPolicy 2 | from .normal_mlp import NormalMLPPolicy 3 | from .policy import Policy 4 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /maml_rl/envs/utils.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import load 2 | from .normalized_env import NormalizedActionWrapper 3 | 4 | 5 | 6 | 7 | 8 | 9 | def mujoco_wrapper(entry_point, **kwargs): 10 | 11 | # Load the environment from its entry point 12 | env_cls = load(entry_point) 13 | env = env_cls(**kwargs) 14 | # Normalization wrapper 15 | env = NormalizedActionWrapper(env) 16 | 17 | return env 18 | -------------------------------------------------------------------------------- /test/test_pipe2.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Process, Pipe 2 | import numpy as np 3 | 4 | 5 | def writeToConnection(conn): 6 | conn.send(np.ones(3)) 7 | conn.close() 8 | 9 | 10 | if __name__ == '__main__': 11 | 12 | recv_conn, send_conn = Pipe(duplex=False) 13 | 14 | p = Process(target=writeToConnection, args=(send_conn,)) 15 | p.start() 16 | print(recv_conn.recv()) 17 | p.join() 18 | 19 | 20 | recv_conn, send_conn = Pipe(duplex=False) 21 | send_conn.send('hello') 22 | res = recv_conn.recv() 23 | print(res) 24 | -------------------------------------------------------------------------------- /maml_rl/utils/optimization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def conjugate_gradient(f_Ax, b, cg_iters=10, residual_tol=1e-10): 5 | p = b.clone().detach() 6 | r = b.clone().detach() 7 | x = torch.zeros_like(b).float() 8 | rdotr = torch.dot(r, r) 9 | 10 | for i in range(cg_iters): 11 | z = f_Ax(p).detach() 12 | v = rdotr / torch.dot(p, z) 13 | x += v * p 14 | r -= v * z 15 | newrdotr = torch.dot(r, r) 16 | mu = newrdotr / rdotr 17 | p = r + mu * p 18 | 19 | rdotr = newrdotr 20 | if rdotr.item() < residual_tol: 21 | break 22 | 23 | return x.detach() 24 | -------------------------------------------------------------------------------- /maml_rl/utils/reinforcement_learning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def value_iteration(transitions, rewards, gamma=0.95, theta=1e-5): 5 | rewards = np.expand_dims(rewards, axis=2) 6 | values = np.zeros(transitions.shape[0], dtype=np.float32) 7 | delta = np.inf 8 | while delta >= theta: 9 | q_values = np.sum(transitions * (rewards + gamma * values), axis=2) 10 | new_values = np.max(q_values, axis=1) 11 | delta = np.max(np.abs(new_values - values)) 12 | values = new_values 13 | 14 | return values 15 | 16 | 17 | def value_iteration_finite_horizon(transitions, rewards, horizon=10, gamma=0.95): 18 | rewards = np.expand_dims(rewards, axis=2) 19 | values = np.zeros(transitions.shape[0], dtype=np.float32) 20 | for k in range(horizon): 21 | q_values = np.sum(transitions * (rewards + gamma * values), axis=2) 22 | values = np.max(q_values, axis=1) 23 | 24 | return values 25 | -------------------------------------------------------------------------------- /maml_rl/policies/policy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from collections import OrderedDict 5 | 6 | 7 | def weight_init(module): 8 | if isinstance(module, nn.Linear): 9 | nn.init.xavier_uniform_(module.weight) 10 | module.bias.data.zero_() 11 | 12 | 13 | class Policy(nn.Module): 14 | def __init__(self, input_size, output_size): 15 | super(Policy, self).__init__() 16 | self.input_size = input_size 17 | self.output_size = output_size 18 | 19 | def update_params(self, loss, step_size=0.5, first_order=False): 20 | """ 21 | Apply one step of gradient descent on the loss function `loss`, with 22 | step-size `step_size`, and returns the updated parameters of the neural 23 | network. 24 | """ 25 | grads = torch.autograd.grad(loss, self.parameters(), 26 | create_graph=not first_order) 27 | updated_params = OrderedDict() 28 | for (name, param), grad in zip(self.named_parameters(), grads): 29 | updated_params[name] = param - step_size * grad 30 | 31 | return updated_params 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Tristan Deleu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /test/test_pipe.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | 3 | 4 | def prod(pipe): 5 | out_conn, _ = pipe 6 | for x in range(10): 7 | out_conn.send(x) 8 | 9 | out_conn.close() 10 | 11 | 12 | def square(pipe1, pipe2): 13 | close, in_conn = pipe1 14 | close.close() 15 | out_conn, _ = pipe2 16 | try: 17 | while True: 18 | x = in_conn.recv() 19 | out_conn.send(x * x) 20 | except EOFError: 21 | out_conn.close() 22 | 23 | 24 | def double(unused_pipes, in_pipe, out_pipe): 25 | for pipe in unused_pipes: 26 | close, _ = pipe 27 | close.close() 28 | 29 | closep, in_conn = in_pipe 30 | closep.close() 31 | 32 | out_conn, _ = out_pipe 33 | try: 34 | while True: 35 | x = in_conn.recv() 36 | out_conn.send(x * 2) 37 | except EOFError: 38 | out_conn.close() 39 | 40 | 41 | def test_pipes(): 42 | pipe1 = mp.Pipe(True) 43 | p1 = mp.Process(target=prod, args=(pipe1,)) 44 | p1.start() 45 | 46 | pipe2 = mp.Pipe(True) 47 | p2 = mp.Process(target=square, args=(pipe1, pipe2,)) 48 | p2.start() 49 | 50 | pipe3 = mp.Pipe(True) 51 | p3 = mp.Process(target=double, args=([pipe1], pipe2, pipe3,)) 52 | p3.start() 53 | 54 | pipe1[0].close() 55 | pipe2[0].close() 56 | pipe3[0].close() 57 | 58 | try: 59 | while True: 60 | print(pipe3[1].recv()) 61 | except EOFError: 62 | print("Finished") 63 | 64 | 65 | if __name__ == '__main__': 66 | test_pipes() 67 | -------------------------------------------------------------------------------- /maml_rl/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions import Categorical, Normal 3 | 4 | 5 | def weighted_mean(tensor, dim=None, weights=None): 6 | if weights is None: 7 | return torch.mean(tensor) 8 | if dim is None: 9 | sum_weights = torch.sum(weights) 10 | return torch.sum(tensor * weights) / sum_weights 11 | if isinstance(dim, int): 12 | dim = (dim,) 13 | numerator = tensor * weights 14 | denominator = weights 15 | for dimension in dim: 16 | numerator = torch.sum(numerator, dimension, keepdim=True) 17 | denominator = torch.sum(denominator, dimension, keepdim=True) 18 | return numerator / denominator 19 | 20 | 21 | def detach_distribution(pi): 22 | if isinstance(pi, Categorical): 23 | distribution = Categorical(logits=pi.logits.detach()) 24 | elif isinstance(pi, Normal): 25 | distribution = Normal(loc=pi.loc.detach(), scale=pi.scale.detach()) 26 | else: 27 | raise NotImplementedError('Only `Categorical` and `Normal` ' 28 | 'policies are valid policies.') 29 | return distribution 30 | 31 | 32 | def weighted_normalize(tensor, dim=None, weights=None, epsilon=1e-8): 33 | if weights is None: 34 | weights = torch.ones_like(tensor) 35 | mean = weighted_mean(tensor, dim=dim, weights=weights) 36 | centered = tensor * weights - mean 37 | std = torch.sqrt(weighted_mean(centered ** 2, dim=dim, weights=weights)) 38 | return centered / (std + epsilon) 39 | -------------------------------------------------------------------------------- /maml_rl/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | # Bandit 4 | # ---------------------------------------- 5 | 6 | for k in [5, 10, 50]: 7 | register( 8 | 'Bandit-K{0}-v0'.format(k), 9 | entry_point='maml_rl.envs.bandit:BernoulliBanditEnv', 10 | kwargs={'k': k} 11 | ) 12 | 13 | # TabularMDP 14 | # ---------------------------------------- 15 | 16 | register( 17 | 'TabularMDP-v0', 18 | entry_point='maml_rl.envs.mdp:TabularMDPEnv', 19 | kwargs={'num_states': 10, 'num_actions': 5}, 20 | max_episode_steps=10 21 | ) 22 | 23 | # Mujoco 24 | # ---------------------------------------- 25 | 26 | register( 27 | 'AntVel-v1', 28 | entry_point='maml_rl.envs.utils:mujoco_wrapper', 29 | kwargs={'entry_point': 'maml_rl.envs.mujoco.ant:AntVelEnv'}, 30 | max_episode_steps=200 31 | ) 32 | 33 | register( 34 | 'AntDir-v1', 35 | entry_point='maml_rl.envs.utils:mujoco_wrapper', 36 | kwargs={'entry_point': 'maml_rl.envs.mujoco.ant:AntDirEnv'}, 37 | max_episode_steps=200 38 | ) 39 | 40 | register( 41 | 'AntPos-v0', 42 | entry_point='maml_rl.envs.utils:mujoco_wrapper', 43 | kwargs={'entry_point': 'maml_rl.envs.mujoco.ant:AntPosEnv'}, 44 | max_episode_steps=200 45 | ) 46 | 47 | register( 48 | 'HalfCheetahVel-v1', 49 | entry_point='maml_rl.envs.utils:mujoco_wrapper', 50 | kwargs={'entry_point': 'maml_rl.envs.mujoco.half_cheetah:HalfCheetahVelEnv'}, 51 | max_episode_steps=200 52 | ) 53 | 54 | register( 55 | 'HalfCheetahDir-v1', 56 | entry_point='maml_rl.envs.utils:mujoco_wrapper', 57 | kwargs={'entry_point': 'maml_rl.envs.mujoco.half_cheetah:HalfCheetahDirEnv'}, 58 | max_episode_steps=200 59 | ) 60 | 61 | # 2D Navigation 62 | # ---------------------------------------- 63 | 64 | register( 65 | '2DNavigation-v0', 66 | entry_point='maml_rl.envs.navigation:Navigation2DEnv', 67 | max_episode_steps=100 68 | ) 69 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # Temporary 104 | tmp/ 105 | run.sh 106 | 107 | # Logs & Saves 108 | logs/ 109 | saves/ 110 | 111 | # Slurm 112 | *.out 113 | -------------------------------------------------------------------------------- /maml_rl/policies/categorical_mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import Categorical 5 | 6 | from collections import OrderedDict 7 | from maml_rl.policies.policy import Policy, weight_init 8 | 9 | 10 | class CategoricalMLPPolicy(Policy): 11 | """ 12 | Policy network based on a multi-layer perceptron (MLP), with a 13 | `Categorical` distribution output. This policy network can be used on tasks 14 | with discrete action spaces (eg. `TabularMDPEnv`). The code is adapted from 15 | https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/sandbox/rocky/tf/policies/maml_minimal_categorical_mlp_policy.py 16 | """ 17 | 18 | def __init__(self, input_size, output_size, hidden_sizes=(), nonlinearity=F.relu): 19 | super(CategoricalMLPPolicy, self).__init__(input_size=input_size, output_size=output_size) 20 | self.hidden_sizes = hidden_sizes 21 | self.nonlinearity = nonlinearity 22 | self.num_layers = len(hidden_sizes) + 1 23 | 24 | layer_sizes = (input_size,) + hidden_sizes + (output_size,) 25 | for i in range(1, self.num_layers + 1): 26 | self.add_module('layer{0}'.format(i), 27 | nn.Linear(layer_sizes[i - 1], layer_sizes[i])) 28 | self.apply(weight_init) 29 | 30 | def forward(self, input, params=None): 31 | if params is None: 32 | params = OrderedDict(self.named_parameters()) 33 | output = input 34 | for i in range(1, self.num_layers): 35 | output = F.linear(output, 36 | weight=params['layer{0}.weight'.format(i)], 37 | bias=params['layer{0}.bias'.format(i)]) 38 | output = self.nonlinearity(output) 39 | logits = F.linear(output, 40 | weight=params['layer{0}.weight'.format(self.num_layers)], 41 | bias=params['layer{0}.bias'.format(self.num_layers)]) 42 | 43 | return Categorical(logits=logits) 44 | -------------------------------------------------------------------------------- /maml_rl/policies/normal_mlp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.distributions import Normal 6 | 7 | from collections import OrderedDict 8 | from maml_rl.policies.policy import Policy, weight_init 9 | 10 | 11 | class NormalMLPPolicy(Policy): 12 | """ 13 | Policy network based on a multi-layer perceptron (MLP), with a 14 | `Normal` distribution output, with trainable standard deviation. This 15 | policy network can be used on tasks with continuous action spaces (eg. 16 | `HalfCheetahDir`). The code is adapted from 17 | https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/sandbox/rocky/tf/policies/maml_minimal_gauss_mlp_policy.py 18 | """ 19 | 20 | def __init__(self, input_size, output_size, hidden_sizes=(), 21 | nonlinearity=F.relu, init_std=1.0, min_std=1e-6): 22 | super(NormalMLPPolicy, self).__init__( 23 | input_size=input_size, output_size=output_size) 24 | self.hidden_sizes = hidden_sizes 25 | self.nonlinearity = nonlinearity 26 | self.min_log_std = math.log(min_std) 27 | self.num_layers = len(hidden_sizes) + 1 28 | 29 | layer_sizes = (input_size,) + hidden_sizes 30 | for i in range(1, self.num_layers): 31 | self.add_module('layer{0}'.format(i), 32 | nn.Linear(layer_sizes[i - 1], layer_sizes[i])) 33 | self.mu = nn.Linear(layer_sizes[-1], output_size) 34 | 35 | self.sigma = nn.Parameter(torch.Tensor(output_size)) 36 | self.sigma.data.fill_(math.log(init_std)) 37 | self.apply(weight_init) 38 | 39 | def forward(self, input, params=None): 40 | if params is None: 41 | params = OrderedDict(self.named_parameters()) 42 | output = input 43 | for i in range(1, self.num_layers): 44 | output = F.linear(output, 45 | weight=params['layer{0}.weight'.format(i)], 46 | bias=params['layer{0}.bias'.format(i)]) 47 | output = self.nonlinearity(output) 48 | mu = F.linear(output, weight=params['mu.weight'], 49 | bias=params['mu.bias']) 50 | scale = torch.exp(torch.clamp(params['sigma'], min=self.min_log_std)) 51 | 52 | return Normal(loc=mu, scale=scale) 53 | -------------------------------------------------------------------------------- /maml_rl/baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LinearFeatureBaseline(nn.Module): 6 | """ 7 | Linear baseline based on handcrafted features, as described in [1] 8 | (Supplementary Material 2). 9 | 10 | [1] Yan Duan, Xi Chen, Rein Houthooft, John Schulman, Pieter Abbeel, 11 | "Benchmarking Deep Reinforcement Learning for Continuous Control", 2016 12 | (https://arxiv.org/abs/1604.06778) 13 | """ 14 | 15 | def __init__(self, input_size, reg_coeff=1e-5): 16 | super(LinearFeatureBaseline, self).__init__() 17 | self.input_size = input_size 18 | self._reg_coeff = reg_coeff 19 | self.linear = nn.Linear(self.feature_size, 1, bias=False) 20 | self.linear.weight.data.zero_() 21 | 22 | @property 23 | def feature_size(self): 24 | return 2 * self.input_size + 4 25 | 26 | def _feature(self, episodes): 27 | ones = episodes.mask.unsqueeze(2) 28 | observations = episodes.observations * ones 29 | cum_sum = torch.cumsum(ones, dim=0) * ones 30 | al = cum_sum / 100.0 31 | 32 | return torch.cat([observations, observations ** 2, 33 | al, al ** 2, al ** 3, ones], dim=2) 34 | 35 | def fit(self, episodes): 36 | # sequence_length * batch_size x feature_size 37 | featmat = self._feature(episodes).view(-1, self.feature_size) 38 | # sequence_length * batch_size x 1 39 | returns = episodes.returns.view(-1, 1) 40 | 41 | reg_coeff = self._reg_coeff 42 | eye = torch.eye(self.feature_size, dtype=torch.float32, 43 | device=self.linear.weight.device) 44 | for _ in range(5): 45 | try: 46 | coeffs, _ = torch.gels( 47 | torch.matmul(featmat.t(), returns), 48 | torch.matmul(featmat.t(), featmat) + reg_coeff * eye 49 | ) 50 | break 51 | except RuntimeError: 52 | reg_coeff += 10 53 | else: 54 | raise RuntimeError('Unable to solve the normal equations in ' 55 | '`LinearFeatureBaseline`. The matrix X^T*X (with X the design ' 56 | 'matrix) is not full-rank, regardless of the regularization ' 57 | '(maximum regularization: {0}).'.format(reg_coeff)) 58 | self.linear.weight.data = coeffs.data.t() 59 | 60 | def forward(self, episodes): 61 | features = self._feature(episodes) 62 | return self.linear(features) 63 | -------------------------------------------------------------------------------- /maml_rl/envs/navigation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import gym 4 | from gym import spaces 5 | from gym.utils import seeding 6 | 7 | 8 | class Navigation2DEnv(gym.Env): 9 | """2D navigation problems, as described in [1]. The code is adapted from 10 | https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/maml_examples/point_env_randgoal.py 11 | 12 | At each time step, the 2D agent takes an action (its velocity, clipped in 13 | [-0.1, 0.1]), and receives a penalty equal to its L2 distance to the goal 14 | position (ie. the reward is `-distance`). The 2D navigation tasks are 15 | generated by sampling goal positions from the uniform distribution 16 | on [-0.5, 0.5]^2. 17 | 18 | [1] Chelsea Finn, Pieter Abbeel, Sergey Levine, "Model-Agnostic 19 | Meta-Learning for Fast Adaptation of Deep Networks", 2017 20 | (https://arxiv.org/abs/1703.03400) 21 | """ 22 | 23 | def __init__(self, task={}): 24 | super(Navigation2DEnv, self).__init__() 25 | 26 | self.observation_space = spaces.Box(low=-np.inf, high=np.inf, 27 | shape=(2,), dtype=np.float32) 28 | self.action_space = spaces.Box(low=-0.1, high=0.1, 29 | shape=(2,), dtype=np.float32) 30 | 31 | self._task = task 32 | self._goal = task.get('goal', np.zeros(2, dtype=np.float32)) 33 | self._state = np.zeros(2, dtype=np.float32) 34 | self.seed() 35 | 36 | def seed(self, seed=None): 37 | self.np_random, seed = seeding.np_random(seed) 38 | return [seed] 39 | 40 | def sample_tasks(self, num_tasks): 41 | goals = self.np_random.uniform(-0.5, 0.5, size=(num_tasks, 2)) 42 | tasks = [{'goal': goal} for goal in goals] 43 | return tasks 44 | 45 | def reset_task(self, task): 46 | self._task = task 47 | self._goal = task['goal'] 48 | 49 | def reset(self, env=True): 50 | self._state = np.zeros(2, dtype=np.float32) 51 | return self._state 52 | 53 | def step(self, action): 54 | action = np.clip(action, -0.1, 0.1) 55 | assert self.action_space.contains(action) 56 | self._state = self._state + action 57 | 58 | x = self._state[0] - self._goal[0] 59 | y = self._state[1] - self._goal[1] 60 | reward = -np.sqrt(x ** 2 + y ** 2) 61 | done = ((np.abs(x) < 0.01) and (np.abs(y) < 0.01)) 62 | 63 | return self._state, reward, done, self._task 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reinforcement Learning with Model-Agnostic Meta-Learning (MAML) 2 | 3 | ![HalfCheetahDir](https://raw.githubusercontent.com/tristandeleu/pytorch-maml-rl/master/_assets/halfcheetahdir.gif) 4 | 5 | Implementation of Model-Agnostic Meta-Learning (MAML) applied on Reinforcement Learning problems in Pytorch. This repository includes environments introduced in ([Duan et al., 2016](https://arxiv.org/abs/1611.02779), [Finn et al., 2017](https://arxiv.org/abs/1703.03400)): multi-armed bandits, tabular MDPs, continuous control with MuJoCo, and 2D navigation task. 6 | 7 | ## Getting started 8 | To avoid any conflict with your existing Python setup, and to keep this project self-contained, it is suggested to work in a virtual environment with [`virtualenv`](http://docs.python-guide.org/en/latest/dev/virtualenvs/). To install `virtualenv`: 9 | ``` 10 | pip install --upgrade virtualenv 11 | ``` 12 | Create a virtual environment, activate it and install the requirements in [`requirements.txt`](requirements.txt). 13 | ``` 14 | virtualenv venv 15 | source venv/bin/activate 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## Usage 20 | You can use the [`main.py`](main.py) script in order to run reinforcement learning experiments with MAML. This script was tested with Python 3.5. Note that some environments may also work with Python 2.7 (all experiments besides MuJoCo-based environments). 21 | ``` 22 | python main.py --env-name HalfCheetahDir-v1 --num-workers 8 --fast-lr 0.1 --max-kl 0.01 --fast-batch-size 20 --meta-batch-size 40 --num-layers 2 --hidden-size 100 --num-batches 1000 --gamma 0.99 --tau 1.0 --cg-damping 1e-5 --ls-max-steps 15 --output-folder maml-halfcheetah-dir --device cuda 23 | ``` 24 | 25 | ## References 26 | This project is, for the most part, a reproduction of the original implementation [cbfinn/maml_rl](https://github.com/cbfinn/maml_rl/) in Pytorch. These experiments are based on the paper 27 | > Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-agnostic meta-learning for fast adaptation of deep 28 | networks. _International Conference on Machine Learning (ICML)_, 2017 [[ArXiv](https://arxiv.org/abs/1703.03400)] 29 | 30 | If you want to cite this paper 31 | ``` 32 | @article{DBLP:journals/corr/FinnAL17, 33 | author = {Chelsea Finn and Pieter Abbeel and Sergey Levine}, 34 | title = {Model-{A}gnostic {M}eta-{L}earning for {F}ast {A}daptation of {D}eep {N}etworks}, 35 | journal = {International Conference on Machine Learning (ICML)}, 36 | year = {2017}, 37 | url = {http://arxiv.org/abs/1703.03400} 38 | } 39 | ``` 40 | -------------------------------------------------------------------------------- /maml_rl/sampler.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import torch 3 | import multiprocessing as mp 4 | 5 | from maml_rl.envs.subproc_vec_env import SubprocVecEnv 6 | from maml_rl.episode import BatchEpisodes 7 | 8 | 9 | def make_env(env_name): 10 | """ 11 | return a function 12 | :param env_name: 13 | :return: 14 | """ 15 | def _make_env(): 16 | return gym.make(env_name) 17 | 18 | return _make_env 19 | 20 | 21 | class BatchSampler: 22 | 23 | def __init__(self, env_name, batch_size, num_workers=mp.cpu_count()): 24 | """ 25 | 26 | :param env_name: 27 | :param batch_size: fast batch size 28 | :param num_workers: 29 | """ 30 | self.env_name = env_name 31 | self.batch_size = batch_size 32 | self.num_workers = num_workers 33 | 34 | self.queue = mp.Queue() 35 | # [lambda function] 36 | env_factorys = [make_env(env_name) for _ in range(num_workers)] 37 | # this is the main process manager, and it will be in charge of num_workers sub-processes interacting with 38 | # environment. 39 | self.envs = SubprocVecEnv(env_factorys, queue_=self.queue) 40 | self._env = gym.make(env_name) 41 | 42 | def sample(self, policy, params=None, gamma=0.95, device='cpu'): 43 | """ 44 | 45 | :param policy: 46 | :param params: 47 | :param gamma: 48 | :param device: 49 | :return: 50 | """ 51 | episodes = BatchEpisodes(batch_size=self.batch_size, gamma=gamma, device=device) 52 | for i in range(self.batch_size): 53 | self.queue.put(i) 54 | for _ in range(self.num_workers): 55 | self.queue.put(None) 56 | 57 | observations, batch_ids = self.envs.reset() 58 | dones = [False] 59 | while (not all(dones)) or (not self.queue.empty()): # if all done and queue is empty 60 | # for reinforcement learning, the forward process requires no-gradient 61 | with torch.no_grad(): 62 | # convert observation to cuda 63 | # compute policy on cuda 64 | # convert action to cpu 65 | observations_tensor = torch.from_numpy(observations).to(device=device) 66 | # forward via policy network 67 | # policy network will return Categorical(logits=logits) 68 | actions_tensor = policy(observations_tensor, params=params).sample() 69 | actions = actions_tensor.cpu().numpy() 70 | 71 | new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(actions) 72 | # here is observations NOT new_observations, batch_ids NOT new_batch_ids 73 | episodes.append(observations, actions, rewards, batch_ids) 74 | observations, batch_ids = new_observations, new_batch_ids 75 | 76 | return episodes 77 | 78 | def reset_task(self, task): 79 | tasks = [task for _ in range(self.num_workers)] 80 | reset = self.envs.reset_task(tasks) 81 | return all(reset) 82 | 83 | def sample_tasks(self, num_tasks): 84 | tasks = self._env.unwrapped.sample_tasks(num_tasks) 85 | return tasks 86 | -------------------------------------------------------------------------------- /maml_rl/envs/mdp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import gym 4 | from gym import spaces 5 | from gym.utils import seeding 6 | 7 | 8 | class TabularMDPEnv(gym.Env): 9 | """Tabular MDP problems, as described in [1]. 10 | 11 | At each time step, the agent chooses one of `num_actions` actions, say `i`, 12 | receives a reward sampled from a Normal distribution with mean `m_i` and 13 | variance 1 (fixed across all tasks), and reaches a new state following the 14 | dynamics of the Markov Decision Process (MDP). The tabular MDP tasks are 15 | generated by sampling the mean rewards from a Normal distribution with mean 16 | 1 and variance 1, and sampling the transition probabilities from a uniform 17 | Dirichlet distribution (ie. with parameter 1). 18 | 19 | [1] Yan Duan, John Schulman, Xi Chen, Peter L. Bartlett, Ilya Sutskever, 20 | Pieter Abbeel, "RL2: Fast Reinforcement Learning via Slow Reinforcement 21 | Learning", 2016 (https://arxiv.org/abs/1611.02779) 22 | """ 23 | 24 | def __init__(self, num_states, num_actions, task={}): 25 | super(TabularMDPEnv, self).__init__() 26 | self.num_states = num_states 27 | self.num_actions = num_actions 28 | 29 | self.action_space = spaces.Discrete(num_actions) 30 | self.observation_space = spaces.Box(low=0.0, 31 | high=1.0, shape=(num_states,), dtype=np.float32) 32 | 33 | self._task = task 34 | self._transitions = task.get('transitions', np.full((num_states, 35 | num_actions, num_states), 1.0 / num_states, 36 | dtype=np.float32)) 37 | self._rewards_mean = task.get('rewards_mean', np.zeros(num_states, 38 | num_actions), dtype=np.float32) 39 | self._state = 0 40 | self.seed() 41 | 42 | def seed(self, seed=None): 43 | self.np_random, seed = seeding.np_random(seed) 44 | return [seed] 45 | 46 | def sample_tasks(self, num_tasks): 47 | transitions = self.np_random.dirichlet(np.ones(self.num_states), 48 | size=(num_tasks, self.num_states, self.num_actions)) 49 | rewards_mean = self.np_random.normal(1.0, 1.0, 50 | size=(num_tasks, self.num_states, self.num_actions)) 51 | tasks = [{'transitions': transition, 'rewards_mean': reward_mean} 52 | for (transition, reward_mean) in zip(transitions, rewards_mean)] 53 | return tasks 54 | 55 | def reset_task(self, task): 56 | self._task = task 57 | self._transitions = task['transitions'] 58 | self._rewards_mean = task['rewards_mean'] 59 | 60 | def reset(self): 61 | # From [1]: "an episode always starts on the first state" 62 | self._state = 0 63 | observation = np.zeros(self.num_states, dtype=np.float32) 64 | observation[self._state] = 1.0 65 | 66 | return observation 67 | 68 | def step(self, action): 69 | assert self.action_space.contains(action) 70 | mean = self._rewards_mean[self._state, action] 71 | reward = self.np_random.normal(mean, 1.0) 72 | 73 | self._state = self.np_random.choice(self.num_states, 74 | p=self._transitions[self._state, action]) 75 | observation = np.zeros(self.num_states, dtype=np.float32) 76 | observation[self._state] = 1.0 77 | 78 | return observation, reward, False, self._task 79 | -------------------------------------------------------------------------------- /maml_rl/envs/bandit.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import gym 4 | from gym import spaces 5 | from gym.utils import seeding 6 | 7 | 8 | class BernoulliBanditEnv(gym.Env): 9 | """Multi-armed bandit problems with Bernoulli observations, as described 10 | in [1]. 11 | 12 | At each time step, the agent pulls one of the `k` possible arms (actions), 13 | say `i`, and receives a reward sampled from a Bernoulli distribution with 14 | parameter `p_i`. The multi-armed bandit tasks are generated by sampling 15 | the parameters `p_i` from the uniform distribution on [0, 1]. 16 | 17 | [1] Yan Duan, John Schulman, Xi Chen, Peter L. Bartlett, Ilya Sutskever, 18 | Pieter Abbeel, "RL2: Fast Reinforcement Learning via Slow Reinforcement 19 | Learning", 2016 (https://arxiv.org/abs/1611.02779) 20 | """ 21 | 22 | def __init__(self, k, task={}): 23 | super(BernoulliBanditEnv, self).__init__() 24 | self.k = k 25 | 26 | self.action_space = spaces.Discrete(self.k) 27 | self.observation_space = spaces.Box(low=0, high=0, 28 | shape=(1,), dtype=np.float32) 29 | 30 | self._task = task 31 | self._means = task.get('mean', np.full((k,), 0.5, dtype=np.float32)) 32 | self.seed() 33 | 34 | def seed(self, seed=None): 35 | self.np_random, seed = seeding.np_random(seed) 36 | return [seed] 37 | 38 | def sample_tasks(self, num_tasks): 39 | means = self.np_random.rand(num_tasks, self.k) 40 | tasks = [{'mean': mean} for mean in means] 41 | return tasks 42 | 43 | def reset_task(self, task): 44 | self._task = task 45 | self._means = task['mean'] 46 | 47 | def reset(self): 48 | return np.zeros(1, dtype=np.float32) 49 | 50 | def step(self, action): 51 | assert self.action_space.contains(action) 52 | mean = self._means[action] 53 | reward = self.np_random.binomial(1, mean) 54 | observation = np.zeros(1, dtype=np.float32) 55 | 56 | return observation, reward, True, self._task 57 | 58 | 59 | class GaussianBanditEnv(gym.Env): 60 | """Multi-armed problems with Gaussian observations. 61 | 62 | At each time step, the agent pulls one of the `k` possible arms (actions), 63 | say `i`, and receives a reward sampled from a Normal distribution with 64 | mean `p_i` and standard deviation `std` (fixed across all tasks). The 65 | multi-armed bandit tasks are generated by sampling the parameters `p_i` 66 | from the uniform distribution on [0, 1]. 67 | """ 68 | 69 | def __init__(self, k, std=1.0, task={}): 70 | super(GaussianBanditEnv, self).__init__() 71 | self.k = k 72 | self.std = std 73 | 74 | self.action_space = spaces.Discrete(self.k) 75 | self.observation_space = spaces.Box(low=0, high=0, 76 | shape=(1,), dtype=np.float32) 77 | 78 | self._task = task 79 | self._means = task.get('mean', np.full((k,), 0.5, dtype=np.float32)) 80 | self.seed() 81 | 82 | def seed(self, seed=None): 83 | self.np_random, seed = seeding.np_random(seed) 84 | return [seed] 85 | 86 | def sample_tasks(self, num_tasks): 87 | means = self.np_random.rand(num_tasks, self.k) 88 | tasks = [{'mean': mean} for mean in means] 89 | return tasks 90 | 91 | def reset_task(self, task): 92 | self._task = task 93 | self._means = task['mean'] 94 | 95 | def reset(self): 96 | return np.zeros(1, dtype=np.float32) 97 | 98 | def step(self, action): 99 | assert self.action_space.contains(action) 100 | mean = self._means[action] 101 | reward = self.np_random.normal(mean, self.std) 102 | observation = np.zeros(1, dtype=np.float32) 103 | 104 | return observation, reward, True, self._task 105 | -------------------------------------------------------------------------------- /maml_rl/envs/normalized_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | from gym import spaces 4 | 5 | 6 | class NormalizedActionWrapper(gym.ActionWrapper): 7 | """ 8 | Environment wrapper to normalize the action space to [-1, 1]. This 9 | wrapper is adapted from rllab's [1] wrapper `NormalizedEnv` 10 | https://github.com/rll/rllab/blob/b3a28992eca103cab3cb58363dd7a4bb07f250a0/rllab/envs/normalized_env.py 11 | 12 | [1] Yan Duan, Xi Chen, Rein Houthooft, John Schulman, Pieter Abbeel, 13 | "Benchmarking Deep Reinforcement Learning for Continuous Control", 2016 14 | (https://arxiv.org/abs/1604.06778) 15 | """ 16 | 17 | def __init__(self, env): 18 | super(NormalizedActionWrapper, self).__init__(env) 19 | self.action_space = spaces.Box(low=-1.0, high=1.0, 20 | shape=self.env.action_space.shape) 21 | 22 | def action(self, action): 23 | # Clip the action in [-1, 1] 24 | action = np.clip(action, -1.0, 1.0) 25 | # Map the normalized action to original action space 26 | lb, ub = self.env.action_space.low, self.env.action_space.high 27 | action = lb + 0.5 * (action + 1.0) * (ub - lb) 28 | return action 29 | 30 | def reverse_action(self, action): 31 | # Map the original action to normalized action space 32 | lb, ub = self.env.action_space.low, self.env.action_space.high 33 | action = 2.0 * (action - lb) / (ub - lb) - 1.0 34 | # Clip the action in [-1, 1] 35 | action = np.clip(action, -1.0, 1.0) 36 | return action 37 | 38 | 39 | class NormalizedObservationWrapper(gym.ObservationWrapper): 40 | """ 41 | Environment wrapper to normalize the observations with a running mean 42 | and standard deviation. This wrapper is adapted from rllab's [1] 43 | wrapper `NormalizedEnv` 44 | https://github.com/rll/rllab/blob/b3a28992eca103cab3cb58363dd7a4bb07f250a0/rllab/envs/normalized_env.py 45 | 46 | [1] Yan Duan, Xi Chen, Rein Houthooft, John Schulman, Pieter Abbeel, 47 | "Benchmarking Deep Reinforcement Learning for Continuous Control", 2016 48 | (https://arxiv.org/abs/1604.06778) 49 | """ 50 | 51 | def __init__(self, env, alpha=1e-3, epsilon=1e-8): 52 | super(NormalizedObservationWrapper, self).__init__(env) 53 | self.alpha = alpha 54 | self.epsilon = epsilon 55 | shape = self.observation_space.shape 56 | dtype = self.observation_space.dtype or np.float32 57 | self._mean = np.zeros(shape, dtype=dtype) 58 | self._var = np.ones(shape, dtype=dtype) 59 | 60 | def observation(self, observation): 61 | self._mean = (1.0 - self.alpha) * self._mean + self.alpha * observation 62 | self._var = (1.0 - self.alpha) * self._var + self.alpha * np.square(observation, self._mean) 63 | return (observation - self._mean) / (np.sqrt(self._var) + self.epsilon) 64 | 65 | 66 | class NormalizedRewardWrapper(gym.RewardWrapper): 67 | """ 68 | Environment wrapper to normalize the rewards with a running mean 69 | and standard deviation. This wrapper is adapted from rllab's [1] 70 | wrapper `NormalizedEnv` 71 | https://github.com/rll/rllab/blob/b3a28992eca103cab3cb58363dd7a4bb07f250a0/rllab/envs/normalized_env.py 72 | 73 | [1] Yan Duan, Xi Chen, Rein Houthooft, John Schulman, Pieter Abbeel, 74 | "Benchmarking Deep Reinforcement Learning for Continuous Control", 2016 75 | (https://arxiv.org/abs/1604.06778) 76 | """ 77 | 78 | def __init__(self, env, alpha=1e-3, epsilon=1e-8): 79 | super(NormalizedRewardWrapper, self).__init__(env) 80 | self.alpha = alpha 81 | self.epsilon = epsilon 82 | self._mean = 0.0 83 | self._var = 1.0 84 | 85 | def reward(self, reward): 86 | self._mean = (1.0 - self.alpha) * self._mean + self.alpha * reward 87 | self._var = (1.0 - self.alpha) * self._var + self.alpha * np.square(reward, self._mean) 88 | return (reward - self._mean) / (np.sqrt(self._var) + self.epsilon) 89 | -------------------------------------------------------------------------------- /maml_rl/episode.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class BatchEpisodes: 7 | 8 | def __init__(self, batch_size, gamma=0.95, device='cpu'): 9 | self.batch_size = batch_size 10 | self.gamma = gamma 11 | self.device = device 12 | 13 | # [[], [],...batchsz of []] 14 | self._observations_list = [[] for _ in range(batch_size)] 15 | self._actions_list = [[] for _ in range(batch_size)] 16 | self._rewards_list = [[] for _ in range(batch_size)] 17 | self._mask_list = [] 18 | 19 | self._observations = None 20 | self._actions = None 21 | self._rewards = None 22 | self._returns = None 23 | self._mask = None 24 | 25 | @property 26 | def observations(self): 27 | if self._observations is None: 28 | observation_shape = self._observations_list[0][0].shape 29 | observations = np.zeros((len(self), self.batch_size) 30 | + observation_shape, dtype=np.float32) 31 | for i in range(self.batch_size): 32 | length = len(self._observations_list[i]) 33 | observations[:length, i] = np.stack(self._observations_list[i], axis=0) 34 | self._observations = torch.from_numpy(observations).to(self.device) 35 | return self._observations 36 | 37 | @property 38 | def actions(self): 39 | if self._actions is None: 40 | action_shape = self._actions_list[0][0].shape 41 | actions = np.zeros((len(self), self.batch_size) 42 | + action_shape, dtype=np.float32) 43 | for i in range(self.batch_size): 44 | length = len(self._actions_list[i]) 45 | actions[:length, i] = np.stack(self._actions_list[i], axis=0) 46 | self._actions = torch.from_numpy(actions).to(self.device) 47 | return self._actions 48 | 49 | @property 50 | def rewards(self): 51 | if self._rewards is None: 52 | rewards = np.zeros((len(self), self.batch_size), dtype=np.float32) 53 | for i in range(self.batch_size): 54 | length = len(self._rewards_list[i]) 55 | rewards[:length, i] = np.stack(self._rewards_list[i], axis=0) 56 | self._rewards = torch.from_numpy(rewards).to(self.device) 57 | return self._rewards 58 | 59 | @property 60 | def returns(self): 61 | if self._returns is None: 62 | return_ = np.zeros(self.batch_size, dtype=np.float32) 63 | returns = np.zeros((len(self), self.batch_size), dtype=np.float32) 64 | rewards = self.rewards.cpu().numpy() 65 | mask = self.mask.cpu().numpy() 66 | for i in range(len(self) - 1, -1, -1): 67 | return_ = self.gamma * return_ + rewards[i] * mask[i] 68 | returns[i] = return_ 69 | self._returns = torch.from_numpy(returns).to(self.device) 70 | return self._returns 71 | 72 | @property 73 | def mask(self): 74 | if self._mask is None: 75 | mask = np.zeros((len(self), self.batch_size), dtype=np.float32) 76 | for i in range(self.batch_size): 77 | length = len(self._actions_list[i]) 78 | mask[:length, i] = 1.0 79 | self._mask = torch.from_numpy(mask).to(self.device) 80 | return self._mask 81 | 82 | def gae(self, values, tau=1.0): 83 | """ 84 | 85 | :param values: [200, 20, 1], tensor 86 | :param tau: 87 | :return: 88 | """ 89 | # Add an additional 0 at the end of values for 90 | # the estimation at the end of the episode 91 | values = values.squeeze(2).detach() # [200, 20] 92 | values = F.pad(values * self.mask, (0, 0, 0, 1)) # [201, 20] 93 | 94 | deltas = self.rewards + self.gamma * values[1:] - values[:-1] # [200, 20] 95 | advantages = torch.zeros_like(deltas).float() # [200, 20] 96 | gae = torch.zeros_like(deltas[0]).float() # [20] 97 | for i in range(len(self) - 1, -1, -1): 98 | gae = gae * self.gamma * tau + deltas[i] 99 | advantages[i] = gae 100 | 101 | return advantages 102 | 103 | def append(self, observations, actions, rewards, batch_ids): 104 | for observation, action, reward, batch_id in zip(observations, actions, rewards, batch_ids): 105 | if batch_id is None: 106 | continue 107 | self._observations_list[batch_id].append(observation.astype(np.float32)) 108 | self._actions_list[batch_id].append(action.astype(np.float32)) 109 | self._rewards_list[batch_id].append(reward.astype(np.float32)) 110 | 111 | def __len__(self): 112 | return max(map(len, self._rewards_list)) 113 | -------------------------------------------------------------------------------- /maml_rl/envs/subproc_vec_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import multiprocessing as mp 3 | import gym 4 | import sys 5 | 6 | import queue 7 | 8 | 9 | class EnvWorker(mp.Process): 10 | 11 | def __init__(self, remote, env_fn, queue_, lock): 12 | """ 13 | 14 | :param remote: send/recv connection, type of Pipe 15 | :param env_fn: construct environment function 16 | :param queue_: global queue instance 17 | :param lock: Every worker has a lock 18 | """ 19 | super(EnvWorker, self).__init__() 20 | 21 | self.remote = remote # Pipe() 22 | self.env = env_fn() # return a function 23 | self.queue = queue_ 24 | self.lock = lock 25 | self.task_id = None 26 | self.done = False 27 | 28 | def empty_step(self): 29 | """ 30 | conduct a dummy step 31 | :return: 32 | """ 33 | observation = np.zeros(self.env.observation_space.shape, dtype=np.float32) 34 | reward, done = 0.0, True 35 | 36 | return observation, reward, done, {} 37 | 38 | def try_reset(self): 39 | """ 40 | 41 | :return: 42 | """ 43 | with self.lock: 44 | try: 45 | self.task_id = self.queue.get(True) # block = True 46 | self.done = (self.task_id is None) 47 | except queue.Empty: 48 | self.done = True 49 | 50 | # construct empty state or get state from env.reset() 51 | observation = np.zeros(self.env.observation_space.shape, dtype=np.float32) if self.done else self.env.reset() 52 | 53 | return observation 54 | 55 | def run(self): 56 | """ 57 | 58 | :return: 59 | """ 60 | while True: 61 | command, data = self.remote.recv() 62 | 63 | if command == 'step': 64 | observation, reward, done, info = (self.empty_step() if self.done else self.env.step(data)) 65 | if done and (not self.done): 66 | observation = self.try_reset() 67 | self.remote.send((observation, reward, done, self.task_id, info)) 68 | 69 | elif command == 'reset': 70 | observation = self.try_reset() 71 | self.remote.send((observation, self.task_id)) 72 | elif command == 'reset_task': 73 | self.env.unwrapped.reset_task(data) 74 | self.remote.send(True) 75 | elif command == 'close': 76 | self.remote.close() 77 | break 78 | elif command == 'get_spaces': 79 | self.remote.send((self.env.observation_space, self.env.action_space)) 80 | else: 81 | raise NotImplementedError() 82 | 83 | 84 | class SubprocVecEnv(gym.Env): 85 | 86 | def __init__(self, env_factorys, queue_): 87 | """ 88 | 89 | :param env_factorys: list of [lambda x: def p: envs.make(env_name), return p], len: num_workers 90 | :param queue: 91 | """ 92 | self.lock = mp.Lock() 93 | # remotes: all recv conn, len: 8, here duplex=True 94 | # works_remotes: all send conn, len: 8, here duplex=True 95 | self.remotes, self.work_remotes = zip(*[mp.Pipe() for _ in env_factorys]) 96 | 97 | # queue and lock is shared. 98 | self.workers = [EnvWorker(remote, env_fn, queue_, self.lock) 99 | for (remote, env_fn) in zip(self.work_remotes, env_factorys)] 100 | # start 8 processes to interact with environments. 101 | for worker in self.workers: 102 | worker.daemon = True 103 | worker.start() 104 | for remote in self.work_remotes: 105 | remote.close() 106 | 107 | self.waiting = False # for step_async 108 | self.closed = False 109 | 110 | # Since the main process need talk to children processes, we need a way to comunicate between these. 111 | # here we use mp.Pipe() to send/recv data. 112 | self.remotes[0].send(('get_spaces', None)) 113 | observation_space, action_space = self.remotes[0].recv() 114 | self.observation_space = observation_space 115 | self.action_space = action_space 116 | 117 | def step(self, actions): 118 | """ 119 | step synchronously 120 | :param actions: 121 | :return: 122 | """ 123 | self.step_async(actions) 124 | # wait until step state overdue 125 | return self.step_wait() 126 | 127 | def step_async(self, actions): 128 | """ 129 | step asynchronouly 130 | :param actions: 131 | :return: 132 | """ 133 | # let each sub-process step 134 | for remote, action in zip(self.remotes, actions): 135 | remote.send(('step', action)) 136 | self.waiting = True 137 | 138 | def step_wait(self): 139 | results = [remote.recv() for remote in self.remotes] 140 | self.waiting = False 141 | observations, rewards, dones, task_ids, infos = zip(*results) 142 | return np.stack(observations), np.stack(rewards), np.stack(dones), task_ids, infos 143 | 144 | def reset(self): 145 | """ 146 | reset synchronously 147 | :return: 148 | """ 149 | for remote in self.remotes: 150 | remote.send(('reset', None)) 151 | results = [remote.recv() for remote in self.remotes] 152 | observations, task_ids = zip(*results) 153 | return np.stack(observations), task_ids 154 | 155 | def reset_task(self, tasks): 156 | for remote, task in zip(self.remotes, tasks): 157 | remote.send(('reset_task', task)) 158 | return np.stack([remote.recv() for remote in self.remotes]) 159 | 160 | def close(self): 161 | if self.closed: 162 | return 163 | if self.waiting: # cope with step_async() 164 | for remote in self.remotes: 165 | remote.recv() 166 | for remote in self.remotes: 167 | remote.send(('close', None)) 168 | for worker in self.workers: 169 | worker.join() 170 | self.closed = True 171 | -------------------------------------------------------------------------------- /maml_rl/envs/mujoco/half_cheetah.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.envs.mujoco import HalfCheetahEnv as HalfCheetahEnv_ 3 | 4 | 5 | class HalfCheetahEnv(HalfCheetahEnv_): 6 | def _get_obs(self): 7 | return np.concatenate([ 8 | self.sim.data.qpos.flat[1:], 9 | self.sim.data.qvel.flat, 10 | self.get_body_com("torso").flat, 11 | ]).astype(np.float32).flatten() 12 | 13 | def viewer_setup(self): 14 | camera_id = self.model.camera_name2id('track') 15 | self.viewer.cam.type = 2 16 | self.viewer.cam.fixedcamid = camera_id 17 | self.viewer.cam.distance = self.model.stat.extent * 0.35 18 | # Hide the overlay 19 | self.viewer._hide_overlay = True 20 | 21 | def render(self, mode='human'): 22 | if mode == 'rgb_array': 23 | self._get_viewer().render() 24 | # window size used for old mujoco-py: 25 | width, height = 500, 500 26 | data = self._get_viewer().read_pixels(width, height, depth=False) 27 | return data 28 | elif mode == 'human': 29 | self._get_viewer().render() 30 | 31 | 32 | class HalfCheetahVelEnv(HalfCheetahEnv): 33 | """ 34 | Half-cheetah environment with target velocity, as described in [1]. The 35 | code is adapted from 36 | https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/rllab/envs/mujoco/half_cheetah_env_rand.py 37 | 38 | The half-cheetah follows the dynamics from MuJoCo [2], and receives at each 39 | time step a reward composed of a control cost and a penalty equal to the 40 | difference between its current velocity and the target velocity. The tasks 41 | are generated by sampling the target velocities from the uniform 42 | distribution on [0, 2]. 43 | 44 | [1] Chelsea Finn, Pieter Abbeel, Sergey Levine, "Model-Agnostic 45 | Meta-Learning for Fast Adaptation of Deep Networks", 2017 46 | (https://arxiv.org/abs/1703.03400) 47 | [2] Emanuel Todorov, Tom Erez, Yuval Tassa, "MuJoCo: A physics engine for 48 | model-based control", 2012 49 | (https://homes.cs.washington.edu/~todorov/papers/TodorovIROS12.pdf) 50 | """ 51 | 52 | def __init__(self, task={}): 53 | self._task = task 54 | self._goal_vel = task.get('velocity', 0.0) 55 | super(HalfCheetahVelEnv, self).__init__() 56 | 57 | def step(self, action): 58 | xposbefore = self.sim.data.qpos[0] 59 | self.do_simulation(action, self.frame_skip) 60 | xposafter = self.sim.data.qpos[0] 61 | 62 | forward_vel = (xposafter - xposbefore) / self.dt 63 | forward_reward = -1.0 * abs(forward_vel - self._goal_vel) 64 | ctrl_cost = 0.5 * 1e-1 * np.sum(np.square(action)) 65 | 66 | observation = self._get_obs() 67 | reward = forward_reward - ctrl_cost 68 | done = False 69 | infos = dict(reward_forward=forward_reward, 70 | reward_ctrl=-ctrl_cost, task=self._task) 71 | return (observation, reward, done, infos) 72 | 73 | def sample_tasks(self, num_tasks): 74 | velocities = self.np_random.uniform(0.0, 2.0, size=(num_tasks,)) 75 | tasks = [{'velocity': velocity} for velocity in velocities] 76 | return tasks 77 | 78 | def reset_task(self, task): 79 | self._task = task 80 | self._goal_vel = task['velocity'] 81 | 82 | 83 | class HalfCheetahDirEnv(HalfCheetahEnv): 84 | """ 85 | Half-cheetah environment with target direction, as described in [1]. The 86 | code is adapted from 87 | https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/rllab/envs/mujoco/half_cheetah_env_rand_direc.py 88 | 89 | The half-cheetah follows the dynamics from MuJoCo [2], and receives at each 90 | time step a reward composed of a control cost and a reward equal to its 91 | velocity in the target direction. The tasks are generated by sampling the 92 | target directions from a Bernoulli distribution on {-1, 1} with parameter 93 | 0.5 (-1: backward, +1: forward). 94 | 95 | [1] Chelsea Finn, Pieter Abbeel, Sergey Levine, "Model-Agnostic 96 | Meta-Learning for Fast Adaptation of Deep Networks", 2017 97 | (https://arxiv.org/abs/1703.03400) 98 | [2] Emanuel Todorov, Tom Erez, Yuval Tassa, "MuJoCo: A physics engine for 99 | model-based control", 2012 100 | (https://homes.cs.washington.edu/~todorov/papers/TodorovIROS12.pdf) 101 | """ 102 | 103 | def __init__(self, task={}): 104 | self._task = task 105 | self._goal_dir = task.get('direction', 1) 106 | super(HalfCheetahDirEnv, self).__init__() 107 | 108 | def step(self, action): 109 | xposbefore = self.sim.data.qpos[0] 110 | self.do_simulation(action, self.frame_skip) 111 | xposafter = self.sim.data.qpos[0] 112 | 113 | forward_vel = (xposafter - xposbefore) / self.dt 114 | forward_reward = self._goal_dir * forward_vel 115 | ctrl_cost = 0.5 * 1e-1 * np.sum(np.square(action)) 116 | 117 | observation = self._get_obs() 118 | reward = forward_reward - ctrl_cost 119 | done = False 120 | infos = dict(reward_forward=forward_reward, 121 | reward_ctrl=-ctrl_cost, task=self._task) 122 | return (observation, reward, done, infos) 123 | 124 | def sample_tasks(self, num_tasks): 125 | directions = 2 * self.np_random.binomial(1, p=0.5, size=(num_tasks,)) - 1 126 | tasks = [{'direction': direction} for direction in directions] 127 | return tasks 128 | 129 | def reset_task(self, task): 130 | self._task = task 131 | self._goal_dir = task['direction'] 132 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import maml_rl.envs 2 | import gym 3 | import numpy as np 4 | import torch 5 | import json 6 | 7 | from maml_rl.metalearner import MetaLearner 8 | from maml_rl.policies import CategoricalMLPPolicy, NormalMLPPolicy 9 | from maml_rl.baseline import LinearFeatureBaseline 10 | from maml_rl.sampler import BatchSampler 11 | 12 | # from tensorboardX import SummaryWriter 13 | 14 | 15 | def total_rewards(episodes_rewards, aggregation=torch.mean): 16 | rewards = torch.mean(torch.stack([aggregation(torch.sum(rewards, dim=0)) 17 | for rewards in episodes_rewards], dim=0)) 18 | return rewards.item() 19 | 20 | 21 | def main(args): 22 | 23 | args.output_folder = args.env_name 24 | 25 | # TODO 26 | continuous_actions = (args.env_name in ['AntVel-v1', 'AntDir-v1', 27 | 'AntPos-v0', 'HalfCheetahVel-v1', 'HalfCheetahDir-v1', 28 | '2DNavigation-v0']) 29 | 30 | # writer = SummaryWriter('./logs/{0}'.format(args.output_folder)) 31 | save_folder = './saves/{0}'.format(args.output_folder) 32 | if not os.path.exists(save_folder): 33 | os.makedirs(save_folder) 34 | 35 | with open(os.path.join(save_folder, 'config.json'), 'w') as f: 36 | # config = {k: v for (k, v) in vars(args).iteritems() if k != 'device'} 37 | config = {k: v for (k, v) in vars(args).items() if k != 'device'} 38 | config.update(device=args.device.type) 39 | json.dump(config, f, indent=2) 40 | print(config) 41 | 42 | sampler = BatchSampler(args.env_name, batch_size=args.fast_batch_size, num_workers=args.num_workers) 43 | 44 | if continuous_actions: 45 | policy = NormalMLPPolicy( 46 | int(np.prod(sampler.envs.observation_space.shape)), # input shape 47 | int(np.prod(sampler.envs.action_space.shape)), # output shape 48 | hidden_sizes=(args.hidden_size,) * args.num_layers) # [100, 100] 49 | else: 50 | policy = CategoricalMLPPolicy( 51 | int(np.prod(sampler.envs.observation_space.shape)), 52 | sampler.envs.action_space.n, 53 | hidden_sizes=(args.hidden_size,) * args.num_layers) 54 | 55 | baseline = LinearFeatureBaseline( int(np.prod(sampler.envs.observation_space.shape))) 56 | 57 | metalearner = MetaLearner(sampler, policy, baseline, gamma=args.gamma, 58 | fast_lr=args.fast_lr, tau=args.tau, device=args.device) 59 | 60 | for batch in range(args.num_batches): # number of epoches 61 | 62 | tasks = sampler.sample_tasks(num_tasks=args.meta_batch_size) 63 | episodes = metalearner.sample(tasks, first_order=args.first_order) 64 | 65 | metalearner.step(episodes, max_kl=args.max_kl, cg_iters=args.cg_iters, 66 | cg_damping=args.cg_damping, ls_max_steps=args.ls_max_steps, 67 | ls_backtrack_ratio=args.ls_backtrack_ratio) 68 | 69 | # Tensorboard 70 | # writer.add_scalar('total_rewards/before_update', 71 | # total_rewards([ep.rewards for ep, _ in episodes]), batch) 72 | # writer.add_scalar('total_rewards/after_update', 73 | # total_rewards([ep.rewards for _, ep in episodes]), batch) 74 | 75 | 76 | # # Save policy network 77 | # with open(os.path.join(save_folder, 'policy-{0}.pt'.format(batch)), 'wb') as f: 78 | # torch.save(policy.state_dict(), f) 79 | 80 | print(batch, total_rewards([ep.rewards for ep, _ in episodes]), total_rewards([ep.rewards for _, ep in episodes])) 81 | 82 | 83 | if __name__ == '__main__': 84 | """ 85 | python main.py --env-name HalfCheetahDir-v1 --output-folder maml-halfcheetah-dir \ 86 | --fast-lr 0.1 --meta-batch-size 30 --fast-batch-size 20 --num-batches 1000 87 | """ 88 | import argparse 89 | import os 90 | import multiprocessing as mp 91 | 92 | parser = argparse.ArgumentParser(description='Reinforcement learning with ' 93 | 'Model-Agnostic Meta-Learning (MAML)') 94 | 95 | # General 96 | parser.add_argument('--env-name', type=str, default='HalfCheetahDir-v1', 97 | help='name of the environment') 98 | parser.add_argument('--gamma', type=float, default=0.95, 99 | help='value of the discount factor gamma') 100 | parser.add_argument('--tau', type=float, default=1.0, 101 | help='value of the discount factor for GAE') 102 | parser.add_argument('--first-order', action='store_true', 103 | help='use the first-order approximation of MAML') 104 | 105 | # Policy network (relu activation function) 106 | parser.add_argument('--hidden-size', type=int, default=100, 107 | help='number of hidden units per layer') 108 | parser.add_argument('--num-layers', type=int, default=2, 109 | help='number of hidden layers') 110 | 111 | # Task-specific 112 | parser.add_argument('--fast-batch-size', type=int, default=20, 113 | help='batch size for each individual task') 114 | parser.add_argument('--fast-lr', type=float, default=0.1, # 0.5 115 | help='learning rate for the 1-step gradient update of MAML') 116 | 117 | # Optimization 118 | parser.add_argument('--num-batches', type=int, default=1000, 119 | help='number of batches, or number of epoches') 120 | parser.add_argument('--meta-batch-size', type=int, default=30, 121 | help='number of tasks per batch') 122 | parser.add_argument('--max-kl', type=float, default=1e-2, 123 | help='maximum value for the KL constraint in TRPO') 124 | parser.add_argument('--cg-iters', type=int, default=10, 125 | help='number of iterations of conjugate gradient') 126 | parser.add_argument('--cg-damping', type=float, default=1e-5, 127 | help='damping in conjugate gradient') 128 | parser.add_argument('--ls-max-steps', type=int, default=15, 129 | help='maximum number of iterations for line search') 130 | parser.add_argument('--ls-backtrack-ratio', type=float, default=0.8, 131 | help='maximum number of iterations for line search') 132 | 133 | # Miscellaneous 134 | parser.add_argument('--output-folder', type=str, default='HalfCheetahDir-v1', 135 | help='name of the output folder') 136 | parser.add_argument('--num-workers', type=int, default=mp.cpu_count(), 137 | help='number of workers for trajectories sampling') 138 | parser.add_argument('--device', type=str, default='cuda', 139 | help='set the device (cpu or cuda)') 140 | 141 | args = parser.parse_args() 142 | 143 | # Create logs and saves folder if they don't exist 144 | if not os.path.exists('./logs'): 145 | os.makedirs('./logs') 146 | if not os.path.exists('./saves'): 147 | os.makedirs('./saves') 148 | # Device 149 | args.device = torch.device(args.device if torch.cuda.is_available() else 'cpu') 150 | # Slurm 151 | if 'SLURM_JOB_ID' in os.environ: 152 | args.output_folder += '-{0}'.format(os.environ['SLURM_JOB_ID']) 153 | 154 | main(args) 155 | -------------------------------------------------------------------------------- /maml_rl/metalearner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils.convert_parameters import vector_to_parameters, parameters_to_vector 3 | from torch.distributions.kl import kl_divergence 4 | 5 | from maml_rl.utils.torch_utils import weighted_mean, detach_distribution, weighted_normalize 6 | from maml_rl.utils.optimization import conjugate_gradient 7 | 8 | 9 | class MetaLearner: 10 | """ 11 | Meta-learner 12 | 13 | The meta-learner is responsible for sampling the trajectories/episodes 14 | (before and after the one-step adaptation), compute the inner loss, compute 15 | the updated parameters based on the inner-loss, and perform the meta-update. 16 | 17 | [1] Chelsea Finn, Pieter Abbeel, Sergey Levine, "Model-Agnostic 18 | Meta-Learning for Fast Adaptation of Deep Networks", 2017 19 | (https://arxiv.org/abs/1703.03400) 20 | [2] Richard Sutton, Andrew Barto, "Reinforcement learning: An introduction", 21 | 2018 (http://incompleteideas.net/book/the-book-2nd.html) 22 | [3] John Schulman, Philipp Moritz, Sergey Levine, Michael Jordan, 23 | Pieter Abbeel, "High-Dimensional Continuous Control Using Generalized 24 | Advantage Estimation", 2016 (https://arxiv.org/abs/1506.02438) 25 | [4] John Schulman, Sergey Levine, Philipp Moritz, Michael I. Jordan, 26 | Pieter Abbeel, "Trust Region Policy Optimization", 2015 27 | (https://arxiv.org/abs/1502.05477) 28 | """ 29 | 30 | def __init__(self, sampler, policy, baseline, gamma=0.95, fast_lr=0.5, tau=1.0, device='cpu'): 31 | self.sampler = sampler 32 | self.policy = policy 33 | self.baseline = baseline 34 | self.gamma = gamma 35 | self.fast_lr = fast_lr 36 | self.tau = tau 37 | self.to(device) 38 | 39 | def inner_loss(self, episodes, params=None): 40 | """ 41 | Compute the inner loss for the one-step gradient update. The inner 42 | loss is REINFORCE with baseline [2], computed on advantages estimated 43 | with Generalized Advantage Estimation (GAE, [3]). 44 | """ 45 | values = self.baseline(episodes) 46 | advantages = episodes.gae(values, tau=self.tau) 47 | advantages = weighted_normalize(advantages, weights=episodes.mask) 48 | 49 | pi = self.policy(episodes.observations, params=params) 50 | # return the log_prob at value 51 | log_probs = pi.log_prob(episodes.actions) # [200, 20, 6] 52 | if log_probs.dim() > 2: 53 | log_probs = torch.sum(log_probs, dim=2) 54 | loss = -weighted_mean(log_probs * advantages, weights=episodes.mask) 55 | 56 | return loss 57 | 58 | def adapt(self, episodes, first_order=False): 59 | """ 60 | Adapt the parameters of the policy network to a new task, from 61 | sampled trajectories `episodes`, with a one-step gradient update [1]. 62 | """ 63 | # Fit the baseline to the training episodes 64 | self.baseline.fit(episodes) 65 | # Get the loss on the training episodes 66 | loss = self.inner_loss(episodes) 67 | # Get the new parameters after a one-step gradient update 68 | params = self.policy.update_params(loss, step_size=self.fast_lr, first_order=first_order) 69 | 70 | return params 71 | 72 | def sample(self, tasks, first_order=False): 73 | """ 74 | Sample trajectories (before and after the update of the parameters) 75 | for all the tasks `tasks`. 76 | """ 77 | episodes = [] 78 | for task in tasks: 79 | self.sampler.reset_task(task) 80 | train_episodes = self.sampler.sample(self.policy, gamma=self.gamma, device=self.device) 81 | 82 | params = self.adapt(train_episodes, first_order=first_order) 83 | 84 | valid_episodes = self.sampler.sample(self.policy, params=params, gamma=self.gamma, device=self.device) 85 | episodes.append((train_episodes, valid_episodes)) 86 | return episodes 87 | 88 | def kl_divergence(self, episodes, old_pis=None): 89 | kls = [] 90 | if old_pis is None: 91 | old_pis = [None] * len(episodes) 92 | 93 | for (train_episodes, valid_episodes), old_pi in zip(episodes, old_pis): 94 | params = self.adapt(train_episodes) 95 | pi = self.policy(valid_episodes.observations, params=params) 96 | 97 | if old_pi is None: 98 | old_pi = detach_distribution(pi) 99 | 100 | mask = valid_episodes.mask 101 | if valid_episodes.actions.dim() > 2: 102 | mask = mask.unsqueeze(2) 103 | kl = weighted_mean(kl_divergence(pi, old_pi), weights=mask) 104 | kls.append(kl) 105 | 106 | return torch.mean(torch.stack(kls, dim=0)) 107 | 108 | def hessian_vector_product(self, episodes, damping=1e-2): 109 | """ 110 | Hessian-vector product, based on the Perlmutter method. 111 | """ 112 | 113 | def _product(vector): 114 | kl = self.kl_divergence(episodes) 115 | grads = torch.autograd.grad(kl, self.policy.parameters(), 116 | create_graph=True) 117 | flat_grad_kl = parameters_to_vector(grads) 118 | 119 | grad_kl_v = torch.dot(flat_grad_kl, vector) 120 | grad2s = torch.autograd.grad(grad_kl_v, self.policy.parameters()) 121 | flat_grad2_kl = parameters_to_vector(grad2s) 122 | 123 | return flat_grad2_kl + damping * vector 124 | 125 | return _product 126 | 127 | def surrogate_loss(self, episodes, old_pis=None): 128 | losses, kls, pis = [], [], [] 129 | if old_pis is None: 130 | old_pis = [None] * len(episodes) 131 | 132 | for (train_episodes, valid_episodes), old_pi in zip(episodes, old_pis): 133 | params = self.adapt(train_episodes) 134 | 135 | with torch.set_grad_enabled(old_pi is None): 136 | pi = self.policy(valid_episodes.observations, params=params) 137 | pis.append(detach_distribution(pi)) 138 | 139 | if old_pi is None: 140 | old_pi = detach_distribution(pi) 141 | 142 | values = self.baseline(valid_episodes) 143 | advantages = valid_episodes.gae(values, tau=self.tau) 144 | advantages = weighted_normalize(advantages, 145 | weights=valid_episodes.mask) 146 | 147 | log_ratio = (pi.log_prob(valid_episodes.actions) 148 | - old_pi.log_prob(valid_episodes.actions)) 149 | if log_ratio.dim() > 2: 150 | log_ratio = torch.sum(log_ratio, dim=2) 151 | ratio = torch.exp(log_ratio) 152 | 153 | loss = -weighted_mean(ratio * advantages, 154 | weights=valid_episodes.mask) 155 | losses.append(loss) 156 | 157 | mask = valid_episodes.mask 158 | if valid_episodes.actions.dim() > 2: 159 | mask = mask.unsqueeze(2) 160 | kl = weighted_mean(kl_divergence(pi, old_pi), weights=mask) 161 | kls.append(kl) 162 | 163 | return (torch.mean(torch.stack(losses, dim=0)), torch.mean(torch.stack(kls, dim=0)), pis) 164 | 165 | def step(self, episodes, max_kl=1e-3, cg_iters=10, cg_damping=1e-2, ls_max_steps=10, ls_backtrack_ratio=0.5): 166 | """ 167 | Meta-optimization step (ie. update of the initial parameters), based 168 | on Trust Region Policy Optimization (TRPO, [4]). 169 | """ 170 | old_loss, _, old_pis = self.surrogate_loss(episodes) 171 | grads = torch.autograd.grad(old_loss, self.policy.parameters()) 172 | grads = parameters_to_vector(grads) 173 | 174 | # Compute the step direction with Conjugate Gradient 175 | # return a function 176 | hessian_vector_product = self.hessian_vector_product(episodes, damping=cg_damping) 177 | stepdir = conjugate_gradient(hessian_vector_product, grads, cg_iters=cg_iters) 178 | 179 | # Compute the Lagrange multiplier 180 | shs = 0.5 * torch.dot(stepdir, hessian_vector_product(stepdir)) 181 | lagrange_multiplier = torch.sqrt(shs / max_kl) 182 | 183 | step = stepdir / lagrange_multiplier 184 | 185 | # Save the old parameters 186 | old_params = parameters_to_vector(self.policy.parameters()) 187 | 188 | # Line search 189 | step_size = 1.0 190 | for _ in range(ls_max_steps): 191 | vector_to_parameters(old_params - step_size * step, self.policy.parameters()) 192 | loss, kl, _ = self.surrogate_loss(episodes, old_pis=old_pis) 193 | improve = loss - old_loss 194 | if (improve.item() < 0.0) and (kl.item() < max_kl): 195 | break 196 | step_size *= ls_backtrack_ratio 197 | else: 198 | vector_to_parameters(old_params, self.policy.parameters()) 199 | 200 | def to(self, device, **kwargs): 201 | self.policy.to(device, **kwargs) 202 | self.baseline.to(device, **kwargs) 203 | self.device = device 204 | -------------------------------------------------------------------------------- /maml_rl/envs/mujoco/ant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym.envs.mujoco import AntEnv as AntEnv_ 3 | 4 | 5 | class AntEnv(AntEnv_): 6 | @property 7 | def action_scaling(self): 8 | if not hasattr(self, 'action_space'): 9 | return 1.0 10 | if self._action_scaling is None: 11 | lb, ub = self.action_space.low, self.action_space.high 12 | self._action_scaling = 0.5 * (ub - lb) 13 | return self._action_scaling 14 | 15 | def _get_obs(self): 16 | return np.concatenate([ 17 | self.sim.data.qpos.flat[2:], 18 | self.sim.data.qvel.flat, 19 | np.clip(self.sim.data.cfrc_ext, -1, 1).flat, 20 | self.sim.data.get_body_xmat("torso").flat, 21 | self.get_body_com("torso").flat, 22 | ]).astype(np.float32).flatten() 23 | 24 | def viewer_setup(self): 25 | camera_id = self.model.camera_name2id('track') 26 | self.viewer.cam.type = 2 27 | self.viewer.cam.fixedcamid = camera_id 28 | self.viewer.cam.distance = self.model.stat.extent * 0.35 29 | # Hide the overlay 30 | self.viewer._hide_overlay = True 31 | 32 | def render(self, mode='human'): 33 | if mode == 'rgb_array': 34 | self._get_viewer().render() 35 | # window size used for old mujoco-py: 36 | width, height = 500, 500 37 | data = self._get_viewer().read_pixels(width, height, depth=False) 38 | return data 39 | elif mode == 'human': 40 | self._get_viewer().render() 41 | 42 | 43 | class AntVelEnv(AntEnv): 44 | """Ant environment with target velocity, as described in [1]. The 45 | code is adapted from 46 | https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/rllab/envs/mujoco/ant_env_rand.py 47 | 48 | The ant follows the dynamics from MuJoCo [2], and receives at each 49 | time step a reward composed of a control cost, a contact cost, a survival 50 | reward, and a penalty equal to the difference between its current velocity 51 | and the target velocity. The tasks are generated by sampling the target 52 | velocities from the uniform distribution on [0, 3]. 53 | 54 | [1] Chelsea Finn, Pieter Abbeel, Sergey Levine, "Model-Agnostic 55 | Meta-Learning for Fast Adaptation of Deep Networks", 2017 56 | (https://arxiv.org/abs/1703.03400) 57 | [2] Emanuel Todorov, Tom Erez, Yuval Tassa, "MuJoCo: A physics engine for 58 | model-based control", 2012 59 | (https://homes.cs.washington.edu/~todorov/papers/TodorovIROS12.pdf) 60 | """ 61 | 62 | def __init__(self, task={}): 63 | self._task = task 64 | self._goal_vel = task.get('velocity', 0.0) 65 | self._action_scaling = None 66 | super(AntVelEnv, self).__init__() 67 | 68 | def step(self, action): 69 | xposbefore = self.get_body_com("torso")[0] 70 | self.do_simulation(action, self.frame_skip) 71 | xposafter = self.get_body_com("torso")[0] 72 | 73 | forward_vel = (xposafter - xposbefore) / self.dt 74 | forward_reward = -1.0 * np.abs(forward_vel - self._goal_vel) + 1.0 75 | survive_reward = 0.05 76 | 77 | ctrl_cost = 0.5 * 1e-2 * np.sum(np.square(action / self.action_scaling)) 78 | contact_cost = 0.5 * 1e-3 * np.sum( 79 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) 80 | 81 | observation = self._get_obs() 82 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward 83 | state = self.state_vector() 84 | notdone = np.isfinite(state).all() \ 85 | and state[2] >= 0.2 and state[2] <= 1.0 86 | done = not notdone 87 | infos = dict(reward_forward=forward_reward, reward_ctrl=-ctrl_cost, 88 | reward_contact=-contact_cost, reward_survive=survive_reward, 89 | task=self._task) 90 | return (observation, reward, done, infos) 91 | 92 | def sample_tasks(self, num_tasks): 93 | velocities = self.np_random.uniform(0.0, 3.0, size=(num_tasks,)) 94 | tasks = [{'velocity': velocity} for velocity in velocities] 95 | return tasks 96 | 97 | def reset_task(self, task): 98 | self._task = task 99 | self._goal_vel = task['velocity'] 100 | 101 | 102 | class AntDirEnv(AntEnv): 103 | """Ant environment with target direction, as described in [1]. The 104 | code is adapted from 105 | https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/rllab/envs/mujoco/ant_env_rand_direc.py 106 | 107 | The ant follows the dynamics from MuJoCo [2], and receives at each 108 | time step a reward composed of a control cost, a contact cost, a survival 109 | reward, and a reward equal to its velocity in the target direction. The 110 | tasks are generated by sampling the target directions from a Bernoulli 111 | distribution on {-1, 1} with parameter 0.5 (-1: backward, +1: forward). 112 | 113 | [1] Chelsea Finn, Pieter Abbeel, Sergey Levine, "Model-Agnostic 114 | Meta-Learning for Fast Adaptation of Deep Networks", 2017 115 | (https://arxiv.org/abs/1703.03400) 116 | [2] Emanuel Todorov, Tom Erez, Yuval Tassa, "MuJoCo: A physics engine for 117 | model-based control", 2012 118 | (https://homes.cs.washington.edu/~todorov/papers/TodorovIROS12.pdf) 119 | """ 120 | 121 | def __init__(self, task={}): 122 | self._task = task 123 | self._goal_dir = task.get('direction', 1) 124 | self._action_scaling = None 125 | super(AntDirEnv, self).__init__() 126 | 127 | def step(self, action): 128 | xposbefore = self.get_body_com("torso")[0] 129 | self.do_simulation(action, self.frame_skip) 130 | xposafter = self.get_body_com("torso")[0] 131 | 132 | forward_vel = (xposafter - xposbefore) / self.dt 133 | forward_reward = self._goal_dir * forward_vel 134 | survive_reward = 0.05 135 | 136 | ctrl_cost = 0.5 * 1e-2 * np.sum(np.square(action / self.action_scaling)) 137 | contact_cost = 0.5 * 1e-3 * np.sum( 138 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) 139 | 140 | observation = self._get_obs() 141 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward 142 | state = self.state_vector() 143 | notdone = np.isfinite(state).all() \ 144 | and state[2] >= 0.2 and state[2] <= 1.0 145 | done = not notdone 146 | infos = dict(reward_forward=forward_reward, reward_ctrl=-ctrl_cost, 147 | reward_contact=-contact_cost, reward_survive=survive_reward, 148 | task=self._task) 149 | return (observation, reward, done, infos) 150 | 151 | def sample_tasks(self, num_tasks): 152 | directions = 2 * self.np_random.binomial(1, p=0.5, size=(num_tasks,)) - 1 153 | tasks = [{'direction': direction} for direction in directions] 154 | return tasks 155 | 156 | def reset_task(self, task): 157 | self._task = task 158 | self._goal_dir = task['direction'] 159 | 160 | 161 | class AntPosEnv(AntEnv): 162 | """Ant environment with target position. The code is adapted from 163 | https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/rllab/envs/mujoco/ant_env_rand_goal.py 164 | 165 | The ant follows the dynamics from MuJoCo [1], and receives at each 166 | time step a reward composed of a control cost, a contact cost, a survival 167 | reward, and a penalty equal to its L1 distance to the target position. The 168 | tasks are generated by sampling the target positions from the uniform 169 | distribution on [-3, 3]^2. 170 | 171 | [1] Emanuel Todorov, Tom Erez, Yuval Tassa, "MuJoCo: A physics engine for 172 | model-based control", 2012 173 | (https://homes.cs.washington.edu/~todorov/papers/TodorovIROS12.pdf) 174 | """ 175 | 176 | def __init__(self, task={}): 177 | self._task = task 178 | self._goal_pos = task.get('position', np.zeros((2,), dtype=np.float32)) 179 | self._action_scaling = None 180 | super(AntPosEnv, self).__init__() 181 | 182 | def step(self, action): 183 | self.do_simulation(action, self.frame_skip) 184 | xyposafter = self.get_body_com("torso")[:2] 185 | 186 | goal_reward = -np.sum(np.abs(xyposafter - self._goal_pos)) + 4.0 187 | survive_reward = 0.05 188 | 189 | ctrl_cost = 0.5 * 1e-2 * np.sum(np.square(action / self.action_scaling)) 190 | contact_cost = 0.5 * 1e-3 * np.sum( 191 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) 192 | 193 | observation = self._get_obs() 194 | reward = goal_reward - ctrl_cost - contact_cost + survive_reward 195 | state = self.state_vector() 196 | notdone = np.isfinite(state).all() \ 197 | and state[2] >= 0.2 and state[2] <= 1.0 198 | done = not notdone 199 | infos = dict(reward_goal=goal_reward, reward_ctrl=-ctrl_cost, 200 | reward_contact=-contact_cost, reward_survive=survive_reward, 201 | task=self._task) 202 | return (observation, reward, done, infos) 203 | 204 | def sample_tasks(self, num_tasks): 205 | positions = self.np_random.uniform(-3.0, 3.0, size=(num_tasks, 2)) 206 | tasks = [{'position': position} for position in positions] 207 | return tasks 208 | 209 | def reset_task(self, task): 210 | self._task = task 211 | self._goal_pos = task['position'] 212 | --------------------------------------------------------------------------------