├── LICENSE ├── README.md └── mujoco ├── 2iwil.py ├── conjugate_gradients.py ├── demonstrations ├── Ant-v2_stage1.npy ├── Ant-v2_stage2.npy ├── HalfCheetah-v2_stage1.npy ├── HalfCheetah-v2_stage2.npy ├── Hopper-v2_stage1.npy ├── Hopper-v2_stage2.npy ├── Walker2d-v2_stage1_conf.npy └── Walker2d-v2_stage2.npy ├── gail.py ├── loss.py ├── models.py ├── replay_memory.py ├── reward_model ├── ant_reward_stage1.pth ├── ant_reward_stage2.pth ├── halfcheetah_reward_stage1.pth ├── halfcheetah_reward_stage2.pth ├── hopper_reward_stage1.pth ├── hopper_reward_stage2.pth ├── walker2d_reward_stage1.pth └── walker2d_reward_stage2.pth ├── running_state.py ├── trpo.py ├── trpo_irl.py ├── utils.py └── wgail.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yunke 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning to Weight Imperfect Demonstrations 2 | 3 | This repository contains the PyTorch code for the paper "Learning to Weight Imperfect Demonstrations" in ICML 2021. Code for Atari experiments can be found in [this repo](https://github.com/naivety77/gail_atari). 4 | 5 | ## Requirement 6 | * Python 3.7 7 | * torch 1.3.1 8 | * gym 0.15.7 9 | * mujoco 2.0.2.4 10 | * numpy 1.16.2 11 | 12 | ## Execute 13 | * WGAIL 14 | ``` 15 | python wgail.py --env Ant-v2 --num-epochs 5000 --traj-size 1000 --stage 2 16 | ``` 17 | * 2IWIL 18 | ``` 19 | python 2iwil.py --env Ant-v2 --num-epochs 5000 --traj-size 1000 --stage 2 20 | ``` 21 | * GAIL 22 | ``` 23 | python gail.py --env Ant-v2 --num-epochs 5000 --traj-size 1000 --stage 2 24 | ``` 25 | * T-REX 26 | ``` 27 | python trpo_irl.py --env Ant-v2 --num-epochs 5000 --reward-path 'reward_model/ant_reward_stage2.pth' --stage 2 28 | ``` 29 | The re-implementation of T-REX/D-REX can be found in [SAIL](https://github.com/naivety77/SAIL). 30 | 31 | ## Acknowledegement 32 | We would like to thank the authors of [2IWIL/IC-GAIL](https://github.com/kristery/Imitation-Learning-from-Imperfect-Demonstration). Our code structure is largely based on their source code. 33 | -------------------------------------------------------------------------------- /mujoco/2iwil.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from itertools import count 3 | 4 | import gym 5 | import gym.spaces 6 | import scipy.optimize 7 | import numpy as np 8 | import math 9 | from tqdm import tqdm 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import torch.nn.functional as F 14 | from models import * 15 | from replay_memory import Memory 16 | from running_state import ZFilter 17 | from torch.autograd import Variable 18 | from trpo import trpo_step 19 | from utils import * 20 | from loss import * 21 | 22 | torch.utils.backcompat.broadcast_warning.enabled = True 23 | torch.utils.backcompat.keepdim_warning.enabled = True 24 | 25 | torch.set_default_tensor_type('torch.DoubleTensor') 26 | device = torch.device("cpu") 27 | parser = argparse.ArgumentParser(description='PyTorch actor-critic example') 28 | parser.add_argument('--gamma', type=float, default=0.995, metavar='G', 29 | help='discount factor (default: 0.995)') 30 | parser.add_argument('--env', type=str, default="Walker2d-v2", metavar='G', 31 | help='name of the environment to run') 32 | parser.add_argument('--tau', type=float, default=0.97, metavar='G', 33 | help='gae (default: 0.97)') 34 | parser.add_argument('--l2-reg', type=float, default=1e-3, metavar='G', 35 | help='l2 regularization regression (default: 1e-3)') 36 | parser.add_argument('--max-kl', type=float, default=1e-2, metavar='G', 37 | help='max kl value (default: 1e-2)') 38 | parser.add_argument('--damping', type=float, default=1e-1, metavar='G', 39 | help='damping (default: 1e-1)') 40 | parser.add_argument('--seed', type=int, default=1111, metavar='N', 41 | help='random seed (default: 1111') 42 | parser.add_argument('--batch-size', type=int, default=5000, metavar='N', 43 | help='size of a single batch') 44 | parser.add_argument('--log-interval', type=int, default=1, metavar='N', 45 | help='interval between training status logs (default: 10)') 46 | parser.add_argument('--fname', type=str, default='expert', metavar='F', 47 | help='the file name to save trajectory') 48 | parser.add_argument('--num-epochs', type=int, default=8000, metavar='N', 49 | help='number of epochs to train an expert') 50 | parser.add_argument('--hidden-dim', type=int, default=100, metavar='H', 51 | help='the size of hidden layers') 52 | parser.add_argument('--lr', type=float, default=1e-3, metavar='L', 53 | help='learning rate') 54 | parser.add_argument('--weight', action='store_true', 55 | help='consider confidence into loss') 56 | parser.add_argument('--vf-iters', type=int, default=30, metavar='V', 57 | help='number of iterations of value function optimization iterations per each policy optimization step') 58 | parser.add_argument('--vf-lr', type=float, default=3e-4, metavar='V', 59 | help='learning rate of value network') 60 | parser.add_argument('--eval-epochs', type=int, default=3, metavar='E', 61 | help='epochs to evaluate model') 62 | parser.add_argument('--traj-size', type=int, default=2000) 63 | parser.add_argument('--ifolder', type=str, default='demonstrations') 64 | parser.add_argument('--optimal-policy', type=float, default=4018.19) 65 | parser.add_argument('--random-policy', type=float, default=249.50) 66 | parser.add_argument('--method', type=str, default='2iwil') 67 | parser.add_argument('--stage', type=int, default=1) 68 | 69 | args = parser.parse_args() 70 | env = gym.make(args.env) 71 | 72 | num_inputs = env.observation_space.shape[0] 73 | num_actions = env.action_space.shape[0] 74 | 75 | env.seed(args.seed) 76 | torch.manual_seed(args.seed) 77 | np.random.seed(args.seed) 78 | 79 | policy_net = Policy(num_inputs, num_actions, args.hidden_dim) 80 | value_net = Value(num_inputs, args.hidden_dim).to(device) 81 | discriminator = Discriminator(num_inputs + num_actions, args.hidden_dim).to(device) 82 | disc_criterion = nn.BCEWithLogitsLoss() 83 | value_criterion = nn.MSELoss() 84 | disc_optimizer = optim.Adam(discriminator.parameters(), args.lr) 85 | value_optimizer = optim.Adam(value_net.parameters(), args.vf_lr) 86 | 87 | 88 | def max_normalization(x): 89 | x = (x - args.random_policy) / (args.optimal_policy - args.random_policy) 90 | return x 91 | 92 | 93 | def select_action(state): 94 | state = torch.from_numpy(state).unsqueeze(0) 95 | # action_mean, _, action_std = policy_net(Variable(state)) 96 | action_mean, _, action_std = policy_net(Variable(state)) 97 | action = torch.normal(action_mean, action_std) 98 | return action 99 | 100 | 101 | def update_params(batch): 102 | rewards = torch.Tensor(batch.reward).to(device) 103 | masks = torch.Tensor(batch.mask).to(device) 104 | actions = torch.Tensor(np.concatenate(batch.action, 0)).to(device) 105 | states = torch.Tensor(batch.state).to(device) 106 | values = value_net(Variable(states)) 107 | 108 | returns = torch.Tensor(actions.size(0), 1).to(device) 109 | deltas = torch.Tensor(actions.size(0), 1).to(device) 110 | advantages = torch.Tensor(actions.size(0), 1).to(device) 111 | 112 | prev_return = 0 113 | prev_value = 0 114 | prev_advantage = 0 115 | for i in reversed(range(rewards.size(0))): 116 | returns[i] = rewards[i] + args.gamma * prev_return * masks[i] 117 | deltas[i] = rewards[i] + args.gamma * prev_value * masks[i] - values.data[i] 118 | advantages[i] = deltas[i] + args.gamma * args.tau * prev_advantage * masks[i] 119 | 120 | prev_return = returns[i, 0] 121 | prev_value = values.data[i, 0] 122 | prev_advantage = advantages[i, 0] 123 | 124 | targets = Variable(returns) 125 | 126 | batch_size = math.ceil(states.shape[0] / args.vf_iters) 127 | idx = np.random.permutation(states.shape[0]) 128 | for i in range(args.vf_iters): 129 | smp_idx = idx[i * batch_size: (i + 1) * batch_size] 130 | smp_states = states[smp_idx, :] 131 | smp_targets = targets[smp_idx, :] 132 | 133 | value_optimizer.zero_grad() 134 | value_loss = value_criterion(value_net(Variable(smp_states)), smp_targets) 135 | value_loss.backward() 136 | value_optimizer.step() 137 | 138 | advantages = (advantages - advantages.mean()) / advantages.std() 139 | 140 | action_means, action_log_stds, action_stds = policy_net(Variable(states.cpu())) 141 | fixed_log_prob = normal_log_density(Variable(actions.cpu()), action_means, action_log_stds, 142 | action_stds).data.clone() 143 | 144 | def get_loss(): 145 | action_means, action_log_stds, action_stds = policy_net(Variable(states.cpu())) 146 | log_prob = normal_log_density(Variable(actions.cpu()), action_means, action_log_stds, action_stds) 147 | action_loss = -Variable(advantages.cpu()) * torch.exp(log_prob - Variable(fixed_log_prob)) 148 | return action_loss.mean() 149 | 150 | def get_kl(): 151 | mean1, log_std1, std1 = policy_net(Variable(states.cpu())) 152 | 153 | mean0 = Variable(mean1.data) 154 | log_std0 = Variable(log_std1.data) 155 | std0 = Variable(std1.data) 156 | kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5 157 | return kl.sum(1, keepdim=True) 158 | 159 | trpo_step(policy_net, get_loss, get_kl, args.max_kl, args.damping) 160 | 161 | 162 | def expert_reward(states, actions): 163 | states = np.concatenate(states) 164 | actions = np.concatenate(actions) 165 | state_action = torch.Tensor(np.concatenate([states, actions], 1)).to(device) 166 | return -F.logsigmoid(discriminator(state_action)).cpu().detach().numpy() 167 | 168 | 169 | def evaluate(episode): 170 | avg_reward = 0.0 171 | avg_dist = 0.0 172 | for _ in range(args.eval_epochs): 173 | state = env.reset() 174 | unwrapped = env 175 | while hasattr(unwrapped, 'env'): 176 | unwrapped = unwrapped.env 177 | 178 | for _ in range(10000): # Don't infinite loop while learning 179 | state = torch.from_numpy(state).unsqueeze(0) 180 | action, _, _ = policy_net(Variable(state)) 181 | action = action.data[0].numpy() 182 | next_state, reward, done, _ = env.step(action) 183 | avg_reward += reward 184 | if done: 185 | break 186 | state = next_state 187 | avg_dist += unwrapped.sim.data.qpos[0] 188 | 189 | reward_eva_in = avg_reward / args.eval_epochs 190 | reward_eva_in_normal = max_normalization(reward_eva_in) 191 | avg_dist = avg_dist / args.eval_epochs 192 | writer.log(episode, reward_eva_in, reward_eva_in_normal, avg_dist) 193 | return reward_eva_in, reward_eva_in_normal, avg_dist 194 | 195 | plabel = '' 196 | try: 197 | expert_traj = np.load("./{}/{}_stage{}.npy".format(args.ifolder, args.env, args.stage)) 198 | expert_conf = np.load("./{}/{}_stage{}_conf.npy".format(args.ifolder, args.env, args.stage)) 199 | except: 200 | print('Mixture demonstrations not loaded successfully.') 201 | assert False 202 | print("./{}/{}_stage{}.npy".format(args.ifolder, args.env, args.stage)) 203 | 204 | idx = np.random.choice(expert_traj.shape[0], args.traj_size, replace=False) 205 | expert_traj = expert_traj[idx, :] 206 | expert_conf = expert_conf[idx, :] 207 | 208 | ##### semi-confidence learning ##### 209 | num_label = int(args.prior * expert_conf.shape[0]) 210 | 211 | p_idx = np.random.permutation(expert_traj.shape[0]) 212 | expert_traj = expert_traj[p_idx, :] 213 | expert_conf = expert_conf[p_idx, :] 214 | 215 | labeled_traj = torch.Tensor(expert_traj[:num_label, :]).to(device) 216 | unlabeled_traj = torch.Tensor(expert_traj[num_label:, :]).to(device) 217 | label = torch.Tensor(expert_conf[:num_label, :]).to(device) 218 | 219 | classifier = Classifier(expert_traj.shape[1], 40).to(device) 220 | optim = optim.Adam(classifier.parameters(), 5e-4, amsgrad=True) 221 | cu_loss = CULoss(expert_conf, beta=1 - args.prior, non=True) 222 | 223 | batch = min(128, labeled_traj.shape[0]) 224 | ubatch = int(batch / labeled_traj.shape[0] * unlabeled_traj.shape[0]) 225 | iters = 25000 226 | for i in range(iters): 227 | l_idx = np.random.choice(labeled_traj.shape[0], batch) 228 | u_idx = np.random.choice(unlabeled_traj.shape[0], ubatch) 229 | 230 | labeled = classifier(Variable(labeled_traj[l_idx, :])) 231 | unlabeled = classifier(Variable(unlabeled_traj[u_idx, :])) 232 | smp_conf = Variable(label[l_idx, :]) 233 | 234 | optim.zero_grad() 235 | risk = cu_loss(smp_conf, labeled, unlabeled) 236 | 237 | risk.backward() 238 | optim.step() 239 | 240 | if i % 1000 == 0: 241 | print('iteration: {}\tcu loss: {:.3f}'.format(i, risk.data.item())) 242 | 243 | classifier = classifier.eval() 244 | expert_conf = torch.sigmoid(classifier(torch.Tensor(expert_traj).to(device))).detach().cpu().numpy() 245 | expert_conf[:num_label, :] = label.cpu().detach().numpy() 246 | ################################### 247 | Z = expert_conf.mean() 248 | print(Z) 249 | 250 | fname = '' 251 | method = '2iwil_{}_stage{}'.format(args.traj_size,args.stage) 252 | from pathlib import Path 253 | import os 254 | 255 | logdir = Path(os.path.abspath(os.path.join('2iwil', str(args.env), method, str(args.seed)))) 256 | if logdir.exists(): 257 | print('orinal logdir is already exist.') 258 | 259 | writer = Writer(args.env, args.seed, args.weight, 'stage{}'.format(args.stage), args.traj_size, folder=str(logdir)) 260 | 261 | 262 | for i_episode in tqdm(range(args.num_epochs), dynamic_ncols=True): 263 | memory = Memory() 264 | 265 | num_steps = 0 266 | num_episodes = 0 267 | 268 | reward_batch = [] 269 | states = [] 270 | actions = [] 271 | mem_actions = [] 272 | mem_mask = [] 273 | mem_next = [] 274 | 275 | while num_steps < args.batch_size: 276 | state = env.reset() 277 | 278 | reward_sum = 0 279 | for t in range(10000): # Don't infinite loop while learning 280 | action = select_action(state) 281 | action = action.data[0].numpy() 282 | states.append(np.array([state])) 283 | actions.append(np.array([action])) 284 | next_state, true_reward, done, _ = env.step(action) 285 | reward_sum += true_reward 286 | 287 | mask = 1 288 | if done: 289 | mask = 0 290 | 291 | mem_mask.append(mask) 292 | mem_next.append(next_state) 293 | 294 | # env.render() 295 | if done: 296 | break 297 | 298 | state = next_state 299 | num_steps += (t - 1) 300 | num_episodes += 1 301 | 302 | reward_batch.append(reward_sum) 303 | 304 | # 2. evaluate distance 305 | reward_eva, reward_eva_normal, dist = evaluate(i_episode) 306 | 307 | rewards = expert_reward(states, actions) 308 | 309 | for idx in range(len(states)): 310 | memory.push(states[idx][0], actions[idx], mem_mask[idx], mem_next[idx], \ 311 | rewards[idx][0]) 312 | batch = memory.sample() 313 | update_params(batch) 314 | 315 | ### update discriminator ### 316 | actions = torch.from_numpy(np.concatenate(actions)) 317 | states = torch.from_numpy(np.concatenate(states)) 318 | 319 | idx = np.random.randint(0, expert_traj.shape[0], num_steps) 320 | expert_state_action = torch.Tensor(expert_traj[idx, :]) 321 | expert_conf_batch = expert_conf[idx, :] 322 | 323 | state_action = torch.cat((states, actions), 1).to(device) 324 | 325 | fake = discriminator(state_action) 326 | real = discriminator(expert_state_action) 327 | 328 | disc_optimizer.zero_grad() 329 | weighted_loss = nn.BCEWithLogitsLoss(weight=torch.Tensor(expert_conf_batch/Z)) 330 | disc_loss = disc_criterion(fake, torch.ones(states.shape[0], 1).to(device)) + \ 331 | weighted_loss(real, torch.zeros(real.shape[0], 1).to(device)) 332 | 333 | disc_loss.backward() 334 | disc_optimizer.step() 335 | 336 | if i_episode % args.log_interval == 0: 337 | tqdm.write( 338 | 'Episode {}\tReward: {:.2f}\tReward_nor: {:.2f}\tAverage distance: {:.2f}\tLoss: {:.2f}'.format(i_episode, reward_eva, 339 | reward_eva_normal, dist, disc_loss.detach().numpy())) 340 | -------------------------------------------------------------------------------- /mujoco/conjugate_gradients.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def conjugate_gradients(Avp, b, nsteps, residual_tol=1e-10): 5 | x = torch.zeros(b.size()) 6 | r = b - Avp(x) 7 | p = r 8 | rdotr = torch.dot(r, r) 9 | 10 | for i in range(nsteps): 11 | _Avp = Avp(p) 12 | alpha = rdotr / torch.dot(p, _Avp) 13 | x += alpha * p 14 | r -= alpha * _Avp 15 | new_rdotr = torch.dot(r, r) 16 | betta = new_rdotr / rdotr 17 | p = r + betta * p 18 | rdotr = new_rdotr 19 | if rdotr < residual_tol: 20 | break 21 | return x 22 | 23 | 24 | def flat_grad_from(net, grad_grad=False): 25 | grads = [] 26 | for param in net.parameters(): 27 | if grad_grad: 28 | grads.append(param.grad.grad.view(-1)) 29 | else: 30 | grads.append(param.grad.view(-1)) 31 | 32 | flat_grad = torch.cat(grads) 33 | return flat_grad 34 | -------------------------------------------------------------------------------- /mujoco/demonstrations/Ant-v2_stage1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunke-wang/WGAIL/d737ccaff3e878c9d347cfbcbf8974eaf1d2e8f5/mujoco/demonstrations/Ant-v2_stage1.npy -------------------------------------------------------------------------------- /mujoco/demonstrations/Ant-v2_stage2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunke-wang/WGAIL/d737ccaff3e878c9d347cfbcbf8974eaf1d2e8f5/mujoco/demonstrations/Ant-v2_stage2.npy -------------------------------------------------------------------------------- /mujoco/demonstrations/HalfCheetah-v2_stage1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunke-wang/WGAIL/d737ccaff3e878c9d347cfbcbf8974eaf1d2e8f5/mujoco/demonstrations/HalfCheetah-v2_stage1.npy -------------------------------------------------------------------------------- /mujoco/demonstrations/HalfCheetah-v2_stage2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunke-wang/WGAIL/d737ccaff3e878c9d347cfbcbf8974eaf1d2e8f5/mujoco/demonstrations/HalfCheetah-v2_stage2.npy -------------------------------------------------------------------------------- /mujoco/demonstrations/Hopper-v2_stage1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunke-wang/WGAIL/d737ccaff3e878c9d347cfbcbf8974eaf1d2e8f5/mujoco/demonstrations/Hopper-v2_stage1.npy -------------------------------------------------------------------------------- /mujoco/demonstrations/Hopper-v2_stage2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunke-wang/WGAIL/d737ccaff3e878c9d347cfbcbf8974eaf1d2e8f5/mujoco/demonstrations/Hopper-v2_stage2.npy -------------------------------------------------------------------------------- /mujoco/demonstrations/Walker2d-v2_stage1_conf.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunke-wang/WGAIL/d737ccaff3e878c9d347cfbcbf8974eaf1d2e8f5/mujoco/demonstrations/Walker2d-v2_stage1_conf.npy -------------------------------------------------------------------------------- /mujoco/demonstrations/Walker2d-v2_stage2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunke-wang/WGAIL/d737ccaff3e878c9d347cfbcbf8974eaf1d2e8f5/mujoco/demonstrations/Walker2d-v2_stage2.npy -------------------------------------------------------------------------------- /mujoco/gail.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from itertools import count 3 | 4 | import gym 5 | import gym.spaces 6 | import scipy.optimize 7 | import numpy as np 8 | import math 9 | from tqdm import tqdm 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import torch.nn.functional as F 14 | from models import * 15 | from replay_memory import Memory 16 | from running_state import ZFilter 17 | from torch.autograd import Variable 18 | from trpo import trpo_step 19 | from utils import * 20 | from loss import * 21 | 22 | torch.utils.backcompat.broadcast_warning.enabled = True 23 | torch.utils.backcompat.keepdim_warning.enabled = True 24 | 25 | torch.set_default_tensor_type('torch.DoubleTensor') 26 | device = torch.device("cpu") 27 | parser = argparse.ArgumentParser(description='PyTorch actor-critic example') 28 | parser.add_argument('--gamma', type=float, default=0.995, metavar='G', 29 | help='discount factor (default: 0.995)') 30 | parser.add_argument('--env', type=str, default="Ant-v2", metavar='G', 31 | help='name of the environment to run') 32 | parser.add_argument('--tau', type=float, default=0.97, metavar='G', 33 | help='gae (default: 0.97)') 34 | parser.add_argument('--l2-reg', type=float, default=1e-3, metavar='G', 35 | help='l2 regularization regression (default: 1e-3)') 36 | parser.add_argument('--max-kl', type=float, default=1e-2, metavar='G', 37 | help='max kl value (default: 1e-2)') 38 | parser.add_argument('--damping', type=float, default=1e-1, metavar='G', 39 | help='damping (default: 1e-1)') 40 | parser.add_argument('--seed', type=int, default=1111, metavar='N', 41 | help='random seed (default: 1111') 42 | parser.add_argument('--batch-size', type=int, default=5000, metavar='N', 43 | help='size of a single batch') 44 | parser.add_argument('--log-interval', type=int, default=1, metavar='N', 45 | help='interval between training status logs (default: 10)') 46 | parser.add_argument('--fname', type=str, default='expert', metavar='F', 47 | help='the file name to save trajectory') 48 | parser.add_argument('--num-epochs', type=int, default=5000, metavar='N', 49 | help='number of epochs to train an expert') 50 | parser.add_argument('--hidden-dim', type=int, default=100, metavar='H', 51 | help='the size of hidden layers') 52 | parser.add_argument('--lr', type=float, default=1e-3, metavar='L', 53 | help='learning rate') 54 | parser.add_argument('--vf-iters', type=int, default=30, metavar='V', 55 | help='number of iterations of value function optimization iterations per each policy optimization step') 56 | parser.add_argument('--vf-lr', type=float, default=3e-4, metavar='V', 57 | help='learning rate of value network') 58 | parser.add_argument('--eval-epochs', type=int, default=3, metavar='E', 59 | help='epochs to evaluate model') 60 | parser.add_argument('--traj-size', type=int, default=2000) 61 | parser.add_argument('--ifolder', type=str, default='demonstrations') 62 | parser.add_argument('--optimal-policy', type=float, default=4143.10) 63 | parser.add_argument('--random-policy', type=float, default=-72.30) 64 | parser.add_argument('--stage', type=int, default=1) 65 | 66 | args = parser.parse_args() 67 | env = gym.make(args.env) 68 | 69 | num_inputs = env.observation_space.shape[0] 70 | num_actions = env.action_space.shape[0] 71 | 72 | env.seed(args.seed) 73 | torch.manual_seed(args.seed) 74 | np.random.seed(args.seed) 75 | 76 | policy_net = Policy(num_inputs, num_actions, args.hidden_dim) 77 | value_net = Value(num_inputs, args.hidden_dim).to(device) 78 | discriminator = Discriminator(num_inputs + num_actions, args.hidden_dim).to(device) 79 | disc_criterion = nn.BCEWithLogitsLoss() 80 | value_criterion = nn.MSELoss() 81 | disc_optimizer = optim.Adam(discriminator.parameters(), args.lr) 82 | value_optimizer = optim.Adam(value_net.parameters(), args.vf_lr) 83 | 84 | 85 | def max_normalization(x): 86 | x = (x - args.random_policy) / (args.optimal_policy - args.random_policy) 87 | return x 88 | 89 | 90 | def select_action(state): 91 | state = torch.from_numpy(state).unsqueeze(0) 92 | # action_mean, _, action_std = policy_net(Variable(state)) 93 | action_mean, _, action_std = policy_net(Variable(state)) 94 | action = torch.normal(action_mean, action_std) 95 | return action 96 | 97 | 98 | def update_params(batch): 99 | rewards = torch.Tensor(batch.reward).to(device) 100 | masks = torch.Tensor(batch.mask).to(device) 101 | actions = torch.Tensor(np.concatenate(batch.action, 0)).to(device) 102 | states = torch.Tensor(batch.state).to(device) 103 | values = value_net(Variable(states)) 104 | 105 | returns = torch.Tensor(actions.size(0), 1).to(device) 106 | deltas = torch.Tensor(actions.size(0), 1).to(device) 107 | advantages = torch.Tensor(actions.size(0), 1).to(device) 108 | 109 | prev_return = 0 110 | prev_value = 0 111 | prev_advantage = 0 112 | for i in reversed(range(rewards.size(0))): 113 | returns[i] = rewards[i] + args.gamma * prev_return * masks[i] 114 | deltas[i] = rewards[i] + args.gamma * prev_value * masks[i] - values.data[i] 115 | advantages[i] = deltas[i] + args.gamma * args.tau * prev_advantage * masks[i] 116 | 117 | prev_return = returns[i, 0] 118 | prev_value = values.data[i, 0] 119 | prev_advantage = advantages[i, 0] 120 | 121 | targets = Variable(returns) 122 | 123 | batch_size = math.ceil(states.shape[0] / args.vf_iters) 124 | idx = np.random.permutation(states.shape[0]) 125 | for i in range(args.vf_iters): 126 | smp_idx = idx[i * batch_size: (i + 1) * batch_size] 127 | smp_states = states[smp_idx, :] 128 | smp_targets = targets[smp_idx, :] 129 | 130 | value_optimizer.zero_grad() 131 | value_loss = value_criterion(value_net(Variable(smp_states)), smp_targets) 132 | value_loss.backward() 133 | value_optimizer.step() 134 | 135 | advantages = (advantages - advantages.mean()) / advantages.std() 136 | 137 | action_means, action_log_stds, action_stds = policy_net(Variable(states.cpu())) 138 | fixed_log_prob = normal_log_density(Variable(actions.cpu()), action_means, action_log_stds, 139 | action_stds).data.clone() 140 | 141 | def get_loss(): 142 | action_means, action_log_stds, action_stds = policy_net(Variable(states.cpu())) 143 | log_prob = normal_log_density(Variable(actions.cpu()), action_means, action_log_stds, action_stds) 144 | action_loss = -Variable(advantages.cpu()) * torch.exp(log_prob - Variable(fixed_log_prob)) 145 | return action_loss.mean() 146 | 147 | def get_kl(): 148 | mean1, log_std1, std1 = policy_net(Variable(states.cpu())) 149 | 150 | mean0 = Variable(mean1.data) 151 | log_std0 = Variable(log_std1.data) 152 | std0 = Variable(std1.data) 153 | kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5 154 | return kl.sum(1, keepdim=True) 155 | 156 | trpo_step(policy_net, get_loss, get_kl, args.max_kl, args.damping) 157 | 158 | 159 | def expert_reward(states, actions): 160 | states = np.concatenate(states) 161 | actions = np.concatenate(actions) 162 | state_action = torch.Tensor(np.concatenate([states, actions], 1)).to(device) 163 | return -F.logsigmoid(discriminator(state_action)).cpu().detach().numpy() 164 | 165 | 166 | def evaluate(episode): 167 | avg_reward = 0.0 168 | avg_dist = 0.0 169 | for _ in range(args.eval_epochs): 170 | state = env.reset() 171 | unwrapped = env 172 | while hasattr(unwrapped, 'env'): 173 | unwrapped = unwrapped.env 174 | 175 | for _ in range(10000): # Don't infinite loop while learning 176 | state = torch.from_numpy(state).unsqueeze(0) 177 | action, _, _ = policy_net(Variable(state)) 178 | action = action.data[0].numpy() 179 | next_state, reward, done, _ = env.step(action) 180 | avg_reward += reward 181 | if done: 182 | break 183 | state = next_state 184 | avg_dist += unwrapped.sim.data.qpos[0] 185 | 186 | reward_eva_in = avg_reward / args.eval_epochs 187 | reward_eva_in_normal = max_normalization(reward_eva_in) 188 | avg_dist = avg_dist / args.eval_epochs 189 | writer.log(episode, reward_eva_in, reward_eva_in_normal, avg_dist) 190 | return reward_eva_in, reward_eva_in_normal, avg_dist 191 | 192 | 193 | try: 194 | expert_traj = np.load("./{}/{}_stage{}.npy".format(args.ifolder, args.env, args.stage)) 195 | except: 196 | print('Mixture demonstrations not loaded successfully.') 197 | assert False 198 | print("./{}/{}_single_level{}.npy".format(args.ifolder, args.env, args.level)) 199 | 200 | idx = np.random.choice(expert_traj.shape[0], args.traj_size, replace=False) 201 | expert_traj = expert_traj[idx, :] 202 | 203 | method = 'gail_{}_stage{}'.format(args.traj_size,args.stage) 204 | from pathlib import Path 205 | import os 206 | 207 | logdir = Path(os.path.abspath(os.path.join('gail', str(args.env), method, str(args.seed)))) 208 | if logdir.exists(): 209 | print('orinal logdir is already exist.') 210 | 211 | writer = Writer(args.env, args.seed, 'stage{}'.format(args.stage), args.traj_size, folder=str(logdir)) 212 | 213 | 214 | for i_episode in tqdm(range(args.num_epochs), dynamic_ncols=True): 215 | memory = Memory() 216 | 217 | num_steps = 0 218 | num_episodes = 0 219 | 220 | reward_batch = [] 221 | states = [] 222 | actions = [] 223 | mem_actions = [] 224 | mem_mask = [] 225 | mem_next = [] 226 | 227 | while num_steps < args.batch_size: 228 | state = env.reset() 229 | 230 | reward_sum = 0 231 | for t in range(10000): # Don't infinite loop while learning 232 | action = select_action(state) 233 | action = action.data[0].numpy() 234 | states.append(np.array([state])) 235 | actions.append(np.array([action])) 236 | next_state, true_reward, done, _ = env.step(action) 237 | reward_sum += true_reward 238 | 239 | mask = 1 240 | if done: 241 | mask = 0 242 | 243 | mem_mask.append(mask) 244 | mem_next.append(next_state) 245 | 246 | # env.render() 247 | if done: 248 | break 249 | 250 | state = next_state 251 | num_steps += (t - 1) 252 | num_episodes += 1 253 | 254 | reward_batch.append(reward_sum) 255 | 256 | # 2. evaluate distance 257 | reward_eva, reward_eva_normal, dist = evaluate(i_episode) 258 | 259 | rewards = expert_reward(states, actions) 260 | 261 | for idx in range(len(states)): 262 | memory.push(states[idx][0], actions[idx], mem_mask[idx], mem_next[idx], \ 263 | rewards[idx][0]) 264 | batch = memory.sample() 265 | update_params(batch) 266 | 267 | actions = torch.from_numpy(np.concatenate(actions)) 268 | states = torch.from_numpy(np.concatenate(states)) 269 | 270 | idx = np.random.randint(0, expert_traj.shape[0], num_steps) 271 | expert_state_action = torch.Tensor(expert_traj[idx, :]).to(device) 272 | 273 | state_action = torch.cat((states, actions), 1).to(device) 274 | 275 | fake = discriminator(state_action) 276 | real = discriminator(expert_state_action) 277 | 278 | disc_optimizer.zero_grad() 279 | disc_loss = disc_criterion(fake, torch.ones(states.shape[0], 1).to(device)) + \ 280 | disc_criterion(real, torch.zeros(real.shape[0], 1).to(device)) 281 | 282 | disc_loss.backward() 283 | disc_optimizer.step() 284 | 285 | 286 | if i_episode % args.log_interval == 0: 287 | tqdm.write( 288 | 'Episode {}\tReward: {:.2f}\tReward_nor: {:.2f}\tAverage distance: {:.2f}'.format(i_episode, reward_eva, 289 | reward_eva_normal, dist)) 290 | -------------------------------------------------------------------------------- /mujoco/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable as V 5 | 6 | device = torch.device("cpu") 7 | class CULoss(nn.Module): 8 | def __init__(self, conf, beta, non=False): 9 | super(CULoss, self).__init__() 10 | self.loss = nn.SoftMarginLoss() 11 | self.beta = beta 12 | self.non = non 13 | if conf.mean() > 0.5: 14 | self.UP = True 15 | else: 16 | self.UP = False 17 | 18 | def forward(self, conf, labeled, unlabeled): 19 | y_conf_pos = self.loss(labeled, torch.ones(labeled.shape).to(device)) 20 | y_conf_neg = self.loss(labeled, -torch.ones(labeled.shape).to(device)) 21 | 22 | if self.UP: 23 | #conf_risk = torch.mean((1-conf) * (y_conf_neg - y_conf_pos) + (1 - self.beta) * y_conf_pos) 24 | unlabeled_risk = torch.mean(self.beta * self.loss(unlabeled, torch.ones(unlabeled.shape).to(device))) 25 | neg_risk = torch.mean((1 - conf) * y_conf_neg) 26 | pos_risk = torch.mean((conf - self.beta) * y_conf_pos) + unlabeled_risk 27 | else: 28 | #conf_risk = torch.mean(conf * (y_conf_pos - y_conf_neg) + (1 - self.beta) * y_conf_neg) 29 | unlabeled_risk = torch.mean(self.beta * self.loss(unlabeled, -torch.ones(unlabeled.shape).to(device))) 30 | pos_risk = torch.mean(conf * y_conf_pos) 31 | neg_risk = torch.mean((1 - self.beta - conf) * y_conf_neg) + unlabeled_risk 32 | if self.non: 33 | objective = torch.clamp(neg_risk, min=0) + torch.clamp(pos_risk, min=0) 34 | else: 35 | objective = neg_risk + pos_risk 36 | return objective 37 | 38 | 39 | class PNLoss(nn.Module): 40 | def __init__(self): 41 | super(PNLoss, self).__init__() 42 | self.loss = nn.SoftMarginLoss() 43 | 44 | def forward(self, conf, labeled): 45 | y_conf_pos = self.loss(labeled, torch.ones(labeled.shape).to(device)) 46 | y_conf_neg = self.loss(labeled, -torch.ones(labeled.shape).to(device)) 47 | 48 | objective = torch.mean(conf * y_conf_pos + (1 - conf) * y_conf_neg) 49 | return objective 50 | -------------------------------------------------------------------------------- /mujoco/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import deepdish as dd 6 | 7 | 8 | class Policy(nn.Module): 9 | def __init__(self, num_inputs, num_outputs, hidden_size): 10 | super(Policy, self).__init__() 11 | self.affine1 = nn.Linear(num_inputs, hidden_size) 12 | self.affine2 = nn.Linear(hidden_size, hidden_size) 13 | 14 | self.action_mean = nn.Linear(hidden_size, num_outputs) 15 | self.action_mean.weight.data.mul_(0.1) 16 | self.action_mean.bias.data.mul_(0.0) 17 | 18 | self.action_log_std = nn.Parameter(torch.zeros(1, num_outputs)) 19 | 20 | self.saved_actions = [] 21 | self.rewards = [] 22 | self.final_value = 0 23 | 24 | def forward(self, x): 25 | x = F.tanh(self.affine1(x)) 26 | x = F.tanh(self.affine2(x)) 27 | 28 | action_mean = self.action_mean(x) 29 | action_log_std = self.action_log_std.expand_as(action_mean) 30 | action_std = torch.exp(action_log_std) 31 | 32 | return action_mean, action_log_std, action_std 33 | 34 | 35 | class Value(nn.Module): 36 | def __init__(self, num_inputs, hidden_size): 37 | super(Value, self).__init__() 38 | self.affine1 = nn.Linear(num_inputs, hidden_size) 39 | self.affine2 = nn.Linear(hidden_size, hidden_size) 40 | self.value_head = nn.Linear(hidden_size, 1) 41 | self.value_head.weight.data.mul_(0.1) 42 | self.value_head.bias.data.mul_(0.0) 43 | 44 | def forward(self, x): 45 | x = F.tanh(self.affine1(x)) 46 | x = F.tanh(self.affine2(x)) 47 | 48 | state_values = self.value_head(x) 49 | return state_values 50 | 51 | 52 | class Discriminator(nn.Module): 53 | def __init__(self, num_inputs, hidden_size): 54 | super(Discriminator, self).__init__() 55 | self.linear1 = nn.Linear(num_inputs, hidden_size) 56 | self.linear2 = nn.Linear(hidden_size, hidden_size) 57 | self.linear3 = nn.Linear(hidden_size, 1) 58 | self.linear3.weight.data.mul_(0.1) 59 | self.linear3.bias.data.mul_(0.0) 60 | 61 | def forward(self, x): 62 | x = F.tanh(self.linear1(x)) 63 | x = F.tanh(self.linear2(x)) 64 | #prob = F.sigmoid(self.linear3(x)) 65 | output = self.linear3(x) 66 | return output 67 | 68 | 69 | class Generator(nn.Module): 70 | def __init__(self, num_inputs, hidden_size, num_outputs): 71 | super(Generator, self).__init__() 72 | self.fc1 = nn.Linear(num_inputs, hidden_size) 73 | self.fc2 = nn.Linear(hidden_size, hidden_size) 74 | self.fc3 = nn.Linear(hidden_size, num_outputs) 75 | 76 | def forward(self, x): 77 | x = torch.tanh(self.fc1(x)) 78 | x = torch.tanh(self.fc2(x)) 79 | x = self.fc3(x) 80 | return x 81 | 82 | 83 | class Classifier(nn.Module): 84 | def __init__(self, num_inputs, hidden_dim): 85 | super(Classifier, self).__init__() 86 | self.fc1 = nn.Linear(num_inputs, hidden_dim) 87 | self.fc2 = nn.Linear(hidden_dim, hidden_dim) 88 | self.fc3 = nn.Linear(hidden_dim, 1) 89 | 90 | self.d1 = nn.Dropout(0.5) 91 | self.d2 = nn.Dropout(0.5) 92 | 93 | self.fc3.weight.data.mul_(0.1) 94 | self.fc3.bias.data.mul_(0.0) 95 | 96 | def forward(self, x): 97 | x = self.d1(torch.tanh(self.fc1(x))) 98 | x = self.d2(torch.tanh(self.fc2(x))) 99 | x = self.fc3(x) 100 | return x 101 | 102 | 103 | class Reward(nn.Module): 104 | def __init__(self, num_inputs, hidden_size=256): 105 | super(Reward, self).__init__() 106 | self.fc1 = nn.Linear(num_inputs, hidden_size) 107 | self.fc2 = nn.Linear(hidden_size, hidden_size) 108 | self.fc3 = nn.Linear(hidden_size, 1) 109 | 110 | def forward(self, x): 111 | x = F.relu(self.fc1(x)) 112 | x = F.relu(self.fc2(x)) 113 | x = self.fc3(x) 114 | return x 115 | 116 | def init_weight(self, model_path): 117 | weights = dd.io.load(model_path) 118 | self.fc1.weight.data = torch.from_numpy(weights['fc1_weight']) 119 | self.fc1.bias.data = torch.from_numpy(weights['fc1_bias']) 120 | self.fc2.weight.data = torch.from_numpy(weights['fc2_weight']) 121 | self.fc2.bias.data = torch.from_numpy(weights['fc2_bias']) 122 | self.fc3.weight.data = torch.from_numpy(weights['fc3_weight']) 123 | self.fc3.bias.data = torch.from_numpy(weights['fc3_bias']) 124 | 125 | -------------------------------------------------------------------------------- /mujoco/replay_memory.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import namedtuple 3 | 4 | # Taken from 5 | # https://github.com/pytorch/tutorials/blob/master/Reinforcement%20(Q-)Learning%20with%20PyTorch.ipynb 6 | 7 | Transition = namedtuple('Transition', ('state', 'action', 'mask', 'next_state', 8 | 'reward')) 9 | 10 | 11 | class Memory(object): 12 | def __init__(self): 13 | self.memory = [] 14 | 15 | def push(self, *args): 16 | """Saves a transition.""" 17 | self.memory.append(Transition(*args)) 18 | 19 | def sample(self): 20 | return Transition(*zip(*self.memory)) 21 | 22 | def __len__(self): 23 | return len(self.memory) 24 | -------------------------------------------------------------------------------- /mujoco/reward_model/ant_reward_stage1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunke-wang/WGAIL/d737ccaff3e878c9d347cfbcbf8974eaf1d2e8f5/mujoco/reward_model/ant_reward_stage1.pth -------------------------------------------------------------------------------- /mujoco/reward_model/ant_reward_stage2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunke-wang/WGAIL/d737ccaff3e878c9d347cfbcbf8974eaf1d2e8f5/mujoco/reward_model/ant_reward_stage2.pth -------------------------------------------------------------------------------- /mujoco/reward_model/halfcheetah_reward_stage1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunke-wang/WGAIL/d737ccaff3e878c9d347cfbcbf8974eaf1d2e8f5/mujoco/reward_model/halfcheetah_reward_stage1.pth -------------------------------------------------------------------------------- /mujoco/reward_model/halfcheetah_reward_stage2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunke-wang/WGAIL/d737ccaff3e878c9d347cfbcbf8974eaf1d2e8f5/mujoco/reward_model/halfcheetah_reward_stage2.pth -------------------------------------------------------------------------------- /mujoco/reward_model/hopper_reward_stage1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunke-wang/WGAIL/d737ccaff3e878c9d347cfbcbf8974eaf1d2e8f5/mujoco/reward_model/hopper_reward_stage1.pth -------------------------------------------------------------------------------- /mujoco/reward_model/hopper_reward_stage2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunke-wang/WGAIL/d737ccaff3e878c9d347cfbcbf8974eaf1d2e8f5/mujoco/reward_model/hopper_reward_stage2.pth -------------------------------------------------------------------------------- /mujoco/reward_model/walker2d_reward_stage1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunke-wang/WGAIL/d737ccaff3e878c9d347cfbcbf8974eaf1d2e8f5/mujoco/reward_model/walker2d_reward_stage1.pth -------------------------------------------------------------------------------- /mujoco/reward_model/walker2d_reward_stage2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunke-wang/WGAIL/d737ccaff3e878c9d347cfbcbf8974eaf1d2e8f5/mujoco/reward_model/walker2d_reward_stage2.pth -------------------------------------------------------------------------------- /mujoco/running_state.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import numpy as np 3 | 4 | # from https://github.com/joschu/modular_rl 5 | # http://www.johndcook.com/blog/standard_deviation/ 6 | class RunningStat(object): 7 | def __init__(self, shape): 8 | self._n = 0 9 | self._M = np.zeros(shape) 10 | self._S = np.zeros(shape) 11 | 12 | def push(self, x): 13 | x = np.asarray(x) 14 | assert x.shape == self._M.shape 15 | self._n += 1 16 | if self._n == 1: 17 | self._M[...] = x 18 | else: 19 | oldM = self._M.copy() 20 | self._M[...] = oldM + (x - oldM) / self._n 21 | self._S[...] = self._S + (x - oldM) * (x - self._M) 22 | 23 | @property 24 | def n(self): 25 | return self._n 26 | 27 | @property 28 | def mean(self): 29 | return self._M 30 | 31 | @property 32 | def var(self): 33 | return self._S / (self._n - 1) if self._n > 1 else np.square(self._M) 34 | 35 | @property 36 | def std(self): 37 | return np.sqrt(self.var) 38 | 39 | @property 40 | def shape(self): 41 | return self._M.shape 42 | 43 | 44 | class ZFilter: 45 | """ 46 | y = (x-mean)/std 47 | using running estimates of mean,std 48 | """ 49 | def __init__(self, shape, demean=True, destd=True, clip=10.0): 50 | self.demean = demean 51 | self.destd = destd 52 | self.clip = clip 53 | 54 | self.rs = RunningStat(shape) 55 | 56 | def __call__(self, x, update=True): 57 | if update: self.rs.push(x) 58 | if self.demean: 59 | x = x - self.rs.mean 60 | if self.destd: 61 | x = x / (self.rs.std + 1e-8) 62 | if self.clip: 63 | x = np.clip(x, -self.clip, self.clip) 64 | return x 65 | 66 | def output_shape(self, input_space): 67 | return input_space.shape 68 | 69 | 70 | class RunningMeanStd(object): 71 | # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 72 | def __init__(self, epsilon=1e-8, shape=()): 73 | self.mean = np.zeros(shape, 'float64') 74 | self.var = np.ones(shape, 'float64') 75 | self.count = epsilon 76 | 77 | def update(self, x): 78 | batch_mean = np.mean(x, axis=0) 79 | batch_var = np.var(x, axis=0) 80 | batch_count = x.shape[0] 81 | self.update_from_moments(batch_mean, batch_var, batch_count) 82 | 83 | def update_from_moments(self, batch_mean, batch_var, batch_count): 84 | self.mean, self.var, self.count = update_mean_var_count_from_moments( 85 | self.mean, self.var, self.count, batch_mean, batch_var, batch_count) 86 | 87 | def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count): 88 | delta = batch_mean - mean 89 | tot_count = count + batch_count 90 | 91 | new_mean = mean + delta * batch_count / tot_count 92 | m_a = var * count 93 | m_b = batch_var * batch_count 94 | M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count 95 | new_var = M2 / tot_count 96 | new_count = tot_count 97 | 98 | return new_mean, new_var, new_count 99 | -------------------------------------------------------------------------------- /mujoco/trpo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.autograd import Variable 5 | from utils import * 6 | 7 | 8 | def conjugate_gradients(Avp, b, nsteps, residual_tol=1e-10): 9 | x = torch.zeros(b.size()) 10 | r = b.clone() 11 | p = b.clone() 12 | rdotr = torch.dot(r, r) 13 | for i in range(nsteps): 14 | _Avp = Avp(p) 15 | alpha = rdotr / torch.dot(p, _Avp) 16 | x += alpha * p 17 | r -= alpha * _Avp 18 | new_rdotr = torch.dot(r, r) 19 | betta = new_rdotr / rdotr 20 | p = r + betta * p 21 | rdotr = new_rdotr 22 | if rdotr < residual_tol: 23 | break 24 | return x 25 | 26 | 27 | def linesearch(model, 28 | f, 29 | x, 30 | fullstep, 31 | expected_improve_rate, 32 | max_backtracks=10, 33 | accept_ratio=.1): 34 | fval = f().data 35 | #print("fval before", fval[0]) 36 | for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)): 37 | xnew = x + stepfrac * fullstep 38 | set_flat_params_to(model, xnew) 39 | newfval = f().data 40 | actual_improve = fval - newfval 41 | expected_improve = expected_improve_rate * stepfrac 42 | ratio = actual_improve / expected_improve 43 | #print("a/e/r", actual_improve[0], expected_improve[0], ratio[0]) 44 | 45 | if ratio.item() > accept_ratio and actual_improve.item() > 0: 46 | #print("fval after", newfval[0]) 47 | return True, xnew 48 | return False, x 49 | 50 | 51 | def trpo_step(model, get_loss, get_kl, max_kl, damping): 52 | loss = get_loss() 53 | grads = torch.autograd.grad(loss, model.parameters()) 54 | loss_grad = torch.cat([grad.view(-1) for grad in grads]).data 55 | 56 | def Fvp(v): 57 | kl = get_kl() 58 | kl = kl.mean() 59 | 60 | grads = torch.autograd.grad(kl, model.parameters(), create_graph=True) 61 | flat_grad_kl = torch.cat([grad.view(-1) for grad in grads]) 62 | 63 | kl_v = (flat_grad_kl * Variable(v)).sum() 64 | grads = torch.autograd.grad(kl_v, model.parameters()) 65 | flat_grad_grad_kl = torch.cat([grad.contiguous().view(-1) for grad in grads]).data 66 | 67 | return flat_grad_grad_kl + v * damping 68 | 69 | stepdir = conjugate_gradients(Fvp, -loss_grad, 10) 70 | 71 | shs = 0.5 * (stepdir * Fvp(stepdir)).sum(0, keepdim=True) 72 | 73 | lm = torch.sqrt(shs / max_kl) # 1 / beta 74 | fullstep = stepdir / lm[0] # beta * s 75 | 76 | neggdotstepdir = (-loss_grad * stepdir).sum(0, keepdim=True) 77 | #print(("lagrange multiplier:", lm[0], "grad_norm:", loss_grad.norm())) 78 | 79 | prev_params = get_flat_params_from(model) 80 | success, new_params = linesearch(model, get_loss, prev_params, fullstep, 81 | neggdotstepdir / lm[0]) 82 | set_flat_params_to(model, new_params) 83 | 84 | return loss 85 | -------------------------------------------------------------------------------- /mujoco/trpo_irl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from itertools import count 3 | from tensorboardX import SummaryWriter 4 | import gym 5 | import gym.spaces 6 | import scipy.optimize 7 | import numpy as np 8 | import math 9 | from tqdm import tqdm 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import torch.nn.functional as F 14 | from models import * 15 | from replay_memory import Memory 16 | from running_state import ZFilter 17 | from torch.autograd import Variable 18 | from trpo import trpo_step 19 | from utils import * 20 | from loss import * 21 | from running_state import * 22 | 23 | torch.utils.backcompat.broadcast_warning.enabled = True 24 | torch.utils.backcompat.keepdim_warning.enabled = True 25 | 26 | torch.set_default_tensor_type('torch.DoubleTensor') 27 | device = torch.device("cpu") 28 | parser = argparse.ArgumentParser(description='PyTorch actor-critic example') 29 | parser.add_argument('--gamma', type=float, default=0.995, metavar='G', 30 | help='discount factor (default: 0.995)') 31 | parser.add_argument('--env', type=str, default="Ant-v2", metavar='G', 32 | help='name of the environment to run') 33 | parser.add_argument('--tau', type=float, default=0.97, metavar='G', 34 | help='gae (default: 0.97)') 35 | parser.add_argument('--l2-reg', type=float, default=1e-3, metavar='G', 36 | help='l2 regularization regression (default: 1e-3)') 37 | parser.add_argument('--max-kl', type=float, default=1e-2, metavar='G', 38 | help='max kl value (default: 1e-2)') 39 | parser.add_argument('--damping', type=float, default=1e-1, metavar='G', 40 | help='damping (default: 1e-1)') 41 | parser.add_argument('--seed', type=int, default=1111, metavar='N', 42 | help='random seed (default: 1111') 43 | parser.add_argument('--batch-size', type=int, default=5000, metavar='N', 44 | help='size of a single batch') 45 | parser.add_argument('--log-interval', type=int, default=1, metavar='N', 46 | help='interval between training status logs (default: 10)') 47 | parser.add_argument('--fname', type=str, default='expert', metavar='F', 48 | help='the file name to save trajectory') 49 | parser.add_argument('--num-epochs', type=int, default=5000, metavar='N', 50 | help='number of epochs to train an expert') 51 | parser.add_argument('--hidden-dim', type=int, default=100, metavar='H', 52 | help='the size of hidden layers') 53 | parser.add_argument('--lr', type=float, default=1e-3, metavar='L', 54 | help='learning rate') 55 | parser.add_argument('--vf-iters', type=int, default=30, metavar='V', 56 | help='number of iterations of value function optimization iterations per each policy optimization step') 57 | parser.add_argument('--vf-lr', type=float, default=3e-4, metavar='V', 58 | help='learning rate of value network') 59 | parser.add_argument('--eval-epochs', type=int, default=3, metavar='E', 60 | help='epochs to evaluate model') 61 | parser.add_argument('--traj-size', type=int, default=1000) 62 | parser.add_argument('--ifolder', type=str, default='demonstrations') 63 | parser.add_argument('--optimal-policy', type=float, default=4145.89) 64 | parser.add_argument('--random-policy', type=float, default=992.18) 65 | parser.add_argument('--reward-path', type=str, default='reward_model/ant_reward_stage1.pth') 66 | parser.add_argument('--stage', type=int, default=1) 67 | 68 | 69 | args = parser.parse_args() 70 | env = gym.make(args.env) 71 | args.weight = True 72 | num_inputs = env.observation_space.shape[0] 73 | num_actions = env.action_space.shape[0] 74 | 75 | env.seed(args.seed) 76 | torch.manual_seed(args.seed) 77 | np.random.seed(args.seed) 78 | 79 | policy_net = Policy(num_inputs, num_actions, args.hidden_dim) 80 | value_net = Value(num_inputs, args.hidden_dim).to(device) 81 | value_criterion = nn.MSELoss() 82 | value_optimizer = optim.Adam(value_net.parameters(), args.vf_lr) 83 | reward_net = Reward(num_inputs + num_actions) 84 | reward_net.load_state_dict(torch.load(args.reward_path)) 85 | fname = 'TRPO_dist' 86 | 87 | def max_normalization(x): 88 | x = (x - args.random_policy) / (args.optimal_policy - args.random_policy) 89 | return x 90 | 91 | def select_action(state): 92 | state = torch.from_numpy(state).unsqueeze(0) 93 | # action_mean, _, action_std = policy_net(Variable(state)) 94 | action_mean, _, action_std = policy_net(Variable(state)) 95 | action = torch.normal(action_mean, action_std) 96 | return action 97 | 98 | def update_params(batch): 99 | rewards = torch.Tensor(batch.reward).to(device) 100 | masks = torch.Tensor(batch.mask).to(device) 101 | actions = torch.Tensor(np.concatenate(batch.action, 0)).to(device) 102 | states = torch.Tensor(batch.state).to(device) 103 | values = value_net(Variable(states)) 104 | 105 | returns = torch.Tensor(actions.size(0),1).to(device) 106 | deltas = torch.Tensor(actions.size(0),1).to(device) 107 | advantages = torch.Tensor(actions.size(0),1).to(device) 108 | 109 | prev_return = 0 110 | prev_value = 0 111 | prev_advantage = 0 112 | for i in reversed(range(rewards.size(0))): 113 | returns[i] = rewards[i] + args.gamma * prev_return * masks[i] 114 | deltas[i] = rewards[i] + args.gamma * prev_value * masks[i] - values.data[i] 115 | advantages[i] = deltas[i] + args.gamma * args.tau * prev_advantage * masks[i] 116 | 117 | prev_return = returns[i, 0] 118 | prev_value = values.data[i, 0] 119 | prev_advantage = advantages[i, 0] 120 | 121 | targets = Variable(returns) 122 | 123 | batch_size = math.ceil(states.shape[0] / args.vf_iters) 124 | idx = np.random.permutation(states.shape[0]) 125 | for i in range(args.vf_iters): 126 | smp_idx = idx[i * batch_size: (i + 1) * batch_size] 127 | smp_states = states[smp_idx, :] 128 | smp_targets = targets[smp_idx, :] 129 | 130 | value_optimizer.zero_grad() 131 | value_loss = value_criterion(value_net(Variable(smp_states)), smp_targets) 132 | value_loss.backward() 133 | value_optimizer.step() 134 | 135 | advantages = (advantages - advantages.mean()) / advantages.std() 136 | 137 | action_means, action_log_stds, action_stds = policy_net(Variable(states.cpu())) 138 | fixed_log_prob = normal_log_density(Variable(actions.cpu()), action_means, action_log_stds, action_stds).data.clone() 139 | 140 | def get_loss(): 141 | action_means, action_log_stds, action_stds = policy_net(Variable(states.cpu())) 142 | log_prob = normal_log_density(Variable(actions.cpu()), action_means, action_log_stds, action_stds) 143 | action_loss = -Variable(advantages.cpu()) * torch.exp(log_prob - Variable(fixed_log_prob)) 144 | return action_loss.mean() 145 | 146 | 147 | def get_kl(): 148 | mean1, log_std1, std1 = policy_net(Variable(states.cpu())) 149 | 150 | mean0 = Variable(mean1.data) 151 | log_std0 = Variable(log_std1.data) 152 | std0 = Variable(std1.data) 153 | kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5 154 | return kl.sum(1, keepdim=True) 155 | 156 | trpo_step(policy_net, get_loss, get_kl, args.max_kl, args.damping) 157 | 158 | def expert_reward(states, actions): 159 | states = np.concatenate(states) 160 | actions = np.concatenate(actions) 161 | state_action = torch.Tensor(np.concatenate([states, actions], 1)).to(device) 162 | reward = reward_net(state_action) 163 | return reward.cpu().detach().numpy() 164 | 165 | 166 | def evaluate(episode): 167 | avg_reward = 0.0 168 | avg_dist = 0.0 169 | for _ in range(args.eval_epochs): 170 | state = env.reset() 171 | unwrapped = env 172 | while hasattr(unwrapped, 'env'): 173 | unwrapped = unwrapped.env 174 | 175 | for _ in range(10000): # Don't infinite loop while learning 176 | state = torch.from_numpy(state).unsqueeze(0) 177 | action, _, _ = policy_net(Variable(state)) 178 | action = action.data[0].numpy() 179 | next_state, reward, done, _ = env.step(action) 180 | avg_reward += reward 181 | if done: 182 | break 183 | state = next_state 184 | avg_dist += unwrapped.sim.data.qpos[0] 185 | 186 | reward_eva_in = avg_reward / args.eval_epochs 187 | reward_eva_in_normal = max_normalization(reward_eva_in) 188 | avg_dist = avg_dist / args.eval_epochs 189 | writer.log(episode, reward_eva_in, reward_eva_in_normal, avg_dist) 190 | return reward_eva_in, reward_eva_in_normal, avg_dist 191 | 192 | 193 | def evaluate_dist(episode): 194 | for _ in range(args.eval_epochs): 195 | state = env.reset() 196 | unwrapped = env 197 | while hasattr(unwrapped, 'env'): 198 | unwrapped = unwrapped.env 199 | 200 | for _ in range(10000): 201 | state = torch.from_numpy(state).unsqueeze(0) 202 | action, _, _ = policy_net(Variable(state)) 203 | action = action.data[0].numpy() 204 | next_state, reward, done, _ = env.step(action) 205 | 206 | if done: 207 | break 208 | state = next_state 209 | 210 | writer.log_dist(episode, unwrapped.sim.data.qpos[0]) 211 | return unwrapped.sim.data.qpos[0] #Final location of the object 212 | 213 | 214 | method = 'IRL_normal' 215 | from pathlib import Path 216 | import os 217 | logdir = Path(os.path.abspath(os.path.join('IRL',str(args.env),args.method+str(args.level)))) 218 | if logdir.exists(): 219 | print('orinal logdir is already exist.') 220 | 221 | writer = Writer(args.env, args.seed, 'IRL_stage{}'.format(args.stage), args.traj_size, folder=str(logdir)) 222 | 223 | rew_rms = RunningMeanStd(shape=()) 224 | cliprew = 10. 225 | epsilon = 1e-8 226 | 227 | for i_episode in tqdm(range(args.num_epochs), dynamic_ncols=True): 228 | memory = Memory() 229 | 230 | num_steps = 0 231 | num_episodes = 0 232 | 233 | reward_batch = [] 234 | states = [] 235 | actions = [] 236 | mem_actions = [] 237 | mem_mask = [] 238 | mem_next = [] 239 | true_rewards = [] 240 | while num_steps < args.batch_size: 241 | state = env.reset() 242 | 243 | 244 | reward_sum = 0 245 | for t in range(10000): # Don't infinite loop while learning 246 | action = select_action(state) 247 | action = action.data[0].numpy() 248 | states.append(np.array([state])) 249 | actions.append(np.array([action])) 250 | 251 | next_state, true_reward, done, _ = env.step(action) 252 | # true_reward = running_reward(np.array(true_reward).reshape(1,)) 253 | reward_sum += true_reward 254 | 255 | mask = 1 256 | if done: 257 | mask = 0 258 | 259 | mem_mask.append(mask) 260 | mem_next.append(next_state) 261 | # true_rewards.append(true_reward) 262 | # env.render() 263 | if done: 264 | break 265 | 266 | state = next_state 267 | num_steps += (t-1) 268 | num_episodes += 1 269 | 270 | reward_batch.append(reward_sum) 271 | 272 | reward_eva, reward_eva_normal, dist = evaluate(i_episode) 273 | 274 | rewards = expert_reward(states, actions) 275 | 276 | rew_rms.update(rewards) 277 | rewards = np.clip(rewards / np.sqrt(rew_rms.var + epsilon), -cliprew, cliprew) 278 | 279 | for idx in range(len(states)): 280 | memory.push(states[idx][0], actions[idx], mem_mask[idx], mem_next[idx], \ 281 | rewards[idx][0]) 282 | batch = memory.sample() 283 | update_params(batch) 284 | 285 | if i_episode % args.log_interval == 0: 286 | tqdm.write('Episode {}\tAvg_reward: {:.2f}\tReward_nor: {:.2f}\tAverage distance: {:.2f}'.format(i_episode, reward_eva, reward_eva_normal, dist)) 287 | -------------------------------------------------------------------------------- /mujoco/utils.py: -------------------------------------------------------------------------------- 1 | import math, os 2 | 3 | import numpy as np 4 | 5 | import torch 6 | 7 | 8 | def normal_entropy(std): 9 | var = std.pow(2) 10 | entropy = 0.5 + 0.5 * torch.log(2 * var * math.pi) 11 | return entropy.sum(1, keepdim=True) 12 | 13 | 14 | def normal_log_density(x, mean, log_std, std): 15 | var = std.pow(2) 16 | log_density = -(x - mean).pow(2) / ( 17 | 2 * var) - 0.5 * math.log(2 * math.pi) - log_std 18 | return log_density.sum(1, keepdim=True) 19 | 20 | 21 | def get_flat_params_from(model): 22 | params = [] 23 | for param in model.parameters(): 24 | params.append(param.data.view(-1)) 25 | 26 | flat_params = torch.cat(params) 27 | return flat_params 28 | 29 | 30 | def set_flat_params_to(model, flat_params): 31 | prev_ind = 0 32 | for param in model.parameters(): 33 | flat_size = int(np.prod(list(param.size()))) 34 | param.data.copy_( 35 | flat_params[prev_ind:prev_ind + flat_size].view(param.size())) 36 | prev_ind += flat_size 37 | 38 | 39 | def get_flat_grad_from(net, grad_grad=False): 40 | grads = [] 41 | for param in net.parameters(): 42 | if grad_grad: 43 | grads.append(param.grad.grad.view(-1)) 44 | else: 45 | grads.append(param.grad.view(-1)) 46 | 47 | flat_grad = torch.cat(grads) 48 | return flat_grad 49 | 50 | 51 | class Writer(object): 52 | def __init__(self, env, seed, epoch, traj_size, folder='PU_log', pbound='0.0'): 53 | 54 | if pbound != '0.0': 55 | pblabel = '_{}'.format(pbound) 56 | else: 57 | pblabel = '' 58 | 59 | self.fname = '{}_{}_{}_{}{}.csv'.format(env, seed, epoch, traj_size, pblabel) 60 | self.folder = folder 61 | if not os.path.isdir(self.folder): 62 | os.makedirs(self.folder) 63 | if os.path.exists('{}/{}'.format(self.folder, self.fname)): 64 | print('Overwrite {}/{}!'.format(self.folder, self.fname)) 65 | os.remove('{}/{}'.format(self.folder, self.fname)) 66 | 67 | def log_dist(self, epoch, dist): 68 | with open(self.folder + '/' + self.fname, 'a') as f: 69 | f.write('{},{}\n'.format(epoch, dist)) 70 | 71 | def log(self, epoch, reward, nor_reward, dist): 72 | with open(self.folder + '/' + self.fname, 'a') as f: 73 | f.write('{},{},{},{}\n'.format(epoch, reward, nor_reward, dist)) 74 | 75 | def digitize(arr, unit): 76 | if unit < 1e-6: 77 | return arr 78 | return np.round(arr / unit) * unit 79 | 80 | def save_model(model, name, folder): 81 | if not os.path.isdir(folder): 82 | os.makedirs(folder) 83 | torch.save(model.state_dict(), folder + name) 84 | -------------------------------------------------------------------------------- /mujoco/wgail.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from itertools import count 3 | 4 | import gym 5 | import gym.spaces 6 | import scipy.optimize 7 | import numpy as np 8 | import math 9 | from tqdm import tqdm 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import torch.nn.functional as F 14 | from models import * 15 | from replay_memory import Memory 16 | from running_state import ZFilter 17 | from torch.autograd import Variable 18 | from trpo import trpo_step 19 | from utils import * 20 | from loss import * 21 | 22 | torch.utils.backcompat.broadcast_warning.enabled = True 23 | torch.utils.backcompat.keepdim_warning.enabled = True 24 | 25 | torch.set_default_tensor_type('torch.DoubleTensor') 26 | device = torch.device("cpu") 27 | parser = argparse.ArgumentParser(description='PyTorch actor-critic example') 28 | parser.add_argument('--gamma', type=float, default=0.995, metavar='G', 29 | help='discount factor (default: 0.995)') 30 | parser.add_argument('--env', type=str, default="Ant-v2", metavar='G', 31 | help='name of the environment to run') 32 | parser.add_argument('--tau', type=float, default=0.97, metavar='G', 33 | help='gae (default: 0.97)') 34 | parser.add_argument('--l2-reg', type=float, default=1e-3, metavar='G', 35 | help='l2 regularization regression (default: 1e-3)') 36 | parser.add_argument('--max-kl', type=float, default=1e-2, metavar='G', 37 | help='max kl value (default: 1e-2)') 38 | parser.add_argument('--damping', type=float, default=1e-1, metavar='G', 39 | help='damping (default: 1e-1)') 40 | parser.add_argument('--seed', type=int, default=1111, metavar='N', 41 | help='random seed (default: 1111') 42 | parser.add_argument('--batch-size', type=int, default=5000, metavar='N', 43 | help='size of a single batch') 44 | parser.add_argument('--log-interval', type=int, default=1, metavar='N', 45 | help='interval between training status logs (default: 10)') 46 | parser.add_argument('--fname', type=str, default='expert', metavar='F', 47 | help='the file name to save trajectory') 48 | parser.add_argument('--num-epochs', type=int, default=5000, metavar='N', 49 | help='number of epochs to train an expert') 50 | parser.add_argument('--hidden-dim', type=int, default=100, metavar='H', 51 | help='the size of hidden layers') 52 | parser.add_argument('--lr', type=float, default=1e-3, metavar='L', 53 | help='learning rate') 54 | parser.add_argument('--vf-iters', type=int, default=30, metavar='V', 55 | help='number of iterations of value function optimization iterations per each policy optimization step') 56 | parser.add_argument('--vf-lr', type=float, default=3e-4, metavar='V', 57 | help='learning rate of value network') 58 | parser.add_argument('--eval-epochs', type=int, default=3, metavar='E', 59 | help='epochs to evaluate model') 60 | parser.add_argument('--traj-size', type=int, default=1000) 61 | parser.add_argument('--ifolder', type=str, default='demonstrations') 62 | parser.add_argument('--optimal-policy', type=float, default=4145.89) 63 | parser.add_argument('--random-policy', type=float, default=992.18) 64 | parser.add_argument('--stage', type=int, default=1) 65 | parser.add_argument('--beta', type=int, default=1) 66 | parser.add_argument('--early-stop', type=int, default=400) 67 | 68 | args = parser.parse_args() 69 | env = gym.make(args.env) 70 | 71 | num_inputs = env.observation_space.shape[0] 72 | num_actions = env.action_space.shape[0] 73 | 74 | env.seed(args.seed) 75 | torch.manual_seed(args.seed) 76 | np.random.seed(args.seed) 77 | 78 | policy_net = Policy(num_inputs, num_actions, args.hidden_dim) 79 | value_net = Value(num_inputs, args.hidden_dim).to(device) 80 | discriminator = Discriminator(num_inputs + num_actions, args.hidden_dim).to(device) 81 | disc_criterion = nn.BCEWithLogitsLoss() 82 | value_criterion = nn.MSELoss() 83 | disc_optimizer = optim.Adam(discriminator.parameters(), args.lr) 84 | value_optimizer = optim.Adam(value_net.parameters(), args.vf_lr) 85 | 86 | 87 | def max_normalization(x): 88 | x = (x - args.random_policy) / (args.optimal_policy - args.random_policy) 89 | return x 90 | 91 | 92 | def select_action(state): 93 | state = torch.from_numpy(state).unsqueeze(0) 94 | # action_mean, _, action_std = policy_net(Variable(state)) 95 | action_mean, _, action_std = policy_net(Variable(state)) 96 | action = torch.normal(action_mean, action_std) 97 | return action 98 | 99 | 100 | def update_params(batch): 101 | rewards = torch.Tensor(batch.reward).to(device) 102 | masks = torch.Tensor(batch.mask).to(device) 103 | actions = torch.Tensor(np.concatenate(batch.action, 0)).to(device) 104 | states = torch.Tensor(batch.state).to(device) 105 | values = value_net(Variable(states)) 106 | 107 | returns = torch.Tensor(actions.size(0), 1).to(device) 108 | deltas = torch.Tensor(actions.size(0), 1).to(device) 109 | advantages = torch.Tensor(actions.size(0), 1).to(device) 110 | 111 | prev_return = 0 112 | prev_value = 0 113 | prev_advantage = 0 114 | for i in reversed(range(rewards.size(0))): 115 | returns[i] = rewards[i] + args.gamma * prev_return * masks[i] 116 | deltas[i] = rewards[i] + args.gamma * prev_value * masks[i] - values.data[i] 117 | advantages[i] = deltas[i] + args.gamma * args.tau * prev_advantage * masks[i] 118 | 119 | prev_return = returns[i, 0] 120 | prev_value = values.data[i, 0] 121 | prev_advantage = advantages[i, 0] 122 | 123 | targets = Variable(returns) 124 | 125 | batch_size = math.ceil(states.shape[0] / args.vf_iters) 126 | idx = np.random.permutation(states.shape[0]) 127 | for i in range(args.vf_iters): 128 | smp_idx = idx[i * batch_size: (i + 1) * batch_size] 129 | smp_states = states[smp_idx, :] 130 | smp_targets = targets[smp_idx, :] 131 | 132 | value_optimizer.zero_grad() 133 | value_loss = value_criterion(value_net(Variable(smp_states)), smp_targets) 134 | value_loss.backward() 135 | value_optimizer.step() 136 | 137 | advantages = (advantages - advantages.mean()) / advantages.std() 138 | 139 | action_means, action_log_stds, action_stds = policy_net(Variable(states.cpu())) 140 | fixed_log_prob = normal_log_density(Variable(actions.cpu()), action_means, action_log_stds, 141 | action_stds).data.clone() 142 | 143 | def get_loss(): 144 | action_means, action_log_stds, action_stds = policy_net(Variable(states.cpu())) 145 | log_prob = normal_log_density(Variable(actions.cpu()), action_means, action_log_stds, action_stds) 146 | action_loss = -Variable(advantages.cpu()) * torch.exp(log_prob - Variable(fixed_log_prob)) 147 | return action_loss.mean() 148 | 149 | def get_kl(): 150 | mean1, log_std1, std1 = policy_net(Variable(states.cpu())) 151 | 152 | mean0 = Variable(mean1.data) 153 | log_std0 = Variable(log_std1.data) 154 | std0 = Variable(std1.data) 155 | kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5 156 | return kl.sum(1, keepdim=True) 157 | 158 | trpo_step(policy_net, get_loss, get_kl, args.max_kl, args.damping) 159 | 160 | 161 | def expert_reward(states, actions): 162 | states = np.concatenate(states) 163 | actions = np.concatenate(actions) 164 | state_action = torch.Tensor(np.concatenate([states, actions], 1)).to(device) 165 | return -F.logsigmoid(discriminator(state_action)).cpu().detach().numpy() 166 | 167 | 168 | def evaluate(episode): 169 | avg_reward = 0.0 170 | avg_dist = 0.0 171 | for _ in range(args.eval_epochs): 172 | state = env.reset() 173 | unwrapped = env 174 | while hasattr(unwrapped, 'env'): 175 | unwrapped = unwrapped.env 176 | 177 | for _ in range(10000): # Don't infinite loop while learning 178 | state = torch.from_numpy(state).unsqueeze(0) 179 | action, _, _ = policy_net(Variable(state)) 180 | action = action.data[0].numpy() 181 | next_state, reward, done, _ = env.step(action) 182 | avg_reward += reward 183 | if done: 184 | break 185 | state = next_state 186 | avg_dist += unwrapped.sim.data.qpos[0] 187 | 188 | reward_eva_in = avg_reward / args.eval_epochs 189 | reward_eva_in_normal = max_normalization(reward_eva_in) 190 | avg_dist = avg_dist / args.eval_epochs 191 | writer.log(episode, reward_eva_in, reward_eva_in_normal, avg_dist) 192 | return reward_eva_in, reward_eva_in_normal, avg_dist 193 | 194 | 195 | try: 196 | expert_traj = np.load("./{}/{}_stage{}.npy".format(args.ifolder, args.env, args.stage)) 197 | except: 198 | print('Mixture demonstrations not loaded successfully.') 199 | assert False 200 | 201 | idx = np.random.choice(expert_traj.shape[0], args.traj_size, replace=False) 202 | expert_traj = expert_traj[idx, :] 203 | expert_conf = torch.ones((expert_traj.shape[0], 1)) 204 | expert_traj = torch.Tensor(expert_traj) 205 | 206 | method = 'wgail_{}_stage{}_{}'.format(args.traj_size,args.stage, args.beta) 207 | from pathlib import Path 208 | import os 209 | 210 | logdir = Path(os.path.abspath(os.path.join('wgail', str(args.env), method, str(args.seed)))) 211 | if logdir.exists(): 212 | print('orinal logdir is already exist.') 213 | 214 | writer = Writer(args.env, args.seed, 'stage{}_{}'.format(args.stage, args.beta), args.traj_size,folder=str(logdir)) 215 | 216 | for i_episode in tqdm(range(args.num_epochs), dynamic_ncols=True): 217 | memory = Memory() 218 | 219 | num_steps = 0 220 | num_episodes = 0 221 | 222 | reward_batch = [] 223 | states = [] 224 | actions = [] 225 | mem_actions = [] 226 | mem_mask = [] 227 | mem_next = [] 228 | 229 | while num_steps < args.batch_size: 230 | state = env.reset() 231 | 232 | reward_sum = 0 233 | for t in range(10000): # Don't infinite loop while learning 234 | action = select_action(state) 235 | action = action.data[0].numpy() 236 | states.append(np.array([state])) 237 | actions.append(np.array([action])) 238 | next_state, true_reward, done, _ = env.step(action) 239 | reward_sum += true_reward 240 | 241 | mask = 1 242 | if done: 243 | mask = 0 244 | 245 | mem_mask.append(mask) 246 | mem_next.append(next_state) 247 | 248 | # env.render() 249 | if done: 250 | break 251 | 252 | state = next_state 253 | num_steps += (t - 1) 254 | num_episodes += 1 255 | 256 | reward_batch.append(reward_sum) 257 | 258 | # 2. evaluate distance 259 | reward_eva, reward_eva_normal, dist = evaluate(i_episode) 260 | 261 | rewards = expert_reward(states, actions) 262 | 263 | for idx in range(len(states)): 264 | memory.push(states[idx][0], actions[idx], mem_mask[idx], mem_next[idx], \ 265 | rewards[idx][0]) 266 | batch = memory.sample() 267 | update_params(batch) 268 | 269 | if i_episode == 0: 270 | expert_conf = torch.ones((expert_traj.shape[0], 1)) 271 | prob = torch.ones_like(expert_conf) 272 | ac_mean = torch.zeros_like(prob) 273 | ac_std = torch.zeros_like(ac_mean) 274 | 275 | if (i_episode+1) % 50 == 0 and i_episode <= args.early_stop: 276 | with torch.no_grad(): 277 | ac_mean, _, ac_std = policy_net(expert_traj[:, :num_inputs]) 278 | ac = expert_traj[:, num_inputs:] 279 | ac_var = ac_std ** 2 280 | lg_prob = -((ac-ac_mean)**2) / (2*ac_var) - torch.log(ac_std) - math.log(math.sqrt(2*math.pi)) 281 | prob = torch.exp(lg_prob.sum(-1, keepdim=True)) 282 | 283 | expert_conf = ((1 / torch.sigmoid(discriminator(expert_traj)) - 1) * prob).pow(1 / (args.beta+1)) 284 | 285 | actions = torch.from_numpy(np.concatenate(actions)) 286 | states = torch.from_numpy(np.concatenate(states)) 287 | 288 | idx = np.random.randint(0, expert_traj.shape[0], num_steps) 289 | expert_state_action = expert_traj[idx, :] 290 | expert_conf_batch = expert_conf[idx, :] 291 | 292 | state_action = torch.cat((states, actions), 1).to(device) 293 | 294 | fake = discriminator(state_action) 295 | real = discriminator(expert_state_action) 296 | 297 | disc_optimizer.zero_grad() 298 | weighted_loss = nn.BCEWithLogitsLoss(weight=expert_conf_batch.detach()) 299 | disc_loss = disc_criterion(fake, torch.ones(states.shape[0], 1).to(device)) + \ 300 | weighted_loss(real, torch.zeros(real.shape[0], 1).to(device)) 301 | 302 | disc_loss.backward() 303 | disc_optimizer.step() 304 | 305 | if i_episode % args.log_interval == 0: 306 | tqdm.write( 307 | 'Episode {}\tReward: {:.2f}\tReward_nor: {:.2f}\tAverage distance: {:.2f}'.format(i_episode, reward_eva, 308 | reward_eva_normal, dist)) 309 | --------------------------------------------------------------------------------