├── .gitignore ├── README.md ├── agents ├── __init__.py ├── drrn │ ├── drrn_agent.py │ ├── drrn_graph_inv_dyn_agent.py │ └── drrn_inv_dyn_agent.py └── random_agent.py ├── analysis ├── augment_wt.py └── sample_env.py ├── definitions └── defs.py ├── games ├── 905.z5 ├── acorncourt.z5 ├── advent.z5 ├── adventureland.z5 ├── anchor.z8 ├── awaken.z5 ├── balances.z5 ├── deephome.z5 ├── detective.z5 ├── dragon.z5 ├── enchanter.z3 ├── inhumane.z5 ├── jewel.z5 ├── karn.z5 ├── library.z5 ├── ludicorp.z5 ├── moonlit.z5 ├── omniquest.z5 ├── pentari.z5 ├── reverb.z5 ├── snacktime.z8 ├── sorcerer.z3 ├── spellbrkr.z3 ├── spirit.z5 ├── temple.z5 ├── zenon.z5 ├── zork1.z5 ├── zork3.z5 └── ztuu.z5 ├── models ├── __init__.py └── drrn │ ├── drrn.py │ └── drrn_inv_dyn.py ├── scripts ├── run_drrn.sh ├── run_inv_dy.sh ├── run_xtx.sh ├── run_xtx_ablation_det.sh ├── run_xtx_no_mix.sh ├── run_xtx_uniform.sh └── train_rl.py ├── trainers ├── __init__.py ├── drrn │ ├── drrn_graph_inv_dyn_trainer.py │ ├── drrn_inv_dyn_trainer.py │ └── drrn_trainer.py └── trainer.py ├── utils ├── drrn.py ├── env.py ├── il_buffer.py ├── inv_dyn.py ├── logger.py ├── memory.py ├── ngram.py ├── util.py └── vec_env.py └── yml_envs ├── jericho-no-wt.yml ├── jericho-wt.yml └── zork1-environment.yml /.gitignore: -------------------------------------------------------------------------------- 1 | *.out 2 | run.slurm 3 | wandb/ 4 | *.swp 5 | *.txt 6 | logs/ 7 | .vscode/ 8 | saved_objects/ 9 | transformer-*.txt 10 | drrn_tf*.txt 11 | tf_ce*.txt 12 | evaluate.sh 13 | run.sh 14 | saved_models/ 15 | figures/ 16 | aws_secrets.json 17 | profiling/ 18 | wandb_downloads/ 19 | node_modules/ 20 | graph_viz/graphs/ 21 | local_files 22 | 23 | # Byte-compiled / optimized / DLL files 24 | __pycache__/ 25 | *.py[cod] 26 | *$py.class 27 | 28 | # C extensions 29 | *.so 30 | 31 | # Distribution / packaging 32 | .Python 33 | build/ 34 | develop-eggs/ 35 | dist/ 36 | downloads/ 37 | eggs/ 38 | .eggs/ 39 | lib/ 40 | lib64/ 41 | parts/ 42 | sdist/ 43 | var/ 44 | wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | MANIFEST 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *.cover 69 | .hypothesis/ 70 | .pytest_cache/ 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Django stuff: 77 | *.log 78 | local_settings.py 79 | db.sqlite3 80 | 81 | # Flask stuff: 82 | instance/ 83 | .webassets-cache 84 | 85 | # Scrapy stuff: 86 | .scrapy 87 | 88 | # Sphinx documentation 89 | docs/_build/ 90 | 91 | # PyBuilder 92 | target/ 93 | 94 | # Jupyter Notebook 95 | .ipynb_checkpoints 96 | 97 | # pyenv 98 | .python-version 99 | 100 | # celery beat schedule file 101 | celerybeat-schedule 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # XTX: eXploit - Then - eXplore 2 | 3 | **Project page:** https://sites.google.com/princeton.edu/xtx 4 | 5 | ## Requirements 6 | First clone this repo using `git clone https://github.com/princeton-nlp/XTX.git` 7 | 8 | Please create two conda environments as follows: 9 | 1. `conda env create -f yml_envs/jericho-wt.yml` 10 | a. `conda activate jericho-wt` 11 | b. `pip install git+https://github.com/jens321/jericho.git@iclr` 12 | 2. `conda env create -f yml_envs/jericho-no-wt.yml` 13 | 14 | The first set of commands will create a conda environment called `jericho-wt` which has added actions to the game grammar for specific games (see games with * in the paper). The second command will create another conda environment called `jericho-no-wt` which installs an unmodified version of the Jericho library. 15 | 16 | ## Training 17 | All code can be run from the root folder of this project. Please follow the commands below for each specific model: 18 | - XTX: `sh scripts/run_xtx.sh` 19 | - XTX (no-mix): `sh scripts/run_xtx_no_mix.sh` 20 | - XTX (uniform): `sh scrtips/run_xtx_uniform.sh` 21 | - XTX ($\lambda$ = 0, 0.5, or 1): `sh scripts/run_xtx_ablation.sh` 22 | - INV DY: `sh scripts/run_inv_dy.sh` 23 | - DRRN: `sh scripts/run_drrn.sh` 24 | 25 | ### Notes 26 | - You can use `analysis/sample_env.py` for quickly playing around with a sample Jericho environment. Run it using `python3 -m analysis.sample_env`. 27 | 28 | - You can use `analysis/augment_wt.py` for generating the missing action candidates that can be added to the game grammar (games with * in the paper). Run it using `python3 -m analysis.augment_wt`. 29 | 30 | - Note that all models should finish within a day or two given 1 gpu and 8 cpus, except for games where Jericho's valid action handicap is slow (e.g. Library, Dragon). Since Jericho's valid action handicap heavily relies on parallelization, increasing the number of cpus also results in good speedups (e.g. 8 -> 16). 31 | 32 | ## Acknowledgements 33 | We used [Weights & Biases](https://wandb.ai/home) for experiment tracking and visualizations to develop insights for this paper. 34 | 35 | Some of the code borrows from the [TDQN](https://github.com/microsoft/tdqn) repo. 36 | 37 | For any questions please contact Jens Tuyls (`jtuyls@princeton.edu`). 38 | -------------------------------------------------------------------------------- /agents/__init__.py: -------------------------------------------------------------------------------- 1 | from agents.drrn.drrn_agent import DrrnAgent 2 | from agents.drrn.drrn_inv_dyn_agent import DrrnInvDynAgent 3 | from agents.drrn.drrn_graph_inv_dyn_agent import DrrnGraphInvDynAgent 4 | -------------------------------------------------------------------------------- /agents/drrn/drrn_agent.py: -------------------------------------------------------------------------------- 1 | # Built-in Imports 2 | import logging 3 | import traceback 4 | import pickle 5 | from os.path import join as pjoin 6 | from typing import Dict, Union, List, Callable 7 | 8 | # Libraries 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from transformers import BertTokenizer 13 | import wandb 14 | 15 | from jericho.util import clean 16 | 17 | # Custom imports 18 | import utils.logger as logger 19 | from utils.memory import PrioritizedReplayMemory, Transition, StateWithActs, State 20 | from utils.env import JerichoEnv 21 | import utils.ngram as Ngram 22 | 23 | from models import DrrnQNetwork 24 | 25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | 27 | 28 | class DrrnAgent: 29 | def __init__( 30 | self, 31 | tb: logger.Logger, 32 | log: Callable[..., None], 33 | args: Dict[str, Union[str, int, float]], 34 | envs: List[JerichoEnv], 35 | action_models 36 | ): 37 | self.gamma = args.gamma 38 | self.batch_size = args.batch_size 39 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 40 | self.network = DrrnQNetwork( 41 | tb=tb, 42 | log=log, 43 | vocab_size=len(self.tokenizer), 44 | envs=envs, 45 | action_models=action_models, 46 | tokenizer=self.tokenizer, 47 | args=args 48 | ).to(device) 49 | self.target_network = DrrnQNetwork( 50 | tb=tb, 51 | log=log, 52 | vocab_size=len(self.tokenizer), 53 | envs=envs, 54 | action_models=action_models, 55 | tokenizer=self.tokenizer, 56 | args=args 57 | ).to(device) 58 | self.target_network.eval() 59 | # if args.wandb: 60 | # wandb.watch(self.network, log='all') 61 | 62 | self.memory = PrioritizedReplayMemory(args.memory_size, 63 | args.priority_fraction) 64 | self.clip = args.clip 65 | self.tb = tb 66 | self.log = log 67 | 68 | self.optimizer = torch.optim.Adam(self.network.parameters(), 69 | lr=args.learning_rate) 70 | self.action_models = action_models 71 | self.max_acts = args.max_acts 72 | self.envs = envs 73 | 74 | def observe(self, transition, is_prior=False): 75 | """ 76 | Push to replay memory. 77 | """ 78 | self.memory.push(transition, is_prior=is_prior) 79 | 80 | def build_skip_state(self, ob: str, info: Dict[str, Union[List[str], float, str]], next_action_str: str, traj_acts: List[str]) -> StateWithActs: 81 | """Returns a state representation built from various info sources. 82 | 83 | Args: 84 | ob (str): the observation. 85 | info (Dict[str, Union[List[str], float, str]]): info dict. 86 | traj_acts (List[str]): past actions. 87 | next_action_str (str): current action. 88 | 89 | Returns: 90 | StateWithActs: state representation. 91 | """ 92 | acts = Ngram.build_traj_state(self, next_action_str, traj_acts) 93 | 94 | obs_ids = self.tokenizer.encode(ob) 95 | look_ids = self.tokenizer.encode(info['look']) 96 | inv_ids = self.tokenizer.encode(info['inv']) 97 | return StateWithActs(obs_ids, look_ids, inv_ids, acts, info['score']) 98 | 99 | def build_state(self, ob: str, info: Dict[str, Union[List[str], float, str]]) -> State: 100 | """Returns a state representation built from various info sources. 101 | 102 | Args: 103 | ob (str): the observation. 104 | info (Dict[str, Union[List[str], float, str]]): info dict. 105 | 106 | Returns: 107 | State: state representation. 108 | """ 109 | obs_ids = self.tokenizer.encode(ob) 110 | look_ids = self.tokenizer.encode(info['look']) 111 | inv_ids = self.tokenizer.encode(info['inv']) 112 | 113 | return State(obs_ids, look_ids, inv_ids, info['score']) 114 | 115 | def build_states( 116 | self, 117 | obs: List[str], 118 | infos: List[Dict[str, Union[List[str], float, str]]], 119 | action_strs: List[str] = None, 120 | traj_acts: List[List[str]] = None 121 | ) -> Union[List[State], List[StateWithActs]]: 122 | """Build list of state representations. 123 | 124 | Args: 125 | obs (List[str]): list of observations per env. 126 | infos (List[Dict[str, Union[List[str], float, str]]]): list of info dicts per env. 127 | action_strs (List[str], optional): list of current action strings per env. Defaults to None. 128 | traj_acts (List[List[str]], optional): list of past action strings per env. Defaults to None. 129 | 130 | Returns: 131 | Union[List[State], List[StateWithActs]]: list of state representations. 132 | """ 133 | if action_strs is None and traj_acts is None: 134 | return [self.build_state(ob, info) for ob, info in zip(obs, infos)] 135 | else: 136 | return [self.build_skip_state(ob, info, action_str, traj_act) for ob, info, action_str, traj_act in zip(obs, infos, action_strs, traj_acts)] 137 | 138 | def encode(self, obs: List[str]): 139 | """ 140 | Encode a list of strings with [SEP] at the end. 141 | """ 142 | return [self.tokenizer.encode(o) for o in obs] 143 | 144 | def transfer_weights(self): 145 | """ 146 | TODO 147 | """ 148 | self.target_network.load_state_dict(self.network.state_dict()) 149 | 150 | def act(self, states, poss_acts, poss_act_strs, sample=True): 151 | """ 152 | Parameters 153 | ---------- 154 | poss_acts: [ 155 | [[], [], ...], 156 | ... (* number of env) 157 | ] 158 | 159 | Returns 160 | ------- 161 | act_ids: the action IDs of the chosen action per env 162 | [ 163 | [, , ...], 164 | ... (* number of env) 165 | ] 166 | idxs: index of the chosen action per env 167 | [ 168 | , 169 | ... (* number of env) 170 | ] 171 | qvals: tuple of qvals per valid action set (i.e. per env) 172 | ( 173 | [, , ...], 174 | ... (* number of env) 175 | ) 176 | """ 177 | # Idxs: indices of the sampled (from the Q-vals) actions 178 | idxs, qvals = self.network.act(states, poss_acts, poss_act_strs) 179 | 180 | # Get the sampled action for each environment 181 | act_ids = [poss_acts[batch][idx] for batch, idx in enumerate(idxs)] 182 | return act_ids, idxs, qvals 183 | 184 | def act_topk(self, states, poss_acts): 185 | """ 186 | """ 187 | idxs = self.network.act_topk(states, poss_acts) 188 | 189 | return idxs 190 | 191 | def update(self): 192 | if len(self.memory) < self.batch_size: 193 | return 194 | 195 | transitions = self.memory.sample(self.batch_size) 196 | batch = Transition(*zip(*transitions)) 197 | 198 | # Compute Q(s', a') for all a' 199 | with torch.no_grad(): 200 | next_qvals = self.target_network(batch.next_state, batch.next_acts) 201 | 202 | # Take the max over next q-values 203 | next_qvals = torch.tensor([vals.max() for vals in next_qvals], 204 | device=device) 205 | 206 | # Zero all the next_qvals that are done 207 | next_qvals = next_qvals * ( 208 | 1 - torch.tensor(batch.done, dtype=torch.float, device=device)) 209 | targets = torch.tensor(batch.reward, dtype=torch.float, 210 | device=device) + self.gamma * next_qvals 211 | 212 | # Next compute Q(s, a) 213 | act_sizes = [1 for act in batch.act] 214 | 215 | nested_acts = tuple([[a] for a in batch.act]) 216 | qvals = self.network(batch.state, nested_acts) 217 | qvals = torch.cat(qvals) 218 | 219 | # Compute Huber loss 220 | loss = F.smooth_l1_loss(qvals, targets.detach()) 221 | 222 | self.tb.logkv_mean('Q', qvals.mean()) 223 | 224 | # Backprop 225 | self.optimizer.zero_grad() 226 | loss.backward() 227 | nn.utils.clip_grad_norm_(self.network.parameters(), self.clip) 228 | self.optimizer.step() 229 | 230 | return loss.item() 231 | 232 | def load(self, run_id: str, weight_file: str, memory_file: str): 233 | try: 234 | api = wandb.Api() 235 | run = api.run(f"princeton-nlp/text-games/{run_id}") 236 | run.file(f"{weight_file}.pt").download(wandb.run.dir) 237 | run.file(f"{memory_file}.pkl").download(wandb.run.dir) 238 | 239 | self.memory = pickle.load( 240 | open(pjoin(wandb.run.dir, f"{memory_file}.pkl"), 'rb')) 241 | self.network.load_state_dict( 242 | torch.load(pjoin(wandb.run.dir, f"{weight_file}.pt"))) 243 | except Exception as e: 244 | self.log(f"Error loading model {e}") 245 | logging.error(traceback.format_exc()) 246 | raise Exception("Didn't properly load model!") 247 | 248 | def load_memory(self, run_id: str, memory_file: str): 249 | try: 250 | api = wandb.Api() 251 | run = api.run(f"princeton-nlp/text-games/{run_id}") 252 | run.file(f"{memory_file}.pkl").download(wandb.run.dir) 253 | 254 | self.memory = pickle.load( 255 | open(pjoin(wandb.run.dir, f"{memory_file}.pkl"), 'rb')) 256 | except Exception as e: 257 | self.log(f"Error loading replay memory {e}") 258 | logging.error(traceback.format_exc()) 259 | raise Exception("Didn't properly load replay memory!") 260 | 261 | def save(self, step: int, traj: List = None): 262 | try: 263 | # save locally 264 | pickle.dump( 265 | self.memory, 266 | open(pjoin(wandb.run.dir, 'memory_{}.pkl'.format(step)), 'wb')) 267 | torch.save(self.network.state_dict(), 268 | pjoin(wandb.run.dir, 'weights_{}.pt'.format(step))) 269 | 270 | if traj is not None: 271 | pickle.dump( 272 | traj, open( 273 | pjoin(wandb.run.dir, 'traj_{}.pkl'.format(step)), 'wb') 274 | ) 275 | wandb.save(pjoin(wandb.run.dir, 'traj_{}.pkl'.format(step))) 276 | 277 | # upload to wandb 278 | wandb.save(pjoin(wandb.run.dir, 'weights_{}.pt'.format(step))) 279 | wandb.save(pjoin(wandb.run.dir, 'memory_{}.pkl'.format(step))) 280 | except Exception as e: 281 | print("Error saving model.") 282 | logging.error(traceback.format_exc()) 283 | -------------------------------------------------------------------------------- /agents/drrn/drrn_graph_inv_dyn_agent.py: -------------------------------------------------------------------------------- 1 | # Libraries 2 | import torch 3 | 4 | # Custom Imports 5 | from agents import DrrnInvDynAgent 6 | from utils.il_buffer import ILBuffer 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | 11 | class DrrnGraphInvDynAgent(DrrnInvDynAgent): 12 | def __init__(self, args, tb, log, envs, action_models): 13 | super().__init__(args, action_models, tb, log, envs) 14 | 15 | if args.use_il_buffer_sampler: 16 | self.il_buffer = ILBuffer(self, args, log, tb) 17 | self.graph_policies = [None] * len(envs) 18 | self.fell_off_trajectory = [False for _ in range(len(envs))] 19 | self.use_il = args.use_il 20 | 21 | def act(self, states, poss_acts, poss_act_strs, sample=True): 22 | """ 23 | Parameters 24 | ---------- 25 | poss_acts: [ 26 | [[], [], ...], 27 | ... (* number of env) 28 | ] 29 | 30 | Returns 31 | ------- 32 | act_ids: the action IDs of the chosen action per env 33 | [ 34 | [, , ...], 35 | ... (* number of env) 36 | ] 37 | idxs: index of the chosen action per env 38 | [ 39 | , 40 | ... (* number of env) 41 | ] 42 | qvals: tuple of qvals per valid action set (i.e. per env) 43 | ( 44 | [, , ...], 45 | ... (* number of env) 46 | ) 47 | """ 48 | # Idxs: indices of the sampled (from the Q-vals) actions 49 | idxs, qvals = self.network.act( 50 | states, poss_acts, poss_act_strs) 51 | 52 | # Get the sampled action for each environment 53 | act_ids = [poss_acts[batch][idx] for batch, idx in enumerate(idxs)] 54 | return act_ids, idxs, qvals -------------------------------------------------------------------------------- /agents/drrn/drrn_inv_dyn_agent.py: -------------------------------------------------------------------------------- 1 | # Built-in Imports 2 | import pickle 3 | from os.path import join as pjoin 4 | import logging 5 | from typing import List 6 | import traceback 7 | 8 | # Libraries 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import wandb 13 | 14 | # Custom Imports 15 | from utils.memory import ABReplayMemory, Transition, StateWithActs, State 16 | import utils.ngram as Ngram 17 | import utils.inv_dyn as InvDyn 18 | 19 | from agents import DrrnAgent 20 | 21 | from models import DrrnInvDynQNetwork 22 | 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | 25 | 26 | class DrrnInvDynAgent(DrrnAgent): 27 | def __init__(self, args, action_models, tb, log, envs): 28 | super().__init__(tb, log, args, envs, action_models) 29 | 30 | self.network = DrrnInvDynQNetwork( 31 | len(self.tokenizer), 32 | args, 33 | envs, 34 | self.tokenizer, 35 | action_models, 36 | tb, 37 | log 38 | ).to(device) 39 | 40 | self.target_network = DrrnInvDynQNetwork( 41 | len(self.tokenizer), 42 | args, 43 | envs, 44 | self.tokenizer, 45 | action_models, 46 | tb, 47 | log 48 | ).to(device) 49 | self.target_network.eval() 50 | self.network.tokenizer = self.tokenizer 51 | 52 | self.memory = ABReplayMemory(args.memory_size, args.memory_alpha) 53 | 54 | self.optimizer = torch.optim.Adam(self.network.parameters(), 55 | lr=args.learning_rate) 56 | 57 | # Inverse Dynamics Stuff 58 | self.type_inv = args.type_inv 59 | self.type_for = args.type_for 60 | self.w_inv = args.w_inv 61 | self.w_for = args.w_for 62 | self.w_act = args.w_act 63 | self.perturb = args.perturb 64 | self.act_obs = args.act_obs 65 | 66 | def build_skip_state(self, ob, info, action_str: str, traj_act: List[str]): 67 | """ Returns a state representation built from various info sources. """ 68 | if self.act_obs: 69 | acts = self.encode(info['valid']) 70 | obs_ids, look_ids, inv_ids = [], [], [] 71 | for act in acts: 72 | obs_ids += act 73 | return State(obs_ids, look_ids, inv_ids) 74 | obs_ids = self.tokenizer.encode(ob) 75 | look_ids = self.tokenizer.encode(info['look']) 76 | inv_ids = self.tokenizer.encode(info['inv']) 77 | 78 | acts = Ngram.build_traj_state(self, action_str, traj_act) 79 | 80 | return StateWithActs(obs_ids, look_ids, inv_ids, acts, info['score']) 81 | 82 | def build_state(self, ob, info): 83 | """ Returns a state representation built from various info sources. """ 84 | if self.act_obs: 85 | acts = self.encode(info['valid']) 86 | obs_ids, look_ids, inv_ids = [], [], [] 87 | for act in acts: 88 | obs_ids += act 89 | return State(obs_ids, look_ids, inv_ids) 90 | obs_ids = self.tokenizer.encode(ob) 91 | look_ids = self.tokenizer.encode(info['look']) 92 | inv_ids = self.tokenizer.encode(info['inv']) 93 | 94 | return State(obs_ids, look_ids, inv_ids, info['score']) 95 | 96 | def q_loss(self, transitions, need_qvals=False): 97 | batch = Transition(*zip(*transitions)) 98 | 99 | # Compute Q(s', a') for all a' 100 | # TODO: Use a target network??? 101 | with torch.no_grad(): 102 | next_qvals = self.target_network(batch.next_state, batch.next_acts) 103 | # Take the max over next q-values 104 | next_qvals = torch.tensor([vals.max() 105 | for vals in next_qvals], device=device) 106 | # Zero all the next_qvals that are done 107 | next_qvals = next_qvals * \ 108 | (1-torch.tensor(batch.done, dtype=torch.float, device=device)) 109 | targets = torch.tensor( 110 | batch.reward, dtype=torch.float, device=device) + self.gamma * next_qvals 111 | 112 | # Next compute Q(s, a) 113 | # Nest each action in a list - so that it becomes the only admissible cmd 114 | nested_acts = tuple([[a] for a in batch.act]) 115 | qvals = self.network(batch.state, nested_acts) 116 | # Combine the qvals: Maybe just do a greedy max for generality 117 | qvals = torch.cat(qvals) 118 | loss = F.smooth_l1_loss(qvals, targets.detach()) 119 | 120 | return (loss, qvals) if need_qvals else loss 121 | 122 | def update(self): 123 | if len(self.memory) < self.batch_size: 124 | return None 125 | 126 | transitions = self.memory.sample(self.batch_size) 127 | batch = Transition(*zip(*transitions)) 128 | nested_acts = tuple([[a] for a in batch.act]) 129 | terms, loss = {}, 0 130 | 131 | # Compute Q learning Huber loss 132 | terms['Loss_q'], qvals = self.q_loss(transitions, need_qvals=True) 133 | loss += terms['Loss_q'] 134 | 135 | # Compute Inverse dynamics loss 136 | if self.w_inv > 0: 137 | if self.type_inv == 'decode': 138 | terms['Loss_id'], terms['Acc_id'] = InvDyn.inv_loss_decode(self.network, 139 | batch.state, batch.next_state, nested_acts, hat=True) 140 | elif self.type_inv == 'ce': 141 | terms['Loss_id'], terms['Acc_id'] = InvDyn.inv_loss_ce(self.network, 142 | batch.state, batch.next_state, nested_acts, batch.acts) 143 | else: 144 | raise NotImplementedError 145 | loss += self.w_inv * terms['Loss_id'] 146 | 147 | # Compute Act reconstruction loss 148 | if self.w_act > 0: 149 | terms['Loss_act'], terms['Acc_act'] = InvDyn.inv_loss_decode(self.network, 150 | batch.state, batch.next_state, nested_acts, hat=False) 151 | loss += self.w_act * terms['Loss_act'] 152 | 153 | # Compute Forward dynamics loss 154 | if self.w_for > 0: 155 | if self.type_for == 'l2': 156 | terms['Loss_fd'] = InvDyn.for_loss_l2(self.network, 157 | batch.state, batch.next_state, nested_acts) 158 | elif self.type_for == 'ce': 159 | terms['Loss_fd'], terms['Acc_fd'] = InvDyn.for_loss_ce(self.network, 160 | batch.state, batch.next_state, nested_acts, batch.acts) 161 | elif self.type_for == 'decode': 162 | terms['Loss_fd'], terms['Acc_fd'] = InvDyn.for_loss_decode(self.network, 163 | batch.state, batch.next_state, nested_acts, hat=True) 164 | elif self.type_for == 'decode_obs': 165 | terms['Loss_fd'], terms['Acc_fd'] = InvDyn.for_loss_decode(self.network, 166 | batch.state, batch.next_state, nested_acts, hat=False) 167 | 168 | loss += self.w_for * terms['Loss_fd'] 169 | 170 | # Backward 171 | terms.update({'Loss': loss, 'Q': qvals.mean()}) 172 | self.optimizer.zero_grad() 173 | loss.backward() 174 | nn.utils.clip_grad_norm_(self.network.parameters(), self.clip) 175 | self.optimizer.step() 176 | return {k: float(v) for k, v in terms.items()} 177 | -------------------------------------------------------------------------------- /agents/random_agent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import sentencepiece as spm 4 | 5 | class RandomAgent(): 6 | def __init__(self, args): 7 | self.sp = spm.SentencePieceProcessor() 8 | self.sp.Load(args.spm_path) 9 | 10 | def act(self, valid_ids): 11 | """ 12 | """ 13 | # print(valid_ids) 14 | return [np.random.randint(len(valid_acts)) for valid_acts in valid_ids] 15 | 16 | def act_topk(self, valid_ids): 17 | """ 18 | """ 19 | return [np.random.permutation(len(valid_acts)) for valid_acts in valid_ids] 20 | 21 | def encode(self, obs:[str]): 22 | """ 23 | Encode a list of strings 24 | """ 25 | return [self.sp.EncodeAsIds(o) for o in obs] -------------------------------------------------------------------------------- /analysis/augment_wt.py: -------------------------------------------------------------------------------- 1 | # Built-in imports 2 | from typing import List 3 | 4 | # Libraries 5 | from tqdm import tqdm 6 | 7 | # Custom imports 8 | from utils.util import inv_process_action, process_action, load_object 9 | from utils.env import JerichoEnv 10 | 11 | from scripts.train_rl import parse_args 12 | 13 | GAME_DIR = "./games" 14 | 15 | def is_act_missing_in_wt(candidates: List[str], act: str, env): 16 | """ 17 | Check if wt action is in Jericho candidates. 18 | """ 19 | state = env.get_state() 20 | env.step(act) 21 | wt_diff = str(env._get_world_diff()) 22 | env.set_state(state) 23 | if wt_diff == '((), (), (), ())': 24 | return False 25 | # return process_action(act) not in list(map(lambda x: process_action(x), candidates)) 26 | 27 | candidate_diffs = [] 28 | for candidate in candidates: 29 | env.set_state(state) 30 | env.step(candidate) 31 | candidate_diffs.append(str(env._get_world_diff())) 32 | env.set_state(state) 33 | 34 | gold_acts = [] 35 | for can_diff, can_act in zip(candidate_diffs, candidates): 36 | if can_diff == wt_diff: 37 | gold_acts.append(can_act) 38 | break 39 | 40 | return len(gold_acts) == 0 41 | 42 | 43 | def get_missing_wt_acts(game: str): 44 | """ 45 | Get missing wt acts. 46 | """ 47 | 48 | cache = dict() 49 | args = parse_args() 50 | env = JerichoEnv("{}/{}".format(GAME_DIR, game), cache=cache, args=args) 51 | ob, info = env.reset() 52 | 53 | missing_wt = set() 54 | walkthrough = env.get_walkthrough() 55 | for i, act in tqdm(enumerate(walkthrough), desc='Getting missing wt acts ...'): 56 | candidates = info['valid'] 57 | missing = is_act_missing_in_wt(candidates, act, env.env) 58 | if missing: 59 | missing_wt = missing_wt.union( 60 | {process_action(act), inv_process_action(act)}) 61 | 62 | next_ob, reward, done, info = env.step(act) 63 | 64 | ob = next_ob 65 | 66 | print("Cache hits: {}".format(env.cache_hits)) 67 | return missing_wt 68 | 69 | def main(): 70 | games = ["zork1.z5"] 71 | for game in games: 72 | missing_acts = get_missing_wt_acts(game) 73 | with open('./missing_wt_acts_{}.txt'.format(game), 'w') as f: 74 | for act in missing_acts: 75 | f.write('{};'.format(act)) 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /analysis/sample_env.py: -------------------------------------------------------------------------------- 1 | # Built-in imports 2 | import random 3 | 4 | # Libraries 5 | import numpy as np 6 | from jericho import * 7 | 8 | GAME_DIR = "./games" 9 | 10 | def main(): 11 | """Simple setup to be able to easily play around with Jericho env. 12 | """ 13 | seed = 1 14 | np.random.seed(seed) 15 | random.seed(seed) 16 | env = FrotzEnv("{}/{}".format(GAME_DIR, 'zork1.z5'), seed=seed) 17 | obs, info = env.reset() 18 | print(obs) 19 | print("Valid acts:", env.get_valid_actions()) 20 | total = info['score'] 21 | 22 | i = 0 23 | done = False 24 | buffer = env.get_walkthrough() 25 | while True: 26 | if i >= len(buffer): 27 | act = input() 28 | else: 29 | act = buffer[i].strip() 30 | 31 | observation, reward, done, info = env.step(act) 32 | 33 | # state = env.get_state() 34 | # inv, _, _, _ = env.step('inventory') 35 | # env.set_state(state) 36 | 37 | # state = env.get_state() 38 | # loc, _, _, _ = env.step('look') 39 | # env.set_state(state) 40 | 41 | print("(reward: {})".format(reward) if reward > 0 else "") 42 | 43 | 44 | print("Action: {}".format(act)) 45 | print("Reward: {}".format(reward)) 46 | print("Obs: {}".format(observation)) 47 | print("Valid acts: {}".format(env.get_valid_actions())) 48 | 49 | total += reward 50 | i += 1 51 | 52 | print("SCORE: {}".format(total)) 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /definitions/defs.py: -------------------------------------------------------------------------------- 1 | GPT = 'gpt' 2 | NGRAM = 'ngram' 3 | TRANSFORMER = 'transformer' 4 | 5 | DRRN = 'drrn' 6 | INV_DY = 'inv_dy' 7 | XTX = 'xtx' 8 | 9 | UNIFORM = 'uniform' 10 | SOFTMAX_LM = 'softmax_lm' 11 | UNIFORM_LM_TOPK = 'uniform_lm_topk' 12 | 13 | OBJECTS_DIR = './saved_objects' 14 | SAVED_MODEL_DIR = "./saved_models" 15 | ANALYSIS_DIR = "./trajectory_analysis" 16 | GAME_DIR = "./games" -------------------------------------------------------------------------------- /games/905.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/905.z5 -------------------------------------------------------------------------------- /games/acorncourt.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/acorncourt.z5 -------------------------------------------------------------------------------- /games/advent.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/advent.z5 -------------------------------------------------------------------------------- /games/adventureland.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/adventureland.z5 -------------------------------------------------------------------------------- /games/anchor.z8: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/anchor.z8 -------------------------------------------------------------------------------- /games/awaken.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/awaken.z5 -------------------------------------------------------------------------------- /games/balances.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/balances.z5 -------------------------------------------------------------------------------- /games/deephome.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/deephome.z5 -------------------------------------------------------------------------------- /games/detective.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/detective.z5 -------------------------------------------------------------------------------- /games/dragon.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/dragon.z5 -------------------------------------------------------------------------------- /games/enchanter.z3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/enchanter.z3 -------------------------------------------------------------------------------- /games/inhumane.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/inhumane.z5 -------------------------------------------------------------------------------- /games/jewel.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/jewel.z5 -------------------------------------------------------------------------------- /games/karn.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/karn.z5 -------------------------------------------------------------------------------- /games/library.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/library.z5 -------------------------------------------------------------------------------- /games/ludicorp.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/ludicorp.z5 -------------------------------------------------------------------------------- /games/moonlit.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/moonlit.z5 -------------------------------------------------------------------------------- /games/omniquest.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/omniquest.z5 -------------------------------------------------------------------------------- /games/pentari.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/pentari.z5 -------------------------------------------------------------------------------- /games/reverb.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/reverb.z5 -------------------------------------------------------------------------------- /games/snacktime.z8: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/snacktime.z8 -------------------------------------------------------------------------------- /games/sorcerer.z3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/sorcerer.z3 -------------------------------------------------------------------------------- /games/spellbrkr.z3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/spellbrkr.z3 -------------------------------------------------------------------------------- /games/spirit.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/spirit.z5 -------------------------------------------------------------------------------- /games/temple.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/temple.z5 -------------------------------------------------------------------------------- /games/zenon.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/zenon.z5 -------------------------------------------------------------------------------- /games/zork1.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/zork1.z5 -------------------------------------------------------------------------------- /games/zork3.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/zork3.z5 -------------------------------------------------------------------------------- /games/ztuu.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/XTX/79f5ed80459bdc998edbb0b51160f912bf08c80f/games/ztuu.z5 -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.drrn.drrn import DrrnQNetwork 2 | from models.drrn.drrn_inv_dyn import DrrnInvDynQNetwork 3 | -------------------------------------------------------------------------------- /models/drrn/drrn.py: -------------------------------------------------------------------------------- 1 | # Built-in Imports 2 | import itertools 3 | from typing import Callable, Dict, Union, List 4 | 5 | # Libraries 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | # Custom Imports 11 | from utils import logger 12 | import utils.ngram as Ngram 13 | import utils.drrn as Drrn 14 | from utils.memory import State, StateWithActs 15 | from utils.env import JerichoEnv 16 | 17 | 18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | 20 | 21 | class DrrnQNetwork(nn.Module): 22 | def __init__(self, 23 | tb: logger.Logger, 24 | log: Callable[..., None], 25 | vocab_size: int, 26 | envs: List[JerichoEnv], 27 | action_models, 28 | tokenizer, 29 | args: Dict[str, Union[str, int, float]]): 30 | super(DrrnQNetwork, self).__init__() 31 | self.sample_argmax = args.sample_argmax 32 | self.sample_uniform = args.sample_uniform 33 | self.envs = envs 34 | 35 | self.log = log 36 | self.tb = tb 37 | 38 | Drrn.init_model(self, args, vocab_size, tokenizer) 39 | 40 | self.use_action_model = args.use_action_model 41 | if self.use_action_model: 42 | Ngram.init_model(self, action_models, args) 43 | 44 | def forward(self, state_batch, act_batch): 45 | 46 | # Zip the state_batch into an easy access format 47 | if self.use_action_model: 48 | state = StateWithActs(*zip(*state_batch)) 49 | else: 50 | state = State(*zip(*state_batch)) 51 | act_sizes = [len(a) for a in act_batch] 52 | # Combine next actions into one long list 53 | act_batch = list(itertools.chain.from_iterable(act_batch)) 54 | act_out = Drrn.packed_rnn(self, act_batch, self.act_encoder) 55 | # Encode the various aspects of the state 56 | obs_out = Drrn.packed_rnn(self, state.obs, self.obs_encoder) 57 | look_out = Drrn.packed_rnn(self, state.description, self.look_encoder) 58 | inv_out = Drrn.packed_rnn(self, state.inventory, self.inv_encoder) 59 | state_out = torch.cat((obs_out, look_out, inv_out), dim=1) 60 | # Expand the state to match the batches of actions 61 | state_out = torch.cat( 62 | [state_out[i].repeat(j, 1) for i, j in enumerate(act_sizes)], 63 | dim=0) 64 | 65 | z = torch.cat((state_out, act_out), dim=1) # Concat along hidden_dim 66 | z = F.relu(self.hidden(z)) 67 | drrn_scores = self.act_scorer(z).squeeze(-1) 68 | 69 | # Split up the q-values by batch 70 | return drrn_scores.split(act_sizes) 71 | 72 | @torch.no_grad() 73 | def act( 74 | self, 75 | states: List[Union[State, StateWithActs]], 76 | valid_ids: List[List[List[int]]], 77 | valid_strs: List[List[str]], 78 | graph_masks=None 79 | ): 80 | """ 81 | Returns an action-string, optionally sampling from the distribution 82 | of Q-Values. 83 | """ 84 | return Drrn.act(self, states, valid_ids, valid_strs, self.log, graph_masks) 85 | -------------------------------------------------------------------------------- /models/drrn/drrn_inv_dyn.py: -------------------------------------------------------------------------------- 1 | # Built-in Imports 2 | from typing import Dict, Union, Callable, List 3 | 4 | # Libraries 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | # Custom Imports 10 | from utils.vec_env import VecEnv 11 | from utils.memory import State, StateWithActs 12 | import utils.logger as logger 13 | import utils.ngram as Ngram 14 | import utils.inv_dyn as InvDyn 15 | 16 | from models.drrn.drrn import DrrnQNetwork 17 | 18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | 20 | 21 | class DrrnInvDynQNetwork(DrrnQNetwork): 22 | """ 23 | Deep Reinforcement Relevance Network - He et al. '16 24 | """ 25 | 26 | def __init__( 27 | self, 28 | vocab_size: int, 29 | args: Dict[str, Union[str, int, float]], 30 | envs: VecEnv, 31 | tokenizer, 32 | action_models, 33 | tb: logger.Logger = None, 34 | log: Callable[..., None] = None 35 | ): 36 | super().__init__(tb, log, vocab_size, envs, action_models, tokenizer, args) 37 | self.embedding = nn.Embedding(vocab_size, args.drrn_embedding_dim) 38 | self.obs_encoder = nn.GRU( 39 | args.drrn_embedding_dim, args.drrn_hidden_dim) 40 | self.look_encoder = nn.GRU( 41 | args.drrn_embedding_dim, args.drrn_hidden_dim) 42 | self.inv_encoder = nn.GRU( 43 | args.drrn_embedding_dim, args.drrn_hidden_dim) 44 | self.act_encoder = nn.GRU( 45 | args.drrn_embedding_dim, args.drrn_hidden_dim) 46 | self.act_scorer = nn.Linear(args.drrn_hidden_dim, 1) 47 | 48 | self.drrn_hidden_dim = args.drrn_hidden_dim 49 | self.hidden = nn.Linear(2 * args.drrn_hidden_dim, args.drrn_hidden_dim) 50 | # self.hidden = nn.Sequential(nn.Linear(2 * args.drrn_hidden_dim, 2 * args.drrn_hidden_dim), nn.Linear(2 * args.drrn_hidden_dim, args.drrn_hidden_dim), nn.Linear(args.drrn_hidden_dim, args.drrn_hidden_dim)) 51 | 52 | self.state_encoder = nn.Linear( 53 | 3 * args.drrn_hidden_dim + (1 if self.augment_state_with_score else 0), args.drrn_hidden_dim) 54 | self.inverse_dynamics = nn.Sequential(nn.Linear( 55 | 2 * args.drrn_hidden_dim, 2 * args.drrn_hidden_dim), nn.ReLU(), nn.Linear(2 * args.drrn_hidden_dim, args.drrn_hidden_dim)) 56 | self.forward_dynamics = nn.Sequential(nn.Linear( 57 | 2 * args.drrn_hidden_dim, 2 * args.drrn_hidden_dim), nn.ReLU(), nn.Linear(2 * args.drrn_hidden_dim, args.drrn_hidden_dim)) 58 | 59 | self.act_decoder = nn.GRU( 60 | args.drrn_hidden_dim, args.drrn_embedding_dim) 61 | self.act_fc = nn.Linear(args.drrn_embedding_dim, vocab_size) 62 | 63 | self.obs_decoder = nn.GRU( 64 | args.drrn_hidden_dim, args.drrn_embedding_dim) 65 | self.obs_fc = nn.Linear(args.drrn_embedding_dim, vocab_size) 66 | 67 | self.fix_rep = args.fix_rep 68 | self.hash_rep = args.hash_rep 69 | self.act_obs = args.act_obs 70 | self.hash_cache = {} 71 | 72 | self.use_action_model = args.use_action_model 73 | if self.use_action_model: 74 | Ngram.init_model(self, action_models, args) 75 | 76 | def forward(self, state_batch: List[Union[State, StateWithActs]], act_batch): 77 | """ 78 | Batched forward pass. 79 | obs_id_batch: iterable of unpadded sequence ids 80 | act_batch: iterable of lists of unpadded admissible command ids 81 | Returns a tuple of tensors containing q-values for each item in the batch 82 | """ 83 | state_out = InvDyn.state_rep(self, state_batch) 84 | act_sizes, act_out = InvDyn.act_rep(self, act_batch) 85 | # Expand the state to match the batches of actions 86 | state_out = torch.cat([state_out[i].repeat(j, 1) 87 | for i, j in enumerate(act_sizes)], dim=0) 88 | z = torch.cat((state_out, act_out), dim=1) # Concat along hidden_dim 89 | z = F.relu(self.hidden(z)) 90 | act_values = self.act_scorer(z).squeeze(-1) 91 | # Split up the q-values by batch 92 | return act_values.split(act_sizes) 93 | -------------------------------------------------------------------------------- /scripts/run_drrn.sh: -------------------------------------------------------------------------------- 1 | LOG_FOLDER='drrn_zork1' 2 | GAME='zork1.z5' 3 | SEED=0 4 | JERICHO_SEED=$SEED # set to -1 if you want stochastic version 5 | MODEL_NAME='drrn' 6 | JERICHO_ADD_WT='add_wt' # change to 'no_add_wt' if you don't want to add extra actions to game grammar 7 | 8 | python3 -m scripts.train_rl --output_dir logs/${LOG_FOLDER} \ 9 | --rom_path games/${GAME} \ 10 | --seed ${SEED} \ 11 | --jericho_seed ${JERICHO_SEED} \ 12 | --model_name ${MODEL_NAME} \ 13 | --eval_freq 10000000 \ 14 | --jericho_add_wt ${JERICHO_ADD_WT} -------------------------------------------------------------------------------- /scripts/run_inv_dy.sh: -------------------------------------------------------------------------------- 1 | LOG_FOLDER='inv_dy_zork1' 2 | GAME='zork1.z5' 3 | SEED=0 4 | JERICHO_SEED=$SEED # set to -1 if you want stochastic version 5 | MODEL_NAME='inv_dy' 6 | JERICHO_ADD_WT='add_wt' # change to 'no_add_wt' if you don't want to add extra actions to game grammar 7 | 8 | python3 -m scripts.train_rl --output_dir logs/${LOG_FOLDER} \ 9 | --rom_path games/${GAME} \ 10 | --seed ${SEED} \ 11 | --jericho_seed ${JERICHO_SEED} \ 12 | --model_name ${MODEL_NAME} \ 13 | --eval_freq 10000000 \ 14 | --memory_size 10000 \ 15 | --w_inv 1 \ 16 | --r_for 1 \ 17 | --w_act 1 \ 18 | --jericho_add_wt ${JERICHO_ADD_WT} -------------------------------------------------------------------------------- /scripts/run_xtx.sh: -------------------------------------------------------------------------------- 1 | LOG_FOLDER='xtx_zork1' 2 | GAME='zork1.z5' 3 | SEED=0 4 | JERICHO_SEED=$SEED # set to -1 if you want stochastic version 5 | MODEL_NAME='xtx' 6 | JERICHO_ADD_WT='add_wt' # change to 'no_add_wt' if you don't want to add extra actions to game grammar 7 | 8 | # NOTE: r_for below corresponds to alpha_1 in the paper 9 | 10 | python3 -m scripts.train_rl --output_dir logs/${LOG_FOLDER} \ 11 | --rom_path games/${GAME} \ 12 | --seed ${SEED} \ 13 | --jericho_seed ${JERICHO_SEED} \ 14 | --model_name ${MODEL_NAME} \ 15 | --eval_freq 10000000 \ 16 | --memory_size 10000 \ 17 | --T 1 \ 18 | --w_inv 1 \ 19 | --r_for 1 \ 20 | --w_act 1 \ 21 | --graph_num_explore_steps 50 \ 22 | --graph_rescore_freq 1000000 \ 23 | --env_step_limit 50 \ 24 | --graph_score_temp 1 \ 25 | --graph_q_temp 10000 \ 26 | --graph_alpha 0 \ 27 | --log_top_blue_acts_freq 500 \ 28 | --use_action_model 1 \ 29 | --action_model_update_freq 500 \ 30 | --action_model_type transformer \ 31 | --il_max_context 512 \ 32 | --max_acts 2 \ 33 | --il_vocab_size 2000 \ 34 | --il_k 10 \ 35 | --il_temp 3 \ 36 | --use_il 1 \ 37 | --il_batch_size 64 \ 38 | --il_max_num_epochs 40 \ 39 | --il_len_scale 1 \ 40 | --use_il_graph_sampler 0 \ 41 | --use_il_buffer_sampler 1 \ 42 | --il_top_p 1 \ 43 | --il_use_dropout 1 \ 44 | --traj_dropout_prob 0.005 \ 45 | --jericho_add_wt ${JERICHO_ADD_WT} -------------------------------------------------------------------------------- /scripts/run_xtx_ablation_det.sh: -------------------------------------------------------------------------------- 1 | LOG_FOLDER='xtx_ablation_zork1' 2 | GAME='zork1.z5' 3 | SEED=0 4 | JERICHO_SEED=$SEED # set to -1 if you want stochastic version 5 | MODEL_NAME='xtx' 6 | JERICHO_ADD_WT='add_wt' # change to 'no_add_wt' if you don't want to add extra actions to game grammar 7 | 8 | # NOTE: 9 | # - traj_dropout_prob below corresponds to the lambda param in the paper 10 | # - r_for below corresponds to alpha_1 in the paper 11 | 12 | python3 -m scripts.train_rl --output_dir logs/${LOG_FOLDER} \ 13 | --rom_path games/${GAME} \ 14 | --seed ${SEED} \ 15 | --jericho_seed ${JERICHO_SEED} \ 16 | --model_name ${MODEL_NAME} \ 17 | --eval_freq 10000000 \ 18 | --memory_size 10000 \ 19 | --T 1 \ 20 | --w_inv 1 \ 21 | --r_for 1 \ 22 | --w_act 1 \ 23 | --graph_num_explore_steps 50 \ 24 | --graph_rescore_freq 1000000 \ 25 | --env_step_limit 50 \ 26 | --graph_score_temp 1 \ 27 | --graph_q_temp 10000 \ 28 | --graph_alpha 0 \ 29 | --log_top_blue_acts_freq 500 \ 30 | --use_action_model 1 \ 31 | --action_model_update_freq 500 \ 32 | --action_model_type transformer \ 33 | --il_max_context 512 \ 34 | --max_acts 2 \ 35 | --il_vocab_size 2000 \ 36 | --il_k 10 \ 37 | --il_temp 3 \ 38 | --use_il 1 \ 39 | --il_batch_size 64 \ 40 | --il_max_num_epochs 40 \ 41 | --il_len_scale 1 \ 42 | --use_il_graph_sampler 0 \ 43 | --use_il_buffer_sampler 1 \ 44 | --il_top_p 1 \ 45 | --il_use_dropout 0 \ 46 | --traj_dropout_prob 0.5 \ 47 | --project_name "iclr-text-games" \ 48 | --jericho_add_wt ${add_wt} \ 49 | --il_use_only_dropout 1 -------------------------------------------------------------------------------- /scripts/run_xtx_no_mix.sh: -------------------------------------------------------------------------------- 1 | LOG_FOLDER='xtx_no_mix_zork1' 2 | GAME='zork1.z5' 3 | SEED=0 4 | JERICHO_SEED=$SEED # set to -1 if you want stochastic version 5 | MODEL_NAME='xtx' 6 | JERICHO_ADD_WT='add_wt' # change to 'no_add_wt' if you don't want to add extra actions to game grammar 7 | 8 | # NOTE: r_for below corresponds to alpha_1 in the paper 9 | 10 | python3 -m scripts.train_rl --output_dir logs/${LOG_FOLDER} \ 11 | --rom_path games/${GAME} \ 12 | --seed ${SEED} \ 13 | --jericho_seed ${JERICHO_SEED} \ 14 | --model_name ${MODEL_NAME} \ 15 | --eval_freq 10000000 \ 16 | --memory_size 10000 \ 17 | --T 1 \ 18 | --w_inv 1 \ 19 | --r_for 1 \ 20 | --w_act 1 \ 21 | --graph_num_explore_steps 50 \ 22 | --graph_rescore_freq 1000000 \ 23 | --env_step_limit 50 \ 24 | --graph_score_temp 1 \ 25 | --graph_q_temp 10000 \ 26 | --graph_alpha 0 \ 27 | --log_top_blue_acts_freq 500 \ 28 | --use_action_model 1 \ 29 | --action_model_update_freq 500 \ 30 | --action_model_type transformer \ 31 | --il_max_context 512 \ 32 | --max_acts 2 \ 33 | --il_vocab_size 2000 \ 34 | --il_k 10 \ 35 | --il_temp 3 \ 36 | --use_il 1 \ 37 | --il_batch_size 64 \ 38 | --il_max_num_epochs 40 \ 39 | --il_len_scale 1 \ 40 | --use_il_graph_sampler 0 \ 41 | --use_il_buffer_sampler 1 \ 42 | --il_top_p 1 \ 43 | --il_use_dropout 0 \ 44 | --traj_dropout_prob 0.0 \ 45 | --jericho_add_wt ${JERICHO_ADD_WT} -------------------------------------------------------------------------------- /scripts/run_xtx_uniform.sh: -------------------------------------------------------------------------------- 1 | LOG_FOLDER='xtx_uniform_zork1' 2 | GAME='zork1.z5' 3 | SEED=0 4 | JERICHO_SEED=$SEED # set to -1 if you want stochastic version 5 | MODEL_NAME='xtx' 6 | JERICHO_ADD_WT='add_wt' # change to 'no_add_wt' if you don't want to add extra actions to game grammar 7 | 8 | python3 -m scripts.train_rl --output_dir logs/${LOG_FOLDER} \ 9 | --rom_path games/${GAME} \ 10 | --seed ${SEED} \ 11 | --jericho_seed ${JERICHO_SEED} \ 12 | --model_name ${MODEL_NAME} \ 13 | --eval_freq 10000000 \ 14 | --memory_size 10000 \ 15 | --T 1 \ 16 | --w_inv 1 \ 17 | --r_for 1 \ 18 | --w_act 1 \ 19 | --graph_num_explore_steps 50 \ 20 | --graph_rescore_freq 1000000 \ 21 | --env_step_limit 50 \ 22 | --graph_score_temp 1 \ 23 | --graph_q_temp 10000 \ 24 | --graph_alpha 0 \ 25 | --log_top_blue_acts_freq 500 \ 26 | --use_action_model 1 \ 27 | --action_model_update_freq 500 \ 28 | --action_model_type transformer \ 29 | --il_max_context 512 \ 30 | --max_acts 2 \ 31 | --il_vocab_size 2000 \ 32 | --il_k 10 \ 33 | --il_temp 3 \ 34 | --use_il 1 \ 35 | --il_batch_size 64 \ 36 | --il_max_num_epochs 40 \ 37 | --il_len_scale 1 \ 38 | --use_il_graph_sampler 0 \ 39 | --use_il_buffer_sampler 1 \ 40 | --il_top_p 1 \ 41 | --il_use_dropout 1 \ 42 | --traj_dropout_prob 0.005 \ 43 | --jericho_add_wt ${JERICHO_ADD_WT} \ 44 | --sample_uniform 1 -------------------------------------------------------------------------------- /scripts/train_rl.py: -------------------------------------------------------------------------------- 1 | # Built-in imports 2 | import argparse 3 | import random 4 | import logging 5 | 6 | # Third party imports 7 | import jericho 8 | import torch 9 | import numpy as np 10 | import wandb 11 | 12 | # Custom imports 13 | from agents import ( 14 | DrrnAgent, 15 | DrrnInvDynAgent, 16 | DrrnGraphInvDynAgent 17 | ) 18 | 19 | from trainers import ( 20 | DrrnTrainer, 21 | DrrnInvDynTrainer, 22 | DrrnGraphInvDynTrainer 23 | ) 24 | 25 | from transformers import GPT2LMHeadModel, GPT2Config 26 | 27 | import definitions.defs as defs 28 | from utils.env import JerichoEnv 29 | from utils.vec_env import VecEnv 30 | from utils import logger 31 | from utils.memory import State, Transition 32 | 33 | 34 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 35 | logging.getLogger().setLevel(logging.CRITICAL) 36 | # torch.autograd.set_detect_anomaly(True) 37 | 38 | 39 | def configure_logger(args): 40 | """ 41 | Setup various logging channels (wandb, text files, etc.). 42 | """ 43 | log_dir = args.output_dir 44 | wandb_on = args.wandb 45 | 46 | type_strs = ["json", "stdout"] 47 | if wandb_on and log_dir != "logs": 48 | type_strs += ["wandb"] 49 | tb = logger.Logger( 50 | log_dir, 51 | [ 52 | logger.make_output_format(type_str, log_dir, args=args) 53 | for type_str in type_strs 54 | ], 55 | ) 56 | 57 | logger.configure("{}-{}".format(log_dir, wandb.run.id), 58 | format_strs=["log", "stdout"], off=args.logging_off) 59 | log = logger.log 60 | 61 | return tb, log 62 | 63 | 64 | def parse_args(): 65 | parser = argparse.ArgumentParser() 66 | 67 | # General Settings 68 | parser.add_argument('--output_dir', default='logs') 69 | parser.add_argument('--rom_path', default='./games/detective.z5') 70 | parser.add_argument('--wandb', default=1, type=int) 71 | parser.add_argument('--save_path', default='princeton-nlp/text-games/') 72 | parser.add_argument('--logging_off', default=0, type=int) 73 | parser.add_argument('--weight_file', default=None, type=str) 74 | parser.add_argument('--memory_file', default=None, type=str) 75 | parser.add_argument('--traj_file', default=None, type=str) 76 | parser.add_argument('--run_id', default=None, type=str) 77 | parser.add_argument('--project_name', default='xtx', type=str) 78 | parser.add_argument('--debug', default=0, type=int) 79 | parser.add_argument('--jericho_add_wt', default='add_wt', type=str) 80 | 81 | # Environment settings 82 | parser.add_argument('--check_valid_actions_changed', default=0, type=int) 83 | 84 | # Training Settings 85 | parser.add_argument('--env_step_limit', default=100, type=int) 86 | parser.add_argument('--dynamic_episode_length', default=0, type=int) 87 | parser.add_argument('--episode_ext_type', default='steady_50', type=str) 88 | parser.add_argument('--seed', default=0, type=int) 89 | parser.add_argument('--jericho_seed', default=0, type=int) 90 | parser.add_argument('--num_envs', default=8, type=int) 91 | parser.add_argument('--max_steps', default=100000, type=int) 92 | parser.add_argument('--q_update_freq', default=1, type=int) 93 | parser.add_argument('--checkpoint_freq', default=5000, type=int) 94 | parser.add_argument('--eval_freq', default=5000, type=int) 95 | parser.add_argument('--log_freq', default=100, type=int) 96 | parser.add_argument('--target_update_freq', default=100, type=int) 97 | parser.add_argument('--dump_traj_freq', default=5000, type=int) 98 | parser.add_argument('--gamma', default=.9, type=float) 99 | parser.add_argument('--batch_size', default=64, type=int) 100 | parser.add_argument('--memory_size', default=500000, type=int) 101 | parser.add_argument('--memory_alpha', default=.4, type=float) 102 | parser.add_argument('--clip', default=5, type=float) 103 | parser.add_argument('--learning_rate', default=0.0001, type=float) 104 | parser.add_argument('--priority_fraction', default=0.5, type=float) 105 | parser.add_argument('--no_invalid_act_detect', default=0, type=int) 106 | parser.add_argument('--filter_invalid_acts', default=1, type=int) 107 | parser.add_argument('--start_from_reward', default=0, type=int) 108 | parser.add_argument('--start_from_wt', default=0, type=int) 109 | parser.add_argument('--filter_drop_acts', default=0, type=int) 110 | 111 | # Action Model Settings 112 | parser.add_argument('--max_acts', default=5, type=int) 113 | parser.add_argument('--tf_embedding_dim', default=128, type=int) 114 | parser.add_argument('--tf_hidden_dim', default=128, type=int) 115 | parser.add_argument('--nhead', default=4, type=int) 116 | parser.add_argument('--feedforward_dim', default=512, type=int) 117 | parser.add_argument('--tf_num_layers', default=3, type=int) 118 | parser.add_argument('--ngram', default=3, type=int) 119 | parser.add_argument('--traj_k', default=1, type=int) 120 | parser.add_argument('--action_model_update_freq', default=1e9, type=int) 121 | parser.add_argument('--smooth_alpha', default=0.00001, type=float) 122 | parser.add_argument('--cut_beta_at_threshold', default=0, type=int) 123 | parser.add_argument('--action_model_type', default='ngram', type=str) 124 | parser.add_argument('--tf_num_epochs', default=50, type=int) 125 | parser.add_argument( 126 | '--turn_action_model_off_after_falling', default=0, type=int) 127 | parser.add_argument('--traj_dropout_prob', default=0, type=float) 128 | parser.add_argument('--init_bin_prob', default=0.1, type=float) 129 | parser.add_argument('--num_bins', default=0, type=int) 130 | parser.add_argument('--binning_prob_update_freq', default=1e9, type=int) 131 | parser.add_argument('--random_action_dropout', default=0, type=int) 132 | parser.add_argument('--use_multi_ngram', default=0, type=int) 133 | parser.add_argument('--use_action_model', default=0, type=int) 134 | parser.add_argument('--sample_action_argmax', default=0, type=int) 135 | parser.add_argument('--il_max_context', default=512, type=int) 136 | parser.add_argument('--il_k', default=5, type=int) 137 | parser.add_argument('--il_batch_size', default=64, type=int) 138 | parser.add_argument('--il_lr', default=1e-3, type=float) 139 | parser.add_argument('--il_max_num_epochs', default=200, type=int) 140 | parser.add_argument('--il_num_eval_runs', default=3, type=int) 141 | parser.add_argument('--il_eval_freq', default=300, type=int) 142 | parser.add_argument('--il_vocab_size', default=2000, type=int) 143 | parser.add_argument('--il_temp', default=1., type=float) 144 | parser.add_argument('--use_il', default=0, type=int) 145 | parser.add_argument('--il_len_scale', default=1.0, type=float) 146 | parser.add_argument('--use_il_graph_sampler', default=0, type=int) 147 | parser.add_argument('--use_il_buffer_sampler', default=1, type=int) 148 | parser.add_argument('--il_top_p', default=0.9, type=float) 149 | parser.add_argument('--il_use_dropout', default=0, type=int) 150 | parser.add_argument('--il_use_only_dropout', default=0, type=int) 151 | 152 | # DRRN Model Settings 153 | parser.add_argument('--drrn_embedding_dim', default=128, type=int) 154 | parser.add_argument('--drrn_hidden_dim', default=128, type=int) 155 | parser.add_argument('--use_drrn_inv_look', default=1, type=int) 156 | parser.add_argument('--use_counts', default=0, type=int) 157 | parser.add_argument('--reset_counts_every_epoch', default=0, type=int) 158 | parser.add_argument('--sample_uniform', default=0, type=int) 159 | parser.add_argument('--T', default=1, type=float) 160 | parser.add_argument('--rotating_temp', default=0, type=int) 161 | parser.add_argument('--augment_state_with_score', default=0, type=int) 162 | 163 | # Graph Model Settings 164 | parser.add_argument('--graph_num_explore_steps', default=50, type=int) 165 | parser.add_argument('--graph_rescore_freq', default=500, type=int) 166 | parser.add_argument('--graph_merge_freq', default=500, type=int) 167 | parser.add_argument('--graph_hash', default='inv_loc_ob', type=str) 168 | parser.add_argument('--graph_score_temp', default=1, type=float) 169 | parser.add_argument('--graph_q_temp', default=1, type=float) 170 | parser.add_argument('--graph_alpha', default=0.5, type=float) 171 | parser.add_argument('--log_top_blue_acts_freq', default=100, type=int) 172 | 173 | # Offline Q Learning settings 174 | parser.add_argument('--offline_q_steps', default=1000, type=int) 175 | parser.add_argument('--offline_q_transfer_freq', default=100, type=int) 176 | parser.add_argument('--offline_q_eval_runs', default=10, type=int) 177 | 178 | # Inv-Dyn Settings 179 | parser.add_argument('--type_inv', default='decode') 180 | parser.add_argument('--type_for', default='ce') 181 | parser.add_argument('--w_inv', default=0, type=float) 182 | parser.add_argument('--w_for', default=0, type=float) 183 | parser.add_argument('--w_act', default=0, type=float) 184 | parser.add_argument('--r_for', default=0, type=float) 185 | 186 | parser.add_argument('--nor', default=0, type=int, help='no game reward') 187 | parser.add_argument('--randr', default=0, type=int, 188 | help='random game reward by objects and locations within episode') 189 | parser.add_argument('--perturb', default=0, type=int, 190 | help='perturb state and action') 191 | 192 | parser.add_argument('--hash_rep', default=0, type=int, 193 | help='hash for representation') 194 | parser.add_argument('--act_obs', default=0, type=int, 195 | help='action set as state representation') 196 | parser.add_argument('--fix_rep', default=0, type=int, 197 | help='fix representation') 198 | 199 | # Additional Model Settings 200 | parser.add_argument('--model_name', default='xtx', type=str) 201 | parser.add_argument('--beta', default=0.3, type=float) 202 | parser.add_argument('--beta_trainable', default=0, type=int) 203 | parser.add_argument( 204 | '--eps', 205 | default=0, 206 | type=int, 207 | help='0: ~ softmax act_value; 1: eps-greedy-exploration', 208 | ) 209 | parser.add_argument( 210 | '--eps_type', 211 | default='uniform', 212 | type=str, 213 | help='uniform (-1): uniform exploration; softmax_lm (0): ~ softmax lm_value; uniform_lm_topk (>0): ~ uniform(top k w.r.t. lm_value)', 214 | ) 215 | parser.add_argument( 216 | '--alpha', 217 | default=0, 218 | type=float, 219 | help='act_value = alpha * bert_value + (1-alpha) * q_value; only used when eps is None now', 220 | ) 221 | parser.add_argument('--sample_argmax', 222 | default=0, 223 | type=int, 224 | help='whether to replace sampling with argmax') 225 | 226 | return parser.parse_args() 227 | 228 | 229 | def main(): 230 | assert jericho.__version__.startswith( 231 | "3"), "This code is designed to be run with Jericho version >= 3.0.0." 232 | 233 | args = parse_args() 234 | print(args) 235 | print("device", device) 236 | print(args.model_name) 237 | 238 | # Set seed across imports 239 | torch.manual_seed(args.seed) 240 | np.random.seed(args.seed) 241 | random.seed(args.seed) 242 | 243 | # Start logger 244 | tb, log = configure_logger(args) 245 | 246 | if args.debug: 247 | import pdb 248 | pdb.set_trace() 249 | 250 | # Setup envs 251 | cache = dict() 252 | eval_env = JerichoEnv(args.rom_path, 253 | args.env_step_limit, 254 | get_valid=True, 255 | seed=args.jericho_seed, 256 | args=args, 257 | cache=cache, 258 | start_from_reward=args.start_from_reward, 259 | start_from_wt=args.start_from_wt, 260 | log=log) 261 | envs = [ 262 | JerichoEnv(args.rom_path, 263 | args.env_step_limit, 264 | get_valid=True, 265 | cache=cache, 266 | args=args, 267 | seed=args.jericho_seed, 268 | start_from_reward=args.start_from_reward, 269 | start_from_wt=args.start_from_wt, 270 | log=log) for _ in range(args.num_envs) 271 | ] 272 | 273 | # Setup rl model 274 | if args.model_name == defs.DRRN: 275 | assert args.use_action_model == 0, "'use_action_model' needs to be OFF" 276 | assert args.r_for == 0, "r_for needs to be zero when NOT using inverse dynamics." 277 | assert args.use_il == 0, "no il should be used when running DRRN." 278 | 279 | envs = VecEnv(args.num_envs, eval_env) 280 | 281 | agent = DrrnAgent(tb, log, args, envs, None) 282 | trainer = DrrnTrainer(tb, log, agent, envs, eval_env, args) 283 | 284 | elif args.model_name == defs.XTX: 285 | assert args.use_il == args.use_action_model, "action model stuff should be on when using IL." 286 | assert args.r_for > 0, "r_for needs to be ON when using inverse dynamics." 287 | if args.il_use_dropout or args.il_use_only_dropout: 288 | assert args.il_use_dropout != args.il_use_only_dropout, "cannot use two types of dropout at the same time." 289 | 290 | envs = VecEnv(args.num_envs, eval_env) 291 | 292 | config = GPT2Config(vocab_size=args.il_vocab_size, n_embd=args.tf_embedding_dim, 293 | n_layer=args.tf_num_layers, n_head=args.nhead, n_positions=args.il_max_context, n_ctx=args.il_max_context) 294 | lm = GPT2LMHeadModel(config) 295 | lm.train() 296 | agent = DrrnGraphInvDynAgent(args, tb, log, envs, action_models=lm) 297 | trainer = DrrnGraphInvDynTrainer(tb, log, agent, envs, eval_env, args) 298 | 299 | elif args.model_name == defs.INV_DY: 300 | assert args.r_for > 0, "r_for needs to be ON when using inverse dynamics." 301 | assert args.use_action_model == 0, "'use_action_model' needs to be OFF." 302 | 303 | envs = VecEnv(args.num_envs, eval_env) 304 | 305 | agent = DrrnInvDynAgent(args, None, tb, log, envs) 306 | trainer = DrrnInvDynTrainer(tb, log, agent, envs, eval_env, args) 307 | 308 | else: 309 | raise Exception("Unknown model type!") 310 | 311 | if args.weight_file is not None and args.memory_file is not None: 312 | agent.load(args.run_id, args.weight_file, args.memory_file) 313 | log("Successfully loaded network and replay buffer from checkpoint!") 314 | 315 | try: 316 | trainer.train() 317 | finally: 318 | for ps in envs.ps: 319 | ps.terminate() 320 | 321 | 322 | if __name__ == "__main__": 323 | main() 324 | -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from trainers.trainer import Trainer 2 | from trainers.drrn.drrn_trainer import DrrnTrainer 3 | from trainers.drrn.drrn_inv_dyn_trainer import DrrnInvDynTrainer 4 | from trainers.drrn.drrn_graph_inv_dyn_trainer import DrrnGraphInvDynTrainer 5 | -------------------------------------------------------------------------------- /trainers/drrn/drrn_graph_inv_dyn_trainer.py: -------------------------------------------------------------------------------- 1 | # Built-in Imports 2 | from typing import Dict, Union, Callable, List 3 | import time 4 | from os.path import join as pjoin 5 | 6 | # Libraries 7 | from jericho.util import clean 8 | import wandb 9 | import torch 10 | 11 | # Custom Imports 12 | from trainers import DrrnInvDynTrainer 13 | 14 | from agents import DrrnInvDynAgent 15 | 16 | from utils.vec_env import VecEnv 17 | from utils.memory import Transition 18 | import utils.logger as logger 19 | from utils.env import JerichoEnv 20 | import utils.drrn as Drrn 21 | import utils.inv_dyn as InvDyn 22 | import utils.ngram as Ngram 23 | from utils.util import process_action 24 | 25 | 26 | class DrrnGraphInvDynTrainer(DrrnInvDynTrainer): 27 | def __init__( 28 | self, 29 | tb: logger.Logger, 30 | log: Callable[..., None], 31 | agent: DrrnInvDynAgent, 32 | envs: VecEnv, 33 | eval_env: JerichoEnv, 34 | args: Dict[str, Union[str, int, float]] 35 | ): 36 | super().__init__(tb, log, agent, envs, eval_env, args) 37 | 38 | self.graph_num_explore_steps = args.graph_num_explore_steps 39 | self.graph_rescore_freq = args.graph_rescore_freq 40 | self.graph_merge_freq = args.graph_merge_freq 41 | self.log_top_blue_acts_freq = args.log_top_blue_acts_freq 42 | 43 | self.use_il_graph_sampler = args.use_il_graph_sampler 44 | self.use_il_buffer_sampler = args.use_il_buffer_sampler 45 | self.use_il = args.use_il 46 | 47 | def train(self): 48 | start = time.time() 49 | max_score = 0 50 | 51 | obs, infos, states, valid_ids, transitions = Drrn.setup_env( 52 | self, self.envs) 53 | 54 | for step in range(1, self.max_steps + 1): 55 | self.steps = step 56 | self.log("Step {}".format(step)) 57 | action_ids, action_idxs, action_qvals = self.agent.act(states, 58 | valid_ids, 59 | [info['valid'] 60 | for info in infos], 61 | sample=True) 62 | 63 | # Get the actual next action string for each env 64 | action_strs = [ 65 | info['valid'][idx] for info, idx in zip(infos, action_idxs) 66 | ] 67 | 68 | # Log envs[0] 69 | s = '' 70 | for idx, (act, val) in enumerate( 71 | sorted(zip(infos[0]['valid'], action_qvals[0]), 72 | key=lambda x: x[1], 73 | reverse=True), 1): 74 | s += "{}){:.2f} {} ".format(idx, val.item(), act) 75 | self.log('Q-Values: {}'.format(s)) 76 | 77 | # Update all envs 78 | infos, next_states, next_valids, max_score, obs = self.update_envs( 79 | action_strs, action_ids, states, max_score, transitions, obs, infos, action_qvals) 80 | states, valid_ids = next_states, next_valids 81 | 82 | self.end_step(step, start, max_score, action_qvals) 83 | 84 | def update_envs(self, action_strs, action_ids, states, max_score: int, 85 | transitions, obs, infos, qvals): 86 | """ 87 | TODO 88 | """ 89 | next_obs, next_rewards, next_dones, next_infos = self.envs.step( 90 | action_strs) 91 | 92 | if self.use_il_graph_sampler: 93 | next_node_ids = [graph.state_hash(next_info, next_ob) for graph, next_ob, next_info in zip( 94 | self.agent.graphs, next_obs, next_infos)] 95 | 96 | # Add to environment trajectory 97 | trajs = self.envs.add_traj( 98 | list(map(lambda x: (process_action(x[0]), x[1]), 99 | zip(action_strs, next_node_ids)))) 100 | 101 | # Update graph depending on state of environment 102 | self.log('Updating graph ...') 103 | for i, (graph, ob, info, qvals, next_ob, next_info, act) in enumerate(zip(self.agent.graphs, obs, infos, qvals, next_obs, next_infos, action_strs)): 104 | graph.maybe_update(ob, info, next_ob, next_info, 105 | qvals.cpu().detach().tolist(), i, process_action(act)) 106 | if self.use_action_model: 107 | next_states = self.agent.build_states( 108 | next_obs, next_infos, action_strs, [state.acts for state in states]) 109 | else: 110 | next_states = self.agent.build_states(next_obs, next_infos) 111 | 112 | # Update valid acts if next node is already in the tree 113 | next_valids = [self.agent.encode(next_info['valid']) 114 | for next_info in next_infos] 115 | 116 | if self.r_for > 0: 117 | reward_curiosity, _ = InvDyn.inv_loss_decode(self.agent.network, 118 | states, next_states, [[a] for a in action_ids], hat=True, reduction='none') 119 | next_rewards = next_rewards + reward_curiosity.detach().numpy() * self.r_for 120 | self.tb.logkv_mean('Curiosity', reward_curiosity.mean().item()) 121 | 122 | for i, (next_ob, next_reward, next_done, next_info, state, next_state, next_action_str) in enumerate(zip(next_obs, next_rewards, next_dones, next_infos, states, next_states, action_strs)): 123 | # Log 124 | self.log('Action_{}: {}'.format( 125 | self.steps, next_action_str), condition=(i == 0)) 126 | self.log("Reward{}: {}, Score {}, Done {}".format( 127 | self.steps, next_reward, next_info['score'], next_done), condition=(i == 0)) 128 | self.log('Obs{}: {} Inv: {} Desc: {}'.format( 129 | self.steps, clean(next_ob), clean(next_info['inv']), 130 | clean(next_info['look'])), condition=(i == 0)) 131 | 132 | transition = Transition( 133 | state, action_ids[i], next_reward, next_state, next_valids[i], next_done) 134 | transitions[i].append(transition) 135 | self.agent.observe(transition) 136 | 137 | if next_done: 138 | # Add trajectory to graph 139 | if self.use_il_buffer_sampler: 140 | self.agent.il_buffer.add_traj(transitions[i]) 141 | 142 | if next_info['score'] >= max_score: # put in alpha queue 143 | if next_info['score'] > max_score: 144 | self.agent.memory.clear_alpha() 145 | max_score = next_info['score'] 146 | for transition in transitions[i]: 147 | self.agent.observe(transition, is_prior=True) 148 | transitions[i] = [] 149 | 150 | if self.use_action_model: 151 | Ngram.log_recovery_metrics(self, i) 152 | 153 | # Add last node to graph 154 | if self.use_il_graph_sampler: 155 | if next_infos[i]['look'] != 'unknown' and next_infos[i]['inv'] != 'unknown': 156 | with torch.no_grad(): 157 | _, qvals = self.agent.network.act( 158 | next_states, next_valids, [next_info['valid'] for next_info in next_infos]) 159 | self.agent.graphs[i].maybe_update( 160 | next_ob, next_info, None, None, qvals[i].cpu().tolist(), i, None) 161 | 162 | 163 | next_infos = list(next_infos) 164 | 165 | next_obs[i], next_infos[i] = self.envs.reset_one(i) 166 | 167 | if self.use_action_model: 168 | next_states[i] = self.agent.build_skip_state( 169 | next_obs[i], next_infos[i], 'reset', []) 170 | else: 171 | next_states[i] = self.agent.build_state( 172 | next_obs[i], next_infos[i]) 173 | 174 | next_valids[i] = self.agent.encode(next_infos[i]['valid']) 175 | 176 | return next_infos, next_states, next_valids, max_score, next_obs 177 | 178 | def end_step(self, step: int, start, max_score: int, action_qvals): 179 | """ 180 | TODO 181 | """ 182 | if step % self.q_update_freq == 0: 183 | self.update_agent() 184 | 185 | if step % self.target_update_freq == 0: 186 | self.agent.transfer_weights() 187 | 188 | if self.use_action_model: 189 | Ngram.end_step(self, step) 190 | 191 | if step % self.log_freq == 0: 192 | # rank_metrics = self.evaluate_optimal() 193 | rank_metrics = dict() 194 | self.write_to_logs(step, start, self.envs, max_score, action_qvals, 195 | rank_metrics) 196 | 197 | # Save model weights etc. 198 | if step % self.checkpoint_freq == 0: 199 | self.agent.save(int(step / self.checkpoint_freq)) 200 | 201 | if self.use_il: 202 | # save locally 203 | torch.save(self.agent.action_models.state_dict(), 204 | pjoin(wandb.run.dir, 'il_weights_{}.pt'.format(step))) 205 | 206 | # upload to wandb 207 | wandb.save( 208 | pjoin(wandb.run.dir, 'il_weights_{}.pt'.format(step))) 209 | 210 | def write_to_logs(self, step, start, envs, max_score, qvals, rank_metrics, 211 | *args): 212 | """ 213 | Log any relevant metrics. 214 | """ 215 | self.tb.logkv('Step', step) 216 | for key, val in rank_metrics.items(): 217 | self.tb.logkv(key, val) 218 | self.tb.logkv("FPS", int( 219 | (step*self.envs.num_envs)/(time.time()-start))) 220 | self.tb.logkv("EpisodeScores100", self.envs.get_end_scores().mean()) 221 | self.tb.logkv('MaxScore', max_score) 222 | if self.use_il_graph_sampler: 223 | self.tb.logkv('#BlueActs', sum( 224 | [len(node['blue_acts']) for node in self.agent.graphs[0].graph.values()])) 225 | self.tb.dumpkvs() 226 | -------------------------------------------------------------------------------- /trainers/drrn/drrn_inv_dyn_trainer.py: -------------------------------------------------------------------------------- 1 | # Built-in Imports 2 | import time 3 | from typing import Dict, Union, Callable 4 | import random 5 | 6 | # Libraries 7 | import wandb 8 | import torch 9 | from jericho.util import clean 10 | 11 | # Custom imports 12 | from utils.util import check_exists, load_object, process_action, save_object 13 | from utils.memory import Transition 14 | from utils.env import JerichoEnv 15 | from utils.vec_env import VecEnv 16 | import utils.logger as logger 17 | import utils.ngram as Ngram 18 | import utils.inv_dyn as InvDyn 19 | 20 | from trainers import Trainer 21 | 22 | OBJECTS_DIR = './saved_objects' 23 | 24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | 26 | 27 | class DrrnInvDynTrainer(Trainer): 28 | def __init__( 29 | self, 30 | tb: logger.Logger, 31 | log: Callable[..., None], 32 | agent, 33 | envs: VecEnv, 34 | eval_env: JerichoEnv, 35 | args: Dict[str, Union[str, int, float]] 36 | ): 37 | super().__init__(tb, log, agent, envs, eval_env, args) 38 | 39 | # Action model settings 40 | self.use_action_model = args.use_action_model 41 | self.rotating_temp = args.rotating_temp 42 | if self.use_action_model: 43 | Ngram.init_trainer(self, args) 44 | 45 | self.r_for = args.r_for 46 | 47 | self.collected_trajs = [] 48 | self.full_traj_folder = args.output_dir.split('/')[-1][:-3] 49 | self.dump_traj_freq = args.dump_traj_freq 50 | 51 | def setup_env(self, envs): 52 | """ 53 | Setup the environment. 54 | """ 55 | obs, infos = envs.reset() 56 | if self.use_action_model: 57 | states = self.agent.build_states( 58 | obs, infos, ['reset'] * 8, [[]] * 8) 59 | else: 60 | states = self.agent.build_states(obs, infos) 61 | valid_ids = [self.agent.encode(info['valid']) for info in infos] 62 | transitions = [[] for info in infos] 63 | 64 | return obs, infos, states, valid_ids, transitions 65 | 66 | def update_envs(self, action_strs, action_ids, states, max_score: int, 67 | transitions, obs, infos): 68 | """ 69 | TODO 70 | """ 71 | next_obs, next_rewards, next_dones, next_infos = self.envs.step( 72 | action_strs) 73 | 74 | if self.use_action_model: 75 | next_states = self.agent.build_states( 76 | next_obs, next_infos, action_strs, [state.acts for state in states]) 77 | else: 78 | next_states = self.agent.build_states(next_obs, next_infos) 79 | 80 | next_valids = [self.agent.encode(next_info['valid']) 81 | for next_info in next_infos] 82 | 83 | self.envs.add_full_traj( 84 | [ 85 | (ob, info['look'], info['inv'], act, r) for ob, info, act, r in zip(obs, infos, action_strs, next_rewards) 86 | ] 87 | ) 88 | 89 | if self.r_for > 0: 90 | reward_curiosity, _ = InvDyn.inv_loss_decode(self.agent.network, 91 | states, next_states, [[a] for a in action_ids], hat=True, reduction='none') 92 | next_rewards = next_rewards + reward_curiosity.detach().numpy() * self.r_for 93 | self.tb.logkv_mean('Curiosity', reward_curiosity.mean().item()) 94 | 95 | if self.use_action_model: 96 | # Add to environment trajectory 97 | trajs = self.envs.add_traj( 98 | list(map(lambda x: process_action(x), action_strs))) 99 | 100 | for next_reward, next_done, next_info, traj in zip(next_rewards, next_dones, next_infos, trajs): 101 | # Push to trajectory memory if reward was positive and the episode didn't end yet 102 | if next_reward > 0: 103 | Ngram.push_to_traj_mem(self, next_info, traj) 104 | 105 | for i, (next_ob, next_reward, next_done, next_info, state, next_state, next_action_str) in enumerate(zip(next_obs, next_rewards, next_dones, next_infos, states, next_states, action_strs)): 106 | # Log 107 | self.log('Action_{}: {}'.format( 108 | self.steps, next_action_str), condition=(i == 0)) 109 | self.log("Reward{}: {}, Score {}, Done {}".format( 110 | self.steps, next_reward, next_info['score'], next_done), condition=(i == 0)) 111 | self.log('Obs{}: {} Inv: {} Desc: {}'.format( 112 | self.steps, clean(next_ob), clean(next_info['inv']), 113 | clean(next_info['look'])), condition=(i == 0)) 114 | 115 | transition = Transition( 116 | state, action_ids[i], next_reward, next_state, next_valids[i], next_done) 117 | transitions[i].append(transition) 118 | self.agent.observe(transition) 119 | 120 | if next_done: 121 | self.tb.logkv_mean('EpisodeScore', next_info['score']) 122 | if next_info['score'] >= max_score: # put in alpha queue 123 | if next_info['score'] > max_score: 124 | self.agent.memory.clear_alpha() 125 | max_score = next_info['score'] 126 | for transition in transitions[i]: 127 | self.agent.observe(transition, is_prior=True) 128 | transitions[i] = [] 129 | 130 | if self.use_action_model: 131 | Ngram.log_recovery_metrics(self, i) 132 | 133 | if self.envs.get_ngram_needs_update(i): 134 | Ngram.update_ngram(self, i) 135 | 136 | if self.rotating_temp: 137 | self.agent.network.T[i] = random.choice([1.0, 2.0, 3.0]) 138 | 139 | next_infos = list(next_infos) 140 | # add finished to trajectory to collection 141 | traj = self.envs.add_full_traj_i( 142 | i, (next_obs[i], next_infos[i]['look'], next_infos[i]['inv'])) 143 | self.collected_trajs.append(traj) 144 | 145 | next_obs[i], next_infos[i] = self.envs.reset_one(i) 146 | 147 | if self.use_action_model: 148 | next_states[i] = self.agent.build_skip_state( 149 | next_obs[i], next_infos[i], 'reset', []) 150 | else: 151 | next_states[i] = self.agent.build_state( 152 | next_obs[i], next_infos[i]) 153 | 154 | next_valids[i] = self.agent.encode(next_infos[i]['valid']) 155 | 156 | return next_infos, next_states, next_valids, max_score, next_obs 157 | 158 | def train(self): 159 | """ 160 | Train the agent. 161 | """ 162 | start = time.time() 163 | max_score, max_eval, self.env_steps = 0, 0, 0 164 | obs, infos, states, valid_ids, transitions = self.setup_env(self.envs) 165 | 166 | for step in range(1, self.max_steps + 1): 167 | self.steps = step 168 | self.log("Step {}".format(step)) 169 | action_ids, action_idxs, action_qvals = self.agent.act(states, 170 | valid_ids, 171 | [info['valid'] 172 | for info in infos], 173 | sample=True) 174 | 175 | # Get the actual next action string for each env 176 | action_strs = [ 177 | info['valid'][idx] for info, idx in zip(infos, action_idxs) 178 | ] 179 | 180 | # Log envs[0] 181 | s = '' 182 | for idx, (act, val) in enumerate( 183 | sorted(zip(infos[0]['valid'], action_qvals[0]), 184 | key=lambda x: x[1], 185 | reverse=True), 1): 186 | s += "{}){:.2f} {} ".format(idx, val.item(), act) 187 | self.log('Q-Values: {}'.format(s)) 188 | 189 | # Update all envs 190 | infos, next_states, next_valids, max_score, obs = self.update_envs( 191 | action_strs, action_ids, states, max_score, transitions, obs, infos) 192 | states, valid_ids = next_states, next_valids 193 | 194 | self.end_step(step, start, max_score, action_qvals, max_eval) 195 | 196 | def update_agent(self): 197 | """ 198 | Update the agent with gradient descent. 199 | """ 200 | # Update 201 | loss = self.agent.update() 202 | 203 | # Log the loss 204 | if loss is not None: 205 | for k, v in loss.items(): 206 | self.tb.logkv_mean(k, v) 207 | 208 | def end_step(self, step: int, start, max_score: int, action_qvals, 209 | max_eval: int): 210 | """ 211 | TODO 212 | """ 213 | if step % self.q_update_freq == 0: 214 | self.update_agent() 215 | 216 | if step % self.target_update_freq == 0: 217 | self.agent.transfer_weights() 218 | 219 | if step % self.log_freq == 0: 220 | # rank_metrics = self.evaluate_optimal() 221 | rank_metrics = dict() 222 | self.write_to_logs(step, start, self.envs, max_score, action_qvals, 223 | rank_metrics) 224 | 225 | # Save model weights etc. 226 | if step % self.checkpoint_freq == 0: 227 | self.agent.save(int(step / self.checkpoint_freq), 228 | self.top_k_traj if self.use_action_model else None) 229 | 230 | # Evaluate agent across several runs 231 | if step % self.eval_freq == 0: 232 | eval_score = self.evaluate(nb_episodes=10) 233 | wandb.log({ 234 | 'EvalScore': eval_score, 235 | 'Step': step, 236 | "Env Steps": self.env_steps 237 | }) 238 | if eval_score >= max_eval: 239 | max_eval = eval_score 240 | self.agent.save(step, is_best=True) 241 | 242 | if self.use_action_model: 243 | Ngram.end_step(self, step) 244 | 245 | def write_to_logs(self, step, start, envs, max_score, qvals, rank_metrics, 246 | *args): 247 | """ 248 | Log any relevant metrics. 249 | """ 250 | self.tb.logkv('Step', step) 251 | self.tb.logkv('Env Steps', self.env_steps) 252 | # self.tb.logkv('Beta', self.agent.network.beta) 253 | for key, val in rank_metrics.items(): 254 | self.tb.logkv(key, val) 255 | self.tb.logkv("FPS", int( 256 | (step*self.envs.num_envs)/(time.time()-start))) 257 | self.tb.logkv("EpisodeScores100", self.envs.get_end_scores().mean()) 258 | self.tb.logkv('MaxScore', max_score) 259 | 260 | if self.use_action_model: 261 | Ngram.log_metrics(self) 262 | 263 | self.tb.dumpkvs() 264 | -------------------------------------------------------------------------------- /trainers/drrn/drrn_trainer.py: -------------------------------------------------------------------------------- 1 | # Built-in Imports 2 | import time 3 | import heapq as pq 4 | import statistics as stats 5 | import random 6 | import copy 7 | from typing import Callable, List, Dict, Union 8 | 9 | # Libraries 10 | import wandb 11 | import torch 12 | from jericho.util import clean 13 | 14 | # Custom imports 15 | from trainers import Trainer 16 | 17 | from utils.util import process_action, check_exists, load_object, save_object 18 | from utils.env import JerichoEnv 19 | from utils.vec_env import VecEnv 20 | from utils.memory import Transition 21 | import utils.logger as logger 22 | import utils.ngram as Ngram 23 | 24 | 25 | OBJECTS_DIR = './saved_objects' 26 | 27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | 29 | 30 | class DrrnTrainer(Trainer): 31 | def __init__( 32 | self, 33 | tb: logger.Logger, 34 | log: Callable[..., None], 35 | agent, 36 | envs: VecEnv, 37 | eval_env: JerichoEnv, 38 | args: Dict[str, Union[str, int, float]] 39 | ): 40 | super().__init__(tb, log, agent, envs, eval_env, args) 41 | 42 | # Action model settings 43 | self.use_action_model = args.use_action_model 44 | self.rotating_temp = args.rotating_temp 45 | if self.use_action_model: 46 | Ngram.init_trainer(self, args) 47 | 48 | self.collected_trajs = [] 49 | self.full_traj_folder = args.output_dir.split('/')[-1][:-3] 50 | self.dump_traj_freq = args.dump_traj_freq 51 | 52 | def setup_env(self, envs): 53 | """ 54 | Setup the environment. 55 | """ 56 | obs, infos = envs.reset() 57 | if self.use_action_model: 58 | states = self.agent.build_states( 59 | obs, infos, ['reset'] * 8, [[]] * 8) 60 | else: 61 | states = self.agent.build_states(obs, infos) 62 | valid_ids = [self.agent.encode(info['valid']) for info in infos] 63 | transitions = [[] for info in infos] 64 | 65 | return obs, infos, states, valid_ids, transitions 66 | 67 | def update_envs(self, action_strs, action_ids, states, max_score: int, 68 | transitions, obs, infos): 69 | """ 70 | TODO 71 | """ 72 | next_obs, next_rewards, next_dones, next_infos = self.envs.step( 73 | action_strs) 74 | 75 | if self.use_action_model: 76 | next_states = self.agent.build_states( 77 | next_obs, next_infos, action_strs, [state.acts for state in states]) 78 | else: 79 | next_states = self.agent.build_states(next_obs, next_infos) 80 | 81 | next_valids = [self.agent.encode(next_info['valid']) 82 | for next_info in next_infos] 83 | 84 | self.envs.add_full_traj( 85 | [ 86 | (ob, info['look'], info['inv'], act, r) for ob, info, act, r in zip(obs, infos, action_strs, next_rewards) 87 | ] 88 | ) 89 | 90 | if self.use_action_model: 91 | # Add to environment trajectory 92 | trajs = self.envs.add_traj( 93 | list(map(lambda x: process_action(x), action_strs))) 94 | 95 | for next_reward, next_done, next_info, traj in zip(next_rewards, next_dones, next_infos, trajs): 96 | # Push to trajectory memory if reward was positive and the episode didn't end yet 97 | if next_reward > 0: 98 | Ngram.push_to_traj_mem(self, next_info, traj) 99 | 100 | for i, (next_ob, next_reward, next_done, next_info, state, next_state, next_action_str) in enumerate(zip(next_obs, next_rewards, next_dones, next_infos, states, next_states, action_strs)): 101 | # Log 102 | self.log('Action_{}: {}'.format( 103 | self.steps, next_action_str), condition=(i == 0)) 104 | self.log("Reward{}: {}, Score {}, Done {}".format( 105 | self.steps, next_reward, next_info['score'], next_done), condition=(i == 0)) 106 | self.log('Obs{}: {} Inv: {} Desc: {}'.format( 107 | self.steps, clean(next_ob), clean(next_info['inv']), 108 | clean(next_info['look'])), condition=(i == 0)) 109 | 110 | transition = Transition( 111 | state, action_ids[i], next_reward, next_state, next_valids[i], next_done) 112 | transitions[i].append(transition) 113 | self.agent.observe(transition) 114 | 115 | if next_done: 116 | self.tb.logkv_mean('EpisodeScore', next_info['score']) 117 | if next_info['score'] >= max_score: # put in alpha queue 118 | if next_info['score'] > max_score: 119 | self.agent.memory.clear_alpha() 120 | max_score = next_info['score'] 121 | for transition in transitions[i]: 122 | self.agent.observe(transition, is_prior=True) 123 | transitions[i] = [] 124 | 125 | if self.use_action_model: 126 | Ngram.log_recovery_metrics(self, i) 127 | 128 | if self.envs.get_ngram_needs_update(i): 129 | Ngram.update_ngram(self, i) 130 | 131 | if self.rotating_temp: 132 | self.agent.network.T[i] = random.choice([1.0, 2.0, 3.0]) 133 | 134 | next_infos = list(next_infos) 135 | # add finished to trajectory to collection 136 | traj = self.envs.add_full_traj_i( 137 | i, (next_obs[i], next_infos[i]['look'], next_infos[i]['inv'])) 138 | self.collected_trajs.append(traj) 139 | 140 | next_obs[i], next_infos[i] = self.envs.reset_one(i) 141 | 142 | if self.use_action_model: 143 | next_states[i] = self.agent.build_skip_state( 144 | next_obs[i], next_infos[i], 'reset', []) 145 | else: 146 | next_states[i] = self.agent.build_state( 147 | next_obs[i], next_infos[i]) 148 | 149 | next_valids[i] = self.agent.encode(next_infos[i]['valid']) 150 | 151 | return next_infos, next_states, next_valids, max_score, next_obs 152 | 153 | def _wrap_up_episode(self, info, env, max_score, transitions, i): 154 | """ 155 | Perform final logging, updating, and building for next episode. 156 | """ 157 | # Logging & update 158 | self.tb.logkv_mean('EpisodeScore', info['score']) 159 | if env.max_score >= max_score: 160 | for t in transitions[i]: 161 | self.agent.observe(t, is_prior=True) 162 | transitions[i] = [] 163 | self.env_steps += info["moves"] 164 | 165 | # Build ingredients for next step 166 | next_ob, next_info = env.reset() 167 | if self.use_action_model: 168 | next_state = self.agent.build_skip_state( 169 | next_ob, next_info, [], 'reset') 170 | else: 171 | next_state = self.agent.build_state(next_ob, next_info) 172 | next_valid = self.agent.encode(next_info['valid']) 173 | 174 | return next_state, next_valid, next_info 175 | 176 | def train(self): 177 | """ 178 | Train the agent. 179 | """ 180 | start = time.time() 181 | max_score, max_eval, self.env_steps = 0, 0, 0 182 | obs, infos, states, valid_ids, transitions = self.setup_env(self.envs) 183 | 184 | for step in range(1, self.max_steps + 1): 185 | print(self.envs.get_cache_size()) 186 | self.steps = step 187 | self.log("Step {}".format(step)) 188 | action_ids, action_idxs, action_qvals = self.agent.act(states, 189 | valid_ids, 190 | [info['valid'] 191 | for info in infos], 192 | sample=True) 193 | 194 | # Get the actual next action string for each env 195 | action_strs = [ 196 | info['valid'][idx] for info, idx in zip(infos, action_idxs) 197 | ] 198 | 199 | # Log envs[0] 200 | s = '' 201 | for idx, (act, val) in enumerate( 202 | sorted(zip(infos[0]['valid'], action_qvals[0]), 203 | key=lambda x: x[1], 204 | reverse=True), 1): 205 | s += "{}){:.2f} {} ".format(idx, val.item(), act) 206 | self.log('Q-Values: {}'.format(s)) 207 | 208 | # Update all envs 209 | infos, next_states, next_valids, max_score, obs = self.update_envs( 210 | action_strs, action_ids, states, max_score, transitions, obs, infos) 211 | states, valid_ids = next_states, next_valids 212 | 213 | self.end_step(step, start, max_score, action_qvals, max_eval) 214 | 215 | def end_step(self, step: int, start, max_score: int, action_qvals, 216 | max_eval: int): 217 | """ 218 | TODO 219 | """ 220 | if step % self.q_update_freq == 0: 221 | self.update_agent() 222 | 223 | if step % self.target_update_freq == 0: 224 | self.agent.transfer_weights() 225 | 226 | if step % self.log_freq == 0: 227 | # rank_metrics = self.evaluate_optimal() 228 | rank_metrics = dict() 229 | self.write_to_logs(step, start, self.envs, max_score, action_qvals, 230 | rank_metrics) 231 | 232 | # Save model weights etc. 233 | if step % self.checkpoint_freq == 0: 234 | self.agent.save(int(step / self.checkpoint_freq), 235 | self.top_k_traj if self.use_action_model else None) 236 | 237 | # Evaluate agent across several runs 238 | if step % self.eval_freq == 0: 239 | eval_score = self.evaluate(nb_episodes=10) 240 | wandb.log({ 241 | 'EvalScore': eval_score, 242 | 'Step': step, 243 | "Env Steps": self.env_steps 244 | }) 245 | if eval_score >= max_eval: 246 | max_eval = eval_score 247 | self.agent.save(step, is_best=True) 248 | 249 | if self.use_action_model: 250 | Ngram.end_step(self, step) 251 | 252 | def write_to_logs(self, step, start, envs, max_score, qvals, rank_metrics, 253 | *args): 254 | """ 255 | Log any relevant metrics. 256 | """ 257 | self.tb.logkv('Step', step) 258 | self.tb.logkv('Env Steps', self.env_steps) 259 | # self.tb.logkv('Beta', self.agent.network.beta) 260 | for key, val in rank_metrics.items(): 261 | self.tb.logkv(key, val) 262 | self.tb.logkv("FPS", int((step * len(envs)) / (time.time() - start))) 263 | self.tb.logkv("EpisodeScores100", self.envs.get_end_scores().mean()) 264 | self.tb.logkv('MaxScore', max_score) 265 | self.tb.logkv('#UniqueActs', self.envs.get_unique_acts()) 266 | self.tb.logkv('#CacheEntries', self.envs.get_cache_size()) 267 | 268 | if self.use_action_model: 269 | Ngram.log_metrics(self) 270 | 271 | self.tb.dumpkvs() 272 | -------------------------------------------------------------------------------- /trainers/trainer.py: -------------------------------------------------------------------------------- 1 | # Built-in Imports 2 | from typing import Union, List 3 | 4 | # Libraries 5 | 6 | # Custom imports 7 | from utils.util import get_name_from_path 8 | from utils.env import JerichoEnv 9 | from utils.vec_env import VecEnv 10 | 11 | OBJECTS_DIR = './saved_objects' 12 | 13 | 14 | class Trainer: 15 | """General trainer class. 16 | """ 17 | 18 | def __init__(self, tb, log, agent, envs, eval_env, args): 19 | self.tb = tb 20 | self.log = log 21 | self.agent = agent 22 | self.envs = envs 23 | self.eval_env = eval_env 24 | 25 | self.max_steps = args.max_steps 26 | self.log_freq = args.log_freq 27 | self.target_update_freq = args.target_update_freq 28 | self.q_update_freq = args.q_update_freq 29 | self.checkpoint_freq = args.checkpoint_freq 30 | self.eval_freq = args.eval_freq 31 | self.batch_size = args.batch_size 32 | self.game = get_name_from_path(args.rom_path) 33 | self.eps = args.eps 34 | self.eps_type = args.eps_type 35 | self.dynamic_episode_length = args.dynamic_episode_length 36 | 37 | self.steps = 0 38 | 39 | def train(self): 40 | """Trains the agent. 41 | 42 | Raises: 43 | NotImplementedError: implemented by child class 44 | """ 45 | raise NotImplementedError 46 | 47 | def setup_env(self): 48 | """Setup the environment. 49 | 50 | Raises: 51 | NotImplementedError: implemented by child class 52 | """ 53 | raise NotImplementedError 54 | 55 | def update_envs(self): 56 | """Step through all the envs. 57 | 58 | Raises: 59 | NotImplementedError: implemented by child class 60 | """ 61 | raise NotImplementedError 62 | 63 | def end_step(self, step: int): 64 | """Perform any logging, saving, evaluation, etc. 65 | that happens at the end of each step. 66 | 67 | Args: 68 | step (int): the current step number 69 | 70 | Raises: 71 | NotImplementedError: implemented by child class 72 | """ 73 | raise NotImplementedError 74 | 75 | def update_agent(self): 76 | """Update the agent with gradient descent. 77 | """ 78 | # Update 79 | loss = self.agent.update() 80 | 81 | # Log the loss 82 | if loss is not None: 83 | self.tb.logkv_mean('Loss', loss) 84 | 85 | def evaluate(self, nb_episodes: int = 3): 86 | """Evaluate the agent on several runs of the episodes and return the average reward. 87 | 88 | Args: 89 | nb_episodes (int, optional): number of episodes to average over. Defaults to 3. 90 | 91 | Raises: 92 | NotImplementedError: implemented by child class 93 | """ 94 | raise NotImplementedError 95 | 96 | def write_to_logs(self, step: int, start: float, envs: Union[List[JerichoEnv], VecEnv], max_score: int): 97 | """Write to loggers. 98 | 99 | Args: 100 | step (int): current step 101 | start (float): time at start of training 102 | envs (Union[List[JerichoEnv], VecEnv]): collections of environments 103 | max_score (int): maximum score seen so far 104 | 105 | Raises: 106 | NotImplementedError: implemented by child class 107 | """ 108 | raise NotImplementedError 109 | -------------------------------------------------------------------------------- /utils/drrn.py: -------------------------------------------------------------------------------- 1 | # Built-in imports 2 | from typing import Dict, Union, List 3 | from urllib.parse import uses_relative 4 | from numpy import isnan 5 | 6 | # Libraries 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.distributions import Categorical 11 | 12 | # Custom imports 13 | from utils import util 14 | import utils.ngram as Ngram 15 | from utils.memory import State, StateWithActs 16 | 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | 20 | def setup_env(self, envs): 21 | """ 22 | Setup the environment. 23 | """ 24 | obs, infos = envs.reset() 25 | if self.use_action_model: 26 | states = self.agent.build_states( 27 | obs, infos, ['reset'] * 8, [[]] * 8) 28 | else: 29 | states = self.agent.build_states(obs, infos) 30 | valid_ids = [self.agent.encode(info['valid']) for info in infos] 31 | transitions = [[] for info in infos] 32 | 33 | return obs, infos, states, valid_ids, transitions 34 | 35 | 36 | def act(model, 37 | states: List[Union[State, StateWithActs]], 38 | valid_ids, 39 | valid_strs, 40 | log, 41 | graph_masks=None): 42 | """ 43 | Returns an action-string, optionally sampling from the distribution 44 | of Q-Values. 45 | """ 46 | act_sizes = [len(valid) for valid in valid_ids] 47 | 48 | if model.sample_uniform: 49 | q_values = tuple(torch.ones(len(valid_id)).to(device) 50 | for valid_id in valid_ids) 51 | else: 52 | with torch.no_grad(): 53 | q_values = model.forward(states, valid_ids) 54 | if torch.any(torch.isnan(q_values[0])): 55 | log( 56 | f"Encountered nan!! State: {states[0]} Valid: {valid_ids[0]}") 57 | 58 | if model.use_action_model: 59 | act_values, betas = Ngram.action_model_forward( 60 | model, states, valid_strs, q_values, act_sizes) 61 | else: 62 | act_values = q_values 63 | 64 | if graph_masks is not None: 65 | log(f"Using graph mask: {graph_masks[0]}") 66 | log(f"Q values: {q_values[0]}") 67 | act_values = [qvals + mask for qvals, 68 | mask in zip(q_values, graph_masks)] 69 | log(f"Act values: {act_values[0]}") 70 | 71 | if model.sample_argmax: 72 | act_idxs = [torch.argmax(vals, dim=0) for vals in act_values] 73 | else: 74 | probs = [F.softmax(vals/model.T[i], dim=0) 75 | for i, vals in enumerate(act_values)] 76 | 77 | if model.use_action_model: 78 | for i, beta in enumerate(betas): 79 | if i == 0: 80 | model.log('Q dist {}'.format(probs[i])) 81 | if not beta: 82 | model.tb.logkv_mean( 83 | 'Q Entropy', Categorical(probs[i]).entropy()) 84 | model.tb.logkv_mean('Uniform entropy', Categorical( 85 | torch.ones_like(probs[i])/len(probs[i])).entropy()) 86 | else: 87 | model.log('Q dist {}'.format(probs[0])) 88 | model.tb.logkv_mean('Q Entropy', Categorical(probs[0]).entropy()) 89 | model.tb.logkv_mean('Uniform entropy', Categorical( 90 | torch.ones_like(probs[0])/len(probs[0])).entropy()) 91 | 92 | act_idxs = [ 93 | torch.multinomial(dist, num_samples=1).item() for dist in probs 94 | ] 95 | 96 | return act_idxs, act_values 97 | 98 | 99 | def init_model(model, args: Dict[str, Union[str, int, float]], vocab_size: int, tokenizer): 100 | model.use_drrn_inv_look = args.use_drrn_inv_look 101 | 102 | model.embedding = nn.Embedding(vocab_size, args.drrn_embedding_dim) 103 | model.drrn_hidden_dim = args.drrn_hidden_dim 104 | model.tokenizer = tokenizer 105 | 106 | model.obs_encoder = nn.GRU( 107 | args.drrn_embedding_dim, args.drrn_hidden_dim) 108 | model.look_encoder = nn.GRU( 109 | args.drrn_embedding_dim, args.drrn_hidden_dim) 110 | model.inv_encoder = nn.GRU( 111 | args.drrn_embedding_dim, args.drrn_hidden_dim) 112 | model.act_encoder = nn.GRU( 113 | args.drrn_embedding_dim, args.drrn_hidden_dim) 114 | 115 | if model.use_drrn_inv_look: 116 | model.hidden = nn.Linear( 117 | 4 * args.drrn_hidden_dim, args.drrn_hidden_dim) 118 | else: 119 | model.hidden = nn.Linear( 120 | 2 * args.drrn_hidden_dim, args.drrn_hidden_dim) 121 | 122 | model.act_scorer = nn.Linear(args.drrn_hidden_dim, 1) 123 | 124 | model.T = [args.T for _ in range(args.num_envs)] 125 | 126 | model.augment_state_with_score = args.augment_state_with_score 127 | model.hash_rep = args.hash_rep 128 | model.hash_cache = {} 129 | 130 | 131 | def packed_hash(self, x): 132 | y = [] 133 | for data in x: 134 | data = hash(tuple(data)) 135 | if data in self.hash_cache: 136 | y.append(self.hash_cache[data]) 137 | else: 138 | a = torch.zeros(self.drrn_hidden_dim).normal_( 139 | generator=torch.random.manual_seed(data)) 140 | # torch.random.seed() 141 | y.append(a) 142 | self.hash_cache[data] = a 143 | y = torch.stack(y, dim=0).to(device) 144 | return y 145 | 146 | 147 | def packed_rnn(model, x, rnn): 148 | """ Runs the provided rnn on the input x. Takes care of packing/unpacking. 149 | 150 | x: list of unpadded input sequences 151 | Returns a tensor of size: len(x) x hidden_dim 152 | """ 153 | if model.hash_rep: 154 | return packed_hash(model, x) 155 | 156 | lengths = torch.tensor([len(n) for n in x], 157 | dtype=torch.long, 158 | device=device) 159 | 160 | # Sort this batch in descending order by seq length 161 | lengths, idx_sort = torch.sort(lengths, dim=0, descending=True) 162 | _, idx_unsort = torch.sort(idx_sort, dim=0) 163 | idx_sort = torch.autograd.Variable(idx_sort) 164 | idx_unsort = torch.autograd.Variable(idx_unsort) 165 | 166 | # Pads to longest action 167 | padded_x = util.pad_sequences(x) 168 | # print("padded x", padded_x) 169 | # print("padded x shape", padded_x.shape) 170 | x_tt = torch.from_numpy(padded_x).type(torch.long).to(device) 171 | x_tt = x_tt.index_select(0, idx_sort) 172 | 173 | # Run the embedding layer 174 | embed = model.embedding(x_tt).permute(1, 0, 2) # Time x Batch x EncDim 175 | 176 | # Pack padded batch of sequences for RNN module 177 | packed = nn.utils.rnn.pack_padded_sequence(embed, lengths.cpu()) 178 | 179 | # Run the RNN 180 | out, _ = rnn(packed) 181 | 182 | # Unpack 183 | out, _ = nn.utils.rnn.pad_packed_sequence(out) 184 | # print("out", out) 185 | # print("out shape", out.shape) 186 | 187 | # Get the last step of each sequence 188 | # print("lengths", lengths) 189 | # print("lengths view", (lengths-1).view(-1, 1)) 190 | # print("out size 2", out.size(2)) 191 | # print("lengths view expanded", (lengths-1).view(-1,1).expand(len(lengths), out.size(2)).unsqueeze(0)) 192 | idx = (lengths - 1).view(-1, 1).expand(len(lengths), 193 | out.size(2)).unsqueeze(0) 194 | out = out.gather(0, idx).squeeze(0) 195 | 196 | # Unsort 197 | out = out.index_select(0, idx_unsort) 198 | return out 199 | -------------------------------------------------------------------------------- /utils/env.py: -------------------------------------------------------------------------------- 1 | # Built-in imports 2 | 3 | # Libraries 4 | from jericho import * 5 | from jericho.util import * 6 | from jericho.defines import * 7 | 8 | # Custom imports 9 | 10 | class JerichoEnv: 11 | ''' Returns valid actions at each step of the game. ''' 12 | 13 | def __init__(self, 14 | rom_path, 15 | step_limit=None, 16 | get_valid=True, 17 | cache=None, 18 | seed=None, 19 | start_from_reward=0, 20 | start_from_wt=0, 21 | log=None, 22 | args=None): 23 | self.rom_path = rom_path 24 | self.env = FrotzEnv(rom_path, seed=seed) 25 | self.bindings = self.env.bindings 26 | self.steps = 0 27 | self.step_limit = step_limit 28 | self.get_valid = get_valid 29 | self.max_score = 0 30 | self.end_scores = [] 31 | self.cache = cache 32 | self.traj = [] 33 | self.full_traj = [] 34 | self.on_trajectory = True 35 | self.start_from_reward = start_from_reward 36 | self.start_from_wt = start_from_wt 37 | 38 | self.log = log 39 | self.cache_hits = 0 40 | self.ngram_hits = 0 41 | self.ngram_needs_update = False 42 | self.filter_drop_acts = args.filter_drop_acts 43 | self.args = args 44 | 45 | def get_objects(self): 46 | desc2objs = self.env._identify_interactive_objects( 47 | use_object_tree=False) 48 | obj_set = set() 49 | for objs in desc2objs.values(): 50 | for obj, pos, source in objs: 51 | if pos == 'ADJ': 52 | continue 53 | obj_set.add(obj) 54 | return list(obj_set) 55 | 56 | def _get_state_hash(self, ob): 57 | return self.env.get_world_state_hash() 58 | 59 | def step(self, action): 60 | ob, reward, done, info = self.env.step(action) 61 | # return ob, reward, done, info 62 | 63 | # Initialize with default values 64 | info['look'] = 'unknown' 65 | info['inv'] = 'unknown' 66 | info['valid'] = ['wait', 'yes', 'no'] 67 | if not done: 68 | save = self.env.get_state() 69 | hash_save = self._get_state_hash(ob) 70 | if self.cache is not None and hash_save in self.cache: 71 | info['look'], info['inv'], info['valid'] = self.cache[ 72 | hash_save] 73 | self.cache_hits += 1 74 | else: 75 | look, _, _, _ = self.env.step('look') 76 | info['look'] = look.lower() 77 | self.env.set_state(save) 78 | inv, _, _, _ = self.env.step('inventory') 79 | info['inv'] = inv.lower() 80 | self.env.set_state(save) 81 | if self.get_valid: 82 | valid = self.env.get_valid_actions() 83 | if len(valid) == 0: 84 | valid = ['wait', 'yes', 'no'] 85 | info['valid'] = valid 86 | if self.cache is not None: 87 | self.cache[hash_save] = info['look'], info['inv'], info[ 88 | 'valid'] 89 | 90 | self.steps += 1 91 | if self.step_limit and self.steps >= self.step_limit: 92 | done = True 93 | self.max_score = max(self.max_score, info['score']) 94 | if done: 95 | self.end_scores.append(info['score']) 96 | return ob, reward, done, info 97 | 98 | def reset(self): 99 | initial_ob, info = self.env.reset() 100 | 101 | rewards_encountered = 0 102 | walkthrough = self.env.get_walkthrough() 103 | 104 | for act in walkthrough: 105 | if rewards_encountered >= self.start_from_reward: 106 | break 107 | initial_ob, reward, _, info = self.env.step(act) 108 | if reward > 0: 109 | rewards_encountered += 1 110 | 111 | for act in walkthrough[:self.start_from_wt]: 112 | initial_ob, reward, _, info = self.env.step(act) 113 | 114 | save = self.env.get_state() 115 | look, _, _, _ = self.env.step('look') 116 | info['look'] = look 117 | self.env.set_state(save) 118 | inv, _, _, _ = self.env.step('inventory') 119 | info['inv'] = inv 120 | self.env.set_state(save) 121 | valid = self.env.get_valid_actions() 122 | info['valid'] = valid 123 | self.steps = 0 124 | self.max_score = 0 125 | self.ngram_hits = 0 126 | self.traj = [] 127 | self.full_traj = [] 128 | self.on_trajectory = True 129 | return initial_ob, info 130 | 131 | def turn_off_trajectory(self): 132 | self.on_trajectory = False 133 | 134 | def get_trajectory_state(self): 135 | return self.on_trajectory 136 | 137 | def get_dictionary(self): 138 | if not self.env: 139 | self.create() 140 | return self.env.get_dictionary() 141 | 142 | def get_action_set(self): 143 | return None 144 | 145 | def get_end_scores(self, last=1): 146 | last = min(last, len(self.end_scores)) 147 | return sum(self.end_scores[-last:]) / last if last else 0 148 | 149 | def close(self): 150 | self.env.close() 151 | 152 | def get_walkthrough(self): 153 | return self.env.get_walkthrough() 154 | 155 | def get_score(self): 156 | return self.env.get_score() 157 | -------------------------------------------------------------------------------- /utils/il_buffer.py: -------------------------------------------------------------------------------- 1 | # Built-in Imports 2 | import time 3 | from typing import Dict, Union, Callable, List 4 | 5 | # Libraries 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | import wandb 10 | 11 | # Custom Imports 12 | from agents import DrrnInvDynAgent 13 | 14 | import utils.logger as logger 15 | from utils.vec_env import VecEnv 16 | from utils.util import process_action, convert_idxs_to_strs 17 | 18 | 19 | class ILBuffer(): 20 | def __init__(self, agent: DrrnInvDynAgent, args, log, tb): 21 | self.agent = agent 22 | self.traj_collection = [] 23 | self.graph_score_temp = args.graph_score_temp 24 | self.graph_q_temp = args.graph_q_temp 25 | self.log = log 26 | self.tb = tb 27 | 28 | def add_traj(self, traj): 29 | self.traj_collection.append(traj) 30 | 31 | def bin_traj_by_scores(self): 32 | unknown_tok = self.agent.encode(['unknown'])[0] 33 | score_bins = dict() 34 | total = 0 35 | for traj in self.traj_collection: 36 | visited = set() 37 | for i in range(len(traj)): 38 | desc = traj[i].next_state.description 39 | inv = traj[i].next_state.inventory 40 | score = traj[i].next_state.score 41 | if desc == unknown_tok or inv == unknown_tok: 42 | self.log('Encountered unknown token!') 43 | 44 | if score in score_bins and score not in visited and desc != unknown_tok and inv != unknown_tok: 45 | score_bins[traj[i].next_state.score].append( 46 | traj[:i + 1]) 47 | visited.add(score) 48 | elif score not in score_bins and desc != unknown_tok and inv != unknown_tok: 49 | score_bins[traj[i].next_state.score] = [traj[:i + 1]] 50 | visited.add(score) 51 | 52 | total += len(visited) 53 | 54 | assert total == sum(len(trajs) for trajs in score_bins.values()) 55 | 56 | return score_bins 57 | 58 | def sample_trajs(self, k=5): 59 | if len(self.traj_collection) == 0: 60 | return [[]] 61 | 62 | start = time.time() 63 | 64 | score_bins = self.bin_traj_by_scores() 65 | 66 | max_score = max(score_bins.keys()) 67 | 68 | if len(score_bins) == 0: 69 | return [[]] 70 | 71 | scores = torch.tensor(sorted(score_bins.keys())).type(torch.float32) 72 | m_score = scores.mean() 73 | std_score = torch.std(scores) if not torch.allclose( 74 | scores - scores[0], torch.zeros_like(scores)) else 1 75 | norm_scores = (scores - m_score)/std_score 76 | score_probs = F.softmax(self.graph_score_temp * norm_scores) 77 | score_idxs = torch.multinomial( 78 | score_probs, num_samples=k, replacement=True) 79 | sampled_scores = [int(scores[score_idx.item()].item()) 80 | for score_idx in score_idxs] 81 | 82 | traj_states = [] 83 | traj_acts = [] 84 | traj_lens = [] 85 | traj_norm_lens = [] 86 | for sampled_score in sampled_scores: 87 | all_lens = [] 88 | for traj in score_bins[sampled_score]: 89 | all_lens.append(len(traj)) 90 | 91 | all_lens = torch.tensor(all_lens).type(torch.float32) 92 | m_len = all_lens.mean() 93 | std_m_len = torch.std(all_lens) if not torch.allclose( 94 | all_lens - all_lens[0], torch.zeros_like(all_lens)) else 1 95 | norm_lens = (all_lens - m_len)/std_m_len 96 | 97 | probs = F.softmax(self.graph_q_temp * (-1) * norm_lens, dim=0) 98 | 99 | traj_idx = torch.multinomial(probs, num_samples=1).item() 100 | sampled_traj = score_bins[sampled_score][traj_idx] 101 | 102 | traj_states.append( 103 | [sampled_traj[0].state] + [transition.next_state for transition in sampled_traj]) 104 | traj_acts.append([transition.act for transition in sampled_traj]) 105 | traj_lens.append(all_lens[traj_idx]) 106 | traj_norm_lens.append(norm_lens[traj_idx]) 107 | 108 | max_len = max(traj_lens) 109 | 110 | for traj_state, traj_act, traj_len, traj_norm_len in zip(traj_states, traj_acts, traj_lens, traj_norm_lens): 111 | last_state = traj_state[-1] 112 | obs = convert_idxs_to_strs( 113 | [last_state.obs[1:-1]], self.agent.tokenizer)[0] 114 | self.log("Returning to:") 115 | self.log( 116 | f"Location: {convert_idxs_to_strs([last_state.description[1:-1]], self.agent.tokenizer)[0]}, \ 117 | Inventory: {convert_idxs_to_strs([last_state.inventory[1:-1]], self.agent.tokenizer)[0]}, \ 118 | Observation: {obs}, \ 119 | Score: {last_state.score}, Len: {traj_len}, NormLen: {traj_norm_len} \ 120 | Path: {convert_idxs_to_strs(list(map(lambda x: x[1:-1], traj_act)), self.agent.tokenizer)}" 121 | ) 122 | 123 | # assert "you have died" not in obs 124 | 125 | self.tb.logkv_mean("ILTrainDataMean", last_state.score) 126 | self.tb.logkv_mean( 127 | "Hit@Max", last_state.score == max_score) 128 | self.tb.logkv_mean("AvgILTrainDataLen", traj_len) 129 | self.tb.logkv_mean("MaxILTrainDataLen", max_len) 130 | 131 | end = time.time() 132 | 133 | self.tb.logkv_mean("SampleTrajTime", end - start) 134 | self.tb.logkv("TotalNumTrajectories", len(self.traj_collection)) 135 | 136 | return traj_states, traj_acts 137 | -------------------------------------------------------------------------------- /utils/inv_dyn.py: -------------------------------------------------------------------------------- 1 | # Library imports 2 | import itertools 3 | from typing import List 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | # Custom imports 10 | from utils import util 11 | from utils.memory import StateWithActs, State 12 | 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | 15 | 16 | def init(model, args, vocab_size): 17 | model.embedding = nn.Embedding(vocab_size, args.drrn_embedding_dim) 18 | model.obs_encoder = nn.GRU(args.drrn_embedding_dim, args.drrn_hidden_dim) 19 | model.look_encoder = nn.GRU(args.drrn_embedding_dim, args.drrn_hidden_dim) 20 | model.inv_encoder = nn.GRU(args.drrn_embedding_dim, args.drrn_hidden_dim) 21 | model.act_encoder = nn.GRU(args.drrn_embedding_dim, args.drrn_hidden_dim) 22 | model.act_scorer = nn.Linear(args.drrn_hidden_dim, 1) 23 | 24 | model.hidden_dim = args.drrn_hidden_dim 25 | model.hidden = nn.Linear(2 * args.drrn_hidden_dim, args.drrn_hidden_dim) 26 | # model.hidden = nn.Sequential(nn.Linear(2 * hidden_dim, 2 * hidden_dim), nn.Linear(2 * hidden_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim)) 27 | 28 | model.state_encoder = nn.Linear( 29 | 3 * args.drrn_hidden_dim, args.drrn_hidden_dim) 30 | model.inverse_dynamics = nn.Sequential(nn.Linear( 31 | 2 * args.drrn_hidden_dim, 2 * args.drrn_hidden_dim), nn.ReLU(), nn.Linear(2 * args.drrn_hidden_dim, args.drrn_hidden_dim)) 32 | model.forward_dynamics = nn.Sequential(nn.Linear( 33 | 2 * args.drrn_hidden_dim, 2 * args.drrn_hidden_dim), nn.ReLU(), nn.Linear(2 * args.drrn_hidden_dim, args.drrn_hidden_dim)) 34 | 35 | model.act_decoder = nn.GRU(args.drrn_hidden_dim, args.drrn_embedding_dim) 36 | model.act_fc = nn.Linear(args.drrn_embedding_dim, vocab_size) 37 | 38 | model.obs_decoder = nn.GRU(args.drrn_hidden_dim, args.drrn_embedding_dim) 39 | model.obs_fc = nn.Linear(args.drrn_embedding_dim, vocab_size) 40 | 41 | model.fix_rep = args.fix_rep 42 | model.hash_rep = args.hash_rep 43 | model.act_obs = args.act_obs 44 | model.hash_cache = {} 45 | 46 | 47 | def packed_hash(self, x): 48 | import pdb 49 | pdb.set_trace() 50 | y = [] 51 | for data in x: 52 | data = hash(tuple(data)) 53 | if data in self.hash_cache: 54 | y.append(self.hash_cache[data]) 55 | else: 56 | a = torch.zeros(self.hidden_dim).normal_( 57 | generator=torch.random.manual_seed(data)) 58 | # torch.random.seed() 59 | y.append(a) 60 | self.hash_cache[data] = a 61 | y = torch.stack(y, dim=0).to(device) 62 | return y 63 | 64 | 65 | def packed_rnn(model, x, rnn): 66 | """ Runs the provided rnn on the input x. Takes care of packing/unpacking. 67 | 68 | x: list of unpadded input sequences 69 | Returns a tensor of size: len(x) x hidden_dim 70 | """ 71 | if model.hash_rep: 72 | return packed_hash(model, x) 73 | lengths = torch.tensor([len(n) for n in x], 74 | dtype=torch.long, 75 | device=device) 76 | 77 | # Sort this batch in descending order by seq length 78 | lengths, idx_sort = torch.sort(lengths, dim=0, descending=True) 79 | _, idx_unsort = torch.sort(idx_sort, dim=0) 80 | idx_sort = torch.autograd.Variable(idx_sort) 81 | idx_unsort = torch.autograd.Variable(idx_unsort) 82 | 83 | # Pads to longest action 84 | padded_x = util.pad_sequences(x) 85 | # print("padded x", padded_x) 86 | # print("padded x shape", padded_x.shape) 87 | x_tt = torch.from_numpy(padded_x).type(torch.long).to(device) 88 | x_tt = x_tt.index_select(0, idx_sort) 89 | 90 | # Run the embedding layer 91 | embed = model.embedding(x_tt).permute(1, 0, 2) # Time x Batch x EncDim 92 | 93 | # Pack padded batch of sequences for RNN module 94 | packed = nn.utils.rnn.pack_padded_sequence(embed, lengths.cpu()) 95 | 96 | # Run the RNN 97 | out, _ = rnn(packed) 98 | 99 | # Unpack 100 | out, _ = nn.utils.rnn.pad_packed_sequence(out) 101 | # print("out", out) 102 | # print("out shape", out.shape) 103 | 104 | # Get the last step of each sequence 105 | # print("lengths", lengths) 106 | # print("lengths view", (lengths-1).view(-1, 1)) 107 | # print("out size 2", out.size(2)) 108 | # print("lengths view expanded", (lengths-1).view(-1,1).expand(len(lengths), out.size(2)).unsqueeze(0)) 109 | idx = (lengths - 1).view(-1, 1).expand(len(lengths), 110 | out.size(2)).unsqueeze(0) 111 | out = out.gather(0, idx).squeeze(0) 112 | 113 | # Unsort 114 | out = out.index_select(0, idx_unsort) 115 | return out 116 | 117 | 118 | def state_rep(model, state_batch: List[StateWithActs]): 119 | # Zip the state_batch into an easy access format 120 | class_name = util.get_class_name(model).lower() 121 | 122 | if 'drrn' in class_name and model.use_action_model: 123 | state = StateWithActs(*zip(*state_batch)) 124 | elif 'drrn' in class_name and not model.use_action_model: 125 | state = State(*zip(*state_batch)) 126 | 127 | # Encode the various aspects of the state 128 | with torch.set_grad_enabled(not model.fix_rep): 129 | obs_out = packed_rnn(model, state.obs, model.obs_encoder) 130 | if model.act_obs: 131 | return obs_out 132 | look_out = packed_rnn(model, state.description, model.look_encoder) 133 | inv_out = packed_rnn(model, state.inventory, model.inv_encoder) 134 | if model.augment_state_with_score: 135 | scores = torch.tensor(state.score).unsqueeze(1).to(device) 136 | state_out = model.state_encoder( 137 | torch.cat((obs_out, look_out, inv_out, scores), dim=1)) 138 | else: 139 | state_out = model.state_encoder( 140 | torch.cat((obs_out, look_out, inv_out), dim=1)) 141 | return state_out 142 | 143 | 144 | def act_rep(model, act_batch): 145 | # This is number of admissible commands in each element of the batch 146 | act_sizes = [len(a) for a in act_batch] 147 | # Combine next actions into one long list 148 | act_batch = list(itertools.chain.from_iterable(act_batch)) 149 | with torch.set_grad_enabled(not model.fix_rep): 150 | act_out = packed_rnn(model, act_batch, model.act_encoder) 151 | return act_sizes, act_out 152 | 153 | 154 | def for_predict(model, state_batch, acts): 155 | _, act_out = act_rep(model, acts) 156 | state_out = state_rep(model, state_batch) 157 | next_state_out = state_out + \ 158 | model.forward_dynamics(torch.cat((state_out, act_out), dim=1)) 159 | return next_state_out 160 | 161 | 162 | def inv_predict(model, state_batch, next_state_batch): 163 | state_out = state_rep(model, state_batch) 164 | next_state_out = state_rep(model, next_state_batch) 165 | act_out = model.inverse_dynamics( 166 | torch.cat((state_out, next_state_out - state_out), dim=1)) 167 | return act_out 168 | 169 | 170 | def inv_loss_l1(model, state_batch, next_state_batch, acts): 171 | _, act_out = act_rep(model, acts) 172 | act_out_hat = inv_predict(model, state_batch, next_state_batch) 173 | return F.l1_loss(act_out, act_out_hat) 174 | 175 | 176 | def inv_loss_l2(model, state_batch, next_state_batch, acts): 177 | _, act_out = act_rep(model, acts) 178 | act_out_hat = inv_predict(model, state_batch, next_state_batch) 179 | return F.mse_loss(act_out, act_out_hat) 180 | 181 | 182 | def inv_loss_ce(model, state_batch, next_state_batch, acts, valids, get_predict=False): 183 | act_sizes, valids_out = act_rep(model, valids) 184 | _, act_out = act_rep(model, acts) 185 | act_out_hat = inv_predict(model, state_batch, next_state_batch) 186 | now, loss, acc = 0, 0, 0 187 | if get_predict: 188 | predicts = [] 189 | for i, j in enumerate(act_sizes): 190 | valid_out = valids_out[now: now + j] 191 | now += j 192 | values = valid_out.matmul(act_out_hat[i]) 193 | label = valids[i].index(acts[i][0]) 194 | loss += F.cross_entropy(values.unsqueeze(0), 195 | torch.LongTensor([label]).to(device)) 196 | predict = values.argmax().item() 197 | acc += predict == label 198 | if get_predict: 199 | predicts.append(predict) 200 | return (loss / len(act_sizes), acc / len(act_sizes), predicts) if get_predict else (loss / len(act_sizes), acc / len(act_sizes)) 201 | 202 | 203 | def inv_loss_decode(model, state_batch, next_state_batch, acts, hat=True, reduction='mean'): 204 | # hat: use rep(o), rep(o'); not hat: use rep(a) 205 | _, act_out = act_rep(model, acts) 206 | act_out_hat = inv_predict(model, state_batch, next_state_batch) 207 | 208 | acts_pad = util.pad_sequences([act[0] for act in acts]) 209 | acts_tensor = torch.from_numpy(acts_pad).type( 210 | torch.long).to(device).transpose(0, 1) 211 | l, bs = acts_tensor.size() 212 | vocab = model.embedding.num_embeddings 213 | outputs = torch.zeros(l, bs, vocab).to(device) 214 | input, z = acts_tensor[0].unsqueeze( 215 | 0), (act_out_hat if hat else act_out).unsqueeze(0) 216 | for t in range(1, l): 217 | input = model.embedding(input) 218 | output, z = model.act_decoder(input, z) 219 | output = model.act_fc(output) 220 | outputs[t] = output 221 | top = output.argmax(2) 222 | input = top 223 | outputs, acts_tensor = outputs[1:], acts_tensor[1:] 224 | loss = F.cross_entropy(outputs.reshape(-1, vocab), 225 | acts_tensor.reshape(-1), ignore_index=0, reduction=reduction) 226 | if reduction == 'none': # loss for each term in batch 227 | lens = [len(act[0]) - 1 for act in acts] 228 | loss = loss.reshape(-1, bs).sum(0).cpu() / torch.tensor(lens) 229 | nonzero = (acts_tensor > 0) 230 | same = (outputs.argmax(-1) == acts_tensor) 231 | acc_token = (same & nonzero).float().sum() / \ 232 | (nonzero).float().sum() # token accuracy 233 | acc_action = (same.int().sum(0) == nonzero.int().sum( 234 | 0)).float().sum() / same.size(1) # action accuracy 235 | return loss, acc_action 236 | 237 | 238 | def for_loss_l2(model, state_batch, next_state_batch, acts): 239 | next_state_out = state_rep(model, next_state_batch) 240 | next_state_out_hat = for_predict(model, state_batch, acts) 241 | return F.mse_loss(next_state_out, next_state_out_hat) # , reduction='sum') 242 | 243 | 244 | def for_loss_ce_batch(model, state_batch, next_state_batch, acts): 245 | # consider duplicates in next_state_batch 246 | next_states, labels = [], [] 247 | for next_state in next_state_batch: 248 | if next_state not in next_states: 249 | labels.append(len(next_states)) 250 | next_states.append(next_state) 251 | else: 252 | labels.append(next_states.index(next_state)) 253 | labels = torch.LongTensor(labels).to(device) 254 | next_state_out = state_rep(model, next_states) 255 | next_state_out_hat = for_predict(model, state_batch, acts) 256 | logits = next_state_out_hat.matmul(next_state_out.transpose(0, 1)) 257 | loss = F.cross_entropy(logits, labels) 258 | acc = (logits.argmax(1) == labels).float().sum() / len(labels) 259 | return loss, acc 260 | 261 | 262 | def for_loss_ce(model, state_batch, next_state_batch, acts, valids): 263 | # classify rep(o') from predict(o, a1), predict(o, a2), ... 264 | act_sizes, valids_out = act_rep(model, valids) 265 | _, act_out = act_rep(model, acts) 266 | next_state_out = state_rep(model, next_state_batch) 267 | now, loss, acc = 0, 0, 0 268 | for i, j in enumerate(act_sizes): 269 | valid_out = valids_out[now: now + j] 270 | now += j 271 | next_states_out_hat = for_predict( 272 | model, [state_batch[i]] * j, [[_] for _ in valids[i]]) 273 | values = next_states_out_hat.matmul(next_state_out[i]) 274 | label = valids[i].index(acts[i][0]) 275 | loss += F.cross_entropy(values.unsqueeze(0), 276 | torch.LongTensor([label]).to(device)) 277 | predict = values.argmax().item() 278 | acc += predict == label 279 | return (loss / len(act_sizes), acc / len(act_sizes)) 280 | 281 | 282 | def for_loss_decode(model, state_batch, next_state_batch, acts, hat=True): 283 | # hat: use rep(o), rep(a); not hat: use rep(o') 284 | next_state_out = state_rep(model, next_state_batch) 285 | next_state_out_hat = for_predict(model, state_batch, acts) 286 | 287 | import pdb 288 | pdb.set_trace() 289 | next_state_pad = util.pad_sequences(next_state_batch) 290 | next_state_tensor = torch.from_numpy(next_state_batch).type( 291 | torch.long).to(device).transpose(0, 1) 292 | l, bs = next_state_tensor.size() 293 | vocab = model.embedding.num_embeddings 294 | outputs = torch.zeros(l, bs, vocab).to(device) 295 | input, z = next_state_tensor[0].unsqueeze( 296 | 0), (next_state_out_hat if hat else next_state_out).unsqueeze(0) 297 | for t in range(1, l): 298 | input = model.embedding(input) 299 | output, z = model.obs_decoder(input, z) 300 | output = model.obs_fc(output) 301 | outputs[t] = output 302 | top = output.argmax(2) 303 | input = top 304 | outputs, next_state_tensor = outputs[1:].reshape( 305 | -1, vocab), next_state_tensor[1:].reshape(-1) 306 | loss = F.cross_entropy(outputs, next_state_tensor, ignore_index=0) 307 | nonzero = (next_state_tensor > 0) 308 | same = (outputs.argmax(1) == next_state_tensor) 309 | acc = (same & nonzero).float().sum() / \ 310 | (nonzero).float().sum() # token accuracy 311 | return loss, acc 312 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # Built-in Imports 2 | import os 3 | import sys 4 | import shutil 5 | import os.path as osp 6 | import json 7 | import time 8 | import datetime 9 | import tempfile 10 | from collections import defaultdict 11 | 12 | # Libraries 13 | import wandb 14 | 15 | DEBUG = 10 16 | INFO = 20 17 | WARN = 30 18 | ERROR = 40 19 | 20 | DISABLED = 50 21 | 22 | 23 | class KVWriter(object): 24 | def writekvs(self, kvs): 25 | raise NotImplementedError 26 | 27 | 28 | class SeqWriter(object): 29 | def writeseq(self, seq): 30 | raise NotImplementedError 31 | 32 | 33 | class HumanOutputFormat(KVWriter, SeqWriter): 34 | def __init__(self, filename_or_file): 35 | if isinstance(filename_or_file, str): 36 | self.file = open(filename_or_file, 'wt') 37 | self.own_file = True 38 | else: 39 | assert hasattr( 40 | filename_or_file, 'read'), 'expected file or str, got %s' % filename_or_file 41 | self.file = filename_or_file 42 | self.own_file = False 43 | 44 | def writekvs(self, kvs): 45 | # Create strings for printing 46 | key2str = {} 47 | for (key, val) in sorted(kvs.items()): 48 | if isinstance(val, float): 49 | valstr = '%-8.3g' % (val,) 50 | else: 51 | valstr = str(val) 52 | key2str[self._truncate(key)] = self._truncate(valstr) 53 | 54 | # Find max widths 55 | if len(key2str) == 0: 56 | print('WARNING: tried to write empty key-value dict') 57 | return 58 | else: 59 | keywidth = max(map(len, key2str.keys())) 60 | valwidth = max(map(len, key2str.values())) 61 | 62 | # Write out the data 63 | dashes = '-' * (keywidth + valwidth + 7) 64 | lines = [dashes] 65 | for (key, val) in sorted(key2str.items()): 66 | lines.append('| %s%s | %s%s |' % ( 67 | key, 68 | ' ' * (keywidth - len(key)), 69 | val, 70 | ' ' * (valwidth - len(val)), 71 | )) 72 | lines.append(dashes) 73 | self.file.write('\n'.join(lines) + '\n') 74 | 75 | # Flush the output to the file 76 | self.file.flush() 77 | 78 | def _truncate(self, s): 79 | return s[:20] + '...' if len(s) > 23 else s 80 | 81 | def writeseq(self, seq): 82 | seq = list(seq) 83 | for (i, elem) in enumerate(seq): 84 | self.file.write(elem) 85 | if i < len(seq) - 1: # add space unless this is the last one 86 | self.file.write(' ') 87 | self.file.write('\n') 88 | self.file.flush() 89 | 90 | def close(self): 91 | if self.own_file: 92 | self.file.close() 93 | 94 | 95 | class JSONOutputFormat(KVWriter): 96 | def __init__(self, filename): 97 | self.file = open(filename, 'wt') 98 | 99 | def writekvs(self, kvs): 100 | for k, v in sorted(kvs.items()): 101 | if hasattr(v, 'dtype'): 102 | v = v.tolist() 103 | kvs[k] = float(v) 104 | self.file.write(json.dumps(kvs) + '\n') 105 | self.file.flush() 106 | 107 | def close(self): 108 | self.file.close() 109 | 110 | 111 | class WandBOutputFormat(KVWriter): 112 | def __init__(self, filename, args): 113 | wandb.init(project=f'{args.project_name}', 114 | name=filename.split('/')[-1], anonymous="allow") 115 | wandb.config.update(args) 116 | wandb.run.tags = wandb.run.tags + (args.jericho_add_wt,) 117 | 118 | def writekvs(self, kvs): 119 | wandb.log(kvs) 120 | 121 | def close(self): 122 | pass 123 | 124 | 125 | class CSVOutputFormat(KVWriter): 126 | def __init__(self, filename): 127 | self.file = open(filename, 'w+t') 128 | self.keys = [] 129 | self.sep = ',' 130 | 131 | def writekvs(self, kvs): 132 | # Add our current row to the history 133 | extra_keys = kvs.keys() - self.keys 134 | if extra_keys: 135 | self.keys.extend(extra_keys) 136 | self.file.seek(0) 137 | lines = self.file.readlines() 138 | self.file.seek(0) 139 | for (i, k) in enumerate(self.keys): 140 | if i > 0: 141 | self.file.write(',') 142 | self.file.write(k) 143 | self.file.write('\n') 144 | for line in lines[1:]: 145 | self.file.write(line[:-1]) 146 | self.file.write(self.sep * len(extra_keys)) 147 | self.file.write('\n') 148 | for (i, k) in enumerate(self.keys): 149 | if i > 0: 150 | self.file.write(',') 151 | v = kvs.get(k) 152 | if v is not None: 153 | self.file.write(str(v)) 154 | self.file.write('\n') 155 | self.file.flush() 156 | 157 | def close(self): 158 | self.file.close() 159 | 160 | 161 | class TensorBoardOutputFormat(KVWriter): 162 | """ 163 | Dumps key/value pairs into TensorBoard's numeric format. 164 | """ 165 | 166 | def __init__(self, dir): 167 | os.makedirs(dir, exist_ok=True) 168 | self.dir = dir 169 | self.step = 1 170 | prefix = 'events' 171 | path = osp.join(osp.abspath(dir), prefix) 172 | import tensorflow as tf 173 | from tensorflow.python import pywrap_tensorflow 174 | from tensorflow.core.util import event_pb2 175 | from tensorflow.python.util import compat 176 | self.tf = tf 177 | self.event_pb2 = event_pb2 178 | self.pywrap_tensorflow = pywrap_tensorflow 179 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 180 | 181 | def writekvs(self, kvs): 182 | def summary_val(k, v): 183 | kwargs = {'tag': k, 'simple_value': float(v)} 184 | return self.tf.Summary.Value(**kwargs) 185 | 186 | summary = self.tf.Summary( 187 | value=[summary_val(k, v) for k, v in kvs.items()]) 188 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 189 | event.step = self.step # is there any reason why you'd want to specify the step? 190 | self.writer.WriteEvent(event) 191 | self.writer.Flush() 192 | self.step += 1 193 | 194 | def close(self): 195 | if self.writer: 196 | self.writer.Close() 197 | self.writer = None 198 | 199 | 200 | def make_output_format(format, ev_dir, log_suffix='', args=None): 201 | os.makedirs(ev_dir, exist_ok=True) 202 | if format == 'stdout': 203 | return HumanOutputFormat(sys.stdout) 204 | elif format == 'log': 205 | return HumanOutputFormat(osp.join(ev_dir, 'log%s.txt' % log_suffix)) 206 | elif format == 'json': 207 | return JSONOutputFormat(osp.join(ev_dir, 'progress%s.json' % log_suffix)) 208 | elif format == 'csv': 209 | return CSVOutputFormat(osp.join(ev_dir, 'progress%s.csv' % log_suffix)) 210 | elif format == 'tensorboard': 211 | return TensorBoardOutputFormat(osp.join(ev_dir, 'tb%s' % log_suffix)) 212 | elif format == 'wandb': 213 | return WandBOutputFormat(ev_dir, args) 214 | else: 215 | raise ValueError('Unknown format specified: %s' % (format,)) 216 | 217 | 218 | # ================================================================ 219 | # API 220 | # ================================================================ 221 | 222 | def logkv(key, val): 223 | """ 224 | Log a value of some diagnostic 225 | Call this once for each diagnostic quantity, each iteration 226 | If called many times, last value will be used. 227 | """ 228 | Logger.CURRENT.logkv(key, val) 229 | 230 | 231 | def logkv_mean(key, val): 232 | """ 233 | The same as logkv(), but if called many times, values averaged. 234 | """ 235 | Logger.CURRENT.logkv_mean(key, val) 236 | 237 | 238 | def logkv_sum(key, val): 239 | """ 240 | """ 241 | Logger.CURRENT.logkv_sum(key, val) 242 | 243 | 244 | def logkvs(d): 245 | """ 246 | Log a dictionary of key-value pairs 247 | """ 248 | for (k, v) in d.items(): 249 | logkv(k, v) 250 | 251 | 252 | def dumpkvs(): 253 | """ 254 | Write all of the diagnostics from the current iteration 255 | 256 | level: int. (see logger.py docs) If the global logger level is higher than 257 | the level argument here, don't print to stdout. 258 | """ 259 | Logger.CURRENT.dumpkvs() 260 | 261 | 262 | def getkvs(): 263 | return Logger.CURRENT.name2val 264 | 265 | 266 | def log(*args, level=INFO, condition=True): 267 | """ 268 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 269 | """ 270 | Logger.CURRENT.log(*args, level=level, condition=condition) 271 | 272 | 273 | def debug(*args): 274 | log(*args, level=DEBUG) 275 | 276 | 277 | def info(*args): 278 | log(*args, level=INFO) 279 | 280 | 281 | def warn(*args): 282 | log(*args, level=WARN) 283 | 284 | 285 | def error(*args): 286 | log(*args, level=ERROR) 287 | 288 | 289 | def set_level(level): 290 | """ 291 | Set logging threshold on current logger. 292 | """ 293 | Logger.CURRENT.set_level(level) 294 | 295 | 296 | def get_dir(): 297 | """ 298 | Get directory that log files are being written to. 299 | will be None if there is no output directory (i.e., if you didn't call start) 300 | """ 301 | return Logger.CURRENT.get_dir() 302 | 303 | 304 | record_tabular = logkv 305 | dump_tabular = dumpkvs 306 | 307 | 308 | class ProfileKV: 309 | """ 310 | Usage: 311 | with logger.ProfileKV("interesting_scope"): 312 | code 313 | """ 314 | 315 | def __init__(self, n): 316 | self.n = "wait_" + n 317 | 318 | def __enter__(self): 319 | self.t1 = time.time() 320 | 321 | def __exit__(self, type, value, traceback): 322 | Logger.CURRENT.name2val[self.n] += time.time() - self.t1 323 | 324 | 325 | def profile(n): 326 | """ 327 | Usage: 328 | @profile("my_func") 329 | def my_func(): code 330 | """ 331 | 332 | def decorator_with_name(func): 333 | def func_wrapper(*args, **kwargs): 334 | with ProfileKV(n): 335 | return func(*args, **kwargs) 336 | 337 | return func_wrapper 338 | 339 | return decorator_with_name 340 | 341 | 342 | # ================================================================ 343 | # Backend 344 | # ================================================================ 345 | 346 | class Logger(object): 347 | # A logger with no output files. (See right below class definition) 348 | DEFAULT = None 349 | # So that you can still log to the terminal without setting up any output files 350 | CURRENT = None # Current logger being used by the free functions above 351 | 352 | def __init__(self, dir, output_formats, off=False): 353 | self.name2val = defaultdict(float) # values this iteration 354 | self.name2cnt = defaultdict(int) 355 | self.level = INFO 356 | self.dir = dir 357 | self.output_formats = output_formats 358 | self.off = off 359 | 360 | # Logging API, forwarded 361 | # ---------------------------------------- 362 | def logkv(self, key, val): 363 | self.name2val[key] = val 364 | 365 | def logkv_mean(self, key, val): 366 | if val is None: 367 | self.name2val[key] = None 368 | return 369 | oldval, cnt = self.name2val[key], self.name2cnt[key] 370 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 371 | self.name2cnt[key] = cnt + 1 372 | 373 | def logkv_sum(self, key, val): 374 | self.name2val[key] += val 375 | 376 | def dumpkvs(self): 377 | if self.level == DISABLED: 378 | return 379 | for fmt in self.output_formats: 380 | if isinstance(fmt, KVWriter): 381 | fmt.writekvs(self.name2val) 382 | self.name2val.clear() 383 | self.name2cnt.clear() 384 | 385 | def log(self, *args, level=INFO, condition=True): 386 | if self.level <= level and condition and not self.off: 387 | self._do_log(args) 388 | 389 | # Configuration 390 | # ---------------------------------------- 391 | def set_level(self, level): 392 | self.level = level 393 | 394 | def get_dir(self): 395 | return self.dir 396 | 397 | def close(self): 398 | for fmt in self.output_formats: 399 | fmt.close() 400 | 401 | # Misc 402 | # ---------------------------------------- 403 | def _do_log(self, args): 404 | for fmt in self.output_formats: 405 | if isinstance(fmt, SeqWriter): 406 | fmt.writeseq(map(str, args)) 407 | 408 | 409 | def configure(dir=None, format_strs=None, off=False): 410 | if dir is None: 411 | dir = os.getenv('OPENAI_LOGDIR') 412 | if dir is None: 413 | dir = osp.join(tempfile.gettempdir(), 414 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f")) 415 | assert isinstance(dir, str) 416 | os.makedirs(dir, exist_ok=True) 417 | 418 | log_suffix = '' 419 | rank = 0 420 | # check environment variables here instead of importing mpi4py 421 | # to avoid calling MPI_Init() when this module is imported 422 | for varname in ['PMI_RANK', 'OMPI_COMM_WORLD_RANK']: 423 | if varname in os.environ: 424 | rank = int(os.environ[varname]) 425 | if rank > 0: 426 | log_suffix = "-rank%03i" % rank 427 | 428 | if format_strs is None: 429 | if rank == 0: 430 | format_strs = os.getenv( 431 | 'OPENAI_LOG_FORMAT', 'stdout,log,csv').split(',') 432 | else: 433 | format_strs = os.getenv('OPENAI_LOG_FORMAT_MPI', 'log').split(',') 434 | format_strs = filter(None, format_strs) 435 | output_formats = [make_output_format( 436 | f, dir, log_suffix) for f in format_strs] 437 | print("output_formats", output_formats) 438 | 439 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, off=off) 440 | log('Logging to %s' % dir) 441 | 442 | 443 | def _configure_default_logger(): 444 | format_strs = None 445 | # keep the old default of only writing to stdout 446 | if 'OPENAI_LOG_FORMAT' not in os.environ: 447 | format_strs = ['stdout'] 448 | configure(format_strs=format_strs) 449 | Logger.DEFAULT = Logger.CURRENT 450 | 451 | 452 | def reset(): 453 | if Logger.CURRENT is not Logger.DEFAULT: 454 | Logger.CURRENT.close() 455 | Logger.CURRENT = Logger.DEFAULT 456 | log('Reset logger') 457 | 458 | 459 | class scoped_configure(object): 460 | def __init__(self, dir=None, format_strs=None): 461 | self.dir = dir 462 | self.format_strs = format_strs 463 | self.prevlogger = None 464 | 465 | def __enter__(self): 466 | self.prevlogger = Logger.CURRENT 467 | configure(dir=self.dir, format_strs=self.format_strs) 468 | 469 | def __exit__(self, *args): 470 | Logger.CURRENT.close() 471 | Logger.CURRENT = self.prevlogger 472 | 473 | 474 | # ================================================================ 475 | 476 | def _demo(): 477 | info("hi") 478 | debug("shouldn't appear") 479 | set_level(DEBUG) 480 | debug("should appear") 481 | dir = "/tmp/testlogging" 482 | if os.path.exists(dir): 483 | shutil.rmtree(dir) 484 | configure(dir=dir) 485 | logkv("a", 3) 486 | logkv("b", 2.5) 487 | dumpkvs() 488 | logkv("b", -2.5) 489 | logkv("a", 5.5) 490 | dumpkvs() 491 | info("^^^ should see a = 5.5") 492 | logkv_mean("b", -22.5) 493 | logkv_mean("b", -44.4) 494 | logkv("a", 5.5) 495 | dumpkvs() 496 | info("^^^ should see b = 33.3") 497 | 498 | logkv("b", -2.5) 499 | dumpkvs() 500 | 501 | logkv("a", "longasslongasslongasslongasslongasslongassvalue") 502 | dumpkvs() 503 | 504 | 505 | # ================================================================ 506 | # Readers 507 | # ================================================================ 508 | 509 | def read_json(fname): 510 | import pandas 511 | ds = [] 512 | with open(fname, 'rt') as fh: 513 | for line in fh: 514 | ds.append(json.loads(line)) 515 | return pandas.DataFrame(ds) 516 | 517 | 518 | def read_csv(fname): 519 | import pandas 520 | return pandas.read_csv(fname, index_col=None, comment='#') 521 | 522 | 523 | def read_tb(path): 524 | """ 525 | path : a tensorboard file OR a directory, where we will find all TB files 526 | of the form events.* 527 | """ 528 | import pandas 529 | import numpy as np 530 | from glob import glob 531 | from collections import defaultdict 532 | import tensorflow as tf 533 | if osp.isdir(path): 534 | fnames = glob(osp.join(path, "events.*")) 535 | elif osp.basename(path).startswith("events."): 536 | fnames = [path] 537 | else: 538 | raise NotImplementedError( 539 | "Expected tensorboard file or directory containing them. Got %s" % path) 540 | tag2pairs = defaultdict(list) 541 | maxstep = 0 542 | for fname in fnames: 543 | for summary in tf.train.summary_iterator(fname): 544 | if summary.step > 0: 545 | for v in summary.summary.value: 546 | pair = (summary.step, v.simple_value) 547 | tag2pairs[v.tag].append(pair) 548 | maxstep = max(summary.step, maxstep) 549 | data = np.empty((maxstep, len(tag2pairs))) 550 | data[:] = np.nan 551 | tags = sorted(tag2pairs.keys()) 552 | for (colidx, tag) in enumerate(tags): 553 | pairs = tag2pairs[tag] 554 | for (step, value) in pairs: 555 | data[step - 1, colidx] = value 556 | return pandas.DataFrame(data, columns=tags) 557 | 558 | 559 | # configure the default logger on import 560 | # _configure_default_logger() 561 | 562 | if __name__ == "__main__": 563 | _demo() 564 | -------------------------------------------------------------------------------- /utils/memory.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import random 3 | 4 | State = namedtuple('State', ('obs', 'description', 'inventory', 'score'), defaults=[-1]) 5 | StateWithActs = namedtuple('StateWithActs', 6 | ('obs', 'description', 'inventory', 'acts', 'score'), defaults=[-1]) 7 | Transition = namedtuple( 8 | 'Transition', 9 | ('state', 'act', 'reward', 'next_state', 'next_acts', 'done')) 10 | 11 | 12 | class ReplayMemory(object): 13 | def __init__(self, capacity): 14 | self.capacity = capacity 15 | self.memory = [] 16 | self.position = 0 17 | 18 | def push(self, *args): 19 | if len(self.memory) < self.capacity: 20 | self.memory.append(None) 21 | self.memory[self.position] = Transition(*args) 22 | self.position = (self.position + 1) % self.capacity 23 | 24 | def sample(self, batch_size): 25 | return random.sample(self.memory, 26 | batch_size) # samples without replacement 27 | 28 | def __len__(self): 29 | return len(self.memory) 30 | 31 | 32 | class PrioritizedReplayMemory(object): 33 | def __init__(self, capacity=100000, priority_fraction=0.0): 34 | self.priority_fraction = priority_fraction 35 | self.alpha_capacity = int(capacity * priority_fraction) 36 | self.beta_capacity = capacity - self.alpha_capacity 37 | self.alpha_memory, self.beta_memory = [], [] 38 | self.alpha_position, self.beta_position = 0, 0 39 | 40 | def clear_alpha(self): 41 | """ 42 | """ 43 | self.alpha_memory = [] 44 | self.alpha_position = 0 45 | 46 | def push(self, transition, is_prior=False): 47 | """Saves a transition.""" 48 | if self.priority_fraction == 0.0: 49 | is_prior = False 50 | if is_prior: 51 | if len(self.alpha_memory) < self.alpha_capacity: 52 | self.alpha_memory.append(None) 53 | self.alpha_memory[self.alpha_position] = transition 54 | self.alpha_position = (self.alpha_position + 55 | 1) % self.alpha_capacity 56 | else: 57 | if len(self.beta_memory) < self.beta_capacity: 58 | self.beta_memory.append(None) 59 | self.beta_memory[self.beta_position] = transition 60 | self.beta_position = (self.beta_position + 1) % self.beta_capacity 61 | 62 | def sample(self, batch_size): 63 | if self.priority_fraction == 0.0: 64 | from_beta = min(batch_size, len(self.beta_memory)) 65 | res = random.sample(self.beta_memory, from_beta) 66 | else: 67 | from_alpha = min(int(self.priority_fraction * batch_size), 68 | len(self.alpha_memory)) 69 | from_beta = min( 70 | batch_size - int(self.priority_fraction * batch_size), 71 | len(self.beta_memory)) 72 | res = random.sample(self.alpha_memory, from_alpha) + random.sample( 73 | self.beta_memory, from_beta) 74 | random.shuffle(res) 75 | return res 76 | 77 | def __len__(self): 78 | return len(self.alpha_memory) + len(self.beta_memory) 79 | 80 | 81 | class ABReplayMemory(object): 82 | def __init__(self, capacity, priority_fraction): 83 | self.priority_fraction = priority_fraction 84 | self.alpha_capacity = int(capacity * priority_fraction) 85 | self.beta_capacity = capacity - self.alpha_capacity 86 | self.alpha_memory, self.beta_memory = [], [] 87 | self.alpha_position, self.beta_position = 0, 0 88 | 89 | def clear_alpha(self): 90 | self.alpha_memory = [] 91 | self.alpha_position = 0 92 | 93 | def push(self, transition, is_prior=False): 94 | """Saves a transition.""" 95 | if self.priority_fraction == 0.0: 96 | is_prior = False 97 | if is_prior: 98 | if len(self.alpha_memory) < self.alpha_capacity: 99 | self.alpha_memory.append(None) 100 | self.alpha_memory[self.alpha_position] = transition 101 | self.alpha_position = ( 102 | self.alpha_position + 1) % self.alpha_capacity 103 | else: 104 | if len(self.beta_memory) < self.beta_capacity: 105 | self.beta_memory.append(None) 106 | self.beta_memory[self.beta_position] = transition 107 | self.beta_position = (self.beta_position + 1) % self.beta_capacity 108 | 109 | def sample(self, batch_size): 110 | if self.priority_fraction == 0.0: 111 | from_beta = min(batch_size, len(self.beta_memory)) 112 | res = random.sample(self.beta_memory, from_beta) 113 | else: 114 | from_alpha = min(int(self.priority_fraction * 115 | batch_size), len(self.alpha_memory)) 116 | from_beta = min( 117 | batch_size - int(self.priority_fraction * batch_size), len(self.beta_memory)) 118 | res = random.sample(self.alpha_memory, from_alpha) + \ 119 | random.sample(self.beta_memory, from_beta) 120 | random.shuffle(res) 121 | return res 122 | 123 | def __len__(self): 124 | return len(self.alpha_memory) + len(self.beta_memory) 125 | -------------------------------------------------------------------------------- /utils/ngram.py: -------------------------------------------------------------------------------- 1 | # Built-in imports 2 | import time 3 | from typing import Dict, Union, List 4 | import heapq as pq 5 | import statistics as stats 6 | 7 | # Libraries 8 | import numpy as np 9 | 10 | from transformers import AdamW, TopPLogitsWarper 11 | 12 | import torch 13 | 14 | from torch.utils.data import DataLoader 15 | 16 | from jericho.util import clean 17 | 18 | from tqdm import tqdm 19 | 20 | # Custom imports 21 | import definitions.defs as defs 22 | 23 | from utils.memory import StateWithActs 24 | from utils.util import process_action, get_class_name, flatten_2d, pad_sequences 25 | from utils.vec_env import VecEnv 26 | 27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | 29 | 30 | def init_model(model, action_models, args): 31 | model.action_models = action_models 32 | model.sample_action_argmax = args.sample_action_argmax 33 | model.il_max_context = args.il_max_context 34 | model.il_vocab_size = args.il_vocab_size 35 | model.max_acts = args.max_acts 36 | model.action_model_type = args.action_model_type 37 | model.score_thresholds = [0.0 for _ in range(len(model.envs))] 38 | model.traj_lens = [0.0 for _ in range(len(model.envs))] 39 | model.cut_beta_at_threshold = args.cut_beta_at_threshold 40 | model.turn_action_model_off_after_falling = args.turn_action_model_off_after_falling 41 | model.traj_dropout_prob = args.traj_dropout_prob 42 | model.use_action_model = args.use_action_model 43 | model.num_bins = args.num_bins 44 | model.init_bin_prob = args.init_bin_prob 45 | if model.num_bins > 0: 46 | model.binning_probs = np.array( 47 | [model.init_bin_prob for _ in range(model.num_bins)]) 48 | model.bins = np.array( 49 | [i/model.num_bins for i in range(model.num_bins + 1)]) 50 | 51 | # IL stuff 52 | model.token_remap_cache = dict() 53 | model.il_temp = args.il_temp 54 | model.use_il = args.use_il 55 | model.old_betas = torch.tensor([0. for i in range(len(model.envs))]) 56 | model.il_top_p = args.il_top_p 57 | model.il_use_dropout = args.il_use_dropout 58 | model.il_use_only_dropout = args.il_use_only_dropout 59 | 60 | model.top_p_warper = TopPLogitsWarper( 61 | top_p=model.il_top_p, filter_value=-1e9) 62 | 63 | 64 | def remap_token_idxs(model, token_idxs): 65 | result = [] 66 | for token_idx in token_idxs: 67 | if token_idx in model.token_remap_cache: 68 | result.append(model.token_remap_cache[token_idx]) 69 | else: 70 | cur = max(model.token_remap_cache.values()) if len( 71 | model.token_remap_cache) > 0 else 0 72 | if cur + 1 >= model.il_vocab_size: 73 | model.log("Vocab size exceeded during remapping!") 74 | assert cur + 1 < model.il_vocab_size, "Vocab size exceeded during remapping!" 75 | result.append(cur + 1) 76 | model.token_remap_cache[token_idx] = cur + 1 77 | 78 | return result 79 | 80 | 81 | def init_trainer(trainer, args): 82 | trainer.top_k_traj = [] 83 | trainer.traj_to_train = [] 84 | trainer.traj_k = args.traj_k 85 | trainer.last_ngram_update = 0 86 | trainer.cut_beta_at_threshold = args.cut_beta_at_threshold 87 | 88 | trainer.action_model_update_freq = args.action_model_update_freq 89 | trainer.action_model_scale_factor = args.action_model_update_freq / 100 90 | trainer.action_model_type = args.action_model_type 91 | trainer.use_multi_ngram = args.use_multi_ngram 92 | trainer.random_action_dropout = args.random_action_dropout 93 | trainer.max_acts = args.max_acts 94 | trainer.tf_num_epochs = args.tf_num_epochs 95 | 96 | trainer.init_bin_prob = args.init_bin_prob 97 | trainer.binning_prob_update_freq = args.binning_prob_update_freq 98 | trainer.num_bins = args.num_bins 99 | 100 | trainer.episode_ext_type = args.episode_ext_type 101 | 102 | # IL stuff 103 | trainer.il_batch_size = args.il_batch_size 104 | trainer.il_k = args.il_k 105 | trainer.il_lr = args.il_lr 106 | trainer.il_max_num_epochs = args.il_max_num_epochs 107 | trainer.il_num_eval_runs = args.il_num_eval_runs 108 | trainer.il_max_context = args.il_max_context 109 | trainer.il_eval_freq = args.il_eval_freq 110 | trainer.il_len_scale = args.il_len_scale 111 | 112 | 113 | def build_traj_state(agent, prev_act: str, traj_acts: List[str]): 114 | """Update the trajectory state. 115 | Args: 116 | agent ([type]): the agent 117 | prev_act (str): the last act take by the agent 118 | traj_acts (List[str]): the past actions taken by agent 119 | """ 120 | if agent.max_acts == 0 or prev_act == 'reset': 121 | acts = [] 122 | elif len(traj_acts) == agent.max_acts: 123 | acts = traj_acts[1:] + [process_action(prev_act)] 124 | else: 125 | acts = traj_acts + [process_action(prev_act)] 126 | 127 | return acts 128 | 129 | 130 | def push_to_traj_mem(trainer, next_info: Dict[str, Union[List[str], str, int]], traj: List[str]): 131 | """Push current trajectory to the trajectory heap. 132 | Args: 133 | trainer (Trainer): trainer to act on 134 | next_info (Dict[str, Union[str, int]]): contains valid actions etc. 135 | env ([type]): [description] 136 | """ 137 | pq.heappush( 138 | trainer.top_k_traj, 139 | (next_info['score'], -1 * len(traj), traj.copy())) 140 | 141 | # pop if we have too many 142 | if len(trainer.top_k_traj) > trainer.traj_k: 143 | pq.heappop(trainer.top_k_traj) 144 | 145 | 146 | def get_bin_prob(model): 147 | """ 148 | """ 149 | x = [min(env.steps/model.traj_len, 1.0) for env in model.envs] 150 | idxs = np.digitize(x, model.bins, right=True) - 1 151 | model.log('x {}'.format(x)) 152 | model.log('bins {}'.format(model.bins)) 153 | model.log('indices {}'.format(idxs)) 154 | 155 | return [torch.bernoulli(torch.tensor(1. - model.binning_probs[idx])).item() for idx in idxs] 156 | 157 | 158 | def get_beta_from_lm_vals(model, vals): 159 | """ 160 | Compute beta parameter for each environment. 161 | Set beta = 0 if score threshold is reached, otherwise 1. 162 | """ 163 | assert model.use_il, 'Agent needs to use IL when using LM vals.' 164 | 165 | current_steps = model.envs.get_current_steps() 166 | current_scores = model.envs.get_current_scores() 167 | # determine whether to use dropout 168 | if model.il_use_dropout: 169 | traj_dropout_probs = [compute_dropout_prob( 170 | model, traj_len) for traj_len in model.traj_lens] 171 | model.log("Adjusted dropout prob: {}, env length {}".format( 172 | traj_dropout_probs, model.traj_lens)) 173 | dropout_results = [torch.bernoulli(torch.tensor( 174 | 1. - traj_dropout_probs[i])).item() for i in range(len(vals))] 175 | model.log('Traj dropout results: {}'.format(dropout_results)) 176 | 177 | betas = torch.tensor([ 178 | 1. if current_scores[i] < model.score_thresholds[i] and 179 | current_steps[i] < model.traj_lens[i] and 180 | dropout_results[i] 181 | else 0. for i, val in enumerate(vals) 182 | ], device=device) 183 | elif model.il_use_only_dropout: 184 | traj_dropout_probs = [ 185 | model.traj_dropout_prob for traj_len in model.traj_lens] 186 | model.log("Hard dropout prob: {}, env length {}".format( 187 | traj_dropout_probs, model.traj_lens)) 188 | dropout_results = [torch.bernoulli(torch.tensor( 189 | 1. - traj_dropout_probs[i])).item() for i in range(len(vals))] 190 | model.log('Traj dropout results: {}'.format(dropout_results)) 191 | 192 | betas = torch.tensor([ 193 | 1. if dropout_results[i] 194 | else 0. for i, val in enumerate(vals) 195 | ], device=device) 196 | else: 197 | betas = torch.tensor([ 198 | 1. if current_scores[i] < model.score_thresholds[i] and 199 | current_steps[i] < model.traj_lens[i] 200 | else 0. for i, val in enumerate(vals) 201 | ], device=device) 202 | 203 | model.old_betas = betas 204 | 205 | return betas 206 | 207 | 208 | def compute_dropout_prob(model, traj_len): 209 | """ 210 | """ 211 | return model.traj_dropout_prob * 100 * (1/traj_len) if traj_len > 0 else 1 212 | 213 | 214 | def action_model_forward(model, states, valid_strs, q_values, act_sizes, il_eval=False): 215 | """ 216 | """ 217 | # ** action model forward ** 218 | class_name = get_class_name(model).lower() 219 | if 'drrn' in class_name: 220 | past_acts = StateWithActs(*zip(*states)).acts 221 | 222 | # TODO: finish transformer case 223 | if model.action_model_type == defs.TRANSFORMER: 224 | cls_id = model.tokenizer.convert_tokens_to_ids(['[CLS]'])[0] 225 | sep_id = model.tokenizer.convert_tokens_to_ids(['[SEP]'])[0] 226 | 227 | input_ids = [] 228 | act_masks = [] 229 | att_masks = [] 230 | 231 | for state, acts, valid in zip(states, past_acts, valid_strs): 232 | context_acts = acts[-model.max_acts:] 233 | act_history = [model.tokenizer.encode( 234 | act) for act in context_acts] if len(acts) > 0 else [] 235 | context = [cls_id] + \ 236 | flatten_2d([act[1:-1] + [sep_id] 237 | for act in act_history]) + state.obs[1:-1] + state.description[1:-1] + state.inventory[1:-1] + [sep_id] 238 | for valid_act in valid: 239 | to_predict = model.tokenizer.encode(valid_act)[1:] 240 | ids = context + to_predict 241 | act_mask = np.zeros(len(ids)) 242 | act_mask[len(context):] = 1 243 | att_mask = np.ones(len(ids)) 244 | 245 | if len(ids) > model.il_max_context: 246 | input_ids.append(ids[-model.il_max_context:]) 247 | act_masks.append(act_mask[-model.il_max_context:]) 248 | att_masks.append(att_mask[-model.il_max_context:]) 249 | else: 250 | input_ids.append(ids) 251 | act_masks.append(act_mask) 252 | att_masks.append(att_mask) 253 | 254 | # remap 255 | input_ids = [remap_token_idxs(model, seq) for seq in input_ids] 256 | 257 | pad_len = max([len(ids) for ids in input_ids]) 258 | 259 | input_ids = torch.tensor(pad_sequences( 260 | input_ids, pad_len), dtype=torch.long).to(device) 261 | act_masks = torch.tensor(pad_sequences( 262 | act_masks, pad_len), dtype=torch.long) 263 | att_masks = torch.tensor(pad_sequences( 264 | att_masks, pad_len), dtype=torch.long).to(device) 265 | 266 | # Get lm values 267 | lm_values = [] 268 | model.action_models.eval() 269 | with torch.no_grad(): 270 | predictions = model.action_models( 271 | input_ids, attention_mask=att_masks)[0] 272 | for prediction, ids, act_mask in zip(predictions, input_ids, act_masks): 273 | prediction = prediction[np.argmax(act_mask)-1:-1] 274 | log_p = torch.nn.functional.log_softmax(prediction, dim=-1) 275 | score = 1./act_mask.sum() * log_p[range(act_mask.sum()), 276 | ids[act_mask == 1]].sum().item() 277 | lm_values.append(score) 278 | 279 | lm_values = (model.il_temp * torch.tensor(lm_values, 280 | device=device)).split(act_sizes) 281 | lm_values = [model.top_p_warper(None, scores.unsqueeze(0))[ 282 | 0] for scores in lm_values] 283 | 284 | model.log("ngram values: {}".format( 285 | list(map(lambda x: x.item(), lm_values[0])))) 286 | 287 | # Assert shape match 288 | assert len(q_values) == len(lm_values) 289 | for q_val, lm_val in zip(q_values, lm_values): 290 | assert len(q_val) == len(lm_val) 291 | 292 | beta_vec = get_beta_from_lm_vals(model, lm_values) 293 | 294 | # Turn off ngram if it has fallen off 295 | if model.turn_action_model_off_after_falling: 296 | on_traj = [] 297 | for i in range(len(model.envs)): 298 | if beta_vec[i] == 0: 299 | model.envs.turn_off_trajectory(i) 300 | 301 | on_traj.append(int(model.envs.get_trajectory_state(i))) 302 | 303 | old_beta_vec = beta_vec.clone() 304 | model.log('Old beta {}'.format(old_beta_vec)) 305 | on_traj = torch.tensor(on_traj, device=device) 306 | model.log('On traj {}'.format(on_traj)) 307 | 308 | beta_vec *= on_traj 309 | 310 | model.tb.logkv_sum('Turned off action model count', 311 | (old_beta_vec - beta_vec).sum()) 312 | 313 | # log beta & whether trajectory was fully recovered 314 | model.log('beta {}'.format(beta_vec)) 315 | 316 | # update ngram hits 317 | if isinstance(model.envs, VecEnv): 318 | model.envs.update_ngram_hits(beta_vec.cpu().numpy()) 319 | else: 320 | for i in range(len(model.envs)): 321 | model.envs[i].ngram_hits += beta_vec[i] 322 | 323 | act_values = [ 324 | q_value * (1 - beta_vec[i]) + bert_value * beta_vec[i] 325 | for i, (q_value, bert_value) in enumerate(zip(q_values, lm_values)) 326 | ] 327 | 328 | return act_values, beta_vec 329 | 330 | 331 | def update_score_threshold(trainer, i: int): 332 | """Update score and length thresholds 333 | Args: 334 | trainer (Trainer): the trainer 335 | i (int): environment index 336 | """ 337 | # NOTE: this assumes 1 trajectory for now 338 | if len(trainer.traj_to_train) > 0: 339 | traj = trainer.traj_to_train[0] 340 | score = traj[0] 341 | traj_len = -1 * traj[1] 342 | 343 | trainer.agent.network.score_thresholds[i] = score 344 | trainer.agent.network.traj_lens[i] = traj_len 345 | 346 | 347 | def _build_traj_states_and_acts(trainer, trajs): 348 | # find first state 349 | first_state_id = trainer.agent.graphs[0].get_first_state_id() 350 | new_trajs = [] 351 | for i in range(len(trajs)): 352 | new_trajs.append([(None, first_state_id)] + trajs[i]) 353 | 354 | traj_states = [] 355 | traj_acts = [] 356 | for traj in new_trajs: 357 | traj_states.append([]) 358 | traj_acts.append([]) 359 | for t in traj: 360 | action, state_id = t 361 | if state_id is not None: 362 | node = trainer.agent.graphs[0][state_id] 363 | info = dict() 364 | info["look"] = node["loc"] 365 | info["inv"] = node["inv"] 366 | info["score"] = node["score"] 367 | state = trainer.agent.build_state(node["obs"], info) 368 | traj_states[-1].append(state) 369 | 370 | if action is not None: 371 | traj_acts[-1].append(trainer.agent.encode([action])[0]) 372 | 373 | return traj_states, traj_acts 374 | 375 | 376 | def _build_tf_input_elements(trainer, traj_states, traj_acts): 377 | input_ids = [] 378 | act_masks = [] 379 | att_masks = [] 380 | for states, acts in zip(traj_states, traj_acts): 381 | for i in range(len(acts)): 382 | act_history = acts[max(i - trainer.max_acts, 0): i] 383 | context = [101] + \ 384 | flatten_2d([act[1:-1] + [102] for act in act_history]) + \ 385 | states[i].obs[1: -1] + states[i].description[1: -1] + \ 386 | states[i].inventory[1: -1] + [102] 387 | to_predict = acts[i][1:] 388 | ids = context + to_predict 389 | act_mask = np.zeros(len(ids)) 390 | act_mask[len(context):] = 1 391 | att_mask = np.ones(len(ids)) 392 | 393 | if len(ids) > trainer.il_max_context: 394 | trainer.tb.logkv_sum('ExceededContext', 1) 395 | input_ids.append(ids[-trainer.il_max_context:]) 396 | act_masks.append(act_mask[-trainer.il_max_context:]) 397 | att_masks.append(att_mask[-trainer.il_max_context:]) 398 | else: 399 | trainer.tb.logkv_sum('ExceededContext', 0) 400 | input_ids.append(ids) 401 | act_masks.append(act_mask) 402 | att_masks.append(att_mask) 403 | 404 | return input_ids, act_masks, att_masks 405 | 406 | def my_collate(batch): 407 | input_ids = [el[0] for el in batch] 408 | act_masks = [el[1] for el in batch] 409 | att_masks = [el[2] for el in batch] 410 | 411 | pad_len = max([len(ids) for ids in input_ids]) 412 | 413 | input_ids = torch.tensor(pad_sequences( 414 | input_ids, pad_len), dtype=torch.long) 415 | act_masks = torch.tensor(pad_sequences( 416 | act_masks, pad_len), dtype=torch.long) 417 | att_masks = torch.tensor(pad_sequences( 418 | att_masks, pad_len), dtype=torch.long) 419 | 420 | return (input_ids, act_masks, att_masks) 421 | 422 | 423 | def update_il_threshold(trainer, score_threshold, len_threshold): 424 | trainer.agent.network.score_thresholds = [ 425 | score_threshold for _ in range(len(trainer.envs))] 426 | trainer.agent.network.traj_lens = [ 427 | len_threshold for _ in range(len(trainer.envs)) 428 | ] 429 | 430 | for i in range(len(trainer.envs)): 431 | trainer.envs.set_env_limit( 432 | len_threshold + trainer.graph_num_explore_steps, i) 433 | 434 | trainer.log(f'Setting score threshold: {score_threshold}') 435 | trainer.log(f'Setting len threshold: {len_threshold}') 436 | trainer.log( 437 | f'Setting env limit: {len_threshold + trainer.graph_num_explore_steps}') 438 | 439 | trainer.tb.logkv('ScoreThreshold', score_threshold) 440 | trainer.tb.logkv('LenThreshold', len_threshold) 441 | 442 | 443 | def end_step(trainer, step: int): 444 | """ 445 | Update action model if necessary. 446 | """ 447 | if (trainer.steps - trainer.last_ngram_update) == trainer.action_model_update_freq: 448 | trainer.last_ngram_update = trainer.steps 449 | if trainer.action_model_type == defs.TRANSFORMER: 450 | # get trajectories 451 | if trainer.use_il_graph_sampler: 452 | trainer.il_trajs = trainer.agent.graphs[0].get_graph_policy( 453 | k=trainer.il_k) 454 | 455 | state_trajs, act_trajs = _build_traj_states_and_acts( 456 | trainer, trainer.il_trajs) 457 | elif trainer.use_il_buffer_sampler: 458 | state_trajs, act_trajs = trainer.agent.il_buffer.sample_trajs( 459 | trainer.il_k) 460 | 461 | max_score = max( 462 | [state_traj[-1].score for state_traj in state_trajs]) 463 | max_traj_len = max([len(act_traj) for act_traj in act_trajs]) 464 | 465 | # update IL threshold + env len limit 466 | update_il_threshold(trainer, max_score, 467 | trainer.il_len_scale * max_traj_len) 468 | 469 | input_ids, act_masks, att_masks = _build_tf_input_elements( 470 | trainer, state_trajs, act_trajs) 471 | 472 | # remap 473 | input_ids = [remap_token_idxs( 474 | trainer.agent.network, seq) for seq in input_ids] 475 | 476 | X = [(ids, act_mask, att_mask) for ids, act_mask, 477 | att_mask in zip(input_ids, act_masks, att_masks)] 478 | data = DataLoader(X, batch_size=trainer.il_batch_size, 479 | shuffle=True, collate_fn=my_collate) 480 | 481 | lm = trainer.agent.action_models 482 | lm.train() 483 | 484 | # Train! 485 | optimizer = AdamW( 486 | lm.parameters(), lr=trainer.il_lr, eps=1e-8) 487 | start = time.time() 488 | for i in tqdm(range(trainer.il_max_num_epochs)): 489 | for batch in data: 490 | b_input_ids, b_act_masks, b_att_masks = batch 491 | b_input_ids = b_input_ids.to(device) 492 | b_act_masks = b_act_masks.to(device) 493 | b_att_masks = b_att_masks.to(device) 494 | 495 | b_labels = b_input_ids.clone() 496 | b_labels[b_act_masks == 0] = -100 497 | 498 | outputs = lm( 499 | b_input_ids, attention_mask=b_att_masks, labels=b_labels) 500 | 501 | loss = outputs[0] 502 | loss.backward() 503 | 504 | torch.nn.utils.clip_grad_norm_(lm.parameters(), 1.0) 505 | optimizer.step() 506 | lm.zero_grad() 507 | 508 | end = time.time() 509 | 510 | trainer.tb.logkv("IL-Loss", loss) 511 | trainer.tb.logkv("IL-TimeToFit", end - start) 512 | 513 | trainer.action_model_update_freq = ( 514 | max_traj_len + trainer.graph_num_explore_steps) * trainer.action_model_scale_factor 515 | 516 | trainer.tb.logkv('Action Model Update Freq', 517 | trainer.action_model_update_freq) 518 | 519 | else: 520 | raise Exception("Unrecognized action model!") 521 | 522 | 523 | def log_metrics(trainer): 524 | # Log min/max/median scores 525 | scores = [] 526 | lengths = [] 527 | for traj in trainer.traj_to_train: 528 | scores.append(traj[0]) 529 | lengths.append(traj[1] * -1) 530 | 531 | scores.sort() 532 | lengths.sort() 533 | 534 | trainer.tb.logkv('Max Traj Score', 535 | scores[-1] if len(scores) > 0 else 0) 536 | trainer.tb.logkv('Min Traj Score', 537 | scores[0] if len(scores) > 0 else 0) 538 | trainer.tb.logkv('Median Traj Score', stats.median( 539 | scores) if len(scores) > 0 else 0) 540 | 541 | # Log min/max/median length of saved trajectory 542 | trainer.tb.logkv('Max Traj Length', 543 | lengths[-1] if len(lengths) > 0 else 0) 544 | trainer.tb.logkv('Min Traj Length', 545 | lengths[0] if len(lengths) > 0 else 0) 546 | trainer.tb.logkv('Median Traj Length', stats.median( 547 | lengths) if len(lengths) > 0 else 0) 548 | 549 | # Log the score threshold 550 | trainer.tb.logkv('Score Threshold', 551 | max(trainer.agent.network.score_thresholds)) 552 | # Log the current episode length 553 | if isinstance(trainer.envs, VecEnv): 554 | current_limit = trainer.envs.get_env_limit() 555 | else: 556 | current_limit = trainer.envs[0].step_limit 557 | trainer.tb.logkv('Episode Length', current_limit) 558 | 559 | 560 | def log_recovery_metrics(trainer, i: int): 561 | # Log recovery metrics 562 | if isinstance(trainer.envs, VecEnv): 563 | ngram_hits = trainer.envs.get_ngram_hits(i) 564 | else: 565 | ngram_hits = trainer.envs[i].ngram_hits 566 | 567 | traj_lens = trainer.agent.network.traj_lens 568 | trainer.tb.logkv_mean( 569 | 'Traj Fully Recovered', 1 if (ngram_hits == traj_lens[i] and traj_lens[i] > 0) else 0) 570 | trainer.tb.logkv_mean('Traj Part Recovered', ngram_hits / 571 | traj_lens[i] if traj_lens[i] > 0 else 0) 572 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | # Built-in Imports 2 | import itertools 3 | import os 4 | import re 5 | from typing import Any, List 6 | from collections import deque 7 | 8 | # Libraries 9 | import numpy as np 10 | import pickle 11 | import torch 12 | 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | 15 | zork1_two_obj_acts = [ 16 | 'apply( \w+)+ to( \w+)+', 17 | 'tie( \w+)+ to( \w+)+', 18 | 'tie up( \w+)+ with( \w+)+', 19 | 'hit( \w+)+ with( \w+)+', 20 | 'break( \w+)+ with( \w+)+', 21 | 'break down( \w+)+ with( \w+)+', 22 | 'blow up( \w+)+ with( \w+)+', 23 | 'wave( \w+)+ at( \w+)+', 24 | 'clean( \w+)+ with( \w+)+', 25 | 'burn( \w+)+ with( \w+)+', 26 | 'burn down( \w+)+ with( \w+)+', 27 | 'take( \w+)+ from( \w+)+', 28 | 'take( \w+)+ off( \w+)+', 29 | 'take( \w+)+ out( \w+)+', 30 | 'throw( \w+)+( \w+)+', 31 | 'throw( \w+)+ at( \w+)+', 32 | 'throw( \w+)+ in( \w+)+', 33 | 'throw( \w+)+ off( \w+)+', 34 | 'throw( \w+)+ on( \w+)+', 35 | 'throw( \w+)+ over( \w+)+', 36 | 'throw( \w+)+ with( \w+)+', 37 | 'cut( \w+)+ with( \w+)+', 38 | 'dig( \w+)+ with( \w+)+', 39 | 'dig in( \w+)+ with( \w+)+', 40 | 'give( \w+)+( \w+)+', 41 | 'give( \w+)+ to( \w+)+', 42 | 'drop( \w+)+ down( \w+)+', 43 | 'put( \w+)+ on( \w+)+', 44 | 'put( \w+)+ in( \w+)+', 45 | 'touch( \w+)+ with( \w+)+', 46 | 'fill( \w+)+ with( \w+)+', 47 | 'plug( \w+)+ with( \w+)+', 48 | 'turn( \w+)+ for( \w+)+', 49 | 'turn( \w+)+ to( \w+)+', 50 | 'turn( \w+)+ with( \w+)+', 51 | 'turn on( \w+)+ with( \w+)+', 52 | 'untie( \w+)+ from( \w+)+', 53 | 'look at( \w+)+ with( \w+)+', 54 | 'oil( \w+)+ with( \w+)+', 55 | 'put( \w+)+ behind( \w+)+', 56 | 'put( \w+)+ under( \w+)+', 57 | 'inflat( \w+)+ with( \w+)+', 58 | 'is( \w+)+ in( \w+)+', 59 | 'is( \w+)+ on( \w+)+', 60 | 'light( \w+)+ with( \w+)+', 61 | 'melt( \w+)+ with( \w+)+', 62 | 'lock( \w+)+ with( \w+)+', 63 | 'push( \w+)+ with( \w+)+', 64 | 'open( \w+)+ with( \w+)+', 65 | 'ring( \w+)+ with( \w+)+', 66 | 'pick( \w+)+ with( \w+)+', 67 | 'poke( \w+)+ with( \w+)+', 68 | 'pour( \w+)+ from( \w+)+', 69 | 'pour( \w+)+ in( \w+)+', 70 | 'pour( \w+)+ on( \w+)+', 71 | 'push( \w+)+( \w+)+', 72 | 'push( \w+)+ to( \w+)+', 73 | 'push( \w+)+ under( \w+)+', 74 | 'pump up( \w+)+ with( \w+)+', 75 | 'read( \w+)+( \w+)+', 76 | 'read( \w+)+ with( \w+)+', 77 | 'spray( \w+)+ on( \w+)+', 78 | 'spray( \w+)+ with( \w+)+', 79 | 'squeez( \w+)+ on( \w+)+', 80 | 'strike( \w+)+ with( \w+)+', 81 | 'swing( \w+)+ at( \w+)+', 82 | 'unlock( \w+)+ with( \w+)+' 83 | ] 84 | 85 | ZORK1_TWO_OBJ_REGEX = [re.compile(regexp) for regexp in zork1_two_obj_acts] 86 | 87 | 88 | def obj_bfs(start, env): 89 | nodes = [] 90 | visited = set() 91 | q = deque() 92 | visited.add(start.num) 93 | q.append(start) 94 | 95 | while len(q) > 0: 96 | node = q.popleft() 97 | nodes.append(node) 98 | if node.child: 99 | visited.add(node.child) 100 | q.append(env.get_object(node.child)) 101 | if node.sibling: 102 | visited.add(node.sibling) 103 | q.append(env.get_object(node.sibling)) 104 | 105 | return nodes 106 | 107 | 108 | def extract_inventory(env): 109 | first_inv_num = env.get_player_object().child 110 | if not first_inv_num: 111 | return [] 112 | 113 | first_inv_node = env.get_object(first_inv_num) 114 | 115 | return obj_bfs(first_inv_node, env) 116 | 117 | # def extract_surrounding(env): 118 | # first_sur_num = env.get_player_location().child 119 | # if not first 120 | 121 | 122 | def filter_two_obj_acts(acts, game: str): 123 | if game == 'zork1': 124 | return list(filter(lambda x: any( 125 | p.match(x) is not None for p in ZORK1_TWO_OBJ_REGEX), acts)) 126 | else: 127 | raise NotImplementedError(f'Not implemented for game: {game}') 128 | 129 | 130 | class RunningMeanStd(object): 131 | # https://github.com/jcwleo/random-network-distillation-pytorch/blob/e383fb95177c50bfdcd81b43e37c443c8cde1d94/utils.py#L44 132 | # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 133 | def __init__(self, epsilon=1e-9, shape=()): 134 | self.mean = torch.zeros(shape, dtype=torch.float32).to(device) 135 | self.var = torch.ones(shape, dtype=torch.float32).to(device) 136 | self.count = epsilon 137 | 138 | def update(self, x): 139 | batch_mean = torch.mean(x, dim=0) 140 | batch_var = torch.var(x, dim=0) 141 | batch_count = x.shape[0] 142 | self.update_from_moments(batch_mean, batch_var, batch_count) 143 | 144 | def update_from_moments(self, batch_mean, batch_var, batch_count): 145 | delta = batch_mean - self.mean 146 | tot_count = self.count + batch_count 147 | 148 | new_mean = self.mean + delta * batch_count / tot_count 149 | m_a = self.var * (self.count) 150 | m_b = batch_var * (batch_count) 151 | M2 = m_a + m_b + \ 152 | torch.square(delta) * self.count * batch_count / \ 153 | (self.count + batch_count) 154 | new_var = M2 / (self.count + batch_count) 155 | 156 | new_count = batch_count + self.count 157 | 158 | self.mean = new_mean 159 | self.var = new_var 160 | self.count = new_count 161 | 162 | 163 | def get_class_name(obj: Any) -> str: 164 | """Get class name of object. 165 | 166 | Args: 167 | obj (Any): the object. 168 | 169 | Returns: 170 | [str]: the class name 171 | """ 172 | return type(obj).__name__ 173 | 174 | 175 | def get_name_from_path(path: str): 176 | """ 177 | Given a path string, extract the game name from it. 178 | """ 179 | return path.split('/')[-1].split('.')[0] 180 | 181 | 182 | def flatten_2d(l: list): 183 | return list(itertools.chain.from_iterable(l)) 184 | 185 | 186 | def load_object(path: str): 187 | """ 188 | Load the Python object at the given path. 189 | """ 190 | with open(path, "rb") as f: 191 | return pickle.load(f) 192 | 193 | 194 | def save_object(obj, path: str): 195 | """ 196 | """ 197 | with open(path, 'wb') as f: 198 | pickle.dump(obj, f) 199 | 200 | 201 | def check_exists(file_path: str): 202 | """ 203 | """ 204 | return os.path.exists(file_path) 205 | 206 | 207 | def setup_env(agent, envs): 208 | """ 209 | TODO 210 | """ 211 | _, infos = envs.reset() 212 | states, valid_ids = [], [] 213 | 214 | for info in infos: 215 | states, valid_ids = states + [[]], valid_ids + [ 216 | agent.encode(info['valid']) 217 | ] 218 | 219 | return infos, states, valid_ids 220 | 221 | 222 | def convert_idxs_to_strs(act_idxs, tokenizer): 223 | """ 224 | Given a list of action idxs, convert it to a list of action strings. 225 | """ 226 | return [ 227 | tokenizer.convert_tokens_to_string( 228 | tokenizer.convert_ids_to_tokens(act)).strip() for act in act_idxs 229 | ] 230 | 231 | 232 | def process_action(action: str): 233 | """ 234 | Transforms the action to lowercase and spells it 235 | out in case it was abbreviated. 236 | """ 237 | abbr_map = { 238 | "w": "west", 239 | "n": "north", 240 | "e": "east", 241 | "s": "south", 242 | "se": "southeast", 243 | "sw": "southwest", 244 | "ne": "northeast", 245 | "nw": "northwest", 246 | "u": "up", 247 | "d": "down", 248 | "l": "look" 249 | } 250 | 251 | action = action.strip().lower() 252 | if action in abbr_map: 253 | return abbr_map[action] 254 | return action 255 | 256 | 257 | def inv_process_action(action: str): 258 | abbr_map = { 259 | "west": "w", 260 | "north": "n", 261 | "east": "e", 262 | "south": "s", 263 | "southeast": "se", 264 | "southwest": "sw", 265 | "northeast": "ne", 266 | "northwest": "nw", 267 | "up": "u", 268 | "down": "d", 269 | "look": "l" 270 | } 271 | 272 | action = action.strip().lower() 273 | if action in abbr_map: 274 | return abbr_map[action] 275 | return action 276 | 277 | 278 | def pad_sequences(sequences, maxlen=None, dtype='int32', value=0.): 279 | ''' 280 | Partially borrowed from Keras 281 | # Arguments 282 | sequences: list of lists where each element is a sequence 283 | maxlen: int, maximum length 284 | dtype: type to cast the resulting sequence. 285 | value: float, value to pad the sequences to the desired value. 286 | # Returns 287 | x: numpy array with dimensions (number_of_sequences, maxlen) 288 | ''' 289 | lengths = [len(s) for s in sequences] 290 | nb_samples = len(sequences) 291 | if maxlen is None: 292 | maxlen = np.max(lengths) 293 | # take the sample shape from the first non empty sequence 294 | # checking for consistency in the main loop below. 295 | sample_shape = tuple() 296 | for s in sequences: 297 | if len(s) > 0: 298 | sample_shape = np.asarray(s).shape[1:] 299 | break 300 | x = (np.ones((nb_samples, maxlen) + sample_shape) * value).astype(dtype) 301 | for idx, s in enumerate(sequences): 302 | if len(s) == 0: 303 | continue # empty list was found 304 | # pre truncating 305 | trunc = s[-maxlen:] 306 | # check `trunc` has expected shape 307 | trunc = np.asarray(trunc, dtype=dtype) 308 | if trunc.shape[1:] != sample_shape: 309 | raise ValueError( 310 | 'Shape of sample %s of sequence at position %s is different from expected shape %s' 311 | % (trunc.shape[1:], idx, sample_shape)) 312 | # post padding 313 | x[idx, :len(trunc)] = trunc 314 | return x 315 | 316 | 317 | def add_special_tok_and_pad(trajectories: List[int], special_begin=101): 318 | """ 319 | TODO 320 | """ 321 | # Pad to max length of the batch & convert to tensor 322 | trajectories = pad_sequences(trajectories) 323 | trajectories = torch.tensor(trajectories, dtype=torch.long, device=device) 324 | 325 | # Add ([CLS]) token 326 | sos_tokens = special_begin * torch.ones( 327 | (len(trajectories), 1), dtype=torch.long, device=device) 328 | trajectories = torch.cat((sos_tokens, trajectories), dim=1) 329 | 330 | return trajectories 331 | 332 | 333 | def create_trajectories(past_acts, 334 | acts, 335 | obs=None, 336 | desc=None, 337 | inv=None, 338 | sep_id: int = None, 339 | cls_id: int = None, 340 | do_pad_and_special_tok: bool = True): 341 | """ 342 | TODO 343 | """ 344 | act_sizes = [len(a) for a in acts] 345 | 346 | # 2D list of unrolled valid actions in the batch 347 | act_batch = list(itertools.chain.from_iterable(acts)) 348 | 349 | for act in act_batch: 350 | assert len(act) > 1, "empty action! {}".format(act) 351 | assert act[-1] == 50258, "not ending with sep!" 352 | 353 | if obs is not None: 354 | states = [] 355 | for i in range(len(past_acts)): 356 | states.append(past_acts[i] + obs[i] + desc[i] + inv[i] + [sep_id]) 357 | 358 | # Repeat state for each valid action in that state 359 | trajectories = [ 360 | states[i] + acts[i][idx] for i, j in enumerate(act_sizes) 361 | for idx in range(j) 362 | ] 363 | else: 364 | states = past_acts 365 | # Repeat state for each valid action in that state 366 | trajectories = [ 367 | states[i] + acts[i][idx] for i, j in enumerate(act_sizes) 368 | for idx in range(j) 369 | ] 370 | 371 | for i, size in enumerate(act_sizes): 372 | for idx in range(size): 373 | assert len(acts[i][idx]) > 1, "too short of an action! {}".format( 374 | acts[i][idx]) 375 | 376 | # Note we subtract one here to not count [SEP] token 377 | mask = [[0] * len(states[i]) + [1] * (len(acts[i][idx]) - 1) + [0] 378 | for i, size in enumerate(act_sizes) for idx in range(size)] 379 | assert len(trajectories) == len(act_batch) 380 | 381 | # only pad and add CLS if asked for 382 | if do_pad_and_special_tok: 383 | if cls_id is None: 384 | trajectories = add_special_tok_and_pad(trajectories) 385 | else: 386 | trajectories = add_special_tok_and_pad(trajectories, 387 | special_begin=cls_id) 388 | mask = add_special_tok_and_pad(mask, special_begin=0) 389 | 390 | # Make sure there is at least one element not masked out 391 | for el in mask: 392 | assert 1 in el, "mask consists of all zeros!" 393 | 394 | if hasattr(trajectories, 'shape'): 395 | assert trajectories.shape == mask.shape 396 | if hasattr(mask, 'cpu'): 397 | assert np.all( 398 | np.array(list(map(lambda x: len(x), act_batch))) == 399 | np.sum(mask.cpu().numpy(), axis=1) + 1) 400 | 401 | return trajectories, act_sizes, mask 402 | -------------------------------------------------------------------------------- /utils/vec_env.py: -------------------------------------------------------------------------------- 1 | # Built-in imports 2 | from multiprocessing import Process, Pipe, Manager 3 | 4 | # Libraries 5 | import numpy as np 6 | 7 | 8 | def worker(remote, parent_remote, env): 9 | parent_remote.close() 10 | try: 11 | done = False 12 | while True: 13 | cmd, data = remote.recv() 14 | if cmd == 'step': 15 | if done: 16 | ob, info = env.reset() 17 | reward = 0 18 | done = False 19 | else: 20 | ob, reward, done, info = env.step(data) 21 | remote.send((ob, reward, done, info)) 22 | elif cmd == 'reset': 23 | ob, info = env.reset() 24 | done = False 25 | remote.send((ob, info)) 26 | elif cmd == 'get_ngram_hits': 27 | remote.send(env.ngram_hits) 28 | elif cmd == 'get_end_scores': 29 | remote.send(env.get_end_scores(last=100)) 30 | elif cmd == 'get_current_score': 31 | remote.send(env.get_score()) 32 | elif cmd == 'get_current_step': 33 | remote.send(env.steps) 34 | elif cmd == 'add_traj': 35 | env.traj.append(data) 36 | remote.send(env.traj) 37 | elif cmd == 'add_full_traj': 38 | env.full_traj.append(data) 39 | remote.send(env.full_traj) 40 | elif cmd == 'update_ngram_hits': 41 | env.ngram_hits += data 42 | remote.send(env.ngram_hits) 43 | elif cmd == 'set_env_limit': 44 | env.step_limit = data 45 | remote.send(env.step_limit) 46 | elif cmd == 'turn_off_trajectory': 47 | env.turn_off_trajectory() 48 | remote.send(True) 49 | elif cmd == 'get_trajectory_state': 50 | traj_state = env.get_trajectory_state() 51 | remote.send(traj_state) 52 | elif cmd == 'get_env_limit': 53 | remote.send(env.step_limit) 54 | elif cmd == "get_traj": 55 | remote.send(env.traj) 56 | elif cmd == 'get_ngram_needs_update': 57 | remote.send(env.ngram_needs_update) 58 | elif cmd == 'set_ngram_needs_update': 59 | env.ngram_needs_update = data 60 | remote.send(env.ngram_needs_update) 61 | elif cmd == 'get_cache_size': 62 | remote.send(len(env.cache)) 63 | elif cmd == 'get_unique_acts_size': 64 | remote.send(len(env.unique_acts)) 65 | elif cmd == 'close': 66 | env.close() 67 | break 68 | else: 69 | raise NotImplementedError 70 | except KeyboardInterrupt: 71 | print('SubprocVecEnv worker: got KeyboardInterrupt') 72 | finally: 73 | env.close() 74 | 75 | 76 | class VecEnv: 77 | def __init__(self, num_envs, env): 78 | self.closed = False 79 | self.num_envs = num_envs 80 | self.remotes, self.work_remotes = zip( 81 | *[Pipe() for _ in range(num_envs)]) 82 | env.cache = Manager().dict() 83 | env.unique_acts = Manager().dict() 84 | self.ps = [Process(target=worker, args=(work_remote, remote, env)) 85 | for (work_remote, remote) in zip(self.work_remotes, self.remotes)] 86 | for p in self.ps: 87 | # p.daemon = True # if the main process crashes, we should not cause things to hang 88 | p.start() 89 | for remote in self.work_remotes: 90 | remote.close() 91 | 92 | def __len__(self): 93 | return self.num_envs 94 | 95 | def step(self, actions): 96 | self._assert_not_closed() 97 | assert len( 98 | actions) == self.num_envs, "Error: incorrect number of actions." 99 | for remote, action in zip(self.remotes, actions): 100 | remote.send(('step', action)) 101 | results = [remote.recv() for remote in self.remotes] 102 | self.waiting = False 103 | obs, rewards, dones, infos = zip(*results) 104 | return list(obs), np.stack(rewards), list(dones), infos 105 | 106 | def reset(self): 107 | self._assert_not_closed() 108 | for remote in self.remotes: 109 | remote.send(('reset', None)) 110 | results = [remote.recv() for remote in self.remotes] 111 | obs, infos = zip(*results) 112 | return np.stack(obs), infos 113 | 114 | def get_ngram_hits(self, i): 115 | self._assert_not_closed() 116 | self.remotes[i].send(('get_ngram_hits', None)) 117 | ngram_hits = self.remotes[i].recv() 118 | return ngram_hits 119 | 120 | def update_ngram_hits(self, beta_vec): 121 | self._assert_not_closed() 122 | for i, remote in enumerate(self.remotes): 123 | remote.send(('update_ngram_hits', beta_vec[i])) 124 | results = [remote.recv() for remote in self.remotes] 125 | return np.stack(results) 126 | 127 | def turn_off_trajectory(self, i: int): 128 | self._assert_not_closed() 129 | self.remotes[i].send(('turn_off_trajectory', None)) 130 | result = self.remotes[i].recv() 131 | return result 132 | 133 | def get_trajectory_state(self, i: int): 134 | self._assert_not_closed() 135 | self.remotes[i].send(('get_trajectory_state', None)) 136 | result = self.remotes[i].recv() 137 | return result 138 | 139 | def reset_one(self, i): 140 | self._assert_not_closed() 141 | self.remotes[i].send(('reset', None)) 142 | ob, info = self.remotes[i].recv() 143 | return ob, info 144 | 145 | def get_end_scores(self): 146 | self._assert_not_closed() 147 | for remote in self.remotes: 148 | remote.send(('get_end_scores', None)) 149 | results = [remote.recv() for remote in self.remotes] 150 | return np.stack(results) 151 | 152 | def get_current_scores(self): 153 | self._assert_not_closed() 154 | for remote in self.remotes: 155 | remote.send(('get_current_score', None)) 156 | results = [remote.recv() for remote in self.remotes] 157 | return np.stack(results) 158 | 159 | def set_env_limit(self, limit, i): 160 | self._assert_not_closed() 161 | self.remotes[i].send(('set_env_limit', limit)) 162 | result = self.remotes[i].recv() 163 | return result 164 | 165 | def get_env_limit(self): 166 | self._assert_not_closed() 167 | self.remotes[0].send(('get_env_limit', None)) 168 | limit = self.remotes[0].recv() 169 | return limit 170 | 171 | def get_ngram_needs_update(self, i: int): 172 | self._assert_not_closed() 173 | self.remotes[i].send(('get_ngram_needs_update', None)) 174 | result = self.remotes[i].recv() 175 | return result 176 | 177 | def set_ngram_needs_update(self, update): 178 | self._assert_not_closed() 179 | for remote in self.remotes: 180 | remote.send(('set_ngram_needs_update', update)) 181 | results = [remote.recv() for remote in self.remotes] 182 | return np.stack(results) 183 | 184 | def set_ngram_needs_update_i(self, update: bool, i: int): 185 | self._assert_not_closed() 186 | self.remotes[i].send(('set_ngram_needs_update', update)) 187 | result = self.remotes[i].recv() 188 | return result 189 | 190 | def get_current_steps(self): 191 | self._assert_not_closed() 192 | for remote in self.remotes: 193 | remote.send(('get_current_step', None)) 194 | results = [remote.recv() for remote in self.remotes] 195 | return np.stack(results) 196 | 197 | def get_cache_size(self): 198 | self._assert_not_closed() 199 | self.remotes[0].send(('get_cache_size', None)) 200 | result = self.remotes[0].recv() 201 | return result 202 | 203 | def get_unique_acts(self): 204 | self._assert_not_closed() 205 | self.remotes[0].send(('get_unique_acts_size', None)) 206 | result = self.remotes[0].recv() 207 | return result 208 | 209 | def add_traj(self, action_strs): 210 | self._assert_not_closed() 211 | for remote, action_str in zip(self.remotes, action_strs): 212 | remote.send(('add_traj', action_str)) 213 | results = [remote.recv() for remote in self.remotes] 214 | return results 215 | 216 | def add_full_traj(self, traj_steps): 217 | self._assert_not_closed() 218 | for remote, traj_step in zip(self.remotes, traj_steps): 219 | remote.send(('add_full_traj', traj_step)) 220 | results = [remote.recv() for remote in self.remotes] 221 | return results 222 | 223 | def add_full_traj_i(self, i: int, traj_step): 224 | self._assert_not_closed() 225 | self.remotes[i].send(('add_full_traj', traj_step)) 226 | result = self.remotes[i].recv() 227 | return result 228 | 229 | def get_traj_i(self, i: int): 230 | self._assert_not_closed() 231 | self.remotes[i].send(('get_traj', None)) 232 | result = self.remotes[i].recv() 233 | return result 234 | 235 | def close_extras(self): 236 | self.closed = True 237 | for remote in self.remotes: 238 | remote.send(('close', None)) 239 | for p in self.ps: 240 | p.join() 241 | 242 | def _assert_not_closed(self): 243 | assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()" 244 | -------------------------------------------------------------------------------- /yml_envs/jericho-no-wt.yml: -------------------------------------------------------------------------------- 1 | name: jericho-no-wt 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=4.5=1_gnu 9 | - aiohttp=3.7.4=py37h5e8e339_0 10 | - async-timeout=3.0.1=py_1000 11 | - attrs=21.2.0=pyhd8ed1ab_0 12 | - autopep8=1.5.7=pyhd3eb1b0_0 13 | - blas=1.0=mkl 14 | - boto=2.49.0=py_0 15 | - boto3=1.17.92=pyhd8ed1ab_0 16 | - botocore=1.20.92=pyhd8ed1ab_0 17 | - brotlipy=0.7.0=py37h5e8e339_1001 18 | - bz2file=0.98=py_0 19 | - c-ares=1.17.1=h7f98852_1 20 | - ca-certificates=2021.5.25=h06a4308_1 21 | - cachetools=4.2.2=pyhd8ed1ab_0 22 | - catalogue=2.0.4=py37h89c1867_0 23 | - certifi=2021.5.30=py37h06a4308_0 24 | - cffi=1.14.5=py37hc58025e_0 25 | - chardet=4.0.0=py37h89c1867_1 26 | - click=7.1.2=pyh9f0ad1d_0 27 | - colorama=0.4.4=pyh9f0ad1d_0 28 | - cryptography=3.4.7=py37h5d9358c_0 29 | - cudatoolkit=11.0.221=h6bb024c_0 30 | - cymem=2.0.5=py37hcd2ae1e_1 31 | - cython-blis=0.7.4=py37h902c9e0_0 32 | - dataclasses=0.8=pyhc8e2a94_1 33 | - freetype=2.10.4=h5ab3b9f_0 34 | - google-api-core=1.26.3=pyhd8ed1ab_0 35 | - google-auth=1.30.0=pyh44b312d_0 36 | - google-cloud-core=1.5.0=pyhd3deb0d_0 37 | - google-cloud-storage=1.19.0=py_0 38 | - google-crc32c=1.1.2=py37hab72019_0 39 | - google-resumable-media=1.2.0=pyhd3deb0d_0 40 | - googleapis-common-protos=1.53.0=py37h89c1867_0 41 | - grpcio=1.38.0=py37hb27c1af_0 42 | - idna=2.10=pyh9f0ad1d_0 43 | - intel-openmp=2021.2.0=h06a4308_610 44 | - jinja2=3.0.1=pyhd8ed1ab_0 45 | - jmespath=0.10.0=pyh9f0ad1d_0 46 | - jpeg=9b=h024ee3a_2 47 | - lcms2=2.12=h3be6417_0 48 | - ld_impl_linux-64=2.35.1=h7274673_9 49 | - libcrc32c=1.1.1=h9c3ff4c_2 50 | - libffi=3.3=he6710b0_2 51 | - libgcc-ng=9.3.0=h5101ec6_17 52 | - libgomp=9.3.0=h5101ec6_17 53 | - libpng=1.6.37=hbc83047_0 54 | - libprotobuf=3.17.2=h780b84a_0 55 | - libstdcxx-ng=9.3.0=hd4cf53a_17 56 | - libtiff=4.2.0=h85742a9_0 57 | - libuv=1.40.0=h7b6447c_0 58 | - libwebp-base=1.2.0=h27cfd23_0 59 | - lz4-c=1.9.3=h2531618_0 60 | - markupsafe=2.0.1=py37h5e8e339_0 61 | - mkl=2021.2.0=h06a4308_296 62 | - mkl-service=2.3.0=py37h27cfd23_1 63 | - mkl_fft=1.3.0=py37h42c9631_2 64 | - mkl_random=1.2.1=py37ha9443f7_2 65 | - multidict=5.1.0=py37h5e8e339_1 66 | - murmurhash=1.0.5=py37hcd2ae1e_0 67 | - ncurses=6.2=he6710b0_1 68 | - ninja=1.10.2=hff7bd54_1 69 | - numpy=1.20.2=py37h2d18471_0 70 | - numpy-base=1.20.2=py37hfae3a4d_0 71 | - olefile=0.46=py37_0 72 | - openssl=1.1.1k=h27cfd23_0 73 | - packaging=20.9=pyh44b312d_0 74 | - pathy=0.5.2=pyhd8ed1ab_0 75 | - pillow=8.2.0=py37he98fc37_0 76 | - pip=21.1.2=py37h06a4308_0 77 | - preshed=3.0.5=py37hcd2ae1e_0 78 | - pyasn1=0.4.8=py_0 79 | - pyasn1-modules=0.2.7=py_0 80 | - pycodestyle=2.7.0=pyhd3eb1b0_0 81 | - pycparser=2.20=pyh9f0ad1d_2 82 | - pyopenssl=20.0.1=pyhd8ed1ab_0 83 | - pyparsing=2.4.7=pyh9f0ad1d_0 84 | - pysocks=1.7.1=py37h89c1867_3 85 | - python=3.7.7=hcff3b4d_5 86 | - python-dateutil=2.8.1=py_0 87 | - python_abi=3.7=1_cp37m 88 | - pytorch=1.7.1=py3.7_cuda11.0.221_cudnn8.0.5_0 89 | - pytz=2021.1=pyhd8ed1ab_0 90 | - readline=8.1=h27cfd23_0 91 | - requests=2.25.1=pyhd3deb0d_0 92 | - rsa=4.7.2=pyh44b312d_0 93 | - s3transfer=0.4.2=pyhd8ed1ab_0 94 | - setuptools=52.0.0=py37h06a4308_0 95 | - shellingham=1.4.0=pyh44b312d_0 96 | - six=1.15.0=py37h06a4308_0 97 | - smart_open=2.2.1=pyh9f0ad1d_0 98 | - spacy=3.0.6=py37hda21425_0 99 | - spacy-legacy=3.0.5=pyhd8ed1ab_0 100 | - sqlite=3.35.4=hdfb4753_0 101 | - srsly=2.4.1=py37hcd2ae1e_0 102 | - tk=8.6.10=hbc83047_0 103 | - toml=0.10.2=pyhd3eb1b0_0 104 | - torchaudio=0.7.2=py37 105 | - torchvision=0.8.2=py37_cu110 106 | - tqdm=4.61.0=pyhd8ed1ab_0 107 | - typer=0.3.2=pyhd8ed1ab_0 108 | - typing-extensions=3.10.0.0=hd8ed1ab_0 109 | - typing_extensions=3.10.0.0=pyha770c72_0 110 | - urllib3=1.26.5=pyhd8ed1ab_0 111 | - wasabi=0.8.2=pyh44b312d_0 112 | - wheel=0.36.2=pyhd3eb1b0_0 113 | - xz=5.2.5=h7b6447c_0 114 | - yarl=1.6.3=py37h5e8e339_1 115 | - zipp=3.4.1=pyhd8ed1ab_0 116 | - zlib=1.2.11=h7b6447c_3 117 | - zstd=1.4.9=haebb681_0 118 | - pip: 119 | - absl-py==0.12.0 120 | - astor==0.8.1 121 | - cached-property==1.5.2 122 | - configparser==5.0.2 123 | - docker-pycreds==0.4.0 124 | - en-core-web-sm==3.0.0 125 | - fasttext==0.9.2 126 | - filelock==3.0.12 127 | - gast==0.4.0 128 | - gitdb==4.0.7 129 | - gitpython==3.1.17 130 | - h5py==3.2.1 131 | - importlib-metadata==4.5.0 132 | - jericho==3.1.0 133 | - joblib==1.0.1 134 | - keras-applications==1.0.8 135 | - keras-preprocessing==1.1.2 136 | - markdown==3.3.4 137 | - mock==4.0.3 138 | - pathtools==0.1.2 139 | - promise==2.3 140 | - protobuf==3.17.3 141 | - psutil==5.8.0 142 | - pybind11==2.6.2 143 | - pydantic==1.7.4 144 | - pyinstrument==4.0.3 145 | - pynvml==11.0.0 146 | - pyyaml==5.4.1 147 | - regex==2021.4.4 148 | - sacremoses==0.0.45 149 | - sentry-sdk==1.1.0 150 | - shortuuid==1.0.1 151 | - smart-open==3.0.0 152 | - smmap==4.0.0 153 | - subprocess32==3.5.4 154 | - tensorboard==1.13.1 155 | - tensorflow-estimator==1.13.0 156 | - tensorflow-gpu==1.13.1 157 | - termcolor==1.1.0 158 | - thinc==8.0.4 159 | - tokenizers==0.10.3 160 | - transformers==4.4.2 161 | - wandb==0.12.2 162 | - werkzeug==2.0.1 163 | - yaspin==2.1.0 -------------------------------------------------------------------------------- /yml_envs/jericho-wt.yml: -------------------------------------------------------------------------------- 1 | name: jericho-wt 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=4.5=1_gnu 9 | - aiohttp=3.7.4=py37h5e8e339_0 10 | - async-timeout=3.0.1=py_1000 11 | - attrs=21.2.0=pyhd8ed1ab_0 12 | - autopep8=1.5.7=pyhd3eb1b0_0 13 | - blas=1.0=mkl 14 | - boto=2.49.0=py_0 15 | - boto3=1.17.92=pyhd8ed1ab_0 16 | - botocore=1.20.92=pyhd8ed1ab_0 17 | - brotlipy=0.7.0=py37h5e8e339_1001 18 | - bz2file=0.98=py_0 19 | - c-ares=1.17.1=h7f98852_1 20 | - ca-certificates=2021.5.25=h06a4308_1 21 | - cachetools=4.2.2=pyhd8ed1ab_0 22 | - catalogue=2.0.4=py37h89c1867_0 23 | - certifi=2021.5.30=py37h06a4308_0 24 | - cffi=1.14.5=py37hc58025e_0 25 | - chardet=4.0.0=py37h89c1867_1 26 | - click=7.1.2=pyh9f0ad1d_0 27 | - colorama=0.4.4=pyh9f0ad1d_0 28 | - cryptography=3.4.7=py37h5d9358c_0 29 | - cudatoolkit=11.0.221=h6bb024c_0 30 | - cymem=2.0.5=py37hcd2ae1e_1 31 | - cython-blis=0.7.4=py37h902c9e0_0 32 | - dataclasses=0.8=pyhc8e2a94_1 33 | - freetype=2.10.4=h5ab3b9f_0 34 | - google-api-core=1.26.3=pyhd8ed1ab_0 35 | - google-auth=1.30.0=pyh44b312d_0 36 | - google-cloud-core=1.5.0=pyhd3deb0d_0 37 | - google-cloud-storage=1.19.0=py_0 38 | - google-crc32c=1.1.2=py37hab72019_0 39 | - google-resumable-media=1.2.0=pyhd3deb0d_0 40 | - googleapis-common-protos=1.53.0=py37h89c1867_0 41 | - grpcio=1.38.0=py37hb27c1af_0 42 | - idna=2.10=pyh9f0ad1d_0 43 | - intel-openmp=2021.2.0=h06a4308_610 44 | - jinja2=3.0.1=pyhd8ed1ab_0 45 | - jmespath=0.10.0=pyh9f0ad1d_0 46 | - jpeg=9b=h024ee3a_2 47 | - lcms2=2.12=h3be6417_0 48 | - ld_impl_linux-64=2.35.1=h7274673_9 49 | - libcrc32c=1.1.1=h9c3ff4c_2 50 | - libffi=3.3=he6710b0_2 51 | - libgcc-ng=9.3.0=h5101ec6_17 52 | - libgomp=9.3.0=h5101ec6_17 53 | - libpng=1.6.37=hbc83047_0 54 | - libprotobuf=3.17.2=h780b84a_0 55 | - libstdcxx-ng=9.3.0=hd4cf53a_17 56 | - libtiff=4.2.0=h85742a9_0 57 | - libuv=1.40.0=h7b6447c_0 58 | - libwebp-base=1.2.0=h27cfd23_0 59 | - lz4-c=1.9.3=h2531618_0 60 | - markupsafe=2.0.1=py37h5e8e339_0 61 | - mkl=2021.2.0=h06a4308_296 62 | - mkl-service=2.3.0=py37h27cfd23_1 63 | - mkl_fft=1.3.0=py37h42c9631_2 64 | - mkl_random=1.2.1=py37ha9443f7_2 65 | - multidict=5.1.0=py37h5e8e339_1 66 | - murmurhash=1.0.5=py37hcd2ae1e_0 67 | - ncurses=6.2=he6710b0_1 68 | - ninja=1.10.2=hff7bd54_1 69 | - numpy=1.20.2=py37h2d18471_0 70 | - numpy-base=1.20.2=py37hfae3a4d_0 71 | - olefile=0.46=py37_0 72 | - openssl=1.1.1k=h27cfd23_0 73 | - packaging=20.9=pyh44b312d_0 74 | - pathy=0.5.2=pyhd8ed1ab_0 75 | - pillow=8.2.0=py37he98fc37_0 76 | - pip=21.1.2=py37h06a4308_0 77 | - preshed=3.0.5=py37hcd2ae1e_0 78 | - pyasn1=0.4.8=py_0 79 | - pyasn1-modules=0.2.7=py_0 80 | - pycodestyle=2.7.0=pyhd3eb1b0_0 81 | - pycparser=2.20=pyh9f0ad1d_2 82 | - pyopenssl=20.0.1=pyhd8ed1ab_0 83 | - pyparsing=2.4.7=pyh9f0ad1d_0 84 | - pysocks=1.7.1=py37h89c1867_3 85 | - python=3.7.7=hcff3b4d_5 86 | - python-dateutil=2.8.1=py_0 87 | - python_abi=3.7=1_cp37m 88 | - pytorch=1.7.1=py3.7_cuda11.0.221_cudnn8.0.5_0 89 | - pytz=2021.1=pyhd8ed1ab_0 90 | - readline=8.1=h27cfd23_0 91 | - requests=2.25.1=pyhd3deb0d_0 92 | - rsa=4.7.2=pyh44b312d_0 93 | - s3transfer=0.4.2=pyhd8ed1ab_0 94 | - setuptools=52.0.0=py37h06a4308_0 95 | - shellingham=1.4.0=pyh44b312d_0 96 | - six=1.15.0=py37h06a4308_0 97 | - smart_open=2.2.1=pyh9f0ad1d_0 98 | - spacy=3.0.6=py37hda21425_0 99 | - spacy-legacy=3.0.5=pyhd8ed1ab_0 100 | - sqlite=3.35.4=hdfb4753_0 101 | - srsly=2.4.1=py37hcd2ae1e_0 102 | - tk=8.6.10=hbc83047_0 103 | - toml=0.10.2=pyhd3eb1b0_0 104 | - torchaudio=0.7.2=py37 105 | - torchvision=0.8.2=py37_cu110 106 | - tqdm=4.61.0=pyhd8ed1ab_0 107 | - typer=0.3.2=pyhd8ed1ab_0 108 | - typing-extensions=3.10.0.0=hd8ed1ab_0 109 | - typing_extensions=3.10.0.0=pyha770c72_0 110 | - urllib3=1.26.5=pyhd8ed1ab_0 111 | - wasabi=0.8.2=pyh44b312d_0 112 | - wheel=0.36.2=pyhd3eb1b0_0 113 | - xz=5.2.5=h7b6447c_0 114 | - yarl=1.6.3=py37h5e8e339_1 115 | - zipp=3.4.1=pyhd8ed1ab_0 116 | - zlib=1.2.11=h7b6447c_3 117 | - zstd=1.4.9=haebb681_0 118 | - pip: 119 | - absl-py==0.12.0 120 | - astor==0.8.1 121 | - cached-property==1.5.2 122 | - configparser==5.0.2 123 | - docker-pycreds==0.4.0 124 | - en-core-web-sm==3.0.0 125 | - fasttext==0.9.2 126 | - filelock==3.0.12 127 | - gast==0.4.0 128 | - gitdb==4.0.7 129 | - gitpython==3.1.17 130 | - h5py==3.2.1 131 | - importlib-metadata==4.5.0 132 | - joblib==1.0.1 133 | - keras-applications==1.0.8 134 | - keras-preprocessing==1.1.2 135 | - markdown==3.3.4 136 | - mock==4.0.3 137 | - pathtools==0.1.2 138 | - promise==2.3 139 | - protobuf==3.17.3 140 | - psutil==5.8.0 141 | - pybind11==2.6.2 142 | - pydantic==1.7.4 143 | - pyinstrument==4.0.3 144 | - pynvml==11.0.0 145 | - pyyaml==5.4.1 146 | - regex==2021.4.4 147 | - sacremoses==0.0.45 148 | - sentry-sdk==1.1.0 149 | - shortuuid==1.0.1 150 | - smart-open==3.0.0 151 | - smmap==4.0.0 152 | - subprocess32==3.5.4 153 | - tensorboard==1.13.1 154 | - tensorflow-estimator==1.13.0 155 | - tensorflow-gpu==1.13.1 156 | - termcolor==1.1.0 157 | - thinc==8.0.4 158 | - tokenizers==0.10.3 159 | - transformers==4.4.2 160 | - wandb==0.12.2 161 | - werkzeug==2.0.1 162 | - yaspin==2.1.0 -------------------------------------------------------------------------------- /yml_envs/zork1-environment.yml: -------------------------------------------------------------------------------- 1 | name: jericho-zork1 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - autopep8=1.5.7=pyhd3eb1b0_0 9 | - blas=1.0=mkl 10 | - ca-certificates=2021.7.5=h06a4308_1 11 | - certifi=2021.5.30=py37h06a4308_0 12 | - cudatoolkit=11.0.221=h6bb024c_0 13 | - cycler=0.10.0=py37_0 14 | - dbus=1.13.18=hb2f20db_0 15 | - expat=2.4.1=h2531618_2 16 | - fontconfig=2.13.1=h6c09931_0 17 | - freetype=2.10.4=h5ab3b9f_0 18 | - glib=2.69.0=h5202010_0 19 | - gst-plugins-base=1.14.0=h8213a91_2 20 | - gstreamer=1.14.0=h28cd5cc_2 21 | - icu=58.2=he6710b0_3 22 | - intel-openmp=2021.2.0=h06a4308_610 23 | - jpeg=9b=h024ee3a_2 24 | - kiwisolver=1.3.1=py37h2531618_0 25 | - lcms2=2.12=h3be6417_0 26 | - ld_impl_linux-64=2.33.1=h53a641e_7 27 | - libffi=3.3=he6710b0_2 28 | - libgcc=7.2.0=h69d50b8_2 29 | - libgcc-ng=9.1.0=hdf63c60_0 30 | - libpng=1.6.37=hbc83047_0 31 | - libstdcxx-ng=9.1.0=hdf63c60_0 32 | - libtiff=4.1.0=h2733197_1 33 | - libuuid=1.0.3=h1bed415_2 34 | - libuv=1.40.0=h7b6447c_0 35 | - libxcb=1.14=h7b6447c_0 36 | - libxml2=2.9.10=hb55368b_3 37 | - lz4-c=1.9.3=h2531618_0 38 | - matplotlib=3.3.4=py37h06a4308_0 39 | - matplotlib-base=3.3.4=py37h62a2d02_0 40 | - mkl=2021.2.0=h06a4308_296 41 | - mkl-service=2.3.0=py37h27cfd23_1 42 | - mkl_fft=1.3.0=py37h42c9631_2 43 | - mkl_random=1.2.1=py37ha9443f7_2 44 | - ncurses=6.2=he6710b0_1 45 | - ninja=1.10.2=hff7bd54_1 46 | - nodejs=6.11.2=h3db8ef7_0 47 | - numpy-base=1.20.1=py37h7d8b39e_0 48 | - olefile=0.46=py37_0 49 | - openssl=1.1.1k=h27cfd23_0 50 | - pandas=1.2.4=py37h2531618_0 51 | - pcre=8.45=h295c915_0 52 | - pillow=8.2.0=py37he98fc37_0 53 | - pip=21.0.1=py37h06a4308_0 54 | - pycodestyle=2.7.0=pyhd3eb1b0_0 55 | - pyparsing=2.4.7=pyhd3eb1b0_0 56 | - pyqt=5.9.2=py37h05f1152_2 57 | - python=3.7.7=hcff3b4d_5 58 | - python_abi=3.7=2_cp37m 59 | - pytorch=1.7.1=py3.7_cuda11.0.221_cudnn8.0.5_0 60 | - pytz=2021.1=pyhd3eb1b0_0 61 | - qt=5.9.7=h5867ecd_1 62 | - readline=8.1=h27cfd23_0 63 | - setuptools=52.0.0=py37h06a4308_0 64 | - sip=4.19.8=py37hf484d3e_0 65 | - sqlite=3.35.4=hdfb4753_0 66 | - tk=8.6.10=hbc83047_0 67 | - toml=0.10.2=pyhd3eb1b0_0 68 | - torchaudio=0.7.2=py37 69 | - torchvision=0.8.2=py37_cu110 70 | - tornado=6.1=py37h27cfd23_0 71 | - typing_extensions=3.7.4.3=pyha847dfd_0 72 | - wheel=0.36.2=pyhd3eb1b0_0 73 | - xz=5.2.5=h7b6447c_0 74 | - zlib=1.2.11=h7b6447c_3 75 | - zstd=1.4.9=haebb681_0 76 | - pip: 77 | - attrs==21.2.0 78 | - blis==0.2.4 79 | - catalogue==2.0.4 80 | - chardet==4.0.0 81 | - click==7.1.2 82 | - cloudpickle==1.6.0 83 | - colorama==0.4.4 84 | - commonmark==0.9.1 85 | - configparser==5.0.2 86 | - cymem==2.0.5 87 | - docker-pycreds==0.4.0 88 | - en-core-web-sm==2.1.0 89 | - fasttext==0.9.2 90 | - filelock==3.0.12 91 | - gitdb==4.0.7 92 | - gitpython==3.1.17 93 | - gql==0.2.0 94 | - graphql-core==1.1 95 | - idna==2.10 96 | - importlib-metadata==4.0.1 97 | - iniconfig==1.1.1 98 | - jericho==3.0.2 99 | - jinja2==3.0.0 100 | - joblib==1.0.1 101 | - jsonschema==2.6.0 102 | - markupsafe==2.0.0 103 | - murmurhash==1.0.5 104 | - numpy==1.20.3 105 | - nvidia-ml-py==11.450.51 106 | - nvidia-ml-py3==7.352.0 107 | - packaging==20.9 108 | - pathtools==0.1.2 109 | - pathy==0.5.2 110 | - plac==0.9.6 111 | - pluggy==0.13.1 112 | - preshed==2.0.1 113 | - promise==2.3 114 | - protobuf==3.17.1 115 | - psutil==5.8.0 116 | - py==1.10.0 117 | - pybind11==2.6.2 118 | - pydantic==1.7.4 119 | - pygments==2.9.0 120 | - pyinstrument==3.4.2 121 | - pyinstrument-cext==0.2.4 122 | - pytest==6.2.4 123 | - python-dateutil==2.8.1 124 | - pyyaml==5.4.1 125 | - regex==2021.4.4 126 | - requests==2.25.1 127 | - rich==10.2.0 128 | - sacremoses==0.0.45 129 | - sentencepiece==0.1.91 130 | - sentry-sdk==1.1.0 131 | - shortuuid==1.0.1 132 | - six==1.16.0 133 | - smart-open==3.0.0 134 | - smmap==4.0.0 135 | - spacy==2.1.3 136 | - spacy-legacy==3.0.5 137 | - srsly==1.0.5 138 | - subprocess32==3.5.4 139 | - thinc==7.0.8 140 | - tokenizers==0.10.2 141 | - tqdm==4.60.0 142 | - transformers==4.4.2 143 | - typer==0.3.2 144 | - typing-extensions==3.10.0.0 145 | - urllib3==1.26.6 146 | - wandb==0.12.0 147 | - wasabi==0.8.2 148 | - watchdog==2.1.1 149 | - zipp==3.4.1 150 | --------------------------------------------------------------------------------