├── LICENSE ├── README.md ├── assets ├── expert_traj │ └── Hopper-v2_expert_traj.p └── learned_models │ └── Hopper-v2_ppo.p ├── core ├── a2c.py ├── agent.py ├── common.py ├── ppo.py └── trpo.py ├── examples ├── a2c_gym.py ├── ppo_gym.py └── trpo_gym.py ├── gail ├── gail_gym.py └── save_expert_traj.py ├── models ├── mlp_critic.py ├── mlp_discriminator.py ├── mlp_policy.py └── mlp_policy_disc.py └── utils ├── __init__.py ├── math.py ├── replay_memory.py ├── tools.py ├── torch.py └── zfilter.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Ye Yuan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch implementation of reinforcement learning algorithms 2 | This repository contains: 3 | 1. policy gradient methods (TRPO, PPO, A2C) 4 | 2. [Generative Adversarial Imitation Learning (GAIL)](https://arxiv.org/pdf/1606.03476.pdf) 5 | 6 | ## Important notes 7 | - The code now works for PyTorch 0.4. For PyTorch 0.3, please check out the 0.3 branch. 8 | - To run mujoco environments, first install [mujoco-py](https://github.com/openai/mujoco-py) and [gym](https://github.com/openai/gym). 9 | - If you have a GPU, I recommend setting the OMP_NUM_THREADS to 1 (PyTorch will create additional threads when performing computations which can damage the performance of multiprocessing. This problem is most serious with Linux, where multiprocessing can be even slower than a single thread): 10 | ``` 11 | export OMP_NUM_THREADS=1 12 | ``` 13 | 14 | ## Features 15 | * Support discrete and continous action space. 16 | * Support multiprocessing for agent to collect samples in multiple environments simultaneously. (x8 faster than single thread) 17 | * Fast Fisher vector product calculation. For this part, Ankur kindly wrote a [blog](http://www.telesens.co/2018/06/09/efficiently-computing-the-fisher-vector-product-in-trpo/) explaining the implementation details. 18 | ## Policy gradient methods 19 | * [Trust Region Policy Optimization (TRPO)](https://arxiv.org/pdf/1502.05477.pdf) -> [examples/trpo_gym.py](https://github.com/Khrylx/PyTorch-RL/blob/master/examples/trpo_gym.py) 20 | * [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf) -> [examples/ppo_gym.py](https://github.com/Khrylx/PyTorch-RL/blob/master/examples/ppo_gym.py) 21 | * [Synchronous A3C (A2C)](https://arxiv.org/pdf/1602.01783.pdf) -> [examples/a2c_gym.py](https://github.com/Khrylx/PyTorch-RL/blob/master/examples/a2c_gym.py) 22 | 23 | ### Example 24 | * python examples/ppo_gym.py --env-name Hopper-v2 25 | 26 | ### Reference 27 | * [ikostrikov/pytorch-trpo](https://github.com/ikostrikov/pytorch-trpo) 28 | * [openai/baselines](https://github.com/openai/baselines) 29 | 30 | 31 | ## Generative Adversarial Imitation Learning (GAIL) 32 | ### To save trajectory 33 | * python gail/save_expert_traj.py --model-path assets/learned_models/Hopper-v2_ppo.p 34 | ### To do imitation learning 35 | * python gail/gail_gym.py --env-name Hopper-v2 --expert-traj-path assets/expert_traj/Hopper-v2_expert_traj.p 36 | -------------------------------------------------------------------------------- /assets/expert_traj/Hopper-v2_expert_traj.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Khrylx/PyTorch-RL/72069237b4d86bcb9675b899ea94228019a4f003/assets/expert_traj/Hopper-v2_expert_traj.p -------------------------------------------------------------------------------- /assets/learned_models/Hopper-v2_ppo.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Khrylx/PyTorch-RL/72069237b4d86bcb9675b899ea94228019a4f003/assets/learned_models/Hopper-v2_ppo.p -------------------------------------------------------------------------------- /core/a2c.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def a2c_step(policy_net, value_net, optimizer_policy, optimizer_value, states, actions, returns, advantages, l2_reg): 5 | 6 | """update critic""" 7 | values_pred = value_net(states) 8 | value_loss = (values_pred - returns).pow(2).mean() 9 | # weight decay 10 | for param in value_net.parameters(): 11 | value_loss += param.pow(2).sum() * l2_reg 12 | optimizer_value.zero_grad() 13 | value_loss.backward() 14 | optimizer_value.step() 15 | 16 | """update policy""" 17 | log_probs = policy_net.get_log_prob(states, actions) 18 | policy_loss = -(log_probs * advantages).mean() 19 | optimizer_policy.zero_grad() 20 | policy_loss.backward() 21 | torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 40) 22 | optimizer_policy.step() 23 | -------------------------------------------------------------------------------- /core/agent.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | from utils.replay_memory import Memory 3 | from utils.torch import * 4 | import math 5 | import time 6 | import os 7 | os.environ["OMP_NUM_THREADS"] = "1" 8 | 9 | 10 | def collect_samples(pid, queue, env, policy, custom_reward, 11 | mean_action, render, running_state, min_batch_size): 12 | if pid > 0: 13 | torch.manual_seed(torch.randint(0, 5000, (1,)) * pid) 14 | if hasattr(env, 'np_random'): 15 | env.np_random.seed(env.np_random.randint(5000) * pid) 16 | if hasattr(env, 'env') and hasattr(env.env, 'np_random'): 17 | env.env.np_random.seed(env.env.np_random.randint(5000) * pid) 18 | log = dict() 19 | memory = Memory() 20 | num_steps = 0 21 | total_reward = 0 22 | min_reward = 1e6 23 | max_reward = -1e6 24 | total_c_reward = 0 25 | min_c_reward = 1e6 26 | max_c_reward = -1e6 27 | num_episodes = 0 28 | 29 | while num_steps < min_batch_size: 30 | state = env.reset() 31 | if running_state is not None: 32 | state = running_state(state) 33 | reward_episode = 0 34 | 35 | for t in range(10000): 36 | state_var = tensor(state).unsqueeze(0) 37 | with torch.no_grad(): 38 | if mean_action: 39 | action = policy(state_var)[0][0].numpy() 40 | else: 41 | action = policy.select_action(state_var)[0].numpy() 42 | action = int(action) if policy.is_disc_action else action.astype(np.float64) 43 | next_state, reward, done, _ = env.step(action) 44 | reward_episode += reward 45 | if running_state is not None: 46 | next_state = running_state(next_state) 47 | 48 | if custom_reward is not None: 49 | reward = custom_reward(state, action) 50 | total_c_reward += reward 51 | min_c_reward = min(min_c_reward, reward) 52 | max_c_reward = max(max_c_reward, reward) 53 | 54 | mask = 0 if done else 1 55 | 56 | memory.push(state, action, mask, next_state, reward) 57 | 58 | if render: 59 | env.render() 60 | if done: 61 | break 62 | 63 | state = next_state 64 | 65 | # log stats 66 | num_steps += (t + 1) 67 | num_episodes += 1 68 | total_reward += reward_episode 69 | min_reward = min(min_reward, reward_episode) 70 | max_reward = max(max_reward, reward_episode) 71 | 72 | log['num_steps'] = num_steps 73 | log['num_episodes'] = num_episodes 74 | log['total_reward'] = total_reward 75 | log['avg_reward'] = total_reward / num_episodes 76 | log['max_reward'] = max_reward 77 | log['min_reward'] = min_reward 78 | if custom_reward is not None: 79 | log['total_c_reward'] = total_c_reward 80 | log['avg_c_reward'] = total_c_reward / num_steps 81 | log['max_c_reward'] = max_c_reward 82 | log['min_c_reward'] = min_c_reward 83 | 84 | if queue is not None: 85 | queue.put([pid, memory, log]) 86 | else: 87 | return memory, log 88 | 89 | 90 | def merge_log(log_list): 91 | log = dict() 92 | log['total_reward'] = sum([x['total_reward'] for x in log_list]) 93 | log['num_episodes'] = sum([x['num_episodes'] for x in log_list]) 94 | log['num_steps'] = sum([x['num_steps'] for x in log_list]) 95 | log['avg_reward'] = log['total_reward'] / log['num_episodes'] 96 | log['max_reward'] = max([x['max_reward'] for x in log_list]) 97 | log['min_reward'] = min([x['min_reward'] for x in log_list]) 98 | if 'total_c_reward' in log_list[0]: 99 | log['total_c_reward'] = sum([x['total_c_reward'] for x in log_list]) 100 | log['avg_c_reward'] = log['total_c_reward'] / log['num_steps'] 101 | log['max_c_reward'] = max([x['max_c_reward'] for x in log_list]) 102 | log['min_c_reward'] = min([x['min_c_reward'] for x in log_list]) 103 | 104 | return log 105 | 106 | 107 | class Agent: 108 | 109 | def __init__(self, env, policy, device, custom_reward=None, running_state=None, num_threads=1): 110 | self.env = env 111 | self.policy = policy 112 | self.device = device 113 | self.custom_reward = custom_reward 114 | self.running_state = running_state 115 | self.num_threads = num_threads 116 | 117 | def collect_samples(self, min_batch_size, mean_action=False, render=False): 118 | t_start = time.time() 119 | to_device(torch.device('cpu'), self.policy) 120 | thread_batch_size = int(math.floor(min_batch_size / self.num_threads)) 121 | queue = multiprocessing.Queue() 122 | workers = [] 123 | 124 | for i in range(self.num_threads-1): 125 | worker_args = (i+1, queue, self.env, self.policy, self.custom_reward, mean_action, 126 | False, self.running_state, thread_batch_size) 127 | workers.append(multiprocessing.Process(target=collect_samples, args=worker_args)) 128 | for worker in workers: 129 | worker.start() 130 | 131 | memory, log = collect_samples(0, None, self.env, self.policy, self.custom_reward, mean_action, 132 | render, self.running_state, thread_batch_size) 133 | 134 | worker_logs = [None] * len(workers) 135 | worker_memories = [None] * len(workers) 136 | for _ in workers: 137 | pid, worker_memory, worker_log = queue.get() 138 | worker_memories[pid - 1] = worker_memory 139 | worker_logs[pid - 1] = worker_log 140 | for worker_memory in worker_memories: 141 | memory.append(worker_memory) 142 | batch = memory.sample() 143 | if self.num_threads > 1: 144 | log_list = [log] + worker_logs 145 | log = merge_log(log_list) 146 | to_device(self.device, self.policy) 147 | t_end = time.time() 148 | log['sample_time'] = t_end - t_start 149 | log['action_mean'] = np.mean(np.vstack(batch.action), axis=0) 150 | log['action_min'] = np.min(np.vstack(batch.action), axis=0) 151 | log['action_max'] = np.max(np.vstack(batch.action), axis=0) 152 | return batch, log 153 | -------------------------------------------------------------------------------- /core/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import to_device 3 | 4 | 5 | def estimate_advantages(rewards, masks, values, gamma, tau, device): 6 | rewards, masks, values = to_device(torch.device('cpu'), rewards, masks, values) 7 | tensor_type = type(rewards) 8 | deltas = tensor_type(rewards.size(0), 1) 9 | advantages = tensor_type(rewards.size(0), 1) 10 | 11 | prev_value = 0 12 | prev_advantage = 0 13 | for i in reversed(range(rewards.size(0))): 14 | deltas[i] = rewards[i] + gamma * prev_value * masks[i] - values[i] 15 | advantages[i] = deltas[i] + gamma * tau * prev_advantage * masks[i] 16 | 17 | prev_value = values[i, 0] 18 | prev_advantage = advantages[i, 0] 19 | 20 | returns = values + advantages 21 | advantages = (advantages - advantages.mean()) / advantages.std() 22 | 23 | advantages, returns = to_device(device, advantages, returns) 24 | return advantages, returns 25 | -------------------------------------------------------------------------------- /core/ppo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def ppo_step(policy_net, value_net, optimizer_policy, optimizer_value, optim_value_iternum, states, actions, 5 | returns, advantages, fixed_log_probs, clip_epsilon, l2_reg): 6 | 7 | """update critic""" 8 | for _ in range(optim_value_iternum): 9 | values_pred = value_net(states) 10 | value_loss = (values_pred - returns).pow(2).mean() 11 | # weight decay 12 | for param in value_net.parameters(): 13 | value_loss += param.pow(2).sum() * l2_reg 14 | optimizer_value.zero_grad() 15 | value_loss.backward() 16 | optimizer_value.step() 17 | 18 | """update policy""" 19 | log_probs = policy_net.get_log_prob(states, actions) 20 | ratio = torch.exp(log_probs - fixed_log_probs) 21 | surr1 = ratio * advantages 22 | surr2 = torch.clamp(ratio, 1.0 - clip_epsilon, 1.0 + clip_epsilon) * advantages 23 | policy_surr = -torch.min(surr1, surr2).mean() 24 | optimizer_policy.zero_grad() 25 | policy_surr.backward() 26 | torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 40) 27 | optimizer_policy.step() 28 | -------------------------------------------------------------------------------- /core/trpo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.optimize 3 | from utils import * 4 | 5 | 6 | def conjugate_gradients(Avp_f, b, nsteps, rdotr_tol=1e-10): 7 | x = zeros(b.size(), device=b.device) 8 | r = b.clone() 9 | p = b.clone() 10 | rdotr = torch.dot(r, r) 11 | for i in range(nsteps): 12 | Avp = Avp_f(p) 13 | alpha = rdotr / torch.dot(p, Avp) 14 | x += alpha * p 15 | r -= alpha * Avp 16 | new_rdotr = torch.dot(r, r) 17 | betta = new_rdotr / rdotr 18 | p = r + betta * p 19 | rdotr = new_rdotr 20 | if rdotr < rdotr_tol: 21 | break 22 | return x 23 | 24 | 25 | def line_search(model, f, x, fullstep, expected_improve_full, max_backtracks=10, accept_ratio=0.1): 26 | fval = f(True).item() 27 | 28 | for stepfrac in [.5**x for x in range(max_backtracks)]: 29 | x_new = x + stepfrac * fullstep 30 | set_flat_params_to(model, x_new) 31 | fval_new = f(True).item() 32 | actual_improve = fval - fval_new 33 | expected_improve = expected_improve_full * stepfrac 34 | ratio = actual_improve / expected_improve 35 | 36 | if ratio > accept_ratio: 37 | return True, x_new 38 | return False, x 39 | 40 | 41 | def trpo_step(policy_net, value_net, states, actions, returns, advantages, max_kl, damping, l2_reg, use_fim=True): 42 | 43 | """update critic""" 44 | 45 | def get_value_loss(flat_params): 46 | set_flat_params_to(value_net, tensor(flat_params)) 47 | for param in value_net.parameters(): 48 | if param.grad is not None: 49 | param.grad.data.fill_(0) 50 | values_pred = value_net(states) 51 | value_loss = (values_pred - returns).pow(2).mean() 52 | 53 | # weight decay 54 | for param in value_net.parameters(): 55 | value_loss += param.pow(2).sum() * l2_reg 56 | value_loss.backward() 57 | return value_loss.item(), get_flat_grad_from(value_net.parameters()).cpu().numpy() 58 | 59 | flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b(get_value_loss, 60 | get_flat_params_from(value_net).detach().cpu().numpy(), 61 | maxiter=25) 62 | set_flat_params_to(value_net, tensor(flat_params)) 63 | 64 | """update policy""" 65 | with torch.no_grad(): 66 | fixed_log_probs = policy_net.get_log_prob(states, actions) 67 | """define the loss function for TRPO""" 68 | def get_loss(volatile=False): 69 | with torch.set_grad_enabled(not volatile): 70 | log_probs = policy_net.get_log_prob(states, actions) 71 | action_loss = -advantages * torch.exp(log_probs - fixed_log_probs) 72 | return action_loss.mean() 73 | 74 | """use fisher information matrix for Hessian*vector""" 75 | def Fvp_fim(v): 76 | M, mu, info = policy_net.get_fim(states) 77 | mu = mu.view(-1) 78 | filter_input_ids = set() if policy_net.is_disc_action else set([info['std_id']]) 79 | 80 | t = ones(mu.size(), requires_grad=True, device=mu.device) 81 | mu_t = (mu * t).sum() 82 | Jt = compute_flat_grad(mu_t, policy_net.parameters(), filter_input_ids=filter_input_ids, create_graph=True) 83 | Jtv = (Jt * v).sum() 84 | Jv = torch.autograd.grad(Jtv, t)[0] 85 | MJv = M * Jv.detach() 86 | mu_MJv = (MJv * mu).sum() 87 | JTMJv = compute_flat_grad(mu_MJv, policy_net.parameters(), filter_input_ids=filter_input_ids).detach() 88 | JTMJv /= states.shape[0] 89 | if not policy_net.is_disc_action: 90 | std_index = info['std_index'] 91 | JTMJv[std_index: std_index + M.shape[0]] += 2 * v[std_index: std_index + M.shape[0]] 92 | return JTMJv + v * damping 93 | 94 | """directly compute Hessian*vector from KL""" 95 | def Fvp_direct(v): 96 | kl = policy_net.get_kl(states) 97 | kl = kl.mean() 98 | 99 | grads = torch.autograd.grad(kl, policy_net.parameters(), create_graph=True) 100 | flat_grad_kl = torch.cat([grad.view(-1) for grad in grads]) 101 | 102 | kl_v = (flat_grad_kl * v).sum() 103 | grads = torch.autograd.grad(kl_v, policy_net.parameters()) 104 | flat_grad_grad_kl = torch.cat([grad.contiguous().view(-1) for grad in grads]).detach() 105 | 106 | return flat_grad_grad_kl + v * damping 107 | 108 | Fvp = Fvp_fim if use_fim else Fvp_direct 109 | 110 | loss = get_loss() 111 | grads = torch.autograd.grad(loss, policy_net.parameters()) 112 | loss_grad = torch.cat([grad.view(-1) for grad in grads]).detach() 113 | stepdir = conjugate_gradients(Fvp, -loss_grad, 10) 114 | 115 | shs = 0.5 * (stepdir.dot(Fvp(stepdir))) 116 | lm = math.sqrt(max_kl / shs) 117 | fullstep = stepdir * lm 118 | expected_improve = -loss_grad.dot(fullstep) 119 | 120 | prev_params = get_flat_params_from(policy_net) 121 | success, new_params = line_search(policy_net, get_loss, prev_params, fullstep, expected_improve) 122 | set_flat_params_to(policy_net, new_params) 123 | 124 | return success 125 | -------------------------------------------------------------------------------- /examples/a2c_gym.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gym 3 | import os 4 | import sys 5 | import pickle 6 | import time 7 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 8 | 9 | from utils import * 10 | from models.mlp_policy import Policy 11 | from models.mlp_critic import Value 12 | from models.mlp_policy_disc import DiscretePolicy 13 | from core.a2c import a2c_step 14 | from core.common import estimate_advantages 15 | from core.agent import Agent 16 | 17 | 18 | parser = argparse.ArgumentParser(description='PyTorch A2C example') 19 | parser.add_argument('--env-name', default="Hopper-v2", metavar='G', 20 | help='name of the environment to run') 21 | parser.add_argument('--model-path', metavar='G', 22 | help='path of pre-trained model') 23 | parser.add_argument('--render', action='store_true', default=False, 24 | help='render the environment') 25 | parser.add_argument('--log-std', type=float, default=-0.0, metavar='G', 26 | help='log std for the policy (default: -0.0)') 27 | parser.add_argument('--gamma', type=float, default=0.99, metavar='G', 28 | help='discount factor (default: 0.99)') 29 | parser.add_argument('--tau', type=float, default=0.95, metavar='G', 30 | help='gae (default: 0.95)') 31 | parser.add_argument('--l2-reg', type=float, default=1e-3, metavar='G', 32 | help='l2 regularization regression (default: 1e-3)') 33 | parser.add_argument('--num-threads', type=int, default=4, metavar='N', 34 | help='number of threads for agent (default: 4)') 35 | parser.add_argument('--seed', type=int, default=1, metavar='N', 36 | help='random seed (default: 1)') 37 | parser.add_argument('--min-batch-size', type=int, default=2048, metavar='N', 38 | help='minimal batch size per A2C update (default: 2048)') 39 | parser.add_argument('--eval-batch-size', type=int, default=2048, metavar='N', 40 | help='minimal batch size for evaluation (default: 2048)') 41 | parser.add_argument('--max-iter-num', type=int, default=500, metavar='N', 42 | help='maximal number of main iterations (default: 500)') 43 | parser.add_argument('--log-interval', type=int, default=1, metavar='N', 44 | help='interval between training status logs (default: 1)') 45 | parser.add_argument('--save-model-interval', type=int, default=0, metavar='N', 46 | help="interval between saving model (default: 0, means don't save)") 47 | parser.add_argument('--gpu-index', type=int, default=0, metavar='N') 48 | args = parser.parse_args() 49 | 50 | dtype = torch.float64 51 | torch.set_default_dtype(dtype) 52 | device = torch.device('cuda', index=args.gpu_index) if torch.cuda.is_available() else torch.device('cpu') 53 | if torch.cuda.is_available(): 54 | torch.cuda.set_device(args.gpu_index) 55 | 56 | """environment""" 57 | env = gym.make(args.env_name) 58 | state_dim = env.observation_space.shape[0] 59 | is_disc_action = len(env.action_space.shape) == 0 60 | running_state = ZFilter((state_dim,), clip=5) 61 | # running_reward = ZFilter((1,), demean=False, clip=10) 62 | 63 | """seeding""" 64 | np.random.seed(args.seed) 65 | torch.manual_seed(args.seed) 66 | env.seed(args.seed) 67 | 68 | """define actor and critic""" 69 | if args.model_path is None: 70 | if is_disc_action: 71 | policy_net = DiscretePolicy(state_dim, env.action_space.n) 72 | else: 73 | policy_net = Policy(state_dim, env.action_space.shape[0], log_std=args.log_std) 74 | value_net = Value(state_dim) 75 | else: 76 | policy_net, value_net, running_state = pickle.load(open(args.model_path, "rb")) 77 | policy_net.to(device) 78 | value_net.to(device) 79 | 80 | optimizer_policy = torch.optim.Adam(policy_net.parameters(), lr=0.01) 81 | optimizer_value = torch.optim.Adam(value_net.parameters(), lr=0.01) 82 | 83 | """create agent""" 84 | agent = Agent(env, policy_net, device, running_state=running_state, num_threads=args.num_threads) 85 | 86 | 87 | def update_params(batch): 88 | states = torch.from_numpy(np.stack(batch.state)).to(dtype).to(device) 89 | actions = torch.from_numpy(np.stack(batch.action)).to(dtype).to(device) 90 | rewards = torch.from_numpy(np.stack(batch.reward)).to(dtype).to(device) 91 | masks = torch.from_numpy(np.stack(batch.mask)).to(dtype).to(device) 92 | with torch.no_grad(): 93 | values = value_net(states) 94 | 95 | """get advantage estimation from the trajectories""" 96 | advantages, returns = estimate_advantages(rewards, masks, values, args.gamma, args.tau, device) 97 | 98 | """perform TRPO update""" 99 | a2c_step(policy_net, value_net, optimizer_policy, optimizer_value, states, actions, returns, advantages, args.l2_reg) 100 | 101 | 102 | def main_loop(): 103 | for i_iter in range(args.max_iter_num): 104 | """generate multiple trajectories that reach the minimum batch_size""" 105 | batch, log = agent.collect_samples(args.min_batch_size, render=args.render) 106 | t0 = time.time() 107 | update_params(batch) 108 | t1 = time.time() 109 | """evaluate with determinstic action (remove noise for exploration)""" 110 | _, log_eval = agent.collect_samples(args.eval_batch_size, mean_action=True) 111 | t2 = time.time() 112 | 113 | if i_iter % args.log_interval == 0: 114 | print('{}\tT_sample {:.4f}\tT_update {:.4f}\tT_eval {:.4f}\ttrain_R_min {:.2f}\ttrain_R_max {:.2f}\ttrain_R_avg {:.2f}\teval_R_avg {:.2f}'.format( 115 | i_iter, log['sample_time'], t1-t0, t2-t1, log['min_reward'], log['max_reward'], log['avg_reward'], log_eval['avg_reward'])) 116 | 117 | if args.save_model_interval > 0 and (i_iter+1) % args.save_model_interval == 0: 118 | to_device(torch.device('cpu'), policy_net, value_net) 119 | pickle.dump((policy_net, value_net, running_state), 120 | open(os.path.join(assets_dir(), 'learned_models/{}_a2c.p'.format(args.env_name)), 'wb')) 121 | to_device(device, policy_net, value_net) 122 | 123 | """clean up gpu memory""" 124 | torch.cuda.empty_cache() 125 | 126 | 127 | main_loop() 128 | -------------------------------------------------------------------------------- /examples/ppo_gym.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gym 3 | import os 4 | import sys 5 | import pickle 6 | import time 7 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 8 | 9 | from utils import * 10 | from models.mlp_policy import Policy 11 | from models.mlp_critic import Value 12 | from models.mlp_policy_disc import DiscretePolicy 13 | from core.ppo import ppo_step 14 | from core.common import estimate_advantages 15 | from core.agent import Agent 16 | 17 | 18 | parser = argparse.ArgumentParser(description='PyTorch PPO example') 19 | parser.add_argument('--env-name', default="Hopper-v2", metavar='G', 20 | help='name of the environment to run') 21 | parser.add_argument('--model-path', metavar='G', 22 | help='path of pre-trained model') 23 | parser.add_argument('--render', action='store_true', default=False, 24 | help='render the environment') 25 | parser.add_argument('--log-std', type=float, default=-0.0, metavar='G', 26 | help='log std for the policy (default: -0.0)') 27 | parser.add_argument('--gamma', type=float, default=0.99, metavar='G', 28 | help='discount factor (default: 0.99)') 29 | parser.add_argument('--tau', type=float, default=0.95, metavar='G', 30 | help='gae (default: 0.95)') 31 | parser.add_argument('--l2-reg', type=float, default=1e-3, metavar='G', 32 | help='l2 regularization regression (default: 1e-3)') 33 | parser.add_argument('--learning-rate', type=float, default=3e-4, metavar='G', 34 | help='learning rate (default: 3e-4)') 35 | parser.add_argument('--clip-epsilon', type=float, default=0.2, metavar='N', 36 | help='clipping epsilon for PPO') 37 | parser.add_argument('--num-threads', type=int, default=4, metavar='N', 38 | help='number of threads for agent (default: 4)') 39 | parser.add_argument('--seed', type=int, default=1, metavar='N', 40 | help='random seed (default: 1)') 41 | parser.add_argument('--min-batch-size', type=int, default=2048, metavar='N', 42 | help='minimal batch size per PPO update (default: 2048)') 43 | parser.add_argument('--eval-batch-size', type=int, default=2048, metavar='N', 44 | help='minimal batch size for evaluation (default: 2048)') 45 | parser.add_argument('--max-iter-num', type=int, default=500, metavar='N', 46 | help='maximal number of main iterations (default: 500)') 47 | parser.add_argument('--log-interval', type=int, default=1, metavar='N', 48 | help='interval between training status logs (default: 10)') 49 | parser.add_argument('--save-model-interval', type=int, default=0, metavar='N', 50 | help="interval between saving model (default: 0, means don't save)") 51 | parser.add_argument('--gpu-index', type=int, default=0, metavar='N') 52 | args = parser.parse_args() 53 | 54 | dtype = torch.float64 55 | torch.set_default_dtype(dtype) 56 | device = torch.device('cuda', index=args.gpu_index) if torch.cuda.is_available() else torch.device('cpu') 57 | if torch.cuda.is_available(): 58 | torch.cuda.set_device(args.gpu_index) 59 | 60 | """environment""" 61 | env = gym.make(args.env_name) 62 | state_dim = env.observation_space.shape[0] 63 | is_disc_action = len(env.action_space.shape) == 0 64 | running_state = ZFilter((state_dim,), clip=5) 65 | # running_reward = ZFilter((1,), demean=False, clip=10) 66 | 67 | """seeding""" 68 | np.random.seed(args.seed) 69 | torch.manual_seed(args.seed) 70 | env.seed(args.seed) 71 | 72 | """define actor and critic""" 73 | if args.model_path is None: 74 | if is_disc_action: 75 | policy_net = DiscretePolicy(state_dim, env.action_space.n) 76 | else: 77 | policy_net = Policy(state_dim, env.action_space.shape[0], log_std=args.log_std) 78 | value_net = Value(state_dim) 79 | else: 80 | policy_net, value_net, running_state = pickle.load(open(args.model_path, "rb")) 81 | policy_net.to(device) 82 | value_net.to(device) 83 | 84 | optimizer_policy = torch.optim.Adam(policy_net.parameters(), lr=args.learning_rate) 85 | optimizer_value = torch.optim.Adam(value_net.parameters(), lr=args.learning_rate) 86 | 87 | # optimization epoch number and batch size for PPO 88 | optim_epochs = 10 89 | optim_batch_size = 64 90 | 91 | """create agent""" 92 | agent = Agent(env, policy_net, device, running_state=running_state, num_threads=args.num_threads) 93 | 94 | 95 | def update_params(batch, i_iter): 96 | states = torch.from_numpy(np.stack(batch.state)).to(dtype).to(device) 97 | actions = torch.from_numpy(np.stack(batch.action)).to(dtype).to(device) 98 | rewards = torch.from_numpy(np.stack(batch.reward)).to(dtype).to(device) 99 | masks = torch.from_numpy(np.stack(batch.mask)).to(dtype).to(device) 100 | with torch.no_grad(): 101 | values = value_net(states) 102 | fixed_log_probs = policy_net.get_log_prob(states, actions) 103 | 104 | """get advantage estimation from the trajectories""" 105 | advantages, returns = estimate_advantages(rewards, masks, values, args.gamma, args.tau, device) 106 | 107 | """perform mini-batch PPO update""" 108 | optim_iter_num = int(math.ceil(states.shape[0] / optim_batch_size)) 109 | for _ in range(optim_epochs): 110 | perm = np.arange(states.shape[0]) 111 | np.random.shuffle(perm) 112 | perm = LongTensor(perm).to(device) 113 | 114 | states, actions, returns, advantages, fixed_log_probs = \ 115 | states[perm].clone(), actions[perm].clone(), returns[perm].clone(), advantages[perm].clone(), fixed_log_probs[perm].clone() 116 | 117 | for i in range(optim_iter_num): 118 | ind = slice(i * optim_batch_size, min((i + 1) * optim_batch_size, states.shape[0])) 119 | states_b, actions_b, advantages_b, returns_b, fixed_log_probs_b = \ 120 | states[ind], actions[ind], advantages[ind], returns[ind], fixed_log_probs[ind] 121 | 122 | ppo_step(policy_net, value_net, optimizer_policy, optimizer_value, 1, states_b, actions_b, returns_b, 123 | advantages_b, fixed_log_probs_b, args.clip_epsilon, args.l2_reg) 124 | 125 | 126 | def main_loop(): 127 | for i_iter in range(args.max_iter_num): 128 | """generate multiple trajectories that reach the minimum batch_size""" 129 | batch, log = agent.collect_samples(args.min_batch_size, render=args.render) 130 | t0 = time.time() 131 | update_params(batch, i_iter) 132 | t1 = time.time() 133 | """evaluate with determinstic action (remove noise for exploration)""" 134 | _, log_eval = agent.collect_samples(args.eval_batch_size, mean_action=True) 135 | t2 = time.time() 136 | 137 | if i_iter % args.log_interval == 0: 138 | print('{}\tT_sample {:.4f}\tT_update {:.4f}\tT_eval {:.4f}\ttrain_R_min {:.2f}\ttrain_R_max {:.2f}\ttrain_R_avg {:.2f}\teval_R_avg {:.2f}'.format( 139 | i_iter, log['sample_time'], t1-t0, t2-t1, log['min_reward'], log['max_reward'], log['avg_reward'], log_eval['avg_reward'])) 140 | 141 | if args.save_model_interval > 0 and (i_iter+1) % args.save_model_interval == 0: 142 | to_device(torch.device('cpu'), policy_net, value_net) 143 | pickle.dump((policy_net, value_net, running_state), 144 | open(os.path.join(assets_dir(), 'learned_models/{}_ppo.p'.format(args.env_name)), 'wb')) 145 | to_device(device, policy_net, value_net) 146 | 147 | """clean up gpu memory""" 148 | torch.cuda.empty_cache() 149 | 150 | 151 | main_loop() 152 | -------------------------------------------------------------------------------- /examples/trpo_gym.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gym 3 | import os 4 | import sys 5 | import pickle 6 | import time 7 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 8 | 9 | from utils import * 10 | from models.mlp_policy import Policy 11 | from models.mlp_critic import Value 12 | from models.mlp_policy_disc import DiscretePolicy 13 | from core.trpo import trpo_step 14 | from core.common import estimate_advantages 15 | from core.agent import Agent 16 | 17 | 18 | parser = argparse.ArgumentParser(description='PyTorch TRPO example') 19 | parser.add_argument('--env-name', default="Hopper-v2", metavar='G', 20 | help='name of the environment to run') 21 | parser.add_argument('--model-path', metavar='G', 22 | help='path of pre-trained model') 23 | parser.add_argument('--render', action='store_true', default=False, 24 | help='render the environment') 25 | parser.add_argument('--log-std', type=float, default=-0.0, metavar='G', 26 | help='log std for the policy (default: -0.0)') 27 | parser.add_argument('--gamma', type=float, default=0.99, metavar='G', 28 | help='discount factor (default: 0.99)') 29 | parser.add_argument('--tau', type=float, default=0.95, metavar='G', 30 | help='gae (default: 0.95)') 31 | parser.add_argument('--l2-reg', type=float, default=1e-3, metavar='G', 32 | help='l2 regularization regression (default: 1e-3)') 33 | parser.add_argument('--max-kl', type=float, default=1e-2, metavar='G', 34 | help='max kl value (default: 1e-2)') 35 | parser.add_argument('--damping', type=float, default=1e-2, metavar='G', 36 | help='damping (default: 1e-2)') 37 | parser.add_argument('--num-threads', type=int, default=4, metavar='N', 38 | help='number of threads for agent (default: 4)') 39 | parser.add_argument('--seed', type=int, default=1, metavar='N', 40 | help='random seed (default: 1)') 41 | parser.add_argument('--min-batch-size', type=int, default=2048, metavar='N', 42 | help='minimal batch size per TRPO update (default: 2048)') 43 | parser.add_argument('--eval-batch-size', type=int, default=2048, metavar='N', 44 | help='minimal batch size for evaluation (default: 2048)') 45 | parser.add_argument('--max-iter-num', type=int, default=500, metavar='N', 46 | help='maximal number of main iterations (default: 500)') 47 | parser.add_argument('--log-interval', type=int, default=1, metavar='N', 48 | help='interval between training status logs (default: 10)') 49 | parser.add_argument('--save-model-interval', type=int, default=0, metavar='N', 50 | help="interval between saving model (default: 0, means don't save)") 51 | parser.add_argument('--gpu-index', type=int, default=0, metavar='N') 52 | args = parser.parse_args() 53 | 54 | dtype = torch.float64 55 | torch.set_default_dtype(dtype) 56 | device = torch.device('cuda', index=args.gpu_index) if torch.cuda.is_available() else torch.device('cpu') 57 | if torch.cuda.is_available(): 58 | torch.cuda.set_device(args.gpu_index) 59 | 60 | """environment""" 61 | env = gym.make(args.env_name) 62 | state_dim = env.observation_space.shape[0] 63 | is_disc_action = len(env.action_space.shape) == 0 64 | running_state = ZFilter((state_dim,), clip=5) 65 | # running_reward = ZFilter((1,), demean=False, clip=10) 66 | 67 | """seeding""" 68 | np.random.seed(args.seed) 69 | torch.manual_seed(args.seed) 70 | env.seed(args.seed) 71 | 72 | """define actor and critic""" 73 | if args.model_path is None: 74 | if is_disc_action: 75 | policy_net = DiscretePolicy(state_dim, env.action_space.n) 76 | else: 77 | policy_net = Policy(state_dim, env.action_space.shape[0], log_std=args.log_std) 78 | value_net = Value(state_dim) 79 | else: 80 | policy_net, value_net, running_state = pickle.load(open(args.model_path, "rb")) 81 | policy_net.to(device) 82 | value_net.to(device) 83 | 84 | """create agent""" 85 | agent = Agent(env, policy_net, device, running_state=running_state, num_threads=args.num_threads) 86 | 87 | 88 | def update_params(batch): 89 | states = torch.from_numpy(np.stack(batch.state)).to(dtype).to(device) 90 | actions = torch.from_numpy(np.stack(batch.action)).to(dtype).to(device) 91 | rewards = torch.from_numpy(np.stack(batch.reward)).to(dtype).to(device) 92 | masks = torch.from_numpy(np.stack(batch.mask)).to(dtype).to(device) 93 | with torch.no_grad(): 94 | values = value_net(states) 95 | 96 | """get advantage estimation from the trajectories""" 97 | advantages, returns = estimate_advantages(rewards, masks, values, args.gamma, args.tau, device) 98 | 99 | """perform TRPO update""" 100 | trpo_step(policy_net, value_net, states, actions, returns, advantages, args.max_kl, args.damping, args.l2_reg) 101 | 102 | 103 | def main_loop(): 104 | for i_iter in range(args.max_iter_num): 105 | """generate multiple trajectories that reach the minimum batch_size""" 106 | batch, log = agent.collect_samples(args.min_batch_size, render=args.render) 107 | t0 = time.time() 108 | update_params(batch) 109 | t1 = time.time() 110 | """evaluate with determinstic action (remove noise for exploration)""" 111 | _, log_eval = agent.collect_samples(args.eval_batch_size, mean_action=True) 112 | t2 = time.time() 113 | 114 | if i_iter % args.log_interval == 0: 115 | print('{}\tT_sample {:.4f}\tT_update {:.4f}\tT_eval {:.4f}\ttrain_R_min {:.2f}\ttrain_R_max {:.2f}\ttrain_R_avg {:.2f}\teval_R_avg {:.2f}'.format( 116 | i_iter, log['sample_time'], t1-t0, t2-t1, log['min_reward'], log['max_reward'], log['avg_reward'], log_eval['avg_reward'])) 117 | 118 | if args.save_model_interval > 0 and (i_iter+1) % args.save_model_interval == 0: 119 | to_device(torch.device('cpu'), policy_net, value_net) 120 | pickle.dump((policy_net, value_net, running_state), 121 | open(os.path.join(assets_dir(), 'learned_models/{}_trpo.p'.format(args.env_name)), 'wb')) 122 | to_device(device, policy_net, value_net) 123 | 124 | """clean up gpu memory""" 125 | torch.cuda.empty_cache() 126 | 127 | 128 | main_loop() 129 | -------------------------------------------------------------------------------- /gail/gail_gym.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gym 3 | import os 4 | import sys 5 | import pickle 6 | import time 7 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 8 | 9 | from utils import * 10 | from models.mlp_policy import Policy 11 | from models.mlp_critic import Value 12 | from models.mlp_policy_disc import DiscretePolicy 13 | from models.mlp_discriminator import Discriminator 14 | from torch import nn 15 | from core.ppo import ppo_step 16 | from core.common import estimate_advantages 17 | from core.agent import Agent 18 | 19 | 20 | parser = argparse.ArgumentParser(description='PyTorch GAIL example') 21 | parser.add_argument('--env-name', default="Hopper-v2", metavar='G', 22 | help='name of the environment to run') 23 | parser.add_argument('--expert-traj-path', metavar='G', 24 | help='path of the expert trajectories') 25 | parser.add_argument('--render', action='store_true', default=False, 26 | help='render the environment') 27 | parser.add_argument('--log-std', type=float, default=-0.0, metavar='G', 28 | help='log std for the policy (default: -0.0)') 29 | parser.add_argument('--gamma', type=float, default=0.99, metavar='G', 30 | help='discount factor (default: 0.99)') 31 | parser.add_argument('--tau', type=float, default=0.95, metavar='G', 32 | help='gae (default: 0.95)') 33 | parser.add_argument('--l2-reg', type=float, default=1e-3, metavar='G', 34 | help='l2 regularization regression (default: 1e-3)') 35 | parser.add_argument('--learning-rate', type=float, default=3e-4, metavar='G', 36 | help='gae (default: 3e-4)') 37 | parser.add_argument('--clip-epsilon', type=float, default=0.2, metavar='N', 38 | help='clipping epsilon for PPO') 39 | parser.add_argument('--num-threads', type=int, default=4, metavar='N', 40 | help='number of threads for agent (default: 4)') 41 | parser.add_argument('--seed', type=int, default=1, metavar='N', 42 | help='random seed (default: 1)') 43 | parser.add_argument('--min-batch-size', type=int, default=2048, metavar='N', 44 | help='minimal batch size per PPO update (default: 2048)') 45 | parser.add_argument('--eval-batch-size', type=int, default=2048, metavar='N', 46 | help='minimal batch size for evaluation (default: 2048)') 47 | parser.add_argument('--max-iter-num', type=int, default=500, metavar='N', 48 | help='maximal number of main iterations (default: 500)') 49 | parser.add_argument('--log-interval', type=int, default=1, metavar='N', 50 | help='interval between training status logs (default: 10)') 51 | parser.add_argument('--save-model-interval', type=int, default=0, metavar='N', 52 | help="interval between saving model (default: 0, means don't save)") 53 | parser.add_argument('--gpu-index', type=int, default=0, metavar='N') 54 | args = parser.parse_args() 55 | 56 | dtype = torch.float64 57 | torch.set_default_dtype(dtype) 58 | device = torch.device('cuda', index=args.gpu_index) if torch.cuda.is_available() else torch.device('cpu') 59 | if torch.cuda.is_available(): 60 | torch.cuda.set_device(args.gpu_index) 61 | 62 | """environment""" 63 | env = gym.make(args.env_name) 64 | state_dim = env.observation_space.shape[0] 65 | is_disc_action = len(env.action_space.shape) == 0 66 | action_dim = 1 if is_disc_action else env.action_space.shape[0] 67 | running_state = ZFilter((state_dim,), clip=5) 68 | # running_reward = ZFilter((1,), demean=False, clip=10) 69 | 70 | """seeding""" 71 | np.random.seed(args.seed) 72 | torch.manual_seed(args.seed) 73 | env.seed(args.seed) 74 | 75 | """define actor and critic""" 76 | if is_disc_action: 77 | policy_net = DiscretePolicy(state_dim, env.action_space.n) 78 | else: 79 | policy_net = Policy(state_dim, env.action_space.shape[0], log_std=args.log_std) 80 | value_net = Value(state_dim) 81 | discrim_net = Discriminator(state_dim + action_dim) 82 | discrim_criterion = nn.BCELoss() 83 | to_device(device, policy_net, value_net, discrim_net, discrim_criterion) 84 | 85 | optimizer_policy = torch.optim.Adam(policy_net.parameters(), lr=args.learning_rate) 86 | optimizer_value = torch.optim.Adam(value_net.parameters(), lr=args.learning_rate) 87 | optimizer_discrim = torch.optim.Adam(discrim_net.parameters(), lr=args.learning_rate) 88 | 89 | # optimization epoch number and batch size for PPO 90 | optim_epochs = 10 91 | optim_batch_size = 64 92 | 93 | # load trajectory 94 | expert_traj, running_state = pickle.load(open(args.expert_traj_path, "rb")) 95 | running_state.fix = True 96 | 97 | 98 | def expert_reward(state, action): 99 | state_action = tensor(np.hstack([state, action]), dtype=dtype) 100 | with torch.no_grad(): 101 | return -math.log(discrim_net(state_action)[0].item()) 102 | 103 | 104 | """create agent""" 105 | agent = Agent(env, policy_net, device, custom_reward=expert_reward, 106 | running_state=running_state, num_threads=args.num_threads) 107 | 108 | 109 | def update_params(batch, i_iter): 110 | states = torch.from_numpy(np.stack(batch.state)).to(dtype).to(device) 111 | actions = torch.from_numpy(np.stack(batch.action)).to(dtype).to(device) 112 | rewards = torch.from_numpy(np.stack(batch.reward)).to(dtype).to(device) 113 | masks = torch.from_numpy(np.stack(batch.mask)).to(dtype).to(device) 114 | with torch.no_grad(): 115 | values = value_net(states) 116 | fixed_log_probs = policy_net.get_log_prob(states, actions) 117 | 118 | """get advantage estimation from the trajectories""" 119 | advantages, returns = estimate_advantages(rewards, masks, values, args.gamma, args.tau, device) 120 | 121 | """update discriminator""" 122 | for _ in range(1): 123 | expert_state_actions = torch.from_numpy(expert_traj).to(dtype).to(device) 124 | g_o = discrim_net(torch.cat([states, actions], 1)) 125 | e_o = discrim_net(expert_state_actions) 126 | optimizer_discrim.zero_grad() 127 | discrim_loss = discrim_criterion(g_o, ones((states.shape[0], 1), device=device)) + \ 128 | discrim_criterion(e_o, zeros((expert_traj.shape[0], 1), device=device)) 129 | discrim_loss.backward() 130 | optimizer_discrim.step() 131 | 132 | """perform mini-batch PPO update""" 133 | optim_iter_num = int(math.ceil(states.shape[0] / optim_batch_size)) 134 | for _ in range(optim_epochs): 135 | perm = np.arange(states.shape[0]) 136 | np.random.shuffle(perm) 137 | perm = LongTensor(perm).to(device) 138 | 139 | states, actions, returns, advantages, fixed_log_probs = \ 140 | states[perm].clone(), actions[perm].clone(), returns[perm].clone(), advantages[perm].clone(), fixed_log_probs[perm].clone() 141 | 142 | for i in range(optim_iter_num): 143 | ind = slice(i * optim_batch_size, min((i + 1) * optim_batch_size, states.shape[0])) 144 | states_b, actions_b, advantages_b, returns_b, fixed_log_probs_b = \ 145 | states[ind], actions[ind], advantages[ind], returns[ind], fixed_log_probs[ind] 146 | 147 | ppo_step(policy_net, value_net, optimizer_policy, optimizer_value, 1, states_b, actions_b, returns_b, 148 | advantages_b, fixed_log_probs_b, args.clip_epsilon, args.l2_reg) 149 | 150 | 151 | def main_loop(): 152 | for i_iter in range(args.max_iter_num): 153 | """generate multiple trajectories that reach the minimum batch_size""" 154 | discrim_net.to(torch.device('cpu')) 155 | batch, log = agent.collect_samples(args.min_batch_size, render=args.render) 156 | discrim_net.to(device) 157 | 158 | t0 = time.time() 159 | update_params(batch, i_iter) 160 | t1 = time.time() 161 | """evaluate with determinstic action (remove noise for exploration)""" 162 | discrim_net.to(torch.device('cpu')) 163 | _, log_eval = agent.collect_samples(args.eval_batch_size, mean_action=True) 164 | discrim_net.to(device) 165 | t2 = time.time() 166 | 167 | if i_iter % args.log_interval == 0: 168 | print('{}\tT_sample {:.4f}\tT_update {:.4f}\ttrain_discrim_R_avg {:.2f}\ttrain_R_avg {:.2f}\teval_discrim_R_avg {:.2f}\teval_R_avg {:.2f}'.format( 169 | i_iter, log['sample_time'], t1-t0, log['avg_c_reward'], log['avg_reward'], log_eval['avg_c_reward'], log_eval['avg_reward'])) 170 | 171 | if args.save_model_interval > 0 and (i_iter+1) % args.save_model_interval == 0: 172 | to_device(torch.device('cpu'), policy_net, value_net, discrim_net) 173 | pickle.dump((policy_net, value_net, discrim_net), open(os.path.join(assets_dir(), 'learned_models/{}_gail.p'.format(args.env_name)), 'wb')) 174 | to_device(device, policy_net, value_net, discrim_net) 175 | 176 | """clean up gpu memory""" 177 | torch.cuda.empty_cache() 178 | 179 | 180 | main_loop() 181 | -------------------------------------------------------------------------------- /gail/save_expert_traj.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gym 3 | import os 4 | import sys 5 | import pickle 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 7 | 8 | from itertools import count 9 | from utils import * 10 | 11 | 12 | parser = argparse.ArgumentParser(description='Save expert trajectory') 13 | parser.add_argument('--env-name', default="Hopper-v2", metavar='G', 14 | help='name of the environment to run') 15 | parser.add_argument('--model-path', metavar='G', 16 | help='name of the expert model') 17 | parser.add_argument('--render', action='store_true', default=False, 18 | help='render the environment') 19 | parser.add_argument('--seed', type=int, default=1, metavar='N', 20 | help='random seed (default: 1)') 21 | parser.add_argument('--max-expert-state-num', type=int, default=50000, metavar='N', 22 | help='maximal number of main iterations (default: 50000)') 23 | args = parser.parse_args() 24 | 25 | dtype = torch.float64 26 | torch.set_default_dtype(dtype) 27 | env = gym.make(args.env_name) 28 | env.seed(args.seed) 29 | torch.manual_seed(args.seed) 30 | is_disc_action = len(env.action_space.shape) == 0 31 | state_dim = env.observation_space.shape[0] 32 | 33 | policy_net, _, running_state = pickle.load(open(args.model_path, "rb")) 34 | running_state.fix = True 35 | expert_traj = [] 36 | 37 | 38 | def main_loop(): 39 | 40 | num_steps = 0 41 | 42 | for i_episode in count(): 43 | 44 | state = env.reset() 45 | state = running_state(state) 46 | reward_episode = 0 47 | 48 | for t in range(10000): 49 | state_var = tensor(state).unsqueeze(0).to(dtype) 50 | # choose mean action 51 | action = policy_net(state_var)[0][0].detach().numpy() 52 | # choose stochastic action 53 | # action = policy_net.select_action(state_var)[0].cpu().numpy() 54 | action = int(action) if is_disc_action else action.astype(np.float64) 55 | next_state, reward, done, _ = env.step(action) 56 | next_state = running_state(next_state) 57 | reward_episode += reward 58 | num_steps += 1 59 | 60 | expert_traj.append(np.hstack([state, action])) 61 | 62 | if args.render: 63 | env.render() 64 | if done or num_steps >= args.max_expert_state_num: 65 | break 66 | 67 | state = next_state 68 | 69 | print('Episode {}\t reward: {:.2f}'.format(i_episode, reward_episode)) 70 | 71 | if num_steps >= args.max_expert_state_num: 72 | break 73 | 74 | 75 | main_loop() 76 | expert_traj = np.stack(expert_traj) 77 | pickle.dump((expert_traj, running_state), open(os.path.join(assets_dir(), 'expert_traj/{}_expert_traj.p'.format(args.env_name)), 'wb')) 78 | -------------------------------------------------------------------------------- /models/mlp_critic.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class Value(nn.Module): 6 | def __init__(self, state_dim, hidden_size=(128, 128), activation='tanh'): 7 | super().__init__() 8 | if activation == 'tanh': 9 | self.activation = torch.tanh 10 | elif activation == 'relu': 11 | self.activation = torch.relu 12 | elif activation == 'sigmoid': 13 | self.activation = torch.sigmoid 14 | 15 | self.affine_layers = nn.ModuleList() 16 | last_dim = state_dim 17 | for nh in hidden_size: 18 | self.affine_layers.append(nn.Linear(last_dim, nh)) 19 | last_dim = nh 20 | 21 | self.value_head = nn.Linear(last_dim, 1) 22 | self.value_head.weight.data.mul_(0.1) 23 | self.value_head.bias.data.mul_(0.0) 24 | 25 | def forward(self, x): 26 | for affine in self.affine_layers: 27 | x = self.activation(affine(x)) 28 | 29 | value = self.value_head(x) 30 | return value 31 | -------------------------------------------------------------------------------- /models/mlp_discriminator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class Discriminator(nn.Module): 6 | def __init__(self, num_inputs, hidden_size=(128, 128), activation='tanh'): 7 | super().__init__() 8 | if activation == 'tanh': 9 | self.activation = torch.tanh 10 | elif activation == 'relu': 11 | self.activation = torch.relu 12 | elif activation == 'sigmoid': 13 | self.activation = torch.sigmoid 14 | 15 | self.affine_layers = nn.ModuleList() 16 | last_dim = num_inputs 17 | for nh in hidden_size: 18 | self.affine_layers.append(nn.Linear(last_dim, nh)) 19 | last_dim = nh 20 | 21 | self.logic = nn.Linear(last_dim, 1) 22 | self.logic.weight.data.mul_(0.1) 23 | self.logic.bias.data.mul_(0.0) 24 | 25 | def forward(self, x): 26 | for affine in self.affine_layers: 27 | x = self.activation(affine(x)) 28 | 29 | prob = torch.sigmoid(self.logic(x)) 30 | return prob 31 | -------------------------------------------------------------------------------- /models/mlp_policy.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from utils.math import * 4 | 5 | 6 | class Policy(nn.Module): 7 | def __init__(self, state_dim, action_dim, hidden_size=(128, 128), activation='tanh', log_std=0): 8 | super().__init__() 9 | self.is_disc_action = False 10 | if activation == 'tanh': 11 | self.activation = torch.tanh 12 | elif activation == 'relu': 13 | self.activation = torch.relu 14 | elif activation == 'sigmoid': 15 | self.activation = torch.sigmoid 16 | 17 | self.affine_layers = nn.ModuleList() 18 | last_dim = state_dim 19 | for nh in hidden_size: 20 | self.affine_layers.append(nn.Linear(last_dim, nh)) 21 | last_dim = nh 22 | 23 | self.action_mean = nn.Linear(last_dim, action_dim) 24 | self.action_mean.weight.data.mul_(0.1) 25 | self.action_mean.bias.data.mul_(0.0) 26 | 27 | self.action_log_std = nn.Parameter(torch.ones(1, action_dim) * log_std) 28 | 29 | def forward(self, x): 30 | for affine in self.affine_layers: 31 | x = self.activation(affine(x)) 32 | 33 | action_mean = self.action_mean(x) 34 | action_log_std = self.action_log_std.expand_as(action_mean) 35 | action_std = torch.exp(action_log_std) 36 | 37 | return action_mean, action_log_std, action_std 38 | 39 | def select_action(self, x): 40 | action_mean, _, action_std = self.forward(x) 41 | action = torch.normal(action_mean, action_std) 42 | return action 43 | 44 | def get_kl(self, x): 45 | mean1, log_std1, std1 = self.forward(x) 46 | 47 | mean0 = mean1.detach() 48 | log_std0 = log_std1.detach() 49 | std0 = std1.detach() 50 | kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5 51 | return kl.sum(1, keepdim=True) 52 | 53 | def get_log_prob(self, x, actions): 54 | action_mean, action_log_std, action_std = self.forward(x) 55 | return normal_log_density(actions, action_mean, action_log_std, action_std) 56 | 57 | def get_fim(self, x): 58 | mean, _, _ = self.forward(x) 59 | cov_inv = self.action_log_std.exp().pow(-2).squeeze(0).repeat(x.size(0)) 60 | param_count = 0 61 | std_index = 0 62 | id = 0 63 | for name, param in self.named_parameters(): 64 | if name == "action_log_std": 65 | std_id = id 66 | std_index = param_count 67 | param_count += param.view(-1).shape[0] 68 | id += 1 69 | return cov_inv.detach(), mean, {'std_id': std_id, 'std_index': std_index} 70 | 71 | 72 | -------------------------------------------------------------------------------- /models/mlp_policy_disc.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from utils.math import * 4 | 5 | 6 | class DiscretePolicy(nn.Module): 7 | def __init__(self, state_dim, action_num, hidden_size=(128, 128), activation='tanh'): 8 | super().__init__() 9 | self.is_disc_action = True 10 | if activation == 'tanh': 11 | self.activation = torch.tanh 12 | elif activation == 'relu': 13 | self.activation = torch.relu 14 | elif activation == 'sigmoid': 15 | self.activation = torch.sigmoid 16 | 17 | self.affine_layers = nn.ModuleList() 18 | last_dim = state_dim 19 | for nh in hidden_size: 20 | self.affine_layers.append(nn.Linear(last_dim, nh)) 21 | last_dim = nh 22 | 23 | self.action_head = nn.Linear(last_dim, action_num) 24 | self.action_head.weight.data.mul_(0.1) 25 | self.action_head.bias.data.mul_(0.0) 26 | 27 | def forward(self, x): 28 | for affine in self.affine_layers: 29 | x = self.activation(affine(x)) 30 | 31 | action_prob = torch.softmax(self.action_head(x), dim=1) 32 | return action_prob 33 | 34 | def select_action(self, x): 35 | action_prob = self.forward(x) 36 | action = action_prob.multinomial(1) 37 | return action 38 | 39 | def get_kl(self, x): 40 | action_prob1 = self.forward(x) 41 | action_prob0 = action_prob1.detach() 42 | kl = action_prob0 * (torch.log(action_prob0) - torch.log(action_prob1)) 43 | return kl.sum(1, keepdim=True) 44 | 45 | def get_log_prob(self, x, actions): 46 | action_prob = self.forward(x) 47 | return torch.log(action_prob.gather(1, actions.long().unsqueeze(1))) 48 | 49 | def get_fim(self, x): 50 | action_prob = self.forward(x) 51 | M = action_prob.pow(-1).view(-1).detach() 52 | return M, action_prob, {} 53 | 54 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.replay_memory import * 2 | from utils.zfilter import * 3 | from utils.torch import * 4 | from utils.math import * 5 | from utils.tools import * 6 | -------------------------------------------------------------------------------- /utils/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | def normal_entropy(std): 6 | var = std.pow(2) 7 | entropy = 0.5 + 0.5 * torch.log(2 * var * math.pi) 8 | return entropy.sum(1, keepdim=True) 9 | 10 | 11 | def normal_log_density(x, mean, log_std, std): 12 | var = std.pow(2) 13 | log_density = -(x - mean).pow(2) / (2 * var) - 0.5 * math.log(2 * math.pi) - log_std 14 | return log_density.sum(1, keepdim=True) 15 | -------------------------------------------------------------------------------- /utils/replay_memory.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import random 3 | 4 | # Taken from 5 | # https://github.com/pytorch/tutorials/blob/master/Reinforcement%20(Q-)Learning%20with%20PyTorch.ipynb 6 | 7 | Transition = namedtuple('Transition', ('state', 'action', 'mask', 'next_state', 8 | 'reward')) 9 | 10 | 11 | class Memory(object): 12 | def __init__(self): 13 | self.memory = [] 14 | 15 | def push(self, *args): 16 | """Saves a transition.""" 17 | self.memory.append(Transition(*args)) 18 | 19 | def sample(self, batch_size=None): 20 | if batch_size is None: 21 | return Transition(*zip(*self.memory)) 22 | else: 23 | random_batch = random.sample(self.memory, batch_size) 24 | return Transition(*zip(*random_batch)) 25 | 26 | def append(self, new_memory): 27 | self.memory += new_memory.memory 28 | 29 | def __len__(self): 30 | return len(self.memory) 31 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | 3 | 4 | def assets_dir(): 5 | return path.abspath(path.join(path.dirname(path.abspath(__file__)), '../assets')) 6 | -------------------------------------------------------------------------------- /utils/torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | tensor = torch.tensor 5 | DoubleTensor = torch.DoubleTensor 6 | FloatTensor = torch.FloatTensor 7 | LongTensor = torch.LongTensor 8 | ByteTensor = torch.ByteTensor 9 | ones = torch.ones 10 | zeros = torch.zeros 11 | 12 | 13 | def to_device(device, *args): 14 | return [x.to(device) for x in args] 15 | 16 | 17 | def get_flat_params_from(model): 18 | params = [] 19 | for param in model.parameters(): 20 | params.append(param.view(-1)) 21 | 22 | flat_params = torch.cat(params) 23 | return flat_params 24 | 25 | 26 | def set_flat_params_to(model, flat_params): 27 | prev_ind = 0 28 | for param in model.parameters(): 29 | flat_size = int(np.prod(list(param.size()))) 30 | param.data.copy_( 31 | flat_params[prev_ind:prev_ind + flat_size].view(param.size())) 32 | prev_ind += flat_size 33 | 34 | 35 | def get_flat_grad_from(inputs, grad_grad=False): 36 | grads = [] 37 | for param in inputs: 38 | if grad_grad: 39 | grads.append(param.grad.grad.view(-1)) 40 | else: 41 | if param.grad is None: 42 | grads.append(zeros(param.view(-1).shape)) 43 | else: 44 | grads.append(param.grad.view(-1)) 45 | 46 | flat_grad = torch.cat(grads) 47 | return flat_grad 48 | 49 | 50 | def compute_flat_grad(output, inputs, filter_input_ids=set(), retain_graph=False, create_graph=False): 51 | if create_graph: 52 | retain_graph = True 53 | 54 | inputs = list(inputs) 55 | params = [] 56 | for i, param in enumerate(inputs): 57 | if i not in filter_input_ids: 58 | params.append(param) 59 | 60 | grads = torch.autograd.grad(output, params, retain_graph=retain_graph, create_graph=create_graph) 61 | 62 | j = 0 63 | out_grads = [] 64 | for i, param in enumerate(inputs): 65 | if i in filter_input_ids: 66 | out_grads.append(zeros(param.view(-1).shape, device=param.device, dtype=param.dtype)) 67 | else: 68 | out_grads.append(grads[j].view(-1)) 69 | j += 1 70 | grads = torch.cat(out_grads) 71 | 72 | for param in params: 73 | param.grad = None 74 | return grads 75 | -------------------------------------------------------------------------------- /utils/zfilter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # from https://github.com/joschu/modular_rl 4 | # http://www.johndcook.com/blog/standard_deviation/ 5 | 6 | 7 | class RunningStat(object): 8 | def __init__(self, shape): 9 | self._n = 0 10 | self._M = np.zeros(shape) 11 | self._S = np.zeros(shape) 12 | 13 | def push(self, x): 14 | x = np.asarray(x) 15 | assert x.shape == self._M.shape 16 | self._n += 1 17 | if self._n == 1: 18 | self._M[...] = x 19 | else: 20 | oldM = self._M.copy() 21 | self._M[...] = oldM + (x - oldM) / self._n 22 | self._S[...] = self._S + (x - oldM) * (x - self._M) 23 | 24 | @property 25 | def n(self): 26 | return self._n 27 | 28 | @property 29 | def mean(self): 30 | return self._M 31 | 32 | @property 33 | def var(self): 34 | return self._S / (self._n - 1) if self._n > 1 else np.square(self._M) 35 | 36 | @property 37 | def std(self): 38 | return np.sqrt(self.var) 39 | 40 | @property 41 | def shape(self): 42 | return self._M.shape 43 | 44 | 45 | class ZFilter: 46 | """ 47 | y = (x-mean)/std 48 | using running estimates of mean,std 49 | """ 50 | 51 | def __init__(self, shape, demean=True, destd=True, clip=10.0): 52 | self.demean = demean 53 | self.destd = destd 54 | self.clip = clip 55 | 56 | self.rs = RunningStat(shape) 57 | self.fix = False 58 | 59 | def __call__(self, x, update=True): 60 | if update and not self.fix: 61 | self.rs.push(x) 62 | if self.demean: 63 | x = x - self.rs.mean 64 | if self.destd: 65 | x = x / (self.rs.std + 1e-8) 66 | if self.clip: 67 | x = np.clip(x, -self.clip, self.clip) 68 | return x 69 | 70 | --------------------------------------------------------------------------------