├── .gitignore ├── LICENSE ├── README.md ├── baselines ├── 0001-Set-up-for-running-deits-solver-with-timeout.patch ├── baseline_deits.ipynb ├── baseline_knn.ipynb ├── baseline_t5.ipynb └── baseline_wn.ipynb ├── data ├── disjoint.json.zip ├── disjoint_word_init.json.zip ├── generated │ └── twl_dict.txt ├── guardian_2020_10_08.json.zip ├── naive_random.json.zip └── original │ ├── deits_anag_indic │ └── ana_ │ ├── names │ ├── README.txt │ ├── all.txt │ ├── boys.txt │ └── girls.txt │ └── us │ ├── US.dic │ └── US.txt ├── decrypt ├── __init__.py ├── common │ ├── __init__.py │ ├── anagrammer.py │ ├── label_anagrams.py │ ├── puzzle_clue.py │ ├── substitution.py │ ├── util_data.py │ ├── util_spellchecker.py │ ├── util_wordnet.py │ └── validation_tools.py ├── config.py └── scrape_parse │ ├── __init__.py │ ├── acw_load.py │ ├── guardian_load.py │ ├── guardian_scrape.py │ ├── make_public_data.ipynb │ ├── make_public_data_datasets.ipynb │ └── util.py ├── experiments ├── curricular.ipynb └── model_analysis.ipynb ├── requirements.txt └── seq2seq ├── args_cryptic.py ├── common_seq ├── __init__.py ├── collate_fns.py ├── types.py ├── util.py ├── util_checkpoint.py ├── util_dataloader.py ├── util_dataloader_batch.py ├── util_metrics.py └── util_multiloader.py ├── model_runner.py ├── multitask_config.py ├── train_abc.py ├── train_clues.py └── train_descramble.py /.gitignore: -------------------------------------------------------------------------------- 1 | # data folders but not files 2 | data/* 3 | !data/guardian_2020_10_08.json.zip 4 | nocommit/ 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Shield: [![CC BY 4.0][cc-by-shield]][cc-by] 2 | 3 | This work is licensed under a 4 | [Creative Commons Attribution 4.0 International License][cc-by]. 5 | 6 | [![CC BY 4.0][cc-by-image]][cc-by] 7 | 8 | [cc-by]: http://creativecommons.org/licenses/by/4.0/ 9 | [cc-by-image]: https://i.creativecommons.org/l/by/4.0/88x31.png 10 | [cc-by-shield]: https://img.shields.io/badge/License-CC%20BY%204.0-lightgrey.svg 11 | 12 | Please cite 13 | https://proceedings.neurips.cc/paper/2021/hash/5f1d3986fae10ed2994d14ecd89892d7-Abstract.html 14 | https://arxiv.org/abs/2104.08620 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [comment]: <> (adapted from https://github.com/paperswithcode/releasing-research-code) 2 | 3 | # Decrypting Cryptic Crosswords: Semantically Complex Wordplay Puzzles as a Target for NLP 4 | 5 | This repository is the official implementation of 6 | [Decrypting Cryptic Crosswords: Semantically Complex Wordplay 7 | Puzzles as a Target for NLP](https://arxiv.org/abs/2104.08620). 8 | Please cite arxiv or [Neurips 2021 version](https://proceedings.neurips.cc/paper/2021/hash/5f1d3986fae10ed2994d14ecd89892d7-Abstract.html) 9 | 10 | [comment]: <> (>📋 todo Optional: include a graphic explaining your approach/main result, bibtex entry, link to demos, blog posts and tutorials) 11 | The dataset is also available at https://doi.org/10.5061/dryad.n02v6wwzp 12 | 13 | ## Requirements 14 | 15 | This will enable you to download and replicate the datasplits, but it has not been updated 16 | to include all requirements to run the (baselines and experiments notebooks). 17 | ```setup 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | ## Preparing data 22 | ```setup 23 | git clone # if using code supplement, just unzip 24 | cd decrypt 25 | pushd ./data && unzip "*.json.zip" && popd 26 | ``` 27 | 28 | ### Download data (can safely be ignored) 29 | If you want to download the data yourself from the web (you probably don't want to) 30 | ```setup 31 | git clone # if using code supplement, just unzip 32 | cd decrypt 33 | mkdir -p './data/puzzles' 34 | python decrypt/scrape_parse/guardian_scrape.py --save_directory="./data/puzzles" 35 | ``` 36 | Then when you run `load_guardian_splits` you will run 37 | `load_guardian_splits("./data/puzzles", load_from_files=True, use_premade_json=False)` 38 | 39 | 40 | # Reproducing our splits 41 | ```python 42 | from decrypt.scrape_parse import ( 43 | load_guardian_splits, # naive random split 44 | load_guardian_splits_disjoint, # answer-disjoint split 45 | load_guardian_splits_disjoint_hash # word-initial disjoint split 46 | ) 47 | from decrypt.scrape_parse.guardian_load import SplitReturn 48 | """ 49 | each of these methods returns a tuple of `SplitReturn` 50 | - soln to clue map (string to List of clues mapping to that soln): Dict[str, List[BaseClue] 51 | this enables seeing all clues associated with a given answer word 52 | - list of all clues (List[BaseClue]) 53 | - Tuple of three lists (the train, val, test splits), each is List[BaseClue] 54 | 55 | Note that 56 | load_guardian_splits() will verify that 57 | - total glob length matches the one in paper (ie. number of puzzles downloaded matches) 58 | - total clue set length matches the one in paper (i.e. filtering is the same) 59 | - one of the clues in our train set matches our train set (i.e. a single clue 60 | spot check for randomness) 61 | If you get an assertion error or an exception during load, please file an 62 | issue, since the splits should be identical 63 | Alternatively, if you don't care, you can pass `verify=False` to 64 | `load_guardian_splits` 65 | """ 66 | 67 | soln_to_clue_map, all_clues_list, (train, val, test) = load_guardian_splits() 68 | ``` 69 | 70 | ## Replicating our work 71 | We make code available to replicate the entire paper. 72 | 73 | Note that the directory structure is specified in `decrypt/config.py`. You can change it if you would like. 74 | Most references use this file, but run commands (i.e. `python ...` assume that the directories are unchanged 75 | from the original config.py. 76 | 77 | ### Datasets and task (Section 3) 78 | - The splits are replicated as above using the load methods 79 | - The task is replicated in the following sections 80 | - We provide code to replicate metric analysis. See the implementation in jupyter notebooks below 81 | 82 | To run the notebooks, you should start your jupyter server from the top level `decrypt` directory. 83 | The notebooks have been run using pycharm open from the top level `decrypt` directory. 84 | If you experience import errors it is likely because you are not running from the top level. 85 | 86 | ### Baselines (Section 4) 87 | Notebook to replicate the four baselines are in `baselines` directory. 88 | Note that a patch will need to be applied to work with the deits solver. 89 | 90 | 91 | ### Curriculum Learning (Section 5) 92 | See `experiments/curricular.ipynb` 93 | 94 | ### Model Analysis 95 | See `experiments/model_analysis` 96 | 97 | 98 | ### Misc 99 | Note that details of training and evaluating the models are available in the relevant jupyter 100 | notebooks. 101 | 102 | [comment]: <> ([comment]: <> TODO (## Pre-trained Models)) 103 | 104 | [comment]: <> (You can download pretrained models here:) 105 | 106 | [comment]: <> (- [My awesome model](https://drive.google.com/mymodel.pth) trained on ImageNet using parameters x,y,z. ) 107 | 108 | [comment]: <> (>📋 Give a link to where/how the pretrained models can be downloaded and how they were trained (if applicable). Alternatively you can have an additional column in your results table with a link to the models.) 109 | 110 | [comment]: <> (## Results) 111 | 112 | [comment]: <> (Our model achieves the following performance on :) 113 | 114 | [comment]: <> (### [Image Classification on ImageNet](https://paperswithcode.com/sota/image-classification-on-imagenet)) 115 | 116 | [comment]: <> (| Model name | Top 1 Accuracy | Top 5 Accuracy |) 117 | 118 | [comment]: <> (| ------------------ |---------------- | -------------- |) 119 | 120 | [comment]: <> (| My awesome model | 85% | 95% |) 121 | 122 | [comment]: <> (>📋 Include a table of results from your paper, and link back to the leaderboard for clarity and context. If your main result is a figure, include that figure and link to the command or notebook to reproduce it. ) 123 | 124 | 125 | [comment]: <> (## Contributing) 126 | 127 | [comment]: <> (>📋 Pick a licence and describe how to contribute to your code repository. ) 128 | 129 | 130 | -------------------------------------------------------------------------------- /baselines/0001-Set-up-for-running-deits-solver-with-timeout.patch: -------------------------------------------------------------------------------- 1 | From ab1dd26b2b854c5bb8ed4c99f9b7ccea423a4462 Mon Sep 17 00:00:00 2001 2 | Date: Thu, 27 May 2021 18:04:51 -0700 3 | 4 | --- 5 | .gitignore | 6 ++ 6 | pycryptics/data_generators/generate_synonyms.py | 2 +- 7 | pycryptics/solve_clue.py | 23 +++--- 8 | timeout.py | 19 +++++ 9 | validate_cryptics.py | 99 +++++++++++++++++++++++++ 10 | 5 files changed, 139 insertions(+), 10 deletions(-) 11 | create mode 100644 timeout.py 12 | create mode 100644 validate_cryptics.py 13 | 14 | diff --git a/.gitignore b/.gitignore 15 | index a9aaf4c..d9a9b6c 100644 16 | --- a/.gitignore 17 | +++ b/.gitignore 18 | @@ -6,3 +6,9 @@ data 19 | app_build 20 | .ipynb_checkpoints 21 | env/* 22 | + 23 | +*.out 24 | +clues/* 25 | +outputs*/ 26 | +raw_data/* 27 | +nltk* 28 | diff --git a/pycryptics/data_generators/generate_synonyms.py b/pycryptics/data_generators/generate_synonyms.py 29 | index 0207e16..3cb780e 100644 30 | --- a/pycryptics/data_generators/generate_synonyms.py 31 | +++ b/pycryptics/data_generators/generate_synonyms.py 32 | @@ -66,7 +66,7 @@ def main(): 33 | i = 0 34 | for word in WORDS: 35 | if i % 1000 == 0: 36 | - print i, "/", len(WORDS) 37 | + print(i, "/", len(WORDS)) 38 | i += 1 39 | word = word.lower() 40 | syns = map(cleanup, list(synonyms(word))) 41 | diff --git a/pycryptics/solve_clue.py b/pycryptics/solve_clue.py 42 | index 1dc76d9..c86a57c 100644 43 | --- a/pycryptics/solve_clue.py 44 | +++ b/pycryptics/solve_clue.py 45 | @@ -7,6 +7,7 @@ from pycryptics.grammar.clue_tree import ClueUnsolvableError 46 | from collections import namedtuple 47 | import re 48 | 49 | +from timeout import TimeoutError 50 | 51 | Constraints = namedtuple('Constraints', 'phrases lengths pattern known_answer') 52 | 53 | @@ -102,13 +103,16 @@ class CrypticClueSolver(object): 54 | 55 | self.answers_with_clues = [] 56 | 57 | - for p in all_phrasings: 58 | - constraints = constraints._replace(phrases=p) 59 | - # constraints = Constraints(p, lengths, pattern, answer) 60 | - if not self.quiet: 61 | - print p 62 | - for ann_ans in self.solve_constraints(constraints): 63 | - self.answers_with_clues.append(ann_ans) 64 | + try: 65 | + for p in all_phrasings: 66 | + constraints = constraints._replace(phrases=p) 67 | + # constraints = Constraints(p, lengths, pattern, answer) 68 | + if not self.quiet: 69 | + print p 70 | + for ann_ans in self.solve_constraints(constraints): 71 | + self.answers_with_clues.append(ann_ans) 72 | + except TimeoutError: 73 | + pass 74 | if len(self.answers_with_clues) == 0 and constraints.pattern.replace('.', '') != "": 75 | self.answers_with_clues = [PatternAnswer(x, all_phrasings[0]) for x in SYNONYMS if matches_pattern(x, constraints.pattern, constraints.lengths)] 76 | self.answers_with_clues.sort(reverse=True) 77 | @@ -119,7 +123,6 @@ class CrypticClueSolver(object): 78 | possible_clues = generate_clues(constraints) 79 | 80 | for i, clue in enumerate(possible_clues): 81 | - # print "solving:", clue 82 | try: 83 | answers = clue.answers 84 | except ClueUnsolvableError: 85 | @@ -137,9 +140,11 @@ def matches_pattern(word, pattern, lengths): 86 | return (tuple(len(x) for x in word.split('_')) == lengths) and all(c == pattern[i] or pattern[i] == '.' for i, c in enumerate(word.replace('_', ''))) 87 | 88 | 89 | -def split_clue_text(clue_text): 90 | +def split_clue_text(clue_text, assert_has_answer=False): 91 | clue_text = clue_text.encode('ascii', 'ignore') 92 | if '|' not in clue_text: 93 | + if assert_has_answer: 94 | + raise ValueError('missing an answer for ', clue_text) 95 | clue_text += ' |' 96 | clue_text = clue_text.lower() 97 | clue, paren, rest = clue_text.rpartition('(') 98 | diff --git a/timeout.py b/timeout.py 99 | new file mode 100644 100 | index 0000000..ad44652 101 | --- /dev/null 102 | +++ b/timeout.py 103 | @@ -0,0 +1,19 @@ 104 | +import signal 105 | + 106 | +class TimeoutError(Exception): 107 | + pass 108 | + 109 | +def timeout(func, args, timeout_duration, default=None): 110 | + def handler(signum, frame): 111 | + raise TimeoutError() 112 | + 113 | + # set the timeout handler 114 | + signal.signal(signal.SIGALRM, handler) 115 | + signal.alarm(timeout_duration) 116 | + try: 117 | + result = func(*args) 118 | + except TimeoutError: 119 | + result = default 120 | + finally: 121 | + signal.alarm(0) 122 | + return result 123 | diff --git a/validate_cryptics.py b/validate_cryptics.py 124 | new file mode 100644 125 | index 0000000..8978f16 126 | --- /dev/null 127 | +++ b/validate_cryptics.py 128 | @@ -0,0 +1,99 @@ 129 | +from pycryptics.solve_clue import CrypticClueSolver, split_clue_text 130 | +from collections import Counter 131 | +from tqdm import tqdm 132 | +from timeout import timeout 133 | + 134 | +#k_file = 'clues/guardian_disj2_test.txt' 135 | + 136 | +def normalize_output(s): 137 | + return s.replace("_"," ").lower() # handle treatment of underscores 138 | + 139 | +def dump_json(idx, obj): 140 | + print('dumping json at ', idx) 141 | + #name = "./outputs/" + k_name + "_" + str(k_start) + "_" + str(idx) + ".json" 142 | + name = "./" + out_dir + "/" + k_name + "_" + str(k_start) + "_" + str(idx) + ".json" 143 | + with open(name, 'w') as f: 144 | + json.dump(obj, f) 145 | + 146 | +def get_answers(solver, clue_text, timeout_len): 147 | + phrases, lengths, pattern, known_answer = split_clue_text(clue_text, assert_has_answer=True) 148 | + 149 | + tgt = known_answer.lower().strip() 150 | + solver.setup(clue_text) 151 | + answers = timeout(solver.run, args=[], timeout_duration=timeout_len) 152 | + 153 | + if answers is None: # timed out 154 | + return -2, [] 155 | + 156 | + return 0, answers, tgt 157 | + 158 | +def validate(timeout_len=20, start_idx=0, end_idx=None): 159 | + with CrypticClueSolver() as solver: 160 | + solver.quiet = True 161 | + 162 | + ct_timed_out = 0 163 | + ct_error = 0 164 | + ct_top = 0 165 | + ct_any_10 = 0 166 | + num_answers_ctr = Counter() 167 | + 168 | + output_set = [] # input, tgt, greedy (empty), sampled, timeout, error 169 | + 170 | + with open(k_file, 'r') as f: 171 | + for clue_idx, clue_text in enumerate(tqdm(f.readlines())): 172 | + timeout = False 173 | + error = False 174 | + 175 | + if clue_idx < start_idx: 176 | + continue 177 | + if end_idx is not None and clue_idx == end_idx: 178 | + break 179 | + try: 180 | + ret_code, answers, tgt = get_answers(solver, clue_text, timeout_len) 181 | + 182 | + # errors 183 | + if ret_code != 0: 184 | + answers = [] 185 | + if ret_code == -2: 186 | + ct_timed_out += 1 187 | + timeout = True 188 | + print('timed out') 189 | + 190 | + 191 | + answers_text = [normalize_output(a.answer) for a in answers] 192 | + num_answers_ctr[len(answers_text)] += 1 193 | + 194 | + for idx, a in enumerate(answers_text): 195 | + if len(a) != len(tgt): 196 | + print('invalid length', clue_text) 197 | + if a == tgt: 198 | + if idx == 0: 199 | + ct_top += 1 200 | + if idx <= 9: 201 | + ct_any_10 += 1 202 | + break 203 | + output_set.append((clue_idx, clue_text.strip(), tgt, "", answers_text, timeout, error)) 204 | + 205 | + except Exception: 206 | + print('got error') 207 | + ct_error += 1 208 | + error = True 209 | + output_set.append((clue_idx, "", "", "", [], timeout, error)) 210 | + 211 | + print(ct_timed_out, ct_error, ct_top, ct_any_10, num_answers_ctr) 212 | + return output_set 213 | + 214 | +import json 215 | +import sys 216 | +if __name__ == "__main__": 217 | + k_name = sys.argv[1] 218 | + time_len = int(sys.argv[2]) 219 | + k_start = int(sys.argv[3]) 220 | + k_end = int(sys.argv[4]) 221 | + k_file = sys.argv[5] 222 | + out_dir = sys.argv[6] 223 | + 224 | + print(k_file) 225 | + print(k_name, time_len, k_start, k_end) 226 | + res = validate(time_len, k_start, k_end) 227 | + dump_json(-1, res) 228 | -- 229 | 2.7.4 230 | 231 | -------------------------------------------------------------------------------- /baselines/baseline_deits.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# how to run deits solver\n", 8 | "1. clone the deits repository (https://github.com/rdeits/cryptics) into './deits/; Note that the julia solver is better but for our research we used python version.\n", 9 | "2. checkout commit 402579 (in deits repo)\n", 10 | "3. apply patch 0001-Set-up... (in this directory)\n", 11 | "4. Use this notebook to set up the deits input clue file\n", 12 | "5. run validate_cryptics.py in the deits directory.\n", 13 | " - this file will be created by patch application.\n", 14 | " (see bottom of validate_cryptics.py for the command line arguments that should be included)\n", 15 | " - you will need to specify an output file. Use the abs path of '../deits/clues/'\n", 16 | " - you will need to specify input clue file. that should be the one generated in this nb ('../deits/clues/*')\n", 17 | "6. Use this notebook to run eval (this file)\n" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": { 24 | "pycharm": { 25 | "name": "#%%\n" 26 | } 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "%load_ext autoreload\n", 31 | "%autoreload 2" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": { 38 | "pycharm": { 39 | "name": "#%%\n" 40 | } 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "from decrypt import config\n", 45 | "k_json_folder = config.DataDirs.Guardian.json_folder\n", 46 | "k_deits_clue_folder = config.DataDirs.Deits.k_deits_clues\n", 47 | "k_deits_output_folder = config.DataDirs.Deits.k_deits_outputs" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": { 54 | "pycharm": { 55 | "name": "#%%\n" 56 | } 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "from decrypt.scrape_parse import (\n", 61 | " load_guardian_splits,\n", 62 | " load_guardian_splits_disjoint_hash\n", 63 | ")\n", 64 | "from decrypt.common.validation_tools import (\n", 65 | " load_deits,\n", 66 | " all_aggregate,\n", 67 | ")\n", 68 | "from decrypt.common.puzzle_clue import GuardianClue" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": { 74 | "pycharm": { 75 | "name": "#%% md\n" 76 | } 77 | }, 78 | "source": [ 79 | "Setting up the datasets" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": { 86 | "pycharm": { 87 | "name": "#%%\n" 88 | } 89 | }, 90 | "outputs": [], 91 | "source": [ 92 | "## produce deits\n", 93 | "\n", 94 | "# produce test set for cdeits\n", 95 | "def make_deits_format(gc: GuardianClue):\n", 96 | " len_str = \"(\" + \",\".join(map(str, gc.lengths)) + \")\"\n", 97 | " clue_str = gc.clue + \" \" + len_str\n", 98 | " final = clue_str + \" | \" + gc.soln_with_spaces\n", 99 | " return final\n", 100 | "\n", 101 | "def make_deits(fname, clueset):\n", 102 | " with open(fname, \"w\") as f:\n", 103 | " for c in clueset:\n", 104 | " f.write(make_deits_format(c) + \"\\n\")\n", 105 | "\n", 106 | "\n", 107 | "def prep_deits(val_or_test: str,\n", 108 | " naive_or_disj: str):\n", 109 | " assert val_or_test in [\"val\", \"test\"]\n", 110 | " assert naive_or_disj in [\"naive\", \"disj\"]\n", 111 | " if naive_or_disj == \"naive\":\n", 112 | " load_fn = load_guardian_splits\n", 113 | " else:\n", 114 | " load_fn = load_guardian_splits_disjoint_hash\n", 115 | "\n", 116 | " _, _, (train_local, val, test) = load_fn(k_json_folder)\n", 117 | " if val_or_test == \"val\":\n", 118 | " val_local = val\n", 119 | " else:\n", 120 | " val_local = test\n", 121 | "\n", 122 | " append_name = f'guardian_{naive_or_disj}_{val_or_test}'\n", 123 | " fname = f'{k_deits_clue_folder}{append_name}.txt'\n", 124 | " output_folder = f'{k_deits_output_folder}{append_name}/' # outputs will be like output/{append_name}/*-1.json\n", 125 | " return val_local, fname, output_folder\n", 126 | "\n", 127 | "def make_deits_clues(val_or_test,\n", 128 | " naive_or_disj):\n", 129 | " val_local, fname, _ = prep_deits(val_or_test, naive_or_disj)\n", 130 | " make_deits(fname, val_local)" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": { 136 | "pycharm": { 137 | "name": "#%% md\n" 138 | } 139 | }, 140 | "source": [ 141 | "This will produce a clue file in `k_deits_clue_folder`\n", 142 | "For example, for disjoint val set:\n", 143 | "make_deits_clues('val', 'disj') # produces a clue file in k_deits_clue_folder\n", 144 | "\n", 145 | "To fully replicate results, also run\n", 146 | "- `make_deits_clues('test', 'disj')\n", 147 | "- `make_deits_clues('val', 'naive')\n", 148 | "- `make_deits_clues('test', 'naive')" 149 | ] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "metadata": { 154 | "pycharm": { 155 | "name": "#%% md\n" 156 | } 157 | }, 158 | "source": [ 159 | "Now we run `deits/validate_cryptics.py` to generate outputs from the model.\n", 160 | "For example,\n", 161 | "`python validate_cryptics.py `\n", 162 | "\n", 163 | "To run all clues, e.g.,\n", 164 | "- determine the last index of the clues in the set on which you are evaluating. If it is 28442 then\n", 165 | "`python validate_cryptics.py out.json 120 0 28443 k_deits_clue_folder/guardian_naive_val.txt outputs/naive_val/`\n", 166 | "\n", 167 | "To actually replicate this, it is best to split up the runs into sets of, e.g., 100 clues and parallelize.\n", 168 | "\n", 169 | "It is recommended to use the Julia rather than the Deits solver since this one is incredibly slow.\n", 170 | "\n", 171 | "Finally, we run the evaluation code:\n", 172 | "Below we provide code to run the two models to produce row 1 of Main Results\n", 173 | "Table 2 in the paper.\n", 174 | "\n", 175 | "Note that for the Main Results Table 2, the metrics we include in the table correspond to\n", 176 | "- `agg_top_match`\n", 177 | "- `agg_top_10_after_filter`\n", 178 | "\n", 179 | "More details of these metric calculations can be found in `decrypt.common.validation_tools`" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "metadata": { 186 | "pycharm": { 187 | "name": "#%%\n" 188 | } 189 | }, 190 | "outputs": [], 191 | "source": [ 192 | "def eval_deits_clues(val_or_test,\n", 193 | " naive_or_disj):\n", 194 | " val_local, fname, output_folder = prep_deits(val_or_test, naive_or_disj)\n", 195 | " deits_outputs_glob = output_folder + \"*-1.json\"\n", 196 | " model_outputs = load_deits(val_local, deits_outputs_glob)\n", 197 | " all_aggregate(model_outputs)\n" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "metadata": { 204 | "pycharm": { 205 | "name": "#%%\n" 206 | } 207 | }, 208 | "outputs": [], 209 | "source": [ 210 | "# for example\n", 211 | "eval_deits_clues('val', 'disj') # will evaluate the outputs\n", 212 | "\n" 213 | ] 214 | } 215 | ], 216 | "metadata": { 217 | "kernelspec": { 218 | "display_name": "Python 3", 219 | "language": "python", 220 | "name": "python3" 221 | }, 222 | "language_info": { 223 | "codemirror_mode": { 224 | "name": "ipython", 225 | "version": 2 226 | }, 227 | "file_extension": ".py", 228 | "mimetype": "text/x-python", 229 | "name": "python", 230 | "nbconvert_exporter": "python", 231 | "pygments_lexer": "ipython2", 232 | "version": "2.7.6" 233 | } 234 | }, 235 | "nbformat": 4, 236 | "nbformat_minor": 0 237 | } 238 | -------------------------------------------------------------------------------- /baselines/baseline_knn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": { 17 | "pycharm": { 18 | "name": "#%%\n" 19 | } 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "from decrypt import config\n", 24 | "k_json_folder = config.DataDirs.Guardian.json_folder\n", 25 | "\n", 26 | "from decrypt.scrape_parse import (\n", 27 | " load_guardian_splits,\n", 28 | " load_guardian_splits_disjoint_hash\n", 29 | ")\n", 30 | "\n", 31 | "from sklearn.neighbors import KNeighborsClassifier\n", 32 | "from decrypt.common.puzzle_clue import GuardianClue\n", 33 | "\n", 34 | "from sklearn.feature_extraction.text import CountVectorizer\n", 35 | "from typing import *\n", 36 | "from tqdm import tqdm\n", 37 | "\n", 38 | "from decrypt.common import validation_tools as vt" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": { 45 | "pycharm": { 46 | "name": "#%%\n" 47 | } 48 | }, 49 | "outputs": [], 50 | "source": [ 51 | "# no need to lowercase - countvectorizer does this\n", 52 | "def load_data(clue_list: List[GuardianClue], add_lens=False):\n", 53 | " \"\"\"\n", 54 | " Take clue_list and return the X and Y data\n", 55 | " \"\"\"\n", 56 | " def iter_fcn(clue: GuardianClue):\n", 57 | " if add_lens:\n", 58 | " ret = clue.clue_with_lengths(\"|\")\n", 59 | " else:\n", 60 | " ret = clue.clue\n", 61 | " return ret\n", 62 | "\n", 63 | " X = [iter_fcn(c) for c in clue_list]\n", 64 | " Y = [c.soln_with_spaces.lower() for c in clue_list]\n", 65 | " return X, Y\n", 66 | "\n", 67 | "def knn_eval(train, val, add_lens, knn_neighbors: int, verify=False):\n", 68 | " \"\"\"\n", 69 | " :param add_lens: whether to add lengths to the input clues\n", 70 | " :param knn_neighbors: number of neighbors to use when doing \"beam search\". None => no beam search\n", 71 | " \"\"\"\n", 72 | " # load data\n", 73 | " train_inputs, train_targets = load_data(train, add_lens)\n", 74 | " test_inputs, test_targets = load_data(val, add_lens)\n", 75 | " print(train_inputs[:2])\n", 76 | "\n", 77 | " # set up the bag-of-words vectorizer\n", 78 | " # token patter needed for the length specification\n", 79 | " bow_vectorizer = CountVectorizer(token_pattern='[a-z\\d()|]+',\n", 80 | " ngram_range=(1,1)) # further ngrams degrade performance\n", 81 | " bowVect = bow_vectorizer.fit(train_inputs)\n", 82 | "\n", 83 | " # show that everything was vectorized correctly\n", 84 | " print(len(bowVect.vocabulary_))\n", 85 | " if verify:\n", 86 | " for w in train_inputs[0].replace(\",\",\" \").lower().split(\" \"):\n", 87 | " if w == '': continue\n", 88 | " print(bowVect.vocabulary_[w])\n", 89 | "\n", 90 | " bowTrain = bowVect.transform(train_inputs)\n", 91 | " bowTest = bowVect.transform(test_inputs)\n", 92 | "\n", 93 | " # fit KNN\n", 94 | " # neighbor setting here doesn't matter; can put in call to knn.kneighbors\n", 95 | " knn = KNeighborsClassifier()\n", 96 | " knn.fit(bowTrain, train_targets)\n", 97 | "\n", 98 | " # predict (runs long)\n", 99 | " # get the nearest neighbors (beam search)\n", 100 | " # this returns in sorted order, so commented code not needed\n", 101 | " # nn_dist, nn_idx = knn.kneighbors(bowTest, n_neighbors=knn_neighbors, return_distance=True)\n", 102 | " # nn_dist_and_idx = zip(nn_dist, nn_idx)\n", 103 | " nn = knn.kneighbors(bowTest, n_neighbors=knn_neighbors, return_distance=False)\n", 104 | "\n", 105 | " # get the predictions (\"greedy\")\n", 106 | " pred = knn.predict(bowTest)\n", 107 | " return train_targets, test_targets, nn, pred" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": { 114 | "pycharm": { 115 | "name": "#%%\n" 116 | } 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "def eval_knn(val_set: List[GuardianClue],\n", 121 | " train_targets,\n", 122 | " test_targets,\n", 123 | " nn,\n", 124 | " pred):\n", 125 | " model_outputs = []\n", 126 | "\n", 127 | " # don't need to check idx set since we have a 1:1 of val_gc to test_tgt\n", 128 | " # nn_list is already sorted\n", 129 | " for val_gc, test_tgt, nn_list, greedy_pred in tqdm(zip(val_set, test_targets, nn, pred)):\n", 130 | " assert val_gc.soln_with_spaces == test_tgt\n", 131 | " neighbor_solns = [train_targets[n] for n in nn_list]\n", 132 | "\n", 133 | " # nbr set is the list of indices of nearest neighbor\n", 134 | " # we retrieve all the solns for those neighbors (y_train[i])\n", 135 | " mp = vt.ModelPrediction(idx=val_gc.idx,\n", 136 | " input=val_gc.clue_with_lengths(punct=\"|\"),\n", 137 | " target=test_tgt,\n", 138 | " greedy=greedy_pred,\n", 139 | " sampled=neighbor_solns)\n", 140 | "\n", 141 | " mp.model_eval = vt.eval(mp)\n", 142 | " model_outputs.append(mp)\n", 143 | "\n", 144 | " return model_outputs\n", 145 | "\n", 146 | "def aggregate(val, output_tuple):\n", 147 | " model_out = eval_knn(val, *output_tuple)\n", 148 | " vt.all_aggregate(model_out)\n" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": { 155 | "pycharm": { 156 | "name": "#%%\n" 157 | } 158 | }, 159 | "outputs": [], 160 | "source": [ 161 | "# eval on naive val set (with/without lens)\n", 162 | "def run_eval_knn(val_or_test: str, naive_or_disj: str, nn=3000):\n", 163 | " assert val_or_test in [\"val\", \"test\"]\n", 164 | " assert naive_or_disj in [\"naive\", \"disj\"]\n", 165 | " if naive_or_disj == \"naive\":\n", 166 | " load_fn = load_guardian_splits\n", 167 | " else:\n", 168 | " load_fn = load_guardian_splits_disjoint_hash\n", 169 | " _, _, (train_local, val, test) = load_fn(k_json_folder)\n", 170 | " if val_or_test == \"val\":\n", 171 | " val_local = val\n", 172 | " else:\n", 173 | " val_local = test\n", 174 | "\n", 175 | " knn_tuple_random_val_nolens = knn_eval(train_local, val_local, add_lens=False, knn_neighbors=nn)\n", 176 | " aggregate(val_local, knn_tuple_random_val_nolens)\n", 177 | "\n", 178 | " knn_tuple_random_val_lens = knn_eval(train_local, val_local, add_lens=True, knn_neighbors=nn)\n", 179 | " aggregate(val_local, knn_tuple_random_val_lens)\n", 180 | "\n", 181 | " return knn_tuple_random_val_nolens, knn_tuple_random_val_lens" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": { 188 | "pycharm": { 189 | "name": "#%%\n" 190 | } 191 | }, 192 | "outputs": [], 193 | "source": [ 194 | "# run with nn=3000 to replicate research\n", 195 | "knn_tuple_random_val_nolens, knn_tuple_random_val_lens = run_eval_knn(val_or_test=\"val\",\n", 196 | " naive_or_disj=\"naive\",\n", 197 | " nn=3000)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "markdown", 202 | "metadata": { 203 | "pycharm": { 204 | "name": "#%% md\n" 205 | } 206 | }, 207 | "source": [ 208 | "To reproduce the two rows corresponding to KNN in Main Results, also run with\n", 209 | "- val_or_test = \"test\"\n", 210 | "- naive_or_disj=\"disj\"\n", 211 | "\n", 212 | "Note that for the Main Results Table 2, the metrics we include in the table correspond to\n", 213 | "- `agg_top_match`\n", 214 | "- `agg_top_10_after_filter`\n", 215 | "\n", 216 | "More details of these metric calculations can be found in `decrypt.common.validation_tools`\n", 217 | "\n", 218 | "\n" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "metadata": { 225 | "pycharm": { 226 | "name": "#%%\n" 227 | } 228 | }, 229 | "outputs": [], 230 | "source": [ 231 | "####\n", 232 | "# Supplementary to verify KNN works the way we expect\n", 233 | "###\n", 234 | "\n", 235 | "## verify how KNN does tokenization\n", 236 | "train_inputs, train_targets = load_data(train, True)\n", 237 | "print(train_inputs[:2])\n", 238 | "\n", 239 | "# set up the bag-of-words vectorizer\n", 240 | "# token patter needed for the length specification\n", 241 | "# punctuation not included in token pattern (e.g. , or ') will be split and treated as space\n", 242 | "# see below experiment\n", 243 | "bow_vectorizer = CountVectorizer(token_pattern='[a-z\\d()|]+',\n", 244 | " ngram_range=(1,1)) # further ngrams degrade performance\n", 245 | "bowVect = bow_vectorizer.fit(train_inputs)\n", 246 | "\n", 247 | "# show that everything was vectorized correctly\n", 248 | "print(len(bowVect.vocabulary_))\n", 249 | "print()\n", 250 | "# for w in train_inputs[0].replace(\",\",\" \").lower().split(\" \"):\n", 251 | "# need to replace any punct that occurs\n", 252 | "all = []\n", 253 | "for idx, w in enumerate(train_inputs[12].replace(\"'\", \"\").lower().split(\" \")):\n", 254 | " print(w)\n", 255 | " try:\n", 256 | " val = bowVect.vocabulary_[w]\n", 257 | " print(val)\n", 258 | " all.append(val)\n", 259 | " except:\n", 260 | " pass\n", 261 | "print(sorted(all))\n", 262 | "\n", 263 | "print(train_inputs[12])" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "metadata": { 270 | "pycharm": { 271 | "name": "#%%\n" 272 | } 273 | }, 274 | "outputs": [], 275 | "source": [ 276 | "matrix = bow_vectorizer.transform([train_inputs[12]])\n", 277 | "print(matrix)\n", 278 | "for i,j in matrix:\n", 279 | " print(bowVect.vocabulary_[])\n", 280 | "\n", 281 | "bow_vectorizer.inverse_transform(matrix)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "metadata": { 288 | "pycharm": { 289 | "name": "#%%\n" 290 | } 291 | }, 292 | "outputs": [], 293 | "source": [ 294 | "import string\n", 295 | "print(string.punctuation)" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": null, 301 | "metadata": { 302 | "pycharm": { 303 | "name": "#%%\n" 304 | } 305 | }, 306 | "outputs": [], 307 | "source": [ 308 | "for i, c in enumerate(train):\n", 309 | " if len(c.lengths) > 1:\n", 310 | " print(c.idx)\n", 311 | " print(i)\n", 312 | " break\n", 313 | "print(train_inputs[12])" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": null, 319 | "metadata": { 320 | "pycharm": { 321 | "name": "#%%\n" 322 | } 323 | }, 324 | "outputs": [], 325 | "source": [] 326 | } 327 | ], 328 | "metadata": { 329 | "kernelspec": { 330 | "display_name": "Python 3", 331 | "language": "python", 332 | "name": "python3" 333 | }, 334 | "language_info": { 335 | "codemirror_mode": { 336 | "name": "ipython", 337 | "version": 2 338 | }, 339 | "file_extension": ".py", 340 | "mimetype": "text/x-python", 341 | "name": "python", 342 | "nbconvert_exporter": "python", 343 | "pygments_lexer": "ipython2", 344 | "version": "2.7.6" 345 | } 346 | }, 347 | "nbformat": 4, 348 | "nbformat_minor": 0 349 | } 350 | -------------------------------------------------------------------------------- /baselines/baseline_t5.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "pycharm": { 7 | "name": "#%% md\n" 8 | } 9 | }, 10 | "source": [ 11 | "## For running the T5 vanilla baseline\n", 12 | "Here we provide code to\n", 13 | "1. Setup datafiles for running t5 (i.e. produce json files)\n", 14 | "1. Run the seq2seq model to produce outputs, saving in a directory that matches k_t5_outputs below\n", 15 | " - Train model (produces saved checkpoints)\n", 16 | " - Eval top performing model (load from top checkpoint, produces json outputs)\n", 17 | "\n", 18 | "1. run eval code using the json output from model eval" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": { 25 | "pycharm": { 26 | "name": "#%%\n" 27 | } 28 | }, 29 | "outputs": [], 30 | "source": [ 31 | "%load_ext autoreload\n", 32 | "%autoreload 2\n", 33 | "\n", 34 | "from decrypt.scrape_parse import (\n", 35 | " load_guardian_splits,\n", 36 | " load_guardian_splits_disjoint,\n", 37 | " load_guardian_splits_disjoint_hash\n", 38 | ")\n", 39 | "\n", 40 | "import os\n", 41 | "from decrypt import config\n", 42 | "from decrypt.common import validation_tools as vt\n", 43 | "from decrypt.common.util_data import clue_list_tuple_to_train_split_json\n", 44 | "import logging\n", 45 | "logging.getLogger(__name__)\n", 46 | "\n", 47 | "\n", 48 | "k_json_folder = config.DataDirs.Guardian.json_folder" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": { 54 | "pycharm": { 55 | "name": "#%% md\n" 56 | } 57 | }, 58 | "source": [ 59 | "## 1 Produce datasets" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "metadata": { 66 | "pycharm": { 67 | "name": "#%%\n" 68 | } 69 | }, 70 | "outputs": [], 71 | "source": [ 72 | "def make_dataset(split_type: str, overwrite=False):\n", 73 | " assert split_type in ['naive_random', 'naive_disjoint', 'word_init_disjoint']\n", 74 | " if split_type == 'naive_random':\n", 75 | " load_fn = load_guardian_splits\n", 76 | " tgt_dir = config.DataDirs.DataExport.guardian_naive_random_split\n", 77 | " elif split_type == 'naive_disjoint':\n", 78 | " load_fn = load_guardian_splits_disjoint\n", 79 | " tgt_dir = config.DataDirs.DataExport.guardian_naive_disjoint_split\n", 80 | " else:\n", 81 | " load_fn = load_guardian_splits_disjoint_hash\n", 82 | " tgt_dir = config.DataDirs.DataExport.guardian_word_init_disjoint_split\n", 83 | "\n", 84 | " _, _, (train, val, test) = load_fn(k_json_folder)\n", 85 | "\n", 86 | " os.makedirs(tgt_dir, exist_ok=True)\n", 87 | " # write the output as json\n", 88 | " try:\n", 89 | " clue_list_tuple_to_train_split_json((train, val, test),\n", 90 | " comment=f'Guardian data. Split: {split_type}',\n", 91 | " export_dir=tgt_dir,\n", 92 | " overwrite=overwrite)\n", 93 | " except FileExistsError:\n", 94 | " logging.warning(f'You have already generated the {split_type} dataset.\\n'\n", 95 | " f'It is located at {tgt_dir}\\n'\n", 96 | " f'To regenerate, pass overwrite=True or delete it\\n')\n", 97 | "\n", 98 | "\n", 99 | "make_dataset('naive_random')\n", 100 | "make_dataset('word_init_disjoint')\n", 101 | "# you can also make_dataset('naive_disjoint')" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": { 107 | "pycharm": { 108 | "name": "#%% md\n" 109 | } 110 | }, 111 | "source": [ 112 | "## 2 Running (training) the model\n", 113 | "1. Setup environment\n", 114 | " 1. You should setup wandb for logging (that's where metrics will show up).\n", 115 | " If you try to run without wandb, then wandb will tell you what you need to do to initialize\n", 116 | "\n", 117 | " 1. The relevant libraries used for our runs are\n", 118 | " - transformers==4.4.2\n", 119 | " - wandb==0.10.13 # this can probably be updated\n", 120 | " - torch==1.7.1+cu110\n", 121 | " - torchvision==0.8.2+cu110\n", 122 | " - Choose place for your wandb dir, e.g., `'./wandb' `\n", 123 | "1. Train the model\n", 124 | " 1. Note that the default arguments are given in args_cryptic. See `--default_train` and `--default_val`\n", 125 | " - Note that, when looking at logging messages or wandb, it will appear that epochs start at 11.\n", 126 | " This is done so that we have \"space\" for 10 \"warmup\" epochs for curricular training.\n", 127 | " This space causes all plots in wandb to line up.\n", 128 | " 1. from directory seq2seq, run the commands in the box below.\n", 129 | " This will produce model checkpoints that can then be used for evaluation.\n", 130 | "\n", 131 | "\n", 132 | "\n", 133 | "Baseline naive\n", 134 | "```python\n", 135 | "python train_clues.py --default_train=base --name=baseline_naive --project=baseline --wandb_dir='./wandb' --data_dir='../data/clue_json/guardian/naive_random'\n", 136 | "```\n", 137 | "Baseline (naive split), without lengths\n", 138 | "```python\n", 139 | "python train_clues.py --default_train=base --name=baseline_naive_nolens --project=baseline --wandb_dir='./wandb' --data_dir='../data/clue_json/guardian/word_initial_disjoint' --special=no_lens\n", 140 | "```\n", 141 | "Baseline disjoint (word initial disjoint)\n", 142 | "```python\n", 143 | "python train_clues.py --default_train=base --name=baseline_disj --project=baseline --wandb_dir='./wandb' --data_dir='../data/clue_json/guardian/word_initial_disjoint'\n", 144 | "```\n", 145 | "Baseline disjoint (word initial disjoint), without lengths\n", 146 | "```python\n", 147 | "python train_clues.py --default_train=base --name=baseline_disj --project=baseline --wandb_dir='./wandb' --data_dir='../data/clue_json/guardian/word_initial_disjoint' --special=no_lens\n", 148 | "```" 149 | ] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "metadata": { 154 | "pycharm": { 155 | "name": "#%% md\n" 156 | } 157 | }, 158 | "source": [ 159 | "## 3 Evaluating the model\n", 160 | "During training we generate only 5 beams for efficiency. For eval we generate 100.\n", 161 | "1. Select the best model based on num_match_top_sampled. There should be\n", 162 | "a logging statement at the end of the run that prints the location\n", 163 | "of the best model checkpoint.\n", 164 | "You can also find it by matching the peak in the wandb.ai metrics graph\n", 165 | "to the appropriate model save.\n", 166 | "2. Run eval using that model (see commands below), which will\n", 167 | "produce a file in a (new, different) wandb directory that looks like `epoch_11.pth.tar.preds.json` (i.e only a single epoch)\n", 168 | "\n", 169 | "For example,\n", 170 | "\n", 171 | "Baseline naive, if epoch 20 is best (you'll need to set the run_name)\n", 172 | "This produces generations for the validation set\n", 173 | "```python\n", 174 | "python train_clues.py --default_val=base --name=baseline_naive_val --project=baseline --data_dir='../data/clue_json/guardian/naive_random' --ckpt_path='./wandb/run_name/files/epoch_20.pth.tar\n", 175 | "```\n", 176 | "\n", 177 | "To produce generations for the test set,\n", 178 | "```python\n", 179 | "python train_clues.py --default_val=base --name=baseline_naive_val --project=baseline --data_dir='../data/clue_json/guardian/naive_random' --ckpt_path='./wandb/run_name/files/epoch_10.pth.tar --test\n", 180 | "```\n", 181 | "\n", 182 | "This should also be run for the no-lengths versions if you want to replicate those results.\n" 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "metadata": { 188 | "pycharm": { 189 | "name": "#%% md\n" 190 | } 191 | }, 192 | "source": [ 193 | "Now we produce metrics by evaluating the json that was produced\n", 194 | "1. Change the k_t5_outputs_dir value to the location where you have saved the json files. \n", 195 | " - Recommend copying all of the preds.json files into a common directory and working from that.\n", 196 | " - Alternatively you could modify the code below and pass in a full path name to each of the json outputs (using the wandb directory path)\n", 197 | "1. For each t5 model eval (above) that you ran (each of which produced some `..preds.json` file, run `load_and_run()` to get metrics for those outputs\n", 198 | "1. The resulting outputs are the values we report in the tables. See `decrypt/common/validation_tools.ModelEval` for more details about the numbers that are produced. Percentages are prefixed by agg_\n", 199 | "\n", 200 | "Note that for the Main Results Table 2, the metrics we include in the table correspond to\n", 201 | "- `agg_top_match`\n", 202 | "- `agg_top_10_after_filter`\n", 203 | "\n", 204 | "More details of these metric calculations can be found in `decrypt.common.validation_tools`" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": { 211 | "pycharm": { 212 | "name": "#%%\n" 213 | } 214 | }, 215 | "outputs": [], 216 | "source": [ 217 | "# for example, if your output files are in\n", 218 | "# 'decrypt/t5_outputs/'\n", 219 | "# and you will run the below, e.g., if you have named the files\n", 220 | "# baseline_naive_e12_test.json\n", 221 | "# (.json will be appended for you by the load_and_run_t5 function)\n", 222 | "# a better name for load_and_run is load_and_eval\n", 223 | "\n", 224 | "# for example\n", 225 | "### primary - test\n", 226 | "vt.load_and_run_t5('decrypt/t5_outputs/baseline_naive_e12_test')\n", 227 | "vt.load_and_run_t5('decrypt/t5_outputs/baseline_naive_nolens_e15_test') # test set\n", 228 | "\n", 229 | "## primary val\n", 230 | "vt.load_and_run_t5('decrypt/t5_outputs/baseline_naive_e12_val')\n", 231 | "vt.load_and_run_t5('decrypt/t5_outputs/baseline_naive_nolens_e15_val')\n", 232 | "\n" 233 | ] 234 | } 235 | ], 236 | "metadata": { 237 | "kernelspec": { 238 | "display_name": "Python 3", 239 | "language": "python", 240 | "name": "python3" 241 | }, 242 | "language_info": { 243 | "codemirror_mode": { 244 | "name": "ipython", 245 | "version": 2 246 | }, 247 | "file_extension": ".py", 248 | "mimetype": "text/x-python", 249 | "name": "python", 250 | "nbconvert_exporter": "python", 251 | "pygments_lexer": "ipython2", 252 | "version": "2.7.6" 253 | } 254 | }, 255 | "nbformat": 4, 256 | "nbformat_minor": 0 257 | } 258 | -------------------------------------------------------------------------------- /baselines/baseline_wn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "pycharm": { 8 | "name": "#%%\n" 9 | } 10 | }, 11 | "outputs": [], 12 | "source": [ 13 | "\"\"\"Heuristic wordnet baseline\"\"\"\n", 14 | "###" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "pycharm": { 22 | "name": "#%%\n" 23 | } 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "%load_ext autoreload\n", 28 | "%autoreload 2" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": { 35 | "pycharm": { 36 | "name": "#%%\n" 37 | } 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "k_json_folder = '../puzzles/'" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": { 48 | "pycharm": { 49 | "name": "#%%\n" 50 | } 51 | }, 52 | "outputs": [], 53 | "source": [ 54 | "from decrypt.scrape_parse import (\n", 55 | " load_guardian_splits,\n", 56 | " load_guardian_splits_disjoint_hash\n", 57 | ")\n", 58 | "\n", 59 | "import random\n", 60 | "from typing import *\n", 61 | "\n", 62 | "import jellyfish\n", 63 | "\n", 64 | "from multiset import Multiset\n", 65 | "from nltk.corpus import wordnet as wn\n", 66 | "from tqdm import tqdm\n", 67 | "\n", 68 | "from decrypt.common.puzzle_clue import GuardianClue\n", 69 | "from decrypt.common.util_wordnet import all_inflect\n", 70 | "from decrypt.common import validation_tools as vt" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": { 77 | "pycharm": { 78 | "name": "#%%\n" 79 | } 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "# Wordnet functions to produce reverse dictionary sets\n", 84 | "\n", 85 | "def normalize(lemma):\n", 86 | " \"\"\"Wordnet returns words with underscores and hyphens. We replace them with spaces. This possibly does not work well with lemminflect.\"\"\"\n", 87 | " return lemma.replace(\"_\",\" \").replace(\"-\",\" \")\n", 88 | "\n", 89 | "def get_syns(w: str) -> Set[str]:\n", 90 | " \"\"\"\n", 91 | " Get all synonyms of w\n", 92 | " \"\"\"\n", 93 | " ret = set()\n", 94 | " for ss in wn.synsets(w):\n", 95 | " for l in ss.lemma_names():\n", 96 | " ret.add(normalize(l))\n", 97 | " return ret\n", 98 | "\n", 99 | "def get_syns_hypo1(w: str) -> Set[str]:\n", 100 | " \"\"\"\n", 101 | " Get all synonyms and hyponyms to depth 1\n", 102 | " \"\"\"\n", 103 | " ret = set()\n", 104 | " for ss in wn.synsets(w):\n", 105 | " for l in ss.lemma_names():\n", 106 | " ret.add(normalize(l))\n", 107 | " for rel_ss in ss.hyponyms():\n", 108 | " for l in rel_ss.lemma_names():\n", 109 | " ret.add(normalize(l))\n", 110 | " return ret\n", 111 | "\n", 112 | "def get_syns_hypo_all(w: str, include_hyper=False, depth=3) -> Set[str]:\n", 113 | " \"\"\"\n", 114 | " Get all synonyms; hyponyms to depth, depth; and hypernyms to depth, depth,\n", 115 | " if include_hyper is True\n", 116 | "\n", 117 | " :param w: word to lookup\n", 118 | " :param include_hyper: whether to do hypernym lookup\n", 119 | " :param depth: how far to go in hyponym / hypernym traversal\n", 120 | " \"\"\"\n", 121 | " ret = set()\n", 122 | " for ss in wn.synsets(w):\n", 123 | " for l in ss.lemma_names():\n", 124 | " ret.add(normalize(l))\n", 125 | " if include_hyper:\n", 126 | " for rel_ss in ss.closure(lambda s: s.hypernyms(), depth=depth):\n", 127 | " for l in rel_ss.lemma_names():\n", 128 | " ret.add(normalize(l))\n", 129 | " for rel_ss in ss.closure(lambda s: s.hyponyms(), depth=depth):\n", 130 | " for l in rel_ss.lemma_names():\n", 131 | " ret.add(normalize(l))\n", 132 | " return ret\n", 133 | "\n", 134 | "def get_first_and_last_word(c: GuardianClue):\n", 135 | " clue_words = c.clue.split(\" \")\n", 136 | " return clue_words[0], clue_words[-1]\n" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": { 143 | "pycharm": { 144 | "name": "#%%\n" 145 | } 146 | }, 147 | "outputs": [], 148 | "source": [ 149 | "def pct_sim(str1, str2):\n", 150 | " max_len = max(len(str1), len(str2))\n", 151 | " lev = jellyfish.levenshtein_distance(str1, str2)\n", 152 | " return 1.0 - lev/max_len\n", 153 | "\n", 154 | "def eval_wn(val_set: List[GuardianClue],\n", 155 | " fcn: Callable,\n", 156 | " do_fuzzy: bool,\n", 157 | " do_rank: bool = False,\n", 158 | " **fcn_kwargs):\n", 159 | " \"\"\"\n", 160 | " :param val_set:\n", 161 | " :param fcn:\n", 162 | " :param do_fuzzy:\n", 163 | " :param fcn_kwargs:\n", 164 | " :return:\n", 165 | " \"\"\"\n", 166 | " rng = random.Random()\n", 167 | " rng.seed(42)\n", 168 | "\n", 169 | " model_outputs = []\n", 170 | " for val_gc in tqdm(val_set):\n", 171 | " all_possible = set()\n", 172 | "\n", 173 | " # add the direct synonyms\n", 174 | " for w in get_first_and_last_word(val_gc):\n", 175 | " all_possible.update(list(fcn(w.lower(), **fcn_kwargs)))\n", 176 | "\n", 177 | " # potentially add lemmas\n", 178 | " if do_fuzzy:\n", 179 | " orig = all_possible.copy()\n", 180 | " for w in orig:\n", 181 | " all_possible.update(all_inflect(w, None))\n", 182 | "\n", 183 | " _, filtered = vt.filter_to_len(val_gc.soln_with_spaces, all_possible)\n", 184 | " filtered_final = [x[0] for x in filtered] # go back to with spaces\n", 185 | "\n", 186 | " # jellyfish score\n", 187 | " # # if do_rank:\n", 188 | " # # list_with_rank = []\n", 189 | " # # for out in filtered_final:\n", 190 | " # # score = pct_sim(out, val_gc.clue)\n", 191 | " # # list_with_rank.append((out, score))\n", 192 | " # # # sort\n", 193 | " # # list_sorted = sorted(list_with_rank, key=lambda x: x[1], reverse=True)\n", 194 | " # # # take the word not the score\n", 195 | " # filtered_final = [x[0] for x in list_sorted]\n", 196 | "\n", 197 | " # simple character overlap\n", 198 | " if do_rank:\n", 199 | " list_with_rank = []\n", 200 | " mset = Multiset(val_gc.clue)\n", 201 | " for out in filtered_final:\n", 202 | " score = len(mset.intersection(Multiset(out)))\n", 203 | " list_with_rank.append((out, score))\n", 204 | " # sort\n", 205 | " list_sorted = sorted(list_with_rank, key=lambda x: x[1], reverse=True)\n", 206 | " # take the word not the score\n", 207 | " filtered_final = [x[0] for x in list_sorted]\n", 208 | " else:\n", 209 | " rng.shuffle(filtered_final)\n", 210 | "\n", 211 | " mp = vt.ModelPrediction(\n", 212 | " idx=val_gc.idx,\n", 213 | " input=val_gc.clue_with_lengths(),\n", 214 | " target=val_gc.soln_with_spaces,\n", 215 | " greedy=\"\",\n", 216 | " sampled=filtered_final)\n", 217 | "\n", 218 | " mp.model_eval = vt.eval(mp)\n", 219 | " model_outputs.append(mp)\n", 220 | "\n", 221 | " return model_outputs\n", 222 | "\n", 223 | "\n", 224 | "\n" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "metadata": { 230 | "pycharm": { 231 | "name": "#%% md\n" 232 | } 233 | }, 234 | "source": [ 235 | "Below we provide code to run the two models to produce row 1 of Main Results\n", 236 | "Table 2 in the paper.\n", 237 | "\n", 238 | "Other combinations of (unreported) hyperparameters can be tested by changing\n", 239 | "- the fcn passed to eval_wn\n", 240 | "- do_fuzzy\n", 241 | "- do_rank (or changing how ranking is computed -- uncomment the jellyfish code above)\n", 242 | "\n", 243 | "Note that for the Main Results Table 2, the metrics we include in the table correspond to\n", 244 | "- `agg_top_match`\n", 245 | "- `agg_top_10_after_filter`\n", 246 | "\n", 247 | "More details of these metric calculations can be found in `decrypt.common.validation_tools`\n", 248 | "\n" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": null, 254 | "metadata": { 255 | "pycharm": { 256 | "name": "#%%\n" 257 | } 258 | }, 259 | "outputs": [], 260 | "source": [ 261 | "#################\n", 262 | "# this is the primary baseline\n", 263 | "######################\n", 264 | "\n", 265 | "# naive set\n", 266 | "def run_primary_wn_naive():\n", 267 | " _, _, (_, val_orig, test_orig) = load_guardian_splits(k_json_folder)\n", 268 | " out1 = eval_wn(val_orig, fcn=get_syns_hypo1, do_fuzzy=False, do_rank=True) # 1711\n", 269 | " print('val results')\n", 270 | " vt.all_aggregate(out1, label='syns,hypo1; no fuzzy, ranked by char overlap')\n", 271 | "\n", 272 | " print('test results')\n", 273 | " out2 = eval_wn(test_orig, fcn=get_syns_hypo1, do_fuzzy=False, do_rank=True) # 1711\n", 274 | " vt.all_aggregate(out2, label='syns,hypo1; no fuzzy, ranked by char overlap')\n", 275 | "\n", 276 | "run_primary_wn_naive()\n" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": { 283 | "pycharm": { 284 | "name": "#%%\n" 285 | } 286 | }, 287 | "outputs": [], 288 | "source": [ 289 | "##\n", 290 | "# run on disjoint set\n", 291 | "##\n", 292 | "\n", 293 | "def run_primary_wn_disj2():\n", 294 | " _, _, (_, val_orig, test_orig) = load_guardian_splits_disjoint_hash(k_json_folder)\n", 295 | " print('val results')\n", 296 | " out1 = eval_wn(val_orig, fcn=get_syns_hypo1, do_fuzzy=False, do_rank=True) # 1711\n", 297 | " vt.all_aggregate(out1, label='syns,hypo1; no fuzzy, ranked by char overlap')\n", 298 | "\n", 299 | " print('test results')\n", 300 | " out2 = eval_wn(test_orig, fcn=get_syns_hypo1, do_fuzzy=False, do_rank=True) # 1711\n", 301 | " vt.all_aggregate(out2, label='syns,hypo1; no fuzzy, ranked by char overlap')\n", 302 | "\n", 303 | "run_primary_wn_disj2()\n", 304 | "\n" 305 | ] 306 | } 307 | ], 308 | "metadata": { 309 | "kernelspec": { 310 | "display_name": "Python 3", 311 | "language": "python", 312 | "name": "python3" 313 | }, 314 | "language_info": { 315 | "codemirror_mode": { 316 | "name": "ipython", 317 | "version": 2 318 | }, 319 | "file_extension": ".py", 320 | "mimetype": "text/x-python", 321 | "name": "python", 322 | "nbconvert_exporter": "python", 323 | "pygments_lexer": "ipython2", 324 | "version": "2.7.6" 325 | } 326 | }, 327 | "nbformat": 4, 328 | "nbformat_minor": 0 329 | } 330 | -------------------------------------------------------------------------------- /data/disjoint.json.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrozner/decrypt/71679f76ce19cd098b0988d0255b03d7e47f0c05/data/disjoint.json.zip -------------------------------------------------------------------------------- /data/disjoint_word_init.json.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrozner/decrypt/71679f76ce19cd098b0988d0255b03d7e47f0c05/data/disjoint_word_init.json.zip -------------------------------------------------------------------------------- /data/guardian_2020_10_08.json.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrozner/decrypt/71679f76ce19cd098b0988d0255b03d7e47f0c05/data/guardian_2020_10_08.json.zip -------------------------------------------------------------------------------- /data/naive_random.json.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrozner/decrypt/71679f76ce19cd098b0988d0255b03d7e47f0c05/data/naive_random.json.zip -------------------------------------------------------------------------------- /data/original/names/README.txt: -------------------------------------------------------------------------------- 1 | Downloaded from 2 | https://www.ons.gov.uk/peoplepopulationandcommunity/birthsdeathsandmarriages/livebirths/bulletins/babynamesenglandandwales/2018 3 | https://www.usna.edu/Users/cs/roche/courses/s15si335/proj1/files.php%3Ff=names.txt&downloadcode=yes 4 | -------------------------------------------------------------------------------- /data/original/names/boys.txt: -------------------------------------------------------------------------------- 1 | Aaron 2 | Adam 3 | Aidan 4 | Aiden 5 | Albert 6 | Albie 7 | Alex 8 | Alexander 9 | Alfie 10 | Andrew 11 | Anthony 12 | Archie 13 | Arlo 14 | Arthur 15 | Ashley 16 | Ashton 17 | Austin 18 | Bailey 19 | Ben 20 | Benjamin 21 | Billy 22 | Blake 23 | Bobby 24 | Bradley 25 | Brandon 26 | Caleb 27 | Callum 28 | Cameron 29 | Carter 30 | Charles 31 | Charlie 32 | Christopher 33 | Connor 34 | Conor 35 | Corey 36 | Daniel 37 | David 38 | Declan 39 | Dexter 40 | Dominic 41 | Dylan 42 | Edward 43 | Elijah 44 | Elliot 45 | Elliott 46 | Ellis 47 | Ethan 48 | Evan 49 | Ewan 50 | Ezra 51 | Felix 52 | Finlay 53 | Finley 54 | Finn 55 | Frankie 56 | Freddie 57 | Frederick 58 | Gabriel 59 | George 60 | Grayson 61 | Harley 62 | Harrison 63 | Harry 64 | Harvey 65 | Hayden 66 | Henry 67 | Hugo 68 | Hunter 69 | Ibrahim 70 | Isaac 71 | Jack 72 | Jackson 73 | Jacob 74 | Jake 75 | James 76 | Jamie 77 | Jason 78 | Jasper 79 | Jaxon 80 | Jay 81 | Jayden 82 | Jenson 83 | Jesse 84 | Joe 85 | Joel 86 | Joey 87 | John 88 | Jonathan 89 | Jordan 90 | Joseph 91 | Josh 92 | Joshua 93 | Jude 94 | Kai 95 | Kane 96 | Kayden 97 | Kian 98 | Kieran 99 | Kieron 100 | Kyle 101 | Lee 102 | Leo 103 | Leon 104 | Lewis 105 | Liam 106 | Logan 107 | Louie 108 | Louis 109 | Luca 110 | Lucas 111 | Luke 112 | Marcus 113 | Mark 114 | Mason 115 | Matthew 116 | Max 117 | Michael 118 | Mitchell 119 | Mohammad 120 | Mohammed 121 | Morgan 122 | Muhammad 123 | Nathan 124 | Nicholas 125 | Noah 126 | Oliver 127 | Ollie 128 | Oscar 129 | Owen 130 | Patrick 131 | Peter 132 | Ralph 133 | Reece 134 | Reggie 135 | Reuben 136 | Rhys 137 | Riley 138 | Robert 139 | Roman 140 | Ronnie 141 | Rory 142 | Ross 143 | Rowan 144 | Ryan 145 | Sam 146 | Samuel 147 | Scott 148 | Sean 149 | Sebastian 150 | Seth 151 | Sonny 152 | Spencer 153 | Stanley 154 | Stephen 155 | Taylor 156 | Teddy 157 | Theo 158 | Theodore 159 | Thomas 160 | Tobias 161 | Toby 162 | Tom 163 | Tommy 164 | Tyler -------------------------------------------------------------------------------- /data/original/names/girls.txt: -------------------------------------------------------------------------------- 1 | Aaliyah 2 | Abbie 3 | Abby 4 | Abigail 5 | Ada 6 | Aimee 7 | Aisha 8 | Alexandra 9 | Alice 10 | Alicia 11 | Alisha 12 | Amber 13 | Amelia 14 | Amelie 15 | Amy 16 | Anna 17 | Annabelle 18 | Arabella 19 | Aria 20 | Aurora 21 | Ava 22 | Ayla 23 | Beatrice 24 | Bella 25 | Bethan 26 | Bethany 27 | Bonnie 28 | Brooke 29 | Caitlin 30 | Catherine 31 | Cerys 32 | Charlie 33 | Charlotte 34 | Chelsea 35 | Chloe 36 | Clara 37 | Courtney 38 | Daisy 39 | Danielle 40 | Darcey 41 | Darcie 42 | Darcy 43 | Delilah 44 | Demi 45 | Edith 46 | Eleanor 47 | Elise 48 | Eliza 49 | Elizabeth 50 | Ella 51 | Elle 52 | Ellen 53 | Ellie 54 | Eloise 55 | Elsie 56 | Emilia 57 | Emily 58 | Emma 59 | Erin 60 | Esme 61 | Eva 62 | Eve 63 | Evelyn 64 | Evie 65 | Faith 66 | Felicity 67 | Florence 68 | Francesca 69 | Freya 70 | Gabrielle 71 | Gemma 72 | Georgia 73 | Georgina 74 | Grace 75 | Gracie 76 | Hallie 77 | Hannah 78 | Harper 79 | Harriet 80 | Heidi 81 | Hollie 82 | Holly 83 | Imogen 84 | Iris 85 | Isabel 86 | Isabella 87 | Isabelle 88 | Isla 89 | Isobel 90 | Ivy 91 | Jade 92 | Jasmine 93 | Jennifer 94 | Jessica 95 | Jodie 96 | Jordan 97 | Julia 98 | Kate 99 | Katherine 100 | Katie 101 | Kayla 102 | Kayleigh 103 | Keira 104 | Kiera 105 | Kirsty 106 | Lacey 107 | Laila 108 | Lara 109 | Laura 110 | Lauren 111 | Layla 112 | Leah 113 | Lexi 114 | Lexie 115 | Libby 116 | Lilly 117 | Lily 118 | Lola 119 | Lottie 120 | Louise 121 | Lucy 122 | Luna 123 | Lydia 124 | Lyla 125 | Maddison 126 | Madeleine 127 | Madison 128 | Maisie 129 | Maisy 130 | Margot 131 | Maria 132 | Martha 133 | Maryam 134 | Matilda 135 | Maya 136 | Megan 137 | Melissa 138 | Mia 139 | Mila 140 | Millie 141 | Mollie 142 | Molly 143 | Morgan 144 | Mya 145 | Nancy 146 | Naomi 147 | Natalie 148 | Natasha 149 | Niamh 150 | Nicole 151 | Olivia 152 | Orla 153 | Paige 154 | Penelope 155 | Phoebe 156 | Poppy 157 | Rachel 158 | Rebecca 159 | Rhiannon 160 | Robyn 161 | Rose 162 | Rosie 163 | Ruby 164 | Samantha 165 | Sara 166 | Sarah 167 | Scarlett 168 | Shannon 169 | Sienna 170 | Skye 171 | Sofia 172 | Sophia 173 | Sophie 174 | Stephanie 175 | Summer 176 | Tegan 177 | Thea 178 | Tia 179 | Tilly 180 | Victoria 181 | Violet 182 | Willow 183 | Yasmin 184 | Zara 185 | Zoe -------------------------------------------------------------------------------- /data/original/us/US.dic: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrozner/decrypt/71679f76ce19cd098b0988d0255b03d7e47f0c05/data/original/us/US.dic -------------------------------------------------------------------------------- /data/original/us/US.txt: -------------------------------------------------------------------------------- 1 | This dictionary was originally compiled from public domain sources 2 | for the amSpell spell-checker by 3 | 4 | Erik Frambach (e-mail: e.h.m.frambach@eco.rug.nl). 5 | 6 | 7 | I have further modified this dictionary for use with WinEdt: 8 | The dictionary is decompressed, translated from OEM to Windows 9 | Character set, and sorted by WinEdt's Dictionary Manager. 10 | 11 | alex 12 | -------------------------------------------------------------------------------- /decrypt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrozner/decrypt/71679f76ce19cd098b0988d0255b03d7e47f0c05/decrypt/__init__.py -------------------------------------------------------------------------------- /decrypt/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrozner/decrypt/71679f76ce19cd098b0988d0255b03d7e47f0c05/decrypt/common/__init__.py -------------------------------------------------------------------------------- /decrypt/common/anagrammer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from origcrypt/anagrammer 3 | 4 | An Anagrammer has a dictionary mapping (sorted letters) => AnagramSet 5 | 6 | AnagramSets are used for reading in the various dictionaries to generate lists of potential anagrams. 7 | An AnagramSet corresponds to a set of sorted letters. It has two dictionaries: 8 | For each of these, a sort of the unsorted letters gives the sorted letters 9 | - map (unsorted letters) => [One word anagram] 10 | - map (unsorted letters) => [multi-word-anagram] 11 | In both cases we represent the anagrammable as a list of words. For a single word anagram the list has len == 1 12 | """ 13 | import logging 14 | import random 15 | import shelve 16 | import string 17 | from collections import Counter 18 | from collections import defaultdict 19 | from os import path 20 | from typing import * 21 | 22 | from tqdm import tqdm 23 | 24 | import decrypt.config as config 25 | from .util_spellchecker import ( 26 | line_parser_US_dic, 27 | SpellChecker, 28 | get_shelve_dbhandler_open_flag 29 | ) 30 | 31 | logging.getLogger(__name__) 32 | 33 | k_default_base_input_file_name = str(config.DataDirs.OriginalData.k_US_dic) 34 | k_default_output_file_name = str(config.DataDirs.Generated.anagram_db) 35 | 36 | class AnagramSet: 37 | """ 38 | Assumes that a given ordering of letters has only one valid parse into words 39 | 40 | """ 41 | def __init__(self, new_item: List[str], new_item_ltrs: str): 42 | # For one-word-anagrams, each list will have only one word 43 | # But this is done so that both dicts look the same 44 | self.one_word_anagrams: Dict[str, List[str]] = defaultdict() 45 | self.multi_word_anagrams: Dict[str, List[str]] = defaultdict() 46 | self._num_anagrams = 0 47 | self._num_one_word_anagrams = 0 48 | 49 | self.add_to_anag_set(new_item, new_item_ltrs) 50 | 51 | def add_to_anag_set(self, item: List[str], item_ltrs: str, 52 | log_errors = True) -> bool: 53 | if len(item) == 1: 54 | existing = self.one_word_anagrams.get(item_ltrs) 55 | if existing is not None: 56 | if log_errors: 57 | logging.error(f"Double inserting {existing}, {item}") 58 | return False 59 | self.one_word_anagrams[item_ltrs] = item 60 | self._num_one_word_anagrams += 1 61 | else: 62 | existing = self.multi_word_anagrams.get(item_ltrs) 63 | if existing is not None: 64 | if log_errors: 65 | logging.error(f"Double inserting {existing}, {item}") 66 | return False 67 | self.multi_word_anagrams[item_ltrs] = item 68 | 69 | self._num_anagrams += 1 70 | return True 71 | 72 | # todo: is this necessary? need to make sure we don't modify the lookup 73 | def get_lists(self) -> Tuple[List[List[str]], List[List[str]]]: 74 | """ 75 | Returns: Tuple: (one word anagrams, multi-word-anagrams) 76 | where each is a List of Lists 77 | 78 | """ 79 | # list and tuple are equivalent 80 | return list(self.one_word_anagrams.values()), list(self.multi_word_anagrams.values()) 81 | 82 | 83 | 84 | class Anagrammer(): 85 | """ Anagram database 86 | 87 | Attributes: 88 | db: opened anagram database. See genanagrams for details of the structure. 89 | db is a map from str => AnagramSet 90 | 91 | """ 92 | def __init__(self, anagram_database): 93 | # logging.info(f"Initializing Singleton Anagrammer from {anagram_database}") 94 | logging.info(f"Initializing (non-singleton) Anagrammer from {anagram_database}") 95 | self._translation_table = str.maketrans('','',string.punctuation) 96 | self.db = self.__init_db(anagram_database) 97 | 98 | self._possible_anagrams: Optional[List[AnagramSet]] = None 99 | logging.info(f"DONE: Initialized Anagrammer from {anagram_database}") 100 | 101 | def __init_db(self, anagram_database): 102 | if not path.exists(anagram_database + ".db"): 103 | logging.exception(f'Given anagram database file {anagram_database + ".db"} does not exist') 104 | raise Exception(f'Given anagram database file {anagram_database} does not exist') 105 | try: 106 | db = shelve.open(anagram_database, flag='r') 107 | logging.debug("Opened anagram database successfully") 108 | return db 109 | except Exception as e: 110 | logging.exception(f'While trying to open {anagram_database}, exception:') 111 | print(f"path is: {path.curdir}") 112 | 113 | raise e 114 | 115 | def __look_up(self, char_string: str) -> Optional[AnagramSet]: 116 | """ Perform a lookup on a set of characters. Internal method. 117 | 118 | :param str char_string: letters to use in lookup 119 | :return: Valid anagrams 120 | :rtype: list[str] 121 | """ 122 | 123 | chars_no_punct = char_string.translate(self._translation_table) 124 | lookup = "".join(sorted(chars_no_punct)) # Sort for hashing, essentially 125 | if lookup in self.db: 126 | lookup_result: AnagramSet = self.db[lookup] 127 | return lookup_result 128 | else: 129 | return None 130 | 131 | def get_anagrams(self, letters: str, 132 | remove_letters=False, 133 | include_multi_word_anagrams=False) -> List[List[str]]: 134 | """ 135 | """ 136 | logging.debug("Looking up in anagram db: " + letters) 137 | result_anag_set = self.__look_up(letters) 138 | 139 | if result_anag_set is None: 140 | return [] 141 | 142 | one_word_anagrams, multi_word_anagrams = result_anag_set.get_lists() 143 | # todo: should assert something in case we have one word 144 | 145 | # don't return the letters themselves 146 | if remove_letters and [letters] in one_word_anagrams: 147 | one_word_anagrams.remove([letters]) 148 | 149 | results = one_word_anagrams 150 | if include_multi_word_anagrams: 151 | results.extend(multi_word_anagrams) 152 | return results 153 | 154 | def get_anagrams_flat(self, letters: str, **kwargs) -> List[str]: 155 | """ 156 | Return flat list of anagram outputs. 157 | 158 | Signature same as get_anagrams 159 | """ 160 | list_of_lists = self.get_anagrams(letters, **kwargs) 161 | return ["".join(x) for x in list_of_lists] 162 | 163 | def is_word(self, word): 164 | """ Check if word is present in our anagram dictionary 165 | 166 | :param str word: 167 | :return: True if present in anagram dictionary 168 | :rtype: bool 169 | """ 170 | result = self.__look_up(word) 171 | if result is not None: 172 | return word in result.one_word_anagrams 173 | 174 | def get_random_anag_sample(self, sample_count: int = 20, 175 | return_set = "both") -> List[List[str]]: 176 | def rand_samp() -> Optional[List[List[str]]]: 177 | poss_anags : List[AnagramSet] = random.sample(self._possible_anagrams, sample_count) 178 | for anag_set in poss_anags: 179 | # multiword only 180 | # if return_set == "require_mu": 181 | # if anag_set._num_anagrams - anag_set._num_one_word_anagrams < 2: 182 | # continue 183 | # else: 184 | # return anag_set.get_lists()[1] 185 | # 186 | # both 187 | if return_set == "both": # we already know that this is a valid set, since we filtered 188 | one_word, multi_word = anag_set.get_lists() 189 | one_word.extend(multi_word) 190 | return one_word # we know there are sufficient num anags 191 | 192 | else: # return_set == "single" 193 | # otherwise one-word only 194 | if anag_set._num_one_word_anagrams < 2: 195 | continue 196 | else: 197 | return anag_set.get_lists()[0] 198 | return None 199 | 200 | if self._possible_anagrams is None: 201 | self._populate_possible_anagrams() 202 | res = None 203 | while res is None: 204 | res = rand_samp() 205 | return res 206 | 207 | 208 | def _populate_possible_anagrams(self): 209 | self._possible_anagrams = [] 210 | for anag_set in tqdm(self.db.values()): 211 | if anag_set._num_anagrams >=2: 212 | self._possible_anagrams.append(anag_set) 213 | logging.info(f"Total anagramable: {len(self._possible_anagrams)}") 214 | 215 | 216 | 217 | 218 | def gen_db_with_both_inputs(output_filename=k_default_output_file_name, 219 | update_flag: str=""): 220 | def add_word_to_anagram_db(x: List[str], db: shelve.DbfilenameShelf) -> bool: 221 | x_ltrs_unsorted = "".join(x) 222 | x_ltrs_sorted = "".join(sorted(x_ltrs_unsorted)) 223 | # previously x was just a string 224 | 225 | is_new_word = True 226 | if x_ltrs_sorted in db: 227 | temp: AnagramSet = db[x_ltrs_sorted] 228 | if not temp.add_to_anag_set(x, x_ltrs_unsorted, log_errors=False): 229 | is_new_word = False 230 | 231 | # we re-add back whether or not it's different. this seems to speed up the processing 232 | db[x_ltrs_sorted] = temp 233 | else: 234 | db[x_ltrs_sorted] = AnagramSet(x, x_ltrs_unsorted) # ow. insert as List 235 | 236 | return is_new_word 237 | 238 | 239 | # todo: write into shelf which files have been added as a config variable 240 | def add_file_to_database(fh, 241 | db: shelve.DbfilenameShelf, 242 | line_parser_fcn: Union[Callable[[str], List[str]], 243 | Callable[[bytes], List[str]]]) -> NoReturn: 244 | """ 245 | Generates a database mapping: 246 | => [List[one word anagrams] 247 | List[multi word anagrams represented as lists]] 248 | """ 249 | ctr = Counter() 250 | for l in tqdm(fh): 251 | x = line_parser_fcn(l) # x is a List of strs 252 | if not x: # if empty list or returns None 253 | continue 254 | ctr[len(x)] += 1 255 | 256 | if not add_word_to_anagram_db(x, db): # returns true if it's a new word 257 | ctr["dupes"] += 1 258 | 259 | print(ctr) 260 | print(f'Done.') 261 | 262 | 263 | # verify the db flags 264 | dbhandler_flag = get_shelve_dbhandler_open_flag(output_filename, update_flag=update_flag) 265 | if not dbhandler_flag: 266 | return 267 | 268 | logging.info(f"Adding to db {output_filename} with updateflag {update_flag}") 269 | with shelve.open(output_filename, flag=dbhandler_flag) as db: 270 | with open(k_default_base_input_file_name, "rb") as fh: 271 | add_file_to_database(fh, db, line_parser_fcn=line_parser_US_dic) 272 | 273 | 274 | -------------------------------------------------------------------------------- /decrypt/common/label_anagrams.py: -------------------------------------------------------------------------------- 1 | import string 2 | from collections import defaultdict 3 | from typing import * 4 | 5 | from tqdm import tqdm 6 | 7 | import decrypt.config as config 8 | from decrypt.scrape_parse.guardian_load import load_guardian_splits 9 | 10 | 11 | def make_label_set(): 12 | _, all_clues, (_, _, _) = load_guardian_splits(config.DataDirs.Guardian.json_folder, verify=True) 13 | labels: Dict[str, Set[int]] = defaultdict(set) # set of the indices for this type 14 | any_label = set() 15 | def add_to_labels(name, idx, verify=True): 16 | if verify: 17 | assert idx not in any_label 18 | any_label.add(idx) 19 | labels[name].add(idx) 20 | 21 | class PunctStripper: 22 | """ 23 | use to strip punctuation from clues (since punct is not part of outputs) 24 | """ 25 | def __init__(self): 26 | self.table_spaces = str.maketrans('','',string.punctuation + " ") # map punct and space to '' 27 | self.punct_to_space_table = str.maketrans(string.punctuation,' '*len(string.punctuation)) # map punct to space 28 | def strip(self, s: str, strip_spaces=True): 29 | """ 30 | :param s: 31 | :param strip_spaces: if true, will remove spaces; otherwise all punct will be substituted 32 | with a space, which is important for generating anagram outputs 33 | :return: 34 | """ 35 | if strip_spaces: 36 | return s.translate(self.table_spaces) 37 | else: 38 | return s.translate(self.punct_to_space_table) 39 | ps = PunctStripper() 40 | 41 | # will find hiddens / reversals (which are either direct, or direct reverse) 42 | # the anagrams that result potentially take single letters from the start or end of another word 43 | 44 | for sc in tqdm(all_clues): 45 | c = ps.strip(sc.clue).lower() 46 | s = sorted(sc.soln.lower()) 47 | tgt_len = len(s) 48 | for idx in range(0, len(c) - tgt_len + 1): 49 | sub_part = c[idx:idx+tgt_len] 50 | # hidden if directly occurs 51 | if sub_part == sc.soln.lower(): 52 | add_to_labels('hidden', sc.idx) 53 | break 54 | # reverse if occurs backward 55 | if sub_part == sc.soln.lower()[::-1]: 56 | add_to_labels('reverse', sc.idx) 57 | break 58 | # direct anagram if occurs directly in clue once spaces and punct removed 59 | if sorted(sub_part) == s: 60 | add_to_labels('anag_direct', sc.idx) 61 | break 62 | 63 | return labels 64 | -------------------------------------------------------------------------------- /decrypt/common/puzzle_clue.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import logging 5 | import re 6 | import string 7 | from collections import defaultdict 8 | from dataclasses import dataclass, field 9 | from typing import * 10 | 11 | from tqdm import tqdm 12 | 13 | logging.getLogger(__name__) 14 | 15 | 16 | ################## 17 | # Puzzle and Clue related classes / datastructures for easy manipulation ### 18 | ################## 19 | @dataclass 20 | class BaseClue: 21 | clue: str 22 | lengths: List[int] 23 | soln: str 24 | soln_with_spaces: str = field(init=False) # soln string, but has spaces between words for multi-word answers 25 | idx: int = field(init=False) # unique index in the set 26 | dataset: str = field(init=False) # source dataset 27 | 28 | def __post_init__(self): 29 | self.soln = self.soln.lower() 30 | self.__populate_soln_with_spaces() 31 | self.idx = -1 # initially set to -1; but will be set in get_clean_clues() 32 | self.dataset = "" 33 | 34 | def __populate_soln_with_spaces(self): 35 | soln_with_spaces = "" 36 | idx = 0 37 | if len(self.lengths) > 1: 38 | for l in self.lengths: 39 | soln_with_spaces += self.soln[idx: idx + l] + " " 40 | idx += l 41 | soln_with_spaces = soln_with_spaces.strip() 42 | else: 43 | soln_with_spaces = self.soln 44 | self.soln_with_spaces = soln_with_spaces 45 | 46 | def clue_with_lengths(self, punct=","): 47 | return f'{self.clue} ({punct.join(map(str, self.lengths))})' 48 | 49 | @classmethod 50 | def from_clue_and_one_word_soln(cls, clue: str, soln: str): 51 | return cls(clue=clue, 52 | lengths=[len(soln)], 53 | soln=soln) 54 | 55 | @classmethod 56 | def from_clue_and_soln(cls, clue: str, soln: str): 57 | splits = soln.split(' ') 58 | lengths = list(map(lambda x: len(x.strip()), splits)) 59 | return cls(clue=clue, 60 | lengths=lengths, 61 | soln=soln) 62 | 63 | @classmethod 64 | def from_json(cls, json_obj: Dict) -> BaseClue: 65 | json_obj_no_soln_with_spaces = json_obj.copy() # copy bc we will modify 66 | json_obj_no_soln_with_spaces.pop('soln_with_spaces') 67 | return cls(**json_obj_no_soln_with_spaces) 68 | 69 | 70 | @dataclass 71 | class ClueWithGridInfo(BaseClue): 72 | across_or_down: str # "across" or "down" 73 | pos: Tuple[int, int] # row, col 74 | 75 | @classmethod 76 | def from_json(cls, json_list: List): 77 | return cls(*json_list) 78 | 79 | 80 | @dataclass(order=True) 81 | class GuardianClue(ClueWithGridInfo): 82 | # identifiers 83 | unique_clue_id: str # puzzleid_clue_id.. of form: 21465_1-across, e.g. 84 | type: str # "cryptic" or "quiptic" 85 | number: int # should be the end of id; a unique ID 86 | id: str # e.g. crosswords/cryptics/21465 87 | # extra metadata 88 | creator: Optional[str] # json -> creator -> name 89 | orig_lengths: str # sometimes there is a dash instead of a comma separator 90 | lengths_punctuation: Set[str] # the punctuation it contains 91 | 92 | @classmethod 93 | def to_json_dict(cls, gc: GuardianClue) -> Dict: 94 | return dict( 95 | clue=gc.clue, 96 | soln=gc.soln, 97 | soln_with_spaces=gc.soln_with_spaces, 98 | lengths=gc.lengths 99 | ) 100 | 101 | # In order to anonymize the dataset 102 | @dataclass 103 | class CleanGuardianClue(GuardianClue): 104 | def __post_init__(self): 105 | super().__post_init__() 106 | # ClueWithGridInfo 107 | self.across_or_down = "" 108 | self.pos = (0,0) 109 | 110 | self.unique_clue_id = "" 111 | self.number = 0 112 | self.id = "" 113 | self.dataset = "" 114 | 115 | @classmethod 116 | def from_json(cls, json_obj: Dict) -> CleanGuardianClue: 117 | # duplicates from_json in BaseClue 118 | # we need to pop soln_with_spaces bc post_init will be called (todo: do we?) 119 | json_clean = json_obj.copy() # copy bc we will modify 120 | for k in ['soln_with_spaces', 'idx', 'dataset']: 121 | json_clean.pop(k) 122 | json_clean['lengths_punctuation'] = set(json_clean['lengths_punctuation']) 123 | return cls(**json_clean) 124 | 125 | 126 | ### 127 | # for seq2seq 128 | ### 129 | @dataclass 130 | class Seq2seqDataEntry: 131 | idx: int # unique index for this element in the dataset 132 | input: str 133 | target: str 134 | 135 | # labels - if needed during run 136 | # any additional info? -> can be joined in with a join on idx 137 | 138 | @classmethod 139 | def from_base_clue(cls, 140 | gc: BaseClue, 141 | mod_fn: Optional[Callable] = None) -> Seq2seqDataEntry: 142 | if mod_fn is None: 143 | input = gc.clue_with_lengths() 144 | else: 145 | input = mod_fn(gc) 146 | 147 | return Seq2seqDataEntry(idx=gc.idx, 148 | input=input, 149 | target=gc.soln_with_spaces) 150 | 151 | @classmethod 152 | def to_json_dict(cls, entry: Seq2seqDataEntry): 153 | return entry.__dict__.copy() 154 | 155 | 156 | class ClueEncoder(json.JSONEncoder): 157 | def default(self, o): 158 | return o.__dict__ 159 | #### 160 | # Functions to go from guardian clue json to a filtered list for use in datasetes 161 | # todo: roughly copied from guardian_scrape > __main__ ; clean up; i.e. remove from that location 162 | # This code originally in data_util/guardian_gendatasets 163 | # moved here so that could be shared with cryptics_parsing for identifying clue types 164 | #### 165 | def make_stc_map(clue_list: List[BaseClue]) -> DefaultDict[str, List[BaseClue]]: 166 | soln_to_clue_map: defaultdict[str, List[GuardianClue]] = defaultdict(list) 167 | for c in tqdm(clue_list): 168 | soln_to_clue_map[c.soln].append(c) 169 | return soln_to_clue_map 170 | 171 | 172 | # Use to find duplicate clues 173 | def normalize(s): 174 | """Convert to lowercase and remove punctuation, articles and extra whitespace.""" 175 | 176 | def remove_articles(text): 177 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 178 | return re.sub(regex, ' ', text) 179 | 180 | def white_space_fix(text): 181 | return ''.join(text.split()) 182 | 183 | def remove_punc(text): 184 | exclude = set(string.punctuation) 185 | return ''.join(ch for ch in text if ch not in exclude) 186 | 187 | def lower(text): 188 | return text.lower() 189 | 190 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 191 | 192 | 193 | def filter_clues(soln_to_clue_map: defaultdict[str, List[BaseClue]]) \ 194 | -> Tuple[Dict[str, List[BaseClue]], List[BaseClue]]: 195 | # Remove anything that is exactly the same up to small diffs 196 | # removes 1610 normalized clues 197 | soln_to_clue_map_clean: Dict[str, List[BaseClue]] = defaultdict(list) 198 | all_clues_clean = [] 199 | count_removed = 0 200 | output_count = 0 201 | for k, v in tqdm(soln_to_clue_map.items()): 202 | # each v is a list, so we compare the clues that have the same soln 203 | set_of_clues: Set[str] = set() 204 | clean_list = [] 205 | 206 | for gc in v: 207 | norm_clue = normalize(gc.clue) 208 | if norm_clue in set_of_clues: 209 | count_removed += 1 210 | continue 211 | else: 212 | set_of_clues.add(norm_clue) 213 | clean_list.append(gc) 214 | if len(clean_list) > 0: 215 | soln_to_clue_map_clean[k] = clean_list 216 | output_count += len(clean_list) 217 | all_clues_clean.extend(clean_list) 218 | 219 | print(f'removed {count_removed} exact dupes') 220 | print(output_count) 221 | 222 | return soln_to_clue_map_clean, all_clues_clean -------------------------------------------------------------------------------- /decrypt/common/substitution.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from dataclasses_json import dataclass_json 3 | from typing import List 4 | 5 | @dataclass_json 6 | @dataclass 7 | class Substitution: 8 | new_clue_str: str 9 | substituted_word: str 10 | 11 | @dataclass_json 12 | @dataclass 13 | class ClueWithSubstitutions: 14 | orig_input: str 15 | word_to_be_swapped: str # anagram substrate 16 | target: str 17 | 18 | substitutions: List[Substitution] 19 | -------------------------------------------------------------------------------- /decrypt/common/util_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils to 3 | - dump obj to file (or retrieve from file) 4 | - train/test split function for cryptics 5 | - cryptic Parsed Examples => Data for training models 6 | """ 7 | import json 8 | import logging 9 | import os.path 10 | from pprint import pformat 11 | from typing import * 12 | import decrypt.config as config 13 | 14 | from tqdm import tqdm 15 | 16 | from decrypt.common.puzzle_clue import ( 17 | Seq2seqDataEntry, 18 | BaseClue 19 | ) 20 | from .anagrammer import Anagrammer 21 | 22 | log = logging.getLogger(__name__) 23 | 24 | k_data_names = ["train", "val", "test"] 25 | 26 | 27 | def _check_overwrite(filename): 28 | if os.path.isfile(filename): 29 | raise FileExistsError("Cannot write since file_name already exists and overwrite not specified") 30 | ####### 31 | # Dataset generation 32 | ####### 33 | 34 | def write_json_tuple(json_tuple: List[List], 35 | comment: str, 36 | export_dir, 37 | overwrite:bool = False, 38 | mod_fn: Optional[Callable] = None): 39 | assert 1 <= len(json_tuple) <= 3, len(json_tuple) 40 | 41 | def write_json_to_file(json_dict: List, path): 42 | if not overwrite: 43 | _check_overwrite(path) 44 | with open(path, 'w') as fh: 45 | json.dump(json_dict, fh) 46 | 47 | os.makedirs(export_dir, exist_ok=True) 48 | for json_out, filename in zip(json_tuple, k_data_names): 49 | tgt_path = os.path.join(export_dir, filename + ".json") 50 | write_json_to_file(json_out, tgt_path) 51 | 52 | # write a description 53 | file_path = os.path.join(export_dir, "README.txt") 54 | with open(file_path, "w") as f: 55 | f.write(comment) 56 | lengths = list(map(len, json_tuple)) 57 | f.write(f'\nTotal: {sum(lengths)}\n' 58 | f'splits: {lengths}') 59 | f.write("\n\n") 60 | sample_set = "" 61 | for entry in json_tuple[0][:3]: 62 | sample_set += f'{pformat(entry)}\n' 63 | f.write(sample_set + "\n\n") 64 | print(sample_set) 65 | 66 | if mod_fn is not None: 67 | f.write(f'\nMod fcn applied: {mod_fn.__name__}') 68 | 69 | log.info('Finished writing all files') 70 | 71 | 72 | def clue_list_tuple_to_train_split_json( 73 | clue_list_tuple: Tuple[List[BaseClue], ...], # train, val, test 74 | comment: str, 75 | export_dir, 76 | mod_fn: Optional[Callable] = None, 77 | overwrite: bool = False): 78 | assert 1 <= len(clue_list_tuple) <= 3, len(clue_list_tuple) 79 | 80 | def make_json_list(l: List[BaseClue]): 81 | out_list = [] 82 | for bc in tqdm(l): 83 | data_entry = Seq2seqDataEntry.from_base_clue(bc, mod_fn=mod_fn) 84 | json_entry = Seq2seqDataEntry.to_json_dict(data_entry) 85 | out_list.append(json_entry) 86 | return out_list 87 | 88 | json_output_tuple = list(map(make_json_list, clue_list_tuple)) 89 | 90 | out_ex = json_output_tuple[0][0] 91 | log.info(f'Source target mapping:\n' 92 | f'\t{out_ex["input"]} => {out_ex["target"]}\n') 93 | 94 | write_json_tuple(json_output_tuple, comment, export_dir, overwrite=overwrite, mod_fn=mod_fn) 95 | 96 | ### 97 | # Anagrams 98 | ### 99 | 100 | def get_anags(max_num_words=1) -> List[List[str]]: 101 | """ 102 | Return List where each element is a list of words that map to the same set of letters 103 | """ 104 | anag = Anagrammer(str(config.DataDirs.Generated.anagram_db)) # system autoappends db 105 | # First populate the anagrams 106 | anag._populate_possible_anagrams() 107 | 108 | ret_anags = [] 109 | for anag_set in anag._possible_anagrams: 110 | one_word_anags, multi_word_anags = anag_set.get_lists() # get one words only 111 | all_anags = one_word_anags + multi_word_anags 112 | if max_num_words > 0: 113 | # list of lists; num lists is the number of words in the anag set 114 | all_anags = filter(lambda x: len(x) <= max_num_words, all_anags) 115 | 116 | # for multi-word anags, we join them together 117 | all_anags = list(map(lambda x: " ".join(x), all_anags)) 118 | if len(all_anags) > 1: # make sure there are at least two realizations of the letter set 119 | # flattened = [w for realizations in all_anags for w in realizations] # flatten 120 | ret_anags.append(all_anags) 121 | 122 | print(len(ret_anags)) # unique sets of letters that produce more than one realized anagram 123 | print(len(anag._possible_anagrams)) # unique sets of letters that produce a single or multiword anagram 124 | print(sum(map(lambda x: len(x), ret_anags))) # all possible words that have at least one other one word anag 125 | print(ret_anags[0]) # example 126 | 127 | return ret_anags 128 | -------------------------------------------------------------------------------- /decrypt/common/util_spellchecker.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import os 5 | from typing import * 6 | 7 | import enchant 8 | from tqdm import tqdm 9 | 10 | import decrypt.config as config 11 | 12 | logging.getLogger(__name__) 13 | 14 | def get_shelve_dbhandler_open_flag(output_filename: str, update_flag: str = "") -> Optional[str]: 15 | flag = "" 16 | if update_flag == "new": # generate new, don't overwrite 17 | if os.path.isfile(output_filename + ".db"): 18 | logging.warning(f"File already exists. Use other update_type flag") 19 | return None 20 | flag = "n" 21 | elif update_flag == "update": # update: 22 | if not os.path.isfile(output_filename + ".db"): 23 | logging.warning(f"Attempting to update a database that does not exist. Failed") 24 | return None 25 | logging.info(f"Updating database at {output_filename}") 26 | flag = "w" 27 | elif update_flag == "overwrite": # overwrite 28 | logging.info(f"Overwriting database at {output_filename}") 29 | flag = "n" 30 | else: 31 | logging.warning(f"Invalid flag. Failed") 32 | return None 33 | 34 | return flag 35 | 36 | def line_parser_US_dic(input_line: bytes, log_errors=False) -> Optional[List[str]]: 37 | try: 38 | x = input_line.decode("utf-8") 39 | x = x.strip() 40 | return [x] 41 | except UnicodeDecodeError: 42 | if log_errors: 43 | print(f"unicode decode fail: {repr(input_line)}") 44 | return None 45 | 46 | # todo: enchant is no longer maintained and double checking is inefficient 47 | class SpellChecker: 48 | def __init__(self, 49 | dict_files: List[Tuple[str, bool]] = None, 50 | init_enchant_dict=True, 51 | init_twl_dict=True, 52 | log_init_errors=False): 53 | """ 54 | 55 | Args: 56 | dict_files: List of tuples of 57 | """ 58 | print("Initialized a spellchecker") 59 | self.dict = set() 60 | self.enchant_dict = None 61 | self.twl_short_word_dict = set() 62 | if init_enchant_dict: 63 | self.enchant_dict = enchant.Dict("en_US") 64 | if init_twl_dict: 65 | self.__add_twl_contents_to_dict(config.DataDirs.Generated.twl_tex_dict) 66 | 67 | if dict_files is None: 68 | dict_files = [(config.DataDirs.OriginalData.k_US_dic, True)] 69 | for df in dict_files: 70 | self.__add_file_contents_to_dict(df, log_init_errors) 71 | logging.info("Done setting up spellchecker") 72 | 73 | def __del__(self): 74 | print("DEL called for spellchecker") 75 | 76 | def __add_twl_contents_to_dict(self, file: str): 77 | logging.info(f'Reading file into dict: {file}') 78 | print(f"This will fail if you have not downloaded or generated twl_dict.txt") 79 | with open(file, 'r') as f: 80 | for input_line in tqdm(f): 81 | word = input_line.strip() 82 | if word != "": 83 | if len(word) < 3: 84 | self.twl_short_word_dict.add(word.lower()) 85 | else: 86 | self.dict.add(word.lower()) 87 | 88 | logging.info(f'Done reading file: {file}') 89 | 90 | def __add_file_contents_to_dict(self, file: Tuple[str, bool], log_errors): 91 | logging.info(f'Reading file into dict: {file[0]}') 92 | if file[1]: # bytes 93 | with open(file[0], 'rb') as f: 94 | for input_line in tqdm(f): 95 | word_list = line_parser_US_dic(input_line, log_errors=log_errors) 96 | if word_list is not None and len(word_list) > 0 and word_list[0] != "": 97 | self.dict.add(word_list[0].lower()) 98 | else: # not bytes 99 | with open(file[0], 'r') as f: 100 | for input_line in tqdm(f): 101 | word = input_line.strip() 102 | if word != "": 103 | self.dict.add(word.lower()) 104 | 105 | logging.info(f'Done reading file: {file[0]}') 106 | 107 | def check_word(self, w: str, 108 | lower_case: bool = True, 109 | special_handle_short_words: bool = False, 110 | check_twl_short_dict: bool = True, 111 | check_enchant_dict: bool = True, 112 | print_info: bool = False, 113 | use_base_dict=True) -> bool: 114 | if lower_case: 115 | w = w.lower() 116 | 117 | one_letter_words = ["a", "i"] 118 | two_letter_words = ["ad", "am", "an", "as", "at", 119 | "do", "go", "he", "hi", "if", "in", 120 | "is", "it", "me", "my", "no", "of", "on", "or", 121 | "so", "to", "up", "us"] 122 | in_dict = w in self.dict 123 | in_short_words = w in one_letter_words or w in two_letter_words 124 | in_twl_short = w in self.twl_short_word_dict 125 | in_enchant_lower = self.enchant_dict is not None and self.enchant_dict.check(w) 126 | in_enchant_upper = self.enchant_dict is not None and self.enchant_dict.check(w.capitalize()) 127 | 128 | if print_info: 129 | print(f'dict: {in_dict}\t twl_short: {in_twl_short}\t short_word: {in_short_words}\n' 130 | f'enchant_lower: {in_enchant_lower}\t enchant_upper: {in_enchant_upper}') 131 | 132 | # Some heuristics to fix problems with short words pre-empting the backtracking alg 133 | if special_handle_short_words and len(w) <= 3: 134 | if len(w) < 3: 135 | return in_short_words 136 | else: # len == 3 137 | return in_dict and (in_enchant_lower or in_enchant_upper) 138 | 139 | 140 | # Otherwise, successively check dicts 141 | if use_base_dict and in_dict: 142 | return True 143 | elif check_twl_short_dict and in_twl_short: 144 | return True 145 | elif check_enchant_dict and (in_enchant_lower or in_enchant_upper): 146 | return True 147 | else: 148 | return False 149 | 150 | def split_mixed_word(self, input_word: str) -> Optional[List[str]]: 151 | """ 152 | Recursive backtracking, (greedy) algorithm for determining the set of words 153 | in a word without spaces. 154 | 155 | # todo: some three letter words will (still?) cause a problem 156 | 157 | Returns: List of words (str) that compose the input string 158 | None: if no valid split found 159 | """ 160 | # Don't pass around the spell_chkr 161 | wlen = len(input_word) 162 | for end_idx in range(wlen, 0, -1): 163 | w = input_word[0:end_idx] 164 | if self.check_word(w, special_handle_short_words=True): 165 | if end_idx == wlen: # base case, we are done 166 | return [w] 167 | # otherwise, need to compute possibly terminating words 168 | next = input_word[end_idx:] 169 | next_result = self.split_mixed_word(next) 170 | if next_result is None: 171 | continue 172 | else: 173 | ret = [w] 174 | ret.extend(next_result) 175 | return ret 176 | return None 177 | -------------------------------------------------------------------------------- /decrypt/common/util_wordnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Used for cryptic baselines evaluation. 3 | """ 4 | from lemminflect import getAllInflections 5 | from nltk.corpus import wordnet as wn 6 | from typing import Set 7 | from pprint import pprint as pp 8 | 9 | def all_inflect(w, word_len): 10 | out = set() 11 | for k, v in getAllInflections(w).items(): 12 | if word_len is not None: 13 | out.update(filter(lambda x: len(x) == word_len, v)) 14 | else: 15 | out.update(v) 16 | return out 17 | 18 | def all_lemmas_for_word(w: str, 19 | min_word_len, 20 | remove_word=False) -> Set[str]: 21 | """ 22 | Args: 23 | w: word 24 | min_word_len: min_word_len of a synonym to be included in the set 25 | 26 | Returns: Set of all synonyms (lemmas) for the word, in all of its senses (synsets) 27 | """ 28 | # For every synset, for every synonym (lemma), if the lemma is > min_word_len 29 | synsets = wn.synsets(w) 30 | ret = set([lem for ss in synsets for lem in ss.lemma_names() if len(lem) >= min_word_len]) 31 | if remove_word: 32 | ret = [x for x in ret if w[:-1] not in x] 33 | 34 | return ret 35 | 36 | 37 | 38 | # todo: we should only do matching parts of speech, and we should only consider central nouns in the 39 | # noun phrase, e.g. via the parse tree 40 | def in_defn(clue: str, answer: str, min_word_len=3): 41 | """ 42 | Check whether any of the lemma names for the clue synsets are in any of the 43 | lemma sets of the definition of the answer 44 | 45 | The answer often entails the clue word. For example, for a clue of 'bird' with answer 46 | 'hoopoes' we have defn of hoopoes: any of several crested Old World birds 47 | with a slender downward-curved bill 48 | which contains bird. 49 | """ 50 | clue_lemmas = all_lemmas_for_word(clue, min_word_len) 51 | 52 | overlap_ct = 0 # number of times we find a clue lemma in the definition set 53 | for answer_ss in wn.synsets(answer): # for synset in answer 54 | defn_words = answer_ss.definition().split(" ") 55 | defn_lemma_set = set() # set of all lemmas for the definition words 56 | for w in defn_words: 57 | if len(w) >= min_word_len: 58 | defn_lemma_set |= all_lemmas_for_word(w, min_word_len) 59 | 60 | # check whether any clue lemma is in any definition lemma (i.e. is a synonym 61 | # of any of the definition words) for this synset 62 | for cw in clue_lemmas: 63 | if cw in defn_lemma_set: 64 | #print(f"found {cw} in defn_word_set") 65 | overlap_ct += 1 66 | 67 | return overlap_ct > 0 68 | 69 | # def in_closure_set(w1: str, w2: str, max_depth:int=3, print=False, 70 | # closure_fn=lambda x: x.hypernyms()) -> bool: 71 | # """ 72 | # Check whether w1 is in closure set (default hypernym) of w2 73 | # by computing they hypernym set of w2 (for each of its possible synsets) 74 | # and then checking whether the w1 (or any of its synonyms) matches one of those hypernyms 75 | # Args: 76 | # clue: 77 | # answer: 78 | # max_depth: how deep to go in hypernym tree 79 | # 80 | # Returns: 81 | # 82 | # """ 83 | # lookup_syns = set(map(lambda x: x.name(), 84 | # wn.synsets(w1))) 85 | # if print: pp(lookup_syns) 86 | # 87 | # all_closures = set([ss.closure(closure_fn, depth=max_depth) for ss in wn.synsets(w2)]) 88 | # all_closure_names 89 | # closure_hyp_set = set(map(lambda hyp: hyp.name(), 90 | # ss.closure(closure_fn, depth=max_depth))) 91 | # closure_hyp_set |= {ss.name()} # also include "same-level" synonyms 92 | # if print: pp(closure_hyp_set) 93 | # if not lookup_syns.isdisjoint(closure_hyp_set): 94 | # return True 95 | # 96 | # return False 97 | 98 | 99 | def in_closure_set(w1: str, w2: str, max_depth:int=3, print=False, 100 | closure_fn=lambda x: x.hypernyms()) -> bool: 101 | """ 102 | Check whether w1 is in closure set (default hypernym) of w2 103 | by computing they hypernym set of w2 (for each of its possible synsets) 104 | and then checking whether the w1 (or any of its synonyms) matches one of those hypernyms 105 | Args: 106 | clue: 107 | answer: 108 | max_depth: how deep to go in hypernym tree 109 | 110 | Returns: 111 | 112 | """ 113 | # Get all synonyms of the lookup word 114 | lookup_syns = set(wn.synsets(w1)) 115 | if print: pp(lookup_syns) 116 | 117 | # compute the closure sets of the closure word 118 | synsets = wn.synsets(w2) 119 | all_closures = set(c for ss in synsets for c in ss.closure(closure_fn, depth=max_depth)) 120 | all_closures |= set(ss for ss in synsets) # add depth 0 names 121 | if print: pp(all_closures) 122 | 123 | if not lookup_syns.isdisjoint(all_closures): 124 | return True 125 | 126 | return False -------------------------------------------------------------------------------- /decrypt/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | ##### 5 | # File locations 6 | ##### 7 | 8 | # parse root directory 9 | k_dir = Path(os.path.abspath(__file__)).parent.parent 10 | 11 | 12 | # data dirs 13 | class DataDirs: 14 | 15 | class Deits: 16 | k_deits_main = k_dir / 'deits' 17 | k_deits_clues = k_deits_main / 'clues' 18 | k_deits_outputs = k_deits_main / 'outputs' 19 | 20 | class Guardian: 21 | # json_folder = k_dir / "data/puzzles/" 22 | json_folder = k_dir / "data/guardian_2020_10_08.json" 23 | 24 | # splits json 25 | naive_random = k_dir / "data/naive_random.json" 26 | disjoint = k_dir / "data/disjoint.json" 27 | disjoint_word_init = k_dir / "data/disjoint_word_init.json" 28 | 29 | class OriginalData: 30 | k_xd_cw = k_dir / "data/original/xd/clues.tsv" 31 | k_US_dic = k_dir / "data/original/us/US.dic" 32 | 33 | # copied from deits directory 34 | k_deits_anagram_list = k_dir / "data/original/deits_anag_indic/ana_" 35 | 36 | k_names = k_dir / "data/original/names/" 37 | 38 | # cryptonite 39 | _cryptonite_original_data_dir = k_dir / "data/original/cryptonite" 40 | k_cryptonite_offical = _cryptonite_original_data_dir / "cryptonite-official-split" 41 | k_cryptonite_naive = _cryptonite_original_data_dir / "cryptonite-naive-split" 42 | 43 | 44 | # our generated files that are not model inputs 45 | class Generated: 46 | xd_cw_clean_json = k_dir / "data/generated/xd_clean.json" 47 | 48 | # generated from TWL06 49 | # https://github.com/fogleman/TWL06 (no license) 50 | twl_tex_dict = k_dir / "data/generated/twl_dict.txt" 51 | 52 | # anagrams 53 | anagram_db = k_dir / "data/generated/anag_db" 54 | 55 | class DataExport: 56 | _base = k_dir / "data/clue_json/" 57 | 58 | # guardian 59 | _guardian_base_dir = _base / "guardian" 60 | guardian_naive_random_split = _guardian_base_dir / "naive_random" 61 | guardian_naive_disjoint_split = _guardian_base_dir / "naive_disjoint" 62 | guardian_word_init_disjoint_split = _guardian_base_dir / "word_init_disjoint" 63 | 64 | #### 65 | # curricular 66 | _curricular = _base / "curricular" 67 | 68 | # ACW 69 | _ACW_sub_dir = "ACW_data" 70 | xd_cw_json = _curricular / _ACW_sub_dir 71 | 72 | # anagramming 73 | _anag_sub_dir = "anagram" 74 | anag_dir = _curricular / _anag_sub_dir # anagrams (train.json) and anag_indics 75 | anag_indics = anag_dir / "anag_indics.json" 76 | #### 77 | 78 | # descrambling 79 | _descramble_dir = _base / "descramble" 80 | descramble_random = _descramble_dir / "random_split" 81 | descramble_word_init_disjoint = _descramble_dir / "word_initial" 82 | 83 | # 6.4 wordplay 84 | wordplay_dir = _base / "wordplay" 85 | 86 | # 6.5 cryptonite 87 | _crypto_dir = _base / "cryptonite" 88 | crypto_naive = _crypto_dir / "naive" 89 | crypto_naive_disjoint = _crypto_dir / "official_theirs" 90 | crypto_word_init_disjoint = _crypto_dir / "word_init_disjoint" 91 | 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /decrypt/scrape_parse/__init__.py: -------------------------------------------------------------------------------- 1 | from .guardian_load import load_guardian_splits, load_guardian_splits_disjoint, load_guardian_splits_disjoint_hash -------------------------------------------------------------------------------- /decrypt/scrape_parse/acw_load.py: -------------------------------------------------------------------------------- 1 | """ 2 | for XD clue set (i.e. ACW, american crossword) 3 | """ 4 | 5 | import csv 6 | import logging 7 | from collections import Counter 8 | from typing import * 9 | 10 | from tqdm import tqdm 11 | 12 | from decrypt.common.puzzle_clue import BaseClue, make_stc_map, filter_clues 13 | from decrypt.common.util_spellchecker import SpellChecker 14 | 15 | logging.getLogger(__name__) 16 | 17 | 18 | def xd_load_and_filter_clues(filename, 19 | remove_if_not_in_dict=False, 20 | strip_trailing_period=True, 21 | remove_questions=True, 22 | remove_likely_abbreviations=True, 23 | remove_fillin=True, 24 | # try_word_split=False, 25 | ) -> List[BaseClue]: 26 | 27 | with open(filename, "r") as f: 28 | if remove_if_not_in_dict: 29 | sc = SpellChecker(init_twl_dict=True, 30 | init_enchant_dict=True) 31 | else: 32 | sc = None 33 | 34 | rd = csv.reader(f, delimiter="\t", quoting=csv.QUOTE_NONE) 35 | _ = next(rd) # skip first header line 36 | ctr = Counter() 37 | clue_list = [] 38 | 39 | for row in tqdm(rd): 40 | try: 41 | answer, clue = row[2], row[3] 42 | except: 43 | continue 44 | 45 | if answer == "" or clue == "" or len(clue) < 3: 46 | ctr["empty"] += 1 47 | continue 48 | 49 | if remove_fillin and ("_" in clue or "--" in clue): 50 | ctr["fillin"] += 1 51 | continue 52 | 53 | if remove_if_not_in_dict and not sc.check_word(answer): 54 | ctr["not_in_dict"] += 1 55 | continue 56 | 57 | # if "-Across" in c.clue or "-Down" in c.clue: 58 | if "Across" in clue or "Down" in clue: 59 | ctr["ref"] += 1 60 | continue 61 | 62 | if remove_likely_abbreviations: 63 | if clue[-1] == "." and len(answer) < 4: 64 | ctr["removed_likely_abbrev"] += 1 65 | continue 66 | 67 | if remove_questions and clue[-1] == "?": 68 | ctr["question word"] += 1 69 | continue 70 | 71 | if strip_trailing_period and clue[-1] == ".": 72 | # this was implemented wrong originally - truncated the answer rather than clue 73 | ctr["removed_trailing_period"] += 1 74 | clue = clue[:-1] 75 | # answer = answer[:-1] 76 | 77 | # should have no spaces 78 | answer = answer.replace(' ', '') 79 | 80 | c = BaseClue.from_clue_and_soln(clue, answer) 81 | clue_list.append(c) 82 | 83 | logging.info(ctr) 84 | print(ctr) 85 | logging.info(f'Filtered to {len(clue_list)} clues') 86 | return clue_list 87 | 88 | 89 | # modeled after get_clean_clues (for guardian) 90 | def get_clean_xd_clues(filename, 91 | remove_if_not_in_dict=True, 92 | do_filter_dupes=True) \ 93 | -> Tuple[Dict[str, List[BaseClue]], List[BaseClue]]: 94 | 95 | logging.info(f'loading xd (ACW) set from {filename}') 96 | all_clue_list = xd_load_and_filter_clues(filename, 97 | remove_if_not_in_dict=remove_if_not_in_dict, 98 | strip_trailing_period=True, 99 | remove_questions=True, 100 | remove_likely_abbreviations=True, 101 | remove_fillin=True) 102 | 103 | # generate soln to clue map 104 | # soln:str -> List[gc] 105 | soln_to_clue_map = make_stc_map(all_clue_list) 106 | 107 | # Remove anything that is exactly the same up to small diffs 108 | # removes 1610 normalized clues 109 | if do_filter_dupes: 110 | soln_to_clue_map, all_clue_list = filter_clues(soln_to_clue_map) 111 | 112 | # add indices and a note about dataset 113 | for idx, c in enumerate(all_clue_list): 114 | c.idx = idx 115 | c.dataset = filename 116 | 117 | # print the distribution 118 | ctr = Counter() 119 | for c in all_clue_list: 120 | ctr[len(c.lengths)] += 1 121 | logging.info(ctr) 122 | 123 | # Verify same length 124 | assert sum(map(len, soln_to_clue_map.values())) == len(all_clue_list) 125 | 126 | return soln_to_clue_map, all_clue_list 127 | -------------------------------------------------------------------------------- /decrypt/scrape_parse/guardian_scrape.py: -------------------------------------------------------------------------------- 1 | """ 2 | For scraping the guardian cryptic crosswords. 3 | 4 | # Notes on the Decrypting paper dataset 5 | # 0 -> 10000: had 11 puzzles; omitted from dataset 6 | # 10000-20000: not scraped 7 | # 20000: 28259: done (present day as of the date we ran it) 8 | """ 9 | import argparse 10 | import glob 11 | import json 12 | import logging 13 | import os 14 | import time 15 | from collections import Counter 16 | from collections import defaultdict 17 | from typing import * 18 | from typing import IO, Optional, Tuple, Dict 19 | 20 | import urllib3 21 | from bs4 import BeautifulSoup 22 | from tqdm import tqdm 23 | 24 | from decrypt.scrape_parse.util import _gen_filename 25 | 26 | logging.basicConfig(level=logging.INFO) 27 | log = logging.getLogger(__name__) 28 | 29 | min_size = 100 * 1024 # require that webpages are at least this big (otherwise no puzzle) 30 | file_end = ".html" 31 | BASE_URL = "https://www.theguardian.com/crosswords/" 32 | k_subsite = "cryptic" 33 | k_start_idx = 21465 34 | k_end_idx = 28259 35 | 36 | 37 | ### 38 | # Fetching methods 39 | ### 40 | def _puzzle_json_from_file_or_str(input_data: Union[IO, str], ctr: Counter) -> Optional[Dict]: 41 | """ 42 | Take a file handler, fh, representing HTML and return a dict from the parsed json 43 | Returns: Optional[Dict] if the html was successfully parsed 44 | """ 45 | try: 46 | soup = BeautifulSoup(input_data, 'html.parser') 47 | cw_json_class = soup.findAll('div', {'class': 'js-crossword', 'data-crossword-data': True}) 48 | cw_json_data = cw_json_class[0]['data-crossword-data'] 49 | except Exception as e: 50 | print(e) 51 | ctr['html_fail'] += 1 52 | return None 53 | 54 | json_data = json.loads(cw_json_data) 55 | puzzle_entries = json_data.get('entries') 56 | 57 | if puzzle_entries is None: 58 | ctr['json_fail'] += 1 59 | return None 60 | 61 | return json_data 62 | 63 | 64 | def _fetch(url, ctr: Counter, sleep_time=0.2, debug=False) -> Tuple[Optional[urllib3.HTTPResponse], int]: 65 | """ 66 | Fetch from url, with retry if 429 response by sleeping in doubling increments of 0.2 67 | Returns: tuple: 68 | - resp if stat == 200, None otherwise 69 | - response code (-1 : failed to get any response, stat otherwise) 70 | """ 71 | if debug: 72 | print(f"fetching {url}") 73 | if sleep_time > 10: # base case 74 | print(f"sleeptime: {sleep_time} too large, skipping {url}") 75 | return None, -1 76 | 77 | resp = http.request("GET", url) 78 | stat = resp.status 79 | ctr[stat] += 1 80 | 81 | if resp.status == 429: # rate limit 82 | time.sleep(sleep_time) 83 | _fetch(url, ctr, sleep_time * 2, debug=debug) # recurse with higher wait 84 | 85 | if stat == 200: 86 | return resp, stat 87 | 88 | return None, stat # failed, so just return the stat 89 | 90 | 91 | def fetch_and_store_set(base_url: str, 92 | subsite: str, 93 | json_output_dir: str, 94 | db: Dict[str, int], 95 | start_idx: int, 96 | stop_idx: int, 97 | html_output_dir=None): 98 | """ 99 | Fetch all puzzles in base_url/subsite/[start_idx-stop-idx] and write the puzzle json to output_dir 100 | If html_output_dir is not None, then also store a copy of the html (not just json) 101 | """ 102 | 103 | def _fetch_and_store(idx: int) -> NoReturn: 104 | """ 105 | Fetch from BASE_URL/subsite/idx and write the html to json_files_dir/.html 106 | """ 107 | key = subsite + "/" + str(idx) 108 | url = base_url + key 109 | 110 | # this would require the database to be populated. will never trigger as currently written 111 | if db is not None and db.get(key) is not None: # already checked url 112 | ctr['skipped: in db'] += 1 113 | return 114 | 115 | # check if file already exists 116 | json_outfile = _gen_filename(json_output_dir, subsite, ext=".json", idx=idx) 117 | if json_outfile in json_exists_set: 118 | ctr['skipped: json exists'] += 1 119 | return 120 | 121 | # otherwise proceed with _fetch 122 | resp, stat = _fetch(url, ctr) 123 | 124 | if resp is not None: # stat == 200 125 | # parse the json 126 | data: str = resp.data.decode('utf-8') 127 | puz_json = _puzzle_json_from_file_or_str(data, ctr) 128 | if puz_json is None: 129 | print(f"error for {url}") 130 | return 131 | 132 | # write json if valid 133 | # json_outfile = _gen_filename(json_output_dir, subsite, ext=".json", idx=idx) 134 | with open(json_outfile, "w") as f: 135 | json.dump(puz_json, f) 136 | 137 | # write html files if required 138 | if html_output_dir is not None: 139 | outfile = _gen_filename(html_output_dir, subsite, ext=".html", idx=idx) 140 | with open(outfile, "w") as f: 141 | print(data, file=f) 142 | 143 | ctr['success'] += 1 144 | ctr['totalclues'] += len(puz_json['entries']) # we verified that entires is present 145 | 146 | # record status after finish processing 147 | if db is not None: 148 | db[key] = stat 149 | 150 | # first check for json that are already downloaded 151 | file_glob_path = _gen_filename(json_output_dir, subsite=k_subsite, ext=".json", return_glob=True) 152 | log.info(f'Using file glob at {file_glob_path}') 153 | file_glob = glob.glob(file_glob_path) 154 | log.info(f'Some files already present in the output directory; these will be skipped {len(file_glob)}') 155 | json_exists_set = set(file_glob) 156 | 157 | ctr = Counter() 158 | pbar = tqdm(range(start_idx, stop_idx + 1)) 159 | for i in pbar: 160 | # pbar.set_description(f"Succ: {ctr[200]}\t Fail: {ctr[404]}\t Skip: {ctr['skipped: json exists']}\t") 161 | pbar.set_description(f"Succ: {ctr[200]}\t Fail: {ctr[404]}") 162 | _fetch_and_store(i) 163 | print(ctr) 164 | 165 | 166 | def parse_args(): 167 | parser = argparse.ArgumentParser('Guardian Scrape') 168 | 169 | parser.add_argument('--save_directory', 170 | type=str, 171 | required=True, 172 | help='Where to save the downloaded json files') 173 | 174 | return parser.parse_args() 175 | 176 | 177 | def main(): 178 | ####################################### 179 | # setup 180 | ####################################### 181 | # keep track of what has already been fetched; useful if run is stopped 182 | # todo: this would need to be written / loaded to be useful; at present does nothing 183 | # instead we check whether a given json file exists; this works as long as the URL actually exists 184 | # for invalid URLs (i.e. invalid puzzle indices), we will retry download each time this is run 185 | url_db = defaultdict(None) 186 | 187 | # fetch website and store all json files 188 | # if you want the HTML, you can pass html_output_dir = html_output_dir 189 | log.info(f'Fetching puzzles from indexes {k_start_idx} to {k_end_idx} inclusive') 190 | fetch_and_store_set(BASE_URL, 191 | subsite=k_subsite, 192 | json_output_dir=parsed_args.save_directory, 193 | db=url_db, start_idx=k_start_idx, stop_idx=k_end_idx, 194 | html_output_dir=None) 195 | 196 | 197 | if __name__ == "__main__": 198 | parsed_args = parse_args() 199 | 200 | if not os.path.isdir(parsed_args.save_directory): 201 | raise NotImplemented(f'Save dir {parsed_args.save_directory} does not exist') 202 | 203 | http = urllib3.PoolManager() 204 | 205 | main() 206 | -------------------------------------------------------------------------------- /decrypt/scrape_parse/make_public_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 5, 6 | "metadata": { 7 | "collapsed": true, 8 | "pycharm": { 9 | "name": "#%%\n" 10 | } 11 | }, 12 | "outputs": [ 13 | { 14 | "name": "stdout", 15 | "output_type": "stream", 16 | "text": [ 17 | "The autoreload extension is already loaded. To reload it, use:\n", 18 | " %reload_ext autoreload\n" 19 | ] 20 | } 21 | ], 22 | "source": [ 23 | "%load_ext autoreload\n", 24 | "%autoreload 2" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 66, 30 | "outputs": [], 31 | "source": [ 32 | "from decrypt.scrape_parse.guardian_load import orig_get_clean_clues, get_clean_clues, load_guardian_splits, load_guardian_splits_disjoint, load_guardian_splits_disjoint_hash" 33 | ], 34 | "metadata": { 35 | "collapsed": false, 36 | "pycharm": { 37 | "name": "#%%\n" 38 | } 39 | } 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "outputs": [], 45 | "source": [ 46 | "k_my_dir=\"/Users/jsrozner/JROZ/Programming/cryptic/cryptic-data/data (symlink)/puzzles/guardian_data/guardian_2020_10_08_json\"\n", 47 | "soln_to_clue_map, all_clues = orig_get_clean_clues(k_my_dir, strip_identifying_info=True)\n" 48 | ], 49 | "metadata": { 50 | "collapsed": false, 51 | "pycharm": { 52 | "name": "#%%\n" 53 | } 54 | } 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 15, 59 | "outputs": [], 60 | "source": [ 61 | "import json\n", 62 | "import dataclasses" 63 | ], 64 | "metadata": { 65 | "collapsed": false, 66 | "pycharm": { 67 | "name": "#%%\n" 68 | } 69 | } 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 39, 74 | "outputs": [], 75 | "source": [ 76 | "json_list = list(map(dataclasses.asdict, all_clues))" 77 | ], 78 | "metadata": { 79 | "collapsed": false, 80 | "pycharm": { 81 | "name": "#%%\n" 82 | } 83 | } 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 44, 88 | "outputs": [], 89 | "source": [ 90 | "class SetEncoder(json.JSONEncoder):\n", 91 | " def default(self, obj):\n", 92 | " if isinstance(obj, set):\n", 93 | " return list(obj)\n", 94 | " return json.JSONEncoder.default(self, obj)" 95 | ], 96 | "metadata": { 97 | "collapsed": false, 98 | "pycharm": { 99 | "name": "#%%\n" 100 | } 101 | } 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 48, 106 | "outputs": [], 107 | "source": [ 108 | "k_json_file=\"/Users/jsrozner/JROZ/Programming/cryptic/cryptic-data/data (symlink)/puzzles/guardian_data/guardian_2020_10_08.json\"" 109 | ], 110 | "metadata": { 111 | "collapsed": false, 112 | "pycharm": { 113 | "name": "#%%\n" 114 | } 115 | } 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 49, 120 | "outputs": [], 121 | "source": [ 122 | "with open(k_json_file, 'w') as f:\n", 123 | " json.dump(json_list, f, cls=SetEncoder)" 124 | ], 125 | "metadata": { 126 | "collapsed": false, 127 | "pycharm": { 128 | "name": "#%%\n" 129 | } 130 | } 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 61, 135 | "outputs": [ 136 | { 137 | "name": "stderr", 138 | "output_type": "stream", 139 | "text": [ 140 | "100%|██████████| 142380/142380 [00:00<00:00, 674483.51it/s]\n", 141 | "INFO:decrypt.scrape_parse.guardian_load:Counter({1: 118540, 2: 20105, 3: 2929, 4: 686, 5: 112, 6: 8})\n", 142 | "INFO:decrypt.scrape_parse.guardian_load:Clue list length matches Decrypting paper expected length\n" 143 | ] 144 | } 145 | ], 146 | "source": [ 147 | "## testing: verify loading\n", 148 | "stc_map, all_clues2 = get_clean_clues(k_json_file)\n", 149 | "for idx in range(len(all_clues)):\n", 150 | " for k, v in dataclasses.asdict(all_clues[idx]).items():\n", 151 | " if k == 'idx': continue\n", 152 | " assert v == all_clues2[idx].__getattribute__(k), v\n", 153 | "\n", 154 | "print('all match!')" 155 | ], 156 | "metadata": { 157 | "collapsed": false, 158 | "pycharm": { 159 | "name": "#%%\n" 160 | } 161 | } 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 67, 166 | "outputs": [ 167 | { 168 | "name": "stderr", 169 | "output_type": "stream", 170 | "text": [ 171 | "100%|██████████| 142380/142380 [00:00<00:00, 1093641.27it/s]\n", 172 | "INFO:decrypt.scrape_parse.guardian_load:Counter({1: 118540, 2: 20105, 3: 2929, 4: 686, 5: 112, 6: 8})\n", 173 | "INFO:decrypt.scrape_parse.guardian_load:Clue list length matches Decrypting paper expected length\n", 174 | "INFO:decrypt.scrape_parse.guardian_load:Got splits of lenghts [85428, 28476, 28476]\n", 175 | "INFO:decrypt.scrape_parse.guardian_load:First three clues of train set:\n", 176 | "\t[CleanGuardianClue(clue='Suffering to grasp edge of plant', lengths=[8], soln='agrimony', soln_with_spaces='agrimony', idx=85002, dataset='', across_or_down='', pos=(0, 0), unique_clue_id='', type='cryptic', number=0, id='', creator='Chifonie', orig_lengths='8', lengths_punctuation=set()), CleanGuardianClue(clue='Honour Ben and Noel with new order', lengths=[7], soln='ennoble', soln_with_spaces='ennoble', idx=3432, dataset='', across_or_down='', pos=(0, 0), unique_clue_id='', type='cryptic', number=0, id='', creator='Rufus', orig_lengths='7', lengths_punctuation=set()), CleanGuardianClue(clue='Bit the royal we love? Cheers!', lengths=[4], soln='iota', soln_with_spaces='iota', idx=25530, dataset='', across_or_down='', pos=(0, 0), unique_clue_id='', type='cryptic', number=0, id='', creator='Screw', orig_lengths='4', lengths_punctuation=set())]\n", 177 | "INFO:decrypt.scrape_parse.guardian_load:Verifying splits match Decrypting paper: Spot test clue 5111 has correct text\n", 178 | "100%|██████████| 142380/142380 [00:00<00:00, 1156593.35it/s]\n", 179 | "INFO:decrypt.scrape_parse.guardian_load:Counter({1: 118540, 2: 20105, 3: 2929, 4: 686, 5: 112, 6: 8})\n", 180 | "INFO:decrypt.scrape_parse.guardian_load:Clue list length matches Decrypting paper expected length\n", 181 | "INFO:decrypt.scrape_parse.guardian_load:Got splits of lenghts [85149, 28710, 28521]\n", 182 | "INFO:decrypt.scrape_parse.guardian_load:First three clues of train set:\n", 183 | "\t[CleanGuardianClue(clue='Unrivalled heavyweight caught in first moment?', lengths=[6, 2, 4], soln='secondtonone', soln_with_spaces='second to none', idx=101709, dataset='', across_or_down='', pos=(0, 0), unique_clue_id='', type='cryptic', number=0, id='', creator='Picaroon', orig_lengths='6,2,4', lengths_punctuation={','}), CleanGuardianClue(clue='Real sexist entertainer on the inside', lengths=[8], soln='existent', soln_with_spaces='existent', idx=48990, dataset='', across_or_down='', pos=(0, 0), unique_clue_id='', type='cryptic', number=0, id='', creator='Brummie', orig_lengths='8', lengths_punctuation=set()), CleanGuardianClue(clue='Assumes toads move around quietly', lengths=[6], soln='adopts', soln_with_spaces='adopts', idx=54280, dataset='', across_or_down='', pos=(0, 0), unique_clue_id='', type='cryptic', number=0, id='', creator='Rufus', orig_lengths='6', lengths_punctuation=set())]\n", 184 | "100%|██████████| 142380/142380 [00:00<00:00, 194191.48it/s]\n", 185 | "INFO:decrypt.scrape_parse.guardian_load:Counter({1: 118540, 2: 20105, 3: 2929, 4: 686, 5: 112, 6: 8})\n", 186 | "INFO:decrypt.scrape_parse.guardian_load:Clue list length matches Decrypting paper expected length\n", 187 | "100%|██████████| 142380/142380 [00:00<00:00, 242245.08it/s]\n", 188 | "INFO:decrypt.scrape_parse.guardian_load:Got splits of lenghts [75847, 32628, 33905]\n", 189 | "INFO:decrypt.scrape_parse.guardian_load:First three clues of train set:\n", 190 | "\t[CleanGuardianClue(clue='Sailor boy in his hammock', lengths=[4], soln='abed', soln_with_spaces='abed', idx=34809, dataset='', across_or_down='', pos=(0, 0), unique_clue_id='', type='cryptic', number=0, id='', creator='Rufus', orig_lengths='4', lengths_punctuation=set()), CleanGuardianClue(clue='With a degree, I leave this subject', lengths=[5], soln='maths', soln_with_spaces='maths', idx=412, dataset='', across_or_down='', pos=(0, 0), unique_clue_id='', type='cryptic', number=0, id='', creator='Rufus', orig_lengths='5', lengths_punctuation=set()), CleanGuardianClue(clue='Burrow to cure limb and make sure one gets up', lengths=[3, 3, 5], soln='setthealarm', soln_with_spaces='set the alarm', idx=116809, dataset='', across_or_down='', pos=(0, 0), unique_clue_id='', type='cryptic', number=0, id='', creator='Araucaria', orig_lengths='3,3,5', lengths_punctuation={','})]\n" 191 | ] 192 | } 193 | ], 194 | "source": [ 195 | "## testing: extra verify step\n", 196 | "tup = load_guardian_splits(k_json_file)\n", 197 | "tup = load_guardian_splits_disjoint(k_json_file)\n", 198 | "tup = load_guardian_splits_disjoint_hash(k_json_file)\n", 199 | "\n", 200 | "\n", 201 | "\n" 202 | ], 203 | "metadata": { 204 | "collapsed": false, 205 | "pycharm": { 206 | "name": "#%%\n" 207 | } 208 | } 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "outputs": [], 214 | "source": [], 215 | "metadata": { 216 | "collapsed": false, 217 | "pycharm": { 218 | "name": "#%%\n" 219 | } 220 | } 221 | } 222 | ], 223 | "metadata": { 224 | "kernelspec": { 225 | "display_name": "Python 3", 226 | "language": "python", 227 | "name": "python3" 228 | }, 229 | "language_info": { 230 | "codemirror_mode": { 231 | "name": "ipython", 232 | "version": 2 233 | }, 234 | "file_extension": ".py", 235 | "mimetype": "text/x-python", 236 | "name": "python", 237 | "nbconvert_exporter": "python", 238 | "pygments_lexer": "ipython2", 239 | "version": "2.7.6" 240 | } 241 | }, 242 | "nbformat": 4, 243 | "nbformat_minor": 0 244 | } -------------------------------------------------------------------------------- /decrypt/scrape_parse/make_public_data_datasets.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "- This file created after make_public_data\n", 7 | "- Use to produce final datasets" 8 | ], 9 | "metadata": { 10 | "collapsed": false, 11 | "pycharm": { 12 | "name": "#%% md\n" 13 | } 14 | } 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": { 20 | "collapsed": true, 21 | "pycharm": { 22 | "name": "#%%\n" 23 | } 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "%load_ext autoreload\n", 28 | "%autoreload 2" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "outputs": [], 35 | "source": [ 36 | "from decrypt.scrape_parse.guardian_load import orig_get_clean_clues, get_clean_clues, load_guardian_splits, load_guardian_splits_disjoint, load_guardian_splits_disjoint_hash" 37 | ], 38 | "metadata": { 39 | "collapsed": false, 40 | "pycharm": { 41 | "name": "#%%\n" 42 | } 43 | } 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 3, 48 | "outputs": [], 49 | "source": [ 50 | "import json\n", 51 | "import dataclasses" 52 | ], 53 | "metadata": { 54 | "collapsed": false, 55 | "pycharm": { 56 | "name": "#%%\n" 57 | } 58 | } 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 4, 63 | "outputs": [], 64 | "source": [ 65 | "class SetEncoder(json.JSONEncoder):\n", 66 | " def default(self, obj):\n", 67 | " if isinstance(obj, set):\n", 68 | " return list(obj)\n", 69 | " return json.JSONEncoder.default(self, obj)" 70 | ], 71 | "metadata": { 72 | "collapsed": false, 73 | "pycharm": { 74 | "name": "#%%\n" 75 | } 76 | } 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 5, 81 | "outputs": [], 82 | "source": [ 83 | "k_data_dir =\"/Users/jsrozner/JROZ/Programming/cryptic/decrypt/data/\"\n", 84 | "k_json_file = k_data_dir + \"guardian_2020_10_08.json\"" 85 | ], 86 | "metadata": { 87 | "collapsed": false, 88 | "pycharm": { 89 | "name": "#%%\n" 90 | } 91 | } 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 21, 96 | "outputs": [], 97 | "source": [ 98 | "def to_dict(json_blob):\n", 99 | " return list(map(dataclasses.asdict, json_blob))\n", 100 | "\n", 101 | "def write_split_to_json(fname, fcn):\n", 102 | " split_tuple = fcn(k_json_file)\n", 103 | " train, val, test = split_tuple[2]\n", 104 | " json_dict = dict(\n", 105 | " train = to_dict(train),\n", 106 | " val = to_dict(val),\n", 107 | " test = to_dict(test)\n", 108 | " )\n", 109 | " with open(k_data_dir + fname, 'w') as f:\n", 110 | " json.dump(json_dict, f, cls=SetEncoder)" 111 | ], 112 | "metadata": { 113 | "collapsed": false, 114 | "pycharm": { 115 | "name": "#%%\n" 116 | } 117 | } 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "outputs": [], 123 | "source": [ 124 | "write_split_to_json('naive_random.json', load_guardian_splits)\n", 125 | "write_split_to_json('disjoint.json', load_guardian_splits_disjoint)\n", 126 | "write_split_to_json('disjoint_word_init.json', load_guardian_splits_disjoint_hash)\n", 127 | "\n" 128 | ], 129 | "metadata": { 130 | "collapsed": false, 131 | "pycharm": { 132 | "name": "#%%\n" 133 | } 134 | } 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 24, 139 | "outputs": [ 140 | { 141 | "name": "stderr", 142 | "output_type": "stream", 143 | "text": [ 144 | "100%|██████████| 142380/142380 [00:00<00:00, 846181.54it/s]\n", 145 | "INFO:decrypt.scrape_parse.guardian_load:Counter({1: 118540, 2: 20105, 3: 2929, 4: 686, 5: 112, 6: 8})\n", 146 | "INFO:decrypt.scrape_parse.guardian_load:Clue list length matches Decrypting paper expected length\n", 147 | "INFO:decrypt.scrape_parse.guardian_load:Got splits of lenghts [85428, 28476, 28476]\n", 148 | "INFO:decrypt.scrape_parse.guardian_load:First three clues of train set:\n", 149 | "\t[CleanGuardianClue(clue='Suffering to grasp edge of plant', lengths=[8], soln='agrimony', soln_with_spaces='agrimony', idx=-1, dataset='', across_or_down='', pos=(0, 0), unique_clue_id='', type='cryptic', number=0, id='', creator='Chifonie', orig_lengths='8', lengths_punctuation=set()), CleanGuardianClue(clue='Honour Ben and Noel with new order', lengths=[7], soln='ennoble', soln_with_spaces='ennoble', idx=-1, dataset='', across_or_down='', pos=(0, 0), unique_clue_id='', type='cryptic', number=0, id='', creator='Rufus', orig_lengths='7', lengths_punctuation=set()), CleanGuardianClue(clue='Bit the royal we love? Cheers!', lengths=[4], soln='iota', soln_with_spaces='iota', idx=-1, dataset='', across_or_down='', pos=(0, 0), unique_clue_id='', type='cryptic', number=0, id='', creator='Screw', orig_lengths='4', lengths_punctuation=set())]\n", 150 | "100%|██████████| 142380/142380 [00:00<00:00, 898765.90it/s]\n", 151 | "INFO:decrypt.scrape_parse.guardian_load:Counter({1: 118540, 2: 20105, 3: 2929, 4: 686, 5: 112, 6: 8})\n", 152 | "INFO:decrypt.scrape_parse.guardian_load:Clue list length matches Decrypting paper expected length\n", 153 | "INFO:decrypt.scrape_parse.guardian_load:Got splits of lenghts [85149, 28710, 28521]\n", 154 | "INFO:decrypt.scrape_parse.guardian_load:First three clues of train set:\n", 155 | "\t[CleanGuardianClue(clue='Unrivalled heavyweight caught in first moment?', lengths=[6, 2, 4], soln='secondtonone', soln_with_spaces='second to none', idx=-1, dataset='', across_or_down='', pos=(0, 0), unique_clue_id='', type='cryptic', number=0, id='', creator='Picaroon', orig_lengths='6,2,4', lengths_punctuation={','}), CleanGuardianClue(clue='Real sexist entertainer on the inside', lengths=[8], soln='existent', soln_with_spaces='existent', idx=-1, dataset='', across_or_down='', pos=(0, 0), unique_clue_id='', type='cryptic', number=0, id='', creator='Brummie', orig_lengths='8', lengths_punctuation=set()), CleanGuardianClue(clue='Assumes toads move around quietly', lengths=[6], soln='adopts', soln_with_spaces='adopts', idx=-1, dataset='', across_or_down='', pos=(0, 0), unique_clue_id='', type='cryptic', number=0, id='', creator='Rufus', orig_lengths='6', lengths_punctuation=set())]\n", 156 | "100%|██████████| 142380/142380 [00:00<00:00, 606420.18it/s]\n", 157 | "INFO:decrypt.scrape_parse.guardian_load:Counter({1: 118540, 2: 20105, 3: 2929, 4: 686, 5: 112, 6: 8})\n", 158 | "INFO:decrypt.scrape_parse.guardian_load:Clue list length matches Decrypting paper expected length\n", 159 | "INFO:decrypt.scrape_parse.guardian_load:Got splits of lenghts [75847, 32628, 33905]\n", 160 | "INFO:decrypt.scrape_parse.guardian_load:First three clues of train set:\n", 161 | "\t[CleanGuardianClue(clue='Sailor boy in his hammock', lengths=[4], soln='abed', soln_with_spaces='abed', idx=-1, dataset='', across_or_down='', pos=(0, 0), unique_clue_id='', type='cryptic', number=0, id='', creator='Rufus', orig_lengths='4', lengths_punctuation=set()), CleanGuardianClue(clue='With a degree, I leave this subject', lengths=[5], soln='maths', soln_with_spaces='maths', idx=-1, dataset='', across_or_down='', pos=(0, 0), unique_clue_id='', type='cryptic', number=0, id='', creator='Rufus', orig_lengths='5', lengths_punctuation=set()), CleanGuardianClue(clue='Burrow to cure limb and make sure one gets up', lengths=[3, 3, 5], soln='setthealarm', soln_with_spaces='set the alarm', idx=-1, dataset='', across_or_down='', pos=(0, 0), unique_clue_id='', type='cryptic', number=0, id='', creator='Araucaria', orig_lengths='3,3,5', lengths_punctuation={','})]\n" 162 | ] 163 | } 164 | ], 165 | "source": [ 166 | "tup = load_guardian_splits(None)\n", 167 | "tup = load_guardian_splits_disjoint(None)\n", 168 | "tup = load_guardian_splits_disjoint_hash(None)\n", 169 | "\n" 170 | ], 171 | "metadata": { 172 | "collapsed": false, 173 | "pycharm": { 174 | "name": "#%%\n" 175 | } 176 | } 177 | } 178 | ], 179 | "metadata": { 180 | "kernelspec": { 181 | "display_name": "Python 3", 182 | "language": "python", 183 | "name": "python3" 184 | }, 185 | "language_info": { 186 | "codemirror_mode": { 187 | "name": "ipython", 188 | "version": 2 189 | }, 190 | "file_extension": ".py", 191 | "mimetype": "text/x-python", 192 | "name": "python", 193 | "nbconvert_exporter": "python", 194 | "pygments_lexer": "ipython2", 195 | "version": "2.7.6" 196 | } 197 | }, 198 | "nbformat": 4, 199 | "nbformat_minor": 0 200 | } -------------------------------------------------------------------------------- /decrypt/scrape_parse/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | 4 | 5 | def _gen_filename(base_dir: str, subsite: str, ext: str, idx: int = None, return_glob=False): 6 | """ 7 | :param base_dir: 8 | :param subsite: 9 | :param ext: 10 | :param idx: 11 | :param return_glob: 12 | :return: 13 | """ 14 | filename = os.path.join(base_dir, subsite) 15 | if return_glob: 16 | filename += "*" + ext 17 | else: 18 | filename += str(idx) + ext 19 | return filename 20 | 21 | # normal hash function is not deterministic across python runs 22 | def hash(input: str): 23 | hash_obj = hashlib.md5(input.encode()) 24 | return hash_obj.hexdigest() 25 | 26 | def str_hash(input: str) -> int: 27 | hex_hash = hash(input) 28 | return int(hex_hash, 16) 29 | 30 | 31 | -------------------------------------------------------------------------------- /experiments/curricular.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "pycharm": { 7 | "name": "#%% md\n" 8 | } 9 | }, 10 | "source": [ 11 | "## Setting up for curricular experiment\n", 12 | "\n", 13 | "This assumes you have already followed the instructions in `baselines/baseline_t5`, which will set up the baseline clue files for model input\n", 14 | "\n", 15 | "### Datasets\n", 16 | "1. Download and unzip the xd cw crossword set from http://xd.saul.pw/xd-clues.zip.\n", 17 | " - Save it as './data/original/xd/clues.tsv'\n", 18 | "2. Preprocess the dataset using this notebook\n", 19 | "3. The dataset will be saved to k_acw_export_dir (as a single train.json file)\n", 20 | "4. We will also produce the anagram dataset\n" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": { 27 | "pycharm": { 28 | "name": "#%%\n" 29 | } 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "%load_ext autoreload\n", 34 | "%autoreload 2\n", 35 | "\n", 36 | "from decrypt.scrape_parse.acw_load import get_clean_xd_clues\n", 37 | "from decrypt import config\n", 38 | "from decrypt.common.util_data import clue_list_tuple_to_train_split_json\n", 39 | "from decrypt.common import validation_tools as vt\n", 40 | "\n", 41 | "k_xd_orig_tsv = config.DataDirs.OriginalData.k_xd_cw # ./data/original/xd/clues.tsv\n", 42 | "k_acw_export_dir = config.DataDirs.DataExport.xd_cw_json" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": { 49 | "pycharm": { 50 | "name": "#%%\n" 51 | } 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "# defaults to strip periods, remove questions, remove abbrevs, remove fillin\n", 56 | "stc_map, all_clues = get_clean_xd_clues(k_xd_orig_tsv,\n", 57 | " remove_if_not_in_dict=False,\n", 58 | " do_filter_dupes=True)\n", 59 | "clue_list_tuple_to_train_split_json((all_clues,),\n", 60 | " comment='ACW set; xd cw set, all',\n", 61 | " export_dir=k_acw_export_dir,\n", 62 | " overwrite=False)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": { 69 | "pycharm": { 70 | "name": "#%%\n" 71 | } 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "# produce anagram datasets\n", 76 | "# roughly 3 minutes to complete\n", 77 | "from decrypt.common import anagrammer\n", 78 | "anagrammer.gen_db_with_both_inputs(update_flag=\"overwrite\")\n", 79 | "\n", 80 | "from decrypt.common.util_data import (\n", 81 | " get_anags,\n", 82 | " write_json_tuple\n", 83 | ")\n", 84 | "import json\n", 85 | "import os" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": { 92 | "pycharm": { 93 | "name": "#%%\n" 94 | } 95 | }, 96 | "outputs": [], 97 | "source": [ 98 | "def make_anag_sets_json():\n", 99 | " all_anags = get_anags(max_num_words=-1)\n", 100 | " json_list = []\n", 101 | " for idx, a_list in enumerate(all_anags):\n", 102 | " json_list.append(dict(idx=idx,\n", 103 | " anag_list=a_list))\n", 104 | " print(json_list[0])\n", 105 | "\n", 106 | " # normally would be (idx, input, tgt)\n", 107 | " output_tuple = [json_list,]\n", 108 | "\n", 109 | " os.makedirs(config.DataDirs.DataExport.anag_dir)\n", 110 | " write_json_tuple(output_tuple,\n", 111 | " comment=\"List of all anagram groupings\",\n", 112 | " export_dir=config.DataDirs.DataExport.anag_dir,\n", 113 | " overwrite=False)\n", 114 | "\n", 115 | "def make_anag_indic_list_json():\n", 116 | " # make the indicator list\n", 117 | " with open(config.DataDirs.OriginalData.k_deits_anagram_list, 'r') as f:\n", 118 | " all_anag_indicators = f.readlines()\n", 119 | " print(len(all_anag_indicators))\n", 120 | "\n", 121 | " final_indic_list = []\n", 122 | " for a in all_anag_indicators:\n", 123 | " final_indic_list.append(a.replace('_', \" \").strip())\n", 124 | " with open(config.DataDirs.DataExport.anag_indics, 'w') as f:\n", 125 | " json.dump(final_indic_list,f)" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": { 132 | "pycharm": { 133 | "name": "#%%\n" 134 | } 135 | }, 136 | "outputs": [], 137 | "source": [ 138 | "make_anag_sets_json()" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": { 145 | "pycharm": { 146 | "name": "#%%\n" 147 | } 148 | }, 149 | "outputs": [], 150 | "source": [ 151 | "make_anag_indic_list_json()\n", 152 | "\n" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": { 158 | "pycharm": { 159 | "name": "#%% md\n" 160 | } 161 | }, 162 | "source": [ 163 | "## Curricular training\n", 164 | "1. At this point you should have a files at\n", 165 | " - `./data/clue_json/curricular/ACW/train.json`\n", 166 | " - `./data/clue_json/curricular/anagram/[train.json, anag_indics.json]`\n", 167 | "\n", 168 | "2. Running curricular training is the same as running main t5 vanilla train, except that we pass an extra multitask flag, which specifies the curriculum to use. See `seq2seq/multitask_config`. You should pass one of the names from `multi_config` dict in that file\n", 169 | "\n", 170 | "For example, to train the naive split with the top performing curricular approach (i.e. the result in table 3 that is ACW + ACW-descramble)\n", 171 | "```python\n", 172 | "python train_clues.py --default_train=base --name=naive_top_curricular --project=curricular --wandb_dir='./wandb' --data_dir='../data/clue_json/guardian/naive_random' --multitask=ACW__ACW_descramble\n", 173 | "```\n", 174 | "\n", 175 | "Note that the modifications on the dataset are done at the\n", 176 | "\n", 177 | "3. To produce Table 3 of the results\n", 178 | " - we don't need to do a model_eval run since the outputted predictions have 5 generations\n", 179 | " (which is all we report for that table (for faster experimental iteration).\n", 180 | " - we need to run `load_and_run_t5` on all outputs (column 1) and on the anagram subset (column 2)\n", 181 | " See below for how we do this.\n", 182 | "\n", 183 | "4. For our top result in Table 2 (main resuls) we\n", 184 | " 1. scale up the curricular period (to 4 total epochs)\n", 185 | "```python\n", 186 | "python train_clues.py --default_train=base --name=naive_top_curricular --project=curricular --wandb_dir='./wandb' --data_dir='../data/clue_json/guardian/naive_random' --multitask=final_top_result_scaled_up\n", 187 | "```\n", 188 | " 2. eval with full 100 generations, as before:\n", 189 | "e.g., if epoch 10 is best (you'll need to set the run_name)\n", 190 | "This runs the eval set (change the run_name)\n", 191 | "```python\n", 192 | "python train_clues.py --default_val=base --name=curricular_naive_top --project=curricular --data_dir='../data/clue_json/guardian/naive_random' --ckpt_path='./wandb/run_name/files/epoch_10.pth.tar\n", 193 | "```\n" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": { 200 | "pycharm": { 201 | "name": "#%%\n" 202 | } 203 | }, 204 | "outputs": [], 205 | "source": [ 206 | "from decrypt.common.label_anagrams import make_label_set\n", 207 | "\n", 208 | "labels = make_label_set()" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": { 215 | "pycharm": { 216 | "name": "#%%\n" 217 | } 218 | }, 219 | "outputs": [], 220 | "source": [ 221 | "# note that this should be run directly on the top model output from curricular training\n", 222 | "# otherwise (eg. if 100 beams were used), the top 5 output\n", 223 | "# sequences would be expected to change\n", 224 | "# remember not to append .json\n", 225 | "\n", 226 | "# eval on the full output (5 beams / 5 sequences)\n", 227 | "# this is column 1 of table 3\n", 228 | "vt.load_and_run_t5('outputs/model_output.preds',\n", 229 | " # pre_truncate=5, # should not be needed since we have only 5 outputs\n", 230 | " do_length_filter=True)\n", 231 | "\n", 232 | "# run on the anagram subset\n", 233 | "# this is column 2 of table 3\n", 234 | "vt.load_and_run_t5('outputs/model_output.preds',\n", 235 | " filter_fcn=vt.make_set_filter(labels, 'anag_direct'),\n", 236 | " # pre_truncate=5,\n", 237 | " do_length_filter=True)\n", 238 | "\n", 239 | "# we are looking at agg_top_match (which is after filter)" 240 | ] 241 | } 242 | ], 243 | "metadata": { 244 | "kernelspec": { 245 | "display_name": "Python 3", 246 | "language": "python", 247 | "name": "python3" 248 | }, 249 | "language_info": { 250 | "codemirror_mode": { 251 | "name": "ipython", 252 | "version": 2 253 | }, 254 | "file_extension": ".py", 255 | "mimetype": "text/x-python", 256 | "name": "python", 257 | "nbconvert_exporter": "python", 258 | "pygments_lexer": "ipython2", 259 | "version": "2.7.6" 260 | } 261 | }, 262 | "nbformat": 4, 263 | "nbformat_minor": 0 264 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bs4 2 | scikit-learn 3 | tqdm 4 | 5 | -------------------------------------------------------------------------------- /seq2seq/args_cryptic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Shortcut args for cryptics 3 | """ 4 | import argparse 5 | 6 | def add_args(parser: argparse.ArgumentParser): 7 | """ 8 | Always need to specify 9 | --name (run name) 10 | --project (for wandb) 11 | --data_dir (see config.py for where outputs are stored) 12 | --wandb_dir 13 | 14 | for curricular training, need to specify 15 | --multitask 16 | 17 | for eval (--default_eval) 18 | --ckpt_path - will be loaded 19 | --test - whether to run on the test set 20 | """ 21 | 22 | ## shortcuts 23 | parser.add_argument('--dev_run', 24 | action='store_true', 25 | help='sets wandb env mode to dry run, reduces number of train/val to 1000') 26 | 27 | parser.add_argument('--default_train', 28 | type=str, 29 | default=None, 30 | help='Default configurations:' 31 | 'Set to either base or cryptonite') 32 | parser.add_argument('--default_val', 33 | type=str, 34 | default=None, 35 | help='Various default configurations:' 36 | 'Set to either base or cryptonite') 37 | #### 38 | parser.add_argument('--name', 39 | type=str, 40 | required=True, 41 | help='Name to identify training or test run.' 42 | 'Will be used by wandb to identify run') 43 | parser.add_argument('--data_dir', 44 | type=str, 45 | required=True, 46 | help='Directory containing the train, val, test files') 47 | parser.add_argument('--ckpt_path', 48 | type=str, 49 | default=None, 50 | help='If given then will load from given model path. ' 51 | 'You should either set no_train (val only), ' 52 | 'or explicitly specify whether resuming train with --resume_train or ' 53 | '--no_resume_train') 54 | 55 | parser.add_argument('--resume_train', 56 | action='store_true', 57 | dest='resume_train', 58 | help='whether to resume train with optimizer and scheduler state.' 59 | 'Need to pass a ckpt_path') 60 | parser.add_argument('--no_resume_train', 61 | action='store_false', 62 | dest='resume_train') 63 | parser.set_defaults(resume_train=None) # will be false for eval 64 | 65 | parser.add_argument('--no_train', 66 | action='store_true', 67 | help='Set to no_train when we want to do eval only.') 68 | 69 | parser.add_argument('--test', 70 | action='store_true', 71 | help='Eval will be done on test rather than val set') 72 | 73 | ## new 74 | parser.add_argument('--ada_constant', 75 | action='store_true', 76 | help='Whether to use constant LR with adafactor. Used for t5-large training') 77 | parser.add_argument('--multi_gpu', 78 | type=int, 79 | default=None, 80 | help='Whether to use dataparallel with multiple gpus.' 81 | 'Also need to set k_data_parallel=True in train_abc') 82 | 83 | # defaults that are auto set but that will vary / depend on --default_eval and 84 | # these will have defaults set (varies for train / eval) 85 | parser.add_argument('--project', 86 | type=str, 87 | required=True) 88 | parser.add_argument('--num_epochs', 89 | type=int, 90 | default=15, # set to 1 by default_eval 91 | help='Number of epochs for which to train. Negative means forever.' 92 | ' i.e. all training data this many times through') 93 | parser.add_argument('--generation_beams', 94 | type=int, 95 | default=5, # default eval 100 96 | help='Number of beams (and return sequences) to use in generation') 97 | parser.add_argument('--batch_size', 98 | type=int, 99 | default=128, 100 | help='Batch size per GPU. Scales automatically when \ 101 | multiple GPUs are available.') 102 | 103 | # defaults for both 104 | parser.add_argument('--model_name', 105 | type=str, 106 | default='t5-base', 107 | help='which t5 model to load, e.g. t5-small, t5-base') 108 | parser.add_argument('--wandb_dir', 109 | type=str, 110 | required=True, 111 | help='Directory in which to add folder wandb for storing all files') 112 | parser.set_defaults(do_sample=True) 113 | parser.add_argument('--no-save', dest='do_save', action='store_false') 114 | parser.set_defaults(do_save=True) # will be false for eval 115 | parser.set_defaults(batched_dl=True) 116 | parser.set_defaults(fast_tokenizer=True) 117 | parser.set_defaults(ada=True) 118 | parser.set_defaults(add_special_tokens=False) 119 | parser.set_defaults(use_json=True) 120 | 121 | # other multitask arguments are in multitaskconfig 122 | parser.add_argument('--multitask', 123 | type=str, 124 | help='Whether to do multitask training. To do multitask, provide a ' 125 | 'update cfg/multi_cfg with a new object. Specify that config by str') 126 | 127 | # don't modify these 128 | parser.add_argument('--num_train', 129 | type=int, 130 | default=-1, 131 | help='Number of train examples to consider. Will reduce dataset. Neg ignores') 132 | parser.add_argument('--num_val', 133 | type=int, 134 | default=-1, 135 | help='Number of val examples to consider. Will reduce dataset. Neg ignores') 136 | parser.add_argument('--multitask_num', 137 | type=int, 138 | default=-1, 139 | help='number of multitask examples to use in the dataloader; -1 is all.' 140 | 'Generally used only by dev_run to speed up') 141 | parser.add_argument('--num_workers', 142 | type=int, 143 | default=4, 144 | help='Number of sub-processes to use per data loader.') 145 | parser.add_argument('--seed', 146 | type=int, 147 | default=42, 148 | help='Random seed for reproducibility.') 149 | parser.add_argument('--val_freq', 150 | type=int, 151 | default=None, 152 | help='How often to do validation in thousands of steps (must be fewer than # examples)') 153 | parser.add_argument('--early_stopping', 154 | type=str, 155 | default=None, 156 | help='Metric to track for early stopping, assumes lower is better') 157 | parser.add_argument('--grad_accum_steps', 158 | type=int, 159 | default=1, 160 | help='Number of batches to accumulate') 161 | # default to don't use 162 | parser.add_argument('--comment', 163 | type=str, 164 | default="", 165 | help='A comment to store with the run') 166 | 167 | 168 | def get_args(extra_args_fn=None): 169 | """Get arguments needed in train.py.""" 170 | 171 | # parse the args 172 | parser = argparse.ArgumentParser('Train T5 for decrypting cryptic crosswords') 173 | add_args(parser) 174 | if extra_args_fn is not None: 175 | extra_args_fn(parser) 176 | args = parser.parse_args() 177 | 178 | ### 179 | # misc validation 180 | if args.dev_run or args.no_train: 181 | args.do_save = False 182 | args.val_freq = None 183 | # note that example counts for --dev_run will be changed (reduced) in pre_setup function 184 | # see train_abc 185 | 186 | ### 187 | 188 | ### 189 | # default training / val 190 | assert args.default_train is None or args.default_val is None, \ 191 | f'Cannot do both default_train and default_val simultaneously' 192 | if args.default_train is not None: 193 | if args.default_train == 'base': 194 | # args.project = "cryptics_train" 195 | # ada_constant is False (i.e. we use relative step) 196 | args.generation_beams = 5 197 | args.batch_size = 256 # alternatively can do 128 and accum_steps=2 198 | # grad_accum_steps = 1 199 | # args.num_epochs = 15 200 | # default model is t5-base 201 | 202 | elif args.default_train == 'cryptonite': 203 | # args.project = "cryptonite" 204 | args.ada_constant = True 205 | args.generation_beams = 5 206 | args.batch_size = 64 207 | args.grad_accum_steps = 12 208 | # args.num_epochs = 15 # for the naive split, can train to 20 epochs 209 | args.model_name = 't5-large' 210 | # args.val_freq = 100 # set to 100 for the disjoint set 211 | else: 212 | raise NotImplemented 213 | 214 | if args.default_val is not None: 215 | assert args.ckpt_path is not None, \ 216 | f'To run default eval, need to pass a model checkpoint with --ckpt_path' 217 | args.no_train = True 218 | args.val_freq = None 219 | args.resume_train = False 220 | args.num_epochs = 1 # just needed for args validation 221 | args.do_save = False # no checkpointing 222 | 223 | if args.default_val == 'base': 224 | # args.project = "cryptics_val" 225 | args.generation_beams = 100 226 | args.batch_size = 16 # can change this depending on your GPU; doesn't affect results 227 | # default model is t5-base 228 | elif args.default_val == 'cryptonite': 229 | # default for cryptonite eval 230 | # args.project = "cryptonite" 231 | args.generation_beams = 5 # cryptonite originally used only 5 beams, copy their implementation 232 | args.batch_size = 64 # can change depending on GPU; doesn't affect val resulst 233 | args.model_name = 't5-large' 234 | else: 235 | raise NotImplemented(f'Invalid option {args.default_val} for --default_val ') 236 | ###### 237 | 238 | return args 239 | -------------------------------------------------------------------------------- /seq2seq/common_seq/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsrozner/decrypt/71679f76ce19cd098b0988d0255b03d7e47f0c05/seq2seq/common_seq/__init__.py -------------------------------------------------------------------------------- /seq2seq/common_seq/collate_fns.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | 4 | from .types import * 5 | from .util_dataloader_batch import default_collate_fn_json 6 | 7 | log = logging.getLogger(__name__) 8 | 9 | 10 | # collate function factory 11 | def collate_fn_from_pretokenize(pretokenize_fn: Callable) -> collate_fn_type: 12 | def coll_fn(tokenizer: PreTrainedTokenizerFast, batch_list: List[Dict]) -> Dict: 13 | return default_collate_fn_json(tokenizer, batch_list, pre_tokenize_fn=pretokenize_fn) 14 | return coll_fn 15 | 16 | 17 | def _add_label(orig_input: str, label: str): 18 | return f'{label}: {orig_input}' 19 | 20 | 21 | def make_pretokenize_prepend_label(label: str) -> pretokenize_fn: 22 | def pre_tokenize_prepend_label(batch_list: List[Dict]) -> Tuple[List, ...]: 23 | src_text, tgt_text, idxs = [], [], [] 24 | for e in batch_list: 25 | input = e['input'] 26 | tgt = e['target'] 27 | idx = e['idx'] 28 | 29 | # input = f'{label}: {input}' 30 | input = _add_label(input, label) 31 | 32 | src_text.append(input) 33 | tgt_text.append(tgt) 34 | idxs.append(idx) 35 | 36 | return src_text, tgt_text, idxs 37 | 38 | return pre_tokenize_prepend_label 39 | 40 | 41 | 42 | # note that casing will be off slightly on this - it will always be lower case 43 | def make_pretokenize_descramble(label: Optional[str], word_only: bool = False): 44 | rng = random.Random(42) 45 | 46 | def randomize_letters(s: str) -> str: 47 | x = list(s) 48 | rng.shuffle(x) 49 | return "".join(x) 50 | 51 | def pre_tokenize_descramble(batch_list: List[Dict]): 52 | src_text, tgt_text, idxs = [], [], [] 53 | for e in batch_list: 54 | input = e['input'] 55 | tgt = e['target'] 56 | idx = e['idx'] 57 | 58 | tgt_scrambled = randomize_letters(tgt) 59 | 60 | # parse length string (and move to end) 61 | splits = input.split(' ') 62 | assert splits[-1][0] == '(' 63 | input_no_len = ' '.join(splits[:-1]) 64 | input_no_len_lower = input_no_len[0].lower() + input_no_len[1:] 65 | len_str = splits[-1] 66 | 67 | if word_only: 68 | input = f'{tgt_scrambled} {len_str}' 69 | else: 70 | if rng.randint(0, 1) == 0: 71 | input = f'{tgt_scrambled} {input_no_len_lower} {len_str}' 72 | else: 73 | input = f'{input_no_len_lower} {tgt_scrambled} {len_str}' 74 | 75 | # finalize 76 | if label is not None: 77 | input = _add_label(input, label) 78 | src_text.append(input) 79 | tgt_text.append(tgt) 80 | idxs.append(idx) 81 | 82 | return src_text, tgt_text, idxs 83 | return pre_tokenize_descramble 84 | 85 | ## for anagramming 86 | # note that casing will be off slightly on this - it will always be lower case 87 | import json 88 | def make_pretokenize_anagram(label: Optional[str], 89 | anag_indic_file: str): 90 | 91 | logging.info(f'Opening {anag_indic_file} for anag indicators') 92 | with open(anag_indic_file, 'r') as f: 93 | anag_indics = json.load(f) 94 | 95 | rng = random.Random(42) 96 | 97 | 98 | def pre_tokenize_descramble(batch_list: List[Dict]): 99 | src_text, tgt_text, idxs = [], [], [] 100 | anag_indic_sampled = random.choices(anag_indics, k=len(batch_list)) # with replacement 101 | 102 | for e, anag_indic in zip(batch_list, anag_indic_sampled): 103 | anag_list = e['anag_list'] # list of words mapping to common set of letters 104 | idx = e['idx'] 105 | 106 | choices = random.sample(anag_list, 2) # no replacement 107 | lhs, rhs = tuple(choices) 108 | 109 | # add lengths (target is the rhs, so lengths uses the rhs) 110 | lengths = list(map(lambda x: str(len(x)), 111 | rhs.split(' '))) 112 | len_str = f'({",".join(lengths)})' 113 | 114 | # input and target 115 | if rng.randint(0, 1) == 0: 116 | input = f'{lhs} {anag_indic} {len_str}' 117 | else: 118 | input = f'{anag_indic} {lhs} {len_str}' 119 | tgt = rhs 120 | 121 | # finalize 122 | if label is not None: 123 | input = _add_label(input, label) 124 | src_text.append(input) 125 | tgt_text.append(tgt) 126 | idxs.append(idx) 127 | 128 | return src_text, tgt_text, idxs 129 | return pre_tokenize_descramble 130 | 131 | -------------------------------------------------------------------------------- /seq2seq/common_seq/types.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | from transformers import PreTrainedTokenizerFast 4 | 5 | collate_fn_type = Callable[[PreTrainedTokenizerFast, List[Dict]], Dict] 6 | pretokenize_fn = Callable[[List[Dict]], Tuple[List, ...]] 7 | -------------------------------------------------------------------------------- /seq2seq/common_seq/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Substantially adapted from squad code 3 | 4 | """ 5 | from __future__ import annotations 6 | 7 | import logging 8 | import os 9 | import random 10 | import socket 11 | from dataclasses import dataclass 12 | from typing import Dict, Tuple, List, Optional 13 | 14 | import numpy as np 15 | import torch 16 | import tqdm 17 | import wandb 18 | 19 | log = logging.getLogger(__name__) 20 | 21 | @dataclass 22 | class ProcessedBatch: 23 | src_ids: torch.Tensor 24 | src_mask: torch.Tensor 25 | tgt_ids: torch.Tensor 26 | orig_text_input: List[str] 27 | orig_text_output: List[str] 28 | 29 | batch_size: int 30 | idxs: Optional[torch.Tensor] = None 31 | 32 | 33 | @dataclass 34 | class PerBatchValStep: 35 | """ 36 | Produced for each batch during a val call; Use to pass outputs around to metrics 37 | """ 38 | loss_val: Optional[float] = None 39 | outputs_greedy: Optional[List[str]] = None 40 | outputs_greedy_ids: Optional[torch.Tensor] = None 41 | outputs_sampled: Optional[List[List[str]]] = None 42 | 43 | 44 | ### 45 | # Useful misc utilities 46 | ### 47 | def symlink_dir(wb_run_obj, readable_name): 48 | path_spec = '{wandb_dir}/' + readable_name + '-{timespec}-{run_id}' 49 | sym_path = wb_run_obj._settings._path_convert(path_spec) 50 | 51 | # link {wandb_dir}/{run_mode}-{timespec}-{run_id}/files -> 52 | # {wandb_dir}/run_name-{timespec}-{run_id} 53 | os.symlink(wb_run_obj.dir, sym_path) 54 | log.info(f'sym: {sym_path} -> {wb_run_obj.dir}') 55 | 56 | def get_available_devices(assert_cuda=False) -> Tuple: 57 | """Get IDs of all available GPUs. 58 | 59 | Returns: 60 | device (torch.device): Main device (GPU 0 or CPU). 61 | gpu_ids (list): List of IDs of all GPUs that are available. 62 | """ 63 | gpu_ids = [] 64 | if torch.cuda.is_available(): 65 | gpu_ids += [gpu_id for gpu_id in range(torch.cuda.device_count())] 66 | if len(gpu_ids) > 1: 67 | log.warning("more than 1 gpu found") 68 | 69 | assert gpu_ids[0] == 0 70 | device = torch.device(f'cuda:{gpu_ids[0]}') # cuda:0 71 | 72 | # torch.cuda.set_device(device) # removed 4/15/2021 73 | else: 74 | if assert_cuda: 75 | raise ValueError('no cuda found') 76 | device = torch.device('cpu') 77 | 78 | log.info(f"Device: {device}\t GPU IDs: {gpu_ids}\t machine: {socket.gethostname()}\n") 79 | 80 | return device, gpu_ids 81 | 82 | 83 | def config_logger(logger, log_dir, log_level="debug", filename="log.txt"): 84 | """ 85 | 86 | :param logger: 87 | :return: 88 | """ 89 | if len(logger.handlers): 90 | log.warning(f'Logger had handlers already set WTF\n' 91 | f'..... CLEARING') 92 | logger.handlers.clear() 93 | 94 | class StreamHandlerWithTQDM(logging.Handler): 95 | """Let `logging` print without breaking `tqdm` progress bars. 96 | 97 | See Also: 98 | > https://stackoverflow.com/questions/38543506 99 | """ 100 | def emit(self, record): 101 | try: 102 | msg = self.format(record) 103 | tqdm.tqdm.write(msg) 104 | self.flush() 105 | except (KeyboardInterrupt, SystemExit): 106 | raise 107 | except: 108 | self.handleError(record) 109 | 110 | # Create logger 111 | if log_level == "debug": 112 | logger.setLevel(logging.DEBUG) 113 | elif log_level == "info": 114 | logger.setLevel(logging.INFO) 115 | else: 116 | raise ValueError(f"Invalid log level {log_level}") 117 | 118 | # Log everything (i.e., DEBUG level and above) to a file 119 | log_path = os.path.join(log_dir, filename) 120 | file_handler = logging.FileHandler(log_path) 121 | file_handler.setLevel(logging.DEBUG) 122 | 123 | # Log everything except DEBUG level (i.e., INFO level and above) to console 124 | console_handler = StreamHandlerWithTQDM() 125 | console_handler.setLevel(logging.INFO) 126 | 127 | # Create format for the logs 128 | file_formatter = logging.Formatter('[%(asctime)s] %(message)s', 129 | datefmt='%m.%d.%y %H:%M:%S') 130 | file_handler.setFormatter(file_formatter) 131 | # console_formatter = logging.Formatter('[%(asctime)s] %(message)s', 132 | # datefmt='%m.%d.%y %H:%M:%S') 133 | console_formatter = logging.Formatter( 134 | '[%(asctime)s] [%(filename)s:%(lineno)s - %(funcName)s()]\t %(message)s', 135 | datefmt='%m.%d %H:%M:%S') 136 | console_handler.setFormatter(console_formatter) 137 | 138 | # add the handlers to the logger 139 | logger.addHandler(file_handler) 140 | logger.addHandler(console_handler) 141 | 142 | 143 | def get_logger(log_dir, name, log_level="debug", filename="log.txt"): 144 | """Get a `logging.Logger` instance that prints to the console 145 | and an auxiliary file. 146 | 147 | Args: 148 | log_dir (str): Directory in which to create the log file. 149 | name (str): Name to identify the logs. 150 | 151 | Returns: 152 | logger (logging.Logger): Logger instance for logging events. 153 | """ 154 | logger = logging.getLogger(name) 155 | config_logger(logger, log_dir, log_level, filename) 156 | return logger 157 | 158 | 159 | def set_seed(seed=42): 160 | log.info("Setting seed") 161 | random.seed(seed) 162 | np.random.seed(seed) 163 | torch.manual_seed(seed) 164 | if torch.cuda.is_available(): 165 | torch.cuda.manual_seed_all(seed) 166 | # todo: cuda deterministic? 167 | 168 | 169 | class AverageMeter: 170 | """Keep track of average values over time. 171 | 172 | Adapted from: 173 | > https://github.com/pytorch/examples/blob/master/imagenet/main.py 174 | """ 175 | def __init__(self): 176 | self.avg = 0 177 | self.sum = 0 178 | self.count = 0 179 | 180 | def reset(self): 181 | """Reset meter.""" 182 | self.__init__() 183 | 184 | def update_sum_direct(self, num_succ, num_samples): 185 | self.count += num_samples 186 | self.sum += num_succ 187 | self.avg = self.sum/self.count 188 | 189 | def update(self, val: float, num_samples=1): 190 | """Update meter with new value `val`, the average of `num` samples. 191 | 192 | Args: 193 | val (float): Average value to update the meter with. 194 | num_samples (int): Number of samples that were averaged to 195 | produce `val`. 196 | """ 197 | self.count += num_samples 198 | self.sum += val * num_samples 199 | self.avg = self.sum / self.count 200 | 201 | ### 202 | # Tensorboard and other save functions 203 | ### 204 | 205 | def log_scalar(name: str, value, step=None, 206 | log_wandb=True): 207 | """ 208 | 1/13:2021: some calls use actual step; some calls use epoch. verify that wandb can handle this 209 | # this might fail with tbx since the step will vary considerably across runs 210 | """ 211 | if log_wandb: 212 | wandb.log({name: value}, step=step) 213 | 214 | def log_wandb_new(log_dict: Dict, 215 | use_step_for_logging: bool, 216 | step: int, 217 | epoch: int): 218 | """ 219 | 1/13:2021: some calls use actual step; some calls use epoch. verify that wandb can handle this 220 | # this might fail with tbx since the step will vary considerably across runs 221 | """ 222 | # todo: switch over to this once we are done with this project 223 | if use_step_for_logging: 224 | step_for_logging = step 225 | log_dict['epoch'] = epoch 226 | else: 227 | step_for_logging = epoch 228 | log_dict['all_step'] = step 229 | wandb.log(log_dict, step=step_for_logging) 230 | 231 | # step_for_logging = epoch 232 | # log_dict['true_step'] = step 233 | # wandb.log(log_dict, step=step_for_logging) 234 | 235 | def log_wandb(log_dict: Dict, 236 | step: Optional[int]=None): 237 | """ 238 | 1/13:2021: some calls use actual step; some calls use epoch. verify that wandb can handle this 239 | # this might fail with tbx since the step will vary considerably across runs 240 | """ 241 | if step is not None: 242 | wandb.log(log_dict, step=step) 243 | else: 244 | wandb.log(log_dict) 245 | 246 | def save_preds(preds: List[Tuple[str,str,str]], save_dir, file_name, epoch): 247 | """Save predictions `preds` to a CSV file named `file_name` in `save_dir`. 248 | 249 | Args: 250 | preds (list): List of predictions each of the form (source, target, actual), 251 | save_dir (str): Directory in which to save the predictions file. 252 | file_name (str): File name for the CSV file. 253 | 254 | Returns: 255 | save_path (str): Path where CSV file was saved. 256 | """ 257 | # Validate format 258 | # if (not isinstance(preds, list) 259 | # or any(not isinstance(p, tuple) or len(p) != 3 for p in preds)): 260 | # raise ValueError('preds must be a list of tuples (id, start, end)') 261 | 262 | # Make sure predictions are sorted by ID 263 | # preds = sorted(preds, key=lambda p: p[0]) 264 | 265 | # Save to a CSV file 266 | save_path = os.path.join(save_dir, f'{file_name}_{epoch}.csv') 267 | np.savetxt(save_path, np.array(preds), delimiter='\t', fmt='%s') 268 | 269 | return save_path 270 | -------------------------------------------------------------------------------- /seq2seq/common_seq/util_checkpoint.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from json import JSONEncoder 5 | from typing import TypedDict, Dict, Optional, List, Tuple, Union 6 | 7 | import numpy 8 | import torch 9 | from transformers import PreTrainedModel 10 | 11 | log = logging.getLogger(__name__) 12 | 13 | class NumpyArrayEncoder(JSONEncoder): 14 | def default(self, obj): 15 | if isinstance(obj, numpy.ndarray): 16 | return obj.tolist() 17 | return JSONEncoder.default(self, obj) 18 | 19 | # todo: should this be a dataclass 20 | class CheckpointDict(TypedDict): 21 | model_state: Dict 22 | optimizer: Dict 23 | scheduler: Optional[Dict] 24 | config: Dict 25 | step: int 26 | epoch: int 27 | 28 | 29 | def load_ckpt(path: str, model: PreTrainedModel, map_location=None, 30 | log_info=True) -> CheckpointDict: 31 | ckpt_dict: CheckpointDict = torch.load(path, map_location=map_location) 32 | model.load_state_dict(ckpt_dict['model_state']) 33 | if log_info: 34 | try: 35 | with open(f'{path}.txt') as f: 36 | log.info(f'Loading model from {path}:\n{f.readlines()}') 37 | except: 38 | pass 39 | return ckpt_dict 40 | 41 | 42 | class CheckpointSaver: 43 | """Class to save and load model checkpoints. 44 | 45 | Save the best checkpoints as measured by a metric value passed into the 46 | `save` method. Overwrite checkpoints with better checkpoints once 47 | `max_checkpoints` have been saved. 48 | 49 | Args: 50 | """ 51 | 52 | def __init__(self, save_dir, 53 | metrics_to_track: List[Tuple[str, bool]], 54 | save_most_recent=True): 55 | if save_most_recent: 56 | metrics_to_track.append(('epoch', True)) 57 | 58 | self.save_dir = save_dir 59 | # Maps metric_name => (maximize_metric, best_val, saved_path, symlink) 60 | self.best_vals: Dict[str, Tuple[bool, Optional[int], Optional[str], Optional[str]]] = \ 61 | self._init_best_vals(metrics_to_track) 62 | # Maps path to number of pointers (i.e. the number of metrics that still have this as best model 63 | self.saved_models: Dict[str, int] = dict() 64 | 65 | log.info(f'Saver will track (metric, maximize?)\n {metrics_to_track}') 66 | 67 | def _init_best_vals(self, metrics_to_track: List[Tuple[str, bool]]): 68 | best_vals = dict() 69 | for metric, maximize_metric in metrics_to_track: 70 | best_vals[metric] = (maximize_metric, None, None, None) 71 | return best_vals 72 | 73 | def _dump_json(self, filename, object): 74 | with open(filename, 'w') as fp: 75 | json.dump(object, fp, cls=NumpyArrayEncoder) 76 | 77 | def save_if_best(self, 78 | epoch: float, 79 | trainer, 80 | metric_dict: Dict[str, Union[int, float]], 81 | preds: Optional[List[Tuple]] = None, 82 | save_model: bool = True): 83 | """ 84 | Save model and outputs 85 | 86 | When a single epoch / model checkpoint maximizes multiple metrics, we will save only one time, 87 | keeping pointers and garbage collecting when that epoch no longer maximizes any metric 88 | 89 | :param epoch: 90 | :param trainer: 91 | :param metric_dict: 92 | :param preds: 93 | :param save_model: whether to save the actual model (otherwise saves only outputs) 94 | :return: Noreturn 95 | """ 96 | 97 | if preds is None and not save_model: 98 | log.warning(f'Nothing to save (no preds and not saving model)') 99 | return 100 | 101 | # file extensions that will be generated; used for autoremoval 102 | file_list = [".txt", ".preds.json"] 103 | if save_model: 104 | file_list.append("") # the base file is also written 105 | 106 | metric_dict.update({"epoch": epoch}) # always saves on most recent epoch 107 | 108 | def save_most_recent(): 109 | """ 110 | Actually save the model and write all the files 111 | """ 112 | checkpoint_path = os.path.join(self.save_dir, f'epoch_{epoch}.pth.tar') 113 | log.info(f'Saving most recent model at epoch={epoch} to {checkpoint_path}') 114 | self.saved_models[checkpoint_path] = 0 # record that we are tracking this path 115 | 116 | # do the actual saving 117 | readme_dict = metric_dict.copy() 118 | readme_dict.update(dict(name=trainer.config.name)) 119 | self._dump_json(checkpoint_path + ".txt", readme_dict) 120 | if save_model: 121 | ckpt_dict = trainer.make_ckpt_dict() 122 | torch.save(ckpt_dict, checkpoint_path) 123 | if preds is not None: 124 | self._dump_json(checkpoint_path + ".preds.json", preds) 125 | return checkpoint_path 126 | 127 | model_save_path = None # we will only save once per checkpoint call 128 | # but potentially multiple metrics will point to this 129 | 130 | for metric_name, (maximize_metric, best_val, prev_path, prev_sym_path) in self.best_vals.items(): 131 | # prev_path tracks the reference that we are using for this metric 132 | new_val = metric_dict.get(metric_name, None) 133 | if new_val is None: # nothing to save since metric was not reported 134 | continue 135 | if best_val is not None: # then need to compare 136 | if maximize_metric and not new_val > best_val: 137 | continue 138 | if not maximize_metric and not new_val < best_val: 139 | continue 140 | 141 | # we should save 142 | log.info(f"Best metric for {metric_name} = {new_val} at epoch={epoch}") 143 | if prev_path is not None: 144 | self.saved_models[prev_path] -= 1 # first decrement pointer to previous 145 | # remove the symlinks 146 | for ext in file_list: 147 | try: os.remove(prev_sym_path + ext) 148 | except OSError: pass 149 | 150 | # if this is the first metric maximized by this, then we need to actually save 151 | # otherwise, we're just adding another pointer 152 | if model_save_path is None: 153 | model_save_path = save_most_recent() 154 | 155 | # increment tracking on this path 156 | self.saved_models[model_save_path] += 1 # increment the counter to new one 157 | metric_name_safe = metric_name.replace("/", "_") 158 | 159 | # prepare the sym link 160 | sym_path = os.path.join(self.save_dir, f'ckpt_{metric_name_safe}_{new_val:.2f}_{epoch}.pth.tar') 161 | 162 | # record that we saved this 163 | self.best_vals[metric_name] = (maximize_metric, new_val, model_save_path, sym_path) 164 | 165 | # actually make the symlinks 166 | for ext in file_list: 167 | try: 168 | if os.path.isfile(model_save_path + ext): # e.g. if we didn't save preds 169 | os.symlink(model_save_path + ext, sym_path + ext) 170 | except: pass 171 | # todo: i think we were saving twice before 172 | # self._dump_json(sym_path + ".txt", metric_dict) 173 | # if preds is not None: 174 | # self._dump_json(sym_path + ".preds.json", preds) 175 | 176 | # Garbage collect any stale model pointers 177 | # note that we previously removed the symlinks above, this is just the actual files now 178 | for k in list(self.saved_models.keys()): 179 | v = self.saved_models[k] 180 | if v == 0: 181 | for ext in file_list: 182 | try: os.remove(k + ext) 183 | except OSError: 184 | log.warning(f'Failed to remove checkpoint {k}') 185 | self.saved_models.pop(k) 186 | -------------------------------------------------------------------------------- /seq2seq/common_seq/util_dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import dataclasses 4 | import json 5 | import logging 6 | from dataclasses import dataclass 7 | 8 | from transformers import PreTrainedTokenizer 9 | 10 | log = logging.getLogger(__name__) 11 | 12 | # todo: we should not doubly track the warmup state both in multitask trainer and in trainer 13 | 14 | 15 | @dataclass 16 | class DataLoaderConfig: 17 | """ 18 | Config to be used for setting up a DataLoader 19 | """ 20 | shuffle: bool = True 21 | batch_size: int = 64 22 | num_workers: int = 4 23 | 24 | use_json: bool = False 25 | 26 | 27 | @dataclass 28 | class DatasetConfig: 29 | """ 30 | Config to be used for setting up DataSet 31 | """ 32 | tokenizer: PreTrainedTokenizer 33 | max_examples: int = -1 # if not -1, will truncate 34 | # src_len: int = 100 35 | # tgt_len: int = 20 36 | 37 | 38 | # support json encoding of dataclass 39 | class EnhancedJSONEncoder(json.JSONEncoder): 40 | def default(self, o): 41 | if dataclasses.is_dataclass(o): 42 | return dataclasses.asdict(o) 43 | elif callable(o): 44 | return o.__name__ 45 | return super().default(o) 46 | -------------------------------------------------------------------------------- /seq2seq/common_seq/util_dataloader_batch.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import logging 5 | import os 6 | from dataclasses import dataclass 7 | from pathlib import Path 8 | 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | from .types import * 12 | from .util_dataloader import DatasetConfig, DataLoaderConfig 13 | 14 | log = logging.getLogger(__name__) 15 | 16 | 17 | class ClueDataLoaderBatched(DataLoader): 18 | dataset: ClueDatasetBatched 19 | 20 | def __init__(self, *args, **kwargs): 21 | super().__init__(*args, **kwargs) 22 | self.__post_init_check() 23 | 24 | # @property 25 | def num_examples(self): 26 | return len(self.dataset) 27 | 28 | def __post_init_check(self): 29 | for idx, batch in enumerate(self): 30 | if idx > 0: 31 | break 32 | inputs = batch["source_text"] 33 | targets = batch["target_text"] 34 | log.info(f'Dataloader:\n\t' 35 | f'{inputs[0]} => {targets[0]}') 36 | 37 | 38 | @dataclass 39 | class DataSetEntry: 40 | src: str 41 | tgt: str 42 | idx: Optional[int] 43 | 44 | 45 | class ClueDatasetBatched(Dataset): 46 | def __init__(self, 47 | dataset_config: DatasetConfig, 48 | data_dir: str, 49 | type_path): 50 | """ 51 | max_examples: if > 0 then will load only max_examples into the dataset 52 | """ 53 | self.use_json = False 54 | valid_type_paths = ["test", "train", "val"] 55 | assert type_path in valid_type_paths, f"Type path must be one of {valid_type_paths}" 56 | log.info(f'Loading cluedatasetbatched of type {type_path}') 57 | 58 | self.example_path = Path(data_dir) / type_path 59 | self.max_examples = dataset_config.max_examples 60 | 61 | # metric scoring additions 62 | # self.orig_clue_idxs = None # the original indices of the examples; set in _build 63 | 64 | # populated later 65 | self._len = None # the total number of examples 66 | 67 | self.data_list: Optional[List[DataSetEntry]] = None 68 | 69 | if type_path == "train": # hacky way to print only once per dataset, since we always have train 70 | try: 71 | with open(Path(data_dir) / "README.txt") as f: 72 | log.info('For dataset, found readme: ') 73 | log.info(f.readlines()) 74 | except: 75 | log.info('No readme for dataset') 76 | pass 77 | self._build() # fill inputs, targets, max_lens 78 | 79 | def __len__(self): 80 | return self._len 81 | 82 | def __getitem__(self, index): 83 | return self.data_list[index] 84 | 85 | def _build_from_json(self): 86 | path = self.example_path.with_suffix(".json") 87 | 88 | with open(path, 'r') as f: 89 | all_json = json.load(f) 90 | ex_ct = len(all_json) 91 | if self.max_examples > 0: 92 | ex_ct = min(self.max_examples, ex_ct) 93 | 94 | self._len = ex_ct 95 | self.data_list = all_json[:ex_ct] 96 | 97 | def _build(self): 98 | if os.path.isfile(self.example_path.with_suffix(".json")): 99 | # log.info('Json files found, so using them') 100 | self.use_json = True 101 | self._build_from_json() 102 | else: 103 | raise NotImplementedError(f'No json files found at {self.example_path}') 104 | 105 | @classmethod 106 | def from_config(cls, cfg: DatasetConfig, 107 | data_dir: str, 108 | type_path: str): 109 | return cls(dataset_config=cfg, 110 | data_dir=data_dir, 111 | type_path=type_path) 112 | 113 | pretokenize_fn = Callable[[List[Dict]], Tuple[List,...]] 114 | 115 | def default_pretokenize(batch_list: List[Dict]) -> Tuple[List,...]: 116 | src_text = [e['input'] for e in batch_list] 117 | tgt_text = [e['target'] for e in batch_list] 118 | idxs = [e['idx'] for e in batch_list] 119 | return src_text, tgt_text, idxs 120 | 121 | def default_collate_fn_json(tokenizer: PreTrainedTokenizerFast, batch_list: List[Dict], 122 | pre_tokenize_fn: pretokenize_fn = None) -> Dict: 123 | if pre_tokenize_fn is not None: 124 | src_text, tgt_text, idxs = pre_tokenize_fn(batch_list) 125 | else: 126 | src_text, tgt_text, idxs = default_pretokenize(batch_list) 127 | 128 | tokenized_inputs = tokenizer(src_text, padding='longest', return_tensors='pt') 129 | tokenized_outputs = tokenizer(tgt_text, padding='longest', return_tensors='pt') 130 | 131 | source_ids = tokenized_inputs["input_ids"] 132 | target_ids = tokenized_outputs["input_ids"] 133 | src_mask = tokenized_inputs["attention_mask"] # might need to squeeze 134 | target_mask = tokenized_outputs["attention_mask"] # might need to squeeze 135 | 136 | # We cast these to torch.long in preprocess batch in trainer (# todo: is this right?) 137 | ret = {"source_ids": source_ids, 138 | "source_mask": src_mask, 139 | "target_ids": target_ids, 140 | "target_mask": target_mask, 141 | "source_text": src_text, 142 | "target_text": tgt_text, 143 | "idxs": idxs} 144 | 145 | return ret 146 | 147 | def _get_dataloader_from_dataset( 148 | tokenizer, 149 | dataset: ClueDatasetBatched, 150 | dl_config: DataLoaderConfig, 151 | inputted_collate_fn: collate_fn_type) \ 152 | -> ClueDataLoaderBatched: 153 | # inputted_collate_fn: Optional[Callable[[PreTrainedTokenizerFast, List[DataSetEntry]], Dict]] = None) \ 154 | 155 | # take care of currying the appropriate collation function 156 | if dl_config.use_json: 157 | default_coll = default_collate_fn_json 158 | else: 159 | raise NotImplemented 160 | # default_coll = default_collate_fn 161 | if inputted_collate_fn is not None: 162 | def curried_collate_fn(input_list) -> Dict: 163 | return inputted_collate_fn(tokenizer, input_list) 164 | else: 165 | def curried_collate_fn(input_list) -> Dict: 166 | return default_coll(tokenizer, input_list) 167 | collate_fn = curried_collate_fn 168 | 169 | dataloader = ClueDataLoaderBatched(dataset, 170 | batch_size=dl_config.batch_size, 171 | shuffle=dl_config.shuffle, 172 | num_workers=dl_config.num_workers, 173 | collate_fn=collate_fn) 174 | log.info(f'Dataloader loaded from dataset') 175 | return dataloader 176 | 177 | 178 | def _get_dataloader_batched( 179 | tokenizer, 180 | dataset_config: DatasetConfig, 181 | dl_config: DataLoaderConfig, 182 | data_dir, 183 | type_path: str, 184 | label_fn: Optional[Callable] = None, 185 | clue_to_idx_map=None, 186 | inputted_collate_fn: Optional[Callable[[PreTrainedTokenizerFast, List[DataSetEntry]], Dict]] = None) \ 187 | -> ClueDataLoaderBatched: 188 | 189 | if label_fn is not None: 190 | # needed because we need the token offsets 191 | assert isinstance(tokenizer, PreTrainedTokenizerFast) 192 | 193 | # set up dataset 194 | data_set = ClueDatasetBatched(dataset_config, 195 | data_dir=data_dir, 196 | type_path=type_path) 197 | log.info(f'Dataset {type_path} loaded with size: {len(data_set)}') 198 | 199 | # setup dataloader 200 | return _get_dataloader_from_dataset(tokenizer, data_set, dl_config, inputted_collate_fn) 201 | 202 | 203 | def get_dataloaders_batched(tokenizer, 204 | dataset_config_train: DatasetConfig, 205 | dataset_config_val: DatasetConfig, 206 | dl_config_train: DataLoaderConfig, 207 | dl_config_val: DataLoaderConfig, 208 | data_dir, 209 | label_fn: Optional[Callable] = None, 210 | clue_to_idx_map: Optional[Dict[str, int]] = None, 211 | collate_fns: Optional[List[Callable]] = None, 212 | use_test_set: bool = False) -> Tuple[ClueDataLoaderBatched, ClueDataLoaderBatched]: 213 | """ 214 | 215 | :param tokenizer: 216 | :param dataset_config: 217 | :param dl_config_train: 218 | :param dl_config_val: 219 | :param data_dir: 220 | :param label_fn: 221 | :param clue_to_idx_map: 222 | :param collate_fns: Two collate functions, one for each of the dataloaders 223 | :return: 224 | """ 225 | if collate_fns is None: 226 | collate_fns = [None, None] 227 | assert len(collate_fns) == 2 228 | train_loader = _get_dataloader_batched(tokenizer, 229 | dataset_config_train, 230 | dl_config_train, 231 | data_dir, 232 | type_path="train", 233 | label_fn=label_fn, 234 | clue_to_idx_map=clue_to_idx_map, 235 | inputted_collate_fn=collate_fns[0]) 236 | if use_test_set: 237 | val_path = "test" 238 | else: 239 | val_path = "val" 240 | eval_loader = _get_dataloader_batched(tokenizer, 241 | dataset_config_val, 242 | dl_config_val, 243 | data_dir, 244 | type_path=val_path, 245 | label_fn=label_fn, 246 | clue_to_idx_map=clue_to_idx_map, 247 | inputted_collate_fn=collate_fns[1]) 248 | 249 | return train_loader, eval_loader 250 | -------------------------------------------------------------------------------- /seq2seq/common_seq/util_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from collections import Counter 5 | from typing import Tuple, List, Dict, Set, Callable, Optional, Union, Any 6 | 7 | import torch 8 | 9 | from .util import ProcessedBatch, PerBatchValStep 10 | 11 | log = logging.getLogger(__name__) 12 | 13 | 14 | # todo: metrics should be callable metaclass function 15 | class MetricsDict: 16 | def __init__(self, avg_metrics=None, no_avg_metrics=None): 17 | if avg_metrics is None: 18 | avg_metrics = dict() 19 | if no_avg_metrics is None: 20 | no_avg_metrics = dict() 21 | self.avg_metrics = avg_metrics 22 | self.no_avg_metrics = no_avg_metrics 23 | 24 | 25 | MetricFcn = Callable[[PerBatchValStep, ProcessedBatch], 26 | Union[MetricsDict, 27 | Tuple[MetricsDict, torch.Tensor]]] 28 | 29 | 30 | class MetricsPredsWrapper: 31 | """ 32 | If label is given, all metrics will be prefixed with the label 33 | 34 | For individual metrics, label should be tied to the val set and passed with update 35 | 36 | """ 37 | 38 | def __init__(self, metrics_dict: Optional[MetricsDict] = None, 39 | label: str = "", 40 | avg_divisor: Optional[int] = None): 41 | if metrics_dict is None: 42 | metrics_dict = MetricsDict() 43 | self.md = metrics_dict # should be accessed only via get_all_metrics() 44 | 45 | self.preds: List[Tuple[str, str, str, Any, ...]] = [] # input, target, greedy, sampled 46 | self.label = "" 47 | if label != "": 48 | self.label = label + "/" 49 | 50 | self.avg_divisor = avg_divisor 51 | 52 | def get_all_metrics(self, avg_divisor: Optional[int] = None) -> Tuple[str, float, float]: 53 | """ 54 | Return the k, value (averaged if necessary) and the original value 55 | """ 56 | assert avg_divisor is not None or self.avg_divisor is not None 57 | if avg_divisor is None: 58 | avg_divisor = self.avg_divisor 59 | 60 | for k, v in self.md.avg_metrics.items(): 61 | yield self.label + k, v / avg_divisor, v 62 | for k, v in self.md.no_avg_metrics.items(): 63 | yield self.label + k, v, v 64 | 65 | def get_all_metrics_dict(self) -> Dict[str, float]: 66 | ret_dict = dict() 67 | for k, _, v in self.get_all_metrics(): 68 | ret_dict[k] = v 69 | return ret_dict 70 | 71 | def update_for_batch(self, 72 | metric_fcns: List[MetricFcn], 73 | valstep_batch: PerBatchValStep, 74 | pbatch: ProcessedBatch, 75 | metric_label: str = ""): 76 | if metric_label != "": 77 | metric_label = metric_label + "/" 78 | 79 | # update predictions 80 | if pbatch.idxs is not None: 81 | preds = list(zip(pbatch.idxs, # will be populated for json DL, or for idxs provided file 82 | pbatch.orig_text_input, 83 | pbatch.orig_text_output, 84 | valstep_batch.outputs_greedy, 85 | valstep_batch.outputs_sampled)) 86 | else: # todo(json): deprecate this 87 | raise NotImplemented 88 | # preds = list(zip(pbatch.orig_text_input, 89 | # pbatch.orig_text_output, 90 | # valstep_batch.outputs_greedy, 91 | # valstep_batch.outputs_sampled)) 92 | self.preds.extend(preds) 93 | 94 | # update metrics 95 | for f in metric_fcns: 96 | result = f(valstep_batch, pbatch) 97 | if type(result) == tuple: 98 | raise NotImplemented('no longer support result tuple') 99 | # new_metrics_dict, correct_indices = result 100 | # if correct_indices is not None: 101 | # preds_correct = [self.preds['greedy'][i] for i in correct_indices.tolist()] 102 | # self.preds[f'{f.__name__}_correct'].extend(preds_correct) 103 | else: 104 | new_metrics_dict = result 105 | 106 | self.update(new_metrics_dict, metric_label) 107 | 108 | # todo: should be internal only method 109 | def update(self, new_dict: MetricsDict, label=""): 110 | # need to iterate since there are floats (otherwise use counter) 111 | for k, v in new_dict.avg_metrics.items(): 112 | self.md.avg_metrics[label + k] = self.md.avg_metrics.get(label + k, 0) + v 113 | for k, v in new_dict.no_avg_metrics.items(): 114 | self.md.no_avg_metrics[label + k] = self.md.no_avg_metrics.get(label + k, 0) + v 115 | 116 | def add_val(self, key, val, avg: bool, label=""): 117 | if label != "": 118 | label = label + "/" 119 | if avg: 120 | self.md.avg_metrics[label + key] = val 121 | else: 122 | self.md.no_avg_metrics[label + key] = val 123 | 124 | 125 | # see test_util for test verification 126 | 127 | # from https://en.wikibooks.org/wiki/Algorithm_Implementation/Strings/Dice%27s_coefficient#Python 128 | 129 | def _remove_spaces(s: str) -> str: 130 | return s.replace(' ','').strip() 131 | 132 | def compute_metrics_sampled_primary(*args) -> MetricsDict: 133 | return compute_metrics_sampled(*args, 134 | primary_only=True) 135 | 136 | def compute_metrics_sampled(valstep_batch: PerBatchValStep, 137 | pbatch: ProcessedBatch, 138 | label_sets: Optional[List[str, Set[int]]] = None, 139 | primary_only: bool = False) -> MetricsDict: 140 | sampled_outputs: List[List[str]] = valstep_batch.outputs_sampled 141 | assert sampled_outputs is not None 142 | greedy_outputs: List[str] = valstep_batch.outputs_greedy 143 | tgt_outputs: List[str] = pbatch.orig_text_output 144 | idxs = pbatch.idxs # optional tensor 145 | 146 | ######### 147 | # special checks on idxs and label_sets when we are doing label metrisc 148 | if label_sets is not None: 149 | assert idxs is not None 150 | assert label_sets is not None and len(label_sets) > 0 151 | assert type(idxs[0].item()) is int, f'{idxs[:10]}' 152 | # label_list = list(label_set) 153 | # assert type(label_list[0]) is int, f'{label_list[:10]}' 154 | label_counters = Counter() 155 | idxs = idxs.tolist() 156 | # do this so that we can zip it in the for loop, even if we're not doing it 157 | if idxs is None: 158 | idxs = [-1] * len(greedy_outputs) 159 | ########### 160 | 161 | # aggregates 162 | ct_greedy = 0 # num times tgt matches greedy (after lowercase) 163 | ct_sampled = 0 # num times tgt in the sample set 164 | ct_top_sampled = 0 # num where, after length filter, the top answer is correct 165 | ct_top_5_sampled = 0 166 | # num_lost = 0 # when greedy gets it but sample doesn't 167 | 168 | cum_num_correct_length = 0.0 # will turn into pct correct by divide by num_seq 169 | cum_num_words_correct = 0.0 170 | 171 | num_sampled = len(sampled_outputs[0]) 172 | 173 | # iterate through batch 174 | for g, sample_list, t, clue_idx in zip(greedy_outputs, sampled_outputs, tgt_outputs, idxs): 175 | # lower case, remove spaces, strip 176 | sample_list: List[str] = list(map(lambda x: x.lower(), sample_list)) 177 | sample_list_no_spaces: List[str] = list(map(_remove_spaces, sample_list)) 178 | g = _remove_spaces(g) 179 | tgt_no_spaces = _remove_spaces(t) 180 | tgt_len_no_spaces = len(tgt_no_spaces) 181 | 182 | # filter to correct len 183 | samples_no_spaces_filtered = list(filter(lambda x: len(x) == tgt_len_no_spaces, sample_list_no_spaces)) 184 | 185 | # inclusion / exclusion 186 | # in_greedy = False 187 | # in_sample = False 188 | if g == tgt_no_spaces: 189 | ct_greedy += 1 190 | # in_greedy = True 191 | 192 | # idx = 0 193 | for idx, samp in enumerate(samples_no_spaces_filtered): 194 | if samp == tgt_no_spaces: 195 | # in_sample=True 196 | ct_sampled += 1 197 | if idx == 0: 198 | ct_top_sampled += 1 199 | if idx < 5: 200 | ct_top_5_sampled += 1 201 | break 202 | # if in_greedy and idx != 0: 203 | # num_lost += 1 204 | 205 | # how close to correct length 206 | cum_num_correct_length += len(samples_no_spaces_filtered) 207 | 208 | # num words 209 | tgt_spaces = t.count(' ') 210 | for samp in sample_list: 211 | if samp.count(' ') == tgt_spaces: 212 | cum_num_words_correct += 1 213 | 214 | ######## 215 | # for labels 216 | if label_sets is not None: 217 | raise NotImplemented 218 | # assert type(clue_idx) is int 219 | # for name, label_set in label_sets: 220 | # if clue_idx in label_set: 221 | # if in_sample: 222 | # label_counters[name + "_label"] += 1 223 | # # ct_succ_with_label += 1 224 | ######### 225 | 226 | # scale cumulatives by the number of sequences generated 227 | pct_correct_length = cum_num_correct_length / num_sampled 228 | pct_correct_wordct = cum_num_words_correct / num_sampled 229 | 230 | if primary_only: 231 | ret_dict: MetricsDict = MetricsDict( 232 | avg_metrics=dict( 233 | num_match_in_sample=ct_sampled, 234 | num_match_top_sampled=ct_top_sampled, 235 | ) 236 | ) 237 | else: 238 | ret_dict: MetricsDict = MetricsDict( 239 | avg_metrics=dict( 240 | num_exact_match_char_2=ct_greedy, 241 | 242 | num_match_in_sample=ct_sampled, 243 | num_match_top_sampled=ct_top_sampled, 244 | num_match_top_5_sampled=ct_top_5_sampled, 245 | 246 | # num_lost_sample_vs_greedy=num_lost, # when greedy correct but top sampled differs 247 | 248 | # will be averaged 249 | pct_correct_length=pct_correct_length, 250 | pct_correct_wordct=pct_correct_wordct, 251 | ) 252 | ) 253 | if label_sets is not None: 254 | raise NotImplemented 255 | # # don't average 256 | # # todo: should have a different divisor 257 | # ret_dict.no_avg_metrics = dict( 258 | # **label_counters) 259 | return ret_dict -------------------------------------------------------------------------------- /seq2seq/model_runner.py: -------------------------------------------------------------------------------- 1 | 2 | from transformers import T5ForConditionalGeneration, T5TokenizerFast 3 | import seq2seq.common_seq.util as util 4 | import numpy as np 5 | from typing import * 6 | import logging 7 | 8 | from seq2seq.common_seq.util_checkpoint import load_ckpt 9 | 10 | 11 | class ModelRunner: 12 | def __init__(self, model_name, ckpt_path, 13 | num_generations=10): 14 | device, gpu_ids = util.get_available_devices(assert_cuda=True) 15 | logging.info(device, gpu_ids) 16 | 17 | self.device = device 18 | self.num_generations = num_generations 19 | self.model = T5ForConditionalGeneration.from_pretrained(model_name) 20 | self.tokenizer = T5TokenizerFast.from_pretrained(model_name) 21 | 22 | load_ckpt(ckpt_path, self.model, map_location=device) 23 | self.model.to(device) 24 | 25 | def generate(self, sentences: List[str]) -> List[np.array]: 26 | # tokenized = self.tokenizer.batch_encode_plus(sentences, max_length=50, return_tensors='pt', padding='max_length', truncation=True) 27 | tokenized = self.tokenizer(sentences, padding='longest', return_tensors='pt') 28 | input_ids = tokenized['input_ids'].to(self.device) 29 | src_mask = tokenized['attention_mask'].to(self.device) 30 | 31 | # greedy decoding 32 | # out_ids = self.model.generate(input_ids, attention_mask=src_mask) 33 | # greedy_decoded = tokenizer.batch_decode(out_ids, skip_special_tokens=True) 34 | 35 | generated_ids_sampled = self.model.generate(input_ids, 36 | attention_mask=src_mask, 37 | num_beams=self.num_generations, 38 | num_return_sequences=self.num_generations, 39 | do_sample=False, 40 | max_length=10, 41 | length_penalty=0.05 42 | ) 43 | 44 | decoded = self.tokenizer.batch_decode(generated_ids_sampled, skip_special_tokens=True) 45 | decoded = np.array_split(decoded, len(sentences)) 46 | 47 | return decoded 48 | 49 | # for output, input_sent in zip(decoded, sentences): 50 | # print(input_sent) 51 | # print(output) 52 | 53 | # set comparison 54 | # orig_set, other_sets = decoded_sets[0], decoded_sets[1:] 55 | # decoded_sets = list(map(set, decoded)) 56 | # for new_set, sent in zip(other_sets, sentences[1:]): 57 | # print(sent) 58 | # pp(new_set) 59 | # pp(f'new: {new_set.difference(orig_set)}') 60 | # pp(f'lost: {orig_set.difference(new_set)}') 61 | # print() 62 | 63 | -------------------------------------------------------------------------------- /seq2seq/multitask_config.py: -------------------------------------------------------------------------------- 1 | import common_seq.collate_fns as cfns 2 | from common_seq import util_metrics 3 | from common_seq.util_multiloader import MultitaskConfig, TaskConfig 4 | # import sys 5 | # sys.path.append('../../decrypt_root') # decrypt_root to enable config import 6 | # import decrypt_root.config as config 7 | 8 | # relative import issues with config 9 | k_curr_dir = "../data/clue_json/curricular" 10 | # k_curr_dir = config.DataDirs.DataExport._curricular 11 | # subdirectories within k_curr_dir with the content 12 | # k_task_dir = config.DataDirs.DataExport._ACW_sub_dir 13 | k_task_dir = "ACW_data" 14 | # k_anag_task_dir = config.DataDirs.DataExport.anag_dir # has train.json 15 | k_anag_task_dir = "anagram" 16 | k_anag_indic_file = f'{k_curr_dir}/{k_anag_task_dir}/anag_indics.json' 17 | 18 | k_default_args = dict( 19 | multitask_dir=k_curr_dir, 20 | reset=True, 21 | val_split_pct=0.99 22 | ) 23 | 24 | # american crossword (with label) 25 | task_ACW = TaskConfig( 26 | dir=k_task_dir, # 2.4m 27 | name="acw", 28 | val_fcn_list=[util_metrics.compute_metrics_sampled_primary], 29 | # adds a label, like 'phrase: ' 30 | collate_fn=cfns.collate_fn_from_pretokenize(cfns.make_pretokenize_prepend_label('phrase')) 31 | ) 32 | task_ACW_descramble = TaskConfig( 33 | dir=k_task_dir, # 2.4m 34 | name="acw_descramble", 35 | val_fcn_list=[util_metrics.compute_metrics_sampled_primary], 36 | collate_fn = cfns.collate_fn_from_pretokenize(cfns.make_pretokenize_descramble(label='descramble')) 37 | ) 38 | task_ACW_descramble_word = TaskConfig( 39 | dir=k_task_dir, # 2.4m 40 | name="acw_descramble_word", 41 | val_fcn_list=[util_metrics.compute_metrics_sampled], 42 | # add a label and descramble (will lowercase the first letter of the clue) 43 | collate_fn = cfns.collate_fn_from_pretokenize( 44 | cfns.make_pretokenize_descramble(label='descramble word', word_only=True)) 45 | ) 46 | task_anagram = TaskConfig( 47 | dir= k_anag_task_dir, 48 | name="anag_with_indic", 49 | val_fcn_list=[util_metrics.compute_metrics_sampled], 50 | # add a label and descramble (will lowercase the first letter of the clue) 51 | collate_fn = cfns.collate_fn_from_pretokenize( 52 | cfns.make_pretokenize_anagram(label='anagram', 53 | anag_indic_file=k_anag_indic_file)) 54 | ) 55 | 56 | 57 | multi_config = dict( 58 | # ACW only 59 | ACW = MultitaskConfig( 60 | freq_list=[20, 6], # same as 20, 3,3 w.r.t. total pretraining examples 61 | num_warmup=4, 62 | tasks=[task_ACW], 63 | **k_default_args 64 | ), 65 | 66 | # ACW-descramble only 67 | ACW_descramble = MultitaskConfig( 68 | freq_list=[20, 6], # same as 20, 3,3 w.r.t. total pretraining examples 69 | num_warmup=4, 70 | tasks=[task_ACW_descramble], 71 | **k_default_args 72 | ), 73 | 74 | # ACW + ACW-descramble 75 | # top performing 76 | ACW__ACW_descramble = MultitaskConfig( 77 | freq_list=[20, 3, 3], 78 | num_warmup=2, # use 2 -> roughly translates to 4 total epochs 79 | tasks=[task_ACW, task_ACW_descramble], 80 | **k_default_args 81 | ), 82 | 83 | # ACW + ACW-descramble-word 84 | # crosswords + descramble bare 85 | ACW__ACW_descramble_word = MultitaskConfig( 86 | freq_list=[20, 3, 3], 87 | num_warmup=2, # use 2 -> roughly translates to 4 total epochs 88 | tasks=[task_ACW, task_ACW_descramble_word], 89 | **k_default_args 90 | ), 91 | 92 | # ACW + anagram 93 | ACW__anagram = MultitaskConfig( 94 | freq_list=[20, 3, 3], 95 | num_warmup=2, 96 | tasks=[task_ACW, task_anagram], 97 | **k_default_args 98 | ), 99 | 100 | # ACW + ACW-descramble + anagram 101 | # has 7:6 ratio of pretraining batches (i.e. more) 102 | ACW__ACW_descramble__anagram = MultitaskConfig( 103 | freq_list=[20, 3, 3, 1], 104 | num_warmup=2, 105 | tasks=[task_ACW, task_ACW_descramble, task_anagram], 106 | **k_default_args 107 | ), 108 | 109 | final_top_result_scaled_up = MultitaskConfig( 110 | freq_list=[20, 3, 3], 111 | num_warmup=4, # scale up for the top performing 112 | tasks=[task_ACW, task_ACW_descramble], 113 | **k_default_args 114 | ), 115 | 116 | # Cryptonite - best multitask approach (ACW + ACW-descramble) 117 | # has an extra epoch of warmup 118 | cfg_crypto_acw_acwdesc = MultitaskConfig( 119 | freq_list=[20, 3, 3], 120 | num_warmup=3, # one more than above 121 | tasks=[task_ACW, task_ACW_descramble], 122 | **k_default_args 123 | ), 124 | ) 125 | -------------------------------------------------------------------------------- /seq2seq/train_clues.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from typing import * 4 | 5 | import wandb 6 | from overrides import overrides 7 | from transformers import PreTrainedTokenizerFast 8 | 9 | import args_cryptic as args 10 | from common_seq import ( 11 | util, 12 | util_metrics 13 | ) 14 | from common_seq.util_dataloader_batch import default_collate_fn_json 15 | from train_abc import ( 16 | RunHelper, 17 | pre_setup, 18 | setup_wandb_for_run, T5Trainer 19 | ) 20 | from multitask_config import multi_config 21 | 22 | 23 | 24 | class ClueTrainer(T5Trainer): 25 | def __init__(self, config, rh, **kwargs): 26 | super().__init__(config, rh, **kwargs) 27 | 28 | if self.config.special: 29 | assert self.config.special in ['no_lens', 'no_len_multi'] 30 | 31 | @overrides 32 | def init_val_fcns(self): 33 | assert self.config.do_sample 34 | def compute_metrics_sampled_curry(*fcn_args): 35 | return util_metrics.compute_metrics_sampled(*fcn_args) 36 | 37 | self.val_fcn_list: List[util_metrics.MetricFcn] = [ 38 | compute_metrics_sampled_curry 39 | ] 40 | 41 | @overrides 42 | def setup_dataloaders(self): 43 | if not self.config.special: 44 | super().setup_dataloaders() 45 | return 46 | 47 | if self.config.special == 'no_lens': 48 | self.setup_dataloaders_no_len() 49 | elif self.config.special == 'no_len_multi': 50 | assert self.config.multitask 51 | self.setup_dataloaders_no_len_multi() 52 | else: 53 | raise NotImplemented 54 | 55 | def setup_dataloaders_no_len(self): 56 | log.info('training with no length specification') 57 | # remove the length specification 58 | def pre_tokenize(batch_list: List[Dict]): 59 | src_text, tgt_text, idxs = [], [], [] 60 | for e in batch_list: 61 | input = e['input'] 62 | splits = input.split(' ') 63 | assert splits[-1][0] == '(' 64 | input = ' '.join(splits[:-1]) 65 | 66 | tgt = e['target'] 67 | idx = e['idx'] 68 | 69 | src_text.append(input) 70 | tgt_text.append(tgt) 71 | idxs.append(idx) 72 | 73 | return src_text, tgt_text, idxs 74 | 75 | def coll_fn(tokenizer: PreTrainedTokenizerFast, batch_list: List[Dict]) -> Dict: 76 | return default_collate_fn_json(tokenizer, batch_list, pre_tokenize_fn=pre_tokenize) 77 | 78 | self.setup_dataloaders_no_multi(batched_collate_fcns=[coll_fn, coll_fn]) 79 | 80 | def setup_dataloaders_no_len_multi(self): 81 | # remove the length specification on multitask dataset 82 | def pre_tokenize(batch_list: List[Dict]): 83 | src_text, tgt_text, idxs = [], [], [] 84 | for e in batch_list: 85 | input = e['input'] 86 | splits = input.split(' ') 87 | assert splits[-1][0] == '(' 88 | input = ' '.join(splits[:-1]) 89 | 90 | tgt = e['target'] 91 | idx = e['idx'] 92 | 93 | src_text.append(input) 94 | tgt_text.append(tgt) 95 | idxs.append(idx) 96 | 97 | return src_text, tgt_text, idxs 98 | 99 | def coll_fn_multi(tokenizer: PreTrainedTokenizerFast, batch_list: List[Dict]) -> Dict: 100 | return default_collate_fn_json(tokenizer, batch_list, pre_tokenize_fn=pre_tokenize) 101 | 102 | self.setup_dataloaders_multi(batched_collate_fcns=None, 103 | multitask_collate_fn=coll_fn_multi) 104 | 105 | @overrides 106 | def post_run(self): 107 | # get the checkpoint 108 | best_val = self.rh.ckpt_saver.best_vals['dev/num_match_top_sampled'] 109 | ckpt_path, sym_link = best_val[2], best_val[3] 110 | log.info(f'loading from\n\t{ckpt_path}\n\t{sym_link}') 111 | 112 | preds_path = best_val[2] + "preds.json" 113 | log.info(f'for final validation:\n' 114 | f'\t{preds_path}') 115 | 116 | 117 | def add_extra_args(parser): 118 | parser.add_argument('--special', 119 | default=None, 120 | help='Enables special flags, for example' 121 | 'no_lens') 122 | parser.add_argument('--hacky', 123 | action='store_true', 124 | help='Enables train resume on multitask. See train_abc.py for details of what needs to be done.') 125 | 126 | if __name__ == '__main__': 127 | parsed_args = args.get_args(add_extra_args) 128 | pre_setup(parsed_args) 129 | global_record_dir = setup_wandb_for_run(parsed_args) 130 | log = logging.getLogger() 131 | util.config_logger(log, global_record_dir) 132 | log.info(" ".join(sys.argv[:])) 133 | util.set_seed(wandb.config.seed) 134 | 135 | ######## 136 | ### run specific config 137 | ######## 138 | metrics_to_track: List[Tuple[str, bool]] = [ 139 | ('dev/num_match_top_sampled', True), 140 | # ('dev/num_match_top_5_sampled', True), 141 | # ('dev/NLL', False), 142 | ] 143 | # todo: this isn't working 144 | wandb.run.summary["dev/num_match_top_sampled"] = "best_top_sampled" 145 | wandb.run.summary["dev/num_match_in_sample"] = "best_in_sample" 146 | 147 | # multitask 148 | aux_config = dict() 149 | if wandb.config.multitask: 150 | aux_config['multitask_config'] = multi_config[wandb.config.multitask] 151 | # 'acw' must match the name of one of the tasks in multitask_config 152 | metrics_to_track.extend([('multisave', True), 153 | ('multi/acw/num_match_in_sample', True)]) # forces save at the end of multitask training 154 | 155 | local_rh = RunHelper(global_record_dir, metrics_to_track, wandb.config) 156 | ################ 157 | # final setup (consistent for all 158 | local_trainer = ClueTrainer(wandb.config, local_rh, aux_config=aux_config) 159 | wandb.watch(local_trainer.model, log="all") 160 | 161 | local_trainer.run() 162 | -------------------------------------------------------------------------------- /seq2seq/train_descramble.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import sys 4 | from typing import List, Dict 5 | 6 | import wandb 7 | from overrides import overrides 8 | from transformers import PreTrainedTokenizerFast 9 | 10 | import args_cryptic as args 11 | from common_seq import util as util 12 | from common_seq import util_metrics 13 | from common_seq.util_dataloader_batch import default_collate_fn_json 14 | from train_abc import ( 15 | T5Trainer, 16 | setup_wandb_for_run, pre_setup, 17 | RunHelper 18 | ) 19 | 20 | log = logging.getLogger(__name__) 21 | 22 | class AnagTrainer(T5Trainer): 23 | @overrides 24 | def init_val_fcns(self): 25 | self.val_fcn_list: List[util_metrics.MetricFcn] = [ 26 | util_metrics.compute_metrics_sampled 27 | ] 28 | 29 | @overrides 30 | def setup_dataloaders(self): 31 | # mods to support new dataloader 32 | if not self.config.randomize_train_scramble: 33 | raise NotImplemented 34 | 35 | self.setup_dataloaders_xd_descramble() 36 | 37 | def setup_dataloaders_xd_descramble(self): 38 | assert self.config.randomize_train_scramble, f'For descramble, the flag --random_train_scramble must be set' 39 | assert self.config.add_defn is not None, f'For descramble, must specify either --add_defn or --no_defn' 40 | 41 | # otherwise randomize 42 | # uses function closures 43 | log.info("Randomizing train set") 44 | rng = random.Random() 45 | rng.seed(42) 46 | 47 | def randomize_letters(s: str) -> str: 48 | x = list(s) 49 | rng.shuffle(x) 50 | return "".join(x) 51 | 52 | def pre_tokenize_fn_xd_cw(batch_list: List[Dict]): 53 | src_text, tgt_text, idxs = [], [], [] 54 | for e in batch_list: 55 | defn = e['defn'] 56 | tgt = e['target'] 57 | tgt_scrambled = randomize_letters(tgt) 58 | 59 | if not self.config.copy: 60 | if self.config.add_defn: 61 | src_text.append(f'{tgt_scrambled} | {defn}') 62 | else: 63 | src_text.append(f'{tgt_scrambled}') 64 | else: # copy 65 | if self.config.add_defn: 66 | src_text.append(f'{tgt} | {defn}') 67 | else: 68 | src_text.append(f'{tgt}') 69 | 70 | tgt_text.append(tgt) 71 | idxs.append(-1) # dummy indices 72 | 73 | return src_text, tgt_text, idxs 74 | 75 | def coll_fn(tokenizer: PreTrainedTokenizerFast, batch_list: List[Dict]) -> Dict: 76 | return default_collate_fn_json(tokenizer, batch_list, pre_tokenize_fn=pre_tokenize_fn_xd_cw) 77 | 78 | self.setup_dataloaders_no_multi(batched_collate_fcns=[coll_fn, coll_fn]) 79 | 80 | 81 | def add_extra_args(parser): 82 | parser.add_argument('--randomize_train_scramble', 83 | action='store_true', 84 | help='Whether to randomize the scrambling in train examples. I.e.' 85 | 'The collate function will scramble the letters each time, ' 86 | 'so each time a word is shown in train it will (likely) be ordered differently') 87 | # whether to append definition 88 | parser.add_argument('--add_defn', 89 | action='store_true', 90 | dest='add_defn', 91 | help='whether to append phrasal definition') 92 | parser.add_argument('--no_defn', 93 | action='store_false', 94 | dest='add_defn') 95 | parser.set_defaults(add_defn=None) # will be false for eval 96 | parser.add_argument('--copy', 97 | action='store_true', 98 | help='Whether to do a copy task (i.e no scrambling)') 99 | 100 | 101 | if __name__ == '__main__': 102 | # repeated logic 103 | parsed_args = args.get_args(add_extra_args) # potentially pass extra args 104 | pre_setup(parsed_args) 105 | global_record_dir = setup_wandb_for_run(parsed_args) 106 | log = logging.getLogger() 107 | util.config_logger(log, global_record_dir) 108 | log.info(" ".join(sys.argv[:])) 109 | util.set_seed(wandb.config.seed) 110 | 111 | ######## 112 | ### run specific config 113 | ######## 114 | metrics_to_track = [('dev/NLL', False), 115 | ('dev/num_match_top_sampled',True)] 116 | 117 | local_rh = RunHelper(global_record_dir, metrics_to_track, wandb.config) 118 | #### 119 | # final setup 120 | local_trainer = AnagTrainer(wandb.config, local_rh) 121 | wandb.watch(local_trainer.model, log="all") 122 | 123 | local_trainer.run() 124 | --------------------------------------------------------------------------------