├── README.md ├── dril ├── .gitignore ├── LICENSE ├── a2c_ppo_acktr │ ├── __init__.py │ ├── algo │ │ ├── __init__.py │ │ ├── a2c_acktr.py │ │ ├── behavior_cloning.py │ │ ├── dril.py │ │ ├── ensemble.py │ │ ├── gail.py │ │ ├── kfac.py │ │ └── ppo.py │ ├── arguments.py │ ├── distributions.py │ ├── duckietown │ │ ├── env.py │ │ ├── teacher.py │ │ └── wrappers.py │ ├── ensemble_models.py │ ├── envs.py │ ├── expert_dataset.py │ ├── model.py │ ├── retro │ │ ├── .gitignore │ │ ├── README.md │ │ ├── pygame_controller.py │ │ ├── retro_interactive.py │ │ └── retro_joystick.py │ ├── stable_baselines │ │ ├── base_vec_env.py │ │ └── running_mean_std.py │ ├── storage.py │ └── utils.py ├── enjoy.py ├── evaluation.py ├── generate_demonstration_data.py └── main.py ├── plot.py ├── pngs ├── atari.png └── continous_control.png └── setup.py /README.md: -------------------------------------------------------------------------------- 1 | **Due to a normalization bug the expert trajectories have lower performance than the [rl_baseline_zoo]() reported experts. Please see the following link in codebase for where the bug was fixed at.** [[link]()] 2 | 3 | # Disagreement-Regularized Imitation Learning 4 | 5 | Code to train the models described in the paper ["Disagreement-Regularized Imitation Learning"](), by Kianté Brantley, Wen Sun and Mikael Henaff. 6 | 7 | ## Usage: 8 | 9 | ### Install using pip 10 | Install the DRIL package 11 | 12 | ``` 13 | pip install -e . 14 | ``` 15 | 16 | ### Software Dependencies 17 | ["stable-baselines"](), ["rl-baselines-zoo"](), ["baselines"](), ["gym"](), ["pytorch"](), ["pybullet"]() 18 | 19 | ### Data 20 | 21 | We provide a python script to generate expert data from per-trained models using the ["rl-baselines-zoo"]() repository. Click ["Here"]() to see all of the pre-trained agents available and their respective perfromance. Replace ```` with the name of the pre-trained agent environment you would like to collect expert data for. 22 | 23 | ``` 24 | python -u generate_demonstration_data.py --seed --env-name --rl_baseline_zoo_dir 25 | ``` 26 | 27 | ### Training 28 | DRIL requires a per-trained ensemble model and a per-trained behavior-cloning model. 29 | 30 | **Note that `````` is the full-path to the top-level directory to the rl_baseline_zoo repository.** 31 | 32 | To train **only** a behavior-cloning model run: 33 | ``` 34 | python -u main.py --env-name --num-trajs --behavior_cloning --rl_baseline_zoo_dir --seed ' 35 | ``` 36 | 37 | To train **only** a ensemble model run: 38 | ``` 39 | python -u main.py --env-name --num-trajs --pretrain_ensemble_only --rl_baseline_zoo_dir --seed ' 40 | ``` 41 | 42 | To train a **DRIL** model run the command below. Note that command below first checks that both the behavior cloning model and the ensemble model are trained, if they are not the script will automatically train both the **ensemble** and **behavior-cloning** model. 43 | 44 | ``` 45 | python -u main.py --env-name --default_experiment_params --num-trajs --rl_baseline_zoo_dir --seed --dril 46 | ``` 47 | 48 | ```--default_experiment_params``` are the default parameters we use in the **DRIL** experiments and has two options: ```atari``` and ```continous-control``` 49 | 50 | ### Visualization 51 | After training the models, the results are stored in a folder called ```trained_results```. Run the command below to reproduce the plots in our paper. If you change any of the hyperparameters, you will need to change the hyperparameters in the plot file naming convention. 52 | ``` 53 | python -u plot.py -env 54 | ``` 55 | 56 | ## Empirical evaluation 57 | ### Atari 58 | Results on Atari environments. 59 | ![Empirical evaluation](pngs/atari.png) 60 | 61 | ### Continous Control 62 | Results on continuous control tasks. 63 | ![Empirical evaluation](pngs/continous_control.png) 64 | 65 | ## Acknowledgement: 66 | We would like to thank Ilya Kostrikov for creating this ["repo"]() that our codebase builds on. 67 | -------------------------------------------------------------------------------- /dril/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | trained_models/ 3 | trained_results/ 4 | tmp/ 5 | -------------------------------------------------------------------------------- /dril/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Ilya Kostrikov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xkianteb/dril/57eac5c3a5b0f4766821a0bedff043471f91e4f1/dril/a2c_ppo_acktr/__init__.py -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/algo/__init__.py: -------------------------------------------------------------------------------- 1 | from .a2c_acktr import A2C_ACKTR 2 | from .ppo import PPO -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/algo/a2c_acktr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | 5 | from dril.a2c_ppo_acktr.algo.behavior_cloning import BehaviorCloning 6 | 7 | class A2C_ACKTR(): 8 | def __init__(self, 9 | actor_critic, 10 | value_loss_coef, 11 | entropy_coef, 12 | lr=None, 13 | eps=None, 14 | alpha=None, 15 | max_grad_norm=None, 16 | acktr=False, 17 | dril=None): 18 | 19 | self.actor_critic = actor_critic 20 | self.acktr = acktr 21 | 22 | self.value_loss_coef = value_loss_coef 23 | self.entropy_coef = entropy_coef 24 | 25 | self.max_grad_norm = max_grad_norm 26 | 27 | if acktr: 28 | self.optimizer = KFACOptimizer(actor_critic) 29 | else: 30 | self.optimizer = optim.RMSprop( 31 | actor_critic.parameters(), lr, eps=eps, alpha=alpha) 32 | 33 | self.dril = dril 34 | 35 | def update(self, rollouts): 36 | obs_shape = rollouts.obs.size()[2:] 37 | action_shape = rollouts.actions.size()[-1] 38 | num_steps, num_processes, _ = rollouts.rewards.size() 39 | 40 | values, action_log_probs, dist_entropy, _ = self.actor_critic.evaluate_actions( 41 | rollouts.obs[:-1].view(-1, *obs_shape), 42 | rollouts.recurrent_hidden_states[0].view( 43 | -1, self.actor_critic.recurrent_hidden_state_size), 44 | rollouts.masks[:-1].view(-1, 1), 45 | rollouts.actions.view(-1, action_shape)) 46 | 47 | values = values.view(num_steps, num_processes, 1) 48 | action_log_probs = action_log_probs.view(num_steps, num_processes, 1) 49 | 50 | advantages = rollouts.returns[:-1] - values 51 | value_loss = advantages.pow(2).mean() 52 | 53 | action_loss = -(advantages.detach() * action_log_probs).mean() 54 | 55 | if self.acktr and self.optimizer.steps % self.optimizer.Ts == 0: 56 | # Compute fisher, see Martens 2014 57 | self.actor_critic.zero_grad() 58 | pg_fisher_loss = -action_log_probs.mean() 59 | 60 | value_noise = torch.randn(values.size()) 61 | if values.is_cuda: 62 | value_noise = value_noise.cuda() 63 | 64 | sample_values = values + value_noise 65 | vf_fisher_loss = -(values - sample_values.detach()).pow(2).mean() 66 | 67 | fisher_loss = pg_fisher_loss + vf_fisher_loss 68 | self.optimizer.acc_stats = True 69 | fisher_loss.backward(retain_graph=True) 70 | self.optimizer.acc_stats = False 71 | 72 | self.optimizer.zero_grad() 73 | (value_loss * self.value_loss_coef + action_loss - 74 | dist_entropy * self.entropy_coef).backward() 75 | 76 | if self.acktr == False: 77 | nn.utils.clip_grad_norm_(self.actor_critic.parameters(), 78 | self.max_grad_norm) 79 | 80 | self.optimizer.step() 81 | 82 | if self.dril: 83 | self.dril.bc_update() 84 | 85 | return value_loss.item(), action_loss.item(), dist_entropy.item() 86 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/algo/behavior_cloning.py: -------------------------------------------------------------------------------- 1 | # prerequisites 2 | import copy 3 | import glob 4 | import sys 5 | import os 6 | import time 7 | from collections import deque 8 | 9 | import gym 10 | 11 | import numpy as np 12 | import copy 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from torch.utils.data import DataLoader, TensorDataset 17 | 18 | class BehaviorCloning: 19 | def __init__(self, policy, device, batch_size=None, lr=None, expert_dataset=None, 20 | num_batches=np.float('inf'), training_data_split=None, envs=None, ensemble_size=None): 21 | super(BehaviorCloning, self).__init__() 22 | 23 | self.actor_critic = policy 24 | 25 | self.optimizer = torch.optim.Adam(self.actor_critic.parameters(), lr=lr) 26 | self.device = device 27 | self.lr = lr 28 | self.batch_size = batch_size 29 | 30 | datasets = expert_dataset.load_demo_data(training_data_split, batch_size, ensemble_size) 31 | self.trdata = datasets['trdata'] 32 | self.tedata = datasets['tedata'] 33 | 34 | self.num_batches = num_batches 35 | self.action_space = envs.action_space 36 | 37 | def update(self, update=True, data_loader_type=None): 38 | if data_loader_type == 'train': 39 | data_loader = self.trdata 40 | elif data_loader_type == 'test': 41 | data_loader = self.tedata 42 | else: 43 | raise Exception("Unknown Data loader specified") 44 | 45 | total_loss = 0 46 | for batch_idx, batch in enumerate(data_loader, 1): 47 | self.optimizer.zero_grad() 48 | (states, actions) = batch 49 | expert_states = states.float().to(self.device) 50 | expert_actions = actions.float().to(self.device) 51 | 52 | dynamic_batch_size = expert_states.shape[0] 53 | try: 54 | # Regular Behavior Cloning 55 | pred_actions = self.actor_critic.get_action(expert_states).view(dynamic_batch_size, -1) 56 | except AttributeError: 57 | # Ensemble Behavior Cloning 58 | pred_actions = self.actor_critic(expert_states).view(dynamic_batch_size, -1) 59 | 60 | if isinstance(self.action_space, gym.spaces.Box): 61 | pred_actions = torch.clamp(pred_actions, self.action_space.low[0],self.action_space.high[0]) 62 | expert_actions = torch.clamp(expert_actions.float(), self.action_space.low[0],self.action_space.high[0]) 63 | loss = F.mse_loss(pred_actions, expert_actions) 64 | elif isinstance(self.action_space, gym.spaces.discrete.Discrete): 65 | loss = F.cross_entropy(pred_actions, expert_actions.flatten().long()) 66 | elif self.action_space.__class__.__name__ == "MultiBinary": 67 | loss = torch.binary_cross_entropy_with_logits(pred_actions, expert_actions).mean() 68 | 69 | if update: 70 | loss.backward() 71 | self.optimizer.step() 72 | 73 | total_loss += loss.item() 74 | 75 | if batch_idx >= self.num_batches: 76 | break 77 | 78 | return (total_loss / batch_idx) 79 | 80 | def reset(self): 81 | self.optimizer = torch.optim.Adam(self.actor_critic.parameters(), lr=self.lr) 82 | 83 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/algo/dril.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import gym 5 | import pandas as pd 6 | 7 | import dril.a2c_ppo_acktr.ensemble_models as ensemble_models 8 | from baselines.common.running_mean_std import RunningMeanStd 9 | from collections import defaultdict 10 | 11 | from torch.utils.data import DataLoader, TensorDataset 12 | 13 | # This file creates the reward function used by dril. Both reinforcement algorithms 14 | # ppo (line: 102) and a2c (line: 92), have dril bc udpates. 15 | 16 | class DRIL: 17 | def __init__(self, device=None, envs=None, ensemble_policy=None, env_name=None, 18 | expert_dataset=None, ensemble_size=None, ensemble_quantile_threshold=None, 19 | dril_bc_model=None, dril_cost_clip=None, num_dril_bc_train_epoch=None,\ 20 | training_data_split=None): 21 | 22 | self.ensemble_quantile_threshold = ensemble_quantile_threshold 23 | self.dril_cost_clip = dril_cost_clip 24 | self.device = device 25 | self.num_dril_bc_train_epoch = num_dril_bc_train_epoch 26 | self.env_name = env_name 27 | self.returns = None 28 | self.ret_rms = RunningMeanStd(shape=()) 29 | self.observation_space = envs.observation_space 30 | 31 | if envs.action_space.__class__.__name__ == "Discrete": 32 | self.num_actions = envs.action_space.n 33 | elif envs.action_space.__class__.__name__ == "Box": 34 | self.num_actions = envs.action_space.shape[0] 35 | elif envs.action_space.__class__.__name__ == "MultiBinary": 36 | self.num_actions = envs.action_space.shape[0] 37 | 38 | self.ensemble_size = ensemble_size 39 | # use full data since we don't use a validation set 40 | self.trdata = expert_dataset.load_demo_data(1.0, 1, self.ensemble_size)['trdata'] 41 | 42 | self.ensemble = ensemble_policy 43 | self.bc = dril_bc_model 44 | self.bc.num_batches = num_dril_bc_train_epoch 45 | self.clip_variance = self.policy_variance(envs=envs) 46 | 47 | def policy_variance(self, q=0.98, envs=None): 48 | q = self.ensemble_quantile_threshold 49 | obs = None 50 | acs = None 51 | 52 | variance = defaultdict(lambda:[]) 53 | for batch_idx, batch in enumerate(self.trdata): 54 | (state, action) = batch 55 | action = action.float().to(self.device) 56 | 57 | # Image observation 58 | if len(self.observation_space.shape) == 3: 59 | state = state.repeat(self.ensemble_size, 1,1,1).float().to(self.device) 60 | # Feature observations 61 | else: 62 | state = state.repeat(self.ensemble_size, 1).float().to(self.device) 63 | 64 | if isinstance(envs.action_space, gym.spaces.discrete.Discrete): 65 | # Note: this is just a place holder 66 | action_idx = int(action.item()) 67 | one_hot_action = torch.FloatTensor(np.eye(self.num_actions)[int(action.item())]) 68 | action = one_hot_action 69 | elif envs.action_space.__class__.__name__ == "MultiBinary": 70 | # create unique id for each combination 71 | action_idx = int("".join(str(int(x)) for x in action[0].tolist()), 2) 72 | else: 73 | action_idx = 0 74 | 75 | with torch.no_grad(): 76 | ensemble_action = self.ensemble(state).squeeze() 77 | if isinstance(envs.action_space, gym.spaces.Box): 78 | action = torch.clamp(action, envs.action_space.low[0], envs.action_space.high[0]) 79 | 80 | ensemble_action = torch.clamp(ensemble_action, envs.action_space.low[0],\ 81 | envs. action_space.high[0]) 82 | 83 | cov = np.cov(ensemble_action.T.cpu().numpy()) 84 | action = action.cpu().numpy() 85 | 86 | # If the env has only one action then we need to reshape cov 87 | if envs.action_space.__class__.__name__ == "Box": 88 | if envs.action_space.shape[0] == 1: 89 | cov = cov.reshape(-1,1) 90 | 91 | #variance.append(np.matmul(np.matmul(action, cov), action.T).item()) 92 | if isinstance(envs.action_space, gym.spaces.discrete.Discrete): 93 | for action_idx in range(envs.action_space.n): 94 | one_hot_action = torch.FloatTensor(np.eye(self.num_actions)[action_idx]) 95 | variance[action_idx].append(np.matmul(np.matmul(one_hot_action, cov), one_hot_action.T).item()) 96 | else: 97 | variance[action_idx].append(np.matmul(np.matmul(action, cov), action.T).item()) 98 | 99 | 100 | quantiles = {key: np.quantile(np.array(variance[key]), q) for key in list(variance.keys())} 101 | if self.dril_cost_clip == '-1_to_1': 102 | return {key: lambda x: -1 if x > quantiles[key] else 1 for key in list(variance.keys())} 103 | elif self.dril_cost_clip == 'no_clipping': 104 | return {key: lambda x: x for i in list(variance.keys())} 105 | elif self.dril_cost_clip == '-1_to_0': 106 | return {key: lambda x: -1 if x > quantiles[key] else 0 for key in list(variance.keys())} 107 | 108 | def predict_reward(self, actions, states, envs): 109 | rewards = [] 110 | for idx in range(actions.shape[0]): 111 | 112 | # Image observation 113 | if len(self.observation_space.shape) == 3: 114 | state = states[[idx]].repeat(self.ensemble_size, 1,1,1).float().to(self.device) 115 | # Feature observations 116 | else: 117 | state = states[[idx]].repeat(self.ensemble_size, 1).float().to(self.device) 118 | 119 | if isinstance(envs.action_space, gym.spaces.discrete.Discrete): 120 | one_hot_action = torch.FloatTensor(np.eye(self.num_actions)[int(actions[idx].item())]) 121 | action = one_hot_action 122 | action_idx = int(actions[idx].item()) 123 | elif isinstance(envs.action_space, gym.spaces.Box): 124 | action = actions[[idx]] 125 | action_idx = 0 126 | elif isinstance(envs.action_space, gym.spaces.MultiBinary): 127 | raise Exception('Envrionment shouldnt be MultiBinary') 128 | else: 129 | raise Exception("Unknown Action Space") 130 | 131 | with torch.no_grad(): 132 | ensemble_action = self.ensemble(state).squeeze().detach() 133 | 134 | if isinstance(envs.action_space, gym.spaces.Box): 135 | action = torch.clamp(action, envs.action_space.low[0], envs.action_space.high[0]) 136 | ensemble_action = torch.clamp(ensemble_action, envs.action_space.low[0],\ 137 | envs. action_space.high[0]) 138 | 139 | cov = np.cov(ensemble_action.T.cpu().numpy()) 140 | action = action.cpu().numpy() 141 | 142 | # If the env has only one action then we need to reshape cov 143 | if envs.action_space.__class__.__name__ == "Box": 144 | if envs.action_space.shape[0] == 1: 145 | cov = cov.reshape(-1,1) 146 | 147 | ensemble_variance = (np.matmul(np.matmul(action, cov), action.T).item()) 148 | 149 | if action_idx in self.clip_variance: 150 | reward = self.clip_variance[action_idx](ensemble_variance) 151 | else: 152 | reward = -1 153 | rewards.append(reward) 154 | return torch.FloatTensor(np.array(rewards)[np.newaxis].T) 155 | 156 | def normalize_reward(self, state, action, gamma, masks, reward, update_rms=True): 157 | if self.returns is None: 158 | self.returns = reward.clone() 159 | 160 | if update_rms: 161 | self.returns = self.returns * masks * gamma + reward 162 | self.ret_rms.update(self.returns.cpu().numpy()) 163 | 164 | return reward / np.sqrt(self.ret_rms.var[0] + 1e-8) 165 | 166 | def bc_update(self): 167 | for dril_epoch in range(self.num_dril_bc_train_epoch): 168 | dril_train_loss = self.bc.update(update=True, data_loader_type='train') 169 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/algo/ensemble.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import gym 5 | import pandas as pd 6 | import copy 7 | 8 | from dril.a2c_ppo_acktr.algo.behavior_cloning import BehaviorCloning 9 | import dril.a2c_ppo_acktr.ensemble_models as ensemble_models 10 | from baselines.common.running_mean_std import RunningMeanStd 11 | 12 | from torch.utils.data import DataLoader, TensorDataset 13 | 14 | def Ensemble (uncertainty_reward=None, device=None, envs=None,\ 15 | ensemble_hidden_size=None, ensemble_drop_rate=None, ensemble_size=None, ensemble_lr=None,\ 16 | ensemble_batch_size=None, env_name=None, expert_dataset=None,num_trajs=None, seed=None,\ 17 | num_ensemble_train_epoch=None,training_data_split=None, save_model_dir=None, save_results_dir=None): 18 | 19 | ensemble_size = ensemble_size 20 | device = device 21 | env_name = env_name 22 | observation_space = envs.observation_space 23 | 24 | num_inputs = envs.observation_space.shape[0] 25 | try: 26 | num_actions = envs.action_space.n 27 | except: 28 | num_actions = envs.action_space.shape[0] 29 | 30 | ensemble_args = (num_inputs, num_actions, ensemble_hidden_size, ensemble_size) 31 | if len(observation_space.shape) == 3: 32 | if env_name in ['duckietown']: 33 | ensemble_policy = ensemble_models.PolicyEnsembleDuckieTownCNN 34 | elif uncertainty_reward == 'ensemble': 35 | ensemble_policy = ensemble_models.PolicyEnsembleCNN 36 | elif uncertainty_reward == 'dropout': 37 | ensemble_policy = ensemble_models.PolicyEnsembleCNNDropout 38 | else: 39 | raise Exception("Unknown uncertainty_reward type") 40 | else: 41 | if uncertainty_reward == 'ensemble': 42 | ensemble_policy = ensemble_models.PolicyEnsembleMLP 43 | else: 44 | raise Exception("Unknown uncertainty_reward type") 45 | 46 | ensemble_policy = ensemble_policy(*ensemble_args).to(device) 47 | 48 | ensemblebc = BehaviorCloning(ensemble_policy,device, batch_size=ensemble_batch_size,\ 49 | lr=ensemble_lr, envs=envs, training_data_split=training_data_split,\ 50 | expert_dataset=expert_dataset,ensemble_size=ensemble_size ) 51 | 52 | ensemble_model_save_path = os.path.join(save_model_dir, 'ensemble') 53 | ensemble_file_name = f'ensemble_{env_name}_policy_ntrajs={num_trajs}_seed={seed}' 54 | ensemble_model_path = os.path.join(ensemble_model_save_path, f'{ensemble_file_name}.model') 55 | ensemble_results_save_path = os.path.join(save_results_dir, 'ensemble', f'{ensemble_file_name}.perf') 56 | # Check if model already exist 57 | best_test_loss, best_test_model = np.float('inf'), None 58 | if os.path.exists(ensemble_model_path): 59 | best_test_params = torch.load(ensemble_model_path, map_location=device) 60 | print(f'*** Loading ensemble policy: {ensemble_model_path} ***') 61 | else: 62 | ensemble_results = [] 63 | for ensemble_epoch in range(num_ensemble_train_epoch): 64 | ensemble_train_loss = ensemblebc.update(update=True, data_loader_type='train') 65 | with torch.no_grad(): 66 | ensemble_test_loss = ensemblebc.update(update=False, data_loader_type='test') 67 | print(f'ensemble-epoch {ensemble_epoch}/{num_ensemble_train_epoch} | train loss: {ensemble_train_loss:.4f}, test loss: {ensemble_test_loss:.4f}') 68 | ensemble_results.append({'epoch': ensemble_epoch, 'trloss':ensemble_train_loss,\ 69 | 'teloss': ensemble_test_loss, 'test_reward': 0}) 70 | best_test_params = copy.deepcopy(ensemble_policy.state_dict()) 71 | 72 | # Save the Ensemble model and training results 73 | torch.save(best_test_params, ensemble_model_path) 74 | df = pd.DataFrame(ensemble_results, columns=np.hstack(['epoch', 'trloss', 'teloss','test_reward'])) 75 | df.to_csv(ensemble_results_save_path) 76 | 77 | ensemble_policy.load_state_dict(best_test_params) 78 | return ensemble_policy 79 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/algo/gail.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.utils.data 7 | from torch import autograd 8 | import gym 9 | 10 | from torch.utils.data import DataLoader, TensorDataset 11 | 12 | from baselines.common.running_mean_std import RunningMeanStd 13 | from dril.a2c_ppo_acktr.utils import init 14 | 15 | class Flatten(nn.Module): 16 | def forward(self, x): 17 | return x.view(x.size(0), -1) 18 | 19 | 20 | 21 | class Discriminator(nn.Module): 22 | def __init__(self, input_dim, hidden_dim, device, gail_reward_type=None, 23 | clip_gail_action=None, envs=None, disc_lr=None): 24 | super(Discriminator, self).__init__() 25 | 26 | self.device = device 27 | 28 | self.trunk = nn.Sequential( 29 | nn.Linear(input_dim, hidden_dim), nn.Tanh(), 30 | nn.Linear(hidden_dim, hidden_dim), nn.Tanh(), 31 | nn.Linear(hidden_dim, 1)).to(device) 32 | 33 | self.trunk.train() 34 | 35 | self.optimizer = torch.optim.Adam(self.trunk.parameters(), lr=disc_lr) 36 | 37 | self.returns = None 38 | self.ret_rms = RunningMeanStd(shape=()) 39 | 40 | self.reward_type = gail_reward_type 41 | self.clip_gail_action = clip_gail_action 42 | self.action_space = envs.action_space 43 | 44 | def compute_grad_pen(self, 45 | expert_state, 46 | expert_action, 47 | policy_state, 48 | policy_action, 49 | lambda_=10): 50 | alpha = torch.rand(expert_state.size(0), 1) 51 | if self.clip_gail_action and isinstance(self.action_space, gym.spaces.Box): 52 | expert_action = torch.clamp(expert_action, self.action_space.low[0], self.action_space.high[0]) 53 | policy_action = torch.clamp(policy_action, self.action_space.low[0], self.action_space.high[0]) 54 | 55 | expert_data = torch.cat([expert_state, expert_action], dim=1) 56 | policy_data = torch.cat([policy_state, policy_action], dim=1) 57 | 58 | alpha = alpha.expand_as(expert_data).to(expert_data.device) 59 | 60 | mixup_data = alpha * expert_data + (1 - alpha) * policy_data 61 | mixup_data.requires_grad = True 62 | 63 | disc = self.trunk(mixup_data) 64 | ones = torch.ones(disc.size()).to(disc.device) 65 | grad = autograd.grad( 66 | outputs=disc, 67 | inputs=mixup_data, 68 | grad_outputs=ones, 69 | create_graph=True, 70 | retain_graph=True, 71 | only_inputs=True)[0] 72 | 73 | grad_pen = lambda_ * (grad.norm(2, dim=1) - 1).pow(2).mean() 74 | return grad_pen 75 | 76 | def update(self,expert_loader, rollouts, obsfilt=None): 77 | self.train() 78 | 79 | policy_data_generator = rollouts.feed_forward_generator( 80 | None, mini_batch_size=expert_loader.batch_size) 81 | 82 | loss = 0 83 | n = 0 84 | for expert_batch, policy_batch in zip(expert_loader, 85 | policy_data_generator): 86 | policy_state, policy_action = policy_batch[0], policy_batch[2] 87 | policy_d = self.trunk( 88 | torch.cat([policy_state, policy_action], dim=1)) 89 | 90 | expert_state, expert_action = expert_batch 91 | expert_state = obsfilt(expert_state.numpy(), update=False) 92 | expert_state = torch.FloatTensor(expert_state).to(self.device) 93 | expert_action = expert_action.to(self.device) 94 | expert_d = self.trunk( 95 | torch.cat([expert_state, expert_action], dim=1)) 96 | 97 | expert_loss = F.binary_cross_entropy_with_logits( 98 | expert_d, 99 | torch.ones(expert_d.size()).to(self.device)) 100 | policy_loss = F.binary_cross_entropy_with_logits( 101 | policy_d, 102 | torch.zeros(policy_d.size()).to(self.device)) 103 | 104 | gail_loss = expert_loss + policy_loss 105 | grad_pen = self.compute_grad_pen(expert_state, expert_action, 106 | policy_state, policy_action) 107 | 108 | loss += (gail_loss + grad_pen).item() 109 | n += 1 110 | 111 | self.optimizer.zero_grad() 112 | (gail_loss + grad_pen).backward() 113 | self.optimizer.step() 114 | return loss / n 115 | 116 | def predict_reward(self, state, action, gamma, masks, update_rms=True): 117 | with torch.no_grad(): 118 | self.eval() 119 | d = self.trunk(torch.cat([state, action], dim=1)) 120 | s = torch.sigmoid(d) 121 | if self.reward_type == 'unbias': 122 | reward = s.log() - (1 - s).log() 123 | elif self.reward_type == 'favor_zero_reward': 124 | reward = reward = s.log() 125 | elif self.reward_type == 'favor_non_zero_reward': 126 | reward = - (1 - s).log() 127 | 128 | if self.returns is None: 129 | self.returns = reward.clone() 130 | 131 | if update_rms: 132 | self.returns = self.returns * masks * gamma + reward 133 | self.ret_rms.update(self.returns.cpu().numpy()) 134 | 135 | return reward / np.sqrt(self.ret_rms.var[0] + 1e-8) 136 | 137 | 138 | class DiscriminatorCNN(nn.Module): 139 | def __init__(self, obs_shape, hidden_dim, num_actions, device, disc_lr,\ 140 | gail_reward_type=None, envs=None): 141 | super(DiscriminatorCNN, self).__init__() 142 | 143 | self.device = device 144 | 145 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 146 | constant_(x, 0), nn.init.calculate_gain('relu')) 147 | 148 | self.num_actions = num_actions 149 | self.action_emb = nn.Embedding(num_actions, num_actions).cuda() 150 | num_inputs = obs_shape.shape[0] + num_actions 151 | 152 | 153 | self.cnn = nn.Sequential( 154 | init_(nn.Conv2d(num_inputs, 32, 8, stride=4)), nn.ReLU(), 155 | init_(nn.Conv2d(32, 64, 4, stride=2)), nn.ReLU(), 156 | init_(nn.Conv2d(64, 32, 3, stride=1)), nn.ReLU(), Flatten(), 157 | init_(nn.Linear(32 * 7 * 7, hidden_dim)), nn.ReLU()).to(device) 158 | 159 | 160 | self.trunk = nn.Sequential( 161 | nn.Linear(hidden_dim, hidden_dim), nn.Tanh(), 162 | nn.Linear(hidden_dim, hidden_dim), nn.Tanh(), 163 | nn.Linear(hidden_dim, 1)).to(device) 164 | 165 | self.cnn.train() 166 | self.trunk.train() 167 | 168 | self.optimizer = torch.optim.Adam(list(self.trunk.parameters()) + list(self.cnn.parameters()), lr=disc_lr) 169 | 170 | self.returns = None 171 | self.ret_rms = RunningMeanStd(shape=()) 172 | 173 | self.reward_type = gail_reward_type 174 | 175 | def compute_grad_pen(self, 176 | expert_state, 177 | expert_action, 178 | policy_state, 179 | policy_action, 180 | lambda_=10): 181 | alpha = torch.rand(expert_state.size(0), 1) 182 | 183 | ''' 184 | expert_data = torch.cat([expert_state, expert_action], dim=1) 185 | policy_data = torch.cat([policy_state, policy_action], dim=1) 186 | ''' 187 | 188 | expert_data = self.combine_states_actions(expert_state, expert_action, detach=True) 189 | policy_data = self.combine_states_actions(policy_state, policy_action, detach=True) 190 | 191 | alpha = alpha.view(-1, 1, 1, 1).expand_as(expert_data).to(expert_data.device) 192 | 193 | mixup_data = alpha * expert_data + (1 - alpha) * policy_data 194 | mixup_data.requires_grad = True 195 | 196 | disc = self.trunk(self.cnn(mixup_data)) 197 | ones = torch.ones(disc.size()).to(disc.device) 198 | grad = autograd.grad( 199 | outputs=disc, 200 | inputs=mixup_data, 201 | grad_outputs=ones, 202 | create_graph=True, 203 | retain_graph=True, 204 | only_inputs=True)[0] 205 | 206 | grad_pen = lambda_ * (grad.norm(2, dim=1) - 1).pow(2).mean() 207 | return grad_pen 208 | 209 | def combine_states_actions(self, states, actions, detach=False): 210 | batch_size, height, width = states.shape[0], states.shape[2], states.shape[3] 211 | action_emb = self.action_emb(actions).squeeze() 212 | action_emb = action_emb.view(batch_size, self.num_actions, 1, 1).expand(batch_size, self.num_actions, height, width) 213 | if detach: 214 | action_emb = action_emb.detach() 215 | state_actions = torch.cat((states / 255.0, action_emb), dim=1) 216 | return state_actions 217 | 218 | def update(self, expert_loader, rollouts, obsfilt=None): 219 | self.train() 220 | 221 | policy_data_generator = rollouts.feed_forward_generator( 222 | None, mini_batch_size=expert_loader.batch_size) 223 | 224 | loss = 0 225 | n = 0 226 | for expert_batch, policy_batch in zip(expert_loader, 227 | policy_data_generator): 228 | policy_state, policy_action = policy_batch[0], policy_batch[2] 229 | policy_data = self.combine_states_actions(policy_state, policy_action) 230 | policy_d = self.trunk(self.cnn(policy_data)) 231 | 232 | expert_state, expert_action = expert_batch 233 | 234 | if obsfilt is not None: 235 | expert_state = obsfilt(expert_state.numpy(), update=False) 236 | expert_state = torch.FloatTensor(expert_state).to(self.device) 237 | expert_action = expert_action.to(self.device) 238 | expert_state = expert_state.to(self.device) 239 | 240 | expert_data = self.combine_states_actions(expert_state, expert_action) 241 | 242 | expert_d = self.trunk(self.cnn(expert_data)) 243 | 244 | expert_loss = F.binary_cross_entropy_with_logits( 245 | expert_d, 246 | torch.ones(expert_d.size()).to(self.device)) 247 | policy_loss = F.binary_cross_entropy_with_logits( 248 | policy_d, 249 | torch.zeros(policy_d.size()).to(self.device)) 250 | 251 | gail_loss = expert_loss + policy_loss 252 | grad_pen = self.compute_grad_pen(expert_state, expert_action, 253 | policy_state, policy_action) 254 | 255 | loss += (gail_loss + grad_pen).item() 256 | n += 1 257 | self.optimizer.zero_grad() 258 | (gail_loss + grad_pen).backward() 259 | self.optimizer.step() 260 | return loss / n 261 | 262 | def predict_reward(self, state, action, gamma, masks, update_rms=True): 263 | with torch.no_grad(): 264 | self.eval() 265 | policy_data = self.combine_states_actions(state, action) 266 | d = self.trunk(self.cnn(policy_data)) 267 | s = torch.sigmoid(d) 268 | 269 | if self.reward_type == 'unbias': 270 | reward = s.log() - (1 - s).log() 271 | elif self.reward_type == 'favor_zero_reward': 272 | reward = reward = s.log() 273 | elif self.reward_type == 'favor_non_zero_reward': 274 | reward = - (1 - s).log() 275 | 276 | if self.returns is None: 277 | self.returns = reward.clone() 278 | 279 | if update_rms: 280 | self.returns = self.returns * masks * gamma + reward 281 | self.ret_rms.update(self.returns.cpu().numpy()) 282 | 283 | return reward / np.sqrt(self.ret_rms.var[0] + 1e-8) 284 | 285 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/algo/kfac.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | 8 | from dril.a2c_ppo_acktr.utils import AddBias 9 | 10 | # TODO: In order to make this code faster: 11 | # 1) Implement _extract_patches as a single cuda kernel 12 | # 2) Compute QR decomposition in a separate process 13 | # 3) Actually make a general KFAC optimizer so it fits PyTorch 14 | 15 | 16 | def _extract_patches(x, kernel_size, stride, padding): 17 | if padding[0] + padding[1] > 0: 18 | x = F.pad(x, (padding[1], padding[1], padding[0], 19 | padding[0])).data # Actually check dims 20 | x = x.unfold(2, kernel_size[0], stride[0]) 21 | x = x.unfold(3, kernel_size[1], stride[1]) 22 | x = x.transpose_(1, 2).transpose_(2, 3).contiguous() 23 | x = x.view( 24 | x.size(0), x.size(1), x.size(2), 25 | x.size(3) * x.size(4) * x.size(5)) 26 | return x 27 | 28 | 29 | def compute_cov_a(a, classname, layer_info, fast_cnn): 30 | batch_size = a.size(0) 31 | 32 | if classname == 'Conv2d': 33 | if fast_cnn: 34 | a = _extract_patches(a, *layer_info) 35 | a = a.view(a.size(0), -1, a.size(-1)) 36 | a = a.mean(1) 37 | else: 38 | a = _extract_patches(a, *layer_info) 39 | a = a.view(-1, a.size(-1)).div_(a.size(1)).div_(a.size(2)) 40 | elif classname == 'AddBias': 41 | is_cuda = a.is_cuda 42 | a = torch.ones(a.size(0), 1) 43 | if is_cuda: 44 | a = a.cuda() 45 | 46 | return a.t() @ (a / batch_size) 47 | 48 | 49 | def compute_cov_g(g, classname, layer_info, fast_cnn): 50 | batch_size = g.size(0) 51 | 52 | if classname == 'Conv2d': 53 | if fast_cnn: 54 | g = g.view(g.size(0), g.size(1), -1) 55 | g = g.sum(-1) 56 | else: 57 | g = g.transpose(1, 2).transpose(2, 3).contiguous() 58 | g = g.view(-1, g.size(-1)).mul_(g.size(1)).mul_(g.size(2)) 59 | elif classname == 'AddBias': 60 | g = g.view(g.size(0), g.size(1), -1) 61 | g = g.sum(-1) 62 | 63 | g_ = g * batch_size 64 | return g_.t() @ (g_ / g.size(0)) 65 | 66 | 67 | def update_running_stat(aa, m_aa, momentum): 68 | # Do the trick to keep aa unchanged and not create any additional tensors 69 | m_aa *= momentum / (1 - momentum) 70 | m_aa += aa 71 | m_aa *= (1 - momentum) 72 | 73 | 74 | class SplitBias(nn.Module): 75 | def __init__(self, module): 76 | super(SplitBias, self).__init__() 77 | self.module = module 78 | self.add_bias = AddBias(module.bias.data) 79 | self.module.bias = None 80 | 81 | def forward(self, input): 82 | x = self.module(input) 83 | x = self.add_bias(x) 84 | return x 85 | 86 | 87 | class KFACOptimizer(optim.Optimizer): 88 | def __init__(self, 89 | model, 90 | lr=0.25, 91 | momentum=0.9, 92 | stat_decay=0.99, 93 | kl_clip=0.001, 94 | damping=1e-2, 95 | weight_decay=0, 96 | fast_cnn=False, 97 | Ts=1, 98 | Tf=10): 99 | defaults = dict() 100 | 101 | def split_bias(module): 102 | for mname, child in module.named_children(): 103 | if hasattr(child, 'bias') and child.bias is not None: 104 | module._modules[mname] = SplitBias(child) 105 | else: 106 | split_bias(child) 107 | 108 | split_bias(model) 109 | 110 | super(KFACOptimizer, self).__init__(model.parameters(), defaults) 111 | 112 | self.known_modules = {'Linear', 'Conv2d', 'AddBias'} 113 | 114 | self.modules = [] 115 | self.grad_outputs = {} 116 | 117 | self.model = model 118 | self._prepare_model() 119 | 120 | self.steps = 0 121 | 122 | self.m_aa, self.m_gg = {}, {} 123 | self.Q_a, self.Q_g = {}, {} 124 | self.d_a, self.d_g = {}, {} 125 | 126 | self.momentum = momentum 127 | self.stat_decay = stat_decay 128 | 129 | self.lr = lr 130 | self.kl_clip = kl_clip 131 | self.damping = damping 132 | self.weight_decay = weight_decay 133 | 134 | self.fast_cnn = fast_cnn 135 | 136 | self.Ts = Ts 137 | self.Tf = Tf 138 | 139 | self.optim = optim.SGD( 140 | model.parameters(), 141 | lr=self.lr * (1 - self.momentum), 142 | momentum=self.momentum) 143 | 144 | def _save_input(self, module, input): 145 | if torch.is_grad_enabled() and self.steps % self.Ts == 0: 146 | classname = module.__class__.__name__ 147 | layer_info = None 148 | if classname == 'Conv2d': 149 | layer_info = (module.kernel_size, module.stride, 150 | module.padding) 151 | 152 | aa = compute_cov_a(input[0].data, classname, layer_info, 153 | self.fast_cnn) 154 | 155 | # Initialize buffers 156 | if self.steps == 0: 157 | self.m_aa[module] = aa.clone() 158 | 159 | update_running_stat(aa, self.m_aa[module], self.stat_decay) 160 | 161 | def _save_grad_output(self, module, grad_input, grad_output): 162 | # Accumulate statistics for Fisher matrices 163 | if self.acc_stats: 164 | classname = module.__class__.__name__ 165 | layer_info = None 166 | if classname == 'Conv2d': 167 | layer_info = (module.kernel_size, module.stride, 168 | module.padding) 169 | 170 | gg = compute_cov_g(grad_output[0].data, classname, layer_info, 171 | self.fast_cnn) 172 | 173 | # Initialize buffers 174 | if self.steps == 0: 175 | self.m_gg[module] = gg.clone() 176 | 177 | update_running_stat(gg, self.m_gg[module], self.stat_decay) 178 | 179 | def _prepare_model(self): 180 | for module in self.model.modules(): 181 | classname = module.__class__.__name__ 182 | if classname in self.known_modules: 183 | assert not ((classname in ['Linear', 'Conv2d']) and module.bias is not None), \ 184 | "You must have a bias as a separate layer" 185 | 186 | self.modules.append(module) 187 | module.register_forward_pre_hook(self._save_input) 188 | module.register_backward_hook(self._save_grad_output) 189 | 190 | def step(self): 191 | # Add weight decay 192 | if self.weight_decay > 0: 193 | for p in self.model.parameters(): 194 | p.grad.data.add_(self.weight_decay, p.data) 195 | 196 | updates = {} 197 | for i, m in enumerate(self.modules): 198 | assert len(list(m.parameters()) 199 | ) == 1, "Can handle only one parameter at the moment" 200 | classname = m.__class__.__name__ 201 | p = next(m.parameters()) 202 | 203 | la = self.damping + self.weight_decay 204 | 205 | if self.steps % self.Tf == 0: 206 | # My asynchronous implementation exists, I will add it later. 207 | # Experimenting with different ways to this in PyTorch. 208 | self.d_a[m], self.Q_a[m] = torch.symeig( 209 | self.m_aa[m], eigenvectors=True) 210 | self.d_g[m], self.Q_g[m] = torch.symeig( 211 | self.m_gg[m], eigenvectors=True) 212 | 213 | self.d_a[m].mul_((self.d_a[m] > 1e-6).float()) 214 | self.d_g[m].mul_((self.d_g[m] > 1e-6).float()) 215 | 216 | if classname == 'Conv2d': 217 | p_grad_mat = p.grad.data.view(p.grad.data.size(0), -1) 218 | else: 219 | p_grad_mat = p.grad.data 220 | 221 | v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m] 222 | v2 = v1 / ( 223 | self.d_g[m].unsqueeze(1) * self.d_a[m].unsqueeze(0) + la) 224 | v = self.Q_g[m] @ v2 @ self.Q_a[m].t() 225 | 226 | v = v.view(p.grad.data.size()) 227 | updates[p] = v 228 | 229 | vg_sum = 0 230 | for p in self.model.parameters(): 231 | v = updates[p] 232 | vg_sum += (v * p.grad.data * self.lr * self.lr).sum() 233 | 234 | nu = min(1, math.sqrt(self.kl_clip / vg_sum)) 235 | 236 | for p in self.model.parameters(): 237 | v = updates[p] 238 | p.grad.data.copy_(v) 239 | p.grad.data.mul_(nu) 240 | 241 | self.optim.step() 242 | self.steps += 1 243 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/algo/ppo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | from dril.a2c_ppo_acktr.algo.behavior_cloning import BehaviorCloning 7 | 8 | class PPO(): 9 | def __init__(self, 10 | actor_critic, 11 | clip_param, 12 | ppo_epoch, 13 | num_mini_batch, 14 | value_loss_coef, 15 | entropy_coef, 16 | lr=None, 17 | eps=None, 18 | max_grad_norm=None, 19 | use_clipped_value_loss=True, 20 | dril=None): 21 | 22 | self.actor_critic = actor_critic 23 | 24 | self.clip_param = clip_param 25 | self.ppo_epoch = ppo_epoch 26 | self.num_mini_batch = num_mini_batch 27 | 28 | self.value_loss_coef = value_loss_coef 29 | self.entropy_coef = entropy_coef 30 | 31 | self.max_grad_norm = max_grad_norm 32 | self.use_clipped_value_loss = use_clipped_value_loss 33 | 34 | self.optimizer = optim.Adam(actor_critic.parameters(), lr=lr, eps=eps) 35 | 36 | self.dril = dril 37 | 38 | def update(self, rollouts): 39 | advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1] 40 | advantages = (advantages - advantages.mean()) / ( 41 | advantages.std() + 1e-5) 42 | 43 | value_loss_epoch = 0 44 | action_loss_epoch = 0 45 | dist_entropy_epoch = 0 46 | 47 | for e in range(self.ppo_epoch): 48 | if self.actor_critic.is_recurrent: 49 | data_generator = rollouts.recurrent_generator( 50 | advantages, self.num_mini_batch) 51 | else: 52 | data_generator = rollouts.feed_forward_generator( 53 | advantages, self.num_mini_batch) 54 | 55 | for sample in data_generator: 56 | obs_batch, recurrent_hidden_states_batch, actions_batch, \ 57 | value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, \ 58 | adv_targ = sample 59 | 60 | # Reshape to do in a single forward pass for all steps 61 | values, action_log_probs, dist_entropy, _ = self.actor_critic.evaluate_actions( 62 | obs_batch, recurrent_hidden_states_batch, masks_batch, 63 | actions_batch) 64 | 65 | ratio = torch.exp(action_log_probs - 66 | old_action_log_probs_batch) 67 | surr1 = ratio * adv_targ 68 | surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 69 | 1.0 + self.clip_param) * adv_targ 70 | action_loss = -torch.min(surr1, surr2).mean() 71 | 72 | if self.use_clipped_value_loss: 73 | value_pred_clipped = value_preds_batch + \ 74 | (values - value_preds_batch).clamp(-self.clip_param, self.clip_param) 75 | value_losses = (values - return_batch).pow(2) 76 | value_losses_clipped = ( 77 | value_pred_clipped - return_batch).pow(2) 78 | value_loss = 0.5 * torch.max(value_losses, 79 | value_losses_clipped).mean() 80 | else: 81 | value_loss = 0.5 * (return_batch - values).pow(2).mean() 82 | 83 | self.optimizer.zero_grad() 84 | (value_loss * self.value_loss_coef + action_loss - 85 | dist_entropy * self.entropy_coef).backward() 86 | nn.utils.clip_grad_norm_(self.actor_critic.parameters(), 87 | self.max_grad_norm) 88 | self.optimizer.step() 89 | 90 | value_loss_epoch += value_loss.item() 91 | action_loss_epoch += action_loss.item() 92 | dist_entropy_epoch += dist_entropy.item() 93 | 94 | if self.dril: 95 | self.dril.bc_update() 96 | 97 | num_updates = self.ppo_epoch * self.num_mini_batch 98 | 99 | value_loss_epoch /= num_updates 100 | action_loss_epoch /= num_updates 101 | dist_entropy_epoch /= num_updates 102 | 103 | return value_loss_epoch, action_loss_epoch, dist_entropy_epoch 104 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import uuid 4 | 5 | import torch 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser(description='RL') 10 | # Behavior Cloning --------------------------------- 11 | parser.add_argument( 12 | '--bc_lr', type=float, default=2.5e-4, help='behavior cloning learning rate (default: 2.5e-4)') 13 | parser.add_argument( 14 | '--bc_batch_size', type=int, default=100, help='behavior cloning batch size (default: 100') 15 | parser.add_argument( 16 | '--bc_train_epoch', type=int, default=2001, help='behavior cloning training epochs (default=500)') 17 | parser.add_argument( 18 | '--behavior_cloning', default=False, action='store_true', 19 | help='**_Only_** train model with behavior cloning (default: False)') 20 | parser.add_argument( 21 | '--warm_start', default=False, action='store_true', 22 | help='train model with behavior cloning and then train with reinforcement learning starting with learned policy (default: False)') 23 | 24 | # DRIL --------------------------------- 25 | parser.add_argument( 26 | '--dril', default=False, action='store_true', 27 | help='train model using dril (default: False)') 28 | parser.add_argument( 29 | '--dril_uncertainty_reward', choices=['ensemble', 'dropout'], default='ensemble', 30 | help='dril uncertainty score to use for the reward function (default: ensemble)') 31 | parser.add_argument( 32 | '--pretain_ensemble_only', default=False, action='store_true', 33 | help='train the ensemble only and then exit') 34 | parser.add_argument( 35 | '--ensemble_hidden_size', default=512, 36 | help='dril ensemble network number of hidden units (default: 512)') 37 | parser.add_argument( 38 | '--ensemble_drop_rate', default=0.1, 39 | help='dril dropout ensemble netwrok rate (default: 0.1)') 40 | parser.add_argument( 41 | '--ensemble_size', type=int, default=5, 42 | help='numnber of polices in the ensemble (default: 5)') 43 | parser.add_argument( 44 | '--ensemble_batch_size', type=int, default=100, 45 | help='dril ensemble training batch size (default: 100)') 46 | parser.add_argument( 47 | '--ensemble_lr', type=float, default=2.5e-4, 48 | help='dril ensemble learning rate (default: 2.5e-4)') 49 | parser.add_argument( 50 | '--num_ensemble_train_epoch', type=int, default=2001, 51 | help='dril ensemble number of training epoch (default: 500)') 52 | parser.add_argument( 53 | '--ensemble_quantile_threshold', type=float, default=0.98, 54 | help='dril reward quantile threshold (default: 0.98)') 55 | parser.add_argument( 56 | '--num_dril_bc_train_epoch', type=int, default=1, 57 | help='number of epochs to do behavior cloning updates after reinforcement learning updates (default: 1)') 58 | parser.add_argument( 59 | '--ensemble_shuffle_type', 60 | choices=['no_shuffle', 'sample_w_replace', 'norm_shuffle'], 61 | default='sample_w_replace') 62 | #TODO: Think of better way to handle this 63 | parser.add_argument( 64 | '--dril_cost_clip', 65 | choices=['-1_to_1', 'no_clipping', '-1_to_0'], 66 | default='-1_to_1', 67 | help='dril uncertainty reward clipping range "lower bound"_to_"upper bound" (default: -1_to_1)') 68 | parser.add_argument( 69 | '--use_obs_norm', default=False, action='store_true', 70 | help='Normallize the observation (default: False)') 71 | parser.add_argument( 72 | '--pretrain_ensemble_only', default=False, action='store_true', 73 | help='pretrain ensemble only on gpu') 74 | 75 | # GAIL ----------------------------------------------- 76 | parser.add_argument( 77 | '--gail_reward_type', 78 | choices=['unbias', 'favor_zero_reward', 'favor_non_zero_reward'], 79 | default='unbias', 80 | help='specifiy the reward function used by gail (default: unbias)') 81 | 82 | parser.add_argument( 83 | '--clip_gail_action', default=True, action='store_true', 84 | help='continous control actions are clipped, so this clips actions between expert and policy trained (defualt: True)') 85 | parser.add_argument( 86 | '--gail-disc-lr', type=float, default=2.5e-3, 87 | help='learning rate for gail discriminator (default: 2.5e-3)') 88 | 89 | 90 | # General Paramteres --------------------------------- 91 | #TODO: Cleaner way to deal with this 92 | parser.add_argument( 93 | '--atari_max_steps', default=100000, 94 | help='Max steps in atari game') 95 | parser.add_argument( 96 | '--default_experiment_params', choices=['atari', 'continous-control', 'None', 'retro'], default='None', 97 | help='Default params ran in the DRIL experiments') 98 | parser.add_argument( 99 | '--rl_baseline_zoo_dir', type=str, default='rl-baselines-zoo', help='directory of rl baseline zoo') 100 | parser.add_argument( 101 | '--demo_data_dir', type=str, default=f'{os.getcwd()}/demo_data', help='directory of demonstration data') 102 | parser.add_argument( 103 | '--num-trajs', type=int, default=1, help='Number of demonstration trajectories') 104 | parser.add_argument( 105 | '--save-model-dir', 106 | default=f'{os.getcwd()}/trained_models/', 107 | help='directory to save agents (default: ./trained_models/)') 108 | parser.add_argument( 109 | '--save-results-dir', 110 | default='./trained_results/', 111 | help='directory to save agent training logs (default: ./trained_results/)') 112 | parser.add_argument( 113 | '--training_data_split', type=float, default=0.8, 114 | help='training split for the behavior cloning data between (0-1) (default: 0.8)') 115 | parser.add_argument( 116 | '--load_expert', default=False, action='store_true', 117 | help='load pretrained expert from rl-baseline-zoo (default: False)') 118 | parser.add_argument( 119 | '--subsample_frequency', type=int, default=20, 120 | help='frequency to subsample demonstration data (default: 20)') 121 | parser.add_argument( 122 | '--subsample', action='store_true', 123 | default=False, 124 | help='boolean to indicate if the demonstration data will be subsampled (default: False)') 125 | parser.add_argument( 126 | '--norm-reward-stable-baseline', 127 | action='store_true', 128 | default=False, 129 | help='Stable-Basline Normalize reward if applicable (trained with VecNormalize) (default: False)') 130 | parser.add_argument( 131 | '--num_eval_episodes', 132 | default=10, 133 | type=int, 134 | help='Number of evaluation epsiodes (default: 10)') 135 | 136 | parser.add_argument( 137 | '--system', 138 | default='', 139 | type=str) 140 | 141 | # Original Params --------------------------------------- 142 | parser.add_argument( 143 | '--algo', default='a2c', help='algorithm to use: a2c | ppo | acktr') 144 | parser.add_argument( 145 | '--gail', 146 | action='store_true', 147 | default=False, 148 | help='do imitation learning with gail') 149 | parser.add_argument( 150 | '--gail-experts-dir', 151 | default='./gail_experts', 152 | help='directory that contains expert demonstrations for gail') 153 | parser.add_argument( 154 | '--gail-batch-size', 155 | type=int, 156 | default=128, 157 | help='gail batch size (default: 128)') 158 | parser.add_argument( 159 | '--gail-epoch', type=int, default=5, help='gail epochs (default: 5)') 160 | parser.add_argument( 161 | '--lr', type=float, default=7e-4, help='learning rate (default: 7e-4)') 162 | parser.add_argument( 163 | '--eps', 164 | type=float, 165 | default=1e-5, 166 | help='RMSprop optimizer epsilon (default: 1e-5)') 167 | parser.add_argument( 168 | '--alpha', 169 | type=float, 170 | default=0.99, 171 | help='RMSprop optimizer apha (default: 0.99)') 172 | parser.add_argument( 173 | '--gamma', 174 | type=float, 175 | default=0.99, 176 | help='discount factor for rewards (default: 0.99)') 177 | parser.add_argument( 178 | '--use-gae', 179 | action='store_true', 180 | default=False, 181 | help='use generalized advantage estimation') 182 | parser.add_argument( 183 | '--gae-lambda', 184 | type=float, 185 | default=0.95, 186 | help='gae lambda parameter (default: 0.95)') 187 | parser.add_argument( 188 | '--entropy-coef', 189 | type=float, 190 | default=0.01, 191 | help='entropy term coefficient (default: 0.01)') 192 | parser.add_argument( 193 | '--value-loss-coef', 194 | type=float, 195 | default=0.5, 196 | help='value loss coefficient (default: 0.5)') 197 | parser.add_argument( 198 | '--max-grad-norm', 199 | type=float, 200 | default=0.5, 201 | help='max norm of gradients (default: 0.5)') 202 | parser.add_argument( 203 | '--seed', type=int, default=1, help='random seed (default: 1)') 204 | parser.add_argument( 205 | '--cuda-deterministic', 206 | action='store_true', 207 | default=False, 208 | help="sets flags for determinism when using CUDA (potentially slow!)") 209 | parser.add_argument( 210 | '--num-processes', 211 | type=int, 212 | default=16, 213 | help='how many training CPU processes to use (default: 16)') 214 | parser.add_argument( 215 | '--num-steps', 216 | type=int, 217 | default=5, 218 | help='number of forward steps in A2C (default: 5)') 219 | parser.add_argument( 220 | '--ppo-epoch', 221 | type=int, 222 | default=4, 223 | help='number of ppo epochs (default: 4)') 224 | parser.add_argument( 225 | '--num-mini-batch', 226 | type=int, 227 | default=32, 228 | help='number of batches for ppo (default: 32)') 229 | parser.add_argument( 230 | '--clip-param', 231 | type=float, 232 | default=0.2, 233 | help='ppo clip parameter (default: 0.2)') 234 | parser.add_argument( 235 | '--log-interval', 236 | type=int, 237 | default=10, 238 | help='log interval, one log per n updates (default: 10)') 239 | parser.add_argument( 240 | '--save-interval', 241 | type=int, 242 | default=10, 243 | help='save interval, one save per n updates (default: 100)') 244 | parser.add_argument( 245 | '--eval-interval', 246 | type=int, 247 | default=None, 248 | help='eval interval, one eval per n updates (default: None)') 249 | parser.add_argument( 250 | '--num-env-steps', 251 | type=int, 252 | default=10e6, 253 | help='number of environment steps to train (default: 10e6)') 254 | parser.add_argument( 255 | '--env-name', 256 | default='PongNoFrameskip-v4', 257 | help='environment to train on (default: PongNoFrameskip-v4)') 258 | parser.add_argument( 259 | '--log-dir', 260 | default=f'{os.getcwd()}/tmp/{uuid.uuid4()}/tmp/gym/', 261 | help='directory to save agent logs (default: /tmp/gym)') 262 | parser.add_argument( 263 | '--no-cuda', 264 | action='store_true', 265 | default=False, 266 | help='disables CUDA training') 267 | parser.add_argument( 268 | '--use-proper-time-limits', 269 | action='store_true', 270 | default=False, 271 | help='compute returns taking into account time limits') 272 | parser.add_argument( 273 | '--recurrent-policy', 274 | action='store_true', 275 | default=False, 276 | help='use a recurrent policy') 277 | parser.add_argument( 278 | '--use-linear-lr-decay', 279 | action='store_true', 280 | default=False, 281 | help='use a linear schedule on the learning rate') 282 | #args = parser.parse_args() 283 | args, unknown = parser.parse_known_args() 284 | 285 | args.cuda = not args.no_cuda and torch.cuda.is_available() 286 | 287 | assert args.algo in ['a2c', 'ppo', 'acktr'] 288 | if args.recurrent_policy: 289 | assert args.algo in ['a2c', 'ppo'], \ 290 | 'Recurrent policy is not implemented for ACKTR' 291 | 292 | if args.env_name in ['AntBulletEnv-v0']: 293 | args.expert_algo = 'trpo' 294 | else: 295 | args.expert_algo = 'ppo2' 296 | 297 | def create_dir(dir): 298 | try: 299 | os.makedirs(dir) 300 | except OSError: 301 | pass 302 | 303 | if args.env_name in ['duckietown']: 304 | print('** Duckietown only works with 1 process') 305 | args.num_processes = 1 306 | 307 | if args.env_name in ['duckietown', 'highway-v0'] and args.load_expert: 308 | raise Exception("Can not load expert because it does not exist") 309 | 310 | if args.algo == 'acktr': 311 | raise Exception("Code base was not test with acktr: comment this line!") 312 | 313 | if args.default_experiment_params != 'None': 314 | # Continous control default settings 315 | if args.default_experiment_params == 'continous-control': 316 | args.algo = 'ppo' 317 | args.use_gae = True 318 | args.log_interval = 1 319 | args.num_steps = 2048 320 | args.num_processes = 1 321 | args.lr = 3e-4 322 | #args.clip_param = 0.1 323 | args.entropy_coef = 0 324 | args.value_loss_coef = 0.5 325 | args.ppo_epoch = 10 326 | args.num_mini_batch = 32 327 | args.gamma = 0.99 328 | args.gae_lambda = 0.95 329 | args.num_env_steps = 20e6 330 | args.use_linear_lr_decay = True 331 | args.use_proper_time_limits = True 332 | args.eval_interval = 200 333 | args.ensemble_quantile_threshold = 0.98 334 | 335 | args.ensemble_shuffle_type = 'sample_w_replace' 336 | args.bc_lr = 2.5e-4 337 | args.ensemble_lr = 2.5e-4 338 | args.ensemble_size = 5 339 | args.num_ensemble_train_epoch = 2001 340 | args.bc_train_epoch = 2001 341 | args.gail_disc_lr = 1e-3 342 | 343 | if args.gail: 344 | args.num_env_steps = 10e6 345 | 346 | ## Atari default settings 347 | elif args.default_experiment_params == 'atari': 348 | args.algo = 'a2c' 349 | args.use_gae = True 350 | args.lr = 2.5e-3 351 | args.clip_param = 0.1 352 | args.value_loss_coef = 0.5 353 | args.num_processes = 8 354 | args.num_steps = 128 355 | #args.num_mini_batch = 4 356 | args.log_interval = 10 357 | args.use_linear_lr_decay = True 358 | args.entropy_coef = 0.01 359 | args.ensemble_quantile_threshold = 0.98 360 | args.num_env_steps = 20e6 361 | args.eval_interval = 1000 362 | args.ensemble_quantile_threshold = 0.98 363 | args.num_dril_bc_train_epoch = 1 364 | 365 | args.ensemble_shuffle_type = 'sample_w_replace' 366 | args.bc_lr = 2.5e-4 367 | args.ensemble_lr = 2.5e-4 368 | args.ensemble_size = 5 369 | args.num_ensemble_train_epoch = 1001 370 | args.bc_train_epoch = 1001 371 | args.gail_disc_lr = 2.5e-3 372 | 373 | elif args.default_experiment_params == 'retro': 374 | args.algo = 'ppo' 375 | args.use_gae = True 376 | args.log_interval = 1 377 | args.num_steps = 128 378 | args.num_processes = 8 379 | args.lr = 2.5e-4 380 | args.clip_param = 0.1 381 | args.entropy_coef = 0 382 | args.value_loss_coef = 0.5 383 | args.num_mini_batch = 4 384 | args.gamma = 0 385 | args.gae_lambda = 0.95 386 | args.num_env_steps = 100e6 387 | args.use_linear_lr_decay = True 388 | args.use_proper_time_limits = True 389 | args.eval_interval = 200 390 | args.ensemble_quantile_threshold = 0.98 391 | 392 | args.ensemble_shuffle_type = 'sample_w_replace' 393 | args.bc_lr = 2.5e-4 394 | args.ensemble_lr = 2.5e-4 395 | args.ensemble_size = 5 396 | args.num_ensemble_train_epoch = 1001 397 | args.bc_train_epoch = 1001 398 | args.gail_disc_lr = 2.5e-3 399 | else: 400 | raise Exception('Unknown Defult experiments') 401 | 402 | # Ensure directories are created 403 | create_dir(os.path.join(args.save_model_dir, args.algo)) 404 | create_dir(os.path.join(args.save_model_dir, 'bc')) 405 | create_dir(os.path.join(args.save_model_dir, 'ensemble')) 406 | create_dir(os.path.join(args.save_model_dir, 'gail')) 407 | create_dir(os.path.join(args.save_model_dir, 'dril')) 408 | create_dir(os.path.join(args.save_model_dir, 'a2c')) 409 | create_dir(os.path.join(args.save_model_dir, 'ppo')) 410 | 411 | create_dir(os.path.join(args.save_results_dir, args.algo)) 412 | create_dir(os.path.join(args.save_results_dir, 'bc')) 413 | create_dir(os.path.join(args.save_results_dir, 'ensemble')) 414 | create_dir(os.path.join(args.save_results_dir, 'gail')) 415 | create_dir(os.path.join(args.save_results_dir, 'dril')) 416 | create_dir(os.path.join(args.save_results_dir, 'a2c')) 417 | create_dir(os.path.join(args.save_results_dir, 'ppo')) 418 | create_dir(os.path.join(args.save_results_dir, 'expert')) 419 | 420 | return args 421 | 422 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/distributions.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from dril.a2c_ppo_acktr.utils import AddBias, init 8 | 9 | """ 10 | Modify standard PyTorch distributions so they are compatible with this code. 11 | """ 12 | 13 | # 14 | # Standardize distribution interfaces 15 | # 16 | 17 | # Categorical 18 | FixedCategorical = torch.distributions.Categorical 19 | 20 | old_sample = FixedCategorical.sample 21 | FixedCategorical.sample = lambda self: old_sample(self).unsqueeze(-1) 22 | 23 | log_prob_cat = FixedCategorical.log_prob 24 | FixedCategorical.log_probs = lambda self, actions: log_prob_cat( 25 | self, actions.squeeze(-1)).view(actions.size(0), -1).sum(-1).unsqueeze(-1) 26 | 27 | FixedCategorical.mode = lambda self: self.probs.argmax(dim=-1, keepdim=True) 28 | 29 | # Normal 30 | FixedNormal = torch.distributions.Normal 31 | 32 | log_prob_normal = FixedNormal.log_prob 33 | FixedNormal.log_probs = lambda self, actions: log_prob_normal( 34 | self, actions).sum( 35 | -1, keepdim=True) 36 | 37 | normal_entropy = FixedNormal.entropy 38 | FixedNormal.entropy = lambda self: normal_entropy(self).sum(-1) 39 | 40 | FixedNormal.mode = lambda self: self.mean 41 | 42 | # Bernoulli 43 | FixedBernoulli = torch.distributions.Bernoulli 44 | 45 | log_prob_bernoulli = FixedBernoulli.log_prob 46 | FixedBernoulli.log_probs = lambda self, actions: log_prob_bernoulli( 47 | self, actions).view(actions.size(0), -1).sum(-1).unsqueeze(-1) 48 | 49 | bernoulli_entropy = FixedBernoulli.entropy 50 | FixedBernoulli.entropy = lambda self: bernoulli_entropy(self).sum(-1) 51 | FixedBernoulli.mode = lambda self: torch.gt(self.probs, 0.5).float() 52 | 53 | class Categorical(nn.Module): 54 | def __init__(self, num_inputs, num_outputs): 55 | super(Categorical, self).__init__() 56 | 57 | init_ = lambda m: init( 58 | m, 59 | nn.init.orthogonal_, 60 | lambda x: nn.init.constant_(x, 0), 61 | gain=0.01) 62 | 63 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 64 | 65 | def forward(self, x): 66 | x = self.linear(x) 67 | return FixedCategorical(logits=x) 68 | 69 | def get_logits(self, x): 70 | x = self.linear(x) 71 | return FixedCategorical(logits=x).logits 72 | 73 | 74 | class DiagGaussian(nn.Module): 75 | def __init__(self, num_inputs, num_outputs): 76 | super(DiagGaussian, self).__init__() 77 | 78 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 79 | constant_(x, 0)) 80 | 81 | self.fc_mean = init_(nn.Linear(num_inputs, num_outputs)) 82 | self.logstd = AddBias(torch.zeros(num_outputs)) 83 | 84 | def forward(self, x): 85 | action_mean = self.fc_mean(x) 86 | 87 | # An ugly hack for my KFAC implementation. 88 | zeros = torch.zeros(action_mean.size()) 89 | if x.is_cuda: 90 | zeros = zeros.cuda() 91 | 92 | action_logstd = self.logstd(zeros) 93 | return FixedNormal(action_mean, action_logstd.exp()) 94 | 95 | def get_mean(self, x): 96 | return self.fc_mean(x) 97 | 98 | class Bernoulli(nn.Module): 99 | def __init__(self, num_inputs, num_outputs): 100 | super(Bernoulli, self).__init__() 101 | 102 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 103 | constant_(x, 0)) 104 | 105 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 106 | 107 | def forward(self, x): 108 | x = self.linear(x) 109 | return FixedBernoulli(logits=x) 110 | 111 | def get_logits(self, x): 112 | return self.linear(x) 113 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/duckietown/env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import gym_duckietown 3 | 4 | def launch_env(id=None): 5 | env = None 6 | if id is None: 7 | # Launch the environment 8 | from gym_duckietown.simulator import Simulator 9 | env = Simulator( 10 | seed=123, # random seed 11 | map_name="loop_empty", 12 | max_steps=10001, # we don't want the gym to reset itself 13 | domain_rand=0, 14 | camera_width=640, 15 | camera_height=480, 16 | accept_start_angle_deg=4, # start close to straight 17 | full_transparency=True, 18 | distortion=True, 19 | ) 20 | else: 21 | env = gym.make(id) 22 | 23 | return env 24 | 25 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/duckietown/teacher.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # parameters for the pure pursuit controller 5 | POSITION_THRESHOLD = 0.04 6 | REF_VELOCITY = 0.8 7 | GAIN = 10 8 | FOLLOWING_DISTANCE = 0.3 9 | 10 | 11 | class PurePursuitExpert: 12 | def __init__(self, env, ref_velocity=REF_VELOCITY, position_threshold=POSITION_THRESHOLD, 13 | following_distance=FOLLOWING_DISTANCE, max_iterations=1000): 14 | self.env = env.unwrapped 15 | self.following_distance = following_distance 16 | self.max_iterations = max_iterations 17 | self.ref_velocity = ref_velocity 18 | self.position_threshold = position_threshold 19 | 20 | def predict(self, observation): # we don't really care about the observation for this implementation 21 | closest_point, closest_tangent = self.env.closest_curve_point(self.env.cur_pos, self.env.cur_angle) 22 | 23 | iterations = 0 24 | lookup_distance = self.following_distance 25 | curve_point = None 26 | while iterations < self.max_iterations: 27 | # Project a point ahead along the curve tangent, 28 | # then find the closest point to to that 29 | follow_point = closest_point + closest_tangent * lookup_distance 30 | curve_point, _ = self.env.closest_curve_point(follow_point, self.env.cur_angle) 31 | 32 | # If we have a valid point on the curve, stop 33 | if curve_point is not None: 34 | break 35 | 36 | iterations += 1 37 | lookup_distance *= 0.5 38 | 39 | # Compute a normalized vector to the curve point 40 | point_vec = curve_point - self.env.cur_pos 41 | point_vec /= np.linalg.norm(point_vec) 42 | 43 | dot = np.dot(self.env.get_right_vec(), point_vec) 44 | steering = GAIN * -dot 45 | 46 | return self.ref_velocity, steering 47 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/duckietown/wrappers.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import spaces 3 | import numpy as np 4 | 5 | 6 | class ResizeWrapper(gym.ObservationWrapper): 7 | def __init__(self, env=None, shape=(120, 160, 3)): 8 | super(ResizeWrapper, self).__init__(env) 9 | self.observation_space.shape = shape 10 | self.observation_space = spaces.Box( 11 | self.observation_space.low[0, 0, 0], 12 | self.observation_space.high[0, 0, 0], 13 | shape, 14 | dtype=self.observation_space.dtype) 15 | self.shape = shape 16 | 17 | def observation(self, observation): 18 | from PIL import Image 19 | output = np.array(Image.fromarray(observation).resize(self.shape[0:2])) 20 | #return np.array(Image.fromarray(observation).resize(self.shape[0:2])) 21 | return output.reshape((120,160,3)) 22 | 23 | 24 | class NormalizeWrapper(gym.ObservationWrapper): 25 | def __init__(self, env=None): 26 | super(NormalizeWrapper, self).__init__(env) 27 | self.obs_lo = self.observation_space.low[0, 0, 0] 28 | self.obs_hi = self.observation_space.high[0, 0, 0] 29 | obs_shape = self.observation_space.shape 30 | self.observation_space = spaces.Box(0.0, 1.0, obs_shape, dtype=np.float32) 31 | 32 | def observation(self, obs): 33 | if self.obs_lo == 0.0 and self.obs_hi == 1.0: 34 | return obs 35 | else: 36 | return (obs - self.obs_lo) / (self.obs_hi - self.obs_lo) 37 | 38 | 39 | class ImgWrapper(gym.ObservationWrapper): 40 | def __init__(self, env=None): 41 | super(ImgWrapper, self).__init__(env) 42 | obs_shape = self.observation_space.shape 43 | self.observation_space = spaces.Box( 44 | self.observation_space.low[0, 0, 0], 45 | self.observation_space.high[0, 0, 0], 46 | [obs_shape[2], obs_shape[0], obs_shape[1]], 47 | dtype=self.observation_space.dtype) 48 | 49 | def observation(self, observation): 50 | return observation.transpose(2, 0, 1) 51 | 52 | 53 | class DtRewardWrapper(gym.RewardWrapper): 54 | def __init__(self, env): 55 | super(DtRewardWrapper, self).__init__(env) 56 | 57 | def reward(self, reward): 58 | if reward == -1000: 59 | reward = -10 60 | elif reward > 0: 61 | reward += 10 62 | else: 63 | reward += 4 64 | 65 | return reward 66 | 67 | 68 | # Deprecated 69 | class ActionWrapper(gym.ActionWrapper): 70 | def __init__(self, env): 71 | super(ActionWrapper, self).__init__(env) 72 | 73 | def action(self, action): 74 | action_ = [action[0], action[1]] 75 | return action_ 76 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/ensemble_models.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import pdb 5 | #from dril.a2c_ppo_acktr.utils import init 6 | 7 | class Flatten(nn.Module): 8 | def forward(self, x): 9 | return x.view(x.size(0), -1) 10 | 11 | # ensemble of linear layers parallelized for GPU 12 | class EnsembleLinearGPU(nn.Module): 13 | def __init__(self, in_features, out_features, n_ensemble, bias=True): 14 | super(EnsembleLinearGPU, self).__init__() 15 | self.in_features = in_features 16 | self.out_features = out_features 17 | self.n_ensemble = n_ensemble 18 | self.bias = bias 19 | self.weights = nn.Parameter(torch.Tensor(n_ensemble, out_features, in_features)) 20 | if bias: 21 | self.biases = nn.Parameter(torch.Tensor(n_ensemble, out_features)) 22 | else: 23 | self.register_parameter('biases', None) 24 | self.reset_parameters() 25 | 26 | def reset_parameters(self): 27 | for weight in self.weights: 28 | w = nn.Linear(self.in_features, self.out_features) 29 | torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) 30 | if self.biases is not None: 31 | for bias in self.biases: 32 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weights[0]) 33 | bound = 1 / math.sqrt(fan_in) 34 | torch.nn.init.uniform_(bias, -bound, bound) 35 | 36 | def forward(self, inputs): 37 | # check input sizes 38 | if inputs.dim() == 3: 39 | # assuming size is [n_ensemble x batch_size x features] 40 | assert(inputs.size(0) == self.n_ensemble and inputs.size(2) == self.in_features) 41 | elif inputs.dim() == 2: 42 | n_samples, n_features = inputs.size(0), inputs.size(1) 43 | assert (n_samples % self.n_ensemble == 0 and n_features == self.in_features), [n_samples, self.n_ensemble, n_features, self.in_features] 44 | batch_size = int(n_samples / self.n_ensemble) 45 | inputs = inputs.view(self.n_ensemble, batch_size, n_features) 46 | 47 | # reshape to [n_ensemble x n_features x batch_size] 48 | inputs = inputs.permute(0, 2, 1) 49 | outputs = torch.bmm(self.weights, inputs) 50 | outputs = outputs 51 | if self.bias: 52 | outputs = outputs + self.biases.unsqueeze(2) 53 | # reshape to [n_ensemble x batch_size x n_features] 54 | outputs = outputs.permute(0, 2, 1).contiguous() 55 | return outputs 56 | 57 | 58 | class Policy(nn.Module): 59 | def __init__(self, num_inputs, n_actions, n_hidden): 60 | super(Policy, self).__init__() 61 | self.layer1 = nn.Linear(num_inputs, n_hidden) 62 | self.layer2 = nn.Linear(n_hidden, n_hidden) 63 | self.layer3 = nn.Linear(n_hidden, n_actions + 1) 64 | 65 | def forward(self, obs): 66 | h = F.relu(self.layer1(obs)) 67 | h = F.relu(self.layer2(h)) 68 | h = self.layer3(h) 69 | a = h[:, :self.n_actions] 70 | v = h[:, -1] 71 | return a, v 72 | 73 | class ValueNetwork(nn.Module): 74 | def __init__(self, n_inputs, n_actions, n_hidden): 75 | super(ValueNetwork, self).__init__() 76 | self.layer1 = nn.Linear(n_inputs, n_hidden) 77 | self.layer2 = nn.Linear(n_hidden, n_hidden) 78 | self.layer3 = nn.Linear(n_hidden, 1) 79 | 80 | def forward(self, obs): 81 | h = F.relu(self.layer1(obs)) 82 | h = F.relu(self.layer2(h)) 83 | v = self.layer3(h) 84 | return v 85 | 86 | class PolicyEnsembleCNN(nn.Module): 87 | def __init__(self, num_inputs, n_actions, n_hidden, n_ensemble): 88 | super(PolicyEnsembleCNN, self).__init__() 89 | 90 | hidden_size = 512 91 | 92 | self.conv1 = nn.Conv2d(in_channels=num_inputs, out_channels=32, kernel_size=8,\ 93 | stride=4, padding=0, bias=True) 94 | self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4,\ 95 | stride=2, padding=0, bias=True) 96 | self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3,\ 97 | stride=1, padding=0,bias=True) 98 | self.fc1 = nn.Linear(3136, hidden_size) 99 | self.relu = nn.ReLU() 100 | 101 | 102 | self.layer1 = EnsembleLinearGPU(hidden_size, n_hidden, n_ensemble) 103 | self.layer2 = EnsembleLinearGPU(n_hidden, n_hidden, n_ensemble) 104 | self.layer3 = EnsembleLinearGPU(n_hidden, n_actions, n_ensemble) 105 | 106 | def forward(self, obs): 107 | x = self.relu(self.conv1(obs/ 255.0)) 108 | x = self.relu(self.conv2(x)) 109 | x = self.relu(self.conv3(x)) 110 | x = x.permute(0, 2, 3, 1).contiguous() 111 | x = x.view(x.size(0), -1) 112 | out = self.relu(self.fc1(x)) 113 | 114 | h = F.relu(self.layer1(out)) 115 | h = F.relu(self.layer2(h)) 116 | a = self.layer3(h) 117 | return a 118 | 119 | 120 | 121 | 122 | class PolicyEnsembleCNNDropout(nn.Module): 123 | def __init__(self, num_inputs, n_actions, n_hidden, p_dropout=0.1): 124 | super(PolicyEnsembleCNNDropout, self).__init__() 125 | 126 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 127 | constant_(x, 0), nn.init.calculate_gain('relu')) 128 | 129 | hidden_size = 512 130 | self.p_dropout = p_dropout 131 | 132 | self.conv1 = nn.Conv2d(in_channels=num_inputs, out_channels=32, kernel_size=8,\ 133 | stride=4, padding=0, bias=True) 134 | self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4,\ 135 | stride=2, padding=0, bias=True) 136 | self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3,\ 137 | stride=1, padding=0,bias=True) 138 | self.fc1 = nn.Linear(3136, hidden_size) 139 | self.relu = nn.ReLU() 140 | 141 | 142 | self.layer1 = nn.Linear(hidden_size, n_hidden) 143 | self.layer2 = nn.Linear(n_hidden, n_hidden) 144 | self.layer3 = nn.Linear(n_hidden, n_actions) 145 | 146 | def forward(self, obs): 147 | x = self.relu(self.conv1(obs/ 255.0)) 148 | x = F.dropout2d(x, p = self.p_dropout) 149 | x = self.relu(self.conv2(x)) 150 | x = F.dropout2d(x, p = self.p_dropout) 151 | x = self.relu(self.conv3(x)) 152 | x = F.dropout2d(x, p = self.p_dropout) 153 | x = x.permute(0, 2, 3, 1).contiguous() 154 | x = x.view(x.size(0), -1) 155 | out = self.relu(self.fc1(x)) 156 | out = F.dropout(out, p = self.p_dropout) 157 | 158 | h = F.relu(self.layer1(out)) 159 | h = F.dropout(h, p = self.p_dropout) 160 | h = F.relu(self.layer2(h)) 161 | h = F.dropout(h, p = self.p_dropout) 162 | a = self.layer3(h) 163 | return a 164 | 165 | 166 | class PolicyEnsembleMLP(nn.Module): 167 | def __init__(self, n_inputs, n_actions, n_hidden, n_ensemble): 168 | super(PolicyEnsembleMLP, self).__init__() 169 | 170 | self.layer1 = EnsembleLinearGPU(n_inputs, n_hidden, n_ensemble) 171 | self.layer2 = EnsembleLinearGPU(n_hidden, n_hidden, n_ensemble) 172 | self.layer3 = EnsembleLinearGPU(n_hidden, n_actions, n_ensemble) 173 | 174 | def forward(self, obs): 175 | h = F.relu(self.layer1(obs)) 176 | h = F.relu(self.layer2(h)) 177 | a = self.layer3(h) 178 | return a 179 | 180 | class PolicyEnsembleDuckieTownCNN(nn.Module): 181 | def __init__(self, num_inputs, n_actions, n_hidden, n_ensemble): 182 | super(PolicyEnsembleDuckieTownCNN, self).__init__() 183 | 184 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 185 | constant_(x, 0), nn.init.calculate_gain('relu')) 186 | 187 | hidden_size = 512 188 | flat_size = 32 * 9 * 14 189 | 190 | self.lr = nn.LeakyReLU() 191 | 192 | self.conv1 = nn.Conv2d(3, 32, 8, stride=2) 193 | self.conv2 = nn.Conv2d(32, 32, 4, stride=2) 194 | self.conv3 = nn.Conv2d(32, 32, 4, stride=2) 195 | self.conv4 = nn.Conv2d(32, 32, 4, stride=1) 196 | 197 | self.bn1 = nn.BatchNorm2d(32) 198 | self.bn2 = nn.BatchNorm2d(32) 199 | self.bn3 = nn.BatchNorm2d(32) 200 | self.bn4 = nn.BatchNorm2d(32) 201 | 202 | self.dropout = nn.Dropout(.5) 203 | self.lin1 = nn.Linear(flat_size, hidden_size) 204 | 205 | self.layer1 = EnsembleLinearGPU(hidden_size, n_hidden, n_ensemble) 206 | self.layer2 = EnsembleLinearGPU(n_hidden, n_hidden, n_ensemble) 207 | self.layer3 = EnsembleLinearGPU(n_hidden, n_actions, n_ensemble) 208 | 209 | def forward(self, obs): 210 | x = obs 211 | x = self.bn1(self.lr(self.conv1(x))) 212 | x = self.bn2(self.lr(self.conv2(x))) 213 | x = self.bn3(self.lr(self.conv3(x))) 214 | x = self.bn4(self.lr(self.conv4(x))) 215 | x = x.view(x.size(0), -1) # flatten 216 | x = self.dropout(x) 217 | out = self.lr(self.lin1(x)) 218 | 219 | h = F.relu(self.layer1(out)) 220 | h = F.relu(self.layer2(h)) 221 | a = self.layer3(h) 222 | return a 223 | 224 | 225 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/envs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import gym 4 | from gym.wrappers import TimeLimit 5 | import numpy as np 6 | import torch 7 | from gym.spaces.box import Box 8 | import pickle 9 | 10 | from baselines import bench 11 | from baselines.common.atari_wrappers import make_atari, wrap_deepmind, WarpFrame, ClipRewardEnv, FrameStack, ScaledFloatFrame 12 | from baselines.common.vec_env.dummy_vec_env import DummyVecEnv 13 | from baselines.common.vec_env.shmem_vec_env import ShmemVecEnv 14 | from baselines.common.vec_env.vec_normalize import \ 15 | VecNormalize as VecNormalize_ 16 | from baselines.common.retro_wrappers import make_retro #, wrap_deepmind_retro 17 | 18 | from dril.a2c_ppo_acktr.stable_baselines.base_vec_env import VecEnvWrapper 19 | from dril.a2c_ppo_acktr.stable_baselines.running_mean_std import RunningMeanStd 20 | 21 | try: 22 | import dm_control2gym 23 | except ImportError: 24 | pass 25 | 26 | try: 27 | import pybullet_envs 28 | except ImportError: 29 | pass 30 | 31 | env_hyperparam = ['BipedalWalkerHardcore-v2', 'BipedalWalker-v2',\ 32 | 'HalfCheetahBulletEnv-v0', 'HopperBulletEnv-v0',\ 33 | 'HumanoidBulletEnv-v0', 'MinitaurBulletEnv-v0',\ 34 | 'MinitaurBulletDuckEnv-v0', 'Walker2DBulletEnv-v0',\ 35 | 'AntBulletEnv-v0', 'LunarLanderContinuous-v2', 36 | 'CartPole-v1','Acrobot-v1', 'Pendulum-v0', 'MountainCarContinuous-v0', 37 | 'CartPoleContinuousBulletEnv-v0','ReacherBulletEnv-v0'] 38 | 39 | retro_envs = ['SuperMarioKart-Snes', 'StreetFighterIISpecialChampionEdition-Genesis',\ 40 | 'AyrtonSennasSuperMonacoGPII-Genesis'] 41 | 42 | def make_env(env_id, seed, rank, log_dir, allow_early_resets, time=False, max_steps=None): 43 | def _thunk(): 44 | if env_id.startswith("dm"): 45 | _, domain, task = env_id.split('.') 46 | env = dm_control2gym.make(domain_name=domain, task_name=task) 47 | elif env_id in ['duckietown']: 48 | from a2c_ppo_acktr.duckietown.env import launch_env 49 | from a2c_ppo_acktr.duckietown.wrappers import NormalizeWrapper, ImgWrapper,\ 50 | DtRewardWrapper, ActionWrapper, ResizeWrapper 51 | from a2c_ppo_acktr.duckietown.teacher import PurePursuitExpert 52 | env = launch_env() 53 | env = ResizeWrapper(env) 54 | env = NormalizeWrapper(env) 55 | env = ImgWrapper(env) 56 | env = ActionWrapper(env) 57 | env = DtRewardWrapper(env) 58 | elif env_id in retro_envs: 59 | env = make_retro(game=env_id) 60 | #env = SuperMarioKartDiscretizer(env) 61 | else: 62 | env = gym.make(env_id) 63 | 64 | is_atari = hasattr(gym.envs, 'atari') and isinstance( 65 | env.unwrapped, gym.envs.atari.atari_env.AtariEnv) 66 | if is_atari: 67 | env = make_atari(env_id, max_episode_steps=max_steps) 68 | 69 | env.seed(seed + rank) 70 | 71 | #TODO: Figure out what todo here 72 | if is_atari: 73 | env = TimeLimitMask(env) 74 | 75 | if log_dir is not None: 76 | env = bench.Monitor( 77 | env, 78 | os.path.join(log_dir, str(rank)), 79 | allow_early_resets=allow_early_resets) 80 | 81 | if is_atari: 82 | if len(env.observation_space.shape) == 3: 83 | env = wrap_deepmind(env) 84 | elif env_id in retro_envs: 85 | if len(env.observation_space.shape) == 3: 86 | env = wrap_deepmind_retro(env, frame_stack=0) 87 | elif len(env.observation_space.shape) == 3: 88 | if env_id not in ['duckietown'] and env_id not in retro_envs: 89 | raise NotImplementedError( 90 | "CNN models work only for atari,\n" 91 | "please use a custom wrapper for a custom pixel input env.\n" 92 | "See wrap_deepmind for an example.") 93 | 94 | # If the input has shape (W,H,3), wrap for PyTorch convolutions 95 | if env_id not in ['duckietown']: 96 | obs_shape = env.observation_space.shape 97 | if len(obs_shape) == 3 and obs_shape[2] in [1, 3]: 98 | env = TransposeImage(env, op=[2, 0, 1]) 99 | 100 | if time: 101 | env = TimeFeatureWrapper(env) 102 | 103 | return env 104 | 105 | return _thunk 106 | 107 | def wrap_deepmind_retro(env, scale=True, frame_stack=0): 108 | """ 109 | Configure environment for retro games, using config similar to DeepMind-style Atari in wrap_deepmind 110 | """ 111 | env = WarpFrame(env, grayscale=False) 112 | env = ClipRewardEnv(env) 113 | if frame_stack > 1: 114 | env = FrameStack(env, frame_stack) 115 | if scale: 116 | env = ScaledFloatFrame(env) 117 | return env 118 | 119 | class SuperMarioKartDiscretizer(gym.ActionWrapper): 120 | """ 121 | Wrap a gym-retro environment and make it use discrete 122 | actions for the Sonic game. 123 | """ 124 | def __init__(self, env): 125 | super(SuperMarioKartDiscretizer, self).__init__(env) 126 | buttons = ['B', 'Y', 'SELECT', 'START', 'UP', 'DOWN', 'LEFT', 'RIGHT', 'A', 'X', 'L', 'R'] 127 | actions = [['B'], ['B', 'LEFT', 'R'], ['LEFT'], ['B', 'LEFT'], ['B', 'RIGHT'], ['B', 'DOWN', 'LEFT'], ['DOWN', 'RIGHT'], ['RIGHT'], ['DOWN', 'LEFT'], ['RIGHT', 'A'], ['A'], [], ['B', 'R'], ['LEFT', 'A'], ['B', 'UP', 'RIGHT'], ['B', 'RIGHT', 'R'], ['B', 'DOWN'], ['B', 'DOWN', 'RIGHT'], ['B', 'UP', 'LEFT'], ['DOWN', 'RIGHT', 'A'], ['B', 'UP', 'LEFT', 'R'], ['B', 'RIGHT', 'A'], ['B', 'LEFT', 'A'], ['DOWN', 'LEFT', 'A'], ['B', 'A'], ['R']] 128 | self._actions = [] 129 | for action in actions: 130 | arr = np.array([False] * 12) 131 | for button in action: 132 | arr[buttons.index(button)] = True 133 | self._actions.append(arr) 134 | self.action_space = gym.spaces.Discrete(len(self._actions)) 135 | 136 | def action(self, a): # pylint: disable=W0221 137 | try: 138 | assert(len(a) == 1) 139 | return self._actions[a[0]].copy() 140 | except: 141 | return self._actions[a].copy() 142 | 143 | 144 | #TODO: Set max_steps as a hyperparameter 145 | def make_vec_envs(env_name, 146 | seed, 147 | num_processes, 148 | gamma, 149 | log_dir, 150 | device, 151 | allow_early_resets, 152 | max_steps=100000, 153 | num_frame_stack=None, 154 | stats_path=None, 155 | hyperparams=None, 156 | training=False, 157 | norm_obs=False, 158 | time=False, 159 | use_obs_norm=False): 160 | 161 | envs = [ 162 | make_env(env_name, seed, i, log_dir, allow_early_resets, time=time, max_steps=max_steps) 163 | for i in range(num_processes) 164 | ] 165 | 166 | if len(envs) > 1: 167 | envs = ShmemVecEnv(envs, context='fork') 168 | else: 169 | envs = DummyVecEnv(envs) 170 | 171 | if env_name in env_hyperparam and hyperparams is not None: 172 | if stats_path is not None: 173 | if hyperparams['normalize']: 174 | print("Loading running average") 175 | print("with params: {}".format(hyperparams['normalize_kwargs'])) 176 | envs = VecNormalizeBullet(envs, training=False, **hyperparams['normalize_kwargs']) 177 | envs.load_running_average(stats_path) 178 | else: 179 | if len(envs.observation_space.shape) == 1: 180 | if gamma is None: 181 | envs = VecNormalize(envs, ret=False, ob=use_obs_norm) 182 | else: 183 | envs = VecNormalize(envs, gamma=gamma, ob=use_obs_norm) 184 | 185 | envs = VecPyTorch(envs, device) 186 | 187 | if env_name not in ['duckietown']: 188 | if num_frame_stack is not None: 189 | envs = VecPyTorchFrameStack(envs, num_frame_stack, device) 190 | elif len(envs.observation_space.shape) == 3: 191 | envs = VecPyTorchFrameStack(envs, 4, device) 192 | 193 | return envs 194 | 195 | 196 | # Checks whether done was caused my timit limits or not 197 | class TimeLimitMask(gym.Wrapper): 198 | def step(self, action): 199 | obs, rew, done, info = self.env.step(action) 200 | if done and self.env._max_episode_steps == self.env._elapsed_steps: 201 | info['bad_transition'] = True 202 | 203 | return obs, rew, done, info 204 | 205 | def reset(self, **kwargs): 206 | return self.env.reset(**kwargs) 207 | 208 | 209 | # Can be used to test recurrent policies for Reacher-v2 210 | class MaskGoal(gym.ObservationWrapper): 211 | def observation(self, observation): 212 | if self.env._elapsed_steps > 0: 213 | observation[-2:] = 0 214 | return observation 215 | 216 | 217 | class TransposeObs(gym.ObservationWrapper): 218 | def __init__(self, env=None): 219 | """ 220 | Transpose observation space (base class) 221 | """ 222 | super(TransposeObs, self).__init__(env) 223 | 224 | 225 | class TransposeImage(TransposeObs): 226 | def __init__(self, env=None, op=[2, 0, 1]): 227 | """ 228 | Transpose observation space for images 229 | """ 230 | super(TransposeImage, self).__init__(env) 231 | assert len(op) == 3, "Error: Operation, " + str(op) + ", must be dim3" 232 | self.op = op 233 | obs_shape = self.observation_space.shape 234 | self.observation_space = Box( 235 | self.observation_space.low[0, 0, 0], 236 | self.observation_space.high[0, 0, 0], [ 237 | obs_shape[self.op[0]], obs_shape[self.op[1]], 238 | obs_shape[self.op[2]] 239 | ], 240 | dtype=self.observation_space.dtype) 241 | 242 | def observation(self, ob): 243 | return ob.transpose(self.op[0], self.op[1], self.op[2]) 244 | 245 | 246 | class VecPyTorch(VecEnvWrapper): 247 | def __init__(self, venv, device): 248 | """Return only every `skip`-th frame""" 249 | super(VecPyTorch, self).__init__(venv) 250 | self.device = device 251 | # TODO: Fix data types 252 | 253 | def reset(self): 254 | obs = self.venv.reset() 255 | obs = torch.from_numpy(obs).float().to(self.device) 256 | return obs 257 | 258 | def step_async(self, actions): 259 | if isinstance(actions, torch.LongTensor): 260 | # Squeeze the dimension for discrete actions 261 | actions = actions.squeeze(1) 262 | actions = actions.cpu().numpy() 263 | self.venv.step_async(actions) 264 | 265 | def step_wait(self): 266 | obs, reward, done, info = self.venv.step_wait() 267 | obs = torch.from_numpy(obs).float().to(self.device) 268 | reward = torch.from_numpy(reward).unsqueeze(dim=1).float() 269 | return obs, reward, done, info 270 | 271 | 272 | class VecNormalize(VecNormalize_): 273 | def __init__(self, *args, **kwargs): 274 | super(VecNormalize, self).__init__(*args, **kwargs) 275 | self.training = True 276 | 277 | def _obfilt(self, obs, update=True): 278 | if self.ob_rms: 279 | if self.training and update: 280 | self.ob_rms.update(obs) 281 | obs = np.clip((obs - self.ob_rms.mean) / 282 | np.sqrt(self.ob_rms.var + self.epsilon), 283 | -self.clipob, self.clipob) 284 | return obs 285 | else: 286 | return obs 287 | 288 | def train(self): 289 | self.training = True 290 | 291 | def eval(self): 292 | self.training = False 293 | 294 | 295 | # Derived from 296 | # https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_frame_stack.py 297 | class VecPyTorchFrameStack(VecEnvWrapper): 298 | def __init__(self, venv, nstack, device=None): 299 | self.venv = venv 300 | self.nstack = nstack 301 | 302 | wos = venv.observation_space # wrapped ob space 303 | self.shape_dim0 = wos.shape[0] 304 | 305 | low = np.repeat(wos.low, self.nstack, axis=0) 306 | high = np.repeat(wos.high, self.nstack, axis=0) 307 | 308 | if device is None: 309 | device = torch.device('cpu') 310 | self.stacked_obs = torch.zeros((venv.num_envs, ) + 311 | low.shape).to(device) 312 | 313 | observation_space = gym.spaces.Box( 314 | low=low, high=high, dtype=venv.observation_space.dtype) 315 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space) 316 | 317 | def step_wait(self): 318 | obs, rews, news, infos = self.venv.step_wait() 319 | #self.stacked_obs[:, :-self.shape_dim0] = \ 320 | # self.stacked_obs[:, self.shape_dim0:] 321 | self.stacked_obs[:, :-self.shape_dim0] = \ 322 | self.stacked_obs[:, self.shape_dim0:].clone() 323 | for (i, new) in enumerate(news): 324 | if new: 325 | self.stacked_obs[i] = 0 326 | self.stacked_obs[:, -self.shape_dim0:] = obs 327 | return self.stacked_obs, rews, news, infos 328 | 329 | def reset(self): 330 | obs = self.venv.reset() 331 | if torch.backends.cudnn.deterministic: 332 | self.stacked_obs = torch.zeros(self.stacked_obs.shape) 333 | else: 334 | self.stacked_obs.zero_() 335 | self.stacked_obs[:, -self.shape_dim0:] = obs 336 | return self.stacked_obs 337 | 338 | def close(self): 339 | self.venv.close() 340 | 341 | 342 | # Code taken from stable-baselines 343 | class VecNormalizeBullet(VecEnvWrapper): 344 | """ 345 | A moving average, normalizing wrapper for vectorized environment. 346 | has support for saving/loading moving average, 347 | :param venv: (VecEnv) the vectorized environment to wrap 348 | :param training: (bool) Whether to update or not the moving average 349 | :param norm_obs: (bool) Whether to normalize observation or not (default: True) 350 | :param norm_reward: (bool) Whether to normalize rewards or not (default: True) 351 | :param clip_obs: (float) Max absolute value for observation 352 | :param clip_reward: (float) Max value absolute for discounted reward 353 | :param gamma: (float) discount factor 354 | :param epsilon: (float) To avoid division by zero 355 | """ 356 | 357 | def __init__(self, venv, training=True, norm_obs=True, norm_reward=False, 358 | clip_obs=10., clip_reward=10., gamma=0.99, epsilon=1e-8): 359 | VecEnvWrapper.__init__(self, venv) 360 | self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) 361 | self.ret_rms = RunningMeanStd(shape=()) 362 | self.clip_obs = clip_obs 363 | self.clip_reward = clip_reward 364 | # Returns: discounted rewards 365 | self.ret = np.zeros(self.num_envs) 366 | self.gamma = gamma 367 | self.epsilon = epsilon 368 | self.training = training 369 | self.norm_obs = norm_obs 370 | self.norm_reward = norm_reward 371 | self.old_obs = np.array([]) 372 | 373 | def step_wait(self): 374 | """ 375 | Apply sequence of actions to sequence of environments 376 | actions -> (observations, rewards, news) 377 | where 'news' is a boolean vector indicating whether each element is new. 378 | """ 379 | obs, rews, news, infos = self.venv.step_wait() 380 | self.ret = self.ret * self.gamma + rews 381 | if isinstance(self.venv.envs[0], TimeFeatureWrapper): 382 | # Remove index corresponding to time 383 | self.old_obs = obs[:,:-1] 384 | else: 385 | self.old_obs = obs 386 | 387 | obs = self._normalize_observation(obs) 388 | if self.norm_reward: 389 | if self.training: 390 | self.ret_rms.update(self.ret) 391 | rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward) 392 | self.ret[news] = 0 393 | return obs, rews, news, infos 394 | 395 | def _normalize_observation(self, obs): 396 | """ 397 | :param obs: (numpy tensor) 398 | """ 399 | if self.norm_obs: 400 | if self.training: 401 | self.obs_rms.update(obs) 402 | obs = np.clip((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon), -self.clip_obs, self.clip_obs) 403 | return obs 404 | else: 405 | return obs 406 | 407 | def get_original_obs(self): 408 | """ 409 | returns the unnormalized observation 410 | :return: (numpy float) 411 | """ 412 | return self.old_obs 413 | 414 | def reset(self): 415 | """ 416 | Reset all environments 417 | """ 418 | obs = self.venv.reset() 419 | if len(np.array(obs).shape) == 1: # for when num_cpu is 1 420 | #self.old_obs = [obs] 421 | if isinstance(self.venv.envs[0], TimeFeatureWrapper): 422 | # Remove index corresponding to time 423 | self.old_obs = [obs[:,:-1]] 424 | else: 425 | self.old_obs = [obs] 426 | else: 427 | #self.old_obs = obs 428 | if isinstance(self.venv.envs[0], TimeFeatureWrapper): 429 | # Remove index corresponding to time 430 | self.old_obs = obs[:,:-1] 431 | else: 432 | self.old_obs = obs 433 | 434 | self.ret = np.zeros(self.num_envs) 435 | return self._normalize_observation(obs) 436 | 437 | def save_running_average(self, path): 438 | """ 439 | :param path: (str) path to log dir 440 | """ 441 | for rms, name in zip([self.obs_rms], ['obs_rms']): 442 | with open("{}/{}.pkl".format(path, name), 'wb') as file_handler: 443 | pickle.dump(rms, file_handler) 444 | 445 | def load_running_average(self, path): 446 | """ 447 | :param path: (str) path to log dir 448 | """ 449 | #for name in ['obs_rms', 'ret_rms']: 450 | for name in ['obs_rms']: 451 | with open("{}/{}.pkl".format(path, name), 'rb') as file_handler: 452 | setattr(self, name, pickle.load(file_handler)) 453 | 454 | 455 | # Code taken from stable-baslines 456 | class TimeFeatureWrapper(gym.Wrapper): 457 | """ 458 | Add remaining time to observation space for fixed length episodes. 459 | See https://arxiv.org/abs/1712.00378 and https://github.com/aravindr93/mjrl/issues/13. 460 | :param env: (gym.Env) 461 | :param max_steps: (int) Max number of steps of an episode 462 | if it is not wrapped in a TimeLimit object. 463 | :param test_mode: (bool) In test mode, the time feature is constant, 464 | equal to zero. This allow to check that the agent did not overfit this feature, 465 | learning a deterministic pre-defined sequence of actions. 466 | """ 467 | def __init__(self, env, max_steps=1000, test_mode=False): 468 | assert isinstance(env.observation_space, gym.spaces.Box) 469 | # Add a time feature to the observation 470 | low, high = env.observation_space.low, env.observation_space.high 471 | low, high= np.concatenate((low, [0])), np.concatenate((high, [1.])) 472 | env.observation_space = gym.spaces.Box(low=low, high=high, dtype=np.float32) 473 | 474 | super(TimeFeatureWrapper, self).__init__(env) 475 | 476 | if isinstance(env, TimeLimit): 477 | self._max_steps = env._max_episode_steps 478 | else: 479 | self._max_steps = max_steps 480 | self._current_step = 0 481 | self._test_mode = test_mode 482 | self.untimed_obs = None 483 | 484 | def reset(self): 485 | self._current_step = 0 486 | return self._get_obs(self.env.reset()) 487 | 488 | def step(self, action): 489 | self._current_step += 1 490 | obs, reward, done, info = self.env.step(action) 491 | return self._get_obs(obs), reward, done, info 492 | 493 | def get_original_obs(self): 494 | """ 495 | returns the unnormalized observation 496 | :return: (numpy float) 497 | """ 498 | return self.untimed_obs[np.newaxis,:] 499 | 500 | def _get_obs(self, obs): 501 | """ 502 | Concatenate the time feature to the current observation. 503 | :param obs: (np.ndarray) 504 | :return: (np.ndarray) 505 | """ 506 | self.untimed_obs = obs 507 | # Remaining time is more general 508 | time_feature = 1 - (self._current_step / self._max_steps) 509 | if self._test_mode: 510 | time_feature = 1.0 511 | # Optionnaly: concatenate [time_feature, time_feature ** 2] 512 | return np.concatenate((obs, [time_feature])) 513 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/expert_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | 5 | from torch.utils.data import DataLoader, TensorDataset 6 | 7 | class ExpertDataset: 8 | def __init__(self, demo_data_dir, env_name, num_trajs, seed, ensemble_shuffle_type): 9 | self.demo_data_dir = demo_data_dir 10 | self.env_name = env_name 11 | self.num_trajs = num_trajs 12 | self.seed = seed 13 | self.ensemble_shuffle_type = ensemble_shuffle_type 14 | 15 | 16 | def load_demo_data(self, training_data_split, batch_size, ensemble_size): 17 | obs_file = f'{self.demo_data_dir}/obs_{self.env_name}_seed={self.seed}_ntraj={self.num_trajs}.npy' 18 | acs_file = f'{self.demo_data_dir}/acs_{self.env_name}_seed={self.seed}_ntraj={self.num_trajs}.npy' 19 | 20 | print(f'loading: {obs_file}') 21 | obs = torch.from_numpy(np.load(obs_file)) 22 | acs = torch.from_numpy(np.load(acs_file)) 23 | perm = torch.randperm(obs.size(0)) 24 | obs = obs[perm] 25 | acs = acs[perm] 26 | 27 | n_train = int(obs.size(0)*training_data_split) 28 | obs_train = obs[:n_train] 29 | acs_train = acs[:n_train] 30 | obs_test = obs[n_train:] 31 | acs_test = acs[n_train:] 32 | 33 | if self.ensemble_shuffle_type == 'norm_shuffle' or ensemble_size is None: 34 | shuffle = True 35 | elif self.ensemble_shuffle_type == 'no_shuffle' and ensemble_size is not None: 36 | shuffle = False 37 | elif self.ensemble_shuffle_type == 'sample_w_replace' and ensemble_size is not None: 38 | print('***** sample_w_replace *****') 39 | # sample with replacement 40 | obs_train_resamp, acs_train_resamp = [], [] 41 | for k in range(n_train * ensemble_size): 42 | indx = random.randint(0, n_train - 1) 43 | obs_train_resamp.append(obs_train[indx]) 44 | acs_train_resamp.append(acs_train[indx]) 45 | obs_train = torch.stack(obs_train_resamp) 46 | acs_train = torch.stack(acs_train_resamp) 47 | shuffle = False 48 | 49 | tr_batch_size = min(batch_size, len(obs_train)) 50 | # If Droplast is False, insure that that dataset is divisible by 51 | # the number of polices in the ensemble 52 | tr_drop_last = (tr_batch_size!=len(obs_train)) 53 | if not tr_drop_last and ensemble_size is not None: 54 | tr_batch_size = int(ensemble_size * np.floor(tr_batch_size/ensemble_size)) 55 | obs_train = obs_train[:tr_batch_size] 56 | acs_train = acs_train[:tr_batch_size] 57 | trdata = DataLoader(TensorDataset(obs_train, acs_train),\ 58 | batch_size = tr_batch_size, shuffle=shuffle, drop_last=tr_drop_last) 59 | 60 | if len(obs_test) == 0: 61 | tedata = None 62 | else: 63 | te_batch_size = min(batch_size, len(obs_test)) 64 | # If Droplast is False, insure that that dataset is divisible by 65 | # the number of polices in the ensemble 66 | te_drop_last = (te_batch_size!=len(obs_test)) 67 | if not te_drop_last and ensemble_size is not None: 68 | te_batch_size = int(ensemble_size * np.floor(te_batch_size/ensemble_size)) 69 | obs_test = obs_test[:te_batch_size] 70 | acs_test = acs_test[:te_batch_size] 71 | tedata = DataLoader(TensorDataset(obs_test, acs_test),\ 72 | batch_size = te_batch_size, shuffle=shuffle, drop_last=te_drop_last) 73 | return {'trdata':trdata, 'tedata': tedata} 74 | 75 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import os 8 | import random 9 | 10 | from a2c_ppo_acktr.distributions import Bernoulli, Categorical, DiagGaussian 11 | from a2c_ppo_acktr.utils import init 12 | 13 | # Convert weights from tensorflow to pytorch 14 | def copy_mlp_weights(baselines_model): 15 | model_params = baselines_model.get_parameters() 16 | 17 | params = { 18 | 'base.actor.0.weight':model_params['model/pi_fc0/w:0'].T, 19 | 'base.actor.0.bias':model_params['model/pi_fc0/b:0'].squeeze(), 20 | 'base.actor.2.weight':model_params['model/pi_fc1/w:0'].T, 21 | 'base.actor.2.bias':model_params['model/pi_fc1/b:0'].squeeze(), 22 | 'base.critic.0.weight':model_params['model/vf_fc0/w:0'].T, 23 | 'base.critic.0.bias':model_params['model/vf_fc0/b:0'].squeeze(), 24 | 'base.critic.2.weight':model_params['model/vf_fc1/w:0'].T, 25 | 'base.critic.2.bias':model_params['model/vf_fc1/b:0'].squeeze(), 26 | 'base.critic_linear.weight':model_params['model/vf/w:0'].T, 27 | 'base.critic_linear.bias':model_params['model/vf/b:0'], 28 | 'dist.fc_mean.weight':model_params['model/pi/w:0'].T, 29 | 'dist.fc_mean.bias':model_params['model/pi/b:0'], 30 | 'dist.logstd._bias':model_params['model/pi/logstd:0'].T 31 | } 32 | 33 | for key in params.keys(): 34 | params[key] = torch.tensor(params[key]) 35 | return params 36 | 37 | def copy_cnn_weights(baselines_model): 38 | model_params = baselines_model.get_parameters() 39 | 40 | # Convert images to torch format 41 | def conv_to_torch(obs): 42 | obs = np.transpose(obs, (3, 2, 0, 1)) 43 | return obs 44 | 45 | params = { 46 | 'base.conv1.weight':conv_to_torch(model_params['model/c1/w:0']), 47 | 'base.conv1.bias':conv_to_torch(model_params['model/c1/b:0']).squeeze(), 48 | 'base.conv2.weight':conv_to_torch(model_params['model/c2/w:0']), 49 | 'base.conv2.bias':conv_to_torch(model_params['model/c2/b:0']).squeeze(), 50 | 'base.conv3.weight':conv_to_torch(model_params['model/c3/w:0']), 51 | 'base.conv3.bias':conv_to_torch(model_params['model/c3/b:0']).squeeze(), 52 | 'base.fc1.weight': model_params['model/fc1/w:0'].T, 53 | 'base.fc1.bias': model_params['model/fc1/b:0'].squeeze(), 54 | 'base.critic_linear.weight': model_params['model/vf/w:0'].T, 55 | 'base.critic_linear.bias': model_params['model/vf/b:0'], 56 | 'dist.linear.weight': model_params['model/pi/w:0'].T, 57 | 'dist.linear.bias': model_params['model/pi/b:0'].squeeze() 58 | } 59 | 60 | for key in params.keys(): 61 | params[key] = torch.tensor(params[key]) 62 | return params 63 | 64 | 65 | class Flatten(nn.Module): 66 | def forward(self, x): 67 | return x.view(x.size(0), -1) 68 | 69 | 70 | class Policy(nn.Module): 71 | def __init__(self, obs_shape, action_space, base=None, base_kwargs=None, load_expert=None, 72 | env_name=None, rl_baseline_zoo_dir=None, expert_algo=None, normalize=True): 73 | super(Policy, self).__init__() 74 | 75 | #TODO: Pass these parameters in 76 | self.epsilon = 0.1 77 | self.dril = True 78 | 79 | if base_kwargs is None: 80 | base_kwargs = {} 81 | if base is None: 82 | if env_name in ['duckietown']: 83 | base = DuckieTownCNN 84 | elif len(obs_shape) == 3: 85 | base = CNNBase 86 | elif len(obs_shape) == 1: 87 | base = MLPBase 88 | else: 89 | raise NotImplementedError 90 | 91 | self.base = base(obs_shape[0], normalize=normalize, **base_kwargs) 92 | self.action_space = None 93 | if action_space.__class__.__name__ == "Discrete": 94 | num_outputs = action_space.n 95 | self.dist = Categorical(self.base.output_size, num_outputs) 96 | self.action_space = "Discrete" 97 | elif action_space.__class__.__name__ == "Box": 98 | num_outputs = action_space.shape[0] 99 | self.dist = DiagGaussian(self.base.output_size, num_outputs) 100 | self.action_space = "Box" 101 | elif action_space.__class__.__name__ == "MultiBinary": 102 | raise Exception('Error') 103 | else: 104 | raise NotImplementedError 105 | 106 | if load_expert == True and env_name not in ['duckietown', 'highway-v0']: 107 | print('[Loading Expert --- Base]') 108 | model_path = os.path.join(rl_baseline_zoo_dir, 'trained_agents', f'{expert_algo}') 109 | try: 110 | import mpi4py 111 | from stable_baselines import TRPO 112 | except ImportError: 113 | mpi4py = None 114 | DDPG, TRPO = None, None 115 | 116 | from stable_baselines import PPO2 117 | 118 | model_path = f'{model_path}/{env_name}.pkl' 119 | if env_name in ['AntBulletEnv-v0']: 120 | baselines_model = TRPO.load(model_path) 121 | else: 122 | baselines_model = PPO2.load(model_path) 123 | for key, value in baselines_model.get_parameters().items(): 124 | print(key, value.shape) 125 | 126 | if base.__name__ == 'CNNBase': 127 | print(['Loading CNNBase expert model']) 128 | params = copy_cnn_weights(baselines_model) 129 | elif load_expert == True and base.__name__ == 'MLPBase': 130 | print(['Loading MLPBase expert model']) 131 | params = copy_mlp_weights(baselines_model) 132 | 133 | #TODO: I am not sure what this is doing 134 | try: 135 | self.load_state_dict(params) 136 | self.obs_shape = obs_shape[0] 137 | except: 138 | self.base = base(obs_shape[0]+ 1, **base_kwargs) 139 | self.load_state_dict(params) 140 | self.obs_shape = obs_shape[0] +1 141 | 142 | 143 | @property 144 | def is_recurrent(self): 145 | return self.base.is_recurrent 146 | 147 | @property 148 | def recurrent_hidden_state_size(self): 149 | """Size of rnn_hx.""" 150 | return self.base.recurrent_hidden_state_size 151 | 152 | def forward(self, inputs, rnn_hxs, masks): 153 | raise NotImplementedError 154 | 155 | def get_action(self, inputs, deterministic=False): 156 | value, actor_features, rnn_hxs = self.base(inputs, None, None)#, rnn_hxs, masks) 157 | if self.action_space == "Discrete": 158 | return self.dist.get_logits(actor_features) 159 | elif self.action_space == "MultiBinary": 160 | return self.dist.get_logits(actor_features) 161 | elif self.action_space == "Box": 162 | return self.dist.get_mean(actor_features) 163 | 164 | 165 | 166 | def act(self, inputs, rnn_hxs, masks, deterministic=False): 167 | value, actor_features, rnn_hxs = self.base(inputs, rnn_hxs, masks) 168 | dist = self.dist(actor_features) 169 | 170 | if (self.dril and random.random() <= self.epsilon) or deterministic: 171 | action = dist.mode() 172 | else: 173 | action = dist.sample() 174 | 175 | action_log_probs = dist.log_probs(action) 176 | dist_entropy = dist.entropy().mean() 177 | 178 | return value, action, action_log_probs, rnn_hxs 179 | 180 | def get_value(self, inputs, rnn_hxs, masks): 181 | value, _, _ = self.base(inputs, rnn_hxs, masks) 182 | return value 183 | 184 | def evaluate_actions(self, inputs, rnn_hxs, masks, action): 185 | value, actor_features, rnn_hxs = self.base(inputs, rnn_hxs, masks) 186 | dist = self.dist(actor_features) 187 | 188 | action_log_probs = dist.log_probs(action) 189 | dist_entropy = dist.entropy().mean() 190 | 191 | return value, action_log_probs, dist_entropy, rnn_hxs 192 | 193 | 194 | class NNBase(nn.Module): 195 | def __init__(self, recurrent, recurrent_input_size, hidden_size): 196 | super(NNBase, self).__init__() 197 | 198 | self._hidden_size = hidden_size 199 | self._recurrent = recurrent 200 | 201 | if recurrent: 202 | self.gru = nn.GRU(recurrent_input_size, hidden_size) 203 | for name, param in self.gru.named_parameters(): 204 | if 'bias' in name: 205 | nn.init.constant_(param, 0) 206 | elif 'weight' in name: 207 | nn.init.orthogonal_(param) 208 | 209 | @property 210 | def is_recurrent(self): 211 | return self._recurrent 212 | 213 | @property 214 | def recurrent_hidden_state_size(self): 215 | if self._recurrent: 216 | return self._hidden_size 217 | return 1 218 | 219 | @property 220 | def output_size(self): 221 | return self._hidden_size 222 | 223 | def _forward_gru(self, x, hxs, masks): 224 | if x.size(0) == hxs.size(0): 225 | x, hxs = self.gru(x.unsqueeze(0), (hxs * masks).unsqueeze(0)) 226 | x = x.squeeze(0) 227 | hxs = hxs.squeeze(0) 228 | else: 229 | # x is a (T, N, -1) tensor that has been flatten to (T * N, -1) 230 | N = hxs.size(0) 231 | T = int(x.size(0) / N) 232 | 233 | # unflatten 234 | x = x.view(T, N, x.size(1)) 235 | 236 | # Same deal with masks 237 | masks = masks.view(T, N) 238 | 239 | # Let's figure out which steps in the sequence have a zero for any agent 240 | # We will always assume t=0 has a zero in it as that makes the logic cleaner 241 | has_zeros = ((masks[1:] == 0.0) \ 242 | .any(dim=-1) 243 | .nonzero() 244 | .squeeze() 245 | .cpu()) 246 | 247 | # +1 to correct the masks[1:] 248 | if has_zeros.dim() == 0: 249 | # Deal with scalar 250 | has_zeros = [has_zeros.item() + 1] 251 | else: 252 | has_zeros = (has_zeros + 1).numpy().tolist() 253 | 254 | # add t=0 and t=T to the list 255 | has_zeros = [0] + has_zeros + [T] 256 | 257 | hxs = hxs.unsqueeze(0) 258 | outputs = [] 259 | for i in range(len(has_zeros) - 1): 260 | # We can now process steps that don't have any zeros in masks together! 261 | # This is much faster 262 | start_idx = has_zeros[i] 263 | end_idx = has_zeros[i + 1] 264 | 265 | rnn_scores, hxs = self.gru( 266 | x[start_idx:end_idx], 267 | hxs * masks[start_idx].view(1, -1, 1)) 268 | 269 | outputs.append(rnn_scores) 270 | 271 | # assert len(outputs) == T 272 | # x is a (T, N, -1) tensor 273 | x = torch.cat(outputs, dim=0) 274 | # flatten 275 | x = x.view(T * N, -1) 276 | hxs = hxs.squeeze(0) 277 | 278 | return x, hxs 279 | 280 | class CNNBase(NNBase): 281 | def __init__(self, num_inputs, recurrent=False, hidden_size=512, normalize=True): 282 | super(CNNBase, self).__init__(recurrent, hidden_size, hidden_size) 283 | 284 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 285 | constant_(x, 0), nn.init.calculate_gain('relu')) 286 | 287 | self.conv1 = (nn.Conv2d(num_inputs, 32, 8, stride=4)) 288 | self.conv2 = (nn.Conv2d(32, 64, 4, stride=2)) 289 | self.conv3 = (nn.Conv2d(64, 64, 3, stride=1)) 290 | self.fc1 = (nn.Linear(32*7*7*2, hidden_size)) 291 | self.relu = nn.ReLU() 292 | self.flatten = Flatten() 293 | self.critic_linear = (nn.Linear(hidden_size, 1)) 294 | 295 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 296 | constant_(x, 0)) 297 | 298 | self.critic_linear = (nn.Linear(hidden_size, 1)) 299 | self.normalize = normalize 300 | 301 | self.train() 302 | 303 | def forward(self, inputs, rnn_hxs, masks): 304 | if self.normalize: 305 | x = (inputs/ 255.0) 306 | else: 307 | x = inputs 308 | x = self.relu(self.conv1(x)) 309 | x = self.relu(self.conv2(x)) 310 | x = self.relu(self.conv3(x)) 311 | 312 | x = x.permute(0, 2, 3, 1).contiguous() 313 | x = self.flatten(x) 314 | x = self.relu(self.fc1(x)) 315 | 316 | if self.is_recurrent: 317 | x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks) 318 | 319 | return self.critic_linear(x), x, rnn_hxs 320 | 321 | class MLPBase(NNBase): 322 | def __init__(self, num_inputs, recurrent=False, hidden_size=64, normalize=None): 323 | super(MLPBase, self).__init__(recurrent, num_inputs, hidden_size) 324 | 325 | if recurrent: 326 | num_inputs = hidden_size 327 | 328 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 329 | constant_(x, 0), np.sqrt(2)) 330 | 331 | self.actor = nn.Sequential( 332 | init_(nn.Linear(num_inputs, hidden_size)), nn.Tanh(), 333 | init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh()) 334 | 335 | self.critic = nn.Sequential( 336 | init_(nn.Linear(num_inputs, hidden_size)), nn.Tanh(), 337 | init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh()) 338 | 339 | self.critic_linear = init_(nn.Linear(hidden_size, 1)) 340 | 341 | self.train() 342 | 343 | def forward(self, inputs, rnn_hxs, masks): 344 | x = inputs 345 | 346 | if self.is_recurrent: 347 | x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks) 348 | 349 | hidden_critic = self.critic(x) 350 | hidden_actor = self.actor(x) 351 | 352 | return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs 353 | 354 | # https://github.com/duckietown/gym-duckietown/blob/master/learning/imitation/iil-dagger/model/squeezenet.py 355 | class DuckieTownCNN(NNBase): 356 | def __init__(self, num_inputs, recurrent=False, hidden_size=512): 357 | super(DuckieTownCNN, self).__init__(recurrent, hidden_size, hidden_size) 358 | 359 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 360 | constant_(x, 0), nn.init.calculate_gain('relu')) 361 | 362 | flat_size = 32 * 9 * 14 363 | 364 | self.lr = nn.LeakyReLU() 365 | 366 | self.conv1 = nn.Conv2d(3, 32, 8, stride=2) 367 | self.conv2 = nn.Conv2d(32, 32, 4, stride=2) 368 | self.conv3 = nn.Conv2d(32, 32, 4, stride=2) 369 | self.conv4 = nn.Conv2d(32, 32, 4, stride=1) 370 | 371 | self.bn1 = nn.BatchNorm2d(32) 372 | self.bn2 = nn.BatchNorm2d(32) 373 | self.bn3 = nn.BatchNorm2d(32) 374 | self.bn4 = nn.BatchNorm2d(32) 375 | 376 | self.dropout = nn.Dropout(.5) 377 | 378 | self.lin1 = nn.Linear(flat_size, hidden_size) 379 | 380 | 381 | self.actor = nn.Sequential( 382 | init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh(), 383 | init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh()) 384 | 385 | self.critic = nn.Sequential( 386 | init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh(), 387 | init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh()) 388 | 389 | self.critic_linear = init_(nn.Linear(hidden_size, 1)) 390 | 391 | self.train() 392 | 393 | def forward(self, inputs, rnn_hxs, masks): 394 | x = (inputs/255.0) 395 | x = self.bn1(self.lr(self.conv1(x))) 396 | x = self.bn2(self.lr(self.conv2(x))) 397 | x = self.bn3(self.lr(self.conv3(x))) 398 | x = self.bn4(self.lr(self.conv4(x))) 399 | x = x.view(x.size(0), -1) # flatten 400 | x = self.dropout(x) 401 | x = self.lr(self.lin1(x)) 402 | 403 | hidden_critic = self.critic(x) 404 | hidden_actor = self.actor(x) 405 | 406 | return self.critic_linear(x), x, rnn_hxs 407 | 408 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/retro/.gitignore: -------------------------------------------------------------------------------- 1 | log 2 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/retro/README.md: -------------------------------------------------------------------------------- 1 | Below are the steps to add game roms to retro 2 | 3 | 1. Download the rom files 4 | 2. Place the rom files in this folder 5 | 3. run: python3 -m retro.import ./ 6 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/retro/pygame_controller.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pygame 3 | from pygame.locals import * 4 | 5 | # The individual event object that is returned 6 | # This serves as a proxy to pygame's event object 7 | # and the key field is one of the strings in the button list listed below 8 | # in the InputManager's constructor 9 | # This comment is actually longer than the class definition itself. 10 | class InputEvent: 11 | def __init__(self, key, down): 12 | self.key = key 13 | self.down = down 14 | self.up = not down 15 | 16 | # This is where all the magic happens 17 | class InputManager: 18 | def __init__(self): 19 | 20 | self.init_joystick() 21 | 22 | # I like SNES button designations. My decision to use them are arbitrary 23 | # and are only used internally to consistently identify buttons. 24 | # Or you could pretend that these were XBox button layout designations. 25 | # Either way. Up to you. You could change them altogether if you want. 26 | self.buttons = ['up', 'down', 'left', 'right', 'A', 'B', 'X', 'Y', 'L', 'R'] 27 | 28 | # If you would like there to be a keyboard fallback configuration, fill those out 29 | # here in this mapping. If you wanted the keyboard keys to be configurable, you could 30 | # probably copy the same sort of system I use for the joystick configuration for the 31 | # keyboard. But that's getting fancy for a simple tutorial. 32 | self.key_map = { 33 | K_UP : 'up', 34 | K_DOWN : 'down', 35 | K_LEFT : 'left', 36 | K_RIGHT : 'right', 37 | K_a : 'A', 38 | K_b : 'B', 39 | K_x : 'X', 40 | K_y : 'Y', 41 | K_l : 'L', 42 | K_r : 'R', 43 | } 44 | 45 | # This dictionary will tell you which logical buttons are pressed, whether it's 46 | # via the keyboard or joystick 47 | self.keys_pressed = {} 48 | for button in self.buttons: 49 | self.keys_pressed[button] = False 50 | 51 | # This is a list of joystick configurations that will be populated during the 52 | # configuration phase 53 | self.joystick_config = {} 54 | 55 | # Quitting the window is raised as an input event. And typically you also want 56 | # that event raised when the user presses escape which is not something you 57 | # want to configure on the joystick. That's why it's wired separately from 58 | # everything else. When escape is pressed or the user closes the window via its 59 | # chrome, this flag is set to True. 60 | self.quit_attempt = False 61 | 62 | # button is a string of the designation in the list above 63 | def is_pressed(self, button): 64 | return self.keys_pressed[button] 65 | 66 | # This will pump the pygame events. If this is not called every frame, 67 | # then the PyGame window will start to lock up. 68 | # This is basically a proxy method for pygame's event pump and will likewise return 69 | # a list of event proxies. 70 | def get_events(self): 71 | events = [] 72 | for event in pygame.event.get(): 73 | if event.type == QUIT or (event.type == KEYDOWN and event.key == K_ESCAPE): 74 | self.quit_attempt = True 75 | 76 | # This is where the keyboard events are checked 77 | if event.type == KEYDOWN or event.type == KEYUP: 78 | key_pushed_down = event.type == KEYDOWN 79 | button = self.key_map.get(event.key) 80 | if button != None: 81 | events.append(InputEvent(button, key_pushed_down)) 82 | self.keys_pressed[button] = key_pushed_down 83 | 84 | # And this is where each configured button is checked... 85 | for button in self.buttons: 86 | 87 | # determine what something like "Y" actually means in terms of the joystick 88 | config = self.joystick_config.get(button) 89 | if config != None: 90 | 91 | # if the button is configured to an actual button... 92 | if config[0] == 'is_button': 93 | pushed = self.joystick.get_button(config[1]) 94 | if pushed != self.keys_pressed[button]: 95 | events.append(InputEvent(button, pushed)) 96 | self.keys_pressed[button] = pushed 97 | 98 | # if the button is configured to a hat direction... 99 | elif config[0] == 'is_hat': 100 | status = self.joystick.get_hat(config[1]) 101 | if config[2] == 'x': 102 | amount = status[0] 103 | else: 104 | amount = status[1] 105 | pushed = amount == config[3] 106 | if pushed != self.keys_pressed[button]: 107 | events.append(InputEvent(button, pushed)) 108 | self.keys_pressed[button] = pushed 109 | 110 | # if the button is configured to a trackball direction... 111 | elif config[0] == 'is_ball': 112 | status = self.joystick.get_ball(config[1]) 113 | if config[2] == 'x': 114 | amount = status[0] 115 | else: 116 | amount = status[1] 117 | if config[3] == 1: 118 | pushed = amount > 0.5 119 | else: 120 | pushed = amount < -0.5 121 | if pushed != self.keys_pressed[button]: 122 | events.append(InputEvent(button, pushed)) 123 | self.keys_pressed[button] = pushed 124 | 125 | # if the button is configured to an axis direction... 126 | elif config[0] == 'is_axis': 127 | status = self.joystick.get_axis(config[1]) 128 | if config[2] == 1: 129 | pushed = status > 0.5 130 | else: 131 | pushed = status < -0.5 132 | if pushed != self.keys_pressed[button]: 133 | events.append(InputEvent(button, pushed)) 134 | self.keys_pressed[button] = pushed 135 | 136 | return events 137 | 138 | # Any button that is currently pressed on the game pad will be toggled 139 | # to the button designation passed in as the 'button' parameter. 140 | # (as long as it isn't already in use for a different button) 141 | def configure_button(self, button): 142 | 143 | js = self.joystick 144 | 145 | # check buttons for activity... 146 | for button_index in range(js.get_numbuttons()): 147 | button_pushed = js.get_button(button_index) 148 | if button_pushed and not self.is_button_used(button_index): 149 | self.joystick_config[button] = ('is_button', button_index) 150 | return True 151 | 152 | # check hats for activity... 153 | # (hats are the basic direction pads) 154 | for hat_index in range(js.get_numhats()): 155 | hat_status = js.get_hat(hat_index) 156 | if hat_status[0] < -.5 and not self.is_hat_used(hat_index, 'x', -1): 157 | self.joystick_config[button] = ('is_hat', hat_index, 'x', -1) 158 | return True 159 | elif hat_status[0] > .5 and not self.is_hat_used(hat_index, 'x', 1): 160 | self.joystick_config[button] = ('is_hat', hat_index, 'x', 1) 161 | return True 162 | if hat_status[1] < -.5 and not self.is_hat_used(hat_index, 'y', -1): 163 | self.joystick_config[button] = ('is_hat', hat_index, 'y', -1) 164 | return True 165 | elif hat_status[1] > .5 and not self.is_hat_used(hat_index, 'y', 1): 166 | self.joystick_config[button] = ('is_hat', hat_index, 'y', 1) 167 | return True 168 | 169 | # check trackballs for activity... 170 | # (I don't actually have a gamepad with a trackball on it. So this code 171 | # is completely untested! Let me know if it works and is typo-free.) 172 | for ball_index in range(js.get_numballs()): 173 | ball_status = js.get_ball(ball_index) 174 | if ball_status[0] < -.5 and not self.is_ball_used(ball_index, 'x', -1): 175 | self.joystick_config[button] = ('is_ball', ball_index, 'x', -1) 176 | return True 177 | elif ball_status[0] > .5 and not self.is_ball_used(ball_index, 'x', 1): 178 | self.joystick_config[button] = ('is_ball', ball_index, 'x', 1) 179 | return True 180 | if ball_status[1] < -.5 and not self.is_ball_used(ball_index, 'y', -1): 181 | self.joystick_config[button] = ('is_ball', ball_index, 'y', -1) 182 | return True 183 | elif ball_status[1] > .5 and not self.is_ball_used(ball_index, 'y', 1): 184 | self.joystick_config[button] = ('is_ball', ball_index, 'y', 1) 185 | return True 186 | 187 | # check axes for activity... 188 | # (that's plural of axis. Not a tree chopping tool. Although a USB Axe would be awesome!) 189 | for axis_index in range(js.get_numaxes()): 190 | axis_status = js.get_axis(axis_index) 191 | if axis_status < -.5 and not self.is_axis_used(axis_index, -1): 192 | self.joystick_config[button] = ('is_axis', axis_index, -1) 193 | return True 194 | elif axis_status > .5 and not self.is_axis_used(axis_index, 1): 195 | self.joystick_config[button] = ('is_axis', axis_index, 1) 196 | return True 197 | 198 | return False 199 | 200 | # The following 4 methods are helper methods used by the above method 201 | # to determine if a particular button/axis/hat/trackball are already 202 | # configured to a particular button designation 203 | def is_button_used(self, button_index): 204 | for button in self.buttons: 205 | config = self.joystick_config.get(button) 206 | if config != None and config[0] == 'is_button' and config[1] == button_index: 207 | return True 208 | return False 209 | 210 | def is_hat_used(self, hat_index, axis, direction): 211 | for button in self.buttons: 212 | config = self.joystick_config.get(button) 213 | if config != None and config[0] == 'is_hat': 214 | if config[1] == hat_index and config[2] == axis and config[3] == direction: 215 | return True 216 | return False 217 | 218 | def is_ball_used(self, ball_index, axis, direction): 219 | for button in self.buttons: 220 | config = self.joystick_config.get(button) 221 | if config != None and config[0] == 'is_ball': 222 | if config[1] == ball_index and config[2] == axis and config[3] == direction: 223 | return True 224 | return False 225 | 226 | def is_axis_used(self, axis_index, direction): 227 | for button in self.buttons: 228 | config = self.joystick_config.get(button) 229 | if config != None and config[0] == 'is_axis': 230 | if config[1] == axis_index and config[2] == direction: 231 | return True 232 | return False 233 | 234 | # Set joystick information. 235 | # The joystick needs to be plugged in before this method is called (see main() method) 236 | def init_joystick(self): 237 | joystick = pygame.joystick.Joystick(0) 238 | joystick.init() 239 | self.joystick = joystick 240 | self.joystick_name = joystick.get_name() 241 | 242 | # A simple player object. This only keeps track of position. 243 | class Player: 244 | def __init__(self): 245 | self.x = 320 246 | self.y = 240 247 | self.speed = 4 248 | 249 | def move_left(self): 250 | self.x -= self.speed 251 | def move_right(self): 252 | self.x += self.speed 253 | def move_up(self): 254 | self.y -= self.speed 255 | def move_down(self): 256 | self.y += self.speed 257 | 258 | # The main method...duh! 259 | def main(): 260 | 261 | fps = 30 262 | 263 | print("Plug in a USB gamepad. Do it! Do it now! Press enter after you have done this.") 264 | wait_for_enter() 265 | 266 | pygame.init() 267 | 268 | num_joysticks = pygame.joystick.get_count() 269 | if num_joysticks < 1: 270 | print("You didn't plug in a joystick. FORSHAME!") 271 | return 272 | 273 | input_manager = InputManager() 274 | 275 | screen = pygame.display.set_mode((640, 480)) 276 | 277 | button_index = 0 278 | 279 | player = Player() 280 | 281 | 282 | # The main game loop 283 | while not input_manager.quit_attempt: 284 | start = time.time() 285 | 286 | screen.fill((0,0,0)) 287 | 288 | # There will be two phases to our "game". 289 | is_configured = button_index >= len(input_manager.buttons) 290 | 291 | # In the first phase, the user will be prompted to configure the joystick by pressing 292 | # the key that is indicated on the screen 293 | # You would probably do this in an input menu in your real game. 294 | if not is_configured: 295 | success = configure_phase(screen, input_manager.buttons[button_index], input_manager) 296 | # if the user pressed a button and configured it... 297 | if success: 298 | # move on to the next button that needs to be configured 299 | button_index += 1 300 | 301 | # In the second phase, the user will control a "character" on the screen (which will 302 | # be represented by a simple blue ball) that obeys the directional commands, whether 303 | # it's from the joystick or the keyboard. 304 | else: 305 | interaction_phase(screen, player, input_manager) 306 | 307 | pygame.display.flip() 308 | 309 | # maintain frame rate 310 | difference = start - time.time() 311 | delay = 1.0 / fps - difference 312 | if delay > 0: 313 | time.sleep(delay) 314 | 315 | def configure_phase(screen, button, input_manager): 316 | 317 | # need to pump windows events otherwise the window will lock up and die 318 | input_manager.get_events() 319 | 320 | # configure_button looks at the state of ALL buttons pressed on the joystick 321 | # and will map the first pressed button it sees to the current button you pass 322 | # in here. 323 | success = input_manager.configure_button(button) 324 | 325 | # tell user which button to configure 326 | write_text(screen, "Press the " + button + " button", 100, 100) 327 | 328 | # If a joystick button was successfully configured, return True 329 | return success 330 | 331 | def interaction_phase(screen, player, input_manager): 332 | # I dunno. This doesn't do anything. But this is how 333 | # you would access key hit events and the like. 334 | # Ideal for "shooting a weapon" or "jump" sort of events 335 | for event in input_manager.get_events(): 336 | if event.key == 'A' and event.down: 337 | pass # weeeeeeee 338 | if event.key == 'X' and event.up: 339 | input_manager.quit_attempted = True 340 | 341 | # ...but for things like "move in this direction", you want 342 | # to know if a button is pressed and held 343 | 344 | if input_manager.is_pressed('left'): 345 | player.move_left() 346 | elif input_manager.is_pressed('right'): 347 | player.move_right() 348 | if input_manager.is_pressed('up'): 349 | player.move_up() 350 | elif input_manager.is_pressed('down'): 351 | player.move_down() 352 | 353 | # Draw the player 354 | pygame.draw.circle(screen, (0, 0, 255), (player.x, player.y), 20) 355 | 356 | # There was probably a more robust way of doing this. But 357 | # command line interaction was not the point of the tutorial. 358 | def wait_for_enter(): 359 | try: input() 360 | except: pass 361 | 362 | # This renders text on the game screen. 363 | # Also not the point of this tutorial. 364 | cached_text = {} 365 | cached_font = None 366 | def write_text(screen, text, x, y): 367 | global cached_text, cached_font 368 | image = cached_text.get(text) 369 | if image == None: 370 | if cached_font == None: 371 | cached_font = pygame.font.Font(pygame.font.get_default_font(), 12) 372 | image = cached_font.render(text, True, (255, 255, 255)) 373 | cached_text[text] = image 374 | screen.blit(image, (x, y - image.get_height())) 375 | 376 | # Kick things off. 377 | if __name__ == "__main__": 378 | main() 379 | 380 | # fin. 381 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/retro/retro_interactive.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import retro 4 | from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv 5 | from .interactive import Interactive 6 | 7 | 8 | class RetroInteractive(Interactive): 9 | """ 10 | Interactive setup for retro games 11 | """ 12 | def __init__(self, game, state, scenario): 13 | def make_env(): 14 | return retro.make(game=game, state=state, scenario=scenario) 15 | 16 | env = make_env() 17 | self._buttons = env.buttons 18 | env.close() 19 | venv = SubprocVecEnv([make_env]) 20 | super().__init__(venv=venv, sync=False, tps=60, aspect_ratio=4/3) 21 | 22 | def get_screen(self, _obs, venv): 23 | return venv.render(mode='rgb_array') 24 | 25 | def keys_to_act(self, keys): 26 | inputs = { 27 | None: False, 28 | 29 | 'BUTTON': 'Z' in keys, 30 | 'A': 'Z' in keys, 31 | 'B': 'X' in keys, 32 | 33 | 'C': 'C' in keys, 34 | 'X': 'A' in keys, 35 | 'Y': 'S' in keys, 36 | 'Z': 'D' in keys, 37 | 38 | 'L': 'Q' in keys, 39 | 'R': 'W' in keys, 40 | 41 | 'UP': 'UP' in keys, 42 | 'DOWN': 'DOWN' in keys, 43 | 'LEFT': 'LEFT' in keys, 44 | 'RIGHT': 'RIGHT' in keys, 45 | 46 | 'MODE': 'TAB' in keys, 47 | 'SELECT': 'TAB' in keys, 48 | 'RESET': 'ENTER' in keys, 49 | 'START': 'ENTER' in keys, 50 | } 51 | return [inputs[b] for b in self._buttons] 52 | 53 | 54 | def main(): 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument('--game', default='StreetFighterIISpecialChampionEdition-Genesis') 57 | parser.add_argument('--state', default=retro.State.DEFAULT) 58 | parser.add_argument('--scenario', default='scenario') 59 | args = parser.parse_args() 60 | 61 | ia = RetroInteractive(game=args.game, state=args.state, scenario=args.scenario) 62 | ia.run() 63 | 64 | 65 | if __name__ == '__main__': 66 | main() 67 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/retro/retro_joystick.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import warnings 3 | warnings.filterwarnings("ignore") 4 | import sys, time, pdb 5 | import numpy as np 6 | import os 7 | import pygame 8 | from itertools import count 9 | import argparse 10 | import pandas as pd 11 | import uuid 12 | from pygame.locals import * 13 | from pygame_controller import InputManager, wait_for_enter, configure_phase 14 | import retro 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | '--env_name', 19 | choices=['SuperMarioKart-Snes', 'StreetFighterIISpecialChampionEdition-Genesis', 'AyrtonSennasSuperMonacoGPII-Genesis'], 20 | default='AyrtonSennasSuperMonacoGPII-Genesis' ) 21 | parser.add_argument( 22 | '--fps', 23 | default=30) 24 | parser.add_argument('--state', default=retro.State.DEFAULT) 25 | args = parser.parse_args() 26 | 27 | def interaction_phase(env, input_manager): 28 | buttons = env.unwrapped.envs[0].unwrapped.buttons 29 | actions = np.zeros(12) 30 | 31 | for event in input_manager.get_events(): 32 | if event.key == 'A' and event.down: 33 | pass # weeeeeeee 34 | if event.key == 'X' and event.up: 35 | input_manager.quit_attempted = True 36 | 37 | if input_manager.is_pressed('left'): 38 | actions[buttons.index("LEFT")] = 1 39 | if input_manager.is_pressed('right'): 40 | actions[buttons.index("RIGHT")] = 1 41 | if input_manager.is_pressed('up'): 42 | actions[buttons.index("UP")] = 1 43 | if input_manager.is_pressed('down'): 44 | actions[buttons.index("DOWN")] = 1 45 | if input_manager.is_pressed('X'): 46 | actions[buttons.index("X")] = 1 47 | if input_manager.is_pressed('A'): 48 | actions[buttons.index("A")] = 1 49 | if input_manager.is_pressed('B'): 50 | actions[buttons.index("B")] = 1 51 | if input_manager.is_pressed('Y'): 52 | actions[buttons.index("Y")] = 1 53 | if input_manager.is_pressed('L'): 54 | actions[buttons.index("L")] = 1 55 | if input_manager.is_pressed('R'): 56 | actions[buttons.index("R")] = 1 57 | return actions 58 | 59 | def rollout(env, input_manager): 60 | import torch 61 | rtn_obs, rtn_acs, rtn_lens, ep_rewards = [], [], [], [] 62 | obser = env.reset() 63 | skip = 0 64 | total_reward = 0 65 | total_timesteps = 0 66 | 67 | while 1: 68 | start = time.time() 69 | action = interaction_phase(env, input_manager) 70 | total_timesteps += 1 71 | 72 | rtn_obs.append(obser.cpu().numpy().copy()) 73 | rtn_acs.append([action]) 74 | obser, env_reward, done, infos = env.step(torch.tensor(action)) 75 | 76 | for info in infos or done: 77 | if 'episode' in info.keys(): 78 | ep_rewards.append(info['episode']['r']) 79 | 80 | total_reward += env_reward 81 | window_still_open = env.render() 82 | 83 | if done: break 84 | # maintain frame rate 85 | difference = start - time.time() 86 | delay = 1.0 / args.fps - difference 87 | if delay > 0: 88 | time.sleep(delay) 89 | 90 | print("timesteps %i reward %0.2f" % (total_timesteps, total_reward)) 91 | rtn_obs_ = np.concatenate(rtn_obs) 92 | rtn_acs_ = np.concatenate(rtn_acs) 93 | return (rtn_obs_, rtn_acs_, total_reward) 94 | 95 | def setup_controller(): 96 | # configure the pygame controller 97 | print("Plug in a USB gamepad. Do it! Do it now! Press enter after you have done this.") 98 | wait_for_enter() 99 | pygame.init() 100 | 101 | num_joysticks = pygame.joystick.get_count() 102 | if num_joysticks < 1: 103 | print("You didn't plug in a joystick. FORSHAME!") 104 | return 105 | 106 | input_manager = InputManager() 107 | 108 | screen = pygame.display.set_mode((640, 480)) 109 | button_index = 0 110 | 111 | is_configured = False 112 | while not is_configured: 113 | start = time.time() 114 | 115 | screen.fill((0,0,0)) 116 | 117 | # There will be two phases to our "game". 118 | is_configured = button_index >= len(input_manager.buttons) 119 | 120 | # configure the joystrick 121 | if not is_configured: 122 | success = configure_phase(screen, input_manager.buttons[button_index], input_manager) 123 | # if the user pressed a button and configured it... 124 | if success: 125 | # move on to the next button that needs to be configured 126 | button_index += 1 127 | 128 | pygame.display.flip() 129 | 130 | # maintain frame rate 131 | difference = start - time.time() 132 | delay = 1.0 / args.fps - difference 133 | if delay > 0: 134 | time.sleep(delay) 135 | 136 | pygame.display.quit() 137 | return input_manager 138 | 139 | def main(input_manager): 140 | from baselines.common.retro_wrappers import make_retro, wrap_deepmind_retro 141 | from dril.a2c_ppo_acktr.envs import make_vec_envs 142 | import torch 143 | import gym, retro 144 | 145 | log_dir = os.path.expanduser(f'{os.getcwd()}/log') 146 | env = make_vec_envs(args.env_name, 0, 1, None, 147 | log_dir, 'cpu', True, use_obs_norm=False) 148 | 149 | pygame.init() 150 | 151 | # Initialize the joysticks. 152 | pygame.joystick.init() 153 | 154 | ep_rewards = [] 155 | 156 | for num_games in count(1): 157 | env.render() 158 | (rtn_obs_, rtn_acs_, reward) = rollout(env, input_manager) 159 | ep_rewards.append(reward) 160 | 161 | demo_data_dir = os.getcwd() 162 | unique_uuid = uuid.uuid4() 163 | if os.name == 'nt': 164 | desktop = os.path.join(os.path.join(os.environ['USERPROFILE']), 'Desktop') 165 | obs_path = os.path.join(desktop,f'obs_{args.env_name}_seed=0_ntraj=1_{unique_uuid}.npy') 166 | acs_path = os.path.join(desktop,f'acs_{args.env_name}_seed=0_ntraj=1_{unique_uuid}.npy') 167 | else: 168 | obs_path = f'{demo_data_dir}/obs_{args.env_name}_seed=0_ntraj=1_{unique_uuid}.npy' 169 | acs_path = f'{demo_data_dir}/acs_{args.env_name}_seed=0_ntraj=1_{unique_uuid}.npy' 170 | 171 | 172 | np.save(obs_path, rtn_obs_) 173 | np.save(acs_path, rtn_acs_) 174 | 175 | to_continue = input('Continue "y" or "n": ') 176 | if to_continue.lower() == 'y': 177 | pass 178 | else: 179 | break 180 | 181 | print(f'expert: {np.mean(ep_rewards)}') 182 | results_save_path = os.path.join(f'{os.getcwd()}', f'expert_{args.env_name}_seed=0.perf') 183 | results = [{'total_num_steps':0 , 'train_loss': 0, 'test_loss': 0, 'num_trajs': 0 ,\ 184 | 'test_reward':np.mean(ep_rewards), 'u_reward': 0}] 185 | df = pd.DataFrame(results, columns=np.hstack(['x', 'steps', 'train_loss', 'test_loss',\ 186 | 'train_reward', 'test_reward', 'label', 'u_reward'])) 187 | df.to_csv(results_save_path) 188 | 189 | if __name__ == "__main__": 190 | input_manager = setup_controller() 191 | main(input_manager) 192 | 193 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/stable_baselines/base_vec_env.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import inspect 3 | import pickle 4 | 5 | import cloudpickle 6 | from stable_baselines import logger 7 | 8 | class AlreadySteppingError(Exception): 9 | """ 10 | Raised when an asynchronous step is running while 11 | step_async() is called again. 12 | """ 13 | 14 | def __init__(self): 15 | msg = 'already running an async step' 16 | Exception.__init__(self, msg) 17 | 18 | 19 | class NotSteppingError(Exception): 20 | """ 21 | Raised when an asynchronous step is not running but 22 | step_wait() is called. 23 | """ 24 | 25 | def __init__(self): 26 | msg = 'not running an async step' 27 | Exception.__init__(self, msg) 28 | 29 | 30 | class VecEnv(ABC): 31 | """ 32 | An abstract asynchronous, vectorized environment. 33 | :param num_envs: (int) the number of environments 34 | :param observation_space: (Gym Space) the observation space 35 | :param action_space: (Gym Space) the action space 36 | """ 37 | metadata = { 38 | 'render.modes': ['human', 'rgb_array'] 39 | } 40 | 41 | def __init__(self, num_envs, observation_space, action_space): 42 | self.num_envs = num_envs 43 | self.observation_space = observation_space 44 | self.action_space = action_space 45 | 46 | @abstractmethod 47 | def reset(self): 48 | """ 49 | Reset all the environments and return an array of 50 | observations, or a tuple of observation arrays. 51 | If step_async is still doing work, that work will 52 | be cancelled and step_wait() should not be called 53 | until step_async() is invoked again. 54 | :return: ([int] or [float]) observation 55 | """ 56 | pass 57 | 58 | @abstractmethod 59 | def step_async(self, actions): 60 | """ 61 | Tell all the environments to start taking a step 62 | with the given actions. 63 | Call step_wait() to get the results of the step. 64 | You should not call this if a step_async run is 65 | already pending. 66 | """ 67 | pass 68 | 69 | @abstractmethod 70 | def step_wait(self): 71 | """ 72 | Wait for the step taken with step_async(). 73 | :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information 74 | """ 75 | pass 76 | 77 | @abstractmethod 78 | def close(self): 79 | """ 80 | Clean up the environment's resources. 81 | """ 82 | pass 83 | 84 | @abstractmethod 85 | def get_attr(self, attr_name, indices=None): 86 | """ 87 | Return attribute from vectorized environment. 88 | :param attr_name: (str) The name of the attribute whose value to return 89 | :param indices: (list,int) Indices of envs to get attribute from 90 | :return: (list) List of values of 'attr_name' in all environments 91 | """ 92 | pass 93 | 94 | @abstractmethod 95 | def set_attr(self, attr_name, value, indices=None): 96 | """ 97 | Set attribute inside vectorized environments. 98 | :param attr_name: (str) The name of attribute to assign new value 99 | :param value: (obj) Value to assign to `attr_name` 100 | :param indices: (list,int) Indices of envs to assign value 101 | :return: (NoneType) 102 | """ 103 | pass 104 | 105 | @abstractmethod 106 | def env_method(self, method_name, *method_args, indices=None, **method_kwargs): 107 | """ 108 | Call instance methods of vectorized environments. 109 | :param method_name: (str) The name of the environment method to invoke. 110 | :param indices: (list,int) Indices of envs whose method to call 111 | :param method_args: (tuple) Any positional arguments to provide in the call 112 | :param method_kwargs: (dict) Any keyword arguments to provide in the call 113 | :return: (list) List of items returned by the environment's method call 114 | """ 115 | pass 116 | 117 | def step(self, actions): 118 | """ 119 | Step the environments with the given action 120 | :param actions: ([int] or [float]) the action 121 | :return: ([int] or [float], [float], [bool], dict) observation, reward, done, information 122 | """ 123 | self.step_async(actions) 124 | return self.step_wait() 125 | 126 | def get_images(self): 127 | """ 128 | Return RGB images from each environment 129 | """ 130 | raise NotImplementedError 131 | 132 | def render(self, *args, **kwargs): 133 | """ 134 | Gym environment rendering 135 | :param mode: (str) the rendering type 136 | """ 137 | logger.warn('Render not defined for %s' % self) 138 | 139 | @property 140 | def unwrapped(self): 141 | if isinstance(self, VecEnvWrapper): 142 | return self.venv.unwrapped 143 | else: 144 | return self 145 | 146 | def getattr_depth_check(self, name, already_found): 147 | """Check if an attribute reference is being hidden in a recursive call to __getattr__ 148 | :param name: (str) name of attribute to check for 149 | :param already_found: (bool) whether this attribute has already been found in a wrapper 150 | :return: (str or None) name of module whose attribute is being shadowed, if any. 151 | """ 152 | if hasattr(self, name) and already_found: 153 | return "{0}.{1}".format(type(self).__module__, type(self).__name__) 154 | else: 155 | return None 156 | 157 | def _get_indices(self, indices): 158 | """ 159 | Convert a flexibly-typed reference to environment indices to an implied list of indices. 160 | :param indices: (None,int,Iterable) refers to indices of envs. 161 | :return: (list) the implied list of indices. 162 | """ 163 | if indices is None: 164 | indices = range(self.num_envs) 165 | elif isinstance(indices, int): 166 | indices = [indices] 167 | return indices 168 | 169 | 170 | class VecEnvWrapper(VecEnv): 171 | """ 172 | Vectorized environment base class 173 | :param venv: (VecEnv) the vectorized environment to wrap 174 | :param observation_space: (Gym Space) the observation space (can be None to load from venv) 175 | :param action_space: (Gym Space) the action space (can be None to load from venv) 176 | """ 177 | 178 | def __init__(self, venv, observation_space=None, action_space=None): 179 | self.venv = venv 180 | VecEnv.__init__(self, num_envs=venv.num_envs, observation_space=observation_space or venv.observation_space, 181 | action_space=action_space or venv.action_space) 182 | self.class_attributes = dict(inspect.getmembers(self.__class__)) 183 | 184 | def step_async(self, actions): 185 | self.venv.step_async(actions) 186 | 187 | @abstractmethod 188 | def reset(self): 189 | pass 190 | 191 | @abstractmethod 192 | def step_wait(self): 193 | pass 194 | 195 | def close(self): 196 | return self.venv.close() 197 | 198 | def render(self, *args, **kwargs): 199 | return self.venv.render(*args, **kwargs) 200 | 201 | def get_images(self): 202 | return self.venv.get_images() 203 | 204 | def get_attr(self, attr_name, indices=None): 205 | return self.venv.get_attr(attr_name, indices) 206 | 207 | def set_attr(self, attr_name, value, indices=None): 208 | return self.venv.set_attr(attr_name, value, indices) 209 | 210 | def env_method(self, method_name, *method_args, indices=None, **method_kwargs): 211 | return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs) 212 | 213 | def __getattr__(self, name): 214 | """Find attribute from wrapped venv(s) if this wrapper does not have it. 215 | Useful for accessing attributes from venvs which are wrapped with multiple wrappers 216 | which have unique attributes of interest. 217 | """ 218 | blocked_class = self.getattr_depth_check(name, already_found=False) 219 | if blocked_class is not None: 220 | own_class = "{0}.{1}".format(type(self).__module__, type(self).__name__) 221 | format_str = ("Error: Recursive attribute lookup for {0} from {1} is " 222 | "ambiguous and hides attribute from {2}") 223 | raise AttributeError(format_str.format(name, own_class, blocked_class)) 224 | 225 | return self.getattr_recursive(name) 226 | 227 | def _get_all_attributes(self): 228 | """Get all (inherited) instance and class attributes 229 | :return: (dict) all_attributes 230 | """ 231 | all_attributes = self.__dict__.copy() 232 | all_attributes.update(self.class_attributes) 233 | return all_attributes 234 | 235 | def getattr_recursive(self, name): 236 | """Recursively check wrappers to find attribute. 237 | :param name (str) name of attribute to look for 238 | :return: (object) attribute 239 | """ 240 | all_attributes = self._get_all_attributes() 241 | if name in all_attributes: # attribute is present in this wrapper 242 | attr = getattr(self, name) 243 | elif hasattr(self.venv, 'getattr_recursive'): 244 | # Attribute not present, child is wrapper. Call getattr_recursive rather than getattr 245 | # to avoid a duplicate call to getattr_depth_check. 246 | attr = self.venv.getattr_recursive(name) 247 | else: # attribute not present, child is an unwrapped VecEnv 248 | attr = getattr(self.venv, name) 249 | 250 | return attr 251 | 252 | def getattr_depth_check(self, name, already_found): 253 | """See base class. 254 | :return: (str or None) name of module whose attribute is being shadowed, if any. 255 | """ 256 | all_attributes = self._get_all_attributes() 257 | if name in all_attributes and already_found: 258 | # this venv's attribute is being hidden because of a higher venv. 259 | shadowed_wrapper_class = "{0}.{1}".format(type(self).__module__, type(self).__name__) 260 | elif name in all_attributes and not already_found: 261 | # we have found the first reference to the attribute. Now check for duplicates. 262 | shadowed_wrapper_class = self.venv.getattr_depth_check(name, True) 263 | else: 264 | # this wrapper does not have the attribute. Keep searching. 265 | shadowed_wrapper_class = self.venv.getattr_depth_check(name, already_found) 266 | 267 | return shadowed_wrapper_class 268 | 269 | 270 | class CloudpickleWrapper(object): 271 | def __init__(self, var): 272 | """ 273 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 274 | :param var: (Any) the variable you wish to wrap for pickling with cloudpickle 275 | """ 276 | self.var = var 277 | 278 | def __getstate__(self): 279 | return cloudpickle.dumps(self.var) 280 | 281 | def __setstate__(self, obs): 282 | self.var = pickle.loads(obs) 283 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/stable_baselines/running_mean_std.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class RunningMeanStd(object): 5 | def __init__(self, epsilon=1e-4, shape=()): 6 | """ 7 | calulates the running mean and std of a data stream 8 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 9 | :param epsilon: (float) helps with arithmetic issues 10 | :param shape: (tuple) the shape of the data stream's output 11 | """ 12 | self.mean = np.zeros(shape, 'float64') 13 | self.var = np.ones(shape, 'float64') 14 | self.count = epsilon 15 | 16 | def update(self, arr): 17 | batch_mean = np.mean(arr, axis=0) 18 | batch_var = np.var(arr, axis=0) 19 | batch_count = arr.shape[0] 20 | self.update_from_moments(batch_mean, batch_var, batch_count) 21 | 22 | def update_from_moments(self, batch_mean, batch_var, batch_count): 23 | delta = batch_mean - self.mean 24 | tot_count = self.count + batch_count 25 | 26 | new_mean = self.mean + delta * batch_count / tot_count 27 | m_a = self.var * self.count 28 | m_b = batch_var * batch_count 29 | m_2 = m_a + m_b + np.square(delta) * self.count * batch_count / (self.count + batch_count) 30 | new_var = m_2 / (self.count + batch_count) 31 | 32 | new_count = batch_count + self.count 33 | 34 | self.mean = new_mean 35 | self.var = new_var 36 | self.count = new_count 37 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/storage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 3 | 4 | 5 | def _flatten_helper(T, N, _tensor): 6 | return _tensor.view(T * N, *_tensor.size()[2:]) 7 | 8 | 9 | class RolloutStorage(object): 10 | def __init__(self, num_steps, num_processes, obs_shape, action_space, 11 | recurrent_hidden_state_size): 12 | self.obs = torch.zeros(num_steps + 1, num_processes, *obs_shape) 13 | self.recurrent_hidden_states = torch.zeros( 14 | num_steps + 1, num_processes, recurrent_hidden_state_size) 15 | self.rewards = torch.zeros(num_steps, num_processes, 1) 16 | self.value_preds = torch.zeros(num_steps + 1, num_processes, 1) 17 | self.returns = torch.zeros(num_steps + 1, num_processes, 1) 18 | self.action_log_probs = torch.zeros(num_steps, num_processes, 1) 19 | if action_space.__class__.__name__ == 'Discrete': 20 | action_shape = 1 21 | else: 22 | action_shape = action_space.shape[0] 23 | self.actions = torch.zeros(num_steps, num_processes, action_shape) 24 | if action_space.__class__.__name__ == 'Discrete': 25 | self.actions = self.actions.long() 26 | self.masks = torch.ones(num_steps + 1, num_processes, 1) 27 | 28 | # Masks that indicate whether it's a true terminal state 29 | # or time limit end state 30 | self.bad_masks = torch.ones(num_steps + 1, num_processes, 1) 31 | 32 | self.num_steps = num_steps 33 | self.step = 0 34 | 35 | def to(self, device): 36 | self.obs = self.obs.to(device) 37 | self.recurrent_hidden_states = self.recurrent_hidden_states.to(device) 38 | self.rewards = self.rewards.to(device) 39 | self.value_preds = self.value_preds.to(device) 40 | self.returns = self.returns.to(device) 41 | self.action_log_probs = self.action_log_probs.to(device) 42 | self.actions = self.actions.to(device) 43 | self.masks = self.masks.to(device) 44 | self.bad_masks = self.bad_masks.to(device) 45 | 46 | def insert(self, obs, recurrent_hidden_states, actions, action_log_probs, 47 | value_preds, rewards, masks, bad_masks): 48 | self.obs[self.step + 1].copy_(obs) 49 | self.recurrent_hidden_states[self.step + 50 | 1].copy_(recurrent_hidden_states) 51 | self.actions[self.step].copy_(actions) 52 | self.action_log_probs[self.step].copy_(action_log_probs) 53 | self.value_preds[self.step].copy_(value_preds) 54 | self.rewards[self.step].copy_(rewards) 55 | self.masks[self.step + 1].copy_(masks) 56 | self.bad_masks[self.step + 1].copy_(bad_masks) 57 | 58 | self.step = (self.step + 1) % self.num_steps 59 | 60 | def after_update(self): 61 | self.obs[0].copy_(self.obs[-1]) 62 | self.recurrent_hidden_states[0].copy_(self.recurrent_hidden_states[-1]) 63 | self.masks[0].copy_(self.masks[-1]) 64 | self.bad_masks[0].copy_(self.bad_masks[-1]) 65 | 66 | def compute_returns(self, 67 | next_value, 68 | use_gae, 69 | gamma, 70 | gae_lambda, 71 | use_proper_time_limits=True): 72 | if use_proper_time_limits: 73 | if use_gae: 74 | self.value_preds[-1] = next_value 75 | gae = 0 76 | for step in reversed(range(self.rewards.size(0))): 77 | delta = self.rewards[step] + gamma * self.value_preds[ 78 | step + 1] * self.masks[step + 79 | 1] - self.value_preds[step] 80 | gae = delta + gamma * gae_lambda * self.masks[step + 81 | 1] * gae 82 | gae = gae * self.bad_masks[step + 1] 83 | self.returns[step] = gae + self.value_preds[step] 84 | else: 85 | self.returns[-1] = next_value 86 | for step in reversed(range(self.rewards.size(0))): 87 | self.returns[step] = (self.returns[step + 1] * \ 88 | gamma * self.masks[step + 1] + self.rewards[step]) * self.bad_masks[step + 1] \ 89 | + (1 - self.bad_masks[step + 1]) * self.value_preds[step] 90 | else: 91 | if use_gae: 92 | self.value_preds[-1] = next_value 93 | gae = 0 94 | for step in reversed(range(self.rewards.size(0))): 95 | delta = self.rewards[step] + gamma * self.value_preds[ 96 | step + 1] * self.masks[step + 97 | 1] - self.value_preds[step] 98 | gae = delta + gamma * gae_lambda * self.masks[step + 99 | 1] * gae 100 | self.returns[step] = gae + self.value_preds[step] 101 | else: 102 | self.returns[-1] = next_value 103 | for step in reversed(range(self.rewards.size(0))): 104 | self.returns[step] = self.returns[step + 1] * \ 105 | gamma * self.masks[step + 1] + self.rewards[step] 106 | 107 | def feed_forward_generator(self, 108 | advantages, 109 | num_mini_batch=None, 110 | mini_batch_size=None): 111 | num_steps, num_processes = self.rewards.size()[0:2] 112 | batch_size = num_processes * num_steps 113 | 114 | if mini_batch_size is None: 115 | assert batch_size >= num_mini_batch, ( 116 | "PPO requires the number of processes ({}) " 117 | "* number of steps ({}) = {} " 118 | "to be greater than or equal to the number of PPO mini batches ({})." 119 | "".format(num_processes, num_steps, num_processes * num_steps, 120 | num_mini_batch)) 121 | mini_batch_size = batch_size // num_mini_batch 122 | sampler = BatchSampler( 123 | SubsetRandomSampler(range(batch_size)), 124 | mini_batch_size, 125 | drop_last=True) 126 | for indices in sampler: 127 | obs_batch = self.obs[:-1].view(-1, *self.obs.size()[2:])[indices] 128 | recurrent_hidden_states_batch = self.recurrent_hidden_states[:-1].view( 129 | -1, self.recurrent_hidden_states.size(-1))[indices] 130 | actions_batch = self.actions.view(-1, 131 | self.actions.size(-1))[indices] 132 | value_preds_batch = self.value_preds[:-1].view(-1, 1)[indices] 133 | return_batch = self.returns[:-1].view(-1, 1)[indices] 134 | masks_batch = self.masks[:-1].view(-1, 1)[indices] 135 | old_action_log_probs_batch = self.action_log_probs.view(-1, 136 | 1)[indices] 137 | if advantages is None: 138 | adv_targ = None 139 | else: 140 | adv_targ = advantages.view(-1, 1)[indices] 141 | 142 | yield obs_batch, recurrent_hidden_states_batch, actions_batch, \ 143 | value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ 144 | 145 | def recurrent_generator(self, advantages, num_mini_batch): 146 | num_processes = self.rewards.size(1) 147 | assert num_processes >= num_mini_batch, ( 148 | "PPO requires the number of processes ({}) " 149 | "to be greater than or equal to the number of " 150 | "PPO mini batches ({}).".format(num_processes, num_mini_batch)) 151 | num_envs_per_batch = num_processes // num_mini_batch 152 | perm = torch.randperm(num_processes) 153 | for start_ind in range(0, num_processes, num_envs_per_batch): 154 | obs_batch = [] 155 | recurrent_hidden_states_batch = [] 156 | actions_batch = [] 157 | value_preds_batch = [] 158 | return_batch = [] 159 | masks_batch = [] 160 | old_action_log_probs_batch = [] 161 | adv_targ = [] 162 | 163 | for offset in range(num_envs_per_batch): 164 | ind = perm[start_ind + offset] 165 | obs_batch.append(self.obs[:-1, ind]) 166 | recurrent_hidden_states_batch.append( 167 | self.recurrent_hidden_states[0:1, ind]) 168 | actions_batch.append(self.actions[:, ind]) 169 | value_preds_batch.append(self.value_preds[:-1, ind]) 170 | return_batch.append(self.returns[:-1, ind]) 171 | masks_batch.append(self.masks[:-1, ind]) 172 | old_action_log_probs_batch.append( 173 | self.action_log_probs[:, ind]) 174 | adv_targ.append(advantages[:, ind]) 175 | 176 | T, N = self.num_steps, num_envs_per_batch 177 | # These are all tensors of size (T, N, -1) 178 | obs_batch = torch.stack(obs_batch, 1) 179 | actions_batch = torch.stack(actions_batch, 1) 180 | value_preds_batch = torch.stack(value_preds_batch, 1) 181 | return_batch = torch.stack(return_batch, 1) 182 | masks_batch = torch.stack(masks_batch, 1) 183 | old_action_log_probs_batch = torch.stack( 184 | old_action_log_probs_batch, 1) 185 | adv_targ = torch.stack(adv_targ, 1) 186 | 187 | # States is just a (N, -1) tensor 188 | recurrent_hidden_states_batch = torch.stack( 189 | recurrent_hidden_states_batch, 1).view(N, -1) 190 | 191 | # Flatten the (T, N, ...) tensors to (T * N, ...) 192 | obs_batch = _flatten_helper(T, N, obs_batch) 193 | actions_batch = _flatten_helper(T, N, actions_batch) 194 | value_preds_batch = _flatten_helper(T, N, value_preds_batch) 195 | return_batch = _flatten_helper(T, N, return_batch) 196 | masks_batch = _flatten_helper(T, N, masks_batch) 197 | old_action_log_probs_batch = _flatten_helper(T, N, \ 198 | old_action_log_probs_batch) 199 | adv_targ = _flatten_helper(T, N, adv_targ) 200 | 201 | yield obs_batch, recurrent_hidden_states_batch, actions_batch, \ 202 | value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ 203 | -------------------------------------------------------------------------------- /dril/a2c_ppo_acktr/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import yaml 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from dril.a2c_ppo_acktr.envs import VecNormalize 9 | 10 | def get_saved_hyperparams(stats_path, norm_reward=False, test_mode=False): 11 | """ 12 | :param stats_path: (str) 13 | :param norm_reward: (bool) 14 | :param test_mode: (bool) 15 | :return: (dict, str) 16 | """ 17 | hyperparams = {} 18 | if not os.path.isdir(stats_path): 19 | stats_path = None 20 | else: 21 | config_file = os.path.join(stats_path, 'config.yml') 22 | if os.path.isfile(config_file): 23 | # Load saved hyperparameters 24 | with open(os.path.join(stats_path, 'config.yml'), 'r') as f: 25 | hyperparams = yaml.load(f, Loader=yaml.FullLoader) 26 | hyperparams['normalize'] = hyperparams.get('normalize', False) 27 | else: 28 | obs_rms_path = os.path.join(stats_path, 'obs_rms.pkl') 29 | hyperparams['normalize'] = os.path.isfile(obs_rms_path) 30 | 31 | # Load normalization params 32 | if hyperparams['normalize']: 33 | if isinstance(hyperparams['normalize'], str): 34 | normalize_kwargs = eval(hyperparams['normalize']) 35 | if test_mode: 36 | normalize_kwargs['norm_reward'] = norm_reward 37 | else: 38 | normalize_kwargs = {'norm_obs': hyperparams['normalize'], 'norm_reward': norm_reward} 39 | hyperparams['normalize_kwargs'] = normalize_kwargs 40 | return hyperparams, stats_path 41 | 42 | 43 | 44 | # Get a render function 45 | def get_render_func(venv): 46 | if hasattr(venv, 'envs'): 47 | return venv.envs[0].render 48 | elif hasattr(venv, 'venv'): 49 | return get_render_func(venv.venv) 50 | elif hasattr(venv, 'env'): 51 | return get_render_func(venv.env) 52 | 53 | return None 54 | 55 | 56 | def get_vec_normalize(venv): 57 | if isinstance(venv, VecNormalize): 58 | return venv 59 | elif hasattr(venv, 'venv'): 60 | return get_vec_normalize(venv.venv) 61 | 62 | return None 63 | 64 | 65 | # Necessary for my KFAC implementation. 66 | class AddBias(nn.Module): 67 | def __init__(self, bias): 68 | super(AddBias, self).__init__() 69 | self._bias = nn.Parameter(bias.unsqueeze(1)) 70 | 71 | def forward(self, x): 72 | if x.dim() == 2: 73 | bias = self._bias.t().view(1, -1) 74 | else: 75 | bias = self._bias.t().view(1, -1, 1, 1) 76 | 77 | return x + bias 78 | 79 | 80 | def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr): 81 | """Decreases the learning rate linearly""" 82 | lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs))) 83 | for param_group in optimizer.param_groups: 84 | param_group['lr'] = lr 85 | 86 | 87 | def init(module, weight_init, bias_init, gain=1): 88 | weight_init(module.weight.data, gain=gain) 89 | bias_init(module.bias.data) 90 | return module 91 | 92 | 93 | def cleanup_log_dir(log_dir): 94 | try: 95 | os.makedirs(log_dir) 96 | except OSError: 97 | files = glob.glob(os.path.join(log_dir, '*.monitor.csv')) 98 | for f in files: 99 | os.remove(f) 100 | -------------------------------------------------------------------------------- /dril/enjoy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | # workaround to unpickle olf model files 4 | import sys 5 | import time 6 | import numpy as np 7 | from PIL import Image 8 | import glob 9 | import PIL 10 | from PIL import Image 11 | from PIL import ImageFont 12 | from PIL import ImageDraw 13 | import sys 14 | from skimage import transform 15 | from torch.distributions import Categorical 16 | 17 | import numpy as np 18 | import torch 19 | from dril.a2c_ppo_acktr.model import Policy 20 | 21 | from dril.a2c_ppo_acktr.envs import VecPyTorch, make_vec_envs 22 | from dril.a2c_ppo_acktr.utils import get_render_func, get_vec_normalize 23 | from dril.a2c_ppo_acktr.algo.ensemble import Ensemble 24 | import dril.a2c_ppo_acktr.ensemble_models as ensemble_models 25 | from dril.a2c_ppo_acktr.arguments import get_args 26 | from dril.a2c_ppo_acktr.algo.behavior_cloning import BehaviorCloning 27 | from dril.a2c_ppo_acktr.algo.dril import DRIL 28 | 29 | import gym, os 30 | import numpy as np 31 | import argparse 32 | import random 33 | import pandas as pd 34 | 35 | import sys 36 | import torch 37 | from gym import wrappers 38 | import random 39 | 40 | from dril.a2c_ppo_acktr.envs import make_vec_envs 41 | from dril.a2c_ppo_acktr.model import Policy 42 | from dril.a2c_ppo_acktr.utils import get_saved_hyperparams 43 | from dril.a2c_ppo_acktr.arguments import get_args 44 | 45 | 46 | import os 47 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 48 | 49 | sys.path.append('a2c_ppo_acktr') 50 | 51 | parser = argparse.ArgumentParser(description='RL') 52 | parser.add_argument( 53 | '--seed', type=int, default=1, help='random seed (default: 1)') 54 | parser.add_argument( 55 | '--log-interval', 56 | type=int, 57 | default=10, 58 | help='log interval, one log per n updates (default: 10)') 59 | parser.add_argument( 60 | '--env-name', 61 | default='PongNoFrameskip-v4', 62 | help='environment to train on (default: PongNoFrameskip-v4)') 63 | parser.add_argument( 64 | '--load-dir', 65 | default='./trained_models/', 66 | help='directory to save agent logs (default: ./trained_models/)') 67 | parser.add_argument( 68 | '--non-det', 69 | action='store_true', 70 | default=False, 71 | help='whether to use a non-deterministic policy') 72 | parser.add_argument( 73 | '--recurrent-policy', 74 | action='store_true', 75 | default=False, 76 | help='use a recurrent policy') 77 | parser.add_argument( 78 | '--rl_baseline_zoo_dir', 79 | type=str, default='', help='directory of rl baseline zoo') 80 | parser.add_argument( 81 | '--ensemble_hidden_size', default=512, 82 | help='dril ensemble network number of hidden units (default: 512)') 83 | parser.add_argument( 84 | '--ensemble_size', type=int, default=5, 85 | help='numnber of polices in the ensemble (default: 5)') 86 | args, unknown = parser.parse_known_args() 87 | 88 | default_args = get_args() 89 | 90 | args.det = not args.non_det 91 | 92 | device='cpu' 93 | env = make_vec_envs( 94 | args.env_name, 95 | args.seed + 1000, 96 | 1, 97 | None, 98 | None, 99 | device='cpu', 100 | allow_early_resets=False) 101 | 102 | # Get a render function 103 | render_func = get_render_func(env) 104 | 105 | # We need to use the same statistics for normalization as used in training 106 | actor_critic = Policy( 107 | env.observation_space.shape, 108 | env.action_space, 109 | load_expert=False, 110 | env_name=args.env_name, 111 | rl_baseline_zoo_dir=args.rl_baseline_zoo_dir, 112 | expert_algo='a2c', 113 | base_kwargs={'recurrent': args.recurrent_policy}) 114 | 115 | 116 | try: 117 | actor_critic, ob_rms = \ 118 | torch.load(os.path.join(args.load_dir, args.env_name + ".pt"), map_location='cpu') 119 | except: 120 | params = \ 121 | torch.load(os.path.join(args.load_dir, args.env_name + ".pt"), map_location='cpu') 122 | actor_critic.load_state_dict(params) 123 | 124 | vec_norm = get_vec_normalize(env) 125 | if vec_norm is not None: 126 | vec_norm.eval() 127 | vec_norm.ob_rms = None 128 | 129 | #recurrent_hidden_states = torch.zeros(1, 130 | # actor_critic.recurrent_hidden_state_size) 131 | #masks = torch.zeros(1, 1) 132 | 133 | obs = env.reset() 134 | 135 | if render_func is not None: 136 | import gym, os 137 | 138 | #if args.env_name.find('Bullet') > -1: 139 | # import pybullet as p 140 | # 141 | # torsoId = -1 142 | # if (p.getBodyInfo(i)[0].decode() == "torso"): 143 | # torsoId = i 144 | 145 | num_inputs = env.observation_space.shape[0] 146 | try: 147 | num_actions = env.action_space.n 148 | except: 149 | num_actions = env.action_space.shape[0] 150 | 151 | ensemble_args = (num_inputs, num_actions, args.ensemble_hidden_size, args.ensemble_size) 152 | #if num_inputs) == 3: 153 | #if env_name in ['duckietown']: 154 | # ensemble_policy = ensemble_models.PolicyEnsembleDuckieTownCNN 155 | #elif uncertainty_reward == 'ensemble': 156 | ensemble_policy = ensemble_models.PolicyEnsembleCNN 157 | #elif uncertainty_reward == 'dropout': 158 | # ensemble_policy = ensemble_models.PolicyEnsembleCNNDropout 159 | #else: 160 | # raise Exception("Unknown uncertainty_reward type") 161 | ensemble_policy = ensemble_policy(*ensemble_args).to(device) 162 | best_test_params = torch.load('', map_location=device) 163 | ensemble_policy.load_state_dict(best_test_params) 164 | 165 | #if save_traces: 166 | if True: 167 | traces, u_rewards_raw, u_rewards_quant, actions = [], [], [], [] 168 | variance = np.load('') 169 | quantile = np.quantile(np.array(variance), .40) 170 | clip = (lambda x: -1 if x > quantile else 1) 171 | 172 | step=0 173 | while True: 174 | with torch.no_grad(): 175 | value, action, _, _ = actor_critic.act(obs, None, None, deterministic=args.det) 176 | 177 | state = obs.repeat(args.ensemble_size, 1,1,1).float().to(device) 178 | ensemble_action = ensemble_policy(state).squeeze().detach() 179 | 180 | if isinstance(env.action_space, gym.spaces.Box): 181 | action = torch.clamp(action, env.action_space.low[0], env.action_space.high[0]) 182 | ensemble_action = torch.clamp(ensemble_action, env.action_space.low[0],\ 183 | env. action_space.high[0]) 184 | 185 | #action = ensemble_action[[4]].squeeze().max(0)[1].unsqueeze(0).unsqueeze(0) 186 | 187 | cov = np.cov(ensemble_action.T.cpu().numpy()) 188 | 189 | # If the env has only one action then we need to reshape cov 190 | if env.action_space.__class__.__name__ == "Box": 191 | if env.action_space.shape[0] == 1: 192 | cov = cov.reshape(-1,1) 193 | 194 | if isinstance(env.action_space, gym.spaces.discrete.Discrete): 195 | one_hot_action = torch.FloatTensor(np.eye(num_actions)[int(action.item())]) 196 | action_vec = one_hot_action 197 | elif isinstance(env.action_space, gym.spaces.Box): 198 | action_vec = action.clone() 199 | elif isinstance(env.action_space, gym.spaces.MultiBinary): 200 | #action = actions[[idx]] 201 | raise Exception('Envrionment shouldnt be MultiBinary') 202 | else: 203 | raise Exception("Unknown Action Space") 204 | 205 | ensemble_variance = (np.matmul(np.matmul(action_vec, cov), action_vec.T).item()) 206 | print(f'step: {step} ensemble_variance[{action.item()}]: {ensemble_variance} u:{clip(ensemble_variance)}') 207 | step+=1 208 | 209 | traces.append(obs[0][:3].permute(1,2,0).cpu().numpy().copy()) 210 | u_rewards_raw.append(ensemble_variance) 211 | actions.append(action.item()) 212 | u_rewards_quant.append(clip(ensemble_variance)) 213 | 214 | # Obser reward and next obs 215 | obs, reward, done, _ = env.step(action) 216 | 217 | 218 | # masks.fill_(0.0 if done else 1.0) 219 | if args.env_name.find('Bullet') > -1: 220 | if torsoId > -1: 221 | distance = 5 222 | yaw = 0 223 | humanPos, humanOrn = p.getBasePositionAndOrientation(torsoId) 224 | p.resetDebugVisualizerCamera(distance, yaw, -20, humanPos) 225 | 226 | if done:break 227 | #if step > 1200: 228 | # import pdb; pdb.set_trace() 229 | 230 | if render_func is not None: 231 | render_func('human') 232 | 233 | traces_dir = 'video' 234 | os.system(f'mkdir -p {traces_dir}') 235 | for i in range(len(traces)): 236 | fname = f'{traces_dir}/im{i:05d}.png'#_uq{u_rewards_quant[i]}.png' 237 | #else: 238 | # fname = f'{traces_dir}/im{i:05d}.png' 239 | #scipy.misc.imsave(fname, traces[i][0].cpu().numpy()) 240 | #im = Image.fromarray(traces[i]*255.999) 241 | #im = transform.resize(traces[i].reshape(84,84,3),(252,252)) 242 | #im = transform.resize(traces[i].permute(1,2,0)),(252,252)) 243 | img = Image.fromarray((traces[i]* 255).astype(np.uint8)) 244 | draw = ImageDraw.Draw(img) 245 | draw.text((3,3), f"q:{u_rewards_quant[i]} a:{actions[i]} ", fill=(255,255,0)) 246 | img.save(fname) 247 | 248 | import imageio 249 | images = [] 250 | import pdb; pdb.set_trace() 251 | for filename in sorted(glob.glob(f'{traces_dir}/*.png')): 252 | img = imageio.imread(filename) 253 | images.append(img) 254 | imageio.mimsave('output.gif', images, fps=10) 255 | -------------------------------------------------------------------------------- /dril/evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gym 4 | 5 | from dril.a2c_ppo_acktr import utils 6 | from dril.a2c_ppo_acktr.envs import make_vec_envs 7 | 8 | 9 | def evaluate(actor_critic, ob_rms, env_name, seed, num_processes, eval_log_dir, 10 | device, num_episodes=None, atari_max_steps=None): 11 | eval_envs = make_vec_envs(env_name, seed + num_processes, num_processes, 12 | None, eval_log_dir, device, True, atari_max_steps) 13 | 14 | vec_norm = utils.get_vec_normalize(eval_envs) 15 | if vec_norm is not None: 16 | vec_norm.eval() 17 | vec_norm.ob_rms = ob_rms 18 | 19 | eval_episode_rewards = [] 20 | 21 | obs = eval_envs.reset() 22 | eval_recurrent_hidden_states = torch.zeros( 23 | num_processes, actor_critic.recurrent_hidden_state_size, device=device) 24 | eval_masks = torch.zeros(num_processes, 1, device=device) 25 | 26 | while len(eval_episode_rewards) < num_episodes: 27 | with torch.no_grad(): 28 | _, action, _, eval_recurrent_hidden_states = actor_critic.act( 29 | obs, 30 | eval_recurrent_hidden_states, 31 | eval_masks, 32 | deterministic=True) 33 | 34 | # Obser reward and next obs 35 | if isinstance(eval_envs.action_space, gym.spaces.Box): 36 | clip_action = torch.clamp(action, float(eval_envs.action_space.low[0]),\ 37 | float(eval_envs.action_space.high[0])) 38 | else: 39 | clip_action = action 40 | 41 | # Obser reward and next obs 42 | obs, _, done, infos = eval_envs.step(clip_action) 43 | 44 | eval_masks = torch.tensor( 45 | [[0.0] if done_ else [1.0] for done_ in done], 46 | dtype=torch.float32, 47 | device=device) 48 | 49 | for info in infos: 50 | if 'episode' in info.keys(): 51 | eval_episode_rewards.append(info['episode']['r']) 52 | 53 | eval_envs.close() 54 | 55 | print(" Evaluation using {} episodes: mean reward {:.5f}\n".format( 56 | len(eval_episode_rewards), np.mean(eval_episode_rewards))) 57 | 58 | return np.mean(eval_episode_rewards) 59 | -------------------------------------------------------------------------------- /dril/generate_demonstration_data.py: -------------------------------------------------------------------------------- 1 | import gym, os 2 | import numpy as np 3 | import argparse 4 | import random 5 | import pandas as pd 6 | 7 | import sys 8 | import torch 9 | from gym import wrappers 10 | import random 11 | import torch.nn.functional as F 12 | import torch.nn as nn 13 | import torch as th 14 | 15 | from dril.a2c_ppo_acktr.envs import make_vec_envs 16 | from dril.a2c_ppo_acktr.model import Policy 17 | from dril.a2c_ppo_acktr.arguments import get_args 18 | import os 19 | 20 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 21 | 22 | args = get_args() 23 | 24 | args.recurrent_policy = False 25 | args.load_expert = True 26 | 27 | os.system(f'mkdir -p {args.demo_data_dir}') 28 | os.system(f'mkdir -p {args.demo_data_dir}/tmp/gym') 29 | sys.path.insert(1,os.path.join(args.rl_baseline_zoo_dir, 'utils')) 30 | from utils import get_saved_hyperparams 31 | 32 | #device = torch.device("cpu") 33 | device = torch.device("cuda:0" if args.cuda else "cpu") 34 | print(f'device: {device}') 35 | seed = args.seed 36 | print(f'seed: {seed}') 37 | 38 | if args.env_name in ['highway-v0']: 39 | import highway_env 40 | from rl_agents.agents.common.factory import agent_factory 41 | 42 | env = make_vec_envs(args.env_name, seed, 1, 0.99, f'{args.emo_data_dir}/tmp/gym', device,\ 43 | True, stats_path=stats_path, hyperparams=hyperparams, time=time, 44 | atari_max_steps=args.atari_max_steps) 45 | 46 | # Make agent 47 | agent_config = { 48 | "__class__": "", 49 | "budget": 50, 50 | "gamma": 0.7, 51 | } 52 | th_model = agent_factory(gym.make(args.env_name), agent_config) 53 | time = False 54 | elif args.env_name in ['duckietown']: 55 | from a2c_ppo_acktr.duckietown.env import launch_env 56 | from a2c_ppo_acktr.duckietown.wrappers import NormalizeWrapper, ImgWrapper,\ 57 | DtRewardWrapper, ActionWrapper, ResizeWrapper 58 | from a2c_ppo_acktr.duckietown.teacher import PurePursuitExpert 59 | env = launch_env() 60 | env = ResizeWrapper(env) 61 | env = NormalizeWrapper(env) 62 | env = ImgWrapper(env) 63 | env = ActionWrapper(env) 64 | env = DtRewardWrapper(env) 65 | 66 | # Create an imperfect demonstrator 67 | expert = PurePursuitExpert(env=env) 68 | time = False 69 | else: 70 | print('[Setting environemnt hyperparams variables]') 71 | stats_path = os.path.join(args.rl_baseline_zoo_dir, 'trained_agents', f'{args.expert_algo}',\ 72 | f'{args.env_name}') 73 | hyperparams, stats_path = get_saved_hyperparams(stats_path, test_mode=True,\ 74 | norm_reward=args.norm_reward_stable_baseline) 75 | 76 | ## Load saved policy 77 | 78 | # subset of the environments have time wrapper 79 | time_wrapper_envs = ['HalfCheetahBulletEnv-v0', 'Walker2DBulletEnv-v0', 'AntBulletEnv-v0'] 80 | if args.env_name in time_wrapper_envs: 81 | time=True 82 | else: 83 | time = False 84 | 85 | env = make_vec_envs(args.env_name, seed, 1, 0.99, f'{args.demo_data_dir}/tmp/gym', device,\ 86 | True, stats_path=stats_path, hyperparams=hyperparams, time=time) 87 | 88 | th_model = Policy( 89 | env.observation_space.shape, 90 | env.action_space, 91 | load_expert=True, 92 | env_name=args.env_name, 93 | rl_baseline_zoo_dir=args.rl_baseline_zoo_dir, 94 | expert_algo=args.expert_algo, 95 | # [Bug]: normalize=False, 96 | normalize=True if hasattr(gym.envs, 'atari') else False, 97 | base_kwargs={'recurrent': args.recurrent_policy}).to(device) 98 | 99 | rtn_obs, rtn_acs, rtn_lens, ep_rewards = [], [], [], [] 100 | obs = env.reset() 101 | if args.env_name in ['duckietown']: 102 | obs = torch.FloatTensor([obs]) 103 | 104 | save = True 105 | print(f'[running]') 106 | 107 | step = 0 108 | args.seed = args.seed 109 | idx = random.randint(1,args.subsample_frequency) 110 | 111 | obs_path_suffix = f'{args.demo_data_dir}/obs_{args.env_name}_seed={args.seed}' 112 | acs_path_suffix = f'{args.demo_data_dir}/acs_{args.env_name}_seed={args.seed}' 113 | 114 | 115 | while True: 116 | with torch.no_grad(): 117 | if args.env_name in ['highway-v0']: 118 | action = torch.tensor([[th_model.act(obs)]]) 119 | elif args.env_name in ['duckietown']: 120 | action = torch.FloatTensor([expert.predict(None)]) 121 | elif hasattr(gym.envs, 'atari'): 122 | _, actor_features, _ = th_model.base(obs, None, None) 123 | #action = th.argmax(th_model.dist.linear(actor_features)).reshape(-1,1) 124 | dist = th_model.dist(actor_features) 125 | action = dist.sample() 126 | else: 127 | _, action, _, _ = th_model.act(obs, None, None, deterministic=True) 128 | 129 | if isinstance(env.action_space, gym.spaces.Box): 130 | clip_action = np.clip(action.cpu(), env.action_space.low, env.action_space.high) 131 | else: 132 | clip_action = action 133 | 134 | if (step == idx and args.subsample) or not args.subsample: 135 | #if args.env_name in env_hyperparam: 136 | if time: 137 | try: # If vectornormalize is on 138 | rtn_obs.append(env.venv.get_original_obs()) 139 | except: # if vectornormalize is off 140 | rtn_obs.append(env.venv.envs[0].get_original_obs()) 141 | else: 142 | try: # If time is on and vectornormalize is on 143 | rtn_obs.append(env.venv.get_original_obs()) 144 | except: # If time is off and vectornormalize is off 145 | rtn_obs.append(obs.cpu().numpy().copy()) 146 | 147 | rtn_acs.append(action.cpu().numpy().copy()) 148 | idx += args.subsample_frequency 149 | 150 | if args.env_name in ['duckietown']: 151 | obs, reward, done, infos = env.step(clip_action.squeeze()) 152 | obs = torch.FloatTensor([obs]) 153 | else: 154 | obs, reward, done, infos = env.step(clip_action) 155 | 156 | step += 1 157 | if args.env_name in ['duckietown']: 158 | if done: 159 | print(f"reward: {reward}") 160 | ep_rewards.append(reward) 161 | save = True 162 | obs = env.reset() 163 | obs = torch.FloatTensor([obs]) 164 | step = 0 165 | idx=random.randint(1,args.subsample_frequency) 166 | else: 167 | for info in infos or done: 168 | if 'episode' in info.keys(): 169 | print(f"reward: {info['episode']['r']}") 170 | ep_rewards.append(info['episode']['r']) 171 | save = True 172 | obs = env.reset() 173 | step = 0 174 | idx=random.randint(1,args.subsample_frequency) 175 | 176 | if (len(ep_rewards) in [1, 3, 5, 10, 15, 20]) and save: 177 | rtn_obs_ = np.concatenate(rtn_obs) 178 | rtn_acs_ = np.concatenate(rtn_acs) 179 | 180 | obs_path = f'{obs_path_suffix}_ntraj={len(ep_rewards)}.npy' 181 | acs_path = f'{acs_path_suffix}_ntraj={len(ep_rewards)}.npy' 182 | 183 | print(f'saving to: {obs_path}') 184 | print(f'saving to: {acs_path}') 185 | 186 | np.save(obs_path, rtn_obs_) 187 | np.save(acs_path, rtn_acs_) 188 | print(f'done, length :{len(ep_rewards)}') 189 | save = False 190 | if len(ep_rewards) % 20 == 0: 191 | break 192 | 193 | print(f'expert: {np.mean(ep_rewards)}') 194 | results_save_path = os.path.join(args.save_results_dir, 'expert', f'expert_{args.env_name}_seed={args.seed}.perf') 195 | results = [{'total_num_steps':0 , 'train_loss': 0, 'test_loss': 0, 'num_trajs': 0 ,\ 196 | 'test_reward':np.mean(ep_rewards), 'u_reward': 0}] 197 | df = pd.DataFrame(results, columns=np.hstack(['x', 'steps', 'train_loss', 'test_loss',\ 198 | 'train_reward', 'test_reward', 'label', 'u_reward'])) 199 | df.to_csv(results_save_path) 200 | -------------------------------------------------------------------------------- /dril/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import copy 4 | import glob 5 | import os 6 | import time 7 | from collections import deque 8 | import sys 9 | import warnings 10 | 11 | import gym 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | 18 | 19 | from dril.a2c_ppo_acktr import algo, utils 20 | from dril.a2c_ppo_acktr.algo import gail 21 | from dril.a2c_ppo_acktr.algo.behavior_cloning import BehaviorCloning 22 | from dril.a2c_ppo_acktr.algo.ensemble import Ensemble 23 | from dril.a2c_ppo_acktr.algo.dril import DRIL 24 | from dril.a2c_ppo_acktr.arguments import get_args 25 | from dril.a2c_ppo_acktr.envs import make_vec_envs 26 | from dril.a2c_ppo_acktr.model import Policy 27 | from dril.a2c_ppo_acktr.expert_dataset import ExpertDataset 28 | from dril.a2c_ppo_acktr.storage import RolloutStorage 29 | from evaluation import evaluate 30 | import pandas as pd 31 | 32 | 33 | def main(): 34 | args = get_args() 35 | 36 | torch.manual_seed(args.seed) 37 | torch.cuda.manual_seed_all(args.seed) 38 | 39 | if args.system == 'philly': 40 | args.demo_data_dir = os.getenv('PT_OUTPUT_DIR') + '/demo_data/' 41 | args.save_model_dir = os.getenv('PT_OUTPUT_DIR') + '/trained_models/' 42 | args.save_results_dir = os.getenv('PT_OUTPUT_DIR') + '/trained_results/' 43 | 44 | 45 | if args.cuda and torch.cuda.is_available() and args.cuda_deterministic: 46 | torch.backends.cudnn.benchmark = False 47 | torch.backends.cudnn.deterministic = True 48 | 49 | log_dir = os.path.expanduser(args.log_dir) 50 | eval_log_dir = log_dir + "_eval" 51 | utils.cleanup_log_dir(log_dir) 52 | utils.cleanup_log_dir(eval_log_dir) 53 | 54 | torch.set_num_threads(1) 55 | device = torch.device("cuda:0" if args.cuda else "cpu") 56 | 57 | 58 | envs = make_vec_envs(args.env_name, args.seed, args.num_processes, args.gamma, 59 | args.log_dir, device, False, use_obs_norm=args.use_obs_norm, 60 | max_steps=args.atari_max_steps) 61 | 62 | actor_critic = Policy( 63 | envs.observation_space.shape, 64 | envs.action_space, 65 | load_expert=args.load_expert, 66 | env_name=args.env_name, 67 | rl_baseline_zoo_dir=args.rl_baseline_zoo_dir, 68 | expert_algo=args.expert_algo, 69 | base_kwargs={'recurrent': args.recurrent_policy}) 70 | actor_critic.to(device) 71 | 72 | # stores results 73 | main_results = [] 74 | 75 | 76 | if args.behavior_cloning or args.dril or args.warm_start: 77 | expert_dataset = ExpertDataset(args.demo_data_dir, args.env_name,\ 78 | args.num_trajs, args.seed, args.ensemble_shuffle_type) 79 | bc_model_save_path = os.path.join(args.save_model_dir, 'bc') 80 | bc_file_name = f'bc_{args.env_name}_policy_ntrajs={args.num_trajs}_seed={args.seed}' 81 | #bc_file_name = f'{args.env_name}_bc_policy_ntraj={args.num_trajs}_seed={args.seed}' 82 | bc_model_path = os.path.join(bc_model_save_path, f'{bc_file_name}.model.pth') 83 | bc_results_save_path = os.path.join(args.save_results_dir, 'bc', f'{bc_file_name}.perf') 84 | 85 | bc_model = BehaviorCloning(actor_critic, device, batch_size=args.bc_batch_size,\ 86 | lr=args.bc_lr, training_data_split=args.training_data_split, 87 | expert_dataset=expert_dataset, envs=envs) 88 | 89 | # Check if model already exist 90 | test_reward = None 91 | if os.path.exists(bc_model_path): 92 | best_test_params = torch.load(bc_model_path, map_location=device) 93 | print(f'*** Loading behavior cloning policy: {bc_model_path} ***') 94 | else: 95 | bc_results = [] 96 | best_test_loss, best_test_model = np.float('inf'), None 97 | for bc_epoch in range(args.bc_train_epoch): 98 | train_loss = bc_model.update(update=True, data_loader_type='train') 99 | with torch.no_grad(): 100 | test_loss = bc_model.update(update=False, data_loader_type='test') 101 | #if test_loss < best_test_loss: 102 | # best_test_loss = test_loss 103 | # best_test_params = copy.deepcopy(actor_critic.state_dict()) 104 | if test_loss < best_test_loss: 105 | print('model has improved') 106 | best_test_loss = test_loss 107 | best_test_params = copy.deepcopy(actor_critic.state_dict()) 108 | patience = 20 109 | else: 110 | patience -= 1 111 | print('model has not improved') 112 | if patience == 0: 113 | print('model has not improved in 20 epochs, breaking') 114 | break 115 | 116 | print(f'bc-epoch {bc_epoch}/{args.bc_train_epoch} | train loss: {train_loss:.4f}, test loss: {test_loss:.4f}') 117 | # Save the Behavior Cloning model and training results 118 | test_reward = evaluate(actor_critic, None, args.env_name, args.seed, 119 | args.num_processes, eval_log_dir, device, num_episodes=10, 120 | atari_max_steps=args.atari_max_steps) 121 | bc_results.append({'epoch': bc_epoch, 'trloss':train_loss, 'teloss': test_loss,\ 122 | 'test_reward': test_reward}) 123 | 124 | torch.save(best_test_params, bc_model_path) 125 | df = pd.DataFrame(bc_results, columns=np.hstack(['epoch', 'trloss', 'teloss', 'test_reward'])) 126 | df.to_csv(bc_results_save_path) 127 | 128 | # Load Behavior cloning model 129 | actor_critic.load_state_dict(best_test_params) 130 | if test_reward is None: 131 | bc_model_reward = evaluate(actor_critic, None, args.env_name, args.seed, 132 | args.num_processes, eval_log_dir, device, num_episodes=10, 133 | atari_max_steps=args.atari_max_steps) 134 | else: 135 | bc_model_reward = test_reward 136 | print(f'Behavior cloning model performance: {bc_model_reward}') 137 | # If behavior cloning terminate the script early 138 | if args.behavior_cloning: 139 | sys.exit() 140 | # Reset the behavior cloning optimizer 141 | bc_model.reset() 142 | 143 | 144 | if args.dril: 145 | expert_dataset = ExpertDataset(args.demo_data_dir, args.env_name, 146 | args.num_trajs, args.seed, args.ensemble_shuffle_type) 147 | 148 | # Train or load ensemble policy 149 | ensemble_policy = Ensemble(device=device, envs=envs, 150 | expert_dataset=expert_dataset, 151 | uncertainty_reward=args.dril_uncertainty_reward, 152 | ensemble_hidden_size=args.ensemble_hidden_size, 153 | ensemble_drop_rate=args.ensemble_drop_rate, 154 | ensemble_size=args.ensemble_size, 155 | ensemble_batch_size=args.ensemble_batch_size, 156 | ensemble_lr=args.ensemble_lr, 157 | num_ensemble_train_epoch=args.num_ensemble_train_epoch, 158 | num_trajs=args.num_trajs, 159 | seed=args.seed, 160 | env_name=args.env_name, 161 | training_data_split=args.training_data_split, 162 | save_model_dir=args.save_model_dir, 163 | save_results_dir=args.save_results_dir) 164 | 165 | # If only training ensemble 166 | if args.pretrain_ensemble_only: 167 | sys.exit() 168 | 169 | # Train or load behavior cloning policy 170 | dril_bc_model = bc_model 171 | 172 | dril = DRIL(device=device,envs=envs,ensemble_policy=ensemble_policy, 173 | dril_bc_model=dril_bc_model, expert_dataset=expert_dataset, 174 | ensemble_quantile_threshold=args.ensemble_quantile_threshold, 175 | ensemble_size=args.ensemble_size, dril_cost_clip=args.dril_cost_clip, 176 | env_name=args.env_name, num_dril_bc_train_epoch=args.num_dril_bc_train_epoch, 177 | training_data_split=args.training_data_split) 178 | else: 179 | dril = None 180 | 181 | 182 | if args.algo == 'a2c': 183 | #TODO: Not sure why this is needed 184 | from dril.a2c_ppo_acktr import algo 185 | agent = algo.A2C_ACKTR( 186 | actor_critic, 187 | args.value_loss_coef, 188 | args.entropy_coef, 189 | lr=args.lr, 190 | eps=args.eps, 191 | alpha=args.alpha, 192 | max_grad_norm=args.max_grad_norm, 193 | dril=dril) 194 | elif args.algo == 'ppo': 195 | #TODO: Not sure why this is needed 196 | from dril.a2c_ppo_acktr import algo 197 | agent = algo.PPO( 198 | actor_critic, 199 | args.clip_param, 200 | args.ppo_epoch, 201 | args.num_mini_batch, 202 | args.value_loss_coef, 203 | args.entropy_coef, 204 | lr=args.lr, 205 | eps=args.eps, 206 | max_grad_norm=args.max_grad_norm, 207 | dril=dril) 208 | elif args.algo == 'acktr': 209 | agent = algo.A2C_ACKTR( 210 | actor_critic, args.value_loss_coef, args.entropy_coef, acktr=True) 211 | 212 | 213 | if args.gail: 214 | if len(envs.observation_space.shape) == 1: 215 | discr = gail.Discriminator( 216 | envs.observation_space.shape[0] + envs.action_space.shape[0], 10, 217 | device, args.gail_reward_type, args.clip_gail_action, 218 | envs, args.gail_disc_lr) 219 | else: 220 | discr = gail.DiscriminatorCNN( 221 | envs.observation_space, 10, envs.action_space.n, 222 | device, args.gail_disc_lr,args.gail_reward_type, envs) 223 | 224 | file_name = os.path.join( 225 | args.gail_experts_dir, "trajs_{}.pt".format( 226 | args.env_name.split('-')[0].lower())) 227 | 228 | expert_dataset = ExpertDataset(args.demo_data_dir, args.env_name, 229 | args.num_trajs, args.seed, args.ensemble_shuffle_type) 230 | dataset = expert_dataset.load_demo_data(args.training_data_split, args.gail_batch_size, None) 231 | gail_train_loader = dataset['trdata'] 232 | 233 | rollouts = RolloutStorage(args.num_steps, args.num_processes, 234 | envs.observation_space.shape, envs.action_space, 235 | actor_critic.recurrent_hidden_state_size) 236 | 237 | obs = envs.reset() 238 | rollouts.obs[0].copy_(obs) 239 | rollouts.to(device) 240 | 241 | episode_rewards = deque(maxlen=10) 242 | episode_uncertainty_rewards = deque(maxlen=10) 243 | running_uncertainty_reward = np.zeros(args.num_processes) 244 | 245 | start = time.time() 246 | num_updates = int( 247 | args.num_env_steps) // args.num_steps // args.num_processes 248 | 249 | previous_action = None 250 | for j in range(num_updates): 251 | 252 | if args.use_linear_lr_decay: 253 | # decrease learning rate linearly 254 | utils.update_linear_schedule( 255 | agent.optimizer, j, num_updates, 256 | agent.optimizer.lr if args.algo == "acktr" else args.lr) 257 | 258 | for step in range(args.num_steps): 259 | # Sample actions 260 | with torch.no_grad(): 261 | value, action, action_log_prob, recurrent_hidden_states = actor_critic.act( 262 | rollouts.obs[step], rollouts.recurrent_hidden_states[step], 263 | rollouts.masks[step]) 264 | 265 | # Obser reward and next obs 266 | if isinstance(envs.action_space, gym.spaces.Box): 267 | clip_action = torch.clamp(action, float(envs.action_space.low[0]), float(envs.action_space.high[0])) 268 | else: 269 | clip_action = action 270 | 271 | if args.dril: 272 | dril_reward = dril.predict_reward(clip_action, obs, envs) 273 | running_uncertainty_reward += dril_reward.view(-1).numpy() 274 | 275 | obs, env_reward, done, infos = envs.step(clip_action) 276 | 277 | if args.dril: 278 | reward = dril_reward 279 | else: 280 | reward = env_reward 281 | 282 | #for info in infos: 283 | for i, info in enumerate(infos): 284 | if 'episode' in info.keys(): 285 | episode_rewards.append(info['episode']['r']) 286 | episode_uncertainty_rewards.append(running_uncertainty_reward[i] / info['episode']['l']) 287 | running_uncertainty_reward[i] = 0 288 | 289 | # If done then clean the history of observations. 290 | masks = torch.FloatTensor( 291 | [[0.0] if done_ else [1.0] for done_ in done]) 292 | bad_masks = torch.FloatTensor( 293 | [[0.0] if 'bad_transition' in info.keys() else [1.0] 294 | for info in infos]) 295 | rollouts.insert(obs, recurrent_hidden_states, action, 296 | action_log_prob, value, reward, masks, bad_masks) 297 | 298 | with torch.no_grad(): 299 | next_value = actor_critic.get_value( 300 | rollouts.obs[-1], rollouts.recurrent_hidden_states[-1], 301 | rollouts.masks[-1]).detach() 302 | 303 | if args.dril and args.algo == 'ppo': 304 | # Normalize the rewards for ppo 305 | # (Implementation Matters in Deep RL: A Case Study on PPO and TRPO) 306 | # (https://openreview.net/forum?id=r1etN1rtPB) 307 | for step in range(args.num_steps): 308 | rollouts.rewards[step] = dril.normalize_reward( 309 | rollouts.obs[step], rollouts.actions[step], args.gamma, 310 | rollouts.masks[step], rollouts.rewards[step]) 311 | 312 | if args.gail: 313 | #if j >= 10: 314 | # envs.venv.eval() 315 | 316 | gail_epoch = args.gail_epoch 317 | if j < 10: 318 | gail_epoch = 10 # Warm up 319 | for _ in range(gail_epoch): 320 | try: 321 | # Continous control task have obfilt 322 | obfilt = utils.get_vec_normalize(envs)._obfilt 323 | except: 324 | # CNN doesnt have obfilt 325 | obfilt = None 326 | discr.update(gail_train_loader, rollouts, obfilt) 327 | 328 | for step in range(args.num_steps): 329 | rollouts.rewards[step] = discr.predict_reward( 330 | rollouts.obs[step], rollouts.actions[step], args.gamma, 331 | rollouts.masks[step]) 332 | 333 | rollouts.compute_returns(next_value, args.use_gae, args.gamma, 334 | args.gae_lambda, args.use_proper_time_limits) 335 | 336 | value_loss, action_loss, dist_entropy = agent.update(rollouts) 337 | 338 | rollouts.after_update() 339 | 340 | # save for every interval-th episode or for the last epoch 341 | if (j % args.save_interval == 0 342 | or j == num_updates - 1) and args.save_model_dir != "": 343 | save_path = os.path.join(args.save_model_dir, args.algo) 344 | model_file_name = f'{args.env_name}_policy_ntrajs={args.num_trajs}_seed={args.seed}' 345 | torch.save([ 346 | actor_critic, 347 | getattr(utils.get_vec_normalize(envs), 'ob_rms', None) 348 | ], os.path.join(save_path, f'{model_file_name}.pt')) 349 | 350 | if j % args.log_interval == 0 and len(episode_rewards) > 1: 351 | total_num_steps = (j + 1) * args.num_processes * args.num_steps 352 | end = time.time() 353 | print( 354 | "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f} mean/median U reward {:.4f}/{:.4f}\n\n" 355 | .format(j, total_num_steps, 356 | int(total_num_steps / (end - start)), 357 | len(episode_rewards), np.mean(episode_rewards), 358 | np.median(episode_rewards), np.min(episode_rewards), 359 | np.max(episode_rewards), np.mean(episode_uncertainty_rewards), 360 | np.median(episode_uncertainty_rewards))) 361 | 362 | 363 | 364 | if (args.eval_interval is not None and len(episode_rewards) > 1 365 | and j % args.eval_interval == 0): 366 | if args.dril: 367 | ob_rms = None 368 | else: 369 | try: 370 | ob_rms = utils.get_vec_normalize(envs).ob_rms 371 | except: 372 | ob_rms = None 373 | 374 | print(f'ob_rms: {ob_rms}') 375 | test_reward = evaluate(actor_critic, ob_rms, args.env_name, args.seed, 376 | args.num_processes, eval_log_dir, device, args.num_eval_episodes, 377 | atari_max_steps=args.atari_max_steps) 378 | main_results.append({'total_num_steps': total_num_steps, 'train_loss': 0,\ 379 | 'test_loss': 0, 'test_reward':test_reward, 'num_trajs': args.num_trajs,\ 380 | 'train_reward': np.mean(episode_rewards),\ 381 | 'u_reward': np.mean(episode_uncertainty_rewards)}) 382 | save_results(args, main_results, algo, args.dril, args.gail) 383 | 384 | 385 | if dril: algo ='dril' 386 | elif gail: algo ='gail' 387 | else: algo = args.algo 388 | save_path = os.path.join(args.save_model_dir, algo) 389 | file_name = f'{algo}_{args.env_name}_policy_ntrajs={args.num_trajs}_seed={args.seed}' 390 | 391 | torch.save([ 392 | actor_critic, 393 | getattr(utils.get_vec_normalize(envs), 'ob_rms', None) 394 | ], os.path.join(save_path, f"{file_name}.pt")) 395 | 396 | # Final evaluation 397 | try: 398 | ob_rms = utils.get_vec_normalize(envs).ob_rms 399 | except: 400 | ob_rms = None 401 | test_reward = evaluate(actor_critic, ob_rms, args.env_name, args.seed, 402 | args.num_processes, eval_log_dir, device, num_episodes=10, atari_max_steps=args.atari_max_steps) 403 | main_results.append({'total_num_steps': total_num_steps, 'train_loss': 0, 'test_loss': 0,\ 404 | 'num_trajs': args.num_trajs, 'test_reward':test_reward,\ 405 | 'train_reward': np.mean(episode_rewards),\ 406 | 'u_reward': np.mean(episode_uncertainty_rewards)}) 407 | save_results(args, main_results, algo, args.dril, args.gail) 408 | 409 | 410 | def save_results(args, main_results, algo, dril, gail): 411 | if dril: algo ='dril' 412 | elif gail: algo ='gail' 413 | else: algo = args.algo 414 | 415 | if dril: 416 | exp_name = f'{algo}_{args.env_name}_ntraj={args.num_trajs}_' 417 | exp_name += f'ensemble_lr={args.ensemble_lr}_' 418 | exp_name += f'lr={args.bc_lr}_bcep={args.bc_train_epoch}_shuffle={args.ensemble_shuffle_type}_' 419 | exp_name += f'quantile={args.ensemble_quantile_threshold}_' 420 | exp_name += f'cost_{args.dril_cost_clip}_seed={args.seed}.perf' 421 | elif gail: 422 | exp_name = f'{algo}_{args.env_name}_ntraj={args.num_trajs}_' 423 | exp_name += f'gail_lr={args.gail_disc_lr}_lr={args.bc_lr}_bcep={args.bc_train_epoch}_' 424 | exp_name += f'gail_reward_type={args.gail_reward_type}_seed={args.seed}.perf' 425 | else: 426 | exp_name = f'{algo}_{args.env_name}.pef' 427 | 428 | results_save_path = os.path.join(args.save_results_dir, f'{algo}', f'{exp_name}') 429 | df = pd.DataFrame(main_results, columns=np.hstack(['x', 'total_num_steps', 'train_loss', 'test_loss', 'train_reward', 'test_reward', 'num_trajs', 'u_reward'])) 430 | df.to_csv(results_save_path) 431 | 432 | if __name__ == "__main__": 433 | main() 434 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import glob, csv, pdb, numpy, torch, os, argparse 3 | import pandas 4 | import matplotlib.pyplot as plt 5 | import os 6 | import numpy as np 7 | 8 | import seaborn as sns 9 | import cycler 10 | import matplotlib 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('-env', type=str, default='SpaceInvadersNoFrameskip-v4') 15 | parser.add_argument('-n_bc_epochs', type=int, default=1) 16 | parser.add_argument('-shuffle', type=int, default=2) 17 | parser.add_argument('-lr', type=float, default=0.00025) 18 | parser.add_argument('-quantile', type=float, default=0.98) 19 | parser.add_argument('-decay', type=int, default=1) 20 | parser.add_argument('-exp', type=str, default='exp1') 21 | parser.add_argument('-plot_u_reward', type=int, default=0) 22 | args = parser.parse_args() 23 | 24 | data_dir = f'{os.getcwd()}/dril/trained_results/' 25 | 26 | def get_results(result_files, filter=False): 27 | rewards = [] 28 | u_rewards = [] 29 | steps = [] 30 | test_reward = [] 31 | for r in result_files: 32 | try: 33 | data = pandas.read_csv(r) 34 | idx = len(data['test_reward']) - 1 35 | rewards.append(data['test_reward'][idx]) 36 | except: 37 | pass 38 | 39 | try: 40 | u_rewards.append(data['u_reward'].tolist()) 41 | steps.append(data['total_num_steps'].tolist()) 42 | test_reward.append(data['test_reward'].tolist()) 43 | except: 44 | pass 45 | 46 | return (rewards, u_rewards, steps, test_reward) 47 | 48 | 49 | def load_results(n_demo): 50 | # Expert results ----------------- 51 | expert_results = glob.glob(f'{data_dir}/expert/expert_{args.env}_seed=*.perf') 52 | (expert_reward, _, _, _) = get_results(expert_results) 53 | 54 | # Behavior Cloing results -------- 55 | bc_mse_results = glob.glob(f'{data_dir}/bc/bc_{args.env}_policy_ntrajs={n_demo}_seed=*.perf') 56 | (bc_mse_reward, _, _,_) = get_results(bc_mse_results) 57 | 58 | # DRIL results ------------------- 59 | exp_name = f'dril_{args.env}_ntraj={n_demo}_ensemble_lr=0.00025_lr=0.00025_bcep=1001_' 60 | exp_name += f'shuffle=sample_w_replace_quantile=0.98_cost_-1_to_1_seed=*.perf' 61 | 62 | bc_mse_variance_results = glob.glob(f'{data_dir}/dril/{exp_name}') 63 | (bc_mse_variance_reward, bc_variance_u_reward, bc_variance_steps, bc_mse_variance_reward_curve) = get_results(bc_mse_variance_results, filter=True) 64 | 65 | # Random results ----------------- 66 | random_reward = [] 67 | random_results = glob.glob(f'{data_dir}/random/{args.env}/random*.perf') 68 | for r in random_results: 69 | random_reward.append(pandas.read_csv(r)['test_reward'].max()) 70 | 71 | # Gail results -------------------- 72 | params = [(clipped_loss, zero_expert_reward, use_obs_norm, use_bc, gail_normalized_reward, bc_loss, clamp_gail_action) 73 | for clipped_loss in [True] 74 | for zero_expert_reward in [True, False] 75 | for use_obs_norm in [False] 76 | for use_bc in [True] 77 | for gail_normalized_reward in [True] 78 | for clamp_gail_action in [False] 79 | for bc_loss in ['mse']] 80 | 81 | gail = {} 82 | for gail_reward_type in ['unbias', 'favor_zero_reward', 'favor_non_zero_reward']: 83 | gail_results = f'gail_{args.env}_ntraj={n_demo}_' 84 | gail_results += f'gail_lr=0.001_lr=0.00025_bcep=2001_' 85 | gail_results += f'gail_reward_type={gail_reward_type}_seed=*.perf' 86 | results = glob.glob(f'{data_dir}/gail/{gail_results}') 87 | 88 | label = f'GAIL {gail_reward_type}' 89 | (results, _, _, _) = get_results(results) 90 | if results: 91 | gail[label] = results 92 | else: 93 | gail[label] = [] 94 | 95 | return {'expert': numpy.array(expert_reward), 96 | 'bc_mse': numpy.array(bc_mse_reward), 97 | 'bc_mse_variance': numpy.array(bc_mse_variance_reward), 98 | 'bc_variance_u_reward_curve': bc_variance_u_reward, 99 | 'bc_mse_variance_reward_curve': bc_mse_variance_reward_curve, 100 | 'bc_variance_steps': bc_variance_steps, 101 | 'random': numpy.array(random_reward), 102 | **gail} 103 | 104 | def add_line_plot(perf_results, color=None, style=None): 105 | width = 3 106 | s = 10 107 | alpha=0.1 108 | 109 | mean = [numpy.mean(perf) for perf in perf_results] 110 | for perf in perf_results: 111 | numpy.std(perf) 112 | std = [numpy.std(perf) for perf in perf_results] 113 | 114 | plt.plot([1, 3, 5, 10, 15, 20], mean, style, c=color, linewidth=width, markersize=s) 115 | plt.xticks([1, 3, 5, 10, 15, 20]) 116 | plt.fill_between([1, 3, 5, 10, 15, 20], numpy.array(mean) - numpy.array(std), numpy.array(mean) + numpy.array(std), color=color, alpha=alpha) 117 | 118 | styles = {'expert': '--', 119 | 'bc': 'o-', 120 | 'dril': '^-', 121 | 'gail0': 'v-', 122 | 'gail1': 'D-', 123 | 'gail2': '<-', 124 | 'gail3': '*-', 125 | 'random': '.-'} 126 | 127 | n = 12 128 | color = numpy.array(sns.color_palette("colorblind", n_colors=n)) 129 | matplotlib.rcParams['axes.prop_cycle'] = cycler.cycler('color', color) 130 | 131 | c1 = color[7]*0.9 132 | c2 = color[2] 133 | c3 = color[4] 134 | c4 = color[3] 135 | c5 = color[1] 136 | c6 = color[10] 137 | c7 = color[10] 138 | c8 = color[10] 139 | 140 | colors = {'expert': c1, 141 | 'bc': c2, 142 | 'dril': c3, 143 | 'gail0': c4, 144 | 'gail1': c5, 145 | 'gail2': c7, 146 | 'gail3': c8, 147 | 'random': c6} 148 | 149 | def main(): 150 | # Expert --------------- 151 | expert = [load_results(n_demo)['expert'] for n_demo in [1, 3, 5, 10, 15, 20]] 152 | add_line_plot(expert, colors['expert'], styles['expert']) 153 | 154 | # Behavior Cloning ------ 155 | bc_mse = [load_results(n_demo)['bc_mse'] for n_demo in [1, 3, 5, 10, 15, 20]] 156 | add_line_plot(bc_mse, colors['bc'], styles['bc']) 157 | 158 | # DRIL ------------------ 159 | bc_mse_variance = [load_results(n_demo)['bc_mse_variance'] for n_demo in [1, 3, 5, 10, 15, 20]] 160 | add_line_plot(bc_mse_variance, colors['dril'], styles['dril']) 161 | 162 | # Random ------------------ 163 | random = [load_results(n_demo)['random'] for n_demo in [1, 3, 5, 10, 15, 20]] 164 | add_line_plot(random, colors['random'], styles['random']) 165 | 166 | # GAIL ----------------- 167 | keys = [] 168 | for n_demo in [1, 3, 5, 10, 15, 20]: 169 | keys += load_results(n_demo).keys() 170 | gail_keys = sorted(list(set([key for key in keys if 'GAIL' in key]))) 171 | 172 | final_keys = [] 173 | for idx, key in enumerate(gail_keys): 174 | gail_results = [load_results(n_demo)[key] for n_demo in [1, 3, 5, 10, 15, 20]] 175 | add_line_plot(gail_results, colors[f'gail{idx}'], styles[f'gail{idx}']) 176 | final_keys.append(key) 177 | 178 | plt.legend(['Expert','BC','DRIL', 'RANDOM']+final_keys, fontsize=6, loc='bottom right') 179 | fsize=16 180 | 181 | plt.xlabel('Expert Trajectories', fontsize=fsize) 182 | plt.ylabel('Reward', fontsize=fsize) 183 | env = args.env.replace('-v4', '').replace('NoFrameskip', '') 184 | plt.title(env, fontsize=fsize) 185 | plt.savefig(f'{env}.pdf') 186 | 187 | plt.clf() 188 | test_rewards = load_results(10)['bc_mse_variance_reward_curve'] 189 | num_values = min([len(test_rewards[0]), len(test_rewards[1])]) 190 | test_rewards = np.array([x[:num_values] for x in test_rewards]) 191 | 192 | u_rewards = load_results(10)['bc_variance_u_reward_curve'] 193 | u_rewards = np.array([x[:num_values] for x in u_rewards]) 194 | 195 | steps = np.array(load_results(10)['bc_variance_steps'][0][:num_values]) 196 | 197 | u_rewards_mean = -numpy.mean(u_rewards, axis=0) 198 | u_rewards_std = numpy.std(u_rewards, axis=0) 199 | 200 | test_rewards_mean = numpy.mean(test_rewards, axis=0) 201 | test_rewards_std = numpy.std(test_rewards, axis=0) 202 | 203 | fig, axs = plt.subplots(2, 1) 204 | ax1 = axs[0] 205 | ax2 = axs[1] 206 | box = dict(facecolor='yellow', pad=5, alpha=0.2) 207 | c1 = color[7] 208 | c2 = color[10] 209 | 210 | ax1.plot(steps, u_rewards_mean, color='black') 211 | ax1.fill_between(steps, u_rewards_mean - u_rewards_std, u_rewards_mean + u_rewards_std, color=c1, alpha=0.3) 212 | ax1.set_xlabel('steps', fontsize=16) 213 | ax1.set_ylabel('Uncertainty Cost', fontsize=12) 214 | ax1.set_title(env, fontsize=16) 215 | ax2.plot(steps, test_rewards_mean, color=c2) 216 | ax2.fill_between(steps, test_rewards_mean - test_rewards_std, test_rewards_mean + test_rewards_std, color=c2, alpha=0.3) 217 | ax2.set_xlabel('steps', fontsize=16) 218 | ax2.set_ylabel('Episode Reward', fontsize=12) 219 | plt.savefig(f'{env}_u_reward.pdf') 220 | 221 | 222 | if __name__== "__main__": 223 | main() 224 | 225 | -------------------------------------------------------------------------------- /pngs/atari.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xkianteb/dril/57eac5c3a5b0f4766821a0bedff043471f91e4f1/pngs/atari.png -------------------------------------------------------------------------------- /pngs/continous_control.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xkianteb/dril/57eac5c3a5b0f4766821a0bedff043471f91e4f1/pngs/continous_control.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from setuptools import setup 3 | 4 | setup(name='dril', 5 | version='1.0', 6 | author='', 7 | packages=['dril'], 8 | ) 9 | --------------------------------------------------------------------------------