├── helpers ├── __init__.py ├── a2c_ppo_acktr │ ├── __init__.py │ ├── algo │ │ ├── __init__.py │ │ ├── a2c_acktr.py │ │ ├── ppo.py │ │ ├── gail.py │ │ └── kfac.py │ ├── utils.py │ ├── distributions.py │ ├── arguments.py │ ├── envs.py │ ├── storage.py │ └── model.py ├── lr_scheduling.py └── random_agent.py ├── requirements.txt ├── environment.yml ├── configs ├── exmpl_config.yml └── full_config.yml ├── vis.py ├── README.md ├── visualize_agent.py ├── a2c_fast.py ├── attention_module.py ├── .gitignore ├── a2c.py └── a2c_dist.py /helpers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /helpers/a2c_ppo_acktr/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /helpers/a2c_ppo_acktr/algo/__init__.py: -------------------------------------------------------------------------------- 1 | from .a2c_acktr import A2C_ACKTR 2 | from .ppo import PPO -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python==3.7.6 2 | matplotlib==3.1.3 3 | numpy==1.18.1 4 | pyyaml==5.3 5 | pytorch==1.3.1 6 | gym==0.17.0 7 | imageio==2.6.1 8 | pandas==1.0.3 9 | seaborn==0.10.0 10 | baselines==0.1.6 -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: drrl 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.7.6 6 | - pip=20.0.2 7 | - matplotlib=3.1.3 8 | - numpy=1.18.1 9 | - pytorch=1.3.1 10 | - imageio=2.6.1 11 | - pandas=1.0.3 12 | - seaborn=0.10.0 13 | - pip: 14 | - gym==0.17.0 15 | - pyyaml==5.3 16 | - baselines==0.1.6 -------------------------------------------------------------------------------- /helpers/lr_scheduling.py: -------------------------------------------------------------------------------- 1 | class Linear_decay: 2 | def __init__(self, lr_init, lr_term, ep_max): 3 | """"Returns an object calculating linear decay factors that's callable with one parameter (episode). After 4 | ep_max is reached, lr plateaus at lr_term. It is a callable so it can be pickled but still statically 5 | parameterized. 6 | params: 7 | lr_init: initial learning rate (episode 1) 8 | lr_term: final learning rate 9 | ep_max: when to reach final lr 10 | """"" 11 | self.lr_init = lr_init 12 | self.lr_term = lr_term 13 | self.ep_max = ep_max 14 | def __call__(self, ep): 15 | return 1-ep* (1-(self.lr_term/self.lr_init)) /self.ep_max if ep 1 or len(stats_path)>1: 12 | raise Exception(f"More than one savepoint found in {os.path.abspath(dirpath)}") 13 | net_path = None if len(net_path)==0 else os.path.join(dirpath, net_path[0]) 14 | stats_path = None if len(stats_path)==0 else os.path.join(dirpath, stats_path[0]) 15 | return net_path, stats_path 16 | 17 | def append_cumsteps(df): 18 | """Append columns with episode steps (sum over workers) and global steps (cumulative sum over episode steps).""" 19 | df = df.copy() 20 | df["global steps"] = df["steps"].cumsum() 21 | df["global steps"] = df.groupby(["global ep"])["global steps"].transform(max) 22 | return df 23 | 24 | def aggregate_df(df): 25 | n_workers = sum(df["global ep"] == 0) # how many workers worked in parallel? 26 | df = df.copy() 27 | df = df.groupby("global ep", as_index=False).aggregate([np.mean, np.var]).reset_index() 28 | df["cumulative steps"] = np.cumsum(df["steps", "mean"])*n_workers 29 | return df 30 | 31 | data1 = pd.read_csv(paths_from_savedir("saves/1e4_newR")[1]).drop(columns=["Unnamed: 0"]) 32 | data2 = pd.read_csv(paths_from_savedir("saves/w19")[1]).drop(columns=["Unnamed: 0"]) 33 | 34 | data = append_cumsteps(data) 35 | # 36 | # agg_1e4 = aggregate_df(data_1e4) 37 | # agg_1e5 = aggregate_df(data_1e5) 38 | 39 | measures = ["cumulative reward", "loss", "steps"] 40 | plt.figure() 41 | for i, measure in enumerate(measures): 42 | plt.subplot(len(measures),1,i+1) 43 | sns.lineplot(x="global steps", y=measure, data=data) 44 | -------------------------------------------------------------------------------- /configs/full_config.yml: -------------------------------------------------------------------------------- 1 | #THIS CONFIG FILE CONTAINS ALL PARAMETERIZABLE VARIABLES WITH DEFAULT VALUES, 2 | #most of them are kwargs in the respecitve inits, so they can be left out of the config file, 3 | # but they are listed here for reference. 4 | seed: 5 | 123 6 | n_cpus: #number of cpus used 7 | 10 8 | cuda: 9 | False 10 | env_name: #only used in a2c_fast.py 11 | gym_boxworld:boxworldRandom-v0 12 | env_config: 13 | n: 12 #size of board 14 | list_goal_lengths: [5] #length of correct path (e.g. 4 means goal can be unlocked with 3rd key), can be list 15 | list_num_distractors: [2] #number of distractor branches, can be list 16 | list_distractor_lengths: [2] #length/"depth" of each distractor branch, can be list 17 | reward_gem: 10 #reward structure 18 | step_cost: 0 #assumed to be negative 19 | reward_dead: 0 20 | reward_correct_key: 1 21 | reward_wrong_key: -1 22 | num_colors: 20 23 | max_steps: 3000 #maximum number of steps before environment terminates 24 | verbose: False 25 | net_config: 26 | n_f_conv1: 12 27 | n_f_conv2: 24 28 | att_emb_size: 64 29 | n_heads: 2 30 | n_att_stack: 2 31 | n_fc_layers: 4 32 | pad: True #padding will maintain size of state space 33 | baseline_mode: False #will replace attentional module with several convolutional layers to create baseline module 34 | n_baseMods: 3 #3and 6 are default in paper 35 | gamma: #temporal discount factor 36 | 0.99 37 | n_step: #number of a2c updates, i.e. number of episodes each worker samples 38 | 10 39 | optimizer: 40 | RMSprop #RMSprop #or Adam 41 | lr: 42 | 0.00001 43 | lr_decay: 44 | True 45 | lr_term: 46 | 0.00005 47 | n_env_steps: 48 | 100000000 49 | update_every_n_steps: #only used in a2c_fast.py 50 | 5 51 | e_schedule: #only used in a2c_dist.py 52 | False 53 | tensorboard: #only used in a2c_dist.py 54 | False 55 | plot_gradients: #only used in a2c_dist.py 56 | False -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A2C training of Relational Deep Reinforcement Learning Architecture 2 | 3 | ## Introduction 4 | 5 | Torch implementation of the deep relational architecture from the paper ["Relational Deep Reinforcement Learning"](https://arxiv.org/pdf/1806.01830.pdf) together with (synchronous) advantage-actor-critic training as discussed for example [here](https://arxiv.org/abs/1602.01783). 6 | 7 | The Box-World environment used in this script can be found at [this repo](https://github.com/mavischer/Box-World). 8 | 9 | Training is performed in `a2c_fast.py`. The implementation is based on [this repo](https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail) which turned out to be more clever and substantially faster than my own implementation `a2c_dist.py`. 10 | However this latter file contains routines to plot the gradients in the network and the computation graph. 11 | 12 | The relational module and general architecture are both implemented as `torch.nn.Module` in `attention_module.py`. However, `a2c_fast.py` uses almost identical adaptations of these classes in `helper/a2c_ppo_acktr/model.yml` that comply with the training algorithm's `Policy` class. 13 | 14 | An example YAML config file parsed from the arguments is `configs/exmpl_config.yml`. Training, the environment and network can be parameterized there. A copy of the loaded configuration file will be saved with checkpoints and logs for documentation. 15 | 16 | A suitable environment can be created e.g. by `conda env create -f environment.yml` or 17 | `pip install -r requirements.txt`. Afterwards install and register the [Box-World environment](https://github.com/mavischer/Box-World) by cloning the repo and `pip install -e gym-boxworld`. 18 | *Remember that after changing the code you need to re-register the environment before the changes become effective.* 19 | You can find the details of state space, action space and reward structure there. 20 | 21 | `visualize_results.ipynb` contains some plotting functionality. 22 | 23 | ## Example Run 24 | 25 | ```bash 26 | python a2c.py -c configs/exmpl_config.yml -s example_run 27 | ``` -------------------------------------------------------------------------------- /helpers/a2c_ppo_acktr/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from helpers.a2c_ppo_acktr.envs import VecNormalize 8 | 9 | 10 | # Get a render function 11 | def get_render_func(venv): 12 | if hasattr(venv, 'envs'): 13 | return venv.envs[0].render 14 | elif hasattr(venv, 'venv'): 15 | return get_render_func(venv.venv) 16 | elif hasattr(venv, 'env'): 17 | return get_render_func(venv.env) 18 | 19 | return None 20 | 21 | 22 | def get_vec_normalize(venv): 23 | if isinstance(venv, VecNormalize): 24 | return venv 25 | elif hasattr(venv, 'venv'): 26 | return get_vec_normalize(venv.venv) 27 | 28 | return None 29 | 30 | 31 | # Necessary for my KFAC implementation. 32 | class AddBias(nn.Module): 33 | def __init__(self, bias): 34 | super(AddBias, self).__init__() 35 | self._bias = nn.Parameter(bias.unsqueeze(1)) 36 | 37 | def forward(self, x): 38 | if x.dim() == 2: 39 | bias = self._bias.t().view(1, -1) 40 | else: 41 | bias = self._bias.t().view(1, -1, 1, 1) 42 | 43 | return x + bias 44 | 45 | 46 | def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr): 47 | """Decreases the learning rate linearly""" 48 | lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs))) 49 | for param_group in optimizer.param_groups: 50 | param_group['lr'] = lr 51 | 52 | 53 | def init(module, weight_init, bias_init, gain=1): 54 | weight_init(module.weight.data, gain=gain) 55 | bias_init(module.bias.data) 56 | return module 57 | 58 | def init_flexible(module, weight_init, bias_init, kwargs={"gain":1}): 59 | weight_init(module.weight.data, **kwargs) 60 | bias_init(module.bias.data) 61 | return module 62 | 63 | 64 | def cleanup_log_dir(log_dir): 65 | try: 66 | os.makedirs(log_dir) 67 | except OSError: 68 | files = glob.glob(os.path.join(log_dir, '*.monitor.csv')) 69 | for f in files: 70 | os.rename(f, f+'_old') 71 | -------------------------------------------------------------------------------- /helpers/a2c_ppo_acktr/distributions.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from helpers.a2c_ppo_acktr.utils import AddBias, init 8 | 9 | """ 10 | Modify standard PyTorch distributions so they are compatible with this code. 11 | """ 12 | 13 | # 14 | # Standardize distribution interfaces 15 | # 16 | 17 | # Categorical 18 | class FixedCategorical(torch.distributions.Categorical): 19 | def sample(self): 20 | return super().sample().unsqueeze(-1) 21 | 22 | def log_probs(self, actions): 23 | return ( 24 | super() 25 | .log_prob(actions.squeeze(-1)) 26 | .view(actions.size(0), -1) 27 | .sum(-1) 28 | .unsqueeze(-1) 29 | ) 30 | 31 | def mode(self): 32 | return self.probs.argmax(dim=-1, keepdim=True) 33 | 34 | 35 | # Normal 36 | class FixedNormal(torch.distributions.Normal): 37 | def log_probs(self, actions): 38 | return super().log_prob(actions).sum(-1, keepdim=True) 39 | 40 | def entrop(self): 41 | return super.entropy().sum(-1) 42 | 43 | def mode(self): 44 | return self.mean 45 | 46 | 47 | # Bernoulli 48 | class FixedBernoulli(torch.distributions.Bernoulli): 49 | def log_probs(self, actions): 50 | return super.log_prob(actions).view(actions.size(0), -1).sum(-1).unsqueeze(-1) 51 | 52 | def entropy(self): 53 | return super().entropy().sum(-1) 54 | 55 | def mode(self): 56 | return torch.gt(self.probs, 0.5).float() 57 | 58 | 59 | class Categorical(nn.Module): 60 | def __init__(self, num_inputs, num_outputs): 61 | super(Categorical, self).__init__() 62 | 63 | init_ = lambda m: init( 64 | m, 65 | nn.init.orthogonal_, 66 | lambda x: nn.init.constant_(x, 0), 67 | gain=0.01) 68 | 69 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 70 | 71 | def forward(self, x): 72 | x = self.linear(x) 73 | return FixedCategorical(logits=x) 74 | 75 | 76 | class DiagGaussian(nn.Module): 77 | def __init__(self, num_inputs, num_outputs): 78 | super(DiagGaussian, self).__init__() 79 | 80 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 81 | constant_(x, 0)) 82 | 83 | self.fc_mean = init_(nn.Linear(num_inputs, num_outputs)) 84 | self.logstd = AddBias(torch.zeros(num_outputs)) 85 | 86 | def forward(self, x): 87 | action_mean = self.fc_mean(x) 88 | 89 | # An ugly hack for my KFAC implementation. 90 | zeros = torch.zeros(action_mean.size()) 91 | if x.is_cuda: 92 | zeros = zeros.cuda() 93 | 94 | action_logstd = self.logstd(zeros) 95 | return FixedNormal(action_mean, action_logstd.exp()) 96 | 97 | 98 | class Bernoulli(nn.Module): 99 | def __init__(self, num_inputs, num_outputs): 100 | super(Bernoulli, self).__init__() 101 | 102 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 103 | constant_(x, 0)) 104 | 105 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 106 | 107 | def forward(self, x): 108 | x = self.linear(x) 109 | return FixedBernoulli(logits=x) 110 | -------------------------------------------------------------------------------- /helpers/a2c_ppo_acktr/algo/a2c_acktr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | 5 | from helpers.a2c_ppo_acktr.algo.kfac import KFACOptimizer 6 | 7 | 8 | class A2C_ACKTR(): 9 | def __init__(self, 10 | actor_critic, 11 | value_loss_coef, 12 | entropy_coef, 13 | lr=None, 14 | lr_decay=None, 15 | lr_sched_fn=None, 16 | eps=None, 17 | alpha=None, 18 | max_grad_norm=None, 19 | acktr=False): 20 | 21 | self.actor_critic = actor_critic 22 | self.acktr = acktr 23 | 24 | self.lr_decay = lr_decay 25 | 26 | self.value_loss_coef = value_loss_coef 27 | self.entropy_coef = entropy_coef 28 | 29 | self.max_grad_norm = max_grad_norm 30 | 31 | if acktr: 32 | self.optimizer = KFACOptimizer(actor_critic) 33 | else: 34 | self.optimizer = optim.RMSprop( 35 | actor_critic.parameters(), lr, eps=eps, alpha=alpha) 36 | if lr_decay: 37 | if not lr_sched_fn: 38 | raise ValueError("Please specify learning rate multiplicative factor function") 39 | self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_sched_fn) 40 | 41 | def update(self, rollouts): 42 | obs_shape = rollouts.obs.size()[2:] 43 | action_shape = rollouts.actions.size()[-1] 44 | num_steps, num_processes, _ = rollouts.rewards.size() 45 | 46 | values, action_log_probs, dist_entropy, _ = self.actor_critic.evaluate_actions( 47 | rollouts.obs[:-1].view(-1, *obs_shape), 48 | rollouts.recurrent_hidden_states[0].view( 49 | -1, self.actor_critic.recurrent_hidden_state_size), 50 | rollouts.masks[:-1].view(-1, 1), 51 | rollouts.actions.view(-1, action_shape)) 52 | 53 | values = values.view(num_steps, num_processes, 1) 54 | action_log_probs = action_log_probs.view(num_steps, num_processes, 1) 55 | 56 | advantages = rollouts.returns[:-1] - values 57 | value_loss = advantages.pow(2).mean() 58 | 59 | action_loss = -(advantages.detach() * action_log_probs).mean() 60 | 61 | if self.acktr and self.optimizer.steps % self.optimizer.Ts == 0: 62 | # Compute fisher, see Martens 2014 63 | self.actor_critic.zero_grad() 64 | pg_fisher_loss = -action_log_probs.mean() 65 | 66 | value_noise = torch.randn(values.size()) 67 | if values.is_cuda: 68 | value_noise = value_noise.cuda() 69 | 70 | sample_values = values + value_noise 71 | vf_fisher_loss = -(values - sample_values.detach()).pow(2).mean() 72 | 73 | fisher_loss = pg_fisher_loss + vf_fisher_loss 74 | self.optimizer.acc_stats = True 75 | fisher_loss.backward(retain_graph=True) 76 | self.optimizer.acc_stats = False 77 | 78 | self.optimizer.zero_grad() 79 | (value_loss * self.value_loss_coef + action_loss - 80 | dist_entropy * self.entropy_coef).backward() 81 | 82 | if self.acktr == False: 83 | nn.utils.clip_grad_norm_(self.actor_critic.parameters(), 84 | self.max_grad_norm) 85 | 86 | self.optimizer.step() 87 | if self.lr_decay: #update learning rate 88 | self.lr_scheduler.step() 89 | 90 | return value_loss.item(), action_loss.item(), dist_entropy.item() 91 | -------------------------------------------------------------------------------- /helpers/a2c_ppo_acktr/algo/ppo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | 7 | class PPO(): 8 | def __init__(self, 9 | actor_critic, 10 | clip_param, 11 | ppo_epoch, 12 | num_mini_batch, 13 | value_loss_coef, 14 | entropy_coef, 15 | lr=None, 16 | eps=None, 17 | max_grad_norm=None, 18 | use_clipped_value_loss=True): 19 | 20 | self.actor_critic = actor_critic 21 | 22 | self.clip_param = clip_param 23 | self.ppo_epoch = ppo_epoch 24 | self.num_mini_batch = num_mini_batch 25 | 26 | self.value_loss_coef = value_loss_coef 27 | self.entropy_coef = entropy_coef 28 | 29 | self.max_grad_norm = max_grad_norm 30 | self.use_clipped_value_loss = use_clipped_value_loss 31 | 32 | self.optimizer = optim.Adam(actor_critic.parameters(), lr=lr, eps=eps) 33 | 34 | def update(self, rollouts): 35 | advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1] 36 | advantages = (advantages - advantages.mean()) / ( 37 | advantages.std() + 1e-5) 38 | 39 | value_loss_epoch = 0 40 | action_loss_epoch = 0 41 | dist_entropy_epoch = 0 42 | 43 | for e in range(self.ppo_epoch): 44 | if self.actor_critic.is_recurrent: 45 | data_generator = rollouts.recurrent_generator( 46 | advantages, self.num_mini_batch) 47 | else: 48 | data_generator = rollouts.feed_forward_generator( 49 | advantages, self.num_mini_batch) 50 | 51 | for sample in data_generator: 52 | obs_batch, recurrent_hidden_states_batch, actions_batch, \ 53 | value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, \ 54 | adv_targ = sample 55 | 56 | # Reshape to do in a single forward pass for all steps 57 | values, action_log_probs, dist_entropy, _ = self.actor_critic.evaluate_actions( 58 | obs_batch, recurrent_hidden_states_batch, masks_batch, 59 | actions_batch) 60 | 61 | ratio = torch.exp(action_log_probs - 62 | old_action_log_probs_batch) 63 | surr1 = ratio * adv_targ 64 | surr2 = torch.clamp(ratio, 1.0 - self.clip_param, 65 | 1.0 + self.clip_param) * adv_targ 66 | action_loss = -torch.min(surr1, surr2).mean() 67 | 68 | if self.use_clipped_value_loss: 69 | value_pred_clipped = value_preds_batch + \ 70 | (values - value_preds_batch).clamp(-self.clip_param, self.clip_param) 71 | value_losses = (values - return_batch).pow(2) 72 | value_losses_clipped = ( 73 | value_pred_clipped - return_batch).pow(2) 74 | value_loss = 0.5 * torch.max(value_losses, 75 | value_losses_clipped).mean() 76 | else: 77 | value_loss = 0.5 * (return_batch - values).pow(2).mean() 78 | 79 | self.optimizer.zero_grad() 80 | (value_loss * self.value_loss_coef + action_loss - 81 | dist_entropy * self.entropy_coef).backward() 82 | nn.utils.clip_grad_norm_(self.actor_critic.parameters(), 83 | self.max_grad_norm) 84 | self.optimizer.step() 85 | 86 | value_loss_epoch += value_loss.item() 87 | action_loss_epoch += action_loss.item() 88 | dist_entropy_epoch += dist_entropy.item() 89 | 90 | num_updates = self.ppo_epoch * self.num_mini_batch 91 | 92 | value_loss_epoch /= num_updates 93 | action_loss_epoch /= num_updates 94 | dist_entropy_epoch /= num_updates 95 | 96 | return value_loss_epoch, action_loss_epoch, dist_entropy_epoch 97 | -------------------------------------------------------------------------------- /helpers/a2c_ppo_acktr/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | 6 | def get_args(): 7 | parser = argparse.ArgumentParser(description='RL') 8 | parser.add_argument( 9 | '--algo', default='a2c', help='algorithm to use: a2c | ppo | acktr') 10 | parser.add_argument( 11 | '--gail', 12 | action='store_true', 13 | default=False, 14 | help='do imitation learning with gail') 15 | parser.add_argument( 16 | '--gail-experts-dir', 17 | default='./gail_experts', 18 | help='directory that contains expert demonstrations for gail') 19 | parser.add_argument( 20 | '--gail-batch-size', 21 | type=int, 22 | default=128, 23 | help='gail batch size (default: 128)') 24 | parser.add_argument( 25 | '--gail-epoch', type=int, default=5, help='gail epochs (default: 5)') 26 | parser.add_argument( 27 | '--lr', type=float, default=7e-4, help='learning rate (default: 7e-4)') 28 | parser.add_argument( 29 | '--eps', 30 | type=float, 31 | default=1e-5, 32 | help='RMSprop optimizer epsilon (default: 1e-5)') 33 | parser.add_argument( 34 | '--alpha', 35 | type=float, 36 | default=0.99, 37 | help='RMSprop optimizer apha (default: 0.99)') 38 | parser.add_argument( 39 | '--gamma', 40 | type=float, 41 | default=0.99, 42 | help='discount factor for rewards (default: 0.99)') 43 | parser.add_argument( 44 | '--use-gae', 45 | action='store_true', 46 | default=False, 47 | help='use generalized advantage estimation') 48 | parser.add_argument( 49 | '--gae-lambda', 50 | type=float, 51 | default=0.95, 52 | help='gae lambda parameter (default: 0.95)') 53 | parser.add_argument( 54 | '--entropy-coef', 55 | type=float, 56 | default=0.01, 57 | help='entropy term coefficient (default: 0.01)') 58 | parser.add_argument( 59 | '--value-loss-coef', 60 | type=float, 61 | default=0.5, 62 | help='value loss coefficient (default: 0.5)') 63 | parser.add_argument( 64 | '--max-grad-norm', 65 | type=float, 66 | default=0.5, 67 | help='max norm of gradients (default: 0.5)') 68 | parser.add_argument( 69 | '--seed', type=int, default=1, help='random seed (default: 1)') 70 | parser.add_argument( 71 | '--cuda-deterministic', 72 | action='store_true', 73 | default=False, 74 | help="sets flags for determinism when using CUDA (potentially slow!)") 75 | parser.add_argument( 76 | '--num-processes', 77 | type=int, 78 | default=16, 79 | help='how many training CPU processes to use (default: 16)') 80 | parser.add_argument( 81 | '--num-steps', 82 | type=int, 83 | default=5, 84 | help='number of forward steps in A2C (default: 5)') 85 | parser.add_argument( 86 | '--ppo-epoch', 87 | type=int, 88 | default=4, 89 | help='number of ppo epochs (default: 4)') 90 | parser.add_argument( 91 | '--num-mini-batch', 92 | type=int, 93 | default=32, 94 | help='number of batches for ppo (default: 32)') 95 | parser.add_argument( 96 | '--clip-param', 97 | type=float, 98 | default=0.2, 99 | help='ppo clip parameter (default: 0.2)') 100 | parser.add_argument( 101 | '--log-interval', 102 | type=int, 103 | default=10, 104 | help='log interval, one log per n updates (default: 10)') 105 | parser.add_argument( 106 | '--save-interval', 107 | type=int, 108 | default=100, 109 | help='save interval, one save per n updates (default: 100)') 110 | parser.add_argument( 111 | '--eval-interval', 112 | type=int, 113 | default=None, 114 | help='eval interval, one eval per n updates (default: None)') 115 | parser.add_argument( 116 | '--num-env-steps', 117 | type=int, 118 | default=10e6, 119 | help='number of environment steps to train (default: 10e6)') 120 | parser.add_argument( 121 | '--env-name', 122 | default='PongNoFrameskip-v4', 123 | help='environment to train on (default: PongNoFrameskip-v4)') 124 | parser.add_argument( 125 | '--log-dir', 126 | default='/tmp/gym/', 127 | help='directory to save agent logs (default: /tmp/gym)') 128 | parser.add_argument( 129 | '--save-dir', 130 | default='./trained_models/', 131 | help='directory to save agent logs (default: ./trained_models/)') 132 | parser.add_argument( 133 | '--no-cuda', 134 | action='store_true', 135 | default=False, 136 | help='disables CUDA training') 137 | parser.add_argument( 138 | '--use-proper-time-limits', 139 | action='store_true', 140 | default=False, 141 | help='compute returns taking into account time limits') 142 | parser.add_argument( 143 | '--recurrent-policy', 144 | action='store_true', 145 | default=False, 146 | help='use a recurrent policy') 147 | parser.add_argument( 148 | '--use-linear-lr-decay', 149 | action='store_true', 150 | default=False, 151 | help='use a linear schedule on the learning rate') 152 | args = parser.parse_args() 153 | 154 | args.cuda = not args.no_cuda and torch.cuda.is_available() 155 | 156 | assert args.algo in ['a2c', 'ppo', 'acktr'] 157 | if args.recurrent_policy: 158 | assert args.algo in ['a2c', 'ppo'], \ 159 | 'Recurrent policy is not implemented for ACKTR' 160 | 161 | if args.attarch and args.attarchbaseline: 162 | raise ValueError("attarch and attarchbaseline can't both be true, choose either of them") 163 | 164 | return args 165 | -------------------------------------------------------------------------------- /helpers/a2c_ppo_acktr/algo/gail.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.utils.data 7 | from torch import autograd 8 | 9 | from baselines.common.running_mean_std import RunningMeanStd 10 | 11 | 12 | class Discriminator(nn.Module): 13 | def __init__(self, input_dim, hidden_dim, device): 14 | super(Discriminator, self).__init__() 15 | 16 | self.device = device 17 | 18 | self.trunk = nn.Sequential( 19 | nn.Linear(input_dim, hidden_dim), nn.Tanh(), 20 | nn.Linear(hidden_dim, hidden_dim), nn.Tanh(), 21 | nn.Linear(hidden_dim, 1)).to(device) 22 | 23 | self.trunk.train() 24 | 25 | self.optimizer = torch.optim.Adam(self.trunk.parameters()) 26 | 27 | self.returns = None 28 | self.ret_rms = RunningMeanStd(shape=()) 29 | 30 | def compute_grad_pen(self, 31 | expert_state, 32 | expert_action, 33 | policy_state, 34 | policy_action, 35 | lambda_=10): 36 | alpha = torch.rand(expert_state.size(0), 1) 37 | expert_data = torch.cat([expert_state, expert_action], dim=1) 38 | policy_data = torch.cat([policy_state, policy_action], dim=1) 39 | 40 | alpha = alpha.expand_as(expert_data).to(expert_data.device) 41 | 42 | mixup_data = alpha * expert_data + (1 - alpha) * policy_data 43 | mixup_data.requires_grad = True 44 | 45 | disc = self.trunk(mixup_data) 46 | ones = torch.ones(disc.size()).to(disc.device) 47 | grad = autograd.grad( 48 | outputs=disc, 49 | inputs=mixup_data, 50 | grad_outputs=ones, 51 | create_graph=True, 52 | retain_graph=True, 53 | only_inputs=True)[0] 54 | 55 | grad_pen = lambda_ * (grad.norm(2, dim=1) - 1).pow(2).mean() 56 | return grad_pen 57 | 58 | def update(self, expert_loader, rollouts, obsfilt=None): 59 | self.train() 60 | 61 | policy_data_generator = rollouts.feed_forward_generator( 62 | None, mini_batch_size=expert_loader.batch_size) 63 | 64 | loss = 0 65 | n = 0 66 | for expert_batch, policy_batch in zip(expert_loader, 67 | policy_data_generator): 68 | policy_state, policy_action = policy_batch[0], policy_batch[2] 69 | policy_d = self.trunk( 70 | torch.cat([policy_state, policy_action], dim=1)) 71 | 72 | expert_state, expert_action = expert_batch 73 | expert_state = obsfilt(expert_state.numpy(), update=False) 74 | expert_state = torch.FloatTensor(expert_state).to(self.device) 75 | expert_action = expert_action.to(self.device) 76 | expert_d = self.trunk( 77 | torch.cat([expert_state, expert_action], dim=1)) 78 | 79 | expert_loss = F.binary_cross_entropy_with_logits( 80 | expert_d, 81 | torch.ones(expert_d.size()).to(self.device)) 82 | policy_loss = F.binary_cross_entropy_with_logits( 83 | policy_d, 84 | torch.zeros(policy_d.size()).to(self.device)) 85 | 86 | gail_loss = expert_loss + policy_loss 87 | grad_pen = self.compute_grad_pen(expert_state, expert_action, 88 | policy_state, policy_action) 89 | 90 | loss += (gail_loss + grad_pen).item() 91 | n += 1 92 | 93 | self.optimizer.zero_grad() 94 | (gail_loss + grad_pen).backward() 95 | self.optimizer.step() 96 | return loss / n 97 | 98 | def predict_reward(self, state, action, gamma, masks, update_rms=True): 99 | with torch.no_grad(): 100 | self.eval() 101 | d = self.trunk(torch.cat([state, action], dim=1)) 102 | s = torch.sigmoid(d) 103 | reward = s.log() - (1 - s).log() 104 | if self.returns is None: 105 | self.returns = reward.clone() 106 | 107 | if update_rms: 108 | self.returns = self.returns * masks * gamma + reward 109 | self.ret_rms.update(self.returns.cpu().numpy()) 110 | 111 | return reward / np.sqrt(self.ret_rms.var[0] + 1e-8) 112 | 113 | 114 | class ExpertDataset(torch.utils.data.Dataset): 115 | def __init__(self, file_name, num_trajectories=4, subsample_frequency=20): 116 | all_trajectories = torch.load(file_name) 117 | 118 | perm = torch.randperm(all_trajectories['states'].size(0)) 119 | idx = perm[:num_trajectories] 120 | 121 | self.trajectories = {} 122 | 123 | # See https://github.com/pytorch/pytorch/issues/14886 124 | # .long() for fixing bug in torch v0.4.1 125 | start_idx = torch.randint( 126 | 0, subsample_frequency, size=(num_trajectories, )).long() 127 | 128 | for k, v in all_trajectories.items(): 129 | data = v[idx] 130 | 131 | if k != 'lengths': 132 | samples = [] 133 | for i in range(num_trajectories): 134 | samples.append(data[i, start_idx[i]::subsample_frequency]) 135 | self.trajectories[k] = torch.stack(samples) 136 | else: 137 | self.trajectories[k] = data // subsample_frequency 138 | 139 | self.i2traj_idx = {} 140 | self.i2i = {} 141 | 142 | self.length = self.trajectories['lengths'].sum().item() 143 | 144 | traj_idx = 0 145 | i = 0 146 | 147 | self.get_idx = [] 148 | 149 | for j in range(self.length): 150 | 151 | while self.trajectories['lengths'][traj_idx].item() <= i: 152 | i -= self.trajectories['lengths'][traj_idx].item() 153 | traj_idx += 1 154 | 155 | self.get_idx.append((traj_idx, i)) 156 | 157 | i += 1 158 | 159 | 160 | def __len__(self): 161 | return self.length 162 | 163 | def __getitem__(self, i): 164 | traj_idx, i = self.get_idx[i] 165 | 166 | return self.trajectories['states'][traj_idx][i], self.trajectories[ 167 | 'actions'][traj_idx][i] 168 | -------------------------------------------------------------------------------- /visualize_agent.py: -------------------------------------------------------------------------------- 1 | # Script to load a trained policy, play an example game and visualize attentional weights 2 | import argparse 3 | import os 4 | import yaml 5 | import torch 6 | import gym 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | 10 | from helpers.a2c_ppo_acktr import algo, utils 11 | from helpers.a2c_ppo_acktr.envs import make_vec_envs 12 | from helpers.a2c_ppo_acktr.model import Policy, DRRLBase 13 | from helpers.a2c_ppo_acktr.storage import RolloutStorage 14 | 15 | plt.ioff() #only plot when asked explicitly 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch A2C BoxWorld Agent Visualization') 18 | parser.add_argument("-s", "--savepath", type=str, required=True, help="path/to/savedirectory") 19 | parser.add_argument("-i", "--imagepath", type=str, required=True, help="path/to/save/images") 20 | args = parser.parse_args() 21 | 22 | #create target directory 23 | if not os.path.exists(args.imagepath): 24 | try: 25 | os.makedirs(args.imagepath) 26 | except OSError: 27 | print('Error: Creating images target directory. ') 28 | #load config 29 | with open(os.path.join(args.savepath, "config.yml"), 'r') as file: 30 | config = yaml.safe_load(file) 31 | net_config = config["net_config"] 32 | #load agent 33 | [start_upd, actor_critic, agent] = torch.load(os.path.join(args.savepath, "saves", "ckpt.pt")) 34 | actor_critic.eval() 35 | 36 | #create environment 37 | env = gym.make(config["env_name"], **config["env_config"]) 38 | 39 | #play environment 40 | done = False 41 | i_step = 0 42 | img = env.reset() 43 | xsize = img.shape[0] 44 | ysize = img.shape[1] 45 | while not done: 46 | obs = torch.tensor([np.moveaxis(img, -1, 0)], dtype=torch.uint8) #todo: seems like vecEnv permutes the image 47 | # correctly 48 | _, action, _, _ = actor_critic.act(obs, None, None) 49 | att_weights = actor_critic.base.get_attention_weights(obs) #att_weights is a list of lists, outer level contains 50 | # stacks of attention module, inner level heads 51 | 52 | fig = plt.figure(figsize=(18.8, 9), constrained_layout=False) 53 | # fig = plt.figure(figsize=(9, 9), constrained_layout=False) 54 | black_img = np.zeros(img.shape) 55 | 56 | # gridspec inside gridspec 57 | outer_grid = fig.add_gridspec(net_config["n_att_stack"], net_config["n_heads"], wspace=0.1, hspace=0.1) 58 | for i_head in range(net_config["n_heads"]): 59 | for i_stack in range(net_config["n_att_stack"]): 60 | #extract current weight map 61 | weightmap_curr = att_weights[i_stack][i_head].numpy().squeeze() # map for specific stack and head 62 | 63 | #PLOTTING OVERHEAD 64 | #set up outer box corresponding to stack x head 65 | ax = fig.add_subplot(outer_grid[i_stack, i_head]) 66 | ax.set_xticks([]) 67 | ax.set_yticks([]) 68 | ax.set_title(f"head{ i_head}, pass {i_stack}") 69 | #all weights for specific stack number, head go inside as a subgridspec 70 | inner_grid = outer_grid[i_stack, i_head].subgridspec(xsize, ysize, wspace=0.0, hspace=0.0) 71 | 72 | #now loop over all source (query) entities: 73 | for ent_i in range(weightmap_curr.shape[0]): #index ent_i is single dim (e.g. entity 32 in a 7x7 grid) 74 | w_max = np.max(weightmap_curr[ent_i, :]) # memorize maximum weight. btw: this vector sums to 1 75 | x_s, y_s = np.unravel_index(ent_i, img.shape[:-1]) #translate to 2d idx for plotting 76 | 77 | # prepare grid 78 | ax = fig.add_subplot(inner_grid[x_s, y_s]) #inner_grid also has matrix-convention x,y-arrangement, 79 | # so surprisingly we don't need to revert the indices here! 80 | # ax.annotate(f"x:{x_s},y:{y_s}", (1,3)) #for bugfixing 81 | ax.set_xticks([]) 82 | ax.set_yticks([]) 83 | 84 | if (x_s==0 and y_s!=0) or (x_s!=0 and y_s==0) or x_s==img.shape[0]-1 or y_s==img.shape[1]-1: 85 | #source entity is on black border 86 | ax.imshow(black_img, vmin=0, vmax=255, interpolation="none") 87 | else: 88 | ax.imshow(img, vmin=0, vmax=255, interpolation="none") 89 | 90 | #loop over target (key) entitites: 91 | for ent_j in range(weightmap_curr.shape[1]): 92 | weight = weightmap_curr[ent_i, ent_j] 93 | if weight > 0.5 * w_max: 94 | x_t, y_t = np.unravel_index(ent_j, img.shape[:-1]) 95 | ax.scatter(y_t, x_t, s=3, c='red', marker='o') 96 | # ax.scatter(y_s, x_s, s=5, c='blue', marker='o') 97 | 98 | 99 | ax.arrow(y_s, x_s, y_t-y_s, x_t-x_s, #DUE TO IMSHOW AXES ARE INVERTED!!! 100 | length_includes_head=True, 101 | head_width=0.2, 102 | head_length=0.3, 103 | alpha=weight*5)# / w_max) 104 | #save figure 105 | plt.savefig(os.path.join(args.imagepath, 'frame_{}.png'.format(i_step))) 106 | 107 | #next step 108 | img, _, done, _ = env.step(action.item()) 109 | obs = torch.tensor([np.moveaxis(img, -1, 0)], dtype=torch.uint8) 110 | i_step += 1 111 | 112 | # 113 | # 114 | # #next step 115 | # img, _, done, _ = env.step(action.item()) 116 | # obs = torch.tensor([np.moveaxis(img, -1, 0)], dtype=torch.uint8) 117 | # 118 | # _, action, _, _ = actor_critic.act(obs, None, None) 119 | # att_weights = actor_critic.base.get_attention_weights(obs) #att_weights is a list of lists, outer level contains 120 | # 121 | # #example 1,3: 122 | # xcoor = 5 123 | # ycoor = 2 124 | # # threshld = 5/(7*7) 125 | # flat_idx = np.ravel_multi_index((xcoor, ycoor), [7,7]) 126 | # weights_1 = att_weights[0][0][0,flat_idx,:].numpy().squeeze() 127 | # weights_2 = att_weights[0][1][0,flat_idx,:].numpy().squeeze() 128 | # w_max = np.max(weights_1) 129 | # target_idxs = [(idx, val) for (idx, val) in enumerate(weights_1) if val > w_max/3] 130 | # 131 | # plt.subplot(1,2,1) 132 | # plt.imshow(img) 133 | # for target_idx, val in target_idxs: 134 | # target_idx = np.unravel_index(target_idx, [7,7]) 135 | # plt.arrow(xcoor, ycoor, target_idx[0]-xcoor, target_idx[1]-ycoor, 136 | # length_includes_head=True, 137 | # head_width=0.2, 138 | # head_length=0.3, 139 | # alpha=val/w_max) 140 | # plt.subplot(1,2,2) 141 | # plt.imshow(weights_1.reshape([7,7]).T) 142 | # 143 | # 144 | # 145 | # 146 | # flat_idx = np.ravel_multi_index((xcoor, ycoor), [7,7]) 147 | # weights_1 = att_weights[0][0][0,flat_idx,:].reshape([7,7]).numpy() 148 | # weights_2 = att_weights[0][1][0,flat_idx,:].reshape([7,7]).numpy() 149 | # img = np.moveaxis(obs.numpy().squeeze(), 0,-1) 150 | # 151 | # 152 | # target_idx = np.unravel_index(np.argmax(weights_1), [7,7]) 153 | # plt.subplot(1,2,1) 154 | # plt.imshow(img) 155 | # 156 | # plt.arrow(xcoor, ycoor, target_idx[0]-xcoor, target_idx[1]-ycoor, 157 | # length_includes_head=True, 158 | # head_width=0.2, 159 | # head_length=0.3, 160 | # alpha=1) 161 | # plt.subplot(1,2,2) 162 | # plt.imshow(weights_1) 163 | # 164 | # 165 | # # a = obs.numpy().squeeze().T 166 | # plt.subplot(2,2,1) 167 | # plt.imshow(a) 168 | # plt.subplot(2,2,2) 169 | # plt.imshow(weights_1) 170 | # plt.subplot(2,2,4) 171 | # plt.imshow(weights_2) 172 | 173 | 174 | 175 | 176 | 177 | # fig2 = plt.imshow(img, vmin=0, vmax=255, interpolation='none') 178 | # fig2.axes.get_xaxis().set_visible(False) 179 | # fig2.axes.get_yaxis().set_visible(False) 180 | # plt.savefig(os.path.join('images', 'observation_{}_{}.png'.format(i_episode, t))) 181 | # 182 | # #next step 183 | # img, _, done, _ = env.step(action.item()) 184 | # obs = torch.tensor([img.T], dtype=torch.uint8) 185 | # 186 | # i_episode += 1 187 | 188 | 189 | -------------------------------------------------------------------------------- /a2c_fast.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | import argparse 5 | import yaml 6 | import time 7 | import csv 8 | from collections import deque 9 | 10 | from helpers.a2c_ppo_acktr import algo, utils 11 | from helpers.a2c_ppo_acktr.envs import make_vec_envs 12 | from helpers.a2c_ppo_acktr.model import Policy, DRRLBase 13 | from helpers.a2c_ppo_acktr.storage import RolloutStorage 14 | from helpers.lr_scheduling import Linear_decay 15 | # from baselines.common import plot_util 16 | 17 | # if "e_schedule" in config.keys(): #todo: implement entropy weight scheduling 18 | # e_schedule = config["e_schedule"] 19 | # else: 20 | # e_schedule = False 21 | 22 | def main(): 23 | # parse yaml config file from cmdline 24 | parser = argparse.ArgumentParser(description='PyTorch A2C BoxWorld Experiment') 25 | parser.add_argument("-c", "--configpath", type=str, required=True, help="path/to/configfile.yml") 26 | parser.add_argument("-s", "--savepath", type=str, required=True, help="path/to/savedirectory") 27 | args = parser.parse_args() 28 | with open(os.path.abspath(args.configpath), 'r') as file: 29 | config = yaml.safe_load(file) 30 | 31 | SAVE_EVERY = 1000 32 | LOG_EVERY = 100 33 | 34 | # todo: implement adam optimizer? 35 | 36 | #set up torch 37 | torch.manual_seed(config["seed"]) 38 | torch.cuda.manual_seed_all(config["seed"]) 39 | 40 | torch.set_num_threads(config["n_cpus"]) #intra-op parallelism 41 | device = torch.device("cuda:0" if config["cuda"] and torch.cuda.is_available() else "cpu") 42 | 43 | #set up logging 44 | log_dir = os.path.expanduser(os.path.join(args.savepath, "logs")) 45 | utils.cleanup_log_dir(log_dir) 46 | save_path = os.path.join(args.savepath, "saves") 47 | 48 | save_stats_path = os.path.join(log_dir, "training_losses.csv") 49 | 50 | #make environments in vectorizer wrapper sharing memory 51 | envs = make_vec_envs((config["env_name"], config["env_config"]), config["seed"], config["n_cpus"], 52 | config["gamma"], log_dir, device, False, num_frame_stack=1) # default is stacking 4 frames 53 | 54 | #load from startpoint 55 | modelckptpath = os.path.join(save_path, "ckpt.pt") 56 | if os.path.isfile(modelckptpath): 57 | #check whether configs are identical 58 | # load config and check whether identical 59 | with open(os.path.join(args.savepath, "config.yml"), 'r') as file: 60 | config_old = yaml.safe_load(file) 61 | if config_old != config: 62 | raise Exception("Existing config different from current config") 63 | #load iteration, algo and agent 64 | [start_upd, actor_critic, agent] = torch.load(modelckptpath) 65 | print(f"loaded from savepoint {start_upd} in folder {modelckptpath}") 66 | 67 | #or start fresh 68 | else: 69 | #start new entropy logging file 70 | with open(save_stats_path, "w") as f: 71 | f.write("value loss,action loss,entropy\n") 72 | 73 | # write config to new directory 74 | with open(os.path.join(args.savepath, "config.yml"), "w+") as f: 75 | f.write(yaml.dump(config)) 76 | 77 | # start new training process 78 | stats = [] 79 | i_start = 0 80 | print("starting new training process") 81 | start_upd = 0 82 | base_kwargs = config["net_config"] 83 | base_kwargs["w"] = base_kwargs["h"] = config["env_config"]["n"]+2 #+2 is for black edge 84 | 85 | actor_critic = Policy( 86 | envs.observation_space.shape, 87 | envs.action_space, 88 | base=DRRLBase, 89 | base_kwargs=base_kwargs) 90 | actor_critic.to(device) 91 | 92 | #set up linear learning rate decay 93 | if config["lr_decay"]: 94 | ep_max = 3e8 / (config["n_cpus"] * config["update_every_n_steps"]) 95 | # ep_max = config["n_env_steps"] / (config["n_cpus"] * config["update_every_n_steps"]) 96 | if config["lr_term"]: 97 | lr_term = config["lr_term"] 98 | else: 99 | lr_term = 1e-5 100 | lr_sched_fn = Linear_decay(lr_init=config["lr"], lr_term=lr_term, ep_max=ep_max) 101 | else: 102 | lr_sched_fn = None 103 | 104 | agent = algo.A2C_ACKTR( 105 | actor_critic, 106 | value_loss_coef=0.5, 107 | entropy_coef=0.01, 108 | lr=config["lr"], 109 | lr_decay=config["lr_decay"], 110 | lr_sched_fn=lr_sched_fn, 111 | eps=1e-5, 112 | alpha=0.99, #RMSProp optimizer alpha 113 | max_grad_norm=0.5) #max norm of grads 114 | 115 | rollouts = RolloutStorage(config["update_every_n_steps"], config["n_cpus"], 116 | envs.observation_space.shape, envs.action_space, 117 | actor_critic.recurrent_hidden_state_size) 118 | 119 | obs = envs.reset() 120 | rollouts.obs[0].copy_(obs) 121 | rollouts.to(device) 122 | 123 | episode_rewards = deque(maxlen=10) 124 | loss_stats = [] 125 | 126 | start = time.time() 127 | num_updates = int(config["n_env_steps"]) // config["update_every_n_steps"] // config["n_cpus"] 128 | for j in range(start_upd, num_updates): #main training loop: global iteration counted in weight updates according 129 | # to num_steps number of environment steps used for each update 130 | 131 | for step in range(config["update_every_n_steps"]): #a batch update of num_steps for each num_process will be created in this 132 | # loop 133 | with torch.no_grad(): # Sample actions 134 | value, action, action_log_prob, recurrent_hidden_states = actor_critic.act( 135 | rollouts.obs[step], rollouts.recurrent_hidden_states[step], 136 | rollouts.masks[step]) 137 | 138 | # Obser reward and next obs 139 | obs, reward, done, infos = envs.step(action) 140 | 141 | for info in infos: 142 | if 'episode' in info.keys(): 143 | episode_rewards.append(info['episode']['r']) #todo: is this needed? 144 | 145 | # If done then clean the history of observations. 146 | masks = torch.FloatTensor([[0.0] if done_ else [1.0] for done_ in done]) 147 | bad_masks = torch.FloatTensor([[0.0] if 'bad_transition' in info.keys() else [1.0] for info in infos]) 148 | rollouts.insert(obs, recurrent_hidden_states, action, 149 | action_log_prob, value, reward, masks, bad_masks) 150 | 151 | with torch.no_grad(): #get value at end of look-ahead 152 | next_value = actor_critic.get_value( 153 | rollouts.obs[-1], rollouts.recurrent_hidden_states[-1], 154 | rollouts.masks[-1]).detach() 155 | 156 | rollouts.compute_returns(next_value, use_gae=False, gamma=config["gamma"], gae_lambda=None) 157 | 158 | value_loss, action_loss, dist_entropy = agent.update(rollouts) 159 | loss_stats.append([value_loss, action_loss, dist_entropy]) 160 | 161 | rollouts.after_update() 162 | 163 | # save for every interval-th episode or for the last epoch 164 | if (j % SAVE_EVERY == 0 165 | or j == num_updates - 1) and save_path != "": 166 | try: 167 | os.makedirs(save_path) 168 | except OSError: 169 | pass 170 | 171 | torch.save([j, actor_critic, agent], os.path.join(modelckptpath)) 172 | 173 | #write and clear loss_stats 174 | with open(save_stats_path, "a", newline="") as f: 175 | writer = csv.writer(f) 176 | writer.writerows(loss_stats) 177 | loss_stats = [] 178 | 179 | if j % LOG_EVERY == 0 and len(episode_rewards) > 1: 180 | total_num_steps = (j + 1) * config["n_cpus"] * config["update_every_n_steps"] 181 | end = time.time() 182 | print( 183 | "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n" 184 | .format(j, total_num_steps, 185 | int(total_num_steps / (end - start)), 186 | len(episode_rewards), np.mean(episode_rewards), 187 | np.median(episode_rewards), np.min(episode_rewards), 188 | np.max(episode_rewards))) 189 | 190 | if __name__ == "__main__": 191 | main() 192 | -------------------------------------------------------------------------------- /helpers/a2c_ppo_acktr/algo/kfac.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | 8 | from helpers.a2c_ppo_acktr.utils import AddBias 9 | 10 | # TODO: In order to make this code faster: 11 | # 1) Implement _extract_patches as a single cuda kernel 12 | # 2) Compute QR decomposition in a separate process 13 | # 3) Actually make a general KFAC optimizer so it fits PyTorch 14 | 15 | 16 | def _extract_patches(x, kernel_size, stride, padding): 17 | if padding[0] + padding[1] > 0: 18 | x = F.pad(x, (padding[1], padding[1], padding[0], 19 | padding[0])).data # Actually check dims 20 | x = x.unfold(2, kernel_size[0], stride[0]) 21 | x = x.unfold(3, kernel_size[1], stride[1]) 22 | x = x.transpose_(1, 2).transpose_(2, 3).contiguous() 23 | x = x.view( 24 | x.size(0), x.size(1), x.size(2), 25 | x.size(3) * x.size(4) * x.size(5)) 26 | return x 27 | 28 | 29 | def compute_cov_a(a, classname, layer_info, fast_cnn): 30 | batch_size = a.size(0) 31 | 32 | if classname == 'Conv2d': 33 | if fast_cnn: 34 | a = _extract_patches(a, *layer_info) 35 | a = a.view(a.size(0), -1, a.size(-1)) 36 | a = a.mean(1) 37 | else: 38 | a = _extract_patches(a, *layer_info) 39 | a = a.view(-1, a.size(-1)).div_(a.size(1)).div_(a.size(2)) 40 | elif classname == 'AddBias': 41 | is_cuda = a.is_cuda 42 | a = torch.ones(a.size(0), 1) 43 | if is_cuda: 44 | a = a.cuda() 45 | 46 | return a.t() @ (a / batch_size) 47 | 48 | 49 | def compute_cov_g(g, classname, layer_info, fast_cnn): 50 | batch_size = g.size(0) 51 | 52 | if classname == 'Conv2d': 53 | if fast_cnn: 54 | g = g.view(g.size(0), g.size(1), -1) 55 | g = g.sum(-1) 56 | else: 57 | g = g.transpose(1, 2).transpose(2, 3).contiguous() 58 | g = g.view(-1, g.size(-1)).mul_(g.size(1)).mul_(g.size(2)) 59 | elif classname == 'AddBias': 60 | g = g.view(g.size(0), g.size(1), -1) 61 | g = g.sum(-1) 62 | 63 | g_ = g * batch_size 64 | return g_.t() @ (g_ / g.size(0)) 65 | 66 | 67 | def update_running_stat(aa, m_aa, momentum): 68 | # Do the trick to keep aa unchanged and not create any additional tensors 69 | m_aa *= momentum / (1 - momentum) 70 | m_aa += aa 71 | m_aa *= (1 - momentum) 72 | 73 | 74 | class SplitBias(nn.Module): 75 | def __init__(self, module): 76 | super(SplitBias, self).__init__() 77 | self.module = module 78 | self.add_bias = AddBias(module.bias.data) 79 | self.module.bias = None 80 | 81 | def forward(self, input): 82 | x = self.module(input) 83 | x = self.add_bias(x) 84 | return x 85 | 86 | 87 | class KFACOptimizer(optim.Optimizer): 88 | def __init__(self, 89 | model, 90 | lr=0.25, 91 | momentum=0.9, 92 | stat_decay=0.99, 93 | kl_clip=0.001, 94 | damping=1e-2, 95 | weight_decay=0, 96 | fast_cnn=False, 97 | Ts=1, 98 | Tf=10): 99 | defaults = dict() 100 | 101 | def split_bias(module): 102 | for mname, child in module.named_children(): 103 | if hasattr(child, 'bias') and child.bias is not None: 104 | module._modules[mname] = SplitBias(child) 105 | else: 106 | split_bias(child) 107 | 108 | split_bias(model) 109 | 110 | super(KFACOptimizer, self).__init__(model.parameters(), defaults) 111 | 112 | self.known_modules = {'Linear', 'Conv2d', 'AddBias'} 113 | 114 | self.modules = [] 115 | self.grad_outputs = {} 116 | 117 | self.model = model 118 | self._prepare_model() 119 | 120 | self.steps = 0 121 | 122 | self.m_aa, self.m_gg = {}, {} 123 | self.Q_a, self.Q_g = {}, {} 124 | self.d_a, self.d_g = {}, {} 125 | 126 | self.momentum = momentum 127 | self.stat_decay = stat_decay 128 | 129 | self.lr = lr 130 | self.kl_clip = kl_clip 131 | self.damping = damping 132 | self.weight_decay = weight_decay 133 | 134 | self.fast_cnn = fast_cnn 135 | 136 | self.Ts = Ts 137 | self.Tf = Tf 138 | 139 | self.optim = optim.SGD( 140 | model.parameters(), 141 | lr=self.lr * (1 - self.momentum), 142 | momentum=self.momentum) 143 | 144 | def _save_input(self, module, input): 145 | if torch.is_grad_enabled() and self.steps % self.Ts == 0: 146 | classname = module.__class__.__name__ 147 | layer_info = None 148 | if classname == 'Conv2d': 149 | layer_info = (module.kernel_size, module.stride, 150 | module.padding) 151 | 152 | aa = compute_cov_a(input[0].data, classname, layer_info, 153 | self.fast_cnn) 154 | 155 | # Initialize buffers 156 | if self.steps == 0: 157 | self.m_aa[module] = aa.clone() 158 | 159 | update_running_stat(aa, self.m_aa[module], self.stat_decay) 160 | 161 | def _save_grad_output(self, module, grad_input, grad_output): 162 | # Accumulate statistics for Fisher matrices 163 | if self.acc_stats: 164 | classname = module.__class__.__name__ 165 | layer_info = None 166 | if classname == 'Conv2d': 167 | layer_info = (module.kernel_size, module.stride, 168 | module.padding) 169 | 170 | gg = compute_cov_g(grad_output[0].data, classname, layer_info, 171 | self.fast_cnn) 172 | 173 | # Initialize buffers 174 | if self.steps == 0: 175 | self.m_gg[module] = gg.clone() 176 | 177 | update_running_stat(gg, self.m_gg[module], self.stat_decay) 178 | 179 | def _prepare_model(self): 180 | for module in self.model.modules(): 181 | classname = module.__class__.__name__ 182 | if classname in self.known_modules: 183 | assert not ((classname in ['Linear', 'Conv2d']) and module.bias is not None), \ 184 | "You must have a bias as a separate layer" 185 | 186 | self.modules.append(module) 187 | module.register_forward_pre_hook(self._save_input) 188 | module.register_backward_hook(self._save_grad_output) 189 | 190 | def step(self): 191 | # Add weight decay 192 | if self.weight_decay > 0: 193 | for p in self.model.parameters(): 194 | p.grad.data.add_(self.weight_decay, p.data) 195 | 196 | updates = {} 197 | for i, m in enumerate(self.modules): 198 | assert len(list(m.parameters()) 199 | ) == 1, "Can handle only one parameter at the moment" 200 | classname = m.__class__.__name__ 201 | p = next(m.parameters()) 202 | 203 | la = self.damping + self.weight_decay 204 | 205 | if self.steps % self.Tf == 0: 206 | # My asynchronous implementation exists, I will add it later. 207 | # Experimenting with different ways to this in PyTorch. 208 | self.d_a[m], self.Q_a[m] = torch.symeig( 209 | self.m_aa[m], eigenvectors=True) 210 | self.d_g[m], self.Q_g[m] = torch.symeig( 211 | self.m_gg[m], eigenvectors=True) 212 | 213 | self.d_a[m].mul_((self.d_a[m] > 1e-6).float()) 214 | self.d_g[m].mul_((self.d_g[m] > 1e-6).float()) 215 | 216 | if classname == 'Conv2d': 217 | p_grad_mat = p.grad.data.view(p.grad.data.size(0), -1) 218 | else: 219 | p_grad_mat = p.grad.data 220 | 221 | v1 = self.Q_g[m].t() @ p_grad_mat @ self.Q_a[m] 222 | v2 = v1 / ( 223 | self.d_g[m].unsqueeze(1) * self.d_a[m].unsqueeze(0) + la) 224 | v = self.Q_g[m] @ v2 @ self.Q_a[m].t() 225 | 226 | v = v.view(p.grad.data.size()) 227 | updates[p] = v 228 | 229 | vg_sum = 0 230 | for p in self.model.parameters(): 231 | v = updates[p] 232 | vg_sum += (v * p.grad.data * self.lr * self.lr).sum() 233 | 234 | nu = min(1, math.sqrt(self.kl_clip / vg_sum)) 235 | 236 | for p in self.model.parameters(): 237 | v = updates[p] 238 | p.grad.data.copy_(v) 239 | p.grad.data.mul_(nu) 240 | 241 | self.optim.step() 242 | self.steps += 1 243 | -------------------------------------------------------------------------------- /helpers/a2c_ppo_acktr/envs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import gym 4 | import numpy as np 5 | import torch 6 | from gym.spaces.box import Box 7 | 8 | from baselines import bench 9 | from baselines.common.atari_wrappers import make_atari, wrap_deepmind 10 | from baselines.common.vec_env import VecEnvWrapper 11 | from baselines.common.vec_env.dummy_vec_env import DummyVecEnv 12 | from baselines.common.vec_env.shmem_vec_env import ShmemVecEnv 13 | from baselines.common.vec_env.vec_normalize import \ 14 | VecNormalize as VecNormalize_ 15 | 16 | try: 17 | import dm_control2gym 18 | except ImportError: 19 | pass 20 | 21 | try: 22 | import roboschool 23 | except ImportError: 24 | pass 25 | 26 | try: 27 | import pybullet_envs 28 | except ImportError: 29 | pass 30 | 31 | 32 | def make_env(env_id, seed, rank, log_dir, allow_early_resets): 33 | """Creates a vectorizable thunk wrapper of specified environment. 34 | 35 | Args: 36 | env_id: either the name of a registered gym environment (gym.make(env_id)) or tuple containing [0] name of 37 | registered environment and [1] kwargs, since gym since recently features parameterizable .make()s. 38 | 39 | """ 40 | def _thunk(): 41 | if type(env_id) == tuple: 42 | env = gym.make(env_id[0], **env_id[1]) 43 | env_name_id = env_id[0] 44 | else: 45 | env = gym.make(env_id) 46 | env_name_id = env_id 47 | 48 | is_atari = hasattr(gym.envs, 'atari') and isinstance( 49 | env.unwrapped, gym.envs.atari.atari_env.AtariEnv) 50 | if is_atari: 51 | env = make_atari(env_id) 52 | 53 | env.seed(seed + rank) 54 | 55 | if str(env.__class__.__name__).find('TimeLimit') >= 0: 56 | env = TimeLimitMask(env) 57 | 58 | if log_dir is not None: 59 | env = bench.Monitor( 60 | env, 61 | os.path.join(log_dir, str(rank)), 62 | allow_early_resets=allow_early_resets) 63 | 64 | if is_atari: 65 | if len(env.observation_space.shape) == 3: 66 | env = wrap_deepmind(env) 67 | elif "boxworld" in env_name_id: 68 | pass 69 | elif len(env.observation_space.shape) == 3: 70 | raise NotImplementedError( 71 | "CNN models work only for atari,\n" 72 | "please use a custom wrapper for a custom pixel input env.\n" 73 | "See wrap_deepmind for an example.") 74 | 75 | # If the input has shape (W,H,3), wrap for PyTorch convolutions 76 | obs_shape = env.observation_space.shape 77 | if len(obs_shape) == 3 and obs_shape[2] in [1, 3]: 78 | env = TransposeImage(env, op=[2, 0, 1]) 79 | 80 | return env 81 | 82 | return _thunk 83 | 84 | 85 | def make_vec_envs(env_name, 86 | seed, 87 | num_processes, 88 | gamma, 89 | log_dir, 90 | device, 91 | allow_early_resets, 92 | num_frame_stack=None): 93 | envs = [ 94 | make_env(env_name, seed, i, log_dir, allow_early_resets) 95 | for i in range(num_processes) 96 | ] 97 | 98 | if len(envs) > 1: 99 | envs = ShmemVecEnv(envs, context='fork') 100 | else: 101 | envs = DummyVecEnv(envs) 102 | 103 | if len(envs.observation_space.shape) == 1: 104 | if gamma is None: 105 | envs = VecNormalize(envs, ret=False) 106 | else: 107 | envs = VecNormalize(envs, gamma=gamma) 108 | envs = VecPyTorch(envs, device) 109 | 110 | if num_frame_stack is not None: 111 | envs = VecPyTorchFrameStack(envs, num_frame_stack, device) 112 | elif len(envs.observation_space.shape) == 3: 113 | envs = VecPyTorchFrameStack(envs, 4, device) 114 | return envs 115 | 116 | 117 | # Checks whether done was caused my timit limits or not 118 | class TimeLimitMask(gym.Wrapper): 119 | def step(self, action): 120 | obs, rew, done, info = self.env.step(action) 121 | if done and self.env._max_episode_steps == self.env._elapsed_steps: 122 | info['bad_transition'] = True 123 | 124 | return obs, rew, done, info 125 | 126 | def reset(self, **kwargs): 127 | return self.env.reset(**kwargs) 128 | 129 | 130 | # Can be used to test recurrent policies for Reacher-v2 131 | class MaskGoal(gym.ObservationWrapper): 132 | def observation(self, observation): 133 | if self.env._elapsed_steps > 0: 134 | observation[-2:] = 0 135 | return observation 136 | 137 | 138 | class TransposeObs(gym.ObservationWrapper): 139 | def __init__(self, env=None): 140 | """ 141 | Transpose observation space (base class) 142 | """ 143 | super(TransposeObs, self).__init__(env) 144 | 145 | 146 | class TransposeImage(TransposeObs): 147 | def __init__(self, env=None, op=[2, 0, 1]): 148 | """ 149 | Transpose observation space for images 150 | """ 151 | super(TransposeImage, self).__init__(env) 152 | assert len(op) == 3, "Error: Operation, " + str(op) + ", must be dim3" 153 | self.op = op 154 | obs_shape = self.observation_space.shape 155 | self.observation_space = Box( 156 | self.observation_space.low[0, 0, 0], 157 | self.observation_space.high[0, 0, 0], [ 158 | obs_shape[self.op[0]], obs_shape[self.op[1]], 159 | obs_shape[self.op[2]] 160 | ], 161 | dtype=self.observation_space.dtype) 162 | 163 | def observation(self, ob): 164 | return ob.transpose(self.op[0], self.op[1], self.op[2]) 165 | 166 | 167 | class VecPyTorch(VecEnvWrapper): 168 | def __init__(self, venv, device): 169 | """Return only every `skip`-th frame""" 170 | super(VecPyTorch, self).__init__(venv) 171 | self.device = device 172 | # TODO: Fix data types 173 | 174 | def reset(self): 175 | obs = self.venv.reset() 176 | obs = torch.from_numpy(obs).float().to(self.device) 177 | return obs 178 | 179 | def step_async(self, actions): 180 | if isinstance(actions, torch.LongTensor): 181 | # Squeeze the dimension for discrete actions 182 | actions = actions.squeeze(1) 183 | actions = actions.cpu().numpy() 184 | self.venv.step_async(actions) 185 | 186 | def step_wait(self): 187 | obs, reward, done, info = self.venv.step_wait() 188 | obs = torch.from_numpy(obs).float().to(self.device) 189 | reward = torch.from_numpy(reward).unsqueeze(dim=1).float() 190 | return obs, reward, done, info 191 | 192 | 193 | class VecNormalize(VecNormalize_): 194 | def __init__(self, *args, **kwargs): 195 | super(VecNormalize, self).__init__(*args, **kwargs) 196 | self.training = True 197 | 198 | def _obfilt(self, obs, update=True): 199 | if self.ob_rms: 200 | if self.training and update: 201 | self.ob_rms.update(obs) 202 | obs = np.clip((obs - self.ob_rms.mean) / 203 | np.sqrt(self.ob_rms.var + self.epsilon), 204 | -self.clipob, self.clipob) 205 | return obs 206 | else: 207 | return obs 208 | 209 | def train(self): 210 | self.training = True 211 | 212 | def eval(self): 213 | self.training = False 214 | 215 | 216 | # Derived from 217 | # https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_frame_stack.py 218 | class VecPyTorchFrameStack(VecEnvWrapper): 219 | def __init__(self, venv, nstack, device=None): 220 | self.venv = venv 221 | self.nstack = nstack 222 | 223 | wos = venv.observation_space # wrapped ob space 224 | self.shape_dim0 = wos.shape[0] 225 | 226 | low = np.repeat(wos.low, self.nstack, axis=0) 227 | high = np.repeat(wos.high, self.nstack, axis=0) 228 | 229 | if device is None: 230 | device = torch.device('cpu') 231 | self.stacked_obs = torch.zeros((venv.num_envs, ) + 232 | low.shape).to(device) 233 | 234 | observation_space = gym.spaces.Box( 235 | low=low, high=high, dtype=venv.observation_space.dtype) 236 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space) 237 | 238 | def step_wait(self): 239 | obs, rews, news, infos = self.venv.step_wait() 240 | self.stacked_obs[:, :-self.shape_dim0] = \ 241 | self.stacked_obs[:, self.shape_dim0:].clone() 242 | for (i, new) in enumerate(news): 243 | if new: 244 | self.stacked_obs[i] = 0 245 | self.stacked_obs[:, -self.shape_dim0:] = obs 246 | return self.stacked_obs, rews, news, infos 247 | 248 | def reset(self): 249 | obs = self.venv.reset() 250 | if torch.backends.cudnn.deterministic: 251 | self.stacked_obs = torch.zeros(self.stacked_obs.shape) 252 | else: 253 | self.stacked_obs.zero_() 254 | self.stacked_obs[:, -self.shape_dim0:] = obs 255 | return self.stacked_obs 256 | 257 | def close(self): 258 | self.venv.close() 259 | -------------------------------------------------------------------------------- /helpers/a2c_ppo_acktr/storage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 3 | 4 | 5 | def _flatten_helper(T, N, _tensor): 6 | return _tensor.view(T * N, *_tensor.size()[2:]) 7 | 8 | 9 | class RolloutStorage(object): 10 | def __init__(self, num_steps, num_processes, obs_shape, action_space, 11 | recurrent_hidden_state_size): 12 | self.obs = torch.zeros(num_steps + 1, num_processes, *obs_shape) 13 | self.recurrent_hidden_states = torch.zeros( 14 | num_steps + 1, num_processes, recurrent_hidden_state_size) 15 | self.rewards = torch.zeros(num_steps, num_processes, 1) 16 | self.value_preds = torch.zeros(num_steps + 1, num_processes, 1) 17 | self.returns = torch.zeros(num_steps + 1, num_processes, 1) 18 | self.action_log_probs = torch.zeros(num_steps, num_processes, 1) 19 | if action_space.__class__.__name__ == 'Discrete': 20 | action_shape = 1 21 | else: 22 | action_shape = action_space.shape[0] 23 | self.actions = torch.zeros(num_steps, num_processes, action_shape) 24 | if action_space.__class__.__name__ == 'Discrete': 25 | self.actions = self.actions.long() 26 | self.masks = torch.ones(num_steps + 1, num_processes, 1) 27 | 28 | # Masks that indicate whether it's a true terminal state 29 | # or time limit end state 30 | self.bad_masks = torch.ones(num_steps + 1, num_processes, 1) 31 | 32 | self.num_steps = num_steps 33 | self.step = 0 34 | 35 | def to(self, device): 36 | self.obs = self.obs.to(device) 37 | self.recurrent_hidden_states = self.recurrent_hidden_states.to(device) 38 | self.rewards = self.rewards.to(device) 39 | self.value_preds = self.value_preds.to(device) 40 | self.returns = self.returns.to(device) 41 | self.action_log_probs = self.action_log_probs.to(device) 42 | self.actions = self.actions.to(device) 43 | self.masks = self.masks.to(device) 44 | self.bad_masks = self.bad_masks.to(device) 45 | 46 | def insert(self, obs, recurrent_hidden_states, actions, action_log_probs, 47 | value_preds, rewards, masks, bad_masks): 48 | self.obs[self.step + 1].copy_(obs) 49 | self.recurrent_hidden_states[self.step + 50 | 1].copy_(recurrent_hidden_states) 51 | self.actions[self.step].copy_(actions) 52 | self.action_log_probs[self.step].copy_(action_log_probs) 53 | self.value_preds[self.step].copy_(value_preds) 54 | self.rewards[self.step].copy_(rewards) 55 | self.masks[self.step + 1].copy_(masks) 56 | self.bad_masks[self.step + 1].copy_(bad_masks) 57 | 58 | self.step = (self.step + 1) % self.num_steps 59 | 60 | def after_update(self): 61 | self.obs[0].copy_(self.obs[-1]) 62 | self.recurrent_hidden_states[0].copy_(self.recurrent_hidden_states[-1]) 63 | self.masks[0].copy_(self.masks[-1]) 64 | self.bad_masks[0].copy_(self.bad_masks[-1]) 65 | 66 | def compute_returns(self, 67 | next_value, 68 | use_gae, 69 | gamma, 70 | gae_lambda, 71 | use_proper_time_limits=True): 72 | if use_proper_time_limits: 73 | if use_gae: 74 | self.value_preds[-1] = next_value 75 | gae = 0 76 | for step in reversed(range(self.rewards.size(0))): 77 | delta = self.rewards[step] + gamma * self.value_preds[ 78 | step + 1] * self.masks[step + 79 | 1] - self.value_preds[step] 80 | gae = delta + gamma * gae_lambda * self.masks[step + 81 | 1] * gae 82 | gae = gae * self.bad_masks[step + 1] 83 | self.returns[step] = gae + self.value_preds[step] 84 | else: 85 | self.returns[-1] = next_value 86 | for step in reversed(range(self.rewards.size(0))): 87 | self.returns[step] = (self.returns[step + 1] * \ 88 | gamma * self.masks[step + 1] + self.rewards[step]) * self.bad_masks[step + 1] \ 89 | + (1 - self.bad_masks[step + 1]) * self.value_preds[step] 90 | else: 91 | if use_gae: 92 | self.value_preds[-1] = next_value 93 | gae = 0 94 | for step in reversed(range(self.rewards.size(0))): 95 | delta = self.rewards[step] + gamma * self.value_preds[ 96 | step + 1] * self.masks[step + 97 | 1] - self.value_preds[step] 98 | gae = delta + gamma * gae_lambda * self.masks[step + 99 | 1] * gae 100 | self.returns[step] = gae + self.value_preds[step] 101 | else: 102 | self.returns[-1] = next_value 103 | for step in reversed(range(self.rewards.size(0))): 104 | self.returns[step] = self.returns[step + 1] * \ 105 | gamma * self.masks[step + 1] + self.rewards[step] 106 | 107 | def feed_forward_generator(self, 108 | advantages, 109 | num_mini_batch=None, 110 | mini_batch_size=None): 111 | num_steps, num_processes = self.rewards.size()[0:2] 112 | batch_size = num_processes * num_steps 113 | 114 | if mini_batch_size is None: 115 | assert batch_size >= num_mini_batch, ( 116 | "PPO requires the number of processes ({}) " 117 | "* number of steps ({}) = {} " 118 | "to be greater than or equal to the number of PPO mini batches ({})." 119 | "".format(num_processes, num_steps, num_processes * num_steps, 120 | num_mini_batch)) 121 | mini_batch_size = batch_size // num_mini_batch 122 | sampler = BatchSampler( 123 | SubsetRandomSampler(range(batch_size)), 124 | mini_batch_size, 125 | drop_last=True) 126 | for indices in sampler: 127 | obs_batch = self.obs[:-1].view(-1, *self.obs.size()[2:])[indices] 128 | recurrent_hidden_states_batch = self.recurrent_hidden_states[:-1].view( 129 | -1, self.recurrent_hidden_states.size(-1))[indices] 130 | actions_batch = self.actions.view(-1, 131 | self.actions.size(-1))[indices] 132 | value_preds_batch = self.value_preds[:-1].view(-1, 1)[indices] 133 | return_batch = self.returns[:-1].view(-1, 1)[indices] 134 | masks_batch = self.masks[:-1].view(-1, 1)[indices] 135 | old_action_log_probs_batch = self.action_log_probs.view(-1, 136 | 1)[indices] 137 | if advantages is None: 138 | adv_targ = None 139 | else: 140 | adv_targ = advantages.view(-1, 1)[indices] 141 | 142 | yield obs_batch, recurrent_hidden_states_batch, actions_batch, \ 143 | value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ 144 | 145 | def recurrent_generator(self, advantages, num_mini_batch): 146 | num_processes = self.rewards.size(1) 147 | assert num_processes >= num_mini_batch, ( 148 | "PPO requires the number of processes ({}) " 149 | "to be greater than or equal to the number of " 150 | "PPO mini batches ({}).".format(num_processes, num_mini_batch)) 151 | num_envs_per_batch = num_processes // num_mini_batch 152 | perm = torch.randperm(num_processes) 153 | for start_ind in range(0, num_processes, num_envs_per_batch): 154 | obs_batch = [] 155 | recurrent_hidden_states_batch = [] 156 | actions_batch = [] 157 | value_preds_batch = [] 158 | return_batch = [] 159 | masks_batch = [] 160 | old_action_log_probs_batch = [] 161 | adv_targ = [] 162 | 163 | for offset in range(num_envs_per_batch): 164 | ind = perm[start_ind + offset] 165 | obs_batch.append(self.obs[:-1, ind]) 166 | recurrent_hidden_states_batch.append( 167 | self.recurrent_hidden_states[0:1, ind]) 168 | actions_batch.append(self.actions[:, ind]) 169 | value_preds_batch.append(self.value_preds[:-1, ind]) 170 | return_batch.append(self.returns[:-1, ind]) 171 | masks_batch.append(self.masks[:-1, ind]) 172 | old_action_log_probs_batch.append( 173 | self.action_log_probs[:, ind]) 174 | adv_targ.append(advantages[:, ind]) 175 | 176 | T, N = self.num_steps, num_envs_per_batch 177 | # These are all tensors of size (T, N, -1) 178 | obs_batch = torch.stack(obs_batch, 1) 179 | actions_batch = torch.stack(actions_batch, 1) 180 | value_preds_batch = torch.stack(value_preds_batch, 1) 181 | return_batch = torch.stack(return_batch, 1) 182 | masks_batch = torch.stack(masks_batch, 1) 183 | old_action_log_probs_batch = torch.stack( 184 | old_action_log_probs_batch, 1) 185 | adv_targ = torch.stack(adv_targ, 1) 186 | 187 | # States is just a (N, -1) tensor 188 | recurrent_hidden_states_batch = torch.stack( 189 | recurrent_hidden_states_batch, 1).view(N, -1) 190 | 191 | # Flatten the (T, N, ...) tensors to (T * N, ...) 192 | obs_batch = _flatten_helper(T, N, obs_batch) 193 | actions_batch = _flatten_helper(T, N, actions_batch) 194 | value_preds_batch = _flatten_helper(T, N, value_preds_batch) 195 | return_batch = _flatten_helper(T, N, return_batch) 196 | masks_batch = _flatten_helper(T, N, masks_batch) 197 | old_action_log_probs_batch = _flatten_helper(T, N, \ 198 | old_action_log_probs_batch) 199 | adv_targ = _flatten_helper(T, N, adv_targ) 200 | 201 | yield obs_batch, recurrent_hidden_states_batch, actions_batch, \ 202 | value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, adv_targ 203 | -------------------------------------------------------------------------------- /attention_module.py: -------------------------------------------------------------------------------- 1 | """Implementation of the deep relational architecture used in https://arxiv.org/pdf/1806.01830.pdf. 2 | """ 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import math 7 | import torch.optim 8 | import torch.autograd 9 | from collections import OrderedDict 10 | 11 | class AttentionHead(nn.Module): 12 | 13 | def __init__(self, n_elems, elem_size, emb_size): 14 | super(AttentionHead, self).__init__() 15 | self.sqrt_emb_size = int(math.sqrt(emb_size)) 16 | #queries, keys, values 17 | self.query = nn.Linear(elem_size, emb_size) 18 | self.key = nn.Linear(elem_size, emb_size) 19 | self.value = nn.Linear(elem_size, elem_size) 20 | #layer norms: 21 | # In the paper the authors normalize the projected Q,K and V with layer normalization. They don't state 22 | # explicitly over which dimensions they normalize and how exactly gains and biases are shared. I decided to 23 | # stick with with the solution from https://github.com/gyh75520/Relational_DRL/ because it makes the most 24 | # sense to me: 0,1-normalize every projected entity and apply separate gain and bias to each entry in the 25 | # embeddings. Weights are shared across entites, but not accross Q,K,V or heads. 26 | self.qln = nn.LayerNorm(emb_size, elementwise_affine=True) 27 | self.kln = nn.LayerNorm(emb_size, elementwise_affine=True) 28 | self.vln = nn.LayerNorm(elem_size, elementwise_affine=True) 29 | 30 | def forward(self, x): 31 | # print(f"input: {x.shape}") 32 | Q = self.qln(self.query(x)) 33 | K = self.kln(self.key(x)) 34 | V = self.vln(self.value(x)) 35 | # softmax is taken over last dimension (rows) of QK': All the attentional weights going into a column/entity 36 | # of V thus sum up to 1. 37 | softmax = F.softmax(torch.bmm(Q,K.transpose(1,2))/self.sqrt_emb_size, dim=-1) 38 | # print(f"softmax shape: {softmax.shape} and sum accross batch 1, column 1: {torch.sum(softmax[0,0,:])}") 39 | return torch.bmm(softmax,V) 40 | 41 | def attention_weights(self, x): 42 | # print(f"input: {x.shape}") 43 | Q = self.qln(self.query(x)) 44 | K = self.kln(self.key(x)) 45 | V = self.vln(self.value(x)) 46 | # softmax is taken over last dimension (rows) of QK': All the attentional weights going into a column/entity 47 | # of V thus sum up to 1. 48 | softmax = F.softmax(torch.bmm(Q,K.transpose(1,2))/self.sqrt_emb_size, dim=-1) 49 | return softmax 50 | 51 | class AttentionModule(nn.Module): 52 | 53 | def __init__(self, n_elems, elem_size, emb_size, n_heads): 54 | super(AttentionModule, self).__init__() 55 | # self.input_shape = input_shape 56 | # self.elem_size = elem_size 57 | # self.emb_size = emb_size #honestly not really needed 58 | self.heads = nn.ModuleList(AttentionHead(n_elems, elem_size, emb_size) for _ in range(n_heads)) 59 | self.linear1 = nn.Linear(n_heads*elem_size, elem_size) 60 | self.linear2 = nn.Linear(elem_size, elem_size) 61 | 62 | self.ln = nn.LayerNorm(elem_size, elementwise_affine=True) 63 | 64 | def forward(self, x): 65 | #concatenate all heads' outputs 66 | A_cat = torch.cat([head(x) for head in self.heads], -1) 67 | # projecting down to original element size with 2-layer MLP, layer size = entity size 68 | mlp_out = F.relu(self.linear2(F.relu(self.linear1(A_cat)))) 69 | # residual connection and final layer normalization 70 | return self.ln(x + mlp_out) 71 | 72 | def get_att_weights(self, x): 73 | """Version of forward function that also returns softmax-normalied QK' attention weights""" 74 | #concatenate all heads' outputs 75 | A_cat = torch.cat([head(x) for head in self.heads], -1) 76 | # projecting down to original element size with 2-layer MLP, layer size = entity size 77 | mlp_out = F.relu(self.linear2(F.relu(self.linear1(A_cat)))) 78 | # residual connection and final layer normalization 79 | output = self.ln(x + mlp_out) 80 | attention_weights = [head.attention_weights(x).detach() for head in self.heads] 81 | return [output, attention_weights] 82 | 83 | class DRRLnet(nn.Module): 84 | 85 | def __init__(self, h, w, outputs, n_f_conv1 = 12, n_f_conv2 = 24, 86 | att_emb_size=64, n_heads=2, n_att_stack=2, n_fc_layers=4, pad=True, 87 | baseline_mode=False, n_baseMods=3): 88 | """ 89 | Args: 90 | baseline: True means that instead of the attentional module, a n_baseline number of residual-convolutional 91 | blocks will be placed at the core of the model instead of the attentional module. 92 | """ 93 | 94 | #internal action replay buffer for simple training algorithms 95 | self.baseline_mode = baseline_mode 96 | self.saved_actions = [] 97 | self.rewards = [] 98 | 99 | self.pad = pad 100 | self.n_baseMods = n_baseMods 101 | super(DRRLnet, self).__init__() 102 | 103 | self.conv1 = nn.Conv2d(3, n_f_conv1, kernel_size=2, stride=1) 104 | #possibly batch or layer norm, neither was mentioned in the paper though 105 | # self.ln1 = nn.LayerNorm([n_f_conv1,conv1w,conv1h]) 106 | # self.bn1 = nn.BatchNorm2d(n_f_conv1) 107 | self.conv2 = nn.Conv2d(n_f_conv1, n_f_conv2, kernel_size=2, stride=1) 108 | # self.ln2 = nn.LayerNorm([n_f_conv2,conv2w,conv2h]) 109 | # self.bn2 = nn.BatchNorm2d(n_f_conv2) 110 | 111 | # calculate size of convolution module output 112 | def conv2d_size_out(size, kernel_size=2, stride=1): 113 | return (size - (kernel_size - 1) - 1) // stride + 1 114 | if self.pad: 115 | conv1w = conv2w = w 116 | conv1h = conv2h = h 117 | else: 118 | conv1w = conv2d_size_out(w) 119 | conv1h = conv2d_size_out(h) 120 | conv2w = conv2d_size_out(conv1w) 121 | conv2h = conv2d_size_out(conv1h) 122 | 123 | # create x,y coordinate matrices to append to convolution output 124 | xmap = np.linspace(-np.ones(conv2h), np.ones(conv2h), num=conv2w, endpoint=True, axis=0) 125 | xmap = torch.tensor(np.expand_dims(np.expand_dims(xmap,0),0), dtype=torch.float32, requires_grad=False) 126 | ymap = np.linspace(-np.ones(conv2w), np.ones(conv2w), num=conv2h, endpoint=True, axis=1) 127 | ymap = torch.tensor(np.expand_dims(np.expand_dims(ymap,0),0), dtype=torch.float32, requires_grad=False) 128 | self.register_buffer("xymap", torch.cat((xmap,ymap),dim=1)) # shape (1, 2, conv2w, conv2h) 129 | 130 | # an "attendable" entity has 24 CNN channels + 2 coordinate channels = 26 features. this is also the default 131 | # number of baseline module conv layer filter number 132 | att_elem_size = n_f_conv2 + 2 133 | if not self.baseline_mode: 134 | # create attention module with n_heads heads and remember how many times to stack it 135 | self.n_att_stack = n_att_stack #how many times the attentional module is to be stacked (weight-sharing -> reuse) 136 | self.attMod = AttentionModule(conv2w*conv2h, att_elem_size, att_emb_size, n_heads) 137 | else: # create baseline module of several residual-convolutional layers 138 | base_dict = {} 139 | for i in range(self.n_baseMods): 140 | base_dict[f"baseline_identity_{i}"] = nn.Identity() 141 | base_dict[f"baseline_conv_{i}_0"] = nn.Conv2d(att_elem_size, att_elem_size, kernel_size=3, stride=1) 142 | base_dict[f"baseline_batchnorm_{i}_0"] = nn.BatchNorm2d(att_elem_size) 143 | base_dict[f"baseline_conv_{i}_1"] = nn.Conv2d(att_elem_size, att_elem_size, kernel_size=3, stride=1) 144 | base_dict[f"baseline_batchnorm_{i}_1"] = nn.BatchNorm2d(att_elem_size) 145 | 146 | self.baseMod = nn.ModuleDict(base_dict) 147 | #max pooling 148 | # print(f"attnl element size:{att_elem_size}") 149 | # self.maxpool = nn.MaxPool1d(kernel_size=att_emb_size,return_indices=False) #don't know why maxpool reduces 150 | # kernel_size by 1 151 | 152 | # FC256 layers, 4 is default 153 | if n_fc_layers < 1: 154 | raise ValueError("At least 1 linear readout layer is required.") 155 | fc_dict = OrderedDict([('fc1', nn.Linear(att_elem_size, 256)), 156 | ('relu1', nn.ReLU())]) #first one has different inpuz size 157 | for i in range(n_fc_layers-1): 158 | fc_dict[f"fc{i+2}"] = nn.Linear(256, 256) 159 | fc_dict[f"relu{i+2}"] = nn.ReLU() 160 | self.fc_seq = nn.Sequential(fc_dict) #sequential container from ordered dict 161 | self.logits = nn.Linear(256, outputs) 162 | self.value = nn.Linear(256, 1) 163 | self.outputmap = nn.Linear(256,outputs+1) 164 | 165 | # def init_weights(m): 166 | # print(m) 167 | # if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 168 | # torch.nn.init.orthogonal_(m.weight) 169 | # if m.bias is not None: 170 | # torch.nn.init.zeros_(m.bias) 171 | # 172 | # self.apply(init_weights) 173 | 174 | # Called with either one element to determine next action, or a batch 175 | # during optimization. Returns tensor([[left0exp,right0exp]...]). 176 | def forward(self, x): 177 | #cast environment observation into appropriate torch tensor 178 | x = x.permute(0,3,1,2) 179 | #convolutional module 180 | if self.pad: 181 | x = F.pad(x, (1,0,1,0)) #zero padding so state size stays constant 182 | c = F.relu(self.conv1(x)) 183 | if self.pad: 184 | c = F.pad(c, (1,0,1,0)) 185 | c = F.relu(self.conv2(c)) 186 | #append x,y coordinates to every sample in batch 187 | batchsize = c.size(0) 188 | # Filewriter complains about the this way of repeating the xymap, hope repeat is just as fine 189 | # batch_maps = torch.cat(batchsize*[self.xymap]) 190 | batch_maps = self.xymap.repeat(batchsize,1,1,1,) 191 | c = torch.cat((c,batch_maps),1) 192 | if not self.baseline_mode: 193 | #attentional module 194 | #careful: we are flattening out x,y dimensions into 1 dimension, so shape changes from (batchsize, #filters, 195 | # #conv2w, conv2h) to (batchsize, conv2w*conv2h, #filters), because downstream linear layers take last 196 | # dimension to be input features 197 | a = c.view(c.size(0),c.size(1), -1).transpose(1,2) 198 | # n_att_mod passes through attentional module -> n_att_mod stacked modules with weight sharing 199 | for i_att in range(self.n_att_stack): 200 | a = self.attMod(a) 201 | else: 202 | #baseline module 203 | for i in range(self.n_baseMods): 204 | inp = self.baseMod[f"baseline_identity_{i}"](c) #save input for residual connection 205 | #todo: make padding adaptive to kernel size and stride 206 | c = F.pad(c, (1, 1, 1, 1)) #padding so input maintains size 207 | c = self.baseMod[f"baseline_conv_{i}_0"](c) #conv1 208 | c = self.baseMod[f"baseline_batchnorm_{i}_0"](c) #batch-norm 209 | c = F.relu(c) #relu 210 | c = F.pad(c, (1, 1, 1, 1)) #padding so input maintains size 211 | c = self.baseMod[f"baseline_conv_{i}_1"](c) #conv2 212 | c = c + inp #residual connecton 213 | c = self.baseMod[f"baseline_batchnorm_{i}_1"](c) #batch-norm 214 | c = F.relu(c) #relu 215 | a = c.view(c.size(0),c.size(1), -1).transpose(1,2) #flatten (transpose not necessary but we do 216 | # it for consistency w/ attentional module 217 | 218 | #max pooling over "space", i.e. max scalar within each feature map m x n x f -> f 219 | # pool over entity dimension #isn't this a problem with gradients? 220 | # todo: try pooling over feature dimension 221 | kernelsize = a.shape[1] #but during forward passes called by SummaryWriter, a.shape[1] returns a tensor instead 222 | # of an int. if this causes any trouble it can be replaced by w*h 223 | if type(kernelsize) == torch.Tensor: 224 | kernelsize = kernelsize.item() 225 | pooled = F.max_pool1d(a.transpose(1,2), kernel_size=kernelsize) #pool out entity dimension 226 | #policy module: 4xFC256, then project to logits and value 227 | p = self.fc_seq(pooled.view(pooled.size(0),pooled.size(1))) 228 | # pi = F.softmax(self.logits(p), dim=1) 229 | # v = self.value(p) #todo: no normalization? 230 | # return pi, v 231 | #for A3C implementation: 232 | return F.softmax(self.outputmap(p),-1) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | a2c_single.py 2 | a3c_attention.py 3 | ac_attention.py 4 | train_attention.py 5 | configEXMPL.py 6 | 7 | 8 | ### Linux ### 9 | *~ 10 | 11 | # temporary files which can be created if a process still has a handle open of a deleted file 12 | .fuse_hidden* 13 | 14 | # KDE directory preferences 15 | .directory 16 | 17 | # Linux trash folder which might appear on any partition or disk 18 | .Trash-* 19 | 20 | # .nfs files are created when an open file is removed but is still being accessed 21 | .nfs* 22 | 23 | ### macOS ### 24 | # General 25 | .DS_Store 26 | .AppleDouble 27 | .LSOverride 28 | 29 | # Icon must end with two \r 30 | Icon 31 | 32 | # Thumbnails 33 | ._* 34 | 35 | # Files that might appear in the root of a volume 36 | .DocumentRevisions-V100 37 | .fseventsd 38 | .Spotlight-V100 39 | .TemporaryItems 40 | .Trashes 41 | .VolumeIcon.icns 42 | .com.apple.timemachine.donotpresent 43 | 44 | # Directories potentially created on remote AFP share 45 | .AppleDB 46 | .AppleDesktop 47 | Network Trash Folder 48 | Temporary Items 49 | .apdisk 50 | 51 | ### PyCharm ### 52 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 53 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 54 | 55 | # User-specific stuff 56 | .idea/ 57 | .idea/**/workspace.xml 58 | .idea/**/tasks.xml 59 | .idea/**/usage.statistics.xml 60 | .idea/**/dictionaries 61 | .idea/**/shelf 62 | 63 | # Generated files 64 | .idea/**/contentModel.xml 65 | 66 | # Sensitive or high-churn files 67 | .idea/**/dataSources/ 68 | .idea/**/dataSources.ids 69 | .idea/**/dataSources.local.xml 70 | .idea/**/sqlDataSources.xml 71 | .idea/**/dynamic.xml 72 | .idea/**/uiDesigner.xml 73 | .idea/**/dbnavigator.xml 74 | 75 | # Gradle 76 | .idea/**/gradle.xml 77 | .idea/**/libraries 78 | 79 | # Gradle and Maven with auto-import 80 | # When using Gradle or Maven with auto-import, you should exclude module files, 81 | # since they will be recreated, and may cause churn. Uncomment if using 82 | # auto-import. 83 | # .idea/modules.xml 84 | # .idea/*.iml 85 | # .idea/modules 86 | # *.iml 87 | # *.ipr 88 | 89 | # CMake 90 | cmake-build-*/ 91 | 92 | # Mongo Explorer plugin 93 | .idea/**/mongoSettings.xml 94 | 95 | # File-based project format 96 | *.iws 97 | 98 | # IntelliJ 99 | out/ 100 | 101 | # mpeltonen/sbt-idea plugin 102 | .idea_modules/ 103 | 104 | # JIRA plugin 105 | atlassian-ide-plugin.xml 106 | 107 | # Cursive Clojure plugin 108 | .idea/replstate.xml 109 | 110 | # Crashlytics plugin (for Android Studio and IntelliJ) 111 | com_crashlytics_export_strings.xml 112 | crashlytics.properties 113 | crashlytics-build.properties 114 | fabric.properties 115 | 116 | # Editor-based Rest Client 117 | .idea/httpRequests 118 | 119 | # Android studio 3.1+ serialized cache file 120 | .idea/caches/build_file_checksums.ser 121 | 122 | ### PyCharm Patch ### 123 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 124 | 125 | # *.iml 126 | # modules.xml 127 | # .idea/misc.xml 128 | # *.ipr 129 | 130 | # Sonarlint plugin 131 | .idea/**/sonarlint/ 132 | 133 | # SonarQube Plugin 134 | .idea/**/sonarIssues.xml 135 | 136 | # Markdown Navigator plugin 137 | .idea/**/markdown-navigator.xml 138 | .idea/**/markdown-navigator/ 139 | 140 | ### Python ### 141 | # Byte-compiled / optimized / DLL files 142 | __pycache__/ 143 | *.py[cod] 144 | *$py.class 145 | 146 | # C extensions 147 | *.so 148 | 149 | # Distribution / packaging 150 | .Python 151 | build/ 152 | develop-eggs/ 153 | dist/ 154 | downloads/ 155 | eggs/ 156 | .eggs/ 157 | lib/ 158 | lib64/ 159 | parts/ 160 | sdist/ 161 | var/ 162 | wheels/ 163 | pip-wheel-metadata/ 164 | share/python-wheels/ 165 | *.egg-info/ 166 | .installed.cfg 167 | *.egg 168 | MANIFEST 169 | 170 | # PyInstaller 171 | # Usually these files are written by a python script from a template 172 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 173 | *.manifest 174 | *.spec 175 | 176 | # Installer logs 177 | pip-log.txt 178 | pip-delete-this-directory.txt 179 | 180 | # Unit test / coverage reports 181 | htmlcov/ 182 | .tox/ 183 | .nox/ 184 | .coverage 185 | .coverage.* 186 | .cache 187 | nosetests.xml 188 | coverage.xml 189 | *.cover 190 | .hypothesis/ 191 | .pytest_cache/ 192 | 193 | # Translations 194 | *.mo 195 | *.pot 196 | 197 | # Scrapy stuff: 198 | .scrapy 199 | 200 | # Sphinx documentation 201 | docs/_build/ 202 | 203 | # PyBuilder 204 | target/ 205 | 206 | # pyenv 207 | .python-version 208 | 209 | # pipenv 210 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 211 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 212 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 213 | # install all needed dependencies. 214 | #Pipfile.lock 215 | 216 | # celery beat schedule file 217 | celerybeat-schedule 218 | 219 | # SageMath parsed files 220 | *.sage.py 221 | 222 | # Spyder project settings 223 | .spyderproject 224 | .spyproject 225 | 226 | # Rope project settings 227 | .ropeproject 228 | 229 | # Mr Developer 230 | .mr.developer.cfg 231 | .project 232 | .pydevproject 233 | 234 | # mkdocs documentation 235 | /site 236 | 237 | # mypy 238 | .mypy_cache/ 239 | .dmypy.json 240 | dmypy.json 241 | 242 | # Pyre type checker 243 | .pyre/ 244 | 245 | ### Windows ### 246 | # Windows thumbnail cache files 247 | Thumbs.db 248 | Thumbs.db:encryptable 249 | ehthumbs.db 250 | ehthumbs_vista.db 251 | 252 | # Dump file 253 | *.stackdump 254 | 255 | # Folder config file 256 | [Dd]esktop.ini 257 | 258 | # Recycle Bin used on file shares 259 | $RECYCLE.BIN/ 260 | 261 | # Windows Installer files 262 | *.cab 263 | *.msi 264 | *.msix 265 | *.msm 266 | *.msp 267 | 268 | # Windows shortcuts 269 | *.lnk 270 | 271 | ### VisualStudio ### 272 | ## Ignore Visual Studio temporary files, build results, and 273 | ## files generated by popular Visual Studio add-ons. 274 | ## 275 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore 276 | 277 | # User-specific files 278 | *.rsuser 279 | *.suo 280 | *.user 281 | *.userosscache 282 | *.sln.docstates 283 | 284 | # User-specific files (MonoDevelop/Xamarin Studio) 285 | *.userprefs 286 | 287 | # Mono auto generated files 288 | mono_crash.* 289 | 290 | # Build results 291 | [Dd]ebug/ 292 | [Dd]ebugPublic/ 293 | [Rr]elease/ 294 | [Rr]eleases/ 295 | x64/ 296 | x86/ 297 | [Aa][Rr][Mm]/ 298 | [Aa][Rr][Mm]64/ 299 | bld/ 300 | [Bb]in/ 301 | [Oo]bj/ 302 | [Ll]og/ 303 | 304 | # Visual Studio 2015/2017 cache/options directory 305 | .vs/ 306 | # Uncomment if you have tasks that create the project's static files in wwwroot 307 | #wwwroot/ 308 | 309 | # Visual Studio 2017 auto generated files 310 | Generated\ Files/ 311 | 312 | # MSTest test Results 313 | [Tt]est[Rr]esult*/ 314 | [Bb]uild[Ll]og.* 315 | 316 | # NUnit 317 | *.VisualState.xml 318 | TestResult.xml 319 | nunit-*.xml 320 | 321 | # Build Results of an ATL Project 322 | [Dd]ebugPS/ 323 | [Rr]eleasePS/ 324 | dlldata.c 325 | 326 | # Benchmark Results 327 | BenchmarkDotNet.Artifacts/ 328 | 329 | # .NET Core 330 | project.lock.json 331 | project.fragment.lock.json 332 | artifacts/ 333 | 334 | # StyleCop 335 | StyleCopReport.xml 336 | 337 | # Files built by Visual Studio 338 | *_i.c 339 | *_p.c 340 | *_h.h 341 | *.ilk 342 | *.meta 343 | *.obj 344 | *.iobj 345 | *.pch 346 | *.pdb 347 | *.ipdb 348 | *.pgc 349 | *.pgd 350 | *.rsp 351 | *.sbr 352 | *.tlb 353 | *.tli 354 | *.tlh 355 | *.tmp 356 | *.tmp_proj 357 | *_wpftmp.csproj 358 | *.log 359 | *.vspscc 360 | *.vssscc 361 | .builds 362 | *.pidb 363 | *.svclog 364 | *.scc 365 | 366 | # Chutzpah Test files 367 | _Chutzpah* 368 | 369 | # Visual C++ cache files 370 | ipch/ 371 | *.aps 372 | *.ncb 373 | *.opendb 374 | *.opensdf 375 | *.sdf 376 | *.cachefile 377 | *.VC.db 378 | *.VC.VC.opendb 379 | 380 | # Visual Studio profiler 381 | *.psess 382 | *.vsp 383 | *.vspx 384 | *.sap 385 | 386 | # Visual Studio Trace Files 387 | *.e2e 388 | 389 | # TFS 2012 Local Workspace 390 | $tf/ 391 | 392 | # Guidance Automation Toolkit 393 | *.gpState 394 | 395 | # ReSharper is a .NET coding add-in 396 | _ReSharper*/ 397 | *.[Rr]e[Ss]harper 398 | *.DotSettings.user 399 | 400 | # JustCode is a .NET coding add-in 401 | .JustCode 402 | 403 | # TeamCity is a build add-in 404 | _TeamCity* 405 | 406 | # DotCover is a Code Coverage Tool 407 | *.dotCover 408 | 409 | # AxoCover is a Code Coverage Tool 410 | .axoCover/* 411 | !.axoCover/settings.json 412 | 413 | # Visual Studio code coverage results 414 | *.coverage 415 | *.coveragexml 416 | 417 | # NCrunch 418 | _NCrunch_* 419 | .*crunch*.local.xml 420 | nCrunchTemp_* 421 | 422 | # MightyMoose 423 | *.mm.* 424 | AutoTest.Net/ 425 | 426 | # Web workbench (sass) 427 | .sass-cache/ 428 | 429 | # Installshield output folder 430 | [Ee]xpress/ 431 | 432 | # DocProject is a documentation generator add-in 433 | DocProject/buildhelp/ 434 | DocProject/Help/*.HxT 435 | DocProject/Help/*.HxC 436 | DocProject/Help/*.hhc 437 | DocProject/Help/*.hhk 438 | DocProject/Help/*.hhp 439 | DocProject/Help/Html2 440 | DocProject/Help/html 441 | 442 | # Click-Once directory 443 | publish/ 444 | 445 | # Publish Web Output 446 | *.[Pp]ublish.xml 447 | *.azurePubxml 448 | # Note: Comment the next line if you want to checkin your web deploy settings, 449 | # but database connection strings (with potential passwords) will be unencrypted 450 | *.pubxml 451 | *.publishproj 452 | 453 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 454 | # checkin your Azure Web App publish settings, but sensitive information contained 455 | # in these scripts will be unencrypted 456 | PublishScripts/ 457 | 458 | # NuGet Packages 459 | *.nupkg 460 | # NuGet Symbol Packages 461 | *.snupkg 462 | # The packages folder can be ignored because of Package Restore 463 | **/[Pp]ackages/* 464 | # except build/, which is used as an MSBuild target. 465 | !**/[Pp]ackages/build/ 466 | # Uncomment if necessary however generally it will be regenerated when needed 467 | #!**/[Pp]ackages/repositories.config 468 | # NuGet v3's project.json files produces more ignorable files 469 | *.nuget.props 470 | *.nuget.targets 471 | 472 | # Microsoft Azure Build Output 473 | csx/ 474 | *.build.csdef 475 | 476 | # Microsoft Azure Emulator 477 | ecf/ 478 | rcf/ 479 | 480 | # Windows Store app package directories and files 481 | AppPackages/ 482 | BundleArtifacts/ 483 | Package.StoreAssociation.xml 484 | _pkginfo.txt 485 | *.appx 486 | *.appxbundle 487 | *.appxupload 488 | 489 | # Visual Studio cache files 490 | # files ending in .cache can be ignored 491 | *.[Cc]ache 492 | # but keep track of directories ending in .cache 493 | !?*.[Cc]ache/ 494 | 495 | # Others 496 | ClientBin/ 497 | ~$* 498 | *.dbmdl 499 | *.dbproj.schemaview 500 | *.jfm 501 | *.pfx 502 | *.publishsettings 503 | orleans.codegen.cs 504 | 505 | # Including strong name files can present a security risk 506 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 507 | #*.snk 508 | 509 | # Since there are multiple workflows, uncomment next line to ignore bower_components 510 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 511 | #bower_components/ 512 | 513 | # RIA/Silverlight projects 514 | Generated_Code/ 515 | 516 | # Backup & report files from converting an old project file 517 | # to a newer Visual Studio version. Backup files are not needed, 518 | # because we have git ;-) 519 | _UpgradeReport_Files/ 520 | Backup*/ 521 | UpgradeLog*.XML 522 | UpgradeLog*.htm 523 | ServiceFabricBackup/ 524 | *.rptproj.bak 525 | 526 | # SQL Server files 527 | *.mdf 528 | *.ldf 529 | *.ndf 530 | 531 | # Business Intelligence projects 532 | *.rdl.data 533 | *.bim.layout 534 | *.bim_*.settings 535 | *.rptproj.rsuser 536 | *- [Bb]ackup.rdl 537 | *- [Bb]ackup ([0-9]).rdl 538 | *- [Bb]ackup ([0-9][0-9]).rdl 539 | 540 | # Microsoft Fakes 541 | FakesAssemblies/ 542 | 543 | # GhostDoc plugin setting file 544 | *.GhostDoc.xml 545 | 546 | # Node.js Tools for Visual Studio 547 | .ntvs_analysis.dat 548 | node_modules/ 549 | 550 | # Visual Studio 6 build log 551 | *.plg 552 | 553 | # Visual Studio 6 workspace options file 554 | *.opt 555 | 556 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 557 | *.vbw 558 | 559 | # Visual Studio LightSwitch build output 560 | **/*.HTMLClient/GeneratedArtifacts 561 | **/*.DesktopClient/GeneratedArtifacts 562 | **/*.DesktopClient/ModelManifest.xml 563 | **/*.Server/GeneratedArtifacts 564 | **/*.Server/ModelManifest.xml 565 | _Pvt_Extensions 566 | 567 | # Paket dependency manager 568 | .paket/paket.exe 569 | paket-files/ 570 | 571 | # FAKE - F# Make 572 | .fake/ 573 | 574 | # CodeRush personal settings 575 | .cr/personal 576 | 577 | # Python Tools for Visual Studio (PTVS) 578 | *.pyc 579 | 580 | # Cake - Uncomment if you are using it 581 | # tools/** 582 | # !tools/packages.config 583 | 584 | # Tabs Studio 585 | *.tss 586 | 587 | # Telerik's JustMock configuration file 588 | *.jmconfig 589 | 590 | # BizTalk build output 591 | *.btp.cs 592 | *.btm.cs 593 | *.odx.cs 594 | *.xsd.cs 595 | 596 | # OpenCover UI analysis results 597 | OpenCover/ 598 | 599 | # Azure Stream Analytics local run output 600 | ASALocalRun/ 601 | 602 | # MSBuild Binary and Structured Log 603 | *.binlog 604 | 605 | # NVidia Nsight GPU debugger configuration file 606 | *.nvuser 607 | 608 | # MFractors (Xamarin productivity tool) working folder 609 | .mfractor/ 610 | 611 | # Local History for Visual Studio 612 | .localhistory/ 613 | 614 | # BeatPulse healthcheck temp database 615 | healthchecksdb 616 | 617 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 618 | MigrationBackup/ 619 | 620 | -------------------------------------------------------------------------------- /a2c.py: -------------------------------------------------------------------------------- 1 | """Base version of advantage actor critic training of the DRRL architecture on a 2 | BoxWorld task. Can be used with local actor instances on CPUS and global learner instance on GPU. 3 | Make sure to have the gym-boxworld environment registered: https://github.com/mavischer/Box-World 4 | Script made with, among others, inspiration from https://github.com/MorvanZhou/pytorch-A3C/ and 5 | https://lilianweng.github.io/lil-log/2018/04/08/policy-gradient-algorithms.html. 6 | """ 7 | import torch 8 | import torch.multiprocessing as mp 9 | import gym 10 | from attention_module import DRRLnet 11 | from torch.distributions import Categorical 12 | import os 13 | import argparse 14 | import yaml 15 | import time 16 | 17 | #parse yaml config file from cmdline 18 | parser = argparse.ArgumentParser(description='PyTorch A2C BoxWorld Experiment') 19 | parser.add_argument("-c", "--configpath", type=str, required=True, help="path/to/configfile.yml") 20 | parser.add_argument("-s", "--savepath", type=str, required=True, help="path/to/savedirectory") 21 | args = parser.parse_args() 22 | with open(os.path.abspath(args.configpath), 'r') as file: 23 | config = yaml.safe_load(file) 24 | SAVEPATH = args.savepath 25 | 26 | #set stage 27 | if not os.path.isdir(SAVEPATH): 28 | os.mkdir(SAVEPATH) 29 | torch.manual_seed(config["seed"]) 30 | ENV_CONFIG = config["env_config"] 31 | if config["n_cpus"] == -1: 32 | config["n_cpus"] = mp.cpu_count() 33 | N_W = config["n_cpus"] 34 | g_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 35 | print(f"running global net on {g_device}") 36 | l_device = torch.device("cpu") 37 | with open(os.path.join(SAVEPATH, "config.yml"), "w+") as f: 38 | f.write(yaml.dump(config)) 39 | 40 | #make environment 41 | env = gym.make('gym_boxworld:boxworld-v0', **ENV_CONFIG) 42 | N_ACT = env.action_space.n 43 | INP_W = env.observation_space.shape[0] 44 | INP_H = env.observation_space.shape[1] 45 | 46 | #configure learning 47 | N_STEP = config["n_step"] 48 | GAMMA = config["gamma"] 49 | 50 | # #todo: check whether we really need this? 51 | # class SharedAdam(torch.optim.Adam): 52 | # """Shared optimizer, the parameters in the optimizer will be shared across multiprocessors.""" 53 | # def __init__(self, params, lr=1e-3, betas=(0.9, 0.9), eps=1e-8, 54 | # weight_decay=0): 55 | # super(SharedAdam, self).__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 56 | # # State initialization 57 | # for group in self.param_groups: 58 | # for p in group['params']: 59 | # state = self.state[p] 60 | # state['step'] = 0 61 | # state['exp_avg'] = torch.zeros_like(p.data) 62 | # state['exp_avg_sq'] = torch.zeros_like(p.data) 63 | # 64 | # # share in memory 65 | # state['exp_avg'].share_memory_() 66 | # state['exp_avg_sq'].share_memory_() 67 | 68 | 69 | class Worker(mp.Process): 70 | def __init__(self, gnet, w_idx, device=l_device, verbose=False): 71 | """ 72 | Args: 73 | gnet: global network that performs parameter updates 74 | w_idx: integer index of worker process for identification 75 | device: assigned device 76 | verbose:whether to print results of current game (better disable for large number of workers) 77 | """ 78 | super(Worker, self).__init__() 79 | self.name = f"w{w_idx:02}" 80 | self.g_net = gnet 81 | self.l_net = DRRLnet(INP_W, INP_H, N_ACT).to(device) # local network 82 | self.l_net.eval() #otherwise gradients clutter memory 83 | print(f"{self.name}: running local net on {l_device}") 84 | self.env = gym.make('gym_boxworld:boxworld-v0', **ENV_CONFIG) 85 | self.device = device 86 | self.verbose = verbose 87 | 88 | def start(self): 89 | """Runs an entire episode, in the end return trajectory as list of s,a,r. 90 | 91 | Returns: 92 | Ready-to-digest trajectory as torch.tensors of state, action, reward+discounted future rewards 93 | """ 94 | s = self.env.reset() 95 | s_, a_, r_ = [], [], [] #trajectory of episode goes here 96 | ep_r = 0. #total episode reward 97 | ep_t = 0 #episode step t, both just for oversight 98 | while True: #generate variable-length trajectory in this loop 99 | s = torch.tensor([s.T], dtype=torch.float, device=self.device) # transpose for CWH-order, apparently 100 | # conv layer want floats 101 | p, _ = self.l_net(s) 102 | m = Categorical(p) # create a categorical distribution over the list of probabilities of actions 103 | a = m.sample().item() # and sample an action using the distribution 104 | s_new, r, done, _ = self.env.step(a) 105 | ep_r += r 106 | 107 | # append current step's elements to lists 108 | s_.append(s) 109 | a_.append(a) 110 | r_.append(r) 111 | 112 | if done: # return trajectory as lists of elements 113 | if self.verbose: 114 | print(f"{self.name}: episode ended after step {ep_t} with total reward {ep_r}") 115 | return(self.prettify_trajectory(s_, a_, r_), ep_r) 116 | s = s_new 117 | ep_t += 1 118 | 119 | def prettify_trajectory(self, s_, a_, r_): 120 | """Prepares trajectory to be sent to central learner in a more orderly fashion. 121 | 122 | Calculate temporally discounted future rewards for TD error and reverse lists so everything is in correct 123 | temporal order. Cast everything to appropriate tensorflow tensors. Return what's necessary to compute 124 | gradients upstream. 125 | 126 | Args: 127 | s_: List of states 128 | a_: List of actions 129 | r_: List of returns 130 | 131 | Returns: 132 | Ready-to-digest trajectory as torch.tensors of state, action, reward+discounted future rewards 133 | 134 | """ 135 | # calculate temporal-discounted rewards 136 | r_acc = 0 137 | r_disc = [] 138 | for r in r_[::-1]: # reverse buffer r, discount accordingly and add in value at for t=t_end 139 | r_acc = r + GAMMA * r_acc 140 | r_disc.append(r_acc) # discounted trajectory in reverse order 141 | r_disc = r_disc[::-1] 142 | # every element in r_disc now contains reward at corresponding step plus future discounted rewards 143 | 144 | #cast everything to tensors (states are already cast) 145 | s_ = torch.cat(s_).to(device=g_device).detach() 146 | a_ = torch.tensor(a_, dtype=torch.uint8, device=g_device).detach() #torch can only compute gradients for float 147 | # tensors, but this shouldn't be a problem 148 | r_disc = torch.tensor(r_disc, dtype=torch.float16, device=g_device).detach() 149 | 150 | return(s_,a_,r_disc) 151 | 152 | def pull_params(self): 153 | """Update own params from global network.""" 154 | self.l_net.load_state_dict(self.g_net.state_dict(), strict=True) 155 | 156 | def update_step(net, trajectories, opt, opt_step): 157 | """Calculate advantage-actor-critic loss on batch with network, updates parameters 158 | Args: 159 | net: network to perform training on 160 | batch: list of trajectories as tuple of 1 tensor object for each 161 | b_s: states 162 | b_a: chosen actions 163 | b_r_disc: returns and discounted future rewards for TD error with value function 164 | Returns: Sum of loss of all trajectories 165 | """ 166 | rezip = zip(*trajectories) 167 | b_s, b_a, b_r_disc = [torch.cat(elems) for elems in list(rezip)] #concatenate torch tensors of all trajectories 168 | try: 169 | b_p, b_v = net.forward(b_s) 170 | except RuntimeError: 171 | print(f"failed to handle batch of size {b_s.shape}") 172 | raise 173 | 174 | td = b_r_disc - b_v 175 | m = torch.distributions.Categorical(b_p) 176 | # e_w = min(1, 2*0.995**opt_step) #todo: try out entropy annealing! 177 | e_w = 0.005 # like in paper 178 | total_loss = (0.5 * td.pow(2) - m.log_prob(b_a) * td.detach().squeeze() + e_w * m.entropy()).mean() 179 | 180 | opt.zero_grad() 181 | total_loss.backward() 182 | opt.step() 183 | return total_loss.sum() 184 | 185 | def save_step(i_step, g_net, steps, losses, rewards): 186 | """Saves statistics to global var SAVEPATH and cleans up outdated save files 187 | Args: 188 | i_step: iteration step 189 | g_net: global net's state dictionary containing all variables' values 190 | steps: global number of environment steps (all workers combined) 191 | losses: global loss of episodes 192 | rewards: average rewards of episodes 193 | """ 194 | try: 195 | ending = f"{i_step:05}.pt" 196 | for name, var in zip(["g_net", "steps", "losses", "rewards"], [g_net.state_dict(), steps, losses, rewards]): 197 | torch.save(var, os.path.join(SAVEPATH, name+ending)) 198 | except Exception as e: 199 | print(f"failed to write step {i_step} to disk:") 200 | print(e) 201 | try: 202 | # clean out old files 203 | oldfiles = [f for f in os.listdir(SAVEPATH) 204 | if (f.startswith("g_net") or f.startswith("steps") or 205 | f.startswith("losses") or f.startswith("rewards")) 206 | and not f.endswith(ending)] 207 | for f in oldfiles: 208 | os.remove(os.path.join(SAVEPATH, f)) 209 | except Exception as e: 210 | print("failed to erase old saves:") 211 | print(e) 212 | 213 | def load_step(): 214 | """Loads statistics from global var SAVEPATH and loads g_nets parameters from saved state_dict 215 | Returns: 216 | list of loaded variables 217 | g_net: global net's state dictionary containing all variables' values 218 | steps: global number of environment steps (all workers combined) 219 | losses: global loss of episodes 220 | rewards: average rewards of episodes 221 | """ 222 | loaded_vars = [] 223 | for name in ["g_net", "steps", "losses", "rewards"]: 224 | files = [file for file in os.listdir(SAVEPATH) if file.startswith(name)] 225 | if len(files) > 1: 226 | raise Exception(f"more than one savefile found for {name}") 227 | else: 228 | loaded_vars.append(torch.load(os.path.join(SAVEPATH, files[0]))) 229 | return(loaded_vars) 230 | 231 | if __name__ == "__main__": 232 | #create global network and pipeline 233 | g_net = DRRLnet(INP_W, INP_H, N_ACT).to(g_device) # global network 234 | #todo: only implicit init so far 235 | g_net.share_memory() # share the global parameters in multiprocessing #todo: check whether this makes a difference 236 | # optimizer = SharedAdam(g_net.parameters(), lr=0.0001) # global optimizer 237 | if config["optimizer"] == "RMSprop": 238 | #RMSprop optimizer was used for the large state space, not the small ones and impala instead of a3c. 239 | # "Learning rate was tuned between 1e-5 and 2e-4" probably means they did hyperparameter search. 240 | # scheduling is also possible conveniently using torch torch.optim.lr_scheduler 241 | # perhaps use smaller decay term 0.9 242 | optimizer = torch.optim.RMSprop(g_net.parameters(), eps=0.1, lr=config["lr"]) 243 | else: 244 | #Adam optimizer was used for the starcraft games with learning rate decaying linearly over 1e10 steps from 245 | # 1e-4 to 1e-5. other params are torch defaults 246 | optimizer = torch.optim.Adam(g_net.parameters(),lr=config["lr"]) 247 | g_step = 0 248 | 249 | #create workers 250 | losses = [] 251 | steps = [] 252 | rewards = [] 253 | trajectories = [] 254 | workers = [Worker(g_net, i) for i in range(N_W)] 255 | 256 | [w.pull_params() for w in workers] #make workers identical copies of global network before training begins 257 | for i_step in range(N_STEP): #performing one parallel update step 258 | #parallel trajectory sampling 259 | episodes = [w.start() for w in workers] #list comprehension automatically waits for workers to finish 260 | trajectories, cum_rewards = zip(*episodes) 261 | #concatenate and push tracjectories to global network for learning (synchronized update) 262 | loss = update_step(g_net, trajectories, optimizer, i_step) 263 | #pull new parameters 264 | [w.pull_params() for w in workers] 265 | #trying to free some gpu memory... 266 | if g_device.type == "cuda": #these only release memory to be visible, should not make a substantial difference 267 | torch.cuda.empty_cache() 268 | # torch.cuda.synchronize() 269 | #bookkeeping 270 | len_iter = sum([len(traj[0])for traj in trajectories]) 271 | g_step += len_iter 272 | steps.append(g_step) 273 | losses.append(loss.item()) 274 | rewards.append(sum(cum_rewards)/N_W) 275 | print(f"{time.strftime('%a %d %b %H:%M:%S', time.gmtime())}: it: {i_step}, steps:{len_iter}, " 276 | f"cum. steps:{g_step}, total loss:{loss.item():.2f}, avg. reward:{rewards[-1]:.2f}.") 277 | if i_step%1 == 0: #save global network 278 | save_step(i_step, g_net, steps, losses, rewards) 279 | 280 | if config["plot"]: 281 | import matplotlib.pyplot as plt 282 | plt.plot(steps, losses) 283 | 284 | if config["tensorboard"]: 285 | from torch.utils.tensorboard import SummaryWriter 286 | # create writers 287 | g_writer = SummaryWriter(os.path.join(SAVEPATH, "tb_g_net")) 288 | l_writer = SummaryWriter(os.path.join(SAVEPATH, "tb_l_net")) 289 | # write graph to file 290 | rezip = zip(*trajectories) 291 | b_s, _, _ = [torch.cat(elems) for elems in list(rezip)] # concatenate torch tensors of all trajectories 292 | g_writer.add_graph(g_net,b_s) 293 | l_writer.add_graph(workers[0].l_net,b_s) 294 | g_writer.close() 295 | l_writer.close() 296 | 297 | -------------------------------------------------------------------------------- /a2c_dist.py: -------------------------------------------------------------------------------- 1 | """Version of a2c.py to calculate gradients in local worker instances and send them to gobal optimizer. This way, 2 | no GPU is required at the cost of slightly longer steps.""" 3 | import torch 4 | import torch.multiprocessing as mp 5 | import gym 6 | from attention_module import DRRLnet 7 | from torch.distributions import Categorical 8 | import os 9 | import argparse 10 | import yaml 11 | import time 12 | import pandas as pd 13 | import random 14 | 15 | #parse yaml config file from cmdline 16 | parser = argparse.ArgumentParser(description='PyTorch A2C BoxWorld Experiment') 17 | parser.add_argument("-c", "--configpath", type=str, required=True, help="path/to/configfile.yml") 18 | parser.add_argument("-s", "--savepath", type=str, required=True, help="path/to/savedirectory") 19 | args = parser.parse_args() 20 | with open(os.path.abspath(args.configpath), 'r') as file: 21 | config = yaml.safe_load(file) 22 | 23 | SAVEPATH = args.savepath 24 | SAVE_IVAL = 1 25 | 26 | torch.manual_seed(config["seed"]) 27 | ENV_CONFIG = config["env_config"] 28 | NET_CONFIG = config["net_config"] 29 | 30 | if config["n_cpus"] == -1: 31 | config["n_cpus"] = mp.cpu_count() -1 32 | N_W = config["n_cpus"] 33 | 34 | def random_config(raw_config=ENV_CONFIG): 35 | """Field in config can contain a list of values that a worker randomly chooses from when starting to generate a 36 | trajectory, i.e. sampling solution paths of different depth or different number of distractor branches. 37 | 38 | """ 39 | config = {} 40 | for key, value in raw_config.items(): 41 | if type(value) == list: 42 | config[key] = random.choice(value) 43 | else: 44 | config[key] = value 45 | return config 46 | 47 | #obtain action and state space size 48 | env = gym.make('gym_boxworld:boxworld-v0', **random_config()) 49 | N_ACT = env.action_space.n 50 | INP_W = env.observation_space.shape[0] 51 | INP_H = env.observation_space.shape[1] 52 | del env 53 | 54 | #configure learning 55 | N_STEP = config["n_step"] 56 | GAMMA = config["gamma"] 57 | if "e_schedule" in config.keys(): 58 | e_schedule = config["e_schedule"] 59 | else: 60 | e_schedule = False 61 | 62 | class Worker(mp.Process): 63 | def __init__(self, g_net, stats_q, grads_q, w_idx, i_start=0, e_schedule=True, verbose=False): 64 | """ 65 | Args: 66 | gnet: global network that performs parameter updates 67 | stats_q:queue to put statistics of sampled trajectory 68 | grads_q:queue to put gradients 69 | w_idx: integer index of worker process for identification 70 | e_schedule: schedule entropy weight 71 | verbose:whether to print results of current game (better disable for large number of workers) 72 | """ 73 | super(Worker, self).__init__() 74 | self.name = f"w{w_idx:02}" #overwrites Processes' name 75 | self.g_net = g_net 76 | self.stats_q = stats_q 77 | self.grads_q = grads_q 78 | self.l_net = DRRLnet(INP_W, INP_H, N_ACT, **NET_CONFIG) # local network 79 | self.l_net.train() # sets net in training mode so gradient's don't clutter memory 80 | self.e_schedule = e_schedule 81 | self.verbose = verbose 82 | self.iter = i_start #basically a private i_step 83 | 84 | def run(self): 85 | """Runs an entire episode, calculates gradients for all weights 86 | 87 | Writes to stats_queue the accumulated returns (not discounted), number of environment steps and loss of sampled 88 | episode. 89 | Write to grads_queue the gradients are written directly to the central learner's parameters grads. 90 | """ 91 | 92 | ### sampling trajectory 93 | while start_cond.wait(1000): #wait for background process to signal start of an episode (if timeout reached 94 | # wait returns false and run is aborted 95 | ### generate random environment for this episode 96 | env_config = random_config() 97 | env = gym.make('gym_boxworld:boxworld-v0', **random_config()) 98 | 99 | # print(f"{self.name}: starting iteration") 100 | t_start = time.time() 101 | self.pull_params() 102 | self.l_net.eval() 103 | s = env.reset() 104 | s_, a_, r_ = [], [], [] #trajectory of episode goes here 105 | ep_r = 0. #total episode reward 106 | ep_t = 0 #episode step t, both just for oversight 107 | while True: #generate variable-length trajectory in this loop 108 | s = torch.tensor([s.T], dtype=torch.float) # transpose for CWH-order, apparently 109 | # conv layer want floats 110 | p, _ = self.l_net(s) 111 | m = Categorical(p) # create a categorical distribution over the list of probabilities of actions 112 | a = m.sample().item() # and sample an action using the distribution 113 | s_new, r, done, _ = env.step(a) 114 | ep_r += r 115 | 116 | # append current step's elements to lists 117 | s_.append(s) 118 | a_.append(a) 119 | r_.append(r) 120 | 121 | if done: # return trajectory as lists of elements 122 | if self.verbose: 123 | print(f"{self.name}: episode ended after step {ep_t} with total reward {ep_r}") 124 | break 125 | s = s_new 126 | ep_t += 1 127 | # t_sample = time.time() 128 | # print(f"{self.name}: sampling took {t_sample-t_start:.2f}s") 129 | ### forward and backward pass of entire episode 130 | # preprocess trajectory 131 | self.l_net.zero_grad() 132 | self.l_net.train() 133 | 134 | s_,a_,r_disc = self.prettify_trajectory(s_,a_,r_) 135 | p_, v_ = self.l_net.forward(s_) 136 | 137 | #backward pass to calculate gradients 138 | loss, loss_dict = self.a2c_loss(s_,a_,r_disc,p_, v_) 139 | loss.backward() 140 | # t_grads = time.time() 141 | # print(f"{self.name}: calculating gradients took {t_grads-t_sample:.2f}s") 142 | 143 | ### shipping out gradients to centralized learner as named dict 144 | grads = [] 145 | for name, param in self.l_net.named_parameters(): 146 | grads.append((name, param.grad)) 147 | grad_dict = dict(grads) 148 | t_end = time.time() 149 | 150 | self.stats_q.put({**{"cumulative reward": ep_r, 151 | "loss": loss.item(), 152 | "success": (r==env.reward_gem+env.step_cost), 153 | "steps": ep_t + 1, 154 | "walltime": t_end-t_start}, 155 | **loss_dict, 156 | **env_config}) 157 | self.grads_q.put(grad_dict) 158 | # print(f"{self.name}: distributing gradients took {t_end-t_grads:.2f}s") 159 | # print(f"{self.name}: episode took {t_end-t_start}s") 160 | self.iter += 1 161 | 162 | def prettify_trajectory(self, s_, a_, r_): 163 | """Prepares trajectory to compute loss on, just to make the code clearer 164 | 165 | Calculate temporally discounted future rewards for TD error and reverse lists so everything is in correct 166 | temporal order. Cast everything to appropriate tensorflow tensors. Return what's necessary to compute 167 | gradients upstream. 168 | 169 | Args: 170 | s_: List of states 171 | a_: List of actions 172 | r_: List of returns 173 | 174 | Returns: 175 | Ready-to-digest trajectory as torch.tensors of state, action, reward+discounted future rewards 176 | """ 177 | # calculate temporal-discounted rewards 178 | r_acc = 0 179 | r_disc = [] 180 | for r in r_[::-1]: # reverse buffer r, discount accordingly and add in value at for t=t_end 181 | r_acc = r + GAMMA * r_acc 182 | r_disc.append(r_acc) # discounted trajectory in reverse order 183 | r_disc = r_disc[::-1] 184 | # every element in r_disc now contains reward at corresponding step plus future discounted rewards 185 | 186 | #cast everything to tensors (states are already cast) 187 | s_ = torch.cat(s_).detach() 188 | a_ = torch.tensor(a_, dtype=torch.uint8).detach() #torch can only compute gradients for float 189 | # tensors, but this shouldn't be a problem 190 | r_disc = torch.tensor(r_disc, dtype=torch.float16).detach() 191 | 192 | return(s_,a_,r_disc) 193 | 194 | def a2c_loss(self, s_,a_,r_disc,p_, v_): 195 | """Calculate advantage-actor-critic loss on entire episode 196 | Args: 197 | for the entire trajectory, one tensor each of 198 | s_: states 199 | a_: actions 200 | r_disc: temporally discounted future rewards 201 | p_: action probabilities 202 | v_: value estimates 203 | 204 | Returns: Summed losses of trajectory 205 | """ 206 | 207 | # critic loss 208 | td = r_disc - v_.squeeze() 209 | c_loss = td.pow(2) 210 | # actor loss 211 | m = torch.distributions.Categorical(p_) 212 | a_loss = - m.log_prob(a_) * td.detach() 213 | # entropy term 214 | if self.e_schedule: 215 | # e_w = max(0, min(2, -self.iter/200 + 2.5)) #linear annealing between episode 100 and 500 from 2 to 0 216 | # e_w = max(0, min(1, -self.iter/400 + 1.25)) #linear annealing between episode 100 and 500 from 1 to 0 217 | e_w = max(0, min(0.5, -self.iter/800 + 0.625)) #linear annealing between episode 100 and 500 from 0.5 to 0 218 | 219 | else: 220 | e_w = 0.5 #in paper they report 0.05 but probably take the sum instaead of the mean 221 | e_loss = m.entropy() 222 | total_loss = (0.5 * c_loss + a_loss - e_w * e_loss).mean() #why was there a .detach here? 223 | 224 | # m = torch.distributions.Categorical(p_) 225 | # # e_w = min(1, 2*0.995**opt_step) #todo: try out entropy annealing! 226 | # e_w = 0.005 # like in paper 227 | # total_loss = (0.5 * td.pow(2) - m.log_prob(a_) * td.squeeze() + e_w * m.entropy()).mean() 228 | 229 | return total_loss, {"critic loss": c_loss.mean().item()*0.5, 230 | "actor loss": a_loss.mean().item(), 231 | "entropy term": e_loss.mean().item(), 232 | "ent. weight": e_w} 233 | 234 | def pull_params(self): 235 | """Update own params from global network.""" 236 | self.l_net.load_state_dict(self.g_net.state_dict(), strict=True) 237 | 238 | 239 | def save_step(i_step, g_net, stats): 240 | """Saves statistics to global var SAVEPATH and cleans up outdated save files 241 | Args: 242 | i_step: iteration step 243 | g_net: global net's state dictionary containing all variables' values 244 | stats: list of dicts to be saved to disc as pandas dataframe 245 | """ 246 | try: 247 | ending = f"{i_step:05}" 248 | torch.save(g_net.state_dict(), os.path.join(SAVEPATH, "net."+ending)) 249 | pd.DataFrame(stats).to_csv(os.path.join(SAVEPATH, "stats."+ending)) 250 | #remove old files 251 | oldfiles = [f for f in os.listdir(SAVEPATH) 252 | if (f.startswith("net") or f.startswith("stats")) 253 | and not f.endswith(ending)] 254 | for f in oldfiles: 255 | os.remove(os.path.join(SAVEPATH, f)) 256 | except Exception as e: 257 | print(f"failed to write step {i_step} to disk:") 258 | print(e) 259 | 260 | def load_step(): 261 | """Loads statistics from global var SAVEPATH and loads g_nets parameters from saved state_dict 262 | Returns: 263 | list of loaded variables 264 | g_net: global net's state dictionary containing all variables' values 265 | steps: global number of environment steps (all workers combined) 266 | losses: global loss of episodes 267 | rewards: average rewards of episodes 268 | """ 269 | files = os.listdir(SAVEPATH) 270 | statsfiles = [f for f in files if f.startswith("stats")] 271 | netfiles = [f for f in files if f.startswith("net")] 272 | if len(statsfiles) > 1 or len(netfiles) > 1: 273 | raise Exception(f"more than one savefile found for in folder {SAVEPATH}") 274 | 275 | i_step = int(statsfiles[0].split(".")[1]) 276 | state_dict = torch.load(os.path.join(SAVEPATH, netfiles[0])) 277 | df = pd.read_csv(os.path.join(SAVEPATH, statsfiles[0])).drop(columns=["Unnamed: 0"]) 278 | #retranslate df to stats 279 | stats = df.to_dict(orient="records") 280 | return(i_step, state_dict, stats) 281 | 282 | #set up plotting gradients 283 | if config["plot_gradients"]: 284 | import matplotlib.pyplot as plt 285 | 286 | fig, axes = plt.subplots(nrows=6, ncols=N_STEP, sharex=True) 287 | plt.tight_layout() 288 | if config["plot_gradients"] or config["tensorboard"]: 289 | import numpy as np 290 | 291 | 292 | def plot_gradients(g_net, i_step, N_STEP, axes): 293 | try: 294 | attMod = g_net.attMod.linear1 295 | except: 296 | attMod = g_net.baseMod.baseline_conv_1_1 297 | layers = [g_net.conv1, g_net.conv2, attMod, g_net.fc_seq.fc1, g_net.logits, g_net.value] 298 | l_names = ["conv1", "conv2", "attLin1", "fc1", "p", "v"] 299 | bins = np.arange(-4,2,0.1) 300 | for i_l,layer in enumerate(layers): 301 | axes[i_l,i_step].hist(layer.weight.grad.flatten(), bins=bins) 302 | axes[i_l,i_step].set_yticklabels([]) 303 | if i_step == 0: 304 | axes[i_l,i_step].set_ylabel(l_names[i_l]) 305 | if i_l == 0: 306 | axes[i_l,i_step].set_title(f"it {i_step}") 307 | 308 | 309 | if __name__ == "__main__": 310 | mp.set_start_method("fork") #fork is unix default and means child process inherits all resources from parent 311 | # process. in case problems occur, might use "forkserver" 312 | #create global network and pipeline 313 | g_net = DRRLnet(INP_W, INP_H, N_ACT, **NET_CONFIG) # global network 314 | g_net.zero_grad() 315 | g_net.share_memory() # share the global parameters in multiprocessing #todo: check whether this makes a difference 316 | stats_queue = mp.SimpleQueue() #statistics about the episodes will be returned in this queue 317 | grads_queue = mp.SimpleQueue() #the calculated gradients will be returned as dicts in this queue 318 | start_cond = mp.Event() #condition object to signal processes to perform another iteration # iteration 319 | # so worker process needs to be still alive when queue is accessed) 320 | if config["optimizer"] == "RMSprop": 321 | #RMSprop optimizer was used for the large state space, not the small ones and impala instead of a3c. 322 | # "Learning rate was tuned between 1e-5 and 2e-4" probably means they did hyperparameter search. 323 | # scheduling is also possible conveniently using torch torch.optim.lr_scheduler 324 | # perhaps use smaller decay term 0.9 325 | optimizer = torch.optim.RMSprop(g_net.parameters(), eps=0.1, lr=config["lr"]) 326 | else: 327 | #Adam optimizer was used for the starcraft games with learning rate decaying linearly over 1e10 steps from 328 | # 1e-4 to 1e-5. other params are torch defaults 329 | optimizer = torch.optim.Adam(g_net.parameters(),lr=config["lr"]) 330 | 331 | # set stage 332 | if not os.path.isdir(SAVEPATH): #directory does not exist or is empty 333 | os.mkdir(SAVEPATH) 334 | #write config to new directory 335 | with open(os.path.join(SAVEPATH, "config.yml"), "w+") as f: 336 | f.write(yaml.dump(config)) 337 | #start new training process 338 | stats = [] 339 | i_start = 0 340 | print("starting new training process") 341 | else: 342 | #load config and check whether identical 343 | with open(os.path.join(SAVEPATH, "config.yml"), 'r') as file: 344 | config_old = yaml.safe_load(file) 345 | if config_old != config: 346 | raise Exception("Existing config different from current config") 347 | i_start, net_dict, stats = load_step() 348 | g_net.load_state_dict(net_dict, strict=True) 349 | print(f"starting from loaded iteration {i_start+1}") 350 | 351 | #create workers 352 | workers = [Worker(g_net, stats_queue, grads_queue, i, 353 | i_start=i_start, e_schedule=e_schedule) for i in range(N_W)] 354 | [w.start() for w in workers] # workers will write the gradients to the parameters directly 355 | # [w.pull_params() for w in workers] #make workers identical copies of global network before training begins 356 | for i_step in range(i_start, N_STEP): #performing one parallel update step 357 | t0 = time.time() 358 | ###parallel trajectory sampling and gradient computation 359 | start_cond.set() # all processes start an iteration 360 | time.sleep(0.1) 361 | start_cond.clear() # this will halt processes' run method at the end of the current episode 362 | 363 | ###copying gradients to global net (also saving statistics) 364 | optimizer.zero_grad() 365 | for i_w in range(N_W): 366 | grad_dict = grads_queue.get() 367 | for name, param in g_net.named_parameters(): 368 | try: 369 | param.grad += grad_dict[name] 370 | except TypeError: 371 | param.grad = grad_dict[name] #at the very beginning, gradients are initialized to None 372 | stats_curr = stats_queue.get() 373 | stats_curr["global ep"] = i_step #append current global step to dictionary 374 | stats.append(stats_curr) 375 | 376 | if config["plot_gradients"]: #visualize gradients at crucial points in network as histogram for every iteration 377 | plot_gradients(g_net, i_step, N_STEP, axes) 378 | # ### copying gradients and perform optimizer step on global network 379 | # while not grads_queue.empty(): 380 | # grad_dict = grads_queue.get() 381 | # for name, param in g_net.named_parameters(): 382 | # try: 383 | # param.grad += grad_dict[name] 384 | # except TypeError: 385 | # param.grad = grad_dict[name] #at the very beginning, gradients are initialized to None 386 | ### centralized optimizer step 387 | optimizer.step() # centralized optimizer updates global network 388 | #bookkeeping 389 | if i_step%SAVE_IVAL == 0: #save global network 390 | save_step(i_step, g_net, stats) 391 | t1 = time.time() 392 | stats_recent = stats[-N_W:] 393 | steps_total_recent = 0 394 | for s in stats_recent: 395 | steps_total_recent += s["steps"] 396 | print(f"{time.strftime('%a %d %b %H:%M:%S', time.gmtime())}: iteration {i_step}: {t1-t0:.1f}s, " 397 | f"{steps_total_recent / (t1 - t0)}FPS ") 398 | 399 | save_step(i_step, g_net, stats) 400 | [w.terminate() for w in workers] 401 | # 402 | # if config["plot"]: 403 | # import matplotlib.pyplot as plt 404 | # import seaborn as sns 405 | # data = pd.DataFrame(stats) 406 | # for i,measure in enumerate(["cumulative reward", "loss", "steps"]): 407 | # plt.figure() 408 | # sns.lineplot(x="global ep",y=measure,data=data) 409 | 410 | #currently deprecated because it doesn't use the current worker 411 | if config["tensorboard"]: 412 | from torch.utils.tensorboard import SummaryWriter 413 | # create writers 414 | g_writer = SummaryWriter(os.path.join(SAVEPATH, "tb_g_net")) 415 | l_writer_train = SummaryWriter(os.path.join(SAVEPATH, "tb_l_net_train")) 416 | l_writer_eval = SummaryWriter(os.path.join(SAVEPATH, "tb_l_net_eval")) 417 | #generate random trajectory to feed-forward as batch 418 | w0 = workers[0] 419 | env = gym.make('gym_boxworld:boxworld-v0', **random_config()) 420 | 421 | env.reset() 422 | 423 | trajectory = [] 424 | while True: 425 | s, _, done, _ = env.step(np.random.choice(4)) 426 | s = torch.tensor([s.T], dtype=torch.float) 427 | trajectory.append(s) 428 | if done: 429 | break 430 | # write graph to file 431 | s_ = torch.cat(trajectory).detach() 432 | 433 | g_writer.add_graph(g_net,s_) 434 | w0.l_net.eval() 435 | l_writer_eval.add_graph(w0.l_net,s_) 436 | w0.l_net.train() 437 | l_writer_train.add_graph(w0.l_net,s_) 438 | g_writer.close() 439 | l_writer_eval.close() 440 | l_writer_train.close() 441 | 442 | ###visualize gradients at critical points 443 | 444 | 445 | -------------------------------------------------------------------------------- /helpers/a2c_ppo_acktr/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from DRRL.attention_module import AttentionModule 6 | from collections import OrderedDict 7 | 8 | from helpers.a2c_ppo_acktr.distributions import Bernoulli, Categorical, DiagGaussian 9 | from helpers.a2c_ppo_acktr.utils import init, init_flexible 10 | 11 | 12 | class Flatten(nn.Module): 13 | def forward(self, x): 14 | return x.view(x.size(0), -1) 15 | 16 | 17 | class Policy(nn.Module): 18 | def __init__(self, obs_shape, action_space, base=None, base_kwargs=None): 19 | super(Policy, self).__init__() 20 | if base_kwargs is None: 21 | base_kwargs = {} 22 | if base is None: 23 | if len(obs_shape) == 3: 24 | base = CNNBase 25 | elif len(obs_shape) == 1: 26 | base = MLPBase 27 | else: 28 | raise NotImplementedError 29 | self.base = base(obs_shape[0], **base_kwargs) 30 | 31 | if action_space.__class__.__name__ == "Discrete": 32 | num_outputs = action_space.n 33 | self.dist = Categorical(self.base.output_size, num_outputs) 34 | elif action_space.__class__.__name__ == "Box": 35 | num_outputs = action_space.shape[0] 36 | self.dist = DiagGaussian(self.base.output_size, num_outputs) 37 | elif action_space.__class__.__name__ == "MultiBinary": 38 | num_outputs = action_space.shape[0] 39 | self.dist = Bernoulli(self.base.output_size, num_outputs) 40 | else: 41 | raise NotImplementedError 42 | 43 | @property 44 | def is_recurrent(self): 45 | return self.base.is_recurrent 46 | 47 | @property 48 | def recurrent_hidden_state_size(self): 49 | """Size of rnn_hx.""" 50 | return self.base.recurrent_hidden_state_size 51 | 52 | def forward(self, inputs, rnn_hxs, masks): 53 | raise NotImplementedError 54 | 55 | def act(self, inputs, rnn_hxs, masks, deterministic=False): 56 | value, actor_features, rnn_hxs = self.base(inputs, rnn_hxs, masks) 57 | dist = self.dist(actor_features) 58 | 59 | if deterministic: 60 | action = dist.mode() 61 | else: 62 | action = dist.sample() 63 | 64 | action_log_probs = dist.log_probs(action) 65 | dist_entropy = dist.entropy().mean() 66 | 67 | return value, action, action_log_probs, rnn_hxs 68 | 69 | def get_value(self, inputs, rnn_hxs, masks): 70 | value, _, _ = self.base(inputs, rnn_hxs, masks) 71 | return value 72 | 73 | def evaluate_actions(self, inputs, rnn_hxs, masks, action): 74 | value, actor_features, rnn_hxs = self.base(inputs, rnn_hxs, masks) 75 | dist = self.dist(actor_features) 76 | 77 | action_log_probs = dist.log_probs(action) 78 | dist_entropy = dist.entropy().mean() 79 | 80 | return value, action_log_probs, dist_entropy, rnn_hxs 81 | 82 | class NNBase(nn.Module): 83 | def __init__(self, recurrent, recurrent_input_size, hidden_size): 84 | super(NNBase, self).__init__() 85 | 86 | self._hidden_size = hidden_size 87 | self._recurrent = recurrent 88 | 89 | if recurrent: 90 | self.gru = nn.GRU(recurrent_input_size, hidden_size) 91 | for name, param in self.gru.named_parameters(): 92 | if 'bias' in name: 93 | nn.init.constant_(param, 0) 94 | elif 'weight' in name: 95 | nn.init.orthogonal_(param) 96 | 97 | @property 98 | def is_recurrent(self): 99 | return self._recurrent 100 | 101 | @property 102 | def recurrent_hidden_state_size(self): 103 | if self._recurrent: 104 | return self._hidden_size 105 | return 1 106 | 107 | @property 108 | def output_size(self): 109 | return self._hidden_size 110 | 111 | def _forward_gru(self, x, hxs, masks): 112 | if x.size(0) == hxs.size(0): 113 | x, hxs = self.gru(x.unsqueeze(0), (hxs * masks).unsqueeze(0)) 114 | x = x.squeeze(0) 115 | hxs = hxs.squeeze(0) 116 | else: 117 | # x is a (T, N, -1) tensor that has been flatten to (T * N, -1) 118 | N = hxs.size(0) 119 | T = int(x.size(0) / N) 120 | 121 | # unflatten 122 | x = x.view(T, N, x.size(1)) 123 | 124 | # Same deal with masks 125 | masks = masks.view(T, N) 126 | 127 | # Let's figure out which steps in the sequence have a zero for any agent 128 | # We will always assume t=0 has a zero in it as that makes the logic cleaner 129 | has_zeros = ((masks[1:] == 0.0) \ 130 | .any(dim=-1) 131 | .nonzero() 132 | .squeeze() 133 | .cpu()) 134 | 135 | # +1 to correct the masks[1:] 136 | if has_zeros.dim() == 0: 137 | # Deal with scalar 138 | has_zeros = [has_zeros.item() + 1] 139 | else: 140 | has_zeros = (has_zeros + 1).numpy().tolist() 141 | 142 | # add t=0 and t=T to the list 143 | has_zeros = [0] + has_zeros + [T] 144 | 145 | hxs = hxs.unsqueeze(0) 146 | outputs = [] 147 | for i in range(len(has_zeros) - 1): 148 | # We can now process steps that don't have any zeros in masks together! 149 | # This is much faster 150 | start_idx = has_zeros[i] 151 | end_idx = has_zeros[i + 1] 152 | 153 | rnn_scores, hxs = self.gru( 154 | x[start_idx:end_idx], 155 | hxs * masks[start_idx].view(1, -1, 1)) 156 | 157 | outputs.append(rnn_scores) 158 | 159 | # assert len(outputs) == T 160 | # x is a (T, N, -1) tensor 161 | x = torch.cat(outputs, dim=0) 162 | # flatten 163 | x = x.view(T * N, -1) 164 | hxs = hxs.squeeze(0) 165 | 166 | return x, hxs 167 | 168 | 169 | class CNNBase(NNBase): 170 | def __init__(self, num_inputs, recurrent=False, hidden_size=512): 171 | super(CNNBase, self).__init__(recurrent, hidden_size, hidden_size) 172 | 173 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 174 | constant_(x, 0), nn.init.calculate_gain('relu')) 175 | 176 | # self.main = nn.Sequential( 177 | # init_(nn.Conv2d(num_inputs, 32, 8, stride=4)), nn.ReLU(), 178 | # init_(nn.Conv2d(32, 64, 4, stride=2)), nn.ReLU(), 179 | # init_(nn.Conv2d(64, 32, 3, stride=1)), nn.ReLU(), Flatten(), 180 | # init_(nn.Linear(32 * 7 * 7, hidden_size)), nn.ReLU()) 181 | self.main = nn.Sequential( 182 | init_(nn.Conv2d(num_inputs, 16, 2, stride=1)), nn.ReLU(), 183 | init_(nn.Conv2d(16, 16, 2, stride=1)), nn.ReLU(), Flatten(), 184 | init_(nn.Linear(16*5*5, hidden_size)), nn.ReLU()) 185 | 186 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 187 | constant_(x, 0)) 188 | 189 | self.critic_linear = init_(nn.Linear(hidden_size, 1)) 190 | 191 | self.train() 192 | 193 | def forward(self, inputs, rnn_hxs, masks): 194 | x = self.main(inputs / 255.0) 195 | 196 | if self.is_recurrent: 197 | x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks) 198 | 199 | return self.critic_linear(x), x, rnn_hxs 200 | 201 | 202 | class MLPBase(NNBase): 203 | def __init__(self, num_inputs, recurrent=False, hidden_size=64): 204 | super(MLPBase, self).__init__(recurrent, num_inputs, hidden_size) 205 | 206 | if recurrent: 207 | num_inputs = hidden_size 208 | 209 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 210 | constant_(x, 0), np.sqrt(2)) 211 | 212 | self.actor = nn.Sequential( 213 | init_(nn.Linear(num_inputs, hidden_size)), nn.Tanh(), 214 | init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh()) 215 | 216 | self.critic = nn.Sequential( 217 | init_(nn.Linear(num_inputs, hidden_size)), nn.Tanh(), 218 | init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh()) 219 | 220 | self.critic_linear = init_(nn.Linear(hidden_size, 1)) 221 | 222 | self.train() 223 | 224 | def forward(self, inputs, rnn_hxs, masks): 225 | x = inputs 226 | 227 | if self.is_recurrent: 228 | x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks) 229 | 230 | hidden_critic = self.critic(x) 231 | hidden_actor = self.actor(x) 232 | 233 | return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs 234 | 235 | 236 | class DRRLBase(NNBase): 237 | """Adaptation of DRRLnet class more easily usable with pytorch-a2c-ppo-acktr-gail implementation""" 238 | def __init__(self, num_inputs, recurrent=False, hidden_size=512, w=12, h=12, n_f_conv1 = 12, n_f_conv2 = 24, pad=True, 239 | att_emb_size=64, n_heads=2, n_att_stack=2, n_fc_layers=4, w_init = "orthogonal", 240 | baseline_mode=False, n_baseMods=3): 241 | """ 242 | Args: 243 | num_inputs: num input channels (usually 3 RGB channels) 244 | recurrent: Not implemented, hidden state can be put in and are returned 245 | hidden_size: Size of output layer of base, i.e. before action layer projects into action space inside Policy 246 | 247 | w: width of input image (including black boundary) 248 | h: hight of input image (including black boundary) 249 | n_f_conv1: #filters in first conv layer 250 | n_f_conv2: #filters in second conv layer 251 | pad: whether input images are padded through convolution layers so size is maintained 252 | 253 | att_emb_size: #attentional filters inside each head 254 | n_heads: #head in parallel inside attentional module 255 | n_att_stack: #times attentional module is stacked 256 | 257 | n_fc_layers: #fully connected output layers on top of attentional module 258 | 259 | baseline: use residual-convolutional baseline core instead of attentional module 260 | n_baseline: #residual-convolutional blocks inside baseline core""" 261 | 262 | if recurrent: 263 | raise NotImplementedError("Currently no recurrent version of DRRL architecture implemented.") 264 | #internal action replay buffer for simple training algorithms 265 | self.baseline_mode = baseline_mode 266 | self.saved_actions = [] 267 | self.rewards = [] 268 | 269 | self.pad = pad 270 | self.n_baseMods = n_baseMods 271 | super(DRRLBase, self).__init__(recurrent, hidden_size, hidden_size) 272 | 273 | if w_init=="orthogonal": #nn.init.orthogonal doesn't compute gain on its own 274 | init_ = lambda m: init_flexible(m, nn.init.orthogonal_, lambda x: nn.init. 275 | constant_(x, 0), {"gain": nn.init.calculate_gain('relu')}) 276 | elif w_init=="kaiming": 277 | init_ = lambda m: init_flexible(m, nn.init.kaiming_uniform_, lambda x: nn.init. 278 | constant_(x, 0), {"nonlinearity": "relu"}) 279 | else: 280 | raise NotImplementedError("init function not implemented") 281 | 282 | self.conv1 = init_(nn.Conv2d(3, n_f_conv1, kernel_size=2, stride=1)) 283 | #possibly batch or layer norm, neither was mentioned in the paper though 284 | # self.ln1 = nn.LayerNorm([n_f_conv1,conv1w,conv1h]) 285 | # self.bn1 = nn.BatchNorm2d(n_f_conv1) 286 | self.conv2 = init_(nn.Conv2d(n_f_conv1, n_f_conv2, kernel_size=2, stride=1)) 287 | # self.ln2 = nn.LayerNorm([n_f_conv2,conv2w,conv2h]) 288 | # self.bn2 = nn.BatchNorm2d(n_f_conv2) 289 | 290 | # calculate size of convolution module output 291 | def conv2d_size_out(size, kernel_size=2, stride=1): 292 | return (size - (kernel_size - 1) - 1) // stride + 1 293 | if self.pad: 294 | conv1w = conv2w = w 295 | conv1h = conv2h = h 296 | else: 297 | conv1w = conv2d_size_out(w) 298 | conv1h = conv2d_size_out(h) 299 | conv2w = conv2d_size_out(conv1w) 300 | conv2h = conv2d_size_out(conv1h) 301 | 302 | # create x,y coordinate matrices to append to convolution output 303 | xmap = np.linspace(-np.ones(conv2h), np.ones(conv2h), num=conv2w, endpoint=True, axis=0) 304 | xmap = torch.tensor(np.expand_dims(np.expand_dims(xmap,0),0), dtype=torch.float32, requires_grad=False) 305 | ymap = np.linspace(-np.ones(conv2w), np.ones(conv2w), num=conv2h, endpoint=True, axis=1) 306 | ymap = torch.tensor(np.expand_dims(np.expand_dims(ymap,0),0), dtype=torch.float32, requires_grad=False) 307 | self.register_buffer("xymap", torch.cat((xmap,ymap),dim=1)) # shape (1, 2, conv2w, conv2h) 308 | 309 | # an "attendable" entity has 24 CNN channels + 2 coordinate channels = 26 features. this is also the default 310 | # number of baseline module conv layer filter number 311 | att_elem_size = n_f_conv2 + 2 312 | if not self.baseline_mode: 313 | # create attention module with n_heads heads and remember how many times to stack it 314 | self.n_att_stack = n_att_stack #how many times the attentional module is to be stacked (weight-sharing -> reuse) 315 | self.attMod = AttentionModule(conv2w*conv2h, att_elem_size, att_emb_size, n_heads) 316 | 317 | for m in self.attMod.modules(): #.modules() iterates recursively 318 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 319 | init_(m) 320 | 321 | else: # create baseline module of several residual-convolutional layers 322 | base_dict = {} 323 | for i in range(self.n_baseMods): 324 | base_dict[f"baseline_identity_{i}"] = nn.Identity() 325 | base_dict[f"baseline_conv_{i}_0"] = init_(nn.Conv2d(att_elem_size, att_elem_size, kernel_size=3, 326 | stride=1)) 327 | base_dict[f"baseline_batchnorm_{i}_0"] = nn.BatchNorm2d(att_elem_size) 328 | base_dict[f"baseline_conv_{i}_1"] = init_(nn.Conv2d(att_elem_size, att_elem_size, kernel_size=3, 329 | stride=1)) 330 | base_dict[f"baseline_batchnorm_{i}_1"] = nn.BatchNorm2d(att_elem_size) 331 | 332 | self.baseMod = nn.ModuleDict(base_dict) 333 | #max pooling 334 | # print(f"attnl element size:{att_elem_size}") 335 | # self.maxpool = nn.MaxPool1d(kernel_size=att_emb_size,return_indices=False) #don't know why maxpool reduces 336 | # kernel_size by 1 337 | 338 | # FC256 layers, 4 is default 339 | if n_fc_layers < 1: 340 | raise ValueError("At least 1 linear readout layer is required.") 341 | fc_dict = OrderedDict([('fc1', init_(nn.Linear(att_elem_size, hidden_size))), 342 | ('relu1', nn.ReLU())]) #first one has different inpuz size 343 | for i in range(n_fc_layers-1): 344 | fc_dict[f"fc{i+2}"] = init_(nn.Linear(hidden_size, hidden_size)) 345 | fc_dict[f"relu{i+2}"] = nn.ReLU() 346 | self.fc_seq = nn.Sequential(fc_dict) #sequential container from ordered dict 347 | 348 | self.critic_linear = init_(nn.Linear(hidden_size, 1)) 349 | 350 | self.train() 351 | 352 | def forward(self, inputs, rnn_hxs=None, masks=None): 353 | """hidden states rnn_hxs and masks are not implemented because there currently is no recurrent version of the 354 | attentional architecture. 355 | """ 356 | x = inputs / 255.0 357 | 358 | #convolutional module 359 | if self.pad: 360 | x = F.pad(x, (1,0,1,0)) #zero padding so state size stays constant 361 | c = F.relu(self.conv1(x)) 362 | if self.pad: 363 | c = F.pad(c, (1,0,1,0)) 364 | c = F.relu(self.conv2(c)) 365 | #append x,y coordinates to every sample in batch 366 | batchsize = c.size(0) 367 | # Filewriter complains about the this way of repeating the xymap, hope repeat is just as fine 368 | # batch_maps = torch.cat(batchsize*[self.xymap]) 369 | batch_maps = self.xymap.repeat(batchsize,1,1,1,) 370 | c = torch.cat((c,batch_maps),1) 371 | if not self.baseline_mode: 372 | #attentional module 373 | #careful: we are flattening out x,y dimensions into 1 dimension, so shape changes from (batchsize, #filters, 374 | # #conv2w, conv2h) to (batchsize, conv2w*conv2h, #filters), because downstream linear layers take last 375 | # dimension to be input features 376 | a = c.view(c.size(0),c.size(1), -1).transpose(1,2) 377 | # n_att_mod passes through attentional module -> n_att_mod stacked modules with weight sharing 378 | for i_att in range(self.n_att_stack): 379 | a = self.attMod(a) 380 | else: 381 | #baseline module 382 | for i in range(self.n_baseMods): 383 | inp = self.baseMod[f"baseline_identity_{i}"](c) #save input for residual connection 384 | #todo: make padding adaptive to kernel size and stride 385 | c = F.pad(c, (1, 1, 1, 1)) #padding so input maintains size 386 | c = self.baseMod[f"baseline_conv_{i}_0"](c) #conv1 387 | c = self.baseMod[f"baseline_batchnorm_{i}_0"](c) #batch-norm 388 | c = F.relu(c) #relu 389 | c = F.pad(c, (1, 1, 1, 1)) #padding so input maintains size 390 | c = self.baseMod[f"baseline_conv_{i}_1"](c) #conv2 391 | c = c + inp #residual connecton 392 | c = self.baseMod[f"baseline_batchnorm_{i}_1"](c) #batch-norm 393 | c = F.relu(c) #relu 394 | a = c.view(c.size(0),c.size(1), -1).transpose(1,2) #flatten (transpose not necessary but we do 395 | # it for consistency w/ attentional module 396 | 397 | #max pooling over "space", i.e. max scalar within each feature map m x n x f -> f 398 | # pool over entity dimension #isn't this a problem with gradients? 399 | # todo: try pooling over feature dimension 400 | kernelsize = a.shape[1] #but during forward passes called by SummaryWriter, a.shape[1] returns a tensor instead 401 | # of an int. if this causes any trouble it can be replaced by w*h 402 | if type(kernelsize) == torch.Tensor: 403 | kernelsize = kernelsize.item() 404 | pooled = F.max_pool1d(a.transpose(1,2), kernel_size=kernelsize) #pool out entity dimension 405 | #policy module: 4xFC256, then project to logits and value 406 | p = self.fc_seq(pooled.view(pooled.size(0),pooled.size(1))) 407 | 408 | return self.critic_linear(p), p, rnn_hxs 409 | 410 | def get_attention_weights(self, inputs, rnn_hxs=None, masks=None): 411 | """ Forward pass through the architecture but only to the point where attention weights are calculated. 412 | Identical up to that point to forward() 413 | """ 414 | 415 | if self.baseline_mode: 416 | raise Exception("Baseline mode set to True. No attention.") 417 | x = inputs / 255.0 418 | #convolutional module 419 | if self.pad: 420 | x = F.pad(x, (1,0,1,0)) #zero padding so state size stays constant 421 | c = F.relu(self.conv1(x)) 422 | if self.pad: 423 | c = F.pad(c, (1,0,1,0)) 424 | c = F.relu(self.conv2(c)) 425 | #append x,y coordinates to every sample in batch 426 | batchsize = c.size(0) 427 | # Filewriter complains about the this way of repeating the xymap, hope repeat is just as fine 428 | # batch_maps = torch.cat(batchsize*[self.xymap]) 429 | batch_maps = self.xymap.repeat(batchsize,1,1,1,) 430 | c = torch.cat((c,batch_maps),1) 431 | #attentional module 432 | #careful: we are flattening out x,y dimensions into 1 dimension, so shape changes from (batchsize, #filters, 433 | # #conv2w, conv2h) to (batchsize, conv2w*conv2h, #filters), because downstream linear layers take last 434 | # dimension to be input features 435 | a = c.view(c.size(0),c.size(1), -1).transpose(1,2) 436 | # n_att_mod passes through attentional module -> n_att_mod stacked modules with weight sharing 437 | att_weights = [] 438 | for i_att in range(self.n_att_stack): 439 | a, weights = self.attMod.get_att_weights(a) 440 | att_weights.append(weights) 441 | return att_weights 442 | 443 | # def get_body_output(self, x): 444 | # #convolutional module 445 | # if self.pad: 446 | # x = F.pad(x, (1,0,1,0)) #zero padding so state size stays constant 447 | # c = F.relu(self.conv1(x)) 448 | # if self.pad: 449 | # c = F.pad(c, (1,0,1,0)) 450 | # c = F.relu(self.conv2(c)) 451 | # #append x,y coordinates to every sample in batch 452 | # batchsize = c.size(0) 453 | # # Filewriter complains about the this way of repeating the xymap, hope repeat is just as fine 454 | # # batch_maps = torch.cat(batchsize*[self.xymap]) 455 | # batch_maps = self.xymap.repeat(batchsize,1,1,1,) 456 | # c = torch.cat((c,batch_maps),1) 457 | # if not self.baseline_mode: 458 | # #attentional module 459 | # #careful: we are flattening out x,y dimensions into 1 dimension, so shape changes from (batchsize, #filters, 460 | # # #conv2w, conv2h) to (batchsize, conv2w*conv2h, #filters), because downstream linear layers take last 461 | # # dimension to be input features 462 | # a = c.view(c.size(0),c.size(1), -1).transpose(1,2) 463 | # # n_att_mod passes through attentional module -> n_att_mod stacked modules with weight sharing 464 | # for i_att in range(self.n_att_stack): 465 | # a = self.attMod(a) 466 | # else: 467 | # #baseline module 468 | # for i in range(self.n_baseMods): 469 | # inp = self.baseMod[f"baseline_identity_{i}"](c) #save input for residual connection 470 | # #todo: make padding adaptive to kernel size and stride 471 | # c = F.pad(c, (1, 1, 1, 1)) #padding so input maintains size 472 | # c = self.baseMod[f"baseline_conv_{i}_0"](c) #conv1 473 | # c = self.baseMod[f"baseline_batchnorm_{i}_0"](c) #batch-norm 474 | # c = F.relu(c) #relu 475 | # c = F.pad(c, (1, 1, 1, 1)) #padding so input maintains size 476 | # c = self.baseMod[f"baseline_conv_{i}_1"](c) #conv2 477 | # c = c + inp #residual connecton 478 | # c = self.baseMod[f"baseline_batchnorm_{i}_1"](c) #batch-norm 479 | # c = F.relu(c) #relu 480 | # a = c.view(c.size(0),c.size(1), -1).transpose(1,2) #flatten (transpose not necessary but we do 481 | # # it for consistency w/ attentional module 482 | # 483 | # #max pooling over "space", i.e. max scalar within each feature map m x n x f -> f 484 | # # pool over entity dimension #isn't this a problem with gradients? 485 | # # todo: try pooling over feature dimension 486 | # kernelsize = a.shape[1] #but during forward passes called by SummaryWriter, a.shape[1] returns a tensor instead 487 | # # of an int. if this causes any trouble it can be replaced by w*h 488 | # if type(kernelsize) == torch.Tensor: 489 | # kernelsize = kernelsize.item() 490 | # pooled = F.max_pool1d(a.transpose(1,2), kernel_size=kernelsize) #pool out entity dimension 491 | # #policy module: 4xFC256, then project to logits and value 492 | # p = self.fc_seq(pooled.view(pooled.size(0),pooled.size(1))) 493 | # return p 494 | # 495 | # def predict(self, state): 496 | # body_output = self.get_body_output(state) 497 | # pi = F.softmax(self.logits(body_output), dim=1) 498 | # return pi, self.value(body_output) 499 | # 500 | # def get_action(self, state): 501 | # probs = self.predict(state)[0].detach().squeeze().numpy() 502 | # action = np.random.choice(4, p=probs) 503 | # return action 504 | # 505 | # def get_log_probs(self, state): 506 | # body_output = self.get_body_output(state) 507 | # logprobs = F.log_softmax(self.logits(body_output), dim=1) 508 | # return logprobs 509 | --------------------------------------------------------------------------------