├── .gitignore ├── README.md ├── drrn.py ├── env.py ├── logger.py ├── memory.py ├── model.py ├── train.py ├── unigram_8k.model ├── util.py ├── vec_env.py └── zork1.z5 /.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | *.pdf 3 | *.rdb 4 | *.sh 5 | *.out 6 | run.slurm 7 | wandb/ 8 | *.swp 9 | logs/ 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # Environments 95 | .env 96 | .venv 97 | env/ 98 | venv/ 99 | ENV/ 100 | env.bak/ 101 | venv.bak/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | .spyproject 106 | 107 | # Rope project settings 108 | .ropeproject 109 | 110 | # mkdocs documentation 111 | /site 112 | 113 | # mypy 114 | .mypy_cache/ 115 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DRRN Model Variants (Hash input, Inverse dynamics) on Text Games 2 | 3 | Code for NAACL 2021 paper [Reading and Acting while Blindfolded: The Need for Semantics in Text Game Agents](https://arxiv.org/abs/2103.13552). 4 | 5 | Project site: https://blindfolded.cs.princeton.edu 6 | 7 | ## Getting Started 8 | 9 | - Install dependencies: 10 | ```bash 11 | pip install jericho fasttext 12 | ``` 13 | - Run baseline DRRN: 14 | ```python 15 | python train.py 16 | ``` 17 | 18 | - Run DRRN (hash): 19 | ```python 20 | python train.py --hash_rep 1 21 | ``` 22 | 23 | - Run DRRN (inv-dy): 24 | ```python 25 | python train.py --w_inv 1 --w_act 1 --r_for 1 26 | ``` 27 | 28 | Use ``--seed`` to specify game random seed. ``-1`` means episode-varying seeds (stochastic game mode), otherwise game mode is deterministic. 29 | 30 | Zork I is played by default. More games are [here](https://github.com/princeton-nlp/calm-textgame/tree/master/games) and use ``--rom_path`` to specify which game to play. 31 | 32 | ## Citation 33 | ``` 34 | @inproceedings{yao2021blindfolded, 35 | title={Reading and Acting while Blindfolded: The Need for Semantics in Text Game Agents}, 36 | author={Yao, Shunyu and Narasimhan, Karthik and Hausknecht, Matthew}, 37 | booktitle={North American Association for Computational Linguistics (NAACL)}, 38 | year={2021} 39 | } 40 | ``` 41 | ## Acknowledgements 42 | The code borrows from [TDQN](https://github.com/microsoft/tdqn). 43 | 44 | For any questions please contact Shunyu Yao ``. 45 | -------------------------------------------------------------------------------- /drrn.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import os 6 | from os.path import join as pjoin 7 | from memory import * 8 | from model import DRRN 9 | from util import * 10 | import logger 11 | from transformers import BertTokenizer 12 | import numpy as np 13 | 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | 16 | class DRRN_Agent: 17 | def __init__(self, args): 18 | self.gamma = args.gamma 19 | self.batch_size = args.batch_size 20 | self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") 21 | self.network = DRRN(len(self.tokenizer), args.embedding_dim, args.hidden_dim, args.fix_rep, args.hash_rep, args.act_obs).to(device) 22 | self.network.tokenizer = self.tokenizer 23 | self.memory = ABReplayMemory(args.memory_size, args.memory_alpha) 24 | self.save_path = args.output_dir 25 | self.clip = args.clip 26 | self.optimizer = torch.optim.Adam(self.network.parameters(), lr=args.learning_rate) 27 | 28 | self.type_inv = args.type_inv 29 | self.type_for = args.type_for 30 | self.w_inv = args.w_inv 31 | self.w_for = args.w_for 32 | self.w_act = args.w_act 33 | self.perturb = args.perturb 34 | 35 | self.act_obs = args.act_obs 36 | 37 | def observe(self, transition, is_prior=False): 38 | self.memory.push(transition, is_prior) 39 | 40 | 41 | def build_state(self, ob, info): 42 | """ Returns a state representation built from various info sources. """ 43 | if self.act_obs: 44 | acts = self.encode(info['valid']) 45 | obs_ids, look_ids, inv_ids = [], [], [] 46 | for act in acts: obs_ids += act 47 | return State(obs_ids, look_ids, inv_ids) 48 | obs_ids = self.tokenizer.encode(ob) 49 | look_ids = self.tokenizer.encode(info['look']) 50 | inv_ids = self.tokenizer.encode(info['inv']) 51 | return State(obs_ids, look_ids, inv_ids) 52 | 53 | 54 | def build_states(self, obs, infos): 55 | return [self.build_state(ob, info) for ob, info in zip(obs, infos)] 56 | 57 | 58 | def encode(self, obs_list): 59 | """ Encode a list of observations """ 60 | return [self.tokenizer.encode(o) for o in obs_list] 61 | 62 | 63 | def act(self, states, poss_acts, sample=True, eps=0.1): 64 | """ Returns a string action from poss_acts. """ 65 | idxs, values = self.network.act(states, poss_acts, sample, eps=eps) 66 | act_ids = [poss_acts[batch][idx] for batch, idx in enumerate(idxs)] 67 | return act_ids, idxs, values 68 | 69 | 70 | def q_loss(self, transitions, need_qvals=False): 71 | batch = Transition(*zip(*transitions)) 72 | 73 | # Compute Q(s', a') for all a' 74 | # TODO: Use a target network??? 75 | next_qvals = self.network(batch.next_state, batch.next_acts) 76 | # Take the max over next q-values 77 | next_qvals = torch.tensor([vals.max() for vals in next_qvals], device=device) 78 | # Zero all the next_qvals that are done 79 | next_qvals = next_qvals * (1-torch.tensor(batch.done, dtype=torch.float, device=device)) 80 | targets = torch.tensor(batch.reward, dtype=torch.float, device=device) + self.gamma * next_qvals 81 | 82 | # Next compute Q(s, a) 83 | # Nest each action in a list - so that it becomes the only admissible cmd 84 | nested_acts = tuple([[a] for a in batch.act]) 85 | qvals = self.network(batch.state, nested_acts) 86 | # Combine the qvals: Maybe just do a greedy max for generality 87 | qvals = torch.cat(qvals) 88 | loss = F.smooth_l1_loss(qvals, targets.detach()) 89 | 90 | return (loss, qvals) if need_qvals else loss 91 | 92 | def update(self): 93 | if len(self.memory) < self.batch_size: 94 | return None 95 | 96 | transitions = self.memory.sample(self.batch_size) 97 | batch = Transition(*zip(*transitions)) 98 | nested_acts = tuple([[a] for a in batch.act]) 99 | terms, loss = {}, 0 100 | 101 | # Compute Q learning Huber loss 102 | terms['Loss_q'], qvals = self.q_loss(transitions, need_qvals=True) 103 | loss += terms['Loss_q'] 104 | 105 | # Compute Inverse dynamics loss 106 | if self.w_inv > 0: 107 | if self.type_inv == 'decode': 108 | terms['Loss_id'], terms['Acc_id'] = self.network.inv_loss_decode(batch.state, batch.next_state, nested_acts, hat=True) 109 | elif self.type_inv == 'ce': 110 | terms['Loss_id'], terms['Acc_id'] = self.network.inv_loss_ce(batch.state, batch.next_state, nested_acts, batch.acts) 111 | else: 112 | raise NotImplementedError 113 | loss += self.w_inv * terms['Loss_id'] 114 | 115 | # Compute Act reconstruction loss 116 | if self.w_act > 0: 117 | terms['Loss_act'], terms['Acc_act'] = self.network.inv_loss_decode(batch.state, batch.next_state, nested_acts, hat=False) 118 | loss += self.w_act * terms['Loss_act'] 119 | 120 | # Compute Forward dynamics loss 121 | if self.w_for > 0: 122 | if self.type_for == 'l2': 123 | terms['Loss_fd'] = self.network.for_loss_l2(batch.state, batch.next_state, nested_acts) 124 | elif self.type_for == 'ce': 125 | terms['Loss_fd'], terms['Acc_fd'] = self.network.for_loss_ce(batch.state, batch.next_state, nested_acts, batch.acts) 126 | elif self.type_for == 'decode': 127 | terms['Loss_fd'], terms['Acc_fd'] = self.network.for_loss_decode(batch.state, batch.next_state, nested_acts, hat=True) 128 | elif self.type_for == 'decode_obs': 129 | terms['Loss_fd'], terms['Acc_fd'] = self.network.for_loss_decode(batch.state, batch.next_state, nested_acts, hat=False) 130 | 131 | loss += self.w_for * terms['Loss_fd'] 132 | 133 | # Backward 134 | terms.update({'Loss': loss, 'Q': qvals.mean()}) 135 | self.optimizer.zero_grad() 136 | loss.backward() 137 | nn.utils.clip_grad_norm_(self.network.parameters(), self.clip) 138 | self.optimizer.step() 139 | return {k: float(v) for k, v in terms.items()} 140 | 141 | 142 | def load(self, path=None): 143 | if path is None: 144 | return 145 | try: 146 | # self.memory = pickle.load(open(pjoin(path, 'memory.pkl'), 'rb')) 147 | network = torch.load(pjoin(path, 'model.pt')) 148 | parts = ['embedding', 'encoder'] # , 'hidden', 'act_scorer'] 149 | state_dict = network.state_dict() 150 | state_dict = {k: v for k, v in state_dict.items() if any(part in k for part in parts)} 151 | # print(state_dict.keys()) 152 | self.network.load_state_dict(state_dict, strict=False) 153 | 154 | except Exception as e: 155 | print("Error saving model.") 156 | logging.error(traceback.format_exc()) 157 | 158 | 159 | def save(self, step=''): 160 | try: 161 | os.makedirs(pjoin(self.save_path, step), exist_ok=True) 162 | pickle.dump(self.memory, open(pjoin(self.save_path, step, 'memory.pkl'), 'wb')) 163 | torch.save(self.network, pjoin(self.save_path, step, 'model.pt')) 164 | except Exception as e: 165 | print("Error saving model.") 166 | logging.error(traceback.format_exc()) 167 | -------------------------------------------------------------------------------- /env.py: -------------------------------------------------------------------------------- 1 | from jericho import * 2 | from jericho.util import * 3 | from jericho.defines import * 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def modify(info): 9 | info['look'] = clean(info['look']) 10 | 11 | # shuffle inventory 12 | info['inv'] = clean(info['inv']) 13 | invs = info['inv'].split(' ') 14 | if len(invs) > 1: 15 | head, invs = invs[0], invs[1:] 16 | np.random.shuffle(invs) 17 | info['inv'] = ' '.join([head] + invs) 18 | 19 | # action switch 20 | subs = [['take', 'grab'], ['drop', 'put down'], ['turn on', 'open'], ['turn off', 'close'], ['get in', 'enter']] 21 | navis = 'north/south/west/east/northwest/southwest/northeast/southeast/up/down'.split('/') 22 | navi_subs = [[w, 'go ' + w] for w in navis] 23 | def sub(act): 24 | for a, b in subs: 25 | act = act.replace(a, b) 26 | for a, b in navi_subs: 27 | if act == a: act = b 28 | return act 29 | info['valid'] = [sub(act) for act in info['valid']] 30 | 31 | return info 32 | 33 | class JerichoEnv: 34 | ''' Returns valid actions at each step of the game. ''' 35 | 36 | def __init__(self, rom_path, seed, step_limit=None, get_valid=True, cache=None, args=None): 37 | self.rom_path = rom_path 38 | self.env = FrotzEnv(rom_path, seed=seed) 39 | self.bindings = self.env.bindings 40 | self.seed = seed 41 | self.steps = 0 42 | self.step_limit = step_limit 43 | self.get_valid = get_valid 44 | self.max_score = 0 45 | self.end_scores = [] 46 | self.cache = cache 47 | self.nor = args.nor 48 | self.randr = args.randr 49 | np.random.seed(max(seed, 0)) 50 | self.random_rewards = (np.random.rand(10000) - .5) * 10. 51 | self.objs = set() 52 | self.perturb = args.perturb 53 | if self.perturb: 54 | self.en2de = args.en2de 55 | self.de2en = args.de2en 56 | self.perturb_dict = args.perturb_dict 57 | 58 | def paraphrase(self, s): 59 | if s in self.perturb_dict: return self.perturb_dict[s] 60 | with torch.no_grad(): 61 | p = self.de2en.translate(self.en2de.translate(s)) 62 | if p == '': p = '.' 63 | self.perturb_dict[s] = p 64 | return p 65 | 66 | def get_objects(self): 67 | desc2objs = self.env._identify_interactive_objects(use_object_tree=False) 68 | obj_set = set() 69 | for objs in desc2objs.values(): 70 | for obj, pos, source in objs: 71 | if pos == 'ADJ': continue 72 | obj_set.add(obj) 73 | return list(obj_set) 74 | 75 | def step(self, action): 76 | ob, reward, done, info = self.env.step(action) 77 | # if self.cache is not None: 78 | # self.cache['loc'].add(self.env.get_player_location().num) 79 | if self.nor: reward = 0 80 | # random reward 81 | if self.randr: 82 | reward = 0 83 | objs = [self.env.get_player_location()] + self.env.get_inventory() 84 | for obj in objs: 85 | obj = obj.num 86 | if obj not in self.objs: 87 | self.objs.add(obj) 88 | reward += self.random_rewards[obj] 89 | info['score'] = sum(self.random_rewards[obj] for obj in self.objs) 90 | 91 | # Initialize with default values 92 | info['look'] = 'unknown' 93 | info['inv'] = 'unknown' 94 | info['valid'] = ['wait', 'yes', 'no'] 95 | if not done: 96 | save = self.env.get_state() 97 | hash_save = self.env.get_world_state_hash() 98 | if self.cache is not None and hash_save in self.cache: 99 | info['look'], info['inv'], info['valid'] = self.cache[hash_save] 100 | else: 101 | look, _, _, _ = self.env.step('look') 102 | info['look'] = look.lower() 103 | self.env.set_state(save) 104 | inv, _, _, _ = self.env.step('inventory') 105 | info['inv'] = inv.lower() 106 | self.env.set_state(save) 107 | if self.get_valid: 108 | valid = self.env.get_valid_actions() 109 | if len(valid) == 0: 110 | valid = ['wait', 'yes', 'no'] 111 | info['valid'] = valid 112 | if self.cache is not None: 113 | self.cache[hash_save] = info['look'], info['inv'], info['valid'] 114 | 115 | self.steps += 1 116 | if self.step_limit and self.steps >= self.step_limit: 117 | done = True 118 | self.max_score = max(self.max_score, info['score']) 119 | if done: self.end_scores.append(info['score']) 120 | if self.perturb: 121 | ob = self.paraphrase(ob) 122 | info['look'] = self.paraphrase(info['look']) 123 | info['inv'] = self.paraphrase(info['inv']) 124 | return ob, reward, done, info 125 | 126 | def reset(self): 127 | initial_ob, info = self.env.reset() 128 | save = self.env.get_state() 129 | look, _, _, _ = self.env.step('look') 130 | info['look'] = look 131 | self.env.set_state(save) 132 | inv, _, _, _ = self.env.step('inventory') 133 | info['inv'] = inv 134 | self.env.set_state(save) 135 | valid = self.env.get_valid_actions() 136 | info['valid'] = valid 137 | self.steps = 0 138 | self.max_score = 0 139 | self.objs = set() 140 | return initial_ob, info 141 | 142 | def get_dictionary(self): 143 | if not self.env: 144 | self.create() 145 | return self.env.get_dictionary() 146 | 147 | def get_action_set(self): 148 | return None 149 | 150 | def get_end_scores(self, last=1): 151 | last = min(last, len(self.end_scores)) 152 | return sum(self.end_scores[-last:]) / last if last else 0 153 | 154 | def close(self): 155 | self.env.close() 156 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import os.path as osp 5 | import json 6 | import time 7 | import datetime 8 | import tempfile 9 | from collections import defaultdict 10 | import wandb 11 | 12 | DEBUG = 10 13 | INFO = 20 14 | WARN = 30 15 | ERROR = 40 16 | 17 | DISABLED = 50 18 | 19 | 20 | class KVWriter(object): 21 | def writekvs(self, kvs): 22 | raise NotImplementedError 23 | 24 | 25 | class SeqWriter(object): 26 | def writeseq(self, seq): 27 | raise NotImplementedError 28 | 29 | 30 | class HumanOutputFormat(KVWriter, SeqWriter): 31 | def __init__(self, filename_or_file): 32 | if isinstance(filename_or_file, str): 33 | self.file = open(filename_or_file, 'wt') 34 | self.own_file = True 35 | else: 36 | assert hasattr(filename_or_file, 'read'), 'expected file or str, got %s' % filename_or_file 37 | self.file = filename_or_file 38 | self.own_file = False 39 | 40 | def writekvs(self, kvs): 41 | # Create strings for printing 42 | key2str = {} 43 | for (key, val) in sorted(kvs.items()): 44 | if isinstance(val, float): 45 | valstr = '%-8.3g' % (val,) 46 | else: 47 | valstr = str(val) 48 | key2str[self._truncate(key)] = self._truncate(valstr) 49 | 50 | # Find max widths 51 | if len(key2str) == 0: 52 | print('WARNING: tried to write empty key-value dict') 53 | return 54 | else: 55 | keywidth = max(map(len, key2str.keys())) 56 | valwidth = max(map(len, key2str.values())) 57 | 58 | # Write out the data 59 | dashes = '-' * (keywidth + valwidth + 7) 60 | lines = [dashes] 61 | for (key, val) in sorted(key2str.items()): 62 | lines.append('| %s%s | %s%s |' % ( 63 | key, 64 | ' ' * (keywidth - len(key)), 65 | val, 66 | ' ' * (valwidth - len(val)), 67 | )) 68 | lines.append(dashes) 69 | self.file.write('\n'.join(lines) + '\n') 70 | 71 | # Flush the output to the file 72 | self.file.flush() 73 | 74 | def _truncate(self, s): 75 | return s[:20] + '...' if len(s) > 23 else s 76 | 77 | def writeseq(self, seq): 78 | seq = list(seq) 79 | for (i, elem) in enumerate(seq): 80 | self.file.write(elem) 81 | if i < len(seq) - 1: # add space unless this is the last one 82 | self.file.write(' ') 83 | self.file.write('\n') 84 | self.file.flush() 85 | 86 | def close(self): 87 | if self.own_file: 88 | self.file.close() 89 | 90 | 91 | class JSONOutputFormat(KVWriter): 92 | def __init__(self, filename): 93 | self.file = open(filename, 'wt') 94 | 95 | def writekvs(self, kvs): 96 | for k, v in sorted(kvs.items()): 97 | if hasattr(v, 'dtype'): 98 | v = v.tolist() 99 | kvs[k] = float(v) 100 | self.file.write(json.dumps(kvs) + '\n') 101 | self.file.flush() 102 | 103 | def close(self): 104 | self.file.close() 105 | 106 | 107 | class WandBOutputFormat(KVWriter): 108 | def __init__(self, filename): 109 | wandb.init(project='drrn_naacl', name=filename.split('/')[-1]) 110 | 111 | def writekvs(self, kvs): 112 | wandb.log(kvs) 113 | 114 | def close(self): 115 | pass 116 | 117 | 118 | class CSVOutputFormat(KVWriter): 119 | def __init__(self, filename): 120 | self.file = open(filename, 'w+t') 121 | self.keys = [] 122 | self.sep = ',' 123 | 124 | def writekvs(self, kvs): 125 | # Add our current row to the history 126 | extra_keys = kvs.keys() - self.keys 127 | if extra_keys: 128 | self.keys.extend(extra_keys) 129 | self.file.seek(0) 130 | lines = self.file.readlines() 131 | self.file.seek(0) 132 | for (i, k) in enumerate(self.keys): 133 | if i > 0: 134 | self.file.write(',') 135 | self.file.write(k) 136 | self.file.write('\n') 137 | for line in lines[1:]: 138 | self.file.write(line[:-1]) 139 | self.file.write(self.sep * len(extra_keys)) 140 | self.file.write('\n') 141 | for (i, k) in enumerate(self.keys): 142 | if i > 0: 143 | self.file.write(',') 144 | v = kvs.get(k) 145 | if v is not None: 146 | self.file.write(str(v)) 147 | self.file.write('\n') 148 | self.file.flush() 149 | 150 | def close(self): 151 | self.file.close() 152 | 153 | 154 | class TensorBoardOutputFormat(KVWriter): 155 | """ 156 | Dumps key/value pairs into TensorBoard's numeric format. 157 | """ 158 | 159 | def __init__(self, dir): 160 | os.makedirs(dir, exist_ok=True) 161 | self.dir = dir 162 | self.step = 1 163 | prefix = 'events' 164 | path = osp.join(osp.abspath(dir), prefix) 165 | import tensorflow as tf 166 | from tensorflow.python import pywrap_tensorflow 167 | from tensorflow.core.util import event_pb2 168 | from tensorflow.python.util import compat 169 | self.tf = tf 170 | self.event_pb2 = event_pb2 171 | self.pywrap_tensorflow = pywrap_tensorflow 172 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 173 | 174 | def writekvs(self, kvs): 175 | def summary_val(k, v): 176 | kwargs = {'tag': k, 'simple_value': float(v)} 177 | return self.tf.Summary.Value(**kwargs) 178 | 179 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 180 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 181 | event.step = self.step # is there any reason why you'd want to specify the step? 182 | self.writer.WriteEvent(event) 183 | self.writer.Flush() 184 | self.step += 1 185 | 186 | def close(self): 187 | if self.writer: 188 | self.writer.Close() 189 | self.writer = None 190 | 191 | 192 | def make_output_format(format, ev_dir, log_suffix='', args=None): 193 | os.makedirs(ev_dir, exist_ok=True) 194 | if format == 'stdout': 195 | return HumanOutputFormat(sys.stdout) 196 | elif format == 'log': 197 | return HumanOutputFormat(osp.join(ev_dir, 'log%s.txt' % log_suffix)) 198 | elif format == 'json': 199 | return JSONOutputFormat(osp.join(ev_dir, 'progress%s.json' % log_suffix)) 200 | elif format == 'csv': 201 | return CSVOutputFormat(osp.join(ev_dir, 'progress%s.csv' % log_suffix)) 202 | elif format == 'tensorboard': 203 | return TensorBoardOutputFormat(osp.join(ev_dir, 'tb%s' % log_suffix)) 204 | elif format == 'wandb': 205 | return WandBOutputFormat(ev_dir) 206 | else: 207 | raise ValueError('Unknown format specified: %s' % (format,)) 208 | 209 | 210 | # ================================================================ 211 | # API 212 | # ================================================================ 213 | 214 | def logkv(key, val): 215 | """ 216 | Log a value of some diagnostic 217 | Call this once for each diagnostic quantity, each iteration 218 | If called many times, last value will be used. 219 | """ 220 | Logger.CURRENT.logkv(key, val) 221 | 222 | 223 | def logkv_mean(key, val): 224 | """ 225 | The same as logkv(), but if called many times, values averaged. 226 | """ 227 | Logger.CURRENT.logkv_mean(key, val) 228 | 229 | 230 | def logkvs(d): 231 | """ 232 | Log a dictionary of key-value pairs 233 | """ 234 | for (k, v) in d.items(): 235 | logkv(k, v) 236 | 237 | 238 | def dumpkvs(): 239 | """ 240 | Write all of the diagnostics from the current iteration 241 | 242 | level: int. (see logger.py docs) If the global logger level is higher than 243 | the level argument here, don't print to stdout. 244 | """ 245 | Logger.CURRENT.dumpkvs() 246 | 247 | 248 | def getkvs(): 249 | return Logger.CURRENT.name2val 250 | 251 | 252 | def log(*args, level=INFO): 253 | """ 254 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 255 | """ 256 | Logger.CURRENT.log(*args, level=level) 257 | 258 | 259 | def debug(*args): 260 | log(*args, level=DEBUG) 261 | 262 | 263 | def info(*args): 264 | log(*args, level=INFO) 265 | 266 | 267 | def warn(*args): 268 | log(*args, level=WARN) 269 | 270 | 271 | def error(*args): 272 | log(*args, level=ERROR) 273 | 274 | 275 | def set_level(level): 276 | """ 277 | Set logging threshold on current logger. 278 | """ 279 | Logger.CURRENT.set_level(level) 280 | 281 | 282 | def get_dir(): 283 | """ 284 | Get directory that log files are being written to. 285 | will be None if there is no output directory (i.e., if you didn't call start) 286 | """ 287 | return Logger.CURRENT.get_dir() 288 | 289 | 290 | record_tabular = logkv 291 | dump_tabular = dumpkvs 292 | 293 | 294 | class ProfileKV: 295 | """ 296 | Usage: 297 | with logger.ProfileKV("interesting_scope"): 298 | code 299 | """ 300 | 301 | def __init__(self, n): 302 | self.n = "wait_" + n 303 | 304 | def __enter__(self): 305 | self.t1 = time.time() 306 | 307 | def __exit__(self, type, value, traceback): 308 | Logger.CURRENT.name2val[self.n] += time.time() - self.t1 309 | 310 | 311 | def profile(n): 312 | """ 313 | Usage: 314 | @profile("my_func") 315 | def my_func(): code 316 | """ 317 | 318 | def decorator_with_name(func): 319 | def func_wrapper(*args, **kwargs): 320 | with ProfileKV(n): 321 | return func(*args, **kwargs) 322 | 323 | return func_wrapper 324 | 325 | return decorator_with_name 326 | 327 | 328 | # ================================================================ 329 | # Backend 330 | # ================================================================ 331 | 332 | class Logger(object): 333 | DEFAULT = None # A logger with no output files. (See right below class definition) 334 | # So that you can still log to the terminal without setting up any output files 335 | CURRENT = None # Current logger being used by the free functions above 336 | 337 | def __init__(self, dir, output_formats): 338 | self.name2val = defaultdict(float) # values this iteration 339 | self.name2cnt = defaultdict(int) 340 | self.level = INFO 341 | self.dir = dir 342 | self.output_formats = output_formats 343 | 344 | # Logging API, forwarded 345 | # ---------------------------------------- 346 | def logkv(self, key, val): 347 | self.name2val[key] = val 348 | 349 | def logkv_mean(self, key, val): 350 | if val is None: 351 | self.name2val[key] = None 352 | return 353 | oldval, cnt = self.name2val[key], self.name2cnt[key] 354 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 355 | self.name2cnt[key] = cnt + 1 356 | 357 | def dumpkvs(self): 358 | if self.level == DISABLED: return 359 | for fmt in self.output_formats: 360 | if isinstance(fmt, KVWriter): 361 | fmt.writekvs(self.name2val) 362 | self.name2val.clear() 363 | self.name2cnt.clear() 364 | 365 | def log(self, *args, level=INFO): 366 | if self.level <= level: 367 | self._do_log(args) 368 | 369 | # Configuration 370 | # ---------------------------------------- 371 | def set_level(self, level): 372 | self.level = level 373 | 374 | def get_dir(self): 375 | return self.dir 376 | 377 | def close(self): 378 | for fmt in self.output_formats: 379 | fmt.close() 380 | 381 | # Misc 382 | # ---------------------------------------- 383 | def _do_log(self, args): 384 | for fmt in self.output_formats: 385 | if isinstance(fmt, SeqWriter): 386 | fmt.writeseq(map(str, args)) 387 | 388 | 389 | def configure(dir=None, format_strs=None): 390 | if dir is None: 391 | dir = os.getenv('OPENAI_LOGDIR') 392 | if dir is None: 393 | dir = osp.join(tempfile.gettempdir(), 394 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f")) 395 | assert isinstance(dir, str) 396 | os.makedirs(dir, exist_ok=True) 397 | 398 | log_suffix = '' 399 | rank = 0 400 | # check environment variables here instead of importing mpi4py 401 | # to avoid calling MPI_Init() when this module is imported 402 | for varname in ['PMI_RANK', 'OMPI_COMM_WORLD_RANK']: 403 | if varname in os.environ: 404 | rank = int(os.environ[varname]) 405 | if rank > 0: 406 | log_suffix = "-rank%03i" % rank 407 | 408 | if format_strs is None: 409 | if rank == 0: 410 | format_strs = os.getenv('OPENAI_LOG_FORMAT', 'stdout,log,csv').split(',') 411 | else: 412 | format_strs = os.getenv('OPENAI_LOG_FORMAT_MPI', 'log').split(',') 413 | format_strs = filter(None, format_strs) 414 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 415 | 416 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats) 417 | log('Logging to %s' % dir) 418 | 419 | 420 | def _configure_default_logger(): 421 | format_strs = None 422 | # keep the old default of only writing to stdout 423 | if 'OPENAI_LOG_FORMAT' not in os.environ: 424 | format_strs = ['stdout'] 425 | configure(format_strs=format_strs) 426 | Logger.DEFAULT = Logger.CURRENT 427 | 428 | 429 | def reset(): 430 | if Logger.CURRENT is not Logger.DEFAULT: 431 | Logger.CURRENT.close() 432 | Logger.CURRENT = Logger.DEFAULT 433 | log('Reset logger') 434 | 435 | 436 | class scoped_configure(object): 437 | def __init__(self, dir=None, format_strs=None): 438 | self.dir = dir 439 | self.format_strs = format_strs 440 | self.prevlogger = None 441 | 442 | def __enter__(self): 443 | self.prevlogger = Logger.CURRENT 444 | configure(dir=self.dir, format_strs=self.format_strs) 445 | 446 | def __exit__(self, *args): 447 | Logger.CURRENT.close() 448 | Logger.CURRENT = self.prevlogger 449 | 450 | 451 | # ================================================================ 452 | 453 | def _demo(): 454 | info("hi") 455 | debug("shouldn't appear") 456 | set_level(DEBUG) 457 | debug("should appear") 458 | dir = "/tmp/testlogging" 459 | if os.path.exists(dir): 460 | shutil.rmtree(dir) 461 | configure(dir=dir) 462 | logkv("a", 3) 463 | logkv("b", 2.5) 464 | dumpkvs() 465 | logkv("b", -2.5) 466 | logkv("a", 5.5) 467 | dumpkvs() 468 | info("^^^ should see a = 5.5") 469 | logkv_mean("b", -22.5) 470 | logkv_mean("b", -44.4) 471 | logkv("a", 5.5) 472 | dumpkvs() 473 | info("^^^ should see b = 33.3") 474 | 475 | logkv("b", -2.5) 476 | dumpkvs() 477 | 478 | logkv("a", "longasslongasslongasslongasslongasslongassvalue") 479 | dumpkvs() 480 | 481 | 482 | # ================================================================ 483 | # Readers 484 | # ================================================================ 485 | 486 | def read_json(fname): 487 | import pandas 488 | ds = [] 489 | with open(fname, 'rt') as fh: 490 | for line in fh: 491 | ds.append(json.loads(line)) 492 | return pandas.DataFrame(ds) 493 | 494 | 495 | def read_csv(fname): 496 | import pandas 497 | return pandas.read_csv(fname, index_col=None, comment='#') 498 | 499 | 500 | def read_tb(path): 501 | """ 502 | path : a tensorboard file OR a directory, where we will find all TB files 503 | of the form events.* 504 | """ 505 | import pandas 506 | import numpy as np 507 | from glob import glob 508 | from collections import defaultdict 509 | import tensorflow as tf 510 | if osp.isdir(path): 511 | fnames = glob(osp.join(path, "events.*")) 512 | elif osp.basename(path).startswith("events."): 513 | fnames = [path] 514 | else: 515 | raise NotImplementedError("Expected tensorboard file or directory containing them. Got %s" % path) 516 | tag2pairs = defaultdict(list) 517 | maxstep = 0 518 | for fname in fnames: 519 | for summary in tf.train.summary_iterator(fname): 520 | if summary.step > 0: 521 | for v in summary.summary.value: 522 | pair = (summary.step, v.simple_value) 523 | tag2pairs[v.tag].append(pair) 524 | maxstep = max(summary.step, maxstep) 525 | data = np.empty((maxstep, len(tag2pairs))) 526 | data[:] = np.nan 527 | tags = sorted(tag2pairs.keys()) 528 | for (colidx, tag) in enumerate(tags): 529 | pairs = tag2pairs[tag] 530 | for (step, value) in pairs: 531 | data[step - 1, colidx] = value 532 | return pandas.DataFrame(data, columns=tags) 533 | 534 | 535 | # configure the default logger on import 536 | # _configure_default_logger() 537 | 538 | if __name__ == "__main__": 539 | _demo() 540 | -------------------------------------------------------------------------------- /memory.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import numpy as np 3 | import random 4 | 5 | State = namedtuple('State', ('obs', 'description', 'inventory')) 6 | Transition = namedtuple('Transition', ('state', 'act', 'reward', 'next_state', 'next_acts', 'done', 'acts')) 7 | 8 | 9 | class ReplayMemory(object): 10 | def __init__(self, capacity): 11 | self.capacity = capacity 12 | self.memory = [] 13 | self.position = 0 14 | 15 | def push(self, transition): 16 | if len(self.memory) < self.capacity: 17 | self.memory.append(None) 18 | self.memory[self.position] = transition 19 | self.position = (self.position + 1) % self.capacity 20 | 21 | def sample(self, batch_size): 22 | return random.sample(self.memory, batch_size) 23 | 24 | def __len__(self): 25 | return len(self.memory) 26 | 27 | 28 | class PrioritizedReplayMemory(object): 29 | def __init__(self, capacity, alpha): 30 | self.capacity = capacity 31 | self.alpha= alpha 32 | self.memory = [] 33 | self.priorities = [] 34 | self.position = 0 35 | 36 | def push(self, transition, priority): 37 | if len(self.memory) < self.capacity: 38 | self.memory.append(None) 39 | self.priorities.append(None) 40 | self.memory[self.position] = transition 41 | self.priorities[self.position] = priority 42 | self.position = (self.position + 1) % self.capacity 43 | 44 | def sample(self, batch_size): 45 | priorities = np.array(self.priorities) 46 | priorities = np.power(priorities + 1e-5, self.alpha) 47 | p = priorities / np.sum(priorities) 48 | idxs = np.random.choice(np.arange(len(p)), size=batch_size, p=p) 49 | return [self.memory[i] for i in idxs] 50 | 51 | def update(self, idxs, priorities): 52 | for i, priority in zip(idxs, priorities): 53 | self.priorities[i] = priority 54 | 55 | def __len__(self): 56 | return len(self.memory) 57 | 58 | 59 | class ABReplayMemory(object): 60 | def __init__(self, capacity, priority_fraction): 61 | self.priority_fraction = priority_fraction 62 | self.alpha_capacity = int(capacity * priority_fraction) 63 | self.beta_capacity = capacity - self.alpha_capacity 64 | self.alpha_memory, self.beta_memory = [], [] 65 | self.alpha_position, self.beta_position = 0, 0 66 | 67 | def clear_alpha(self): 68 | self.alpha_memory = [] 69 | self.alpha_position = 0 70 | 71 | def push(self, transition, is_prior=False): 72 | """Saves a transition.""" 73 | if self.priority_fraction == 0.0: 74 | is_prior = False 75 | if is_prior: 76 | if len(self.alpha_memory) < self.alpha_capacity: 77 | self.alpha_memory.append(None) 78 | self.alpha_memory[self.alpha_position] = transition 79 | self.alpha_position = (self.alpha_position + 1) % self.alpha_capacity 80 | else: 81 | if len(self.beta_memory) < self.beta_capacity: 82 | self.beta_memory.append(None) 83 | self.beta_memory[self.beta_position] = transition 84 | self.beta_position = (self.beta_position + 1) % self.beta_capacity 85 | 86 | def sample(self, batch_size): 87 | if self.priority_fraction == 0.0: 88 | from_beta = min(batch_size, len(self.beta_memory)) 89 | res = random.sample(self.beta_memory, from_beta) 90 | else: 91 | from_alpha = min(int(self.priority_fraction * batch_size), len(self.alpha_memory)) 92 | from_beta = min(batch_size - int(self.priority_fraction * batch_size), len(self.beta_memory)) 93 | res = random.sample(self.alpha_memory, from_alpha) + random.sample(self.beta_memory, from_beta) 94 | random.shuffle(res) 95 | return res 96 | 97 | def __len__(self): 98 | return len(self.alpha_memory) + len(self.beta_memory) 99 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import random 6 | import itertools 7 | from util import pad_sequences 8 | from memory import State 9 | 10 | 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | 13 | 14 | class DRRN(torch.nn.Module): 15 | """ 16 | Deep Reinforcement Relevance Network - He et al. '16 17 | 18 | """ 19 | def __init__(self, vocab_size, embedding_dim, hidden_dim, fix_rep=0, hash_rep=0, act_obs=0): 20 | super(DRRN, self).__init__() 21 | self.hidden_dim = hidden_dim 22 | self.embedding = nn.Embedding(vocab_size, embedding_dim) 23 | self.obs_encoder = nn.GRU(embedding_dim, hidden_dim) 24 | self.look_encoder = nn.GRU(embedding_dim, hidden_dim) 25 | self.inv_encoder = nn.GRU(embedding_dim, hidden_dim) 26 | self.act_encoder = nn.GRU(embedding_dim, hidden_dim) 27 | self.hidden = nn.Linear(2 * hidden_dim, hidden_dim) 28 | # self.hidden = nn.Sequential(nn.Linear(2 * hidden_dim, 2 * hidden_dim), nn.Linear(2 * hidden_dim, hidden_dim), nn.Linear(hidden_dim, hidden_dim)) 29 | self.act_scorer = nn.Linear(hidden_dim, 1) 30 | 31 | self.state_encoder = nn.Linear(3 * hidden_dim, hidden_dim) 32 | self.inverse_dynamics = nn.Sequential(nn.Linear(2 * hidden_dim, 2 * hidden_dim), nn.ReLU(), nn.Linear(2 * hidden_dim, hidden_dim)) 33 | self.forward_dynamics = nn.Sequential(nn.Linear(2 * hidden_dim, 2 * hidden_dim), nn.ReLU(), nn.Linear(2 * hidden_dim, hidden_dim)) 34 | 35 | self.act_decoder = nn.GRU(hidden_dim, embedding_dim) 36 | self.act_fc = nn.Linear(embedding_dim, vocab_size) 37 | 38 | self.obs_decoder = nn.GRU(hidden_dim, embedding_dim) 39 | self.obs_fc = nn.Linear(embedding_dim, vocab_size) 40 | 41 | self.fix_rep = fix_rep 42 | self.hash_rep = hash_rep 43 | self.act_obs = act_obs 44 | self.hash_cache = {} 45 | 46 | def packed_hash(self, x): 47 | y = [] 48 | for data in x: 49 | data = hash(tuple(data)) 50 | if data in self.hash_cache: 51 | y.append(self.hash_cache[data]) 52 | else: 53 | a = torch.zeros(self.hidden_dim).normal_(generator=torch.random.manual_seed(data)) 54 | # torch.random.seed() 55 | y.append(a) 56 | self.hash_cache[data] = a 57 | y = torch.stack(y, dim=0).to(device) 58 | return y 59 | 60 | def packed_rnn(self, x, rnn): 61 | """ Runs the provided rnn on the input x. Takes care of packing/unpacking. 62 | 63 | x: list of unpadded input sequences 64 | Returns a tensor of size: len(x) x hidden_dim 65 | """ 66 | if self.hash_rep: return self.packed_hash(x) 67 | lengths = torch.tensor([len(n) for n in x], dtype=torch.long, device=device) 68 | # Sort this batch in descending order by seq length 69 | lengths, idx_sort = torch.sort(lengths, dim=0, descending=True) 70 | _, idx_unsort = torch.sort(idx_sort, dim=0) 71 | idx_sort = torch.autograd.Variable(idx_sort) 72 | idx_unsort = torch.autograd.Variable(idx_unsort) 73 | padded_x = pad_sequences(x) 74 | x_tt = torch.from_numpy(padded_x).type(torch.long).to(device) 75 | x_tt = x_tt.index_select(0, idx_sort) 76 | # Run the embedding layer 77 | embed = self.embedding(x_tt).permute(1,0,2) # Time x Batch x EncDim 78 | # Pack padded batch of sequences for RNN module 79 | packed = nn.utils.rnn.pack_padded_sequence(embed, lengths) 80 | # Run the RNN 81 | out, _ = rnn(packed) 82 | # Unpack 83 | out, _ = nn.utils.rnn.pad_packed_sequence(out) 84 | # Get the last step of each sequence 85 | idx = (lengths-1).view(-1,1).expand(len(lengths), out.size(2)).unsqueeze(0) 86 | out = out.gather(0, idx).squeeze(0) 87 | # Unsort 88 | out = out.index_select(0, idx_unsort) 89 | return out 90 | 91 | 92 | def state_rep(self, state_batch): 93 | # Zip the state_batch into an easy access format 94 | state = State(*zip(*state_batch)) 95 | # Encode the various aspects of the state 96 | with torch.set_grad_enabled(not self.fix_rep): 97 | obs_out = self.packed_rnn(state.obs, self.obs_encoder) 98 | if self.act_obs: return obs_out 99 | look_out = self.packed_rnn(state.description, self.look_encoder) 100 | inv_out = self.packed_rnn(state.inventory, self.inv_encoder) 101 | state_out = self.state_encoder(torch.cat((obs_out, look_out, inv_out), dim=1)) 102 | return state_out 103 | 104 | 105 | def act_rep(self, act_batch): 106 | # This is number of admissible commands in each element of the batch 107 | act_sizes = [len(a) for a in act_batch] 108 | # Combine next actions into one long list 109 | act_batch = list(itertools.chain.from_iterable(act_batch)) 110 | with torch.set_grad_enabled(not self.fix_rep): 111 | act_out = self.packed_rnn(act_batch, self.act_encoder) 112 | return act_sizes, act_out 113 | 114 | 115 | def for_predict(self, state_batch, acts): 116 | _, act_out = self.act_rep(acts) 117 | state_out = self.state_rep(state_batch) 118 | next_state_out = state_out + self.forward_dynamics(torch.cat((state_out, act_out), dim=1)) 119 | return next_state_out 120 | 121 | 122 | def inv_predict(self, state_batch, next_state_batch): 123 | state_out = self.state_rep(state_batch) 124 | next_state_out = self.state_rep(next_state_batch) 125 | act_out = self.inverse_dynamics(torch.cat((state_out, next_state_out - state_out), dim=1)) 126 | return act_out 127 | 128 | 129 | def inv_loss_l1(self, state_batch, next_state_batch, acts): 130 | _, act_out = self.act_rep(acts) 131 | act_out_hat = self.inv_predict(state_batch, next_state_batch) 132 | return F.l1_loss(act_out, act_out_hat) 133 | 134 | 135 | def inv_loss_l2(self, state_batch, next_state_batch, acts): 136 | _, act_out = self.act_rep(acts) 137 | act_out_hat = self.inv_predict(state_batch, next_state_batch) 138 | return F.mse_loss(act_out, act_out_hat) 139 | 140 | 141 | def inv_loss_ce(self, state_batch, next_state_batch, acts, valids, get_predict=False): 142 | act_sizes, valids_out = self.act_rep(valids) 143 | _, act_out = self.act_rep(acts) 144 | act_out_hat = self.inv_predict(state_batch, next_state_batch) 145 | now, loss, acc = 0, 0, 0 146 | if get_predict: predicts = [] 147 | for i, j in enumerate(act_sizes): 148 | valid_out = valids_out[now: now + j] 149 | now += j 150 | values = valid_out.matmul(act_out_hat[i]) 151 | label = valids[i].index(acts[i][0]) 152 | loss += F.cross_entropy(values.unsqueeze(0), torch.LongTensor([label]).to(device)) 153 | predict = values.argmax().item() 154 | acc += predict == label 155 | if get_predict: predicts.append(predict) 156 | return (loss / len(act_sizes), acc / len(act_sizes), predicts) if get_predict else (loss / len(act_sizes), acc / len(act_sizes)) 157 | 158 | 159 | def inv_loss_decode(self, state_batch, next_state_batch, acts, hat=True, reduction='mean'): 160 | # hat: use rep(o), rep(o'); not hat: use rep(a) 161 | _, act_out = self.act_rep(acts) 162 | act_out_hat = self.inv_predict(state_batch, next_state_batch) 163 | 164 | acts_pad = pad_sequences([act[0] for act in acts]) 165 | acts_tensor = torch.from_numpy(acts_pad).type(torch.long).to(device).transpose(0, 1) 166 | l, bs = acts_tensor.size() 167 | vocab = self.embedding.num_embeddings 168 | outputs = torch.zeros(l, bs, vocab).to(device) 169 | input, z = acts_tensor[0].unsqueeze(0), (act_out_hat if hat else act_out).unsqueeze(0) 170 | for t in range(1, l): 171 | input = self.embedding(input) 172 | output, z = self.act_decoder(input, z) 173 | output = self.act_fc(output) 174 | outputs[t] = output 175 | top = output.argmax(2) 176 | input = top 177 | outputs, acts_tensor = outputs[1:], acts_tensor[1:] 178 | loss = F.cross_entropy(outputs.reshape(-1, vocab), acts_tensor.reshape(-1), ignore_index=0, reduction=reduction) 179 | if reduction == 'none': # loss for each term in batch 180 | lens = [len(act[0]) - 1 for act in acts] 181 | loss = loss.reshape(-1, bs).sum(0).cpu() / torch.tensor(lens) 182 | nonzero = (acts_tensor > 0) 183 | same = (outputs.argmax(-1) == acts_tensor) 184 | acc_token = (same & nonzero).float().sum() / (nonzero).float().sum() # token accuracy 185 | acc_action = (same.int().sum(0) == nonzero.int().sum(0)).float().sum() / same.size(1) # action accuracy 186 | return loss, acc_action 187 | 188 | 189 | def for_loss_l2(self, state_batch, next_state_batch, acts): 190 | next_state_out = self.state_rep(next_state_batch) 191 | next_state_out_hat = self.for_predict(state_batch, acts) 192 | return F.mse_loss(next_state_out, next_state_out_hat) # , reduction='sum') 193 | 194 | 195 | def for_loss_ce_batch(self, state_batch, next_state_batch, acts): 196 | # consider duplicates in next_state_batch 197 | next_states, labels = [], [] 198 | for next_state in next_state_batch: 199 | if next_state not in next_states: 200 | labels.append(len(next_states)) 201 | next_states.append(next_state) 202 | else: 203 | labels.append(next_states.index(next_state)) 204 | labels = torch.LongTensor(labels).to(device) 205 | next_state_out = self.state_rep(next_states) 206 | next_state_out_hat = self.for_predict(state_batch, acts) 207 | logits = next_state_out_hat.matmul(next_state_out.transpose(0, 1)) 208 | loss = F.cross_entropy(logits, labels) 209 | acc = (logits.argmax(1) == labels).float().sum() / len(labels) 210 | return loss, acc 211 | 212 | 213 | def for_loss_ce(self, state_batch, next_state_batch, acts, valids): 214 | # classify rep(o') from predict(o, a1), predict(o, a2), ... 215 | act_sizes, valids_out = self.act_rep(valids) 216 | _, act_out = self.act_rep(acts) 217 | next_state_out = self.state_rep(next_state_batch) 218 | now, loss, acc = 0, 0, 0 219 | for i, j in enumerate(act_sizes): 220 | valid_out = valids_out[now: now + j] 221 | now += j 222 | next_states_out_hat = self.for_predict([state_batch[i]] * j, [[_] for _ in valids[i]]) 223 | values = next_states_out_hat.matmul(next_state_out[i]) 224 | label = valids[i].index(acts[i][0]) 225 | loss += F.cross_entropy(values.unsqueeze(0), torch.LongTensor([label]).to(device)) 226 | predict = values.argmax().item() 227 | acc += predict == label 228 | return (loss / len(act_sizes), acc / len(act_sizes)) 229 | 230 | 231 | def for_loss_decode(self, state_batch, next_state_batch, acts, hat=True): 232 | # hat: use rep(o), rep(a); not hat: use rep(o') 233 | next_state_out = self.state_rep(next_state_batch) 234 | next_state_out_hat = self.for_predict(state_batch, acts) 235 | 236 | import pdb; pdb.set_trace() 237 | next_state_pad = pad_sequences(next_state_batch) 238 | next_state_tensor = torch.from_numpy(next_state_batch).type(torch.long).to(device).transpose(0, 1) 239 | l, bs = next_state_tensor.size() 240 | vocab = self.embedding.num_embeddings 241 | outputs = torch.zeros(l, bs, vocab).to(device) 242 | input, z = next_state_tensor[0].unsqueeze(0), (next_state_out_hat if hat else next_state_out).unsqueeze(0) 243 | for t in range(1, l): 244 | input = self.embedding(input) 245 | output, z = self.obs_decoder(input, z) 246 | output = self.obs_fc(output) 247 | outputs[t] = output 248 | top = output.argmax(2) 249 | input = top 250 | outputs, next_state_tensor = outputs[1:].reshape(-1, vocab), next_state_tensor[1:].reshape(-1) 251 | loss = F.cross_entropy(outputs, next_state_tensor, ignore_index=0) 252 | nonzero = (next_state_tensor > 0) 253 | same = (outputs.argmax(1) == next_state_tensor) 254 | acc = (same & nonzero).float().sum() / (nonzero).float().sum() # token accuracy 255 | return loss, acc 256 | 257 | 258 | def forward(self, state_batch, act_batch): 259 | """ 260 | Batched forward pass. 261 | obs_id_batch: iterable of unpadded sequence ids 262 | act_batch: iterable of lists of unpadded admissible command ids 263 | 264 | Returns a tuple of tensors containing q-values for each item in the batch 265 | """ 266 | state_out = self.state_rep(state_batch) 267 | act_sizes, act_out = self.act_rep(act_batch) 268 | # Expand the state to match the batches of actions 269 | state_out = torch.cat([state_out[i].repeat(j,1) for i,j in enumerate(act_sizes)], dim=0) 270 | z = torch.cat((state_out, act_out), dim=1) # Concat along hidden_dim 271 | z = F.relu(self.hidden(z)) 272 | act_values = self.act_scorer(z).squeeze(-1) 273 | # Split up the q-values by batch 274 | return act_values.split(act_sizes) 275 | 276 | 277 | def act(self, states, act_ids, sample=True, eps=0.1): 278 | """ Returns an action-string, optionally sampling from the distribution 279 | of Q-Values. 280 | """ 281 | act_values = self.forward(states, act_ids) 282 | if sample: 283 | act_probs = [F.softmax(vals, dim=0) for vals in act_values] 284 | act_idxs = [torch.multinomial(probs, num_samples=1).item() \ 285 | for probs in act_probs] 286 | else: 287 | act_idxs = [vals.argmax(dim=0).item() if np.random.rand() > eps else np.random.randint(len(vals)) for vals in act_values] 288 | return act_idxs, act_values 289 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import time 3 | import os 4 | import torch 5 | import logger 6 | import argparse 7 | import jericho 8 | import logging 9 | import json 10 | from os.path import basename, dirname 11 | from drrn import * 12 | from env import JerichoEnv 13 | from jericho.util import clean 14 | from copy import deepcopy 15 | from vec_env import VecEnv 16 | 17 | 18 | logging.getLogger().setLevel(logging.CRITICAL) 19 | subprocess.run("python -m spacy download en_core_web_sm".split()) 20 | 21 | def configure_logger(log_dir, wandb): 22 | logger.configure(log_dir, format_strs=['log']) 23 | global tb 24 | type_strs = ['json', 'stdout'] 25 | if wandb and log_dir != 'logs': type_strs += ['wandb'] 26 | tb = logger.Logger(log_dir, [logger.make_output_format(type_str, log_dir) for type_str in type_strs]) 27 | global log 28 | log = logger.log 29 | 30 | 31 | def evaluate(agent, env, nb_episodes=1): 32 | with torch.no_grad(): 33 | total_score = 0 34 | for ep in range(nb_episodes): 35 | log("Starting evaluation episode {}".format(ep)) 36 | score = evaluate_episode(agent, env) 37 | log("Evaluation episode {} ended with score {}\n\n".format(ep, score)) 38 | total_score += score 39 | avg_score = total_score / nb_episodes 40 | return avg_score 41 | 42 | 43 | def evaluate_episode(agent, env): 44 | step = 0 45 | done = False 46 | ob, info = env.reset() 47 | state = agent.build_state(ob, info) 48 | log('Obs{}: {} Inv: {} Desc: {}'.format(step, clean(ob), clean(info['inv']), clean(info['look']))) 49 | while not done: 50 | valid_acts = info['valid'] 51 | valid_ids = agent.encode(valid_acts) 52 | _, action_idx, action_values = agent.act([state], [valid_ids], sample=False) 53 | action_idx = action_idx[0] 54 | action_values = action_values[0] 55 | action_str = valid_acts[action_idx] 56 | log('Action{}: {}, Q-Value {:.2f}'.format(step, action_str, action_values[action_idx].item())) 57 | s = '' 58 | for idx, (act, val) in enumerate(sorted(zip(valid_acts, action_values), key=lambda x: x[1], reverse=True), 1): 59 | s += "{}){:.2f} {} ".format(idx, val.item(), act) 60 | log('Q-Values: {}'.format(s)) 61 | ob, rew, done, info = env.step(action_str) 62 | log("Reward{}: {}, Score {}, Done {}".format(step, rew, info['score'], done)) 63 | step += 1 64 | log('Obs{}: {} Inv: {} Desc: {}'.format(step, clean(ob), clean(info['inv']), clean(info['look']))) 65 | state = agent.build_state(ob, info) 66 | return info['score'] 67 | 68 | 69 | def train(agent, eval_env, envs, max_steps, update_freq, eval_freq, checkpoint_freq, log_freq, r_for): 70 | start, max_score, max_reward = time.time(), 0, 0 71 | obs, infos = envs.reset() 72 | states = agent.build_states(obs, infos) 73 | valid_ids = [agent.encode(info['valid']) for info in infos] 74 | transitions = [[] for info in infos] 75 | for step in range(1, max_steps+1): 76 | action_ids, action_idxs, action_values = agent.act(states, valid_ids, sample=True, eps=0.05 ** (step / max_steps)) 77 | action_strs = [info['valid'][idx] for info, idx in zip(infos, action_idxs)] 78 | 79 | # log envs[0] 80 | examples = [(action, value) for action, value in zip(infos[0]['valid'], action_values[0].tolist())] 81 | examples = sorted(examples, key=lambda x: -x[1]) 82 | log('State {}: {}'.format(step, clean(obs[0] + infos[0]['inv'] + infos[0]['look']))) 83 | log('Actions{}: {}'.format(step, [action for action, _ in examples])) 84 | log('Qvalues{}: {}'.format(step, [round(value, 2) for _, value in examples])) 85 | log('>> Action{}: {}'.format(step, action_strs[0])) 86 | 87 | # step 88 | obs, rewards, dones, infos = envs.step(action_strs) 89 | next_states = agent.build_states(obs, infos) 90 | next_valids = [agent.encode(info['valid']) for info in infos] 91 | if r_for > 0: 92 | reward_curiosity, _ = agent.network.inv_loss_decode(states, next_states, [[a] for a in action_ids], hat=True, reduction='none') 93 | rewards = rewards + reward_curiosity.detach().numpy() * r_for 94 | tb.logkv_mean('Curiosity', reward_curiosity.mean().item()) 95 | 96 | for i, (ob, reward, done, info, state, next_state) in enumerate(zip(obs, rewards, dones, infos, states, next_states)): 97 | transition = Transition(state, action_ids[i], reward, next_state, next_valids[i], done, valid_ids[i]) 98 | transitions[i].append(transition) 99 | agent.observe(transition) 100 | if i == 0: 101 | log("Reward{}: {}, Score {}, Done {}\n".format(step, reward, info['score'], done)) 102 | if done: 103 | tb.logkv_mean('EpisodeScore', info['score']) 104 | # obs[i], infos[i] = env.reset() 105 | # next_states[i] = agent.build_state(obs[i], infos[i]) 106 | # next_valids[i] = agent.encode(infos[i]['valid']) 107 | if info['score'] >= max_score: # put in alpha queue 108 | if info['score'] > max_score: 109 | agent.memory.clear_alpha() 110 | max_score = info['score'] 111 | for transition in transitions[i]: 112 | agent.observe(transition, is_prior=True) 113 | transitions[i] = [] 114 | 115 | states, valid_ids = next_states, next_valids 116 | if step % log_freq == 0: 117 | tb.logkv('Step', step) 118 | tb.logkv("FPS", int((step*envs.num_envs)/(time.time()-start))) 119 | tb.logkv("EpisodeScores100", envs.get_end_scores().mean()) 120 | tb.logkv('MaxScore', max_score) 121 | tb.logkv('Step', step) 122 | # if envs[0].cache is not None: 123 | # tb.logkv('#dict', len(envs[0].cache)) 124 | # tb.logkv('#locs', len(envs[0].cache['loc'])) 125 | tb.dumpkvs() 126 | if step % update_freq == 0: 127 | res = agent.update() 128 | if res is not None: 129 | for k, v in res.items(): 130 | tb.logkv_mean(k, v) 131 | if step % checkpoint_freq == 0: 132 | agent.save(str(step)) 133 | # json_path = envs[0].rom_path.replace('.z5', '.json') 134 | # if os.path.exists(json_path): 135 | # envs[0].cache.update(json.load(open(json_path))) 136 | # json.dump(envs[0].cache, open(json_path, 'w')) 137 | if step % eval_freq == 0: 138 | eval_score = evaluate(agent, eval_env) 139 | tb.logkv('EvalScore', eval_score) 140 | tb.dumpkvs() 141 | 142 | 143 | def parse_args(): 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument('--output_dir', default='logs') 146 | parser.add_argument('--load', default=None) 147 | parser.add_argument('--spm_path', default='unigram_8k.model') 148 | parser.add_argument('--rom_path', default='zork1.z5') 149 | parser.add_argument('--env_step_limit', default=100, type=int) 150 | parser.add_argument('--seed', default=0, type=int) 151 | parser.add_argument('--num_envs', default=8, type=int) 152 | parser.add_argument('--max_steps', default=100000, type=int) 153 | parser.add_argument('--update_freq', default=1, type=int) 154 | parser.add_argument('--checkpoint_freq', default=10000, type=int) 155 | parser.add_argument('--eval_freq', default=5000, type=int) 156 | parser.add_argument('--log_freq', default=100, type=int) 157 | parser.add_argument('--memory_size', default=10000, type=int) 158 | parser.add_argument('--memory_alpha', default=.4, type=float) 159 | parser.add_argument('--batch_size', default=64, type=int) 160 | parser.add_argument('--gamma', default=.9, type=float) 161 | parser.add_argument('--learning_rate', default=0.0001, type=float) 162 | parser.add_argument('--clip', default=5, type=float) 163 | parser.add_argument('--embedding_dim', default=128, type=int) 164 | parser.add_argument('--hidden_dim', default=128, type=int) 165 | 166 | parser.add_argument('--wandb', default=1, type=int) 167 | 168 | parser.add_argument('--type_inv', default='decode') 169 | parser.add_argument('--type_for', default='ce') 170 | parser.add_argument('--w_inv', default=0, type=float) 171 | parser.add_argument('--w_for', default=0, type=float) 172 | parser.add_argument('--w_act', default=0, type=float) 173 | parser.add_argument('--r_for', default=0, type=float) 174 | 175 | parser.add_argument('--nor', default=0, type=int, help='no game reward') 176 | parser.add_argument('--randr', default=0, type=int, help='random game reward by objects and locations within episode') 177 | parser.add_argument('--perturb', default=0, type=int, help='perturb state and action') 178 | 179 | parser.add_argument('--hash_rep', default=0, type=int, help='hash for representation') 180 | parser.add_argument('--act_obs', default=0, type=int, help='action set as state representation') 181 | parser.add_argument('--fix_rep', default=0, type=int, help='fix representation') 182 | return parser.parse_known_args()[0] 183 | 184 | 185 | def main(): 186 | args = parse_args() 187 | print(args) 188 | configure_logger(args.output_dir, args.wandb) 189 | agent = DRRN_Agent(args) 190 | agent.load(args.load) 191 | # cache = {'loc': set()} 192 | cache = None 193 | if args.perturb: 194 | args.en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model') 195 | args.de2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.de-en.single_model') 196 | args.en2de.eval() 197 | args.de2en.eval() 198 | args.en2de.cuda() 199 | args.de2en.cuda() 200 | args.perturb_dict = {} 201 | 202 | env = JerichoEnv(args.rom_path, args.seed, args.env_step_limit, get_valid=True, cache=cache, args=args) 203 | # envs = [JerichoEnv(args.rom_path, args.seed, args.env_step_limit, get_valid=True, cache=cache, args=args) for _ in range(args.num_envs)] 204 | envs = VecEnv(args.num_envs, env) 205 | train(agent, env, envs, args.max_steps, args.update_freq, args.eval_freq, args.checkpoint_freq, args.log_freq, args.r_for) 206 | 207 | 208 | if __name__ == "__main__": 209 | main() 210 | -------------------------------------------------------------------------------- /unigram_8k.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/blindfold-textgame/f0dbf32cb76563982291c51d6db9d6691889c55d/unigram_8k.model -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def pad_sequences(sequences, maxlen=None, dtype='int32', value=0.): 5 | ''' 6 | Partially borrowed from Keras 7 | # Arguments 8 | sequences: list of lists where each element is a sequence 9 | maxlen: int, maximum length 10 | dtype: type to cast the resulting sequence. 11 | value: float, value to pad the sequences to the desired value. 12 | # Returns 13 | x: numpy array with dimensions (number_of_sequences, maxlen) 14 | ''' 15 | lengths = [len(s) for s in sequences] 16 | nb_samples = len(sequences) 17 | if maxlen is None: 18 | maxlen = np.max(lengths) 19 | # take the sample shape from the first non empty sequence 20 | # checking for consistency in the main loop below. 21 | sample_shape = tuple() 22 | for s in sequences: 23 | if len(s) > 0: 24 | sample_shape = np.asarray(s).shape[1:] 25 | break 26 | x = (np.ones((nb_samples, maxlen) + sample_shape) * value).astype(dtype) 27 | for idx, s in enumerate(sequences): 28 | if len(s) == 0: 29 | continue # empty list was found 30 | # pre truncating 31 | trunc = s[-maxlen:] 32 | # check `trunc` has expected shape 33 | trunc = np.asarray(trunc, dtype=dtype) 34 | if trunc.shape[1:] != sample_shape: 35 | raise ValueError('Shape of sample %s of sequence at position %s is different from expected shape %s' % 36 | (trunc.shape[1:], idx, sample_shape)) 37 | # post padding 38 | x[idx, :len(trunc)] = trunc 39 | return x 40 | -------------------------------------------------------------------------------- /vec_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from multiprocessing import Process, Pipe 3 | 4 | def worker(remote, parent_remote, env): 5 | parent_remote.close() 6 | try: 7 | done = False 8 | while True: 9 | cmd, data = remote.recv() 10 | if cmd == 'step': 11 | if done: 12 | ob, info = env.reset() 13 | reward = 0 14 | done = False 15 | else: 16 | ob, reward, done, info = env.step(data) 17 | remote.send((ob, reward, done, info)) 18 | elif cmd == 'reset': 19 | ob, info = env.reset() 20 | remote.send((ob, info)) 21 | elif cmd == 'get_end_scores': 22 | remote.send(env.get_end_scores(last=100)) 23 | elif cmd == 'close': 24 | env.close() 25 | break 26 | else: 27 | raise NotImplementedError 28 | except KeyboardInterrupt: 29 | print('SubprocVecEnv worker: got KeyboardInterrupt') 30 | finally: 31 | env.close() 32 | 33 | 34 | class VecEnv: 35 | def __init__(self, num_envs, env): 36 | self.closed = False 37 | self.num_envs = num_envs 38 | self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(num_envs)]) 39 | self.ps = [Process(target=worker, args=(work_remote, remote, env)) 40 | for (work_remote, remote) in zip(self.work_remotes, self.remotes)] 41 | for p in self.ps: 42 | p.daemon = True # if the main process crashes, we should not cause things to hang 43 | p.start() 44 | for remote in self.work_remotes: 45 | remote.close() 46 | 47 | def step(self, actions): 48 | self._assert_not_closed() 49 | assert len(actions) == self.num_envs, "Error: incorrect number of actions." 50 | for remote, action in zip(self.remotes, actions): 51 | remote.send(('step', action)) 52 | results = [remote.recv() for remote in self.remotes] 53 | self.waiting = False 54 | obs, rewards, dones, infos = zip(*results) 55 | return np.stack(obs), np.stack(rewards), np.stack(dones), infos 56 | 57 | def reset(self): 58 | self._assert_not_closed() 59 | for remote in self.remotes: 60 | remote.send(('reset', None)) 61 | results = [remote.recv() for remote in self.remotes] 62 | obs, infos = zip(*results) 63 | return np.stack(obs), infos 64 | 65 | def get_end_scores(self): 66 | self._assert_not_closed() 67 | for remote in self.remotes: 68 | remote.send(('get_end_scores', None)) 69 | results = [remote.recv() for remote in self.remotes] 70 | return np.stack(results) 71 | 72 | def close_extras(self): 73 | self.closed = True 74 | for remote in self.remotes: 75 | remote.send(('close', None)) 76 | for p in self.ps: 77 | p.join() 78 | 79 | def _assert_not_closed(self): 80 | assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()" 81 | -------------------------------------------------------------------------------- /zork1.z5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/blindfold-textgame/f0dbf32cb76563982291c51d6db9d6691889c55d/zork1.z5 --------------------------------------------------------------------------------