├── .gitignore ├── README.md ├── gym_wikinav ├── __init__.py └── envs │ ├── __init__.py │ └── wikinav_env │ ├── __init__.py │ ├── environment.py │ ├── util.py │ └── web_graph.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pkl 2 | *.npz 3 | 4 | *.pyc 5 | .DS_Store 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Status:** Archive (code is provided as-is, no updates expected) 2 | 3 | -------------------------------------------------------------------------------- /gym_wikinav/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | 4 | register( 5 | id="wikinav-v0", 6 | entry_point="gym_wikinav.envs:EmbeddingWikiNavEnv", 7 | timestep_limit=50) 8 | -------------------------------------------------------------------------------- /gym_wikinav/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from gym_wikinav.envs.wikinav_env import WikiNavEnv, EmbeddingWikiNavEnv 2 | -------------------------------------------------------------------------------- /gym_wikinav/envs/wikinav_env/__init__.py: -------------------------------------------------------------------------------- 1 | from gym_wikinav.envs.wikinav_env.environment import WikiNavEnv, EmbeddingWikiNavEnv 2 | -------------------------------------------------------------------------------- /gym_wikinav/envs/wikinav_env/environment.py: -------------------------------------------------------------------------------- 1 | from io import StringIO 2 | import sys 3 | 4 | import gym 5 | from gym import error, spaces, utils 6 | from gym.utils import seeding 7 | 8 | from gym_wikinav.envs.wikinav_env import web_graph 9 | 10 | 11 | class WikiNavEnv(gym.Env): 12 | 13 | metadata = {"render.modes": ["human", "ansi"]} 14 | 15 | def __init__(self, beam_size=32, graph=None, goal_reward=10.0): 16 | """ 17 | Args: 18 | beam_size: Number of candidates to present as actions at each 19 | timestep 20 | graph: 21 | """ 22 | super(WikiNavEnv, self).__init__() 23 | 24 | if graph is None: 25 | graph = web_graph.EmbeddedWikispeediaGraph.get_default_graph() 26 | self.graph = graph 27 | 28 | # TODO verify beam size 29 | 30 | self.beam_size = beam_size 31 | self.goal_reward = goal_reward 32 | 33 | self.path_length = self.graph.path_length 34 | 35 | self.navigator = web_graph.Navigator(self.graph, self.beam_size, 36 | self.path_length) 37 | 38 | self._action_space = spaces.Discrete(self.beam_size) 39 | 40 | self._just_reset = False 41 | 42 | @property 43 | def action_space(self): 44 | return self._action_space 45 | 46 | @property 47 | def observation_space(self): 48 | # abstract 49 | raise NotImplementedError 50 | 51 | @property 52 | def cur_article_id(self): 53 | return self.navigator.cur_article_id 54 | 55 | @property 56 | def gold_path_length(self): 57 | return self.navigator.gold_path_length 58 | 59 | def get_article_for_action(self, action): 60 | return self.navigator.get_article_for_action(action) 61 | 62 | def _step(self, action): 63 | reward = self._reward(action) 64 | self.navigator.step(action) 65 | 66 | obs = self._observe() 67 | done = self.navigator.done 68 | info = {} 69 | 70 | return obs, reward, done, info 71 | 72 | def _reset(self): 73 | self.navigator.reset() 74 | self._just_reset = True 75 | obs = self._observe() 76 | self._just_reset = False 77 | return obs 78 | 79 | def _observe(self): 80 | # abstract 81 | raise NotImplementedError 82 | 83 | def _reward(self, action): 84 | """ 85 | Compute single-timestep reward after having taken the action specified 86 | by `action`. 87 | """ 88 | # abstract 89 | raise NotImplementedError 90 | 91 | def _render(self, mode="human", close=False): 92 | if close: return 93 | 94 | outfile = StringIO() if mode == "ansi" else sys.stdout 95 | 96 | cur_page = self.graph.get_article_title(self.cur_article_id) 97 | outfile.write("%s\n" % cur_page) 98 | return outfile 99 | 100 | 101 | class EmbeddingWikiNavEnv(WikiNavEnv): 102 | 103 | """ 104 | WikiNavEnv which represents articles with embeddings. 105 | """ 106 | 107 | def __init__(self, *args, **kwargs): 108 | super(EmbeddingWikiNavEnv, self).__init__(*args, **kwargs) 109 | 110 | self.embedding_dim = self.graph.embedding_dim 111 | 112 | self._query_embedding = None 113 | 114 | @property 115 | def observation_space(self): 116 | # 2 embeddings (query and current page) plus the embeddings of 117 | # articles on the beam 118 | return spaces.Box(low=-np.inf, high=np.inf, 119 | shape=(2 + self.beam_size, self.embedding_dim)) 120 | 121 | def _observe(self): 122 | if self._just_reset: 123 | self._query_embedding = \ 124 | self.graph.get_query_embeddings([self.navigator._path])[0] 125 | 126 | current_page_embedding = \ 127 | self.graph.get_article_embeddings([self.cur_article_id])[0] 128 | beam_embeddings = self.graph.get_article_embeddings(self.navigator.beam) 129 | 130 | return self._query_embedding, current_page_embedding, beam_embeddings 131 | 132 | def _reward(self, idx): 133 | if idx == self.graph.stop_sentinel: 134 | if self.navigator.on_target or self.navigator.done: 135 | # Return goal reward when first stopping on target and also at 136 | # every subsequent timestep. 137 | return self.goal_reward 138 | else: 139 | # Penalize for stopping on wrong page. 140 | return -self.goal_reward 141 | 142 | next_page = self.navigator.get_article_for_action(idx) 143 | overlap = self.graph.get_relative_word_overlap(next_page, 144 | self.navigator.target_id) 145 | return overlap * self.goal_reward 146 | -------------------------------------------------------------------------------- /gym_wikinav/envs/wikinav_env/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import requests 4 | from tqdm import trange 5 | 6 | 7 | def download_file(url, destination=None, chunk_size=1024): 8 | if destination is None: 9 | destination = url.split("/")[-1] 10 | r = requests.get(url, stream=True) 11 | with open(destination, "wb") as f: 12 | size = int(r.headers["content-length"]) 13 | n_chunks = math.ceil(size / float(chunk_size)) 14 | r_iter = r.iter_content(chunk_size=chunk_size) 15 | 16 | for _ in trange(n_chunks): 17 | chunk = next(r_iter) 18 | if chunk: 19 | f.write(chunk) 20 | 21 | # HACK: keep going in case we somehow missed chunks; maybe a wrong 22 | # header or the like 23 | for chunk in r_iter: 24 | if chunk: 25 | f.write(chunk) 26 | 27 | return destination 28 | -------------------------------------------------------------------------------- /gym_wikinav/envs/wikinav_env/web_graph.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines a common web graph navigation interface to WikiNav, Wikispeedia, etc. 3 | """ 4 | 5 | from collections import namedtuple 6 | import os 7 | import random 8 | import sys 9 | 10 | import numpy as np 11 | 12 | from gym_wikinav.envs.wikinav_env.util import download_file 13 | 14 | 15 | EmbeddedArticle = namedtuple("EmbeddedArticle", ["title", "embedding", "text"]) 16 | 17 | 18 | class EmbeddedWebGraph(object): 19 | 20 | embedding_dim = 128 21 | 22 | def __init__(self, articles, datasets, path_length, stop_sentinel=None): 23 | self.articles = articles 24 | self.datasets = {name: (all_paths, np.array(lengths)) 25 | for name, (all_paths, lengths) in datasets.items()} 26 | self.path_length = path_length 27 | 28 | assert "train" in self.datasets 29 | assert "valid" in self.datasets 30 | 31 | # Hack: use a random page as the "STOP" sentinel. 32 | # Works in expectation. :) 33 | if stop_sentinel is None: 34 | stop_sentinel = np.random.choice(len(self.articles)) 35 | self.stop_sentinel = stop_sentinel 36 | print("Stop sentinel: ", self.stop_sentinel, 37 | self.articles[self.stop_sentinel].title) 38 | 39 | self._eval_cursor = 0 40 | 41 | def sample_paths(self, batch_size, is_training=True): 42 | all_paths, lengths = self.datasets["train" if is_training else "valid"] 43 | 44 | if is_training: 45 | ids = np.random.choice(len(all_paths), size=batch_size) 46 | else: 47 | if self._eval_cursor >= len(all_paths) - 1: 48 | self._eval_cursor = 0 49 | ids = np.arange(self._eval_cursor, 50 | min(len(all_paths), 51 | self._eval_cursor + batch_size)) 52 | self._eval_cursor += batch_size 53 | 54 | paths = [self._prepare_path(all_paths[idx]) for idx in ids] 55 | return ids, paths, lengths[ids] 56 | 57 | def get_num_paths(self, is_training=True): 58 | return len(self.datasets["train" if is_training else "valid"][0]) 59 | 60 | def get_article_links(self, article_idx): 61 | raise NotImplementedError 62 | 63 | def get_article_title(self, article_idx): 64 | if article_idx == self.stop_sentinel: 65 | return "" 66 | return self.articles[article_idx].title 67 | 68 | def get_relative_word_overlap(self, article1_idx, article2_idx): 69 | """ 70 | Get the proportion of words in `article1` that are also in `article2`. 71 | """ 72 | article1 = self.articles[article1_idx] 73 | article2 = self.articles[article2_idx] 74 | 75 | article1_types = set(article1.text) 76 | if len(article1_types) == 0: 77 | return 0.0 78 | 79 | article2_types = set(article2.text) 80 | return len(article1_types & article2_types) / float(len(article1_types)) 81 | 82 | def get_query_embeddings(self, path_ids): 83 | raise NotImplementedError 84 | 85 | def get_article_embeddings(self, article_ids): 86 | raise NotImplementedError 87 | 88 | def _prepare_path(self, path): 89 | raise NotImplementedError 90 | 91 | 92 | class EmbeddedWikispeediaGraph(EmbeddedWebGraph): 93 | 94 | def __init__(self, data_path, path_length, emb_paths=None): 95 | try: 96 | import cPickle as pickle 97 | except: import pickle 98 | 99 | with open(data_path, "rb") as data_f: 100 | data = pickle.load(data_f) 101 | self._data = data 102 | 103 | if emb_paths is not None: 104 | embeddings = [np.load(emb_path)["arr_0"] for emb_path in emb_paths] 105 | self.embedding_dim = embeddings[0].shape[1] 106 | for other_embeddings in embeddings: 107 | assert other_embeddings.shape == embeddings[0].shape 108 | self.embeddings = embeddings 109 | else: 110 | print("=====================================================\n" 111 | "WARNING: Using randomly generated article embeddings.\n" 112 | "=====================================================", 113 | file=sys.stderr) 114 | # Random embeddings. 115 | self.embedding_dim = 128 # fixed for now 116 | shape = (len(data["articles"]), self.embedding_dim) 117 | # Match Wikispeedia embedding distribution 118 | embeddings = np.random.normal(scale=0.15, size=shape) 119 | self.embeddings = [embeddings] 120 | 121 | articles = [EmbeddedArticle( 122 | article["name"], self.embeddings[0][i], 123 | set(token.lower() for token in article["lead_tokens"])) 124 | for i, article in enumerate(data["articles"])] 125 | 126 | assert articles[0].title == "_Stop" 127 | assert articles[1].title == "_Dummy" 128 | stop_sentinel = 0 129 | 130 | datasets = {} 131 | for dataset_name, dataset in data["paths"].items(): 132 | paths, original_lengths, n_skipped = [], [], 0 133 | for path in dataset: 134 | if len(path["articles"]) > path_length - 1: 135 | n_skipped += 1 136 | continue 137 | 138 | # Pad with STOP sentinel (every path gets at least one) 139 | pad_length = max(0, path_length + 1 - len(path["articles"])) 140 | original_length = len(path["articles"]) + 1 141 | path = path["articles"] + [stop_sentinel] * pad_length 142 | 143 | paths.append(path) 144 | original_lengths.append(original_length) 145 | 146 | print("%s set: skipped %i of %i paths due to length limit" 147 | % (dataset_name, n_skipped, len(dataset))) 148 | datasets[dataset_name] = (paths, np.array(original_lengths)) 149 | 150 | super(EmbeddedWikispeediaGraph, self).__init__(articles, datasets, 151 | path_length, 152 | stop_sentinel=stop_sentinel) 153 | 154 | def get_article_links(self, article_idx): 155 | return self._data["links"].get(article_idx, [self.stop_sentinel]) 156 | 157 | def get_query_embeddings(self, paths, embedding_set=0): 158 | # Get the last non-STOP page in each corresponding path. 159 | last_pages = [[idx for idx in path if idx != self.stop_sentinel][-1] 160 | for path in paths] 161 | return self.get_article_embeddings(last_pages, 162 | embedding_set=embedding_set) 163 | 164 | def get_article_embeddings(self, article_ids, embedding_set=0): 165 | return self.embeddings[embedding_set][article_ids] 166 | 167 | def _prepare_path(self, path): 168 | return path 169 | 170 | LOCAL_GRAPH_PATH = "wikispeedia.pkl" 171 | LOCAL_EMBEDDINGS_PATH = "wikispeedia_embeddings.npz" 172 | REMOTE_GRAPH_URL = "https://github.com/hans/wikispeedia/raw/master/data/wikispeedia.pkl" 173 | REMOTE_EMBEDDINGS_URL = "https://github.com/hans/wikispeedia/raw/master/data/wikispeedia_embeddings.npz" 174 | 175 | @classmethod 176 | def get_default_graph(cls, path_length=10): 177 | if hasattr(cls, "_default_graph"): 178 | return cls._default_graph 179 | 180 | # Load the built-in graph data, downloading if necessary. 181 | script_dir = os.path.dirname(os.path.realpath(__file__)) 182 | graph_path = os.path.join(script_dir, cls.LOCAL_GRAPH_PATH) 183 | if not os.path.exists(graph_path): 184 | print("Downloading default Wikispeedia graph.", file=sys.stderr) 185 | download_file(cls.REMOTE_GRAPH_URL, graph_path) 186 | emb_path = os.path.join(script_dir, cls.LOCAL_EMBEDDINGS_PATH) 187 | if not os.path.exists(emb_path): 188 | print("Downloading default Wikispeedia embeddings.", file=sys.stderr) 189 | download_file(cls.REMOTE_EMBEDDINGS_URL, emb_path) 190 | 191 | graph = cls(graph_path, path_length, emb_paths=[emb_path]) 192 | cls._default_graph = graph 193 | 194 | return graph 195 | 196 | 197 | class Navigator(object): 198 | 199 | def __init__(self, graph, beam_size, path_length): 200 | self.graph = graph 201 | self.beam_size = beam_size 202 | self.path_length = path_length 203 | 204 | assert self.graph.articles[1].title == "_Dummy", \ 205 | "Graph must have articles[1] == dummy article" 206 | self._dummy_page = 1 207 | print("Dummy page: ", self._dummy_page, 208 | self.graph.get_article_title(self._dummy_page)) 209 | 210 | self._id, self._path, self._length = None, None, None 211 | self.beam = None 212 | 213 | def reset(self, is_training=True): 214 | """ 215 | Prepare a new navigation rollout. 216 | """ 217 | # TODO: Sample outside of the training set. 218 | ids, paths, lengths = self.graph.sample_paths(1, is_training) 219 | self._id, self._path, self._length = ids[0], paths[0], lengths[0] 220 | self._cur_article_id = self._path[0] 221 | 222 | self._target_id = self._path[self._length - 2] 223 | self._on_target = False 224 | self._success, self._stopped = False, False 225 | 226 | self._num_steps = 0 227 | self._reset(is_training) 228 | self._prepare() 229 | 230 | def _reset(self, is_training): 231 | # For subclasses. 232 | pass 233 | 234 | def step(self, action): 235 | """ 236 | Make a navigation step with the given actions. 237 | """ 238 | self._step(action) 239 | # Now cur_article_id contains the result of taking the actions 240 | # specified. 241 | 242 | stopped_now = self.cur_article_id == self.graph.stop_sentinel 243 | self._stopped = self._stopped or stopped_now 244 | 245 | # Did we just stop at the target page? (Use previous self._on_target 246 | # before updating `on_target`) 247 | success_now = self._on_target and stopped_now 248 | self._success = self._success or success_now 249 | self._on_target = self.cur_article_id == self._target_id 250 | 251 | self._num_steps += 1 252 | self._prepare() 253 | 254 | def _step(self, action): 255 | """ 256 | For subclasses. Modify state using `action`. Metadata handled by this 257 | superclass. 258 | """ 259 | self._cur_article_id = self.get_article_for_action(action) 260 | 261 | @property 262 | def cur_article_id(self): 263 | return self._cur_article_id 264 | 265 | @property 266 | def gold_action(self): 267 | """ 268 | Return the gold navigation action for the current state. 269 | """ 270 | raise RuntimeError("Gold actions not defined for this navigator!") 271 | 272 | @property 273 | def target_id(self): 274 | """ 275 | Return target article ID. 276 | """ 277 | return self._target_id 278 | 279 | @property 280 | def on_target(self): 281 | """ 282 | Return True iff we are currently on the target page. 283 | """ 284 | return self.cur_article_id == self.target_id 285 | 286 | @property 287 | def gold_path_length(self): 288 | """ 289 | Return length of un-padded version of gold path (including stop 290 | sentinel). 291 | """ 292 | raise RuntimeError("Gold paths not defined for this navigator!") 293 | 294 | @property 295 | def done(self): 296 | """ 297 | `True` if the traversal was manually stopped or if the path length has 298 | been reached. 299 | """ 300 | return self._stopped or self._num_steps > self.path_length 301 | 302 | @property 303 | def success(self): 304 | """ 305 | `True` when the traversal has successfully reached the target. 306 | """ 307 | return self._success 308 | 309 | def get_article_for_action(self, action): 310 | """ 311 | Get the article ID corresponding to an action ID on the beam. 312 | """ 313 | return self.beam[action] 314 | 315 | def _get_candidates(self): 316 | """ 317 | Build a beam of candidate next-page IDs consisting of available links 318 | on the current article. 319 | 320 | NB: The candidate list returned may have a regular pattern, e.g. the 321 | stop sentinel / filler candidates (for candidate lists which are smaller 322 | than the beam size) may always be in the same position in the list. 323 | Make sure to not build models (e.g. ones with output biases) that might 324 | capitalize on this pattern. 325 | 326 | Returns: 327 | candidates: List of article IDs of length `self.beam_size`. 328 | """ 329 | all_links = self.graph.get_article_links(self.cur_article_id) 330 | 331 | # Sample `beam_size - 1`; add the STOP sentinel 332 | candidates = random.sample(all_links, min(self.beam_size - 1, 333 | len(all_links))) 334 | candidates.append(self.graph.stop_sentinel) 335 | 336 | if len(candidates) < self.beam_size: 337 | padding = [self._dummy_page] * (self.beam_size - len(candidates)) 338 | candidates.extend(padding) 339 | 340 | return candidates 341 | 342 | def _prepare(self): 343 | """ 344 | Prepare/update information about the current navigator state. 345 | Should be called after reset / steps are taken. 346 | """ 347 | self.beam = self._get_candidates() 348 | 349 | 350 | class OracleNavigator(Navigator): 351 | 352 | def _reset(self, is_training): 353 | self._cursor = 0 354 | 355 | def _step(self, action): 356 | # Ignore the action; we are following gold paths. 357 | self._cursor += 1 358 | 359 | @property 360 | def cur_article_id(self): 361 | if self._cursor < self._length: 362 | return self._path[self._cursor] 363 | return self.graph.stop_sentinel 364 | 365 | @property 366 | def gold_action(self): 367 | return self._gold_action 368 | 369 | @property 370 | def gold_path_length(self): 371 | return self._length 372 | 373 | @property 374 | def done(self): 375 | return self._cursor >= self._length 376 | 377 | def _get_candidates(self): 378 | """ 379 | Build a beam of candidate next-page IDs consisting of the valid 380 | solution and other negatively-sampled candidate links on the page. 381 | 382 | NB: The candidate list returned may have a regular pattern, e.g. the 383 | stop sentinel / filler candidates (for candidate lists which are smaller 384 | than the beam size) may always be in the same position in the list. 385 | Make sure to not build models (e.g. ones with output biases) that might 386 | capitalize on this pattern. 387 | 388 | Returns: 389 | candidates: List of article IDs of length `self.beam_size`. 390 | The list is guaranteed to contain 1) the gold next page 391 | according to the oracle trajectory and 2) the stop sentinel. 392 | (Note that these two will make up just one candidate if the 393 | valid next action is to stop.) 394 | """ 395 | # Retrieve gold next-page choice for this example 396 | try: 397 | gold_next_id = path[cursor + 1] 398 | except IndexError: 399 | # We are at the end of this path and ready to quit. Prepare a 400 | # dummy beam that won't have any effect. 401 | candidates = [self._dummy_page] * self.beam_size 402 | self._gold_action = 0 403 | return candidates 404 | 405 | ids = self.graph.get_article_links(self.cur_article_id) 406 | ids = [int(x) for x in ids if x != gold_next_id] 407 | 408 | # Beam must be large enough to hold gold + STOP + a distractor 409 | assert self.beam_size >= 3 410 | gold_is_stop = gold_next_id == self.graph.stop_sentinel 411 | 412 | # Number of distractors to sample 413 | sample_size = self.beam_size - 1 if gold_is_stop \ 414 | else self.beam_size - 2 415 | 416 | if len(ids) > sample_size: 417 | ids = random.sample(ids, sample_size) 418 | if len(ids) < sample_size: 419 | ids += [self._dummy_page] * (sample_size - len(ids)) 420 | 421 | # Add the gold page. 422 | ids = [gold_next_id] + ids 423 | if not gold_is_stop: 424 | ids += [self.graph.stop_sentinel] 425 | random.shuffle(ids) 426 | 427 | assert len(ids) == self.beam_size 428 | 429 | self._gold_action = gold_next_id 430 | return ids 431 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name="gym_wikinav", 4 | version="0.0.1", 5 | install_requires=["gym"]) 6 | --------------------------------------------------------------------------------