├── .gitignore ├── .vscode ├── launch.json └── settings.json ├── LICENSE ├── README.md ├── dataset ├── __init__.py ├── scan_length_resplit.py ├── sequence.py └── text │ ├── cfq.py │ ├── cogs.py │ ├── dm_math.py │ ├── pcfg_set.py │ ├── scan.py │ ├── text_dataset.py │ └── typed_text_dataset.py ├── framework ├── .gitignore ├── __init__.py ├── data_structures │ ├── __init__.py │ ├── dotdict.py │ └── vocabulary.py ├── helpers │ ├── __init__.py │ ├── argument_parser.py │ ├── saver.py │ └── training_helper.py ├── layers │ ├── __init__.py │ ├── cross_entropy_label_smoothing.py │ └── positional_encoding.py ├── loader │ ├── __init__.py │ ├── collate.py │ └── sampler.py ├── utils │ ├── __init__.py │ ├── average.py │ ├── download.py │ ├── gpu_allocator.py │ ├── lockfile.py │ ├── parallel_map.py │ ├── port.py │ ├── process.py │ ├── seed.py │ ├── set_lr.py │ ├── time_meter.py │ └── universal.py └── visualize │ ├── __init__.py │ ├── plot.py │ └── tensorboard.py ├── interfaces ├── __init__.py ├── encoder_decoder.py ├── model_interface.py ├── result.py └── transformer │ ├── __init__.py │ └── encoder_decoder_interface.py ├── layers ├── __init__.py ├── tied_embedding.py └── transformer │ ├── __init__.py │ ├── multi_head_attention.py │ ├── multi_head_relative_pos_attention.py │ ├── relative_transformer.py │ ├── transformer.py │ ├── universal_relative_transformer.py │ └── universal_transformer.py ├── main.py ├── models ├── __init__.py ├── encoder_decoder.py └── transformer_enc_dec.py ├── optimizer ├── __init__.py ├── noam_lr_sched.py └── step_lr_sched.py ├── paper ├── .gitignore ├── config.json ├── lib │ ├── __init__.py │ ├── common.py │ ├── config.py │ ├── matplotlib_config.py │ ├── source.py │ └── stat_tracker.py ├── plot_big_result_table.py ├── plot_big_result_table_iid.py ├── plot_big_result_table_with_init.py ├── plot_cogs_early_stopping.py ├── plot_init.py ├── plot_init_iid.py ├── plot_loss_accuracy.py ├── plot_loss_analysis.py ├── plot_pcfg.py ├── plot_relatrafo_convergece.py ├── plot_scan_eos_performance.py ├── plot_small_batch.py └── run_all.sh ├── plot_dataset_stats.py ├── requrements.txt ├── run.py ├── sweeps ├── cfq_mcd.yaml ├── cfq_mcd_small_batch.yaml ├── cfq_mcd_small_batch_universal.yaml ├── cfq_mcd_universal.yaml ├── cfq_out_length.yaml ├── cfq_out_length_small_batch.yaml ├── cfq_out_length_universal.yaml ├── cfq_out_length_universal_small_batch.yaml ├── cogs_trafo.yaml ├── cogs_trafo_official.yaml ├── dm_math.yaml ├── pcfg_nosched_iid.yaml ├── pcfg_nosched_productivity.yaml ├── pcfg_nosched_systematicity.yaml └── scan_trafo_length_cutoff.yaml └── tasks ├── __init__.py ├── cfq_transformer.py ├── cogs_transofrmer.py ├── dm_math_transformer.py ├── pcfg_transformer.py ├── scan_resplit_transformer.py ├── scan_transformer.py ├── task.py └── transformer_mixin.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | save 3 | __pycache__ 4 | wandb 5 | cache 6 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "COGS trafo", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${workspaceFolder}/main.py", 12 | "console": "integratedTerminal", 13 | "args": ["--name", "cogs_trafo", "--profile", "pcfg_trafo", "--task", "cogs_transformer", "--log", "tb", 14 | "--keep_alive", "1", "-reset", "1", "-stop_after", "1000000"] 15 | }, 16 | 17 | { 18 | "name": "Scan trafo - length", 19 | "type": "python", 20 | "request": "launch", 21 | "program": "${workspaceFolder}/main.py", 22 | "console": "integratedTerminal", 23 | "args": ["--name", "trafo_scan_l", "--profile", "trafo_scan", "--task", "trafo_scan", "--log", "tb", 24 | "--keep_alive", "1", "-reset", "1", "-stop_after", "1000000", "-scan.train_split", "length", 25 | "--test_interval", "1000"] 26 | }, 27 | 28 | { 29 | "name": "Scan trafo - add turn left", 30 | "type": "python", 31 | "request": "launch", 32 | "program": "${workspaceFolder}/main.py", 33 | "console": "integratedTerminal", 34 | "args": ["--name", "trafo_scan_a", "--profile", "trafo_scan", "--task", "trafo_scan", "--log", "tb", 35 | "--keep_alive", "1", "-reset", "1", "-stop_after", "1000000", "-scan.train_split", "turn_left", 36 | "--test_interval", "1000"] 37 | }, 38 | 39 | { 40 | "name": "Scan relative trafo", 41 | "type": "python", 42 | "request": "launch", 43 | "program": "${workspaceFolder}/main.py", 44 | "console": "integratedTerminal", 45 | "args": ["--name", "rel_trafo_scan", "--profile", "trafo_scan", "--task", "trafo_scan", "--log", "tb", 46 | "--keep_alive", "1", "-reset", "1", "-stop_after", "1000000", "-test_interval", "1000", 47 | "-scan.train_split", "turn_left", "-transformer.variant", "relative"] 48 | }, 49 | 50 | { 51 | "name": "Scan relative trafo - length split", 52 | "type": "python", 53 | "request": "launch", 54 | "program": "${workspaceFolder}/main.py", 55 | "console": "integratedTerminal", 56 | "args": ["--name", "rel_trafo_scan", "--profile", "trafo_scan", "--task", "trafo_scan", "--log", "tb", 57 | "--keep_alive", "1", "-reset", "1", "-stop_after", "1000000", "-test_interval", "1000", 58 | "-scan.train_split", "length", "-transformer.variant", "relative"] 59 | }, 60 | 61 | { 62 | "name": "Scan relative trafo - custom_length_split", 63 | "type": "python", 64 | "request": "launch", 65 | "program": "${workspaceFolder}/main.py", 66 | "console": "integratedTerminal", 67 | "args": ["--name", "rel_trafo_scan", "--profile", "trafo_scan", "--task", "scan_resplit_transformer", "--log", "tb", 68 | "--keep_alive", "1", "-reset", "1", "-stop_after", "1000000", "-test_interval", "1000", 69 | "-transformer.variant", "relative"] 70 | }, 71 | 72 | { 73 | "name": "Simple trafo scan - EOS paper", 74 | "type": "python", 75 | "request": "launch", 76 | "program": "${workspaceFolder}/main.py", 77 | "console": "integratedTerminal", 78 | "args": ["--name", "trafo_scan_eos", "--profile", "trafo_scan", "--task", "trafo_scan", "--log", "tb", 79 | "--keep_alive", "1", "-reset", "1", "-stop_after", "1000000", "-scan.train_split", "length", 80 | "--state_size", "128", "-transformer.ff_multiplier", "2", "-transformer.n_heads", "4", 81 | "-transformer.encoder_n_layers", "3", "-transformer.decoder_n_layers", "3", 82 | "--test_interval", "1000"] 83 | }, 84 | 85 | { 86 | "name": "COGS relative trafo", 87 | "type": "python", 88 | "request": "launch", 89 | "program": "${workspaceFolder}/main.py", 90 | "console": "integratedTerminal", 91 | "args": ["--name", "cogs_trafo_rel", "--profile", "cogs_trafo_small", "--log", "tb", 92 | "--keep_alive", "1", "-reset", "1", "-stop_after", "1000000", "-batch_size", "64", "-test_interval", "1000", 93 | "--test_batch_size", "64", "--cogs.generalization_test_interval", "2500", "-label_smoothing", "0", 94 | "-transformer.encoder_n_layers", "2", "-transformer.decoder_n_layers", "2", "-transformer.ff_multiplier", "1", 95 | "-transformer.n_heads", "4", "-transformer.variant", "relative"] 96 | }, 97 | 98 | { 99 | "name": "PCFG trafo - productivity", 100 | "type": "python", 101 | "request": "launch", 102 | "program": "${workspaceFolder}/main.py", 103 | "console": "integratedTerminal", 104 | "args": ["--name", "pcfg_trafo_prod", "--profile", "pcfg_trafo", "--log", "tb", 105 | "--keep_alive", "1", "-reset", "1", "-stop_after", "1000000", "-lr", "1e-4", 106 | "-pcfg.split", "productivity", "-test_interval", "1000", "-amp", "1"] 107 | }, 108 | 109 | { 110 | "name": "PCFG reltrafo", 111 | "type": "python", 112 | "request": "launch", 113 | "program": "${workspaceFolder}/main.py", 114 | "console": "integratedTerminal", 115 | "args": ["--name", "pcfg_rell_trafo", "--profile", "pcfg_trafo", "--log", "tb", 116 | "--keep_alive", "1", "-reset", "1", "-stop_after", "1000000", "-lr", "1e-4", 117 | "-test_interval", "1000", "-task", "pcfg_transformer", "-amp", "1", 118 | "-transformer.variant", "relative"] 119 | }, 120 | 121 | 122 | { 123 | "name": "CFQ trafo", 124 | "type": "python", 125 | "request": "launch", 126 | "program": "${workspaceFolder}/main.py", 127 | "console": "integratedTerminal", 128 | "args": ["--name", "cfq_trafo", "--profile", "pcfg_trafo", "--log", "tb", "-task", "cfq_trafo", 129 | "--keep_alive", "1", "-reset", "1", "-stop_after", "1000000", "-lr", "1e-4", "-amp", "1", 130 | "-test_interval", "100", "-log_sample_level_loss", "1"] 131 | }, 132 | 133 | { 134 | "name": "DM Math relative add or sub", 135 | "type": "python", 136 | "request": "launch", 137 | "program": "${workspaceFolder}/main.py", 138 | "console": "integratedTerminal", 139 | "args": ["--name", "dm_math_add_or_sub", "--profile", "deepmind_math", "--task", "dm_math_transformer", 140 | "--log", "tb", "--keep_alive", "1", "-reset", "1", "-stop_after", "1000000", 141 | "--test_interval", "1000", "-transformer.variant", "relative", "-dm_math.task", 142 | "arithmetic__add_or_sub", "-batch_size", "256"] 143 | }, 144 | 145 | { 146 | "type": "python", 147 | "request": "launch", 148 | "name": "Debug File", 149 | "program": "${file}", 150 | "cwd": "${fileDirname}" 151 | } 152 | ] 153 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.languageServer": "Pylance", 3 | "python.linting.pylintEnabled": false, 4 | "editor.rulers": [120], 5 | "python.analysis.typeCheckingMode": "basic", 6 | "python.linting.flake8Enabled": true, 7 | "python.linting.flake8CategorySeverity.F": "Warning", 8 | "python.linting.flake8Args": [ 9 | "--ignore=E203, E266, E501, W503", 10 | "--max-line-length=120", 11 | "--select=B,C,E,F,W,T4,B9", 12 | "--max-complexity=18" 13 | ], 14 | "workbench.colorCustomizations": { 15 | "editorError.background": "#ff990028", 16 | "editorError.foreground": "#00000000", 17 | "editorWarning.background": "#ff000028", 18 | "editorWarning.foreground": "#00000000", 19 | } 20 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Róbert Csordás 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Codebase for training transformers on systematic generalization datasets. 2 | 3 | The official repository for our EMNLP 2021 paper [The Devil is in the Detail: Simple Tricks Improve Systematic Generalization of Transformers](https://arxiv.org/abs/2108.12284). 4 | 5 | Please note that this repository is a cleaned-up version of the internal research repository we use. In case you encounter any problems with it, please don't hesitate to contact me. 6 | 7 | ## Setup 8 | 9 | This project requires Python 3 (tested with Python 3.8 and 3.9) and PyTorch 1.8. 10 | 11 | ```bash 12 | pip3 install -r requirements.txt 13 | ``` 14 | 15 | Create a Weights and Biases account and run 16 | ```bash 17 | wandb login 18 | ``` 19 | 20 | More information on setting up Weights and Biases can be found on 21 | https://docs.wandb.com/quickstart. 22 | 23 | For plotting, LaTeX is required (to avoid Type 3 fonts and to render symbols). Installation is OS specific. 24 | 25 | ### Downloading data 26 | 27 | All datasets are downloaded automatically except the Mathematics Dataset and CFQ which is hosted in Google Cloud and one has to log in with his/her Google account to be able to access it. 28 | 29 | #### Math dataset 30 | Download the .tar.gz file manually from here: 31 | 32 | https://console.cloud.google.com/storage/browser/mathematics-dataset?pli=1 33 | 34 | Copy it to the ``cache/dm_math/`` folder. You should have a ``cache/dm_math/mathematics_dataset-v1.0.tar.gz`` file in the project folder if you did everyhing correctly. 35 | 36 | #### CFQ 37 | Download the .tar.gz file manually from here: 38 | 39 | https://storage.cloud.google.com/cfq_dataset/cfq1.1.tar.gz 40 | 41 | Copy it to the ``cache/CFQ/`` folder. You should have a ``cache/CFQ/cfq1.1.tar.gz`` file in the project folder if you did everyhing correctly. 42 | 43 | 44 | ## Usage 45 | 46 | ### Running the experiments from the paper on a cluster 47 | 48 | The code makes use of Weights and Biases for experiment tracking. In the ```sweeps``` directory, we provide sweep configurations for all experiments we have performed. The sweeps are officially meant for hyperparameter optimization, but we use them to run multiple configurations and seeds. 49 | 50 | To reproduce our results, start a sweep for each of the YAML files in the ```sweeps``` directory. Run wandb agent for each of them in the _root directory of the project_. This will run all the experiments, and they will be displayed on the W&B dashboard. The name of the sweeps must match the name of the files in ```sweeps``` directory, except the ```.yaml``` ending. More details on how to run W&B sweeps can be found at https://docs.wandb.com/sweeps/quickstart. 51 | 52 | For example, if you want to run Math Dataset experiments, run ```wandb sweep --name dm_math sweeps/dm_math.yaml```. This creates the sweep and prints out its ID. Then run ```wandb agent ``` with that ID. 53 | 54 | #### Re-creating plots from the paper 55 | 56 | Edit config file ```paper/config.json```. Enter your project name in the field "wandb_project" (e.g. "username/project"). 57 | 58 | Run the scripts in the ```paper``` directory. For example: 59 | 60 | ```bash 61 | cd paper 62 | ./run_all.sh 63 | ``` 64 | 65 | The output will be generated in the ```paper/out/``` directory. Tables will be printed to stdout in latex format. 66 | 67 | If you want to reproduce individual plots, it can be done by running individial python files in the ```paper``` directory. 68 | 69 | ### Running experiments locally 70 | 71 | It is possible to run single experiments with Tensorboard without using Weights and Biases. This is intended to be used for debugging the code locally. 72 | 73 | If you want to run experiments locally, you can use ```run.py```: 74 | 75 | ```bash 76 | ./run.py sweeps/tuple_rnn.yaml 77 | ``` 78 | 79 | If the sweep in question has multiple parameter choices, ```run.py``` will interactively prompt choices of each of them. 80 | 81 | The experiment also starts a Tensorboard instance automatically on port 7000. If the port is already occupied, it will incrementally search for the next free port. 82 | 83 | Note that the plotting scripts work only with Weights and Biases. 84 | 85 | ### Reducing memory usage 86 | 87 | In case some tasks won't fit on your GPU, play around with "-max_length_per_batch " argument. It can trade off memory usage/speed by slicing batches and executing them in multiple passes. Reduce it until the model fits. 88 | 89 | ### BibTex 90 | ``` 91 | @inproceedings{csordas2021devil, 92 | title={The Devil is in the Detail: Simple Tricks Improve Systematic Generalization of Transformers}, 93 | author={R\'obert Csord\'as and Kazuki Irie and J\"urgen Schmidhuber}, 94 | booktitle={Proc. Conf. on Empirical Methods in Natural Language Processing (EMNLP)}, 95 | year={2021}, 96 | month={November}, 97 | address={Punta Cana, Dominican Republic} 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .text.scan import Scan 2 | from .scan_length_resplit import ScanLengthResplit 3 | from .text.dm_math import DeepmindMathDataset 4 | from .text.pcfg_set import PCFGSet 5 | from .text.cogs import COGS 6 | from .text.cfq import CFQ 7 | -------------------------------------------------------------------------------- /dataset/scan_length_resplit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import os 4 | import numpy as np 5 | from framework.utils import download 6 | from framework.data_structures import WordVocabulary 7 | from typing import Dict, Any, Tuple 8 | from .sequence import TextSequenceTestState 9 | 10 | 11 | class ScanLengthResplit(torch.utils.data.Dataset): 12 | in_sentences = [] 13 | out_sentences = [] 14 | index_table = {} 15 | 16 | URL = "https://raw.githubusercontent.com/brendenlake/SCAN/master/tasks.txt" 17 | 18 | def _load_dataset(self, cache_dir: str): 19 | if ScanLengthResplit.in_sentences: 20 | return 21 | 22 | os.makedirs(cache_dir, exist_ok=True) 23 | cache_file = os.path.join(cache_dir, "scan.pth") 24 | 25 | if not os.path.isfile(cache_file): 26 | fn = os.path.join(cache_dir, os.path.split(self.URL)[-1]) 27 | 28 | print("Downloading", self.URL) 29 | download(self.URL, fn, ignore_if_exists=True) 30 | 31 | with open(fn) as f: 32 | for line in f: 33 | line = line.split("OUT:") 34 | line[0] = line[0].replace("IN:", "") 35 | line = [l.strip() for l in line] 36 | 37 | ScanLengthResplit.in_sentences.append(line[0]) 38 | ScanLengthResplit.out_sentences.append(line[1]) 39 | 40 | print("Constructing vocabularies") 41 | ScanLengthResplit.in_vocabulary = WordVocabulary(self.in_sentences) 42 | ScanLengthResplit.out_vocabulary = WordVocabulary(self.out_sentences) 43 | 44 | ScanLengthResplit.in_sentences = [ScanLengthResplit.in_vocabulary(s) 45 | for s in ScanLengthResplit.in_sentences] 46 | ScanLengthResplit.out_sentences = [ScanLengthResplit.out_vocabulary(s) 47 | for s in ScanLengthResplit.out_sentences] 48 | 49 | ScanLengthResplit.max_in_len = max(len(l) for l in ScanLengthResplit.in_sentences) 50 | ScanLengthResplit.max_out_len = max(len(l) for l in ScanLengthResplit.out_sentences) 51 | 52 | print("Done.") 53 | torch.save({ 54 | "in_sentences": ScanLengthResplit.in_sentences, 55 | "out_sentences": ScanLengthResplit.out_sentences, 56 | "in_voc": ScanLengthResplit.in_vocabulary.state_dict(), 57 | "out_voc": ScanLengthResplit.out_vocabulary.state_dict(), 58 | "max_in_len": ScanLengthResplit.max_in_len, 59 | "max_out_len": ScanLengthResplit.max_out_len 60 | }, cache_file) 61 | else: 62 | data = torch.load(cache_file) 63 | ScanLengthResplit.in_vocabulary = WordVocabulary(None) 64 | ScanLengthResplit.out_vocabulary = WordVocabulary(None) 65 | ScanLengthResplit.in_vocabulary.load_state_dict(data["in_voc"]) 66 | ScanLengthResplit.out_vocabulary.load_state_dict(data["out_voc"]) 67 | ScanLengthResplit.in_sentences = data["in_sentences"] 68 | ScanLengthResplit.out_sentences = data["out_sentences"] 69 | ScanLengthResplit.max_in_len = data["max_in_len"] 70 | ScanLengthResplit.max_out_len = data["max_out_len"] 71 | 72 | 73 | def __init__(self, dset: str, len_range: Tuple[int, int], train_proprtion: float = 0.9, 74 | cache_dir: str = "./cache/scan_resplit"): 75 | super().__init__() 76 | self.cache_dir = cache_dir 77 | self._load_dataset(cache_dir) 78 | self.len_range = len_range 79 | 80 | assert dset in ["train", "test", "all"] 81 | 82 | self.my_indices = [i for i, o in enumerate(self.out_sentences) if len_range[0] <= len(o) <= len_range[1]] 83 | 84 | if dset != "all": 85 | seed = np.random.RandomState(1234) 86 | test_indices = set(seed.choice(len(self.my_indices), int(len(self.my_indices) * (1 - train_proprtion)), 87 | replace=False).tolist()) 88 | 89 | self.my_indices = [i for ii, i in enumerate(self.my_indices) if (ii in test_indices) ^ (dset == "train")] 90 | 91 | self.this_max_out_len = max(len(self.out_sentences[i]) for i in self.my_indices) 92 | self.this_min_out_len = min(len(self.out_sentences[i]) for i in self.my_indices) 93 | 94 | def __len__(self) -> int: 95 | return len(self.my_indices) 96 | 97 | def __getitem__(self, item: int) -> Dict[str, Any]: 98 | index = self.my_indices[item] 99 | in_seq = ScanLengthResplit.in_sentences[index] 100 | out_seq = ScanLengthResplit.out_sentences[index] 101 | 102 | return { 103 | "in": np.asarray(in_seq, np.int16), 104 | "out": np.asarray(out_seq, np.int16), 105 | "in_len": len(in_seq), 106 | "out_len": len(out_seq) 107 | } 108 | 109 | def get_output_size(self): 110 | return len(self.out_vocabulary) 111 | 112 | def get_input_size(self): 113 | return len(self.in_vocabulary) 114 | 115 | def start_test(self) -> TextSequenceTestState: 116 | return TextSequenceTestState(lambda x: " ".join(self.in_vocabulary(x)), 117 | lambda x: " ".join(self.out_vocabulary(x))) 118 | 119 | def __str__(self): 120 | return f"ScanLengthResplit(range=[{self.this_min_out_len}, {self.this_max_out_len}], len={len(self)})" 121 | 122 | __repr__ = __str__ 123 | -------------------------------------------------------------------------------- /dataset/text/cfq.py: -------------------------------------------------------------------------------- 1 | import json 2 | import mmap 3 | from tqdm import tqdm 4 | import string 5 | from .text_dataset import TextDataset, TextDatasetCache 6 | from typing import Tuple, List 7 | import os 8 | import tarfile 9 | 10 | 11 | class CFQ(TextDataset): 12 | URL = "https://storage.cloud.google.com/cfq_dataset/cfq1.1.tar.gz" 13 | 14 | def tokenize_punctuation(self, text): 15 | # From https://github.com/google-research/google-research/blob/master/cfq/preprocess.py 16 | text = map(lambda c: ' %s ' % c if c in string.punctuation else c, text) 17 | return ' '.join(''.join(text).split()) 18 | 19 | def preprocess_sparql(self, query): 20 | # From https://github.com/google-research/google-research/blob/master/cfq/preprocess.py 21 | """Do various preprocessing on the SPARQL query.""" 22 | # Tokenize braces. 23 | query = query.replace('count(*)', 'count ( * )') 24 | 25 | tokens = [] 26 | for token in query.split(): 27 | # Replace 'ns:' prefixes. 28 | if token.startswith('ns:'): 29 | token = token[3:] 30 | # Replace mid prefixes. 31 | if token.startswith('m.'): 32 | token = 'm_' + token[2:] 33 | tokens.append(token) 34 | 35 | return ' '.join(tokens).replace('\\n', ' ') 36 | 37 | def load_data(self, fname: str) -> Tuple[List[str], List[str]]: 38 | # Split the JSON manually, otherwise it requires infinite RAM and is very slow. 39 | pin = "complexityMeasures".encode() 40 | offset = 1 41 | cnt = 0 42 | 43 | inputs = [] 44 | outputs = [] 45 | 46 | with open(fname, "r") as f: 47 | data = mmap.mmap(f.fileno(), 0, prot=mmap.PROT_READ) 48 | pbar = tqdm(total=len(data)) 49 | pbar.update(offset) 50 | 51 | while True: 52 | pos = data.find(pin, offset+6) 53 | if pos < 0: 54 | this = data[offset: len(data)-2] 55 | else: 56 | this = data[offset: pos-5] 57 | new_offset = pos - 4 58 | pbar.update(new_offset - offset) 59 | offset = new_offset 60 | d = json.loads(this.decode()) 61 | inputs.append(self.tokenize_punctuation(d["questionPatternModEntities"])) 62 | outputs.append(self.preprocess_sparql(d["sparqlPatternModEntities"])) 63 | 64 | cnt += 1 65 | if pos < 0: 66 | break 67 | 68 | return inputs, outputs 69 | 70 | def build_cache(self) -> TextDatasetCache: 71 | index_table = {} 72 | 73 | if not os.path.isdir(os.path.join(self.cache_dir, "cfq")): 74 | gzfile = os.path.join(self.cache_dir, os.path.basename(self.URL)) 75 | if not os.path.isfile(gzfile): 76 | assert False, f"Please download {self.URL} and place it in the {os.path.abspath(self.cache_dir)} "\ 77 | "folder. Google login needed." 78 | 79 | with tarfile.open(gzfile, "r") as tf: 80 | tf.extractall(path=self.cache_dir) 81 | 82 | splitdir = os.path.join(self.cache_dir, "cfq", "splits") 83 | for f in os.listdir(splitdir): 84 | if not f.endswith(".json"): 85 | continue 86 | 87 | name = f[:-5].replace("_split", "") 88 | with open(os.path.join(splitdir, f), "r") as f: 89 | ind = json.loads(f.read()) 90 | 91 | index_table[name] = { 92 | "train": ind["trainIdxs"], 93 | "val": ind["devIdxs"], 94 | "test": ind["testIdxs"] 95 | } 96 | 97 | in_sentences, out_sentences = self.load_data(os.path.join(self.cache_dir, "cfq/dataset.json")) 98 | assert len(in_sentences) == len(out_sentences) 99 | return TextDatasetCache().build(index_table, in_sentences, out_sentences, split_punctuation=False) 100 | -------------------------------------------------------------------------------- /dataset/text/cogs.py: -------------------------------------------------------------------------------- 1 | import os 2 | from framework.utils import download 3 | import csv 4 | from .typed_text_dataset import TypedTextDataset, TypedTextDatasetCache 5 | from ..sequence import TypedTextSequenceTestState 6 | 7 | 8 | class COGS(TypedTextDataset): 9 | URL_BASE = "https://raw.githubusercontent.com/najoungkim/COGS/main/data/" 10 | SPLT_TYPES = ["train", "test", "valid", "gen"] 11 | NAME_MAP = {"valid": "dev"} 12 | 13 | def build_cache(self) -> TypedTextDatasetCache: 14 | 15 | types = [] 16 | type_list = [] 17 | type_map = {} 18 | 19 | index_table = {} 20 | in_sentences = [] 21 | out_sentences = [] 22 | 23 | for st in self.SPLT_TYPES: 24 | fname = self.NAME_MAP.get(st, st) + ".tsv" 25 | split_fn = os.path.join(self.cache_dir, fname) 26 | os.makedirs(os.path.dirname(split_fn), exist_ok=True) 27 | 28 | full_url = self.URL_BASE + fname 29 | print("Downloading", full_url) 30 | download(full_url, split_fn, ignore_if_exists=True) 31 | 32 | index_table[st] = [] 33 | 34 | with open(split_fn, "r") as f: 35 | d = csv.reader(f, delimiter="\t") 36 | for line in d: 37 | i, o, t = line 38 | 39 | index_table[st].append(len(in_sentences)) 40 | in_sentences.append(i) 41 | out_sentences.append(o) 42 | 43 | tind = type_map.get(t) 44 | if tind is None: 45 | type_map[t] = tind = len(type_list) 46 | type_list.append(t) 47 | 48 | types.append(tind) 49 | 50 | assert len(in_sentences) == len(out_sentences) 51 | 52 | return TypedTextDatasetCache().build({"default": index_table}, in_sentences, out_sentences, types, type_list) 53 | 54 | def start_test(self) -> TypedTextSequenceTestState: 55 | return TypedTextSequenceTestState(lambda x: " ".join(self.in_vocabulary(x)), 56 | lambda x: " ".join(self.out_vocabulary(x)), 57 | self._cache.type_names) 58 | -------------------------------------------------------------------------------- /dataset/text/pcfg_set.py: -------------------------------------------------------------------------------- 1 | import os 2 | from framework.utils import download 3 | from .text_dataset import TextDataset, TextDatasetCache 4 | 5 | 6 | class PCFGSet(TextDataset): 7 | URLS = { 8 | "simple": "https://raw.githubusercontent.com/i-machine-think/am-i-compositional/master/data/pcfgset/pcfgset", 9 | "productivity": "https://raw.githubusercontent.com/i-machine-think/am-i-compositional/master/data/pcfgset/productivity", 10 | "substitutivity": "https://raw.githubusercontent.com/i-machine-think/am-i-compositional/master/data/pcfgset/substitutivity/primitive", 11 | "systematicity": "https://raw.githubusercontent.com/i-machine-think/am-i-compositional/master/data/pcfgset/systematicity", 12 | } 13 | 14 | def build_cache(self) -> TextDatasetCache: 15 | index_table = {} 16 | in_sentences = [] 17 | out_sentences = [] 18 | 19 | for split_type, url in self.URLS.items(): 20 | index_table[split_type] = {} 21 | 22 | for set in ["test", "train"] + (["dev"] if split_type == "simple" else []): 23 | set_url = f"{url}/{set}" 24 | set_fn = os.path.join(self.cache_dir, split_type, os.path.split(set_url)[-1]) 25 | os.makedirs(os.path.dirname(set_fn), exist_ok=True) 26 | 27 | for f in ["src", "tgt"]: 28 | full_url = f"{set_url}.{f}" 29 | print("Downloading", full_url) 30 | download(full_url, f"{set_fn}.{f}", ignore_if_exists=True) 31 | 32 | this_set = [] 33 | index_table[split_type][set] = this_set 34 | 35 | with open(set_fn + ".src") as f: 36 | for line in f: 37 | in_sentences.append(line.strip()) 38 | this_set.append(len(in_sentences) - 1) 39 | 40 | with open(set_fn + ".tgt") as f: 41 | for line in f: 42 | out_sentences.append(line.strip()) 43 | 44 | assert len(in_sentences) == len(out_sentences) 45 | 46 | return TextDatasetCache().build(index_table, in_sentences, out_sentences) 47 | -------------------------------------------------------------------------------- /dataset/text/scan.py: -------------------------------------------------------------------------------- 1 | import os 2 | from framework.utils import download 3 | from .text_dataset import TextDataset, TextDatasetCache 4 | 5 | 6 | class Scan(TextDataset): 7 | URLS = { 8 | "simple": { 9 | "train": "https://raw.githubusercontent.com/brendenlake/SCAN/master/simple_split/tasks_train_simple.txt", 10 | "test": "https://raw.githubusercontent.com/brendenlake/SCAN/master/simple_split/tasks_test_simple.txt" 11 | }, 12 | "length": { 13 | "train": "https://raw.githubusercontent.com/brendenlake/SCAN/master/length_split/tasks_train_length.txt", 14 | "test": "https://raw.githubusercontent.com/brendenlake/SCAN/master/length_split/tasks_test_length.txt" 15 | }, 16 | "jump": { 17 | "train": "https://raw.githubusercontent.com/brendenlake/SCAN/master/add_prim_split/tasks_train_addprim_jump.txt", 18 | "test": "https://raw.githubusercontent.com/brendenlake/SCAN/master/add_prim_split/tasks_test_addprim_jump.txt" 19 | }, 20 | "turn_left": { 21 | "train": "https://raw.githubusercontent.com/brendenlake/SCAN/master/add_prim_split/tasks_train_addprim_turn_left.txt", 22 | "test": "https://raw.githubusercontent.com/brendenlake/SCAN/master/add_prim_split/tasks_test_addprim_turn_left.txt" 23 | }, 24 | } 25 | 26 | def build_cache(self) -> TextDatasetCache: 27 | index_table = {} 28 | in_sentences = [] 29 | out_sentences = [] 30 | 31 | for split_type, split in self.URLS.items(): 32 | index_table[split_type] = {} 33 | 34 | for set, url in split.items(): 35 | fn = os.path.join(self.cache_dir, os.path.split(url)[-1]) 36 | 37 | print("Downloading", url) 38 | download(url, fn, ignore_if_exists=True) 39 | 40 | this_set = [] 41 | index_table[split_type][set] = this_set 42 | 43 | with open(fn) as f: 44 | for line in f: 45 | line = line.split("OUT:") 46 | line[0] = line[0].replace("IN:", "") 47 | line = [l.strip() for l in line] 48 | 49 | in_sentences.append(line[0]) 50 | out_sentences.append(line[1]) 51 | 52 | this_set.append(len(in_sentences) - 1) 53 | 54 | return TextDatasetCache().build(index_table, in_sentences, out_sentences) 55 | -------------------------------------------------------------------------------- /dataset/text/text_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import os 4 | import numpy as np 5 | import framework 6 | from framework.data_structures import WordVocabulary 7 | from typing import List, Dict, Any, Tuple 8 | from ..sequence import TextSequenceTestState 9 | 10 | IndexTable = Dict[str, Dict[str, List[int]]] 11 | VERSION = 5 12 | 13 | 14 | class TextDatasetCache: 15 | version: int 16 | in_sentences: List[List[int]] 17 | out_sentences: List[List[int]] 18 | 19 | index_table: IndexTable 20 | in_vocabulary: WordVocabulary 21 | out_vocabulary: WordVocabulary 22 | 23 | len_histogram: Dict[float, int] 24 | 25 | max_in_len: int 26 | max_out_len: int 27 | 28 | def build(self, index_table: IndexTable, in_sentences: List[str], out_sentences: List[str], 29 | split_punctuation: bool = True): 30 | self.version = VERSION 31 | self.index_table = index_table 32 | 33 | print("Constructing vocabularies") 34 | self.in_vocabulary = WordVocabulary(in_sentences, split_punctuation=split_punctuation) 35 | self.out_vocabulary = WordVocabulary(out_sentences, split_punctuation=split_punctuation) 36 | 37 | self.in_sentences = [self.in_vocabulary(s) for s in in_sentences] 38 | self.out_sentences = [self.out_vocabulary(s) for s in out_sentences] 39 | 40 | print("Calculating length statistics") 41 | counts, bins = np.histogram([len(i)+len(o) for i, o in zip(self.in_sentences, self.out_sentences)]) 42 | self.sum_len_histogram = {k: v for k, v in zip(bins.tolist(), counts.tolist())} 43 | 44 | counts, bins = np.histogram([len(i) for i in self.in_sentences]) 45 | self.in_len_histogram = {k: v for k, v in zip(bins.tolist(), counts.tolist())} 46 | 47 | counts, bins = np.histogram([len(o) for o in self.out_sentences]) 48 | self.out_len_histogram = {k: v for k, v in zip(bins.tolist(), counts.tolist())} 49 | 50 | self.max_in_len = max(len(l) for l in self.in_sentences) 51 | self.max_out_len = max(len(l) for l in self.out_sentences) 52 | print("Done.") 53 | 54 | return self 55 | 56 | def state_dict(self) -> Dict[str, Any]: 57 | return { 58 | "version": self.version, 59 | "index": self.index_table, 60 | "in_sentences": self.in_sentences, 61 | "out_sentences": self.out_sentences, 62 | "in_voc": self.in_vocabulary.state_dict(), 63 | "out_voc": self.out_vocabulary.state_dict(), 64 | "max_in_len": self.max_in_len, 65 | "max_out_len": self.max_out_len, 66 | "in_len_histogram": self.in_len_histogram, 67 | "sum_len_histogram": self.sum_len_histogram, 68 | "out_len_histogram": self.out_len_histogram 69 | } 70 | 71 | def load_state_dict(self, data: Dict[str, Any]): 72 | self.version = data.get("version", -1) 73 | if self.version != VERSION: 74 | return 75 | self.index_table = data["index"] 76 | self.in_vocabulary = WordVocabulary(None) 77 | self.out_vocabulary = WordVocabulary(None) 78 | self.in_vocabulary.load_state_dict(data["in_voc"]) 79 | self.out_vocabulary.load_state_dict(data["out_voc"]) 80 | self.in_sentences = data["in_sentences"] 81 | self.out_sentences = data["out_sentences"] 82 | self.max_in_len = data["max_in_len"] 83 | self.max_out_len = data["max_out_len"] 84 | self.in_len_histogram = data["in_len_histogram"] 85 | self.out_len_histogram = data["out_len_histogram"] 86 | self.sum_len_histogram = data["sum_len_histogram"] 87 | 88 | def save(self, fn: str): 89 | torch.save(self.state_dict(), fn) 90 | 91 | @classmethod 92 | def load(cls, fn: str): 93 | res = cls() 94 | try: 95 | data = torch.load(fn) 96 | except: 97 | print(f"Failed to load cache file. {fn}") 98 | res.version = -1 99 | return res 100 | 101 | res.load_state_dict(data) 102 | 103 | return res 104 | 105 | 106 | class TextDataset(torch.utils.data.Dataset): 107 | static_data: Dict[str, TextDatasetCache] = {} 108 | 109 | def build_cache(self) -> TextDatasetCache: 110 | raise NotImplementedError() 111 | 112 | def load_cache_file(self, file) -> TextDatasetCache: 113 | return TextDatasetCache.load(file) 114 | 115 | def _load_dataset(self): 116 | os.makedirs(self.cache_dir, exist_ok=True) 117 | cache_file = os.path.join(self.cache_dir, "cache.pth") 118 | 119 | if os.path.isfile(cache_file): 120 | res = self.load_cache_file(cache_file) 121 | if res.version == VERSION: 122 | return res 123 | else: 124 | print(f"{self.__class__.__name__}: Invalid cache version: {res.version}, current: {VERSION}") 125 | 126 | with framework.utils.LockFile(os.path.join(self.cache_dir, "lock")): 127 | res = self.build_cache() 128 | res.save(cache_file) 129 | return res 130 | 131 | def hist_to_text(self, histogram: Dict[float, int]) -> str: 132 | keys = list(sorted(histogram.keys())) 133 | values = [histogram[k] for k in keys] 134 | percent = (np.cumsum(values) * (100.0 / sum(histogram.values()))).tolist() 135 | return ", ".join(f"{k:.1f}: {v} (>= {p:.1f}%)" for k, v, p in zip(keys, values, percent)) 136 | 137 | def __init__(self, sets: List[str] = ["train"], split_type: List[str] = ["simple"], cache_dir: str = "./cache/", 138 | shared_vocabulary: bool = False): 139 | super().__init__() 140 | 141 | self.cache_dir = os.path.join(cache_dir, self.__class__.__name__) 142 | os.makedirs(self.cache_dir, exist_ok=True) 143 | 144 | assert isinstance(sets, List) 145 | assert isinstance(split_type, List) 146 | 147 | self._cache = TextDataset.static_data.get(self.__class__.__name__) 148 | just_loaded = self._cache is None 149 | if just_loaded: 150 | self._cache = self._load_dataset() 151 | TextDataset.static_data[self.__class__.__name__] = self._cache 152 | 153 | if shared_vocabulary: 154 | self.in_vocabulary = self._cache.in_vocabulary + self._cache.out_vocabulary 155 | self.out_vocabulary = self.in_vocabulary 156 | self.in_remap = self.in_vocabulary.mapfrom(self._cache.in_vocabulary) 157 | self.out_remap = self.out_vocabulary.mapfrom(self._cache.out_vocabulary) 158 | else: 159 | self.in_vocabulary = self._cache.in_vocabulary 160 | self.out_vocabulary = self._cache.out_vocabulary 161 | 162 | if just_loaded: 163 | for k, t in self._cache.index_table.items(): 164 | print(f"{self.__class__.__name__}: split {k} data:", 165 | ", ".join([f"{k}: {len(v)}" for k, v in t.items()])) 166 | print(f"{self.__class__.__name__}: vocabulary sizes: in: {len(self._cache.in_vocabulary)}, " 167 | f"out: {len(self._cache.out_vocabulary)}") 168 | print(f"{self.__class__.__name__}: max input length: {self._cache.max_in_len}, " 169 | f"max output length: {self._cache.max_out_len}") 170 | print(f"{self.__class__.__name__} sum length histogram: {self.hist_to_text(self._cache.sum_len_histogram)}") 171 | print(f"{self.__class__.__name__} in length histogram: {self.hist_to_text(self._cache.in_len_histogram)}") 172 | print(f"{self.__class__.__name__} out length histogram: {self.hist_to_text(self._cache.out_len_histogram)}") 173 | 174 | self.my_indices = [] 175 | for t in split_type: 176 | for s in sets: 177 | self.my_indices += self._cache.index_table[t][s] 178 | 179 | self.shared_vocabulary = shared_vocabulary 180 | 181 | def get_seqs(self, abs_index: int) -> Tuple[List[int], List[int]]: 182 | in_seq = self._cache.in_sentences[abs_index] 183 | out_seq = self._cache.out_sentences[abs_index] 184 | 185 | if self.shared_vocabulary: 186 | in_seq = [self.in_remap[i] for i in in_seq] 187 | out_seq = [self.out_remap[i] for i in out_seq] 188 | 189 | return in_seq, out_seq 190 | 191 | def __len__(self) -> int: 192 | return len(self.my_indices) 193 | 194 | def __getitem__(self, item: int) -> Dict[str, Any]: 195 | index = self.my_indices[item] 196 | in_seq, out_seq = self.get_seqs(index) 197 | 198 | return { 199 | "in": np.asarray(in_seq, np.int16), 200 | "out": np.asarray(out_seq, np.int16), 201 | "in_len": len(in_seq), 202 | "out_len": len(out_seq) 203 | } 204 | 205 | def get_output_size(self): 206 | return len(self._cache.out_vocabulary) 207 | 208 | def get_input_size(self): 209 | return len(self._cache.in_vocabulary) 210 | 211 | def start_test(self) -> TextSequenceTestState: 212 | return TextSequenceTestState(lambda x: " ".join(self.in_vocabulary(x)), 213 | lambda x: " ".join(self.out_vocabulary(x))) 214 | 215 | @property 216 | def max_in_len(self) -> int: 217 | return self._cache.max_in_len 218 | 219 | @property 220 | def max_out_len(self) -> int: 221 | return self._cache.max_out_len 222 | -------------------------------------------------------------------------------- /dataset/text/typed_text_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Any 2 | from .text_dataset import TextDataset, IndexTable, TextDatasetCache 3 | import numpy as np 4 | 5 | 6 | class TypedTextDatasetCache(TextDatasetCache): 7 | def build(self, index_table: IndexTable, in_sentences: List[str], out_sentences: List[str], types: List[int], 8 | type_names: List[str]): 9 | 10 | super().build(index_table, in_sentences, out_sentences) 11 | self.types = types 12 | self.type_names = type_names 13 | return self 14 | 15 | def state_dict(self) -> Dict[str, Any]: 16 | res = super().state_dict() 17 | res["types"] = self.types 18 | res["type_names"] = self.type_names 19 | return res 20 | 21 | def load_state_dict(self, state: Dict[str, Any]): 22 | super().load_state_dict(state) 23 | 24 | self.types = state["types"] 25 | self.type_names = state["type_names"] 26 | 27 | 28 | class TypedTextDataset(TextDataset): 29 | _cache: TypedTextDatasetCache 30 | static_data: Dict[str, TypedTextDatasetCache] = {} 31 | 32 | def load_cache_file(self, file) -> TypedTextDatasetCache: 33 | return TypedTextDatasetCache.load(file) 34 | 35 | def build_cache(self) -> TypedTextDatasetCache: 36 | raise NotImplementedError() 37 | 38 | def __init__(self, sets: List[str] = ["train"], cache_dir: str = "./cache/", shared_vocabulary: bool = False): 39 | super().__init__(sets, ["default"], cache_dir, shared_vocabulary) 40 | 41 | def __getitem__(self, item: int) -> Dict[str, Any]: 42 | index = self.my_indices[item] 43 | in_seq, out_seq = self.get_seqs(index) 44 | 45 | return { 46 | "in": np.asarray(in_seq, np.int16), 47 | "out": np.asarray(out_seq, np.int16), 48 | "in_len": len(in_seq), 49 | "out_len": len(out_seq), 50 | "type": self._cache.types[index] 51 | } 52 | -------------------------------------------------------------------------------- /framework/.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /framework/__init__.py: -------------------------------------------------------------------------------- 1 | from . import utils 2 | from . import visualize 3 | from . import helpers 4 | from . import loader 5 | from . import data_structures 6 | from . import layers 7 | -------------------------------------------------------------------------------- /framework/data_structures/__init__.py: -------------------------------------------------------------------------------- 1 | from .dotdict import DotDict 2 | from .vocabulary import WordVocabulary, CharVocabulary, ByteVocabulary 3 | -------------------------------------------------------------------------------- /framework/data_structures/dotdict.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Dict, Any, Union 3 | 4 | 5 | class DotDefaultDict(defaultdict): 6 | def __getattr__(self, item): 7 | if item not in self: 8 | raise AttributeError 9 | return self.get(item) 10 | 11 | __setattr__ = defaultdict.__setitem__ 12 | __delattr__ = defaultdict.__delitem__ 13 | 14 | 15 | class DotDict(dict): 16 | def __getattr__(self, item): 17 | if item not in self: 18 | raise AttributeError 19 | return self.get(item) 20 | 21 | __setattr__ = dict.__setitem__ 22 | __delattr__ = dict.__delitem__ 23 | 24 | 25 | def create_recursive_dot_dict(data: Dict[str, Any], cls=DotDict) -> Union[DotDict, DotDefaultDict]: 26 | """ 27 | Takes a dict of string keys and arbitrary values, and creates a tree of DotDicts. 28 | 29 | The keys might contain . in which case child DotDicts are created. 30 | 31 | :param data: Input dict with string keys potentially containing .s. 32 | :param cls: Either DotDict or DotDefaultDict 33 | :return: tree DotDict or DotDefaultDict where the keys are split by . 34 | """ 35 | res = cls() 36 | for k, v in data.items(): 37 | k = k.split(".") 38 | target = res 39 | for i in range(0, len(k)-1): 40 | t2 = target.get(k[i]) 41 | if t2 is None: 42 | t2 = cls() 43 | target[k[i]] = t2 44 | 45 | assert isinstance(t2, cls), f"Trying to overwrite key {'.'.join(k[:i+1])}" 46 | target = t2 47 | 48 | assert isinstance(target, cls), f"Trying to overwrite key {'.'.join(k)}" 49 | target[k[-1]] = v 50 | return res 51 | -------------------------------------------------------------------------------- /framework/data_structures/vocabulary.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List, Union, Optional, Dict, Any, Set 3 | 4 | 5 | class WordVocabulary: 6 | def __init__(self, list_of_sentences: Optional[List[Union[str, List[str]]]] = None, allow_any_word: bool = False, 7 | split_punctuation: bool = True): 8 | self.words: Dict[str, int] = {} 9 | self.inv_words: Dict[int, str] = {} 10 | self.to_save = ["words", "inv_words", "_unk_index", "allow_any_word", "split_punctuation"] 11 | self.allow_any_word = allow_any_word 12 | self.initialized = False 13 | self.split_punctuation = split_punctuation 14 | 15 | if list_of_sentences is not None: 16 | words = set() 17 | for s in list_of_sentences: 18 | words |= set(self.split_sentence(s)) 19 | 20 | self._add_set(words) 21 | self.finalize() 22 | 23 | def finalize(self): 24 | self._unk_index = self.words.get("", self.words.get("")) 25 | if self.allow_any_word: 26 | assert self._unk_index is not None 27 | 28 | self.initialized = True 29 | 30 | def _add_word(self, w: str): 31 | next_id = len(self.words) 32 | self.words[w] = next_id 33 | self.inv_words[next_id] = w 34 | 35 | def _add_set(self, words: Set[str]): 36 | for w in sorted(words): 37 | self._add_word(w) 38 | 39 | def _process_word(self, w: str) -> int: 40 | res = self.words.get(w, self._unk_index) 41 | assert (res != self._unk_index) or self.allow_any_word, f"WARNING: unknown word: '{w}'" 42 | return res 43 | 44 | def _process_index(self, i: int) -> str: 45 | res = self.inv_words.get(i, None) 46 | if res is None: 47 | return f"" 48 | return res 49 | 50 | def __getitem__(self, item: Union[int, str]) -> Union[str, int]: 51 | if isinstance(item, int): 52 | return self._process_index(item) 53 | else: 54 | return self._process_word(item) 55 | 56 | def split_sentence(self, sentence: Union[str, List[str]]) -> List[str]: 57 | if isinstance(sentence, list): 58 | # Already tokenized. 59 | return sentence 60 | 61 | if self.split_punctuation: 62 | return re.findall(r"\w+|[^\w\s]", sentence, re.UNICODE) 63 | else: 64 | return [x for x in sentence.split(" ") if x] 65 | 66 | def sentence_to_indices(self, sentence: Union[str, List[str]]) -> List[int]: 67 | assert self.initialized 68 | words = self.split_sentence(sentence) 69 | return [self._process_word(w) for w in words] 70 | 71 | def indices_to_sentence(self, indices: List[int]) -> List[str]: 72 | assert self.initialized 73 | return [self._process_index(i) for i in indices] 74 | 75 | def __call__(self, seq: Union[List[Union[str, int]], str]) -> List[Union[int, str]]: 76 | if seq is None or (isinstance(seq, list) and not seq): 77 | return seq 78 | 79 | if isinstance(seq, str) or isinstance(seq[0], str): 80 | return self.sentence_to_indices(seq) 81 | else: 82 | return self.indices_to_sentence(seq) 83 | 84 | def __len__(self) -> int: 85 | return len(self.words) 86 | 87 | def state_dict(self) -> Dict[str, Any]: 88 | return { 89 | k: self.__dict__[k] for k in self.to_save 90 | } 91 | 92 | def load_state_dict(self, state: Dict[str, Any]): 93 | self.initialized = True 94 | self.__dict__.update(state) 95 | 96 | def __add__(self, other): 97 | res = WordVocabulary(allow_any_word=self.allow_any_word and other.allow_any_word, 98 | split_punctuation=self.split_punctuation) 99 | res._add_set(set(self.words.keys()) | set(other.words.keys())) 100 | res.finalize() 101 | return res 102 | 103 | def mapfrom(self, other) -> Dict[int, int]: 104 | return {other.words[w]: i for w, i in self.words.items() if w in other.words} 105 | 106 | 107 | class CharVocabulary: 108 | def __init__(self, chars: Optional[Set[str]]): 109 | self.initialized = False 110 | if chars is not None: 111 | self.from_set(chars) 112 | 113 | def from_set(self, chars: Set[str]): 114 | chars = list(sorted(chars)) 115 | self.to_index = {c: i for i, c in enumerate(chars)} 116 | self.from_index = {i: c for i, c in enumerate(chars)} 117 | self.initialized = True 118 | 119 | def __len__(self): 120 | return len(self.to_index) 121 | 122 | def state_dict(self) -> Dict[str, Any]: 123 | return { 124 | "chars": set(self.to_index.keys()) 125 | } 126 | 127 | def load_state_dict(self, state: Dict[str, Any]): 128 | self.from_set(state["chars"]) 129 | 130 | def str_to_ind(self, data: str) -> List[int]: 131 | return [self.to_index[c] for c in data] 132 | 133 | def ind_to_str(self, data: List[int]) -> str: 134 | return "".join([self.from_index[i] for i in data]) 135 | 136 | def _is_string(self, i): 137 | return isinstance(i, str) 138 | 139 | def __call__(self, seq: Union[List[int], str]) -> Union[List[int], str]: 140 | assert self.initialized 141 | if seq is None or (isinstance(seq, list) and not seq): 142 | return seq 143 | 144 | if self._is_string(seq): 145 | return self.str_to_ind(seq) 146 | else: 147 | return self.ind_to_str(seq) 148 | 149 | def __add__(self, other): 150 | return self.__class__(set(self.to_index.values()) | set(other.to_index.values())) 151 | 152 | 153 | class ByteVocabulary(CharVocabulary): 154 | def ind_to_str(self, data: List[int]) -> bytearray: 155 | return bytearray([self.from_index[i] for i in data]) 156 | 157 | def _is_string(self, i): 158 | return isinstance(i, bytearray) 159 | -------------------------------------------------------------------------------- /framework/helpers/__init__.py: -------------------------------------------------------------------------------- 1 | from .training_helper import TrainingHelper 2 | from .argument_parser import ArgumentParser 3 | from .saver import Saver 4 | -------------------------------------------------------------------------------- /framework/helpers/argument_parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import re 5 | from ..data_structures.dotdict import create_recursive_dot_dict 6 | 7 | def none_parser(other_parser): 8 | def fn(x): 9 | if x.lower() == "none": 10 | return None 11 | 12 | return other_parser(x) 13 | 14 | return fn 15 | 16 | 17 | class ArgumentParser: 18 | _type = type 19 | 20 | @staticmethod 21 | @none_parser 22 | def int_list_parser(x): 23 | return [int(a) for a in re.split("[,_ ;]", x) if a] 24 | 25 | @staticmethod 26 | @none_parser 27 | def str_list_parser(x): 28 | return x.split(",") 29 | 30 | @staticmethod 31 | @none_parser 32 | def int_or_none_parser(x): 33 | return int(x) 34 | 35 | @staticmethod 36 | @none_parser 37 | def float_or_none_parser(x): 38 | return float(x) 39 | 40 | @staticmethod 41 | @none_parser 42 | def float_list_parser(x): 43 | return [float(a) for a in re.split("[,_ ;]", x) if a] 44 | 45 | @staticmethod 46 | def _merge_args(args, new_args, arg_schemas): 47 | for name, val in new_args.items(): 48 | old = args.get(name) 49 | if old is None: 50 | args[name] = val 51 | else: 52 | args[name] = arg_schemas[name]["updater"](old, val) 53 | 54 | class Profile: 55 | def __init__(self, name, args=None, include=[]): 56 | assert not (args is None and not include), "One of args or include must be defined" 57 | self.name = name 58 | self.args = args 59 | if not isinstance(include, list): 60 | include = [include] 61 | self.include = include 62 | 63 | def get_args(self, arg_schemas, profile_by_name): 64 | res = {} 65 | 66 | for n in self.include: 67 | p = profile_by_name.get(n) 68 | assert p is not None, "Included profile %s doesn't exists" % n 69 | 70 | ArgumentParser._merge_args(res, p.get_args(arg_schemas, profile_by_name), arg_schemas) 71 | 72 | ArgumentParser._merge_args(res, self.args, arg_schemas) 73 | return res 74 | 75 | def __init__(self, description=None, get_train_dir=lambda x: os.path.join("save", x.name)): 76 | self.parser = argparse.ArgumentParser(description=description) 77 | self.profiles = {} 78 | self.args = {} 79 | self.raw = None 80 | self.parsed = None 81 | self.get_train_dir = get_train_dir 82 | self.parser.add_argument("-profile", "--profile", type=str, help="Pre-defined profiles.") 83 | 84 | def add_argument(self, name, type=None, default=None, help="", save=True, parser=lambda x: x, 85 | updater=lambda old, new: new, choice=[]): 86 | assert name not in ["profile"], "Argument name %s is reserved" % name 87 | assert not (type is None and default is None), "Either type or default must be given" 88 | 89 | if type is None: 90 | type = ArgumentParser._type(default) 91 | 92 | self.parser.add_argument(name, "-" + name, type=int if type == bool else type, default=None, help=help) 93 | if name[0] == '-': 94 | name = name[1:] 95 | 96 | self.args[name] = { 97 | "type": type, 98 | "default": int(default) if type == bool else default, 99 | "save": save, 100 | "parser": parser, 101 | "updater": updater, 102 | "choice": choice 103 | } 104 | 105 | def add_profile(self, prof): 106 | if isinstance(prof, list): 107 | for p in prof: 108 | self.add_profile(p) 109 | else: 110 | self.profiles[prof.name] = prof 111 | 112 | def do_parse_args(self, loaded={}): 113 | self.raw = self.parser.parse_args() 114 | 115 | profile = {} 116 | if self.raw.profile: 117 | if loaded: 118 | if self.raw.profile != loaded.get("profile"): 119 | assert False, "Loading arguments from file, but a different profile is given." 120 | else: 121 | for pr in self.raw.profile.split(","): 122 | p = self.profiles.get(pr) 123 | assert p is not None, "Invalid profile: %s. Valid profiles: %s" % (pr, self.profiles.keys()) 124 | p = p.get_args(self.args, self.profiles) 125 | self._merge_args(profile, p, self.args) 126 | 127 | for k, v in self.raw.__dict__.items(): 128 | if k in ["profile"]: 129 | continue 130 | 131 | if v is None: 132 | if k in loaded and self.args[k]["save"]: 133 | self.raw.__dict__[k] = loaded[k] 134 | else: 135 | self.raw.__dict__[k] = profile.get(k, self.args[k]["default"]) 136 | 137 | for k, v in self.raw.__dict__.items(): 138 | if k not in self.args: 139 | continue 140 | c = self.args[k]["choice"] 141 | if c and not v in c: 142 | assert False, f"Invalid value {v}. Allowed: {c}" 143 | 144 | self.parsed = create_recursive_dot_dict({k: self.args[k]["parser"](self.args[k]["type"](v)) if v is not None 145 | else None for k, v in self.raw.__dict__.items() if k in self.args}) 146 | 147 | return self.parsed 148 | 149 | def parse_or_cache(self): 150 | if self.parsed is None: 151 | self.do_parse_args() 152 | 153 | def parse(self): 154 | self.parse_or_cache() 155 | return self.parsed 156 | 157 | def to_dict(self): 158 | self.parse_or_cache() 159 | return self.raw.__dict__ 160 | 161 | def clone(self): 162 | parser = ArgumentParser() 163 | parser.profiles = self.profiles 164 | parser.args = self.args 165 | for name, a in self.args.items(): 166 | parser.parser.add_argument("-" + name, type=int if a["type"] == bool else a["type"], default=None) 167 | parser.parse() 168 | return parser 169 | 170 | def from_dict(self, dict): 171 | return self.do_parse_args(dict) 172 | 173 | def save(self, fname): 174 | with open(fname, 'w') as outfile: 175 | json.dump(self.to_dict(), outfile, indent=4) 176 | return True 177 | 178 | def load(self, fname): 179 | if os.path.isfile(fname): 180 | with open(fname, "r") as data_file: 181 | map = json.load(data_file) 182 | 183 | self.from_dict(map) 184 | return self.parsed 185 | 186 | def sync(self, fname=None): 187 | if fname is None: 188 | fname = self._get_save_filename() 189 | 190 | if fname is not None: 191 | if os.path.isfile(fname): 192 | self.load(fname) 193 | 194 | dir = os.path.dirname(fname) 195 | os.makedirs(dir, exist_ok=True) 196 | 197 | self.save(fname) 198 | return self.parsed 199 | 200 | def _get_save_filename(self, opt=None): 201 | opt = self.parse() if opt is None else opt 202 | dir = self.get_train_dir(opt) 203 | return None if dir is None else os.path.join(dir, "args.json") 204 | 205 | def parse_and_sync(self): 206 | opt = self.parse() 207 | return self.sync(self._get_save_filename(opt)) 208 | 209 | def parse_and_try_load(self): 210 | fname = self._get_save_filename() 211 | if fname and os.path.isfile(fname): 212 | self.load(fname) 213 | 214 | return self.parsed 215 | -------------------------------------------------------------------------------- /framework/helpers/saver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import time 4 | from typing import Optional, List 5 | from collections import defaultdict 6 | 7 | 8 | class SaverElement: 9 | def save(self): 10 | raise NotImplementedError() 11 | 12 | def load(self, saved_state): 13 | raise NotImplementedError() 14 | 15 | 16 | class PyObjectSaver(SaverElement): 17 | def __init__(self, obj): 18 | self._obj = obj 19 | 20 | def load(self, state): 21 | def _load(target, state): 22 | if hasattr(target, "load_state_dict"): 23 | target.load_state_dict(state) 24 | elif isinstance(target, dict): 25 | for k, v in state.items(): 26 | target[k] = _load(target.get(k), v) 27 | elif isinstance(target, list): 28 | if len(target) != len(state): 29 | target.clear() 30 | for v in state: 31 | target.append(v) 32 | else: 33 | for i, v in enumerate(state): 34 | target[i] = _load(target[i], v) 35 | else: 36 | return state 37 | return target 38 | 39 | _load(self._obj, state) 40 | return True 41 | 42 | def save(self): 43 | def _save(target): 44 | if isinstance(target, (defaultdict, dict)): 45 | res = target.__class__() 46 | res.update({k: _save(v) for k, v in target.items()}) 47 | elif hasattr(target, "state_dict"): 48 | res = target.state_dict() 49 | elif isinstance(target, list): 50 | res = [_save(v) for v in target] 51 | else: 52 | res = target 53 | 54 | return res 55 | 56 | return _save(self._obj) 57 | 58 | 59 | class Saver: 60 | def __init__(self, dir: str, short_interval: int, keep_every_n_hours: Optional[int] = 4, keep_last: int = 1): 61 | self.savers = {} 62 | self.short_interval = short_interval 63 | self.dir = dir 64 | assert keep_last >= 1 65 | self.keep_last = keep_last 66 | self._keep_every_n_seconds = keep_every_n_hours * 3600 if keep_every_n_hours else None 67 | 68 | def register(self, name: str, saver, replace: bool = False): 69 | if not replace: 70 | assert name not in self.savers, "Saver %s already registered" % name 71 | 72 | if isinstance(saver, SaverElement): 73 | self.savers[name] = saver 74 | else: 75 | self.savers[name] = PyObjectSaver(saver) 76 | 77 | def __setitem__(self, key: str, value): 78 | if value is not None: 79 | self.register(key, value) 80 | 81 | def save(self, fname: Optional[str] = None, dir: Optional[str] = None, iter: Optional[int]=None): 82 | state = {} 83 | 84 | if fname is None: 85 | assert iter is not None, "If fname is not given, iter should be." 86 | if dir is None: 87 | dir = self.dir 88 | fname = os.path.join(dir, self.model_name_from_index(iter)) 89 | 90 | dname = os.path.dirname(fname) 91 | if dname: 92 | os.makedirs(dname, exist_ok=True) 93 | 94 | print("Saving %s" % fname) 95 | for name, fns in self.savers.items(): 96 | state[name] = fns.save() 97 | 98 | try: 99 | torch.save(state, fname) 100 | except: 101 | print("WARNING: Save failed. Maybe running out of disk space?") 102 | try: 103 | os.remove(fname) 104 | except: 105 | pass 106 | return None 107 | 108 | return fname 109 | 110 | def tick(self, iter: int): 111 | if self.short_interval is None or iter % self.short_interval != 0: 112 | return 113 | 114 | self.save(iter=iter) 115 | self.cleanup() 116 | 117 | @staticmethod 118 | def model_name_from_index(index: int) -> str: 119 | return f"model-{index}.pth" 120 | 121 | @staticmethod 122 | def get_checkpoint_index_list(dir: str) -> List[int]: 123 | return list(reversed(sorted( 124 | [int(fn.split(".")[0].split("-")[-1]) for fn in os.listdir(dir) if fn.split(".")[-1] == "pth"]))) 125 | 126 | def get_ckpts_in_time_window(self, dir: str, index_list: Optional[List[int]]=None): 127 | if index_list is None: 128 | index_list = Saver.get_checkpoint_index_list(dir) 129 | 130 | names = [Saver.model_name_from_index(i) for i in index_list] 131 | if self._keep_every_n_seconds is None: 132 | return names 133 | 134 | now = time.time() 135 | 136 | res = [] 137 | for name in names: 138 | mtime = os.path.getmtime(os.path.join(dir, name)) 139 | if now - mtime > self._keep_every_n_seconds: 140 | break 141 | 142 | res.append(name) 143 | 144 | return res 145 | 146 | @staticmethod 147 | def do_load(fname): 148 | return torch.load(fname) 149 | 150 | def load_last_checkpoint(self) -> Optional[any]: 151 | if not os.path.isdir(self.dir): 152 | return None 153 | 154 | last_checkpoint = Saver.get_checkpoint_index_list(self.dir) 155 | 156 | if last_checkpoint: 157 | for index in last_checkpoint: 158 | fname = Saver.model_name_from_index(index) 159 | try: 160 | data = self.do_load(os.path.join(dir, fname)) 161 | except: 162 | continue 163 | return data 164 | return None 165 | 166 | def cleanup(self): 167 | index_list = self.get_checkpoint_index_list(self.dir) 168 | new_files = self.get_ckpts_in_time_window(self.dir, index_list[self.keep_last:]) 169 | new_files = new_files[:-1] if self._keep_every_n_seconds is not None else new_files 170 | 171 | for f in new_files: 172 | os.remove(os.path.join(self.dir, f)) 173 | 174 | def load_data(self, state) -> bool: 175 | if not state: 176 | return False 177 | 178 | for k, s in state.items(): 179 | if k not in self.savers: 180 | print("WARNING: failed to load state of %s. It doesn't exists." % k) 181 | continue 182 | 183 | print(f"Loading {k}") 184 | if not self.savers[k].load(s): 185 | print(f"Failed to load {k}") 186 | return False 187 | 188 | return True 189 | 190 | def load(self, fname=None) -> bool: 191 | if fname is None: 192 | state = self.load_last_checkpoint() 193 | else: 194 | state = self.do_load(fname) 195 | 196 | return self.load_data(state) 197 | -------------------------------------------------------------------------------- /framework/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .positional_encoding import PositionalEncoding, sinusoidal_pos_embedding 2 | from .cross_entropy_label_smoothing import cross_entropy 3 | -------------------------------------------------------------------------------- /framework/layers/cross_entropy_label_smoothing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from typing import Optional 4 | import math 5 | 6 | 7 | def cross_entropy(input: torch.Tensor, target: torch.Tensor, reduction: str = "mean", smoothing: float = 0, 8 | ignore_index: Optional[int] = None) -> torch.Tensor: 9 | 10 | # Flatten inputs to 2D 11 | t2 = target.flatten().long() 12 | i2 = input.flatten(end_dim=-2) 13 | 14 | # If no smoothing, use built-in cross_entropy loss 15 | if smoothing == 0: 16 | loss = F.cross_entropy(i2, t2, reduction=reduction, ignore_index=-100 if ignore_index is None else ignore_index) 17 | if reduction == "none": 18 | return loss.view_as(target) 19 | else: 20 | return loss 21 | 22 | # Calculate the softmax cross entropy loss 23 | i2 = F.log_softmax(i2, -1) 24 | right_class = i2.gather(-1, t2.unsqueeze(-1)).squeeze() 25 | others = i2.sum(-1) - right_class 26 | 27 | # KL divergence 28 | loss = (smoothing - 1.0) * right_class - others * smoothing 29 | optimal_loss = -((1.0 - smoothing) * math.log(1 - smoothing) + (i2.shape[1] - 1) * smoothing * math.log(smoothing)) 30 | 31 | loss = loss - optimal_loss 32 | 33 | # Handle masking if igonore_index is specified 34 | if ignore_index is not None: 35 | tmask = t2 != ignore_index 36 | loss = torch.where(tmask, loss, torch.zeros([1], dtype=loss.dtype, device=loss.device)) 37 | n_total = tmask.float().sum() 38 | else: 39 | n_total = t2.nelement() 40 | 41 | # Reduction 42 | if reduction == "none": 43 | return loss.view_as(target) 44 | elif reduction == "mean": 45 | return loss.sum() / n_total 46 | elif reduction == "sum": 47 | return loss.sum() 48 | else: 49 | assert False, f"Invalid reduction {reduction}" 50 | -------------------------------------------------------------------------------- /framework/layers/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import math 4 | from typing import Optional 5 | 6 | 7 | def sinusoidal_pos_embedding(d_model: int, max_len: int = 5000, pos_offset: int = 0, 8 | device: Optional[torch.device] = None): 9 | pe = torch.zeros(max_len, d_model, device=device) 10 | position = torch.arange(0, max_len, dtype=torch.float, device=device).unsqueeze(1) + pos_offset 11 | div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float, device=device) * (-math.log(10000.0) / d_model)) 12 | pe[:, 0::2] = torch.sin(position * div_term) 13 | pe[:, 1::2] = torch.cos(position * div_term) 14 | return pe 15 | 16 | 17 | class PositionalEncoding(torch.nn.Module): 18 | r"""Inject some information about the relative or absolute position of the tokens 19 | in the sequence. The positional encodings have the same dimension as 20 | the embeddings, so that the two can be summed. Here, we use sine and cosine 21 | functions of different frequencies. 22 | .. math:: 23 | \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) 24 | \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) 25 | \text{where pos is the word position and i is the embed idx) 26 | Args: 27 | d_model: the embed dim (required). 28 | dropout: the dropout value (default=0.1). 29 | max_len: the max. length of the incoming sequence (default=5000). 30 | batch_first: if true, batch dimension is the first, if not, its the 2nd. 31 | Examples: 32 | >>> pos_encoder = PositionalEncoding(d_model) 33 | """ 34 | 35 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, batch_first: bool = False, 36 | scale: float = 1): 37 | super(PositionalEncoding, self).__init__() 38 | self.dropout = torch.nn.Dropout(p=dropout) 39 | 40 | pe = sinusoidal_pos_embedding(d_model, max_len, 0) * scale 41 | 42 | self.batch_dim = 0 if batch_first else 1 43 | pe = pe.unsqueeze(self.batch_dim) 44 | 45 | self.register_buffer('pe', pe) 46 | 47 | def get(self, n: int, offset: int) -> torch.Tensor: 48 | return self.pe.narrow(1 - self.batch_dim, start=offset, length=n) 49 | 50 | def forward(self, x: torch.Tensor, offset: int = 0) -> torch.Tensor: 51 | x = x + self.get(x.size(1 - self.batch_dim), offset) 52 | return self.dropout(x) 53 | -------------------------------------------------------------------------------- /framework/loader/__init__.py: -------------------------------------------------------------------------------- 1 | from . import collate 2 | from . import sampler 3 | -------------------------------------------------------------------------------- /framework/loader/collate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from operator import mul 4 | from functools import reduce 5 | from typing import List 6 | import warnings 7 | 8 | 9 | class VarLengthCollate: 10 | def __init__(self, ignore_symbol=0, batch_dim: int = 1): 11 | self.ignore_symbol = ignore_symbol 12 | self.batch_dim = batch_dim 13 | 14 | @staticmethod 15 | def _measure_array_max_dim(batch: List[torch.Tensor]): 16 | s=list(batch[0].size()) 17 | different=[False] * len(s) 18 | for i in range(1, len(batch)): 19 | ns = batch[i].size() 20 | different = [different[j] or s[j]!=ns[j] for j in range(len(s))] 21 | s=[max(s[j], ns[j]) for j in range(len(s))] 22 | return s, different 23 | 24 | def _merge_var_len_array(self, batch: List[torch.Tensor]): 25 | max_size, different = self._measure_array_max_dim(batch) 26 | s=max_size[:self.batch_dim] + [len(batch)] + max_size[self.batch_dim:] 27 | storage = batch[0].storage()._new_shared(reduce(mul, s, 1)) 28 | out = batch[0].new(storage).view(s).fill_(self.ignore_symbol if self.ignore_symbol is not None else 0) 29 | for i, d in enumerate(batch): 30 | bdim = self.batch_dim if len(out.shape)>self.batch_dim else 0 31 | this_o = out.narrow(bdim, i, 1).squeeze(bdim) 32 | for j, diff in enumerate(different): 33 | if different[j]: 34 | this_o = this_o.narrow(j, 0, d.size(j)) 35 | 36 | this_o.copy_(d) 37 | return out 38 | 39 | def __call__(self, batch): 40 | if isinstance(batch[0], dict): 41 | return {k: self([b[k] for b in batch]) for k in batch[0].keys()} 42 | elif isinstance(batch[0], np.ndarray): 43 | with warnings.catch_warnings(): 44 | # If the source data is mmapped from a file, from_numpy will throw a warning that it is readonly. 45 | # However it does not matter, since all batches will be merged anyway, which copies the data. 46 | warnings.filterwarnings("ignore", category=UserWarning) 47 | return self([torch.from_numpy(a) for a in batch]) 48 | elif torch.is_tensor(batch[0]): 49 | return self._merge_var_len_array(batch) 50 | elif isinstance(batch[0], (int, float)): 51 | return torch.Tensor(batch) 52 | else: 53 | assert False, "Unknown type: %s" % type(batch[0]) 54 | -------------------------------------------------------------------------------- /framework/loader/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from .. import utils 4 | import numpy as np 5 | from typing import Dict, Any 6 | 7 | 8 | class InfiniteSampler(torch.utils.data.Sampler): 9 | def __init__(self, data_source: torch.utils.data.Dataset, replacement=True, seed=None): 10 | super().__init__(data_source) 11 | self.data_source = data_source 12 | self.replacement = replacement 13 | self.seed = utils.seed.get_randstate(seed) 14 | 15 | def __iter__(self): 16 | n = len(self.data_source) 17 | if self.replacement: 18 | while True: 19 | yield self.seed.randint(0, n, dtype=np.int64) 20 | else: 21 | i_list = None 22 | pos = n 23 | while True: 24 | if pos >= n: 25 | i_list = self.seed.permutation(n).tolist() 26 | pos = 0 27 | 28 | sample = i_list[pos] 29 | pos += 1 30 | yield sample 31 | 32 | def __len__(self): 33 | return 0x7FFFFFFF 34 | -------------------------------------------------------------------------------- /framework/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .lockfile import LockFile 2 | from .gpu_allocator import use_gpu 3 | from . import universal as U 4 | from . import port 5 | from . import process 6 | from . import seed 7 | from .average import Average, MovingAverage 8 | from .download import download 9 | from .time_meter import ElapsedTimeMeter 10 | from .parallel_map import parallel_map 11 | from .set_lr import set_lr -------------------------------------------------------------------------------- /framework/utils/average.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union, Any, Dict 3 | 4 | 5 | class Average: 6 | SAVE = ["sum", "cnt"] 7 | 8 | def __init__(self): 9 | self.reset() 10 | 11 | def add(self, data: Union[int, float, torch.Tensor]): 12 | if torch.is_tensor(data): 13 | data = data.detach() 14 | 15 | self.sum += data 16 | self.cnt += 1 17 | 18 | def reset(self): 19 | self.sum = 0 20 | self.cnt = 0 21 | 22 | def get(self, reset=True) -> Union[float, torch.Tensor]: 23 | res = self.sum / self.cnt 24 | if reset: 25 | self.reset() 26 | 27 | return res 28 | 29 | def state_dict(self) -> Dict[str, Any]: 30 | return {k: self.__dict__[k] for k in self.SAVE} 31 | 32 | def load_state_dict(self, state: Dict[str, Any]): 33 | self.__dict__.update(state or {}) 34 | 35 | 36 | class MovingAverage(Average): 37 | SAVE = ["sum", "cnt", "history"] 38 | 39 | def __init__(self, window_size: int): 40 | self.window_size = window_size 41 | super().__init__() 42 | 43 | def reset(self): 44 | self.history = [] 45 | super().reset() 46 | 47 | def add(self, data: Union[int, float, torch.Tensor]): 48 | super().add(data) 49 | if self.cnt > self.window_size: 50 | self.sum -= self.history.pop(0) 51 | self.cnt -= 1 52 | 53 | assert self.cnt <= self.window_size 54 | -------------------------------------------------------------------------------- /framework/utils/download.py: -------------------------------------------------------------------------------- 1 | import requests, tarfile, io, os, zipfile, gzip 2 | from typing import Optional 3 | 4 | from io import BytesIO, SEEK_SET, SEEK_END 5 | 6 | 7 | class UrlStream: 8 | def __init__(self, url): 9 | self._url = url 10 | headers = requests.head(url, headers={"Accept-Encoding": "identity"}).headers 11 | headers = {k.lower(): v for k, v in headers.items()} 12 | self._seek_supported = headers.get('accept-ranges') == 'bytes' and 'content-length' in headers 13 | if self._seek_supported: 14 | self._size = int(headers['content-length']) 15 | self._curr_pos = 0 16 | self._buf_start_pos = 0 17 | self._iter = None 18 | self._buffer = None 19 | self._buf_size = 0 20 | self._loaded_all = False 21 | 22 | def _load_all(self): 23 | if self._loaded_all: 24 | return 25 | self._make_request() 26 | old_buf_pos = self._buffer.tell() 27 | self._buffer.seek(0, SEEK_END) 28 | for chunk in self._iter: 29 | self._buffer.write(chunk) 30 | self._buf_size = self._buffer.tell() 31 | self._buffer.seek(old_buf_pos, SEEK_SET) 32 | self._loaded_all = True 33 | 34 | def seekable(self): 35 | return self._seek_supported 36 | 37 | def seek(self, position, whence=SEEK_SET): 38 | if whence == SEEK_END: 39 | assert position <= 0 40 | if self._seek_supported: 41 | self.seek(self._size + position) 42 | else: 43 | self._load_all() 44 | self._buffer.seek(position, SEEK_END) 45 | self._curr_pos = self._buffer.tell() 46 | elif whence == SEEK_SET: 47 | if self._curr_pos != position: 48 | self._curr_pos = position 49 | if self._seek_supported: 50 | self._iter = None 51 | self._buffer = None 52 | else: 53 | self._load_until(position) 54 | self._buffer.seek(position) 55 | self._curr_pos = position 56 | else: 57 | assert "Invalid whence %s" % whence 58 | 59 | return self.tell() 60 | 61 | def tell(self): 62 | return self._curr_pos 63 | 64 | def _load_until(self, goal_position): 65 | self._make_request() 66 | old_buf_pos = self._buffer.tell() 67 | current_position = self._buffer.seek(0, SEEK_END) 68 | 69 | goal_position = goal_position - self._buf_start_pos 70 | while current_position < goal_position: 71 | try: 72 | d = next(self._iter) 73 | self._buffer.write(d) 74 | current_position += len(d) 75 | except StopIteration: 76 | break 77 | self._buf_size = current_position 78 | self._buffer.seek(old_buf_pos, SEEK_SET) 79 | 80 | def _new_buffer(self): 81 | remaining = self._buffer.read() if self._buffer is not None else None 82 | self._buffer = BytesIO() 83 | if remaining is not None: 84 | self._buffer.write(remaining) 85 | self._buf_start_pos = self._curr_pos 86 | self._buf_size = 0 if remaining is None else len(remaining) 87 | self._buffer.seek(0, SEEK_SET) 88 | self._loaded_all = False 89 | 90 | def _make_request(self): 91 | if self._iter is None: 92 | h = { 93 | "User-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/47.0.2526.80 Safari/537.36", 94 | } 95 | if self._seek_supported: 96 | h["Range"] = "bytes=%d-%d" % (self._curr_pos, self._size - 1) 97 | 98 | r = requests.get(self._url, headers=h, stream=True) 99 | 100 | self._iter = r.iter_content(1024 * 1024) 101 | self._new_buffer() 102 | elif self._seek_supported and self._buf_size > 128 * 1024 * 1024: 103 | self._new_buffer() 104 | 105 | def size(self): 106 | if self._seek_supported: 107 | return self._size 108 | else: 109 | self._load_all() 110 | return self._buf_size 111 | 112 | def read(self, size=None): 113 | if size is None: 114 | size = self.size() 115 | 116 | self._load_until(self._curr_pos + size) 117 | if self._seek_supported: 118 | self._curr_pos = min(self._curr_pos + size, self._size) 119 | 120 | read_data = self._buffer.read(size) 121 | if not self._seek_supported: 122 | self._curr_pos += len(read_data) 123 | return read_data 124 | 125 | def iter_content(self, block_size): 126 | while True: 127 | d = self.read(block_size) 128 | if not len(d): 129 | break 130 | yield d 131 | 132 | 133 | def download(url: str, dest: Optional[str] = None, extract: bool=True, ignore_if_exists: bool = False, 134 | compression: Optional[str] = None): 135 | """ 136 | Download a file from the internet. 137 | 138 | Args: 139 | url: the url to download 140 | dest: destination file if extract=False, or destionation dir if extract=True. If None, it will be the last part of URL. 141 | extract: extract a tar.gz or zip file? 142 | ignore_if_exists: don't do anything if file exists 143 | 144 | Returns: 145 | the destination filename. 146 | """ 147 | 148 | base_url = url.split("?")[0] 149 | 150 | if dest is None: 151 | dest = [f for f in base_url.split("/") if f][-1] 152 | 153 | if os.path.exists(dest) and ignore_if_exists: 154 | return dest 155 | 156 | stream = UrlStream(url) 157 | extension = base_url.split(".")[-1].lower() 158 | 159 | if extract and extension in ['gz', 'bz2', 'zip', 'tgz', 'tar']: 160 | os.makedirs(dest, exist_ok=True) 161 | 162 | if extension == "gz" and not base_url.endswith(".tar.gz"): 163 | decompressed_file = gzip.GzipFile(fileobj=stream) 164 | with open(os.path.join(dest, url.split("/")[-1][:-3]), 'wb') as f: 165 | while True: 166 | d = decompressed_file.read(1024 * 1024) 167 | if not d: 168 | break 169 | f.write(d) 170 | else: 171 | if extension in ['gz', 'bz2', "tgz", "tar"]: 172 | decompressed_file = tarfile.open(fileobj=stream, mode='r|' + 173 | (compression or ( 174 | "gz" if extension == "tgz" else extension))) 175 | elif extension == 'zip': 176 | decompressed_file = zipfile.ZipFile(stream, mode='r') 177 | else: 178 | assert False, "Invalid extension: %s" % extension 179 | 180 | decompressed_file.extractall(dest) 181 | else: 182 | try: 183 | with open(dest, 'wb') as f: 184 | for d in stream.iter_content(1024 * 1024): 185 | f.write(d) 186 | except: 187 | os.remove(dest) 188 | raise 189 | return dest 190 | -------------------------------------------------------------------------------- /framework/utils/gpu_allocator.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | import torch 4 | from ..utils.lockfile import LockFile 5 | from typing import List, Dict, Optional 6 | 7 | gpu_fake_usage = [] 8 | 9 | 10 | def get_memory_usage() -> Optional[Dict[int, int]]: 11 | try: 12 | proc = subprocess.Popen("nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits".split(" "), 13 | stdout=subprocess.PIPE) 14 | lines = [s.strip().split(" ") for s in proc.communicate()[0].decode().split("\n") if s] 15 | return {int(g[0][:-1]): int(g[1]) for g in lines} 16 | except: 17 | return None 18 | 19 | 20 | def get_free_gpus() -> Optional[List[int]]: 21 | try: 22 | free = [] 23 | proc = subprocess.Popen("nvidia-smi --query-compute-apps=gpu_uuid --format=csv,noheader,nounits".split(" "), 24 | stdout=subprocess.PIPE) 25 | uuids = [s.strip() for s in proc.communicate()[0].decode().split("\n") if s] 26 | 27 | proc = subprocess.Popen("nvidia-smi --query-gpu=index,uuid --format=csv,noheader,nounits".split(" "), 28 | stdout=subprocess.PIPE) 29 | 30 | id_uid_pair = [s.strip().split(", ") for s in proc.communicate()[0].decode().split("\n") if s] 31 | for i in id_uid_pair: 32 | id, uid = i 33 | 34 | if uid not in uuids: 35 | free.append(int(id)) 36 | 37 | return free 38 | except: 39 | return None 40 | 41 | 42 | def _fix_order(): 43 | os.environ["CUDA_DEVICE_ORDER"] = os.environ.get("CUDA_DEVICE_ORDER", "PCI_BUS_ID") 44 | 45 | 46 | def _create_gpu_usage(n_gpus: int): 47 | global gpu_fake_usage 48 | 49 | for i in range(n_gpus): 50 | a = torch.FloatTensor([0.0]) 51 | a.cuda(i) 52 | gpu_fake_usage.append(a) 53 | 54 | 55 | def allocate(n:int = 1): 56 | _fix_order() 57 | with LockFile("/tmp/gpu_allocation_lock"): 58 | if "CUDA_VISIBLE_DEVICES" in os.environ: 59 | print("WARNING: trying to allocate %d GPUs, but CUDA_VISIBLE_DEVICES already set to %s" % 60 | (n, os.environ["CUDA_VISIBLE_DEVICES"])) 61 | return 62 | 63 | allocated = get_free_gpus() 64 | if allocated is None: 65 | print("WARNING: failed to allocate %d GPUs" % n) 66 | return 67 | allocated = allocated[:n] 68 | 69 | if len(allocated) < n: 70 | print("There is no more free GPUs. Allocating the one with least memory usage.") 71 | usage = get_memory_usage() 72 | if usage is None: 73 | print("WARNING: failed to allocate %d GPUs" % n) 74 | return 75 | 76 | inv_usages = {} 77 | 78 | for k, v in usage.items(): 79 | if v not in inv_usages: 80 | inv_usages[v] = [] 81 | 82 | inv_usages[v].append(k) 83 | 84 | min_usage = list(sorted(inv_usages.keys())) 85 | min_usage_devs = [] 86 | for u in min_usage: 87 | min_usage_devs += inv_usages[u] 88 | 89 | min_usage_devs = [m for m in min_usage_devs if m not in allocated] 90 | 91 | n2 = n - len(allocated) 92 | if n2>len(min_usage_devs): 93 | print("WARNING: trying to allocate %d GPUs but only %d available" % (n, len(min_usage_devs)+len(allocated))) 94 | n2 = len(min_usage_devs) 95 | 96 | allocated += min_usage_devs[:n2] 97 | 98 | os.environ["CUDA_VISIBLE_DEVICES"]=",".join([str(a) for a in allocated]) 99 | _create_gpu_usage(len(allocated)) 100 | 101 | 102 | def use_gpu(gpu:str = "auto", n_autoalloc: int = 1): 103 | _fix_order() 104 | 105 | gpu = gpu.lower() 106 | if gpu in ["auto", ""]: 107 | allocate(n_autoalloc) 108 | elif gpu.lower()=="none": 109 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 110 | else: 111 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu 112 | _create_gpu_usage(len(gpu.split(","))) 113 | 114 | return len(os.environ.get("CUDA_VISIBLE_DEVICES","").split(",")) 115 | -------------------------------------------------------------------------------- /framework/utils/lockfile.py: -------------------------------------------------------------------------------- 1 | import os 2 | import fcntl 3 | import time 4 | import errno 5 | 6 | 7 | class LockFile: 8 | def __init__(self, fname: str): 9 | self._fname = fname 10 | self._fd = None 11 | 12 | def acquire(self): 13 | self._fd=open(self._fname, "w") 14 | try: 15 | os.chmod(self._fname, 0o777) 16 | except PermissionError: 17 | # If another user created it already, we don't have the permission to change the access rights. 18 | # But it can be ignored because the creator already set it right. 19 | pass 20 | 21 | while True: 22 | try: 23 | fcntl.flock(self._fd, fcntl.LOCK_EX | fcntl.LOCK_NB) 24 | break 25 | except IOError as e: 26 | if e.errno != errno.EAGAIN: 27 | raise 28 | else: 29 | time.sleep(0.1) 30 | 31 | def release(self): 32 | fcntl.flock(self._fd, fcntl.LOCK_UN) 33 | self._fd.close() 34 | self._fd = None 35 | 36 | def __enter__(self): 37 | self.acquire() 38 | 39 | def __exit__(self, exc_type, exc_val, exc_tb): 40 | self.release() 41 | -------------------------------------------------------------------------------- /framework/utils/parallel_map.py: -------------------------------------------------------------------------------- 1 | from torch import multiprocessing 2 | from typing import Iterable, Callable, Any, List 3 | import time 4 | 5 | 6 | def parallel_map(tasks: Iterable, callback = Callable[[Any], None], max_parallel: int = 32) -> List: 7 | limit = min(multiprocessing.cpu_count(), max_parallel) 8 | processes: List[multiprocessing.Process] = [] 9 | queues: List[multiprocessing.Queue] = [] 10 | indices: List[int] = [] 11 | tlist = [t for t in tasks] 12 | res = [None] * len(tlist) 13 | curr = 0 14 | 15 | def process_return(q, arg): 16 | res = callback(arg) 17 | q.put(res) 18 | 19 | 20 | while curr < len(tlist): 21 | if len(processes) == limit: 22 | ended = [] 23 | for i, q in enumerate(queues): 24 | if not q.empty(): 25 | processes[i].join() 26 | ended.append(i) 27 | res[indices[i]] = q.get() 28 | 29 | for i in sorted(ended, reverse=True): 30 | processes.pop(i) 31 | queues.pop(i) 32 | indices.pop(i) 33 | 34 | if not ended: 35 | time.sleep(0.1) 36 | continue 37 | 38 | queues.append(multiprocessing.Queue()) 39 | indices.append(curr) 40 | processes.append(multiprocessing.Process(target=process_return, args=(queues[-1], tlist[curr]))) 41 | processes[-1].start() 42 | 43 | curr += 1 44 | 45 | for i, p in enumerate(processes): 46 | res[indices[i]] = queues[i].get() 47 | p.join() 48 | 49 | return res -------------------------------------------------------------------------------- /framework/utils/port.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import time 3 | 4 | 5 | def check_used(port: int) -> bool: 6 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 7 | result = sock.connect_ex(('127.0.0.1', port)) 8 | if result == 0: 9 | sock.close() 10 | return True 11 | else: 12 | return False 13 | 14 | 15 | def alloc(start_from: int = 7000) -> int: 16 | while True: 17 | if check_used(start_from): 18 | print("Port already used: %d" % start_from) 19 | start_from += 1 20 | else: 21 | return start_from 22 | 23 | 24 | def wait_for(port: int, timeout: int = 5) -> bool: 25 | star_time = time.time() 26 | while not check_used(port): 27 | if time.time() - star_time > timeout: 28 | return False 29 | 30 | time.sleep(0.1) 31 | return True 32 | -------------------------------------------------------------------------------- /framework/utils/process.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import ctypes 3 | import subprocess 4 | import os 5 | 6 | 7 | def run(cmd: str, hide_stderr: bool = True, stdout_mode: str = "print"): 8 | libc_search_dirs = ["/lib", "/lib/x86_64-linux-gnu", "/lib/powerpc64le-linux-gnu"] 9 | 10 | if sys.platform == "linux" : 11 | found = None 12 | for d in libc_search_dirs: 13 | file = os.path.join(d, "libc.so.6") 14 | if os.path.isfile(file): 15 | found = file 16 | break 17 | 18 | if not found: 19 | print("WARNING: Cannot find libc.so.6. Cannot kill process when parent dies.") 20 | killer = None 21 | else: 22 | libc = ctypes.CDLL(found) 23 | PR_SET_PDEATHSIG = 1 24 | KILL = 9 25 | killer = lambda: libc.prctl(PR_SET_PDEATHSIG, KILL) 26 | else: 27 | print("WARNING: OS not linux. Cannot kill process when parent dies.") 28 | killer = None 29 | 30 | if hide_stderr: 31 | stderr = open(os.devnull,'w') 32 | else: 33 | stderr = None 34 | 35 | if stdout_mode == "hide": 36 | stdout = open(os.devnull, 'w') 37 | elif stdout_mode == "print": 38 | stdout = None 39 | elif stdout_mode == "pipe": 40 | stdout = subprocess.PIPE 41 | else: 42 | assert False, "Invalid stdout mode: %s" % stdout_mode 43 | 44 | return subprocess.Popen(cmd.split(" "), stderr=stderr, stdout=stdout, preexec_fn=killer) 45 | -------------------------------------------------------------------------------- /framework/utils/seed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import torch.cuda 4 | import torch.backends.cudnn 5 | import random 6 | import numpy as np 7 | from typing import Optional 8 | 9 | 10 | def fix(offset: int = 0, fix_cudnn: bool = True): 11 | random.seed(0x12345678 + offset) 12 | torch.manual_seed(0x0DABA52 + offset) 13 | torch.cuda.manual_seed(0x0DABA52 + 1 + offset) 14 | np.random.seed(0xC1CAFA52 + offset) 15 | 16 | if fix_cudnn: 17 | torch.backends.cudnn.deterministic = True 18 | torch.backends.cudnn.benchmark = False 19 | 20 | 21 | def get_randstate(seed: Optional[int] = None) -> np.random.RandomState: 22 | if seed is None: 23 | worker_info = torch.utils.data.get_worker_info() 24 | if worker_info is not None: 25 | seed = worker_info.seed 26 | else: 27 | seed = random.randint(0, 0x7FFFFFFF) 28 | 29 | return np.random.RandomState(seed) 30 | -------------------------------------------------------------------------------- /framework/utils/set_lr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def set_lr(optim: torch.optim.Optimizer, lr: float): 5 | for param_group in optim.param_groups: 6 | param_group['lr'] = lr 7 | -------------------------------------------------------------------------------- /framework/utils/time_meter.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class ElapsedTimeMeter: 5 | def __init__(self): 6 | self.reset() 7 | 8 | def start(self): 9 | self.start_time = time.time() 10 | 11 | def _curr_timer(self) -> float: 12 | if self.start_time is None: 13 | return 0 14 | 15 | return time.time() - self.start_time 16 | 17 | def stop(self): 18 | self.sum += self._curr_timer() 19 | self.start_time = None 20 | 21 | def get(self, reset=False) -> float: 22 | res = self.sum + self._curr_timer() 23 | if reset: 24 | self.reset() 25 | return res 26 | 27 | def reset(self): 28 | self.start_time = None 29 | self.sum = 0 30 | 31 | def __enter__(self): 32 | assert self.start_time is None 33 | self.start() 34 | 35 | def __exit__(self, *args): 36 | self.stop() 37 | -------------------------------------------------------------------------------- /framework/utils/universal.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import torch 3 | 4 | 5 | def apply_recursive(d, fn: Callable, filter: Callable = None): 6 | if isinstance(d, list): 7 | return [apply_recursive(da, fn, filter) for da in d] 8 | elif isinstance(d, tuple): 9 | return tuple(apply_recursive(list(d), fn, filter)) 10 | elif isinstance(d, dict): 11 | return {k: apply_recursive(v, fn, filter) for k, v in d.items()} 12 | else: 13 | if filter is None or filter(d): 14 | return fn(d) 15 | else: 16 | return d 17 | 18 | 19 | def apply_to_tensors(d, fn: Callable): 20 | return apply_recursive(d, fn, torch.is_tensor) -------------------------------------------------------------------------------- /framework/visualize/__init__.py: -------------------------------------------------------------------------------- 1 | from . import plot 2 | from . import tensorboard -------------------------------------------------------------------------------- /framework/visualize/tensorboard.py: -------------------------------------------------------------------------------- 1 | from .. import utils 2 | from ..utils import process 3 | import os 4 | import atexit 5 | from typing import Optional 6 | import shutil 7 | 8 | port: Optional[int] = None 9 | tb_process = None 10 | 11 | 12 | def start(log_dir: str, on_port: Optional[int] = None): 13 | global port 14 | 15 | global tb_process 16 | if tb_process is not None: 17 | return 18 | 19 | port = utils.port.alloc() if on_port is None else on_port 20 | 21 | command = shutil.which("tensorboard") 22 | if command is None: 23 | command = os.path.expanduser("~/.local/bin/tensorboard") 24 | 25 | if os.path.isfile(command): 26 | print("Found tensorboard in", command) 27 | else: 28 | assert False, "Tensorboard not found." 29 | 30 | extra_flags = "" 31 | version = process.run("%s --version" % command, hide_stderr=True, stdout_mode="pipe").communicate()[0].decode() 32 | if int(version[0])>1: 33 | extra_flags = "--bind_all" 34 | 35 | print("Starting Tensorboard server on %d" % port) 36 | tb_process = process.run("%s --port %d --logdir %s %s" % (command, port, log_dir, extra_flags), hide_stderr=True, 37 | stdout_mode="hide") 38 | if not utils.port.wait_for(port): 39 | print("ERROR: failed to start Tensorboard server. Server not responding.") 40 | return 41 | print("Done.") 42 | 43 | def kill_tb(): 44 | if tb_process is None: 45 | return 46 | 47 | tb_process.kill() 48 | 49 | atexit.register(kill_tb) 50 | -------------------------------------------------------------------------------- /interfaces/__init__.py: -------------------------------------------------------------------------------- 1 | from .result import Result 2 | from .model_interface import ModelInterface 3 | from .transformer import TransformerEncDecInterface 4 | -------------------------------------------------------------------------------- /interfaces/encoder_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataclasses import dataclass 3 | from .result import Result 4 | from typing import List, Optional 5 | 6 | 7 | @dataclass 8 | class EncoderDecoderResult(Result): 9 | outputs: torch.Tensor 10 | out_lengths: torch.Tensor 11 | loss: torch.Tensor 12 | 13 | batch_dim = 1 14 | 15 | @staticmethod 16 | def merge(l: List, batch_weights: Optional[List[float]] = None): 17 | if len(l) == 1: 18 | return l[0] 19 | batch_weights = batch_weights if batch_weights is not None else [1] * len(l) 20 | loss = sum([r.loss * w for r, w in zip(l, batch_weights)]) / sum(batch_weights) 21 | out = torch.stack([r.outputs for r in l], l[0].batch_dim) 22 | lens = torch.stack([r.out_lengths for r in l], 0) 23 | return l[0].__class__(out, lens, loss) 24 | -------------------------------------------------------------------------------- /interfaces/model_interface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Dict, Any 3 | from .result import Result 4 | 5 | 6 | class ModelInterface: 7 | def create_input(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: 8 | raise NotImplementedError 9 | 10 | def decode_outputs(self, outputs: Result) -> Any: 11 | raise NotImplementedError 12 | 13 | def __call__(self, data: Dict[str, torch.Tensor]) -> Result: 14 | raise NotImplementedError 15 | -------------------------------------------------------------------------------- /interfaces/result.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Dict, Any, List, Optional 3 | 4 | 5 | class Result: 6 | outputs: torch.Tensor 7 | loss: torch.Tensor 8 | 9 | batch_dim = 0 10 | 11 | def plot(self) -> Dict[str, Any]: 12 | return {} 13 | 14 | @property 15 | def batch_size(self) -> int: 16 | return self.outputs.shape[self.batch_dim] 17 | 18 | @staticmethod 19 | def merge(l: List, batch_weights: Optional[List[float]] = None): 20 | if len(l) == 1: 21 | return l[0] 22 | batch_weights = batch_weights if batch_weights is not None else [1] * len(l) 23 | loss = sum([r.loss * w for r, w in zip(l, batch_weights)]) / sum(batch_weights) 24 | out = torch.stack([r.outputs for r in l], l[0].batch_dim) 25 | return l[0].__class__(out, loss) 26 | -------------------------------------------------------------------------------- /interfaces/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder_decoder_interface import TransformerEncDecInterface 2 | -------------------------------------------------------------------------------- /interfaces/transformer/encoder_decoder_interface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | from typing import Dict, Tuple 4 | from models.encoder_decoder import add_eos 5 | from models.transformer_enc_dec import TransformerResult 6 | from ..model_interface import ModelInterface 7 | import framework 8 | 9 | from ..encoder_decoder import EncoderDecoderResult 10 | 11 | 12 | class TransformerEncDecInterface(ModelInterface): 13 | def __init__(self, model: torch.nn.Module, label_smoothing: float = 0.0): 14 | self.model = model 15 | self.label_smoothing = label_smoothing 16 | 17 | def loss(self, outputs: TransformerResult, ref: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: 18 | l = framework.layers.cross_entropy(outputs.data, ref, reduction='none', smoothing=self.label_smoothing) 19 | l = l.reshape_as(ref) * mask 20 | l = l.sum() / mask.sum() 21 | return l 22 | 23 | def decode_outputs(self, outputs: EncoderDecoderResult) -> Tuple[torch.Tensor, torch.Tensor]: 24 | return outputs.outputs, outputs.out_lengths 25 | 26 | def __call__(self, data: Dict[str, torch.Tensor], train_eos: bool = True) -> EncoderDecoderResult: 27 | in_len = data["in_len"].long() 28 | out_len = data["out_len"].long() 29 | in_with_eos = add_eos(data["in"], data["in_len"], self.model.encoder_eos) 30 | out_with_eos = add_eos(data["out"], data["out_len"], self.model.decoder_sos_eos) 31 | in_len += 1 32 | out_len += 1 33 | 34 | res = self.model(in_with_eos.transpose(0, 1), in_len, out_with_eos.transpose(0, 1), 35 | out_len, teacher_forcing=self.model.training, max_len=out_len.max().item()) 36 | 37 | res.data = res.data.transpose(0, 1) 38 | len_mask = ~self.model.generate_len_mask(out_with_eos.shape[0], out_len if train_eos else (out_len - 1)).\ 39 | transpose(0, 1) 40 | 41 | loss = self.loss(res, out_with_eos, len_mask) 42 | return EncoderDecoderResult(res.data, res.length, loss) 43 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import Transformer 2 | from .tied_embedding import TiedEmbedding 3 | -------------------------------------------------------------------------------- /layers/tied_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class TiedEmbedding(torch.nn.Module): 7 | def __init__(self, weights: torch.Tensor): 8 | super().__init__() 9 | 10 | # Hack: won't save it as a parameter 11 | self.w = [weights] 12 | self.bias = torch.nn.Parameter(torch.zeros(self.w[0].shape[0])) 13 | 14 | def forward(self, t: torch.Tensor) -> torch.Tensor: 15 | return F.linear(t, self.w[0], self.bias) 16 | -------------------------------------------------------------------------------- /layers/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import Transformer 2 | from .relative_transformer import RelativeTransformer 3 | from .universal_transformer import UniversalTransformer 4 | from .universal_relative_transformer import UniversalRelativeTransformer 5 | -------------------------------------------------------------------------------- /layers/transformer/multi_head_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import torch.nn.functional as F 4 | import math 5 | from typing import Optional, Callable, List, Union, Tuple 6 | from dataclasses import dataclass 7 | 8 | 9 | @dataclass 10 | class AttentionMask: 11 | src_length_mask: Optional[torch.Tensor] 12 | position_mask: Optional[torch.Tensor] 13 | 14 | 15 | class MultiHeadAttentionBase(torch.nn.Module): 16 | def __init__(self, state_size: int, n_heads: int, dropout: float = 0.1): 17 | assert state_size % n_heads == 0 18 | super().__init__() 19 | self.state_size = state_size 20 | self.projection_size = state_size // n_heads 21 | self.n_heads = n_heads 22 | self.scale = 1.0 / math.sqrt(self.projection_size) 23 | 24 | self.dropout = torch.nn.Dropout(dropout) 25 | self.multi_head_merge = torch.nn.Linear(n_heads * self.projection_size, state_size, bias=False) 26 | 27 | def _masked_softmax(self, logits: torch.Tensor, mask: Optional[AttentionMask]) -> torch.Tensor: 28 | if mask is None or (mask.src_length_mask is None and mask.position_mask is None): 29 | return F.softmax(logits, -1) 30 | 31 | # Output shape: [n_batch * n_heads, n_time_dest, n_time_src] 32 | bb, n_time_dest, n_time_src = logits.shape 33 | 34 | logits = logits.view(bb // self.n_heads, self.n_heads, n_time_dest, n_time_src) 35 | 36 | if mask.position_mask is not None: 37 | logits = logits.masked_fill(mask.position_mask.unsqueeze(0).unsqueeze(0), float("-inf")) 38 | 39 | if mask.src_length_mask is not None: 40 | logits = logits.masked_fill(mask.src_length_mask.unsqueeze(1).unsqueeze(1), float("-inf")) 41 | 42 | logits = F.softmax(logits, -1) 43 | return logits.view(bb, n_time_dest, n_time_src) 44 | 45 | def _attention_read(self, mask: Optional[AttentionMask], logits: torch.Tensor, v: torch.Tensor) -> \ 46 | Tuple[torch.Tensor, torch.Tensor]: 47 | # logits: [n_batch * n_heads, n_out, n_in] 48 | # v: [n_nbatch * n_heads, n_in] 49 | # Output data shape [n_batch * n_heads, n_time_dest, data_size] 50 | # Out attention score shape: [n_batch, n_heads, n_time_dest, n_time_src] 51 | scores = self._masked_softmax(logits * self.scale, mask) 52 | scores = self.dropout(scores) 53 | return torch.bmm(scores, v), scores.view(-1, self.n_heads, *scores.shape[1:]) 54 | 55 | def merged_attention(self, n_batch: int, n_out_steps: int, *args, need_weights: bool = False, **kwargs) -> \ 56 | Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 57 | 58 | data, scores = self._attention(*args, **kwargs) 59 | 60 | data = data.view(n_batch, self.n_heads, n_out_steps, -1).permute(0, 2, 1, 3).contiguous().\ 61 | view(n_batch, n_out_steps, -1) 62 | 63 | return self.multi_head_merge(data), scores 64 | 65 | def transform_data(self, input: torch.Tensor, proj: Callable[[torch.Tensor], torch.Tensor], 66 | n_projs: int) -> List[torch.Tensor]: 67 | # Input shape: [n_batch, n_steps, n_channels] 68 | # Output: Tuple of n_projs tensors of dimension: [n_batch * n_heads, n_steps, projection_size] 69 | n_batch, n_steps, _ = input.shape 70 | transformed = proj(input).view(n_batch, n_steps, self.n_heads, n_projs, self.projection_size). \ 71 | permute(0, 2, 1, 3, 4).contiguous().view(n_batch * self.n_heads, n_steps, n_projs, self.projection_size) 72 | return transformed.unbind(dim=2) 73 | 74 | def reset_parameters(self): 75 | torch.nn.init.xavier_uniform_(self.multi_head_merge.weight) 76 | 77 | 78 | class AbsPosAttentionBase(MultiHeadAttentionBase): 79 | def _attention(self, mask: Optional[torch.Tensor], q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> \ 80 | torch.Tensor: 81 | # all inputs should have a shape of [n_batch, n_steps, data_size] 82 | # Output shape [n_batch * n_heads, n_time_dest, data_size] 83 | return self._attention_read(mask, torch.bmm(q, k.transpose(1,2)), v) 84 | 85 | 86 | class MultiHeadAttention(AbsPosAttentionBase): 87 | def __init__(self, state_size: int, n_heads: int, dropout: float=0.1, input_size: Optional[torch.Tensor]=None): 88 | super().__init__(state_size, n_heads, dropout) 89 | self.data_to_kv = torch.nn.Linear(state_size, 2 * n_heads * self.projection_size, bias=False) 90 | self.data_to_q = torch.nn.Linear(state_size if input_size is None else input_size, 91 | n_heads * self.projection_size, bias=False) 92 | self.reset_parameters() 93 | 94 | def forward(self, curr_state: torch.Tensor, attend_to: torch.Tensor, mask: Optional[AttentionMask], 95 | need_weights: bool = False): 96 | # Input and output shape: [n_batch, n_steps, data_size] 97 | k, v = self.transform_data(attend_to, self.data_to_kv, 2) 98 | q, = self.transform_data(curr_state, self.data_to_q, 1) 99 | 100 | data, scores = self.merged_attention(curr_state.shape[0], q.shape[1], mask, q, k, v) 101 | if need_weights: 102 | # Calculate the mean over the heads 103 | return data, scores.mean(1) 104 | else: 105 | return data 106 | 107 | def reset_parameters(self): 108 | super().reset_parameters() 109 | 110 | torch.nn.init.xavier_uniform_(self.data_to_q.weight) 111 | torch.nn.init.xavier_uniform_(self.data_to_kv.weight[:self.data_to_kv.weight.shape[0]//2]) 112 | torch.nn.init.xavier_uniform_(self.data_to_kv.weight[self.data_to_kv.weight.shape[0]//2:]) 113 | -------------------------------------------------------------------------------- /layers/transformer/multi_head_relative_pos_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import torch.nn.functional as F 4 | from typing import Optional 5 | from .multi_head_attention import AttentionMask, MultiHeadAttentionBase 6 | import framework 7 | import math 8 | 9 | 10 | class RelativeAttentionBase(MultiHeadAttentionBase): 11 | def __init__(self, state_size: int, n_heads: int, dropout: float): 12 | super().__init__(state_size, n_heads, dropout=dropout) 13 | 14 | def _shift(self, posmat: torch.Tensor) -> torch.Tensor: 15 | # Slice out a matrix diagonally. Each successive row is sliced one position to the left compared. 16 | # shape: [n_batch, n_head, n_out, n_in * 2 - 1] 17 | # return: [n_batch, n_head, n_out, n_in] 18 | p = F.pad(posmat, (0, 1, 0, 1)).flatten(-2) # [n_batch, n_head, (n_out + 1) * n_in * 2] 19 | p = p.narrow(-1, posmat.shape[-1] // 2, posmat.shape[-1] * posmat.shape[-2]).view_as(posmat) 20 | 21 | return p.narrow(-1, 0, (posmat.shape[-1] + 1) // 2) 22 | 23 | def _attention(self, mask: Optional[torch.Tensor], 24 | q_content: torch.Tensor, k_content: torch.Tensor, 25 | q_pos: torch.Tensor, k_pos: torch.Tensor, 26 | v: torch.Tensor) -> torch.Tensor: 27 | 28 | # shape of q_content, q_pos, k_pos: [n_batch * n_heads, n_steps, data_size] 29 | # k_pos: [n_heads, n_in * 2 - 1, data_size] 30 | # Output shape [n_batch * n_heads, n_out, data_size] 31 | 32 | n_batch = q_content.shape[0] // self.n_heads 33 | n_out_steps = q_content.shape[1] 34 | 35 | # content-content addressing 36 | content = torch.bmm(q_content, k_content.transpose(1, 2)) 37 | 38 | # content-pos addressing. 39 | pos = torch.matmul(q_pos.view(n_batch, self.n_heads, n_out_steps, -1), k_pos.transpose(-1, -2)) # [n_batch, n_head, n_out, n_in * 2 - 1] 40 | pos = self._shift(pos).flatten(0, 1) 41 | 42 | # Logits shape: [n_batch * n_heads, n_out, n_in] 43 | return self._attention_read(mask, content + pos, v) 44 | 45 | def _get_pos_subset(self, pos_encoding: torch.Tensor, length: int, offset: int) -> torch.Tensor: 46 | l_slice = 2 * length - 1 47 | assert pos_encoding.shape[0] > l_slice 48 | return pos_encoding.narrow(0, pos_encoding.shape[0] // 2 - length + 1 - offset, 2 * length - 1) 49 | 50 | 51 | class FixedRelativeMultiheadAttention(RelativeAttentionBase): 52 | def __init__(self, state_size: int, n_heads: int, dropout: float = 0.0, global_pos_bias: bool = True, 53 | global_content_bias: bool = True, input_size: Optional[int] = None): 54 | super().__init__(state_size, n_heads, dropout) 55 | 56 | self.data_to_kv = torch.nn.Linear(state_size, 2 * n_heads * self.projection_size, bias=False) 57 | self.data_to_q = torch.nn.Linear(state_size if input_size is None else input_size, 58 | n_heads * self.projection_size, bias=False) 59 | 60 | self.global_content_bias = torch.nn.Parameter(torch.zeros([n_heads, self.projection_size])) \ 61 | if global_content_bias else None 62 | self.global_pos_bias = torch.nn.Parameter(torch.zeros([n_heads, self.projection_size])) \ 63 | if global_pos_bias else None 64 | 65 | self.pos_to_pq = torch.nn.Linear(state_size, self.n_heads * self.projection_size, bias=False) 66 | self.register_buffer("pos_encoding", self._create_buffer(1000)) 67 | 68 | def _create_buffer(self, max_len: int): 69 | return framework.layers.sinusoidal_pos_embedding(self.state_size, 2 * max_len - 1, -max_len + 1, 70 | device=self.data_to_q.weight.device) 71 | 72 | def get_pos(self, l: int, offset: int) -> torch.Tensor: 73 | if self.pos_encoding.shape[0] < 2 * (l + offset) - 1: 74 | self.pos_encoding = self._create_buffer(int(2**math.ceil(math.log2(2 * (l + offset) - 1)))) 75 | 76 | return self._get_pos_subset(self.pos_encoding, l, offset) 77 | 78 | def add_head_specific_bias(self, data: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: 79 | # data [batch * n_heads, len, c] 80 | # bias [n_heads, c] 81 | return (data.view(-1, bias.shape[0], *data.shape[1:]) + bias.unsqueeze(1).type_as(data)).view_as(data) \ 82 | if bias is not None else data 83 | 84 | def forward(self, curr_state: torch.Tensor, attend_to: torch.Tensor, mask: Optional[AttentionMask], 85 | pos_offset: int = 0, need_weights: bool = False): 86 | # curr_state: [batch_size, out_len, c] 87 | # attend_to: [batch_size, in_len, c] 88 | batch_size, in_len = attend_to.shape[0:2] 89 | out_len = curr_state.shape[1] 90 | 91 | k_content, v = self.transform_data(attend_to, self.data_to_kv, 2) 92 | q, = self.transform_data(curr_state, self.data_to_q, 1) 93 | 94 | k_pos = self.pos_to_pq(self.get_pos(in_len, pos_offset)).view(-1, self.n_heads, self.projection_size).\ 95 | transpose(0, 1) # n_heads, 2*in_len -1 , projection_size 96 | 97 | q_content = self.add_head_specific_bias(q, self.global_content_bias) 98 | q_pos = self.add_head_specific_bias(q, self.global_pos_bias) 99 | 100 | data, scores = self.merged_attention(batch_size, out_len, mask, q_content, k_content, q_pos, k_pos, v, 101 | need_weights=need_weights) 102 | 103 | if need_weights: 104 | # Calculate the mean over the heads 105 | return data, scores.mean(1) 106 | else: 107 | return data 108 | 109 | def reset_parameters(self): 110 | super().reset_parameters() 111 | 112 | torch.nn.init.xavier_uniform_(self.data_to_q.weight) 113 | torch.nn.init.xavier_uniform_(self.pos_to_pq.weight) 114 | torch.nn.init.xavier_uniform_(self.data_to_kv.weight[:self.data_to_kv.weight.shape[0]//2]) 115 | torch.nn.init.xavier_uniform_(self.data_to_kv.weight[self.data_to_kv.weight.shape[0]//2:]) 116 | 117 | if self.global_content_bias is not None: 118 | self.global_content_bias.fill_(0) 119 | 120 | if self.global_pos_bias is not None: 121 | self.global_pos_bias.fill_(0) 122 | -------------------------------------------------------------------------------- /layers/transformer/relative_transformer.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Optional 3 | import torch 4 | import torch.nn 5 | import torch.nn.functional as F 6 | from .transformer import ActivationFunction 7 | from .multi_head_relative_pos_attention import FixedRelativeMultiheadAttention, AttentionMask 8 | from .multi_head_attention import MultiHeadAttention 9 | from .transformer import Transformer, TransformerEncoderWithLayer, TransformerDecoderWithLayer 10 | 11 | 12 | class RelativeTransformerEncoderLayer(torch.nn.Module): 13 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation: ActivationFunction = F.relu): 14 | super().__init__() 15 | self.self_attn = FixedRelativeMultiheadAttention(d_model, nhead, dropout=dropout) 16 | self.linear1 = torch.nn.Linear(d_model, dim_feedforward) 17 | self.dropout = torch.nn.Dropout(dropout) 18 | self.linear2 = torch.nn.Linear(dim_feedforward, d_model) 19 | 20 | self.norm1 = torch.nn.LayerNorm(d_model) 21 | self.norm2 = torch.nn.LayerNorm(d_model) 22 | self.dropout1 = torch.nn.Dropout(dropout) 23 | self.dropout2 = torch.nn.Dropout(dropout) 24 | 25 | self.activation = activation 26 | self.reset_parameters() 27 | 28 | def forward(self, src: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: 29 | src2 = self.self_attn(src, src, AttentionMask(mask, None)) 30 | src = src + self.dropout1(src2) 31 | src = self.norm1(src) 32 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 33 | src = src + self.dropout2(src2) 34 | src = self.norm2(src) 35 | return src 36 | 37 | def reset_parameters(self): 38 | torch.nn.init.xavier_uniform_(self.linear1.weight, gain=torch.nn.init.calculate_gain('relu') 39 | if self.activation is F.relu else 1.0) 40 | torch.nn.init.xavier_uniform_(self.linear2.weight) 41 | 42 | 43 | class RelativeTransformerDecoderLayer(torch.nn.Module): 44 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation: ActivationFunction = F.relu): 45 | super().__init__() 46 | 47 | self.self_attn = FixedRelativeMultiheadAttention(d_model, nhead, dropout=dropout) 48 | self.multihead_attn = MultiHeadAttention(d_model, nhead, dropout=dropout) 49 | # Implementation of Feedforward model 50 | self.linear1 = torch.nn.Linear(d_model, dim_feedforward) 51 | self.dropout = torch.nn.Dropout(dropout) 52 | self.linear2 = torch.nn.Linear(dim_feedforward, d_model) 53 | 54 | self.norm1 = torch.nn.LayerNorm(d_model) 55 | self.norm2 = torch.nn.LayerNorm(d_model) 56 | self.norm3 = torch.nn.LayerNorm(d_model) 57 | self.dropout1 = torch.nn.Dropout(dropout) 58 | self.dropout2 = torch.nn.Dropout(dropout) 59 | self.dropout3 = torch.nn.Dropout(dropout) 60 | 61 | self.activation = activation 62 | self.reset_parameters() 63 | 64 | def forward(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None, 65 | memory_key_padding_mask: Optional[torch.Tensor] = None, 66 | full_target: Optional[torch.Tensor] = None, pos_offset: int = 0) -> torch.Tensor: 67 | assert pos_offset == 0 or tgt_mask is None 68 | tgt2 = self.self_attn(tgt, tgt if full_target is None else full_target, mask=AttentionMask(None, tgt_mask), 69 | pos_offset=pos_offset) 70 | tgt = tgt + self.dropout1(tgt2) 71 | tgt = self.norm1(tgt) 72 | tgt2 = self.multihead_attn(tgt, memory, mask=AttentionMask(memory_key_padding_mask, None)) 73 | tgt = tgt + self.dropout2(tgt2) 74 | tgt = self.norm2(tgt) 75 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 76 | tgt = tgt + self.dropout3(tgt2) 77 | tgt = self.norm3(tgt) 78 | return tgt 79 | 80 | def reset_parameters(self): 81 | torch.nn.init.xavier_uniform_(self.linear1.weight, gain=torch.nn.init.calculate_gain('relu') 82 | if self.activation is F.relu else 1.0) 83 | torch.nn.init.xavier_uniform_(self.linear2.weight) 84 | 85 | 86 | class RelativeTransformer(Transformer): 87 | def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6, 88 | num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1, 89 | activation: ActivationFunction = F.relu): 90 | 91 | super().__init__(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout, activation, 92 | TransformerEncoderWithLayer(RelativeTransformerEncoderLayer), 93 | TransformerDecoderWithLayer(RelativeTransformerDecoderLayer)) 94 | -------------------------------------------------------------------------------- /layers/transformer/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import torch.nn.functional as F 4 | from .multi_head_attention import MultiHeadAttention, AttentionMask 5 | from typing import Optional, Callable, Dict 6 | from dataclasses import dataclass 7 | # This file is based on PyTorch's internal implementation 8 | 9 | ActivationFunction = Callable[[torch.Tensor], torch.Tensor] 10 | 11 | 12 | class TransformerEncoderLayer(torch.nn.Module): 13 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation: ActivationFunction = F.relu): 14 | super(TransformerEncoderLayer, self).__init__() 15 | self.self_attn = MultiHeadAttention(d_model, nhead, dropout=dropout) 16 | self.linear1 = torch.nn.Linear(d_model, dim_feedforward) 17 | self.dropout = torch.nn.Dropout(dropout) 18 | self.linear2 = torch.nn.Linear(dim_feedforward, d_model) 19 | 20 | self.norm1 = torch.nn.LayerNorm(d_model) 21 | self.norm2 = torch.nn.LayerNorm(d_model) 22 | self.dropout1 = torch.nn.Dropout(dropout) 23 | self.dropout2 = torch.nn.Dropout(dropout) 24 | 25 | self.activation = activation 26 | self.reset_parameters() 27 | 28 | def forward(self, src: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: 29 | src2 = self.self_attn(src, src, AttentionMask(mask, None)) 30 | src = src + self.dropout1(src2) 31 | src = self.norm1(src) 32 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 33 | src = src + self.dropout2(src2) 34 | src = self.norm2(src) 35 | return src 36 | 37 | def reset_parameters(self): 38 | torch.nn.init.xavier_uniform_(self.linear1.weight, gain=torch.nn.init.calculate_gain('relu') 39 | if self.activation is F.relu else 1.0) 40 | torch.nn.init.xavier_uniform_(self.linear2.weight) 41 | 42 | 43 | class TransformerDecoderLayer(torch.nn.Module): 44 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation: ActivationFunction = F.relu): 45 | super(TransformerDecoderLayer, self).__init__() 46 | 47 | self.self_attn = MultiHeadAttention(d_model, nhead, dropout=dropout) 48 | self.multihead_attn = MultiHeadAttention(d_model, nhead, dropout=dropout) 49 | self.linear1 = torch.nn.Linear(d_model, dim_feedforward) 50 | self.dropout = torch.nn.Dropout(dropout) 51 | self.linear2 = torch.nn.Linear(dim_feedforward, d_model) 52 | 53 | self.norm1 = torch.nn.LayerNorm(d_model) 54 | self.norm2 = torch.nn.LayerNorm(d_model) 55 | self.norm3 = torch.nn.LayerNorm(d_model) 56 | self.dropout1 = torch.nn.Dropout(dropout) 57 | self.dropout2 = torch.nn.Dropout(dropout) 58 | self.dropout3 = torch.nn.Dropout(dropout) 59 | 60 | self.activation = activation 61 | self.reset_parameters() 62 | 63 | def forward(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None, 64 | memory_key_padding_mask: Optional[torch.Tensor] = None, 65 | full_target: Optional[torch.Tensor] = None, pos_offset: int = 0) -> torch.Tensor: 66 | 67 | assert pos_offset == 0 or tgt_mask is None 68 | tgt2 = self.self_attn(tgt, tgt if full_target is None else full_target, mask=AttentionMask(None, tgt_mask)) 69 | tgt = tgt + self.dropout1(tgt2) 70 | tgt = self.norm1(tgt) 71 | tgt2 = self.multihead_attn(tgt, memory, mask=AttentionMask(memory_key_padding_mask, None)) 72 | tgt = tgt + self.dropout2(tgt2) 73 | tgt = self.norm2(tgt) 74 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 75 | tgt = tgt + self.dropout3(tgt2) 76 | tgt = self.norm3(tgt) 77 | return tgt 78 | 79 | def reset_parameters(self): 80 | torch.nn.init.xavier_uniform_(self.linear1.weight, gain=torch.nn.init.calculate_gain('relu') 81 | if self.activation is F.relu else 1.0) 82 | torch.nn.init.xavier_uniform_(self.linear2.weight) 83 | 84 | 85 | class TransformerDecoderBase(torch.nn.Module): 86 | @dataclass 87 | class State: 88 | step: int 89 | state: Dict[int, torch.Tensor] 90 | 91 | def __init__(self, d_model: int): 92 | super().__init__() 93 | self.d_model = d_model 94 | 95 | def create_state(self, batch_size: int, max_length: int, device: torch.device) -> State: 96 | return self.State(0, {i: torch.empty([batch_size, max_length, self.d_model], device=device) 97 | for i in range(len(self.layers))}) 98 | 99 | def one_step_forward(self, state: State, data: torch.Tensor, *args, **kwargs): 100 | assert data.shape[1] == 1, f"For one-step forward should have one timesteps, but shape is {data.shape}" 101 | assert state.step < state.state[0].shape[1] 102 | 103 | for i, l in enumerate(self.layers): 104 | state.state[i][:, state.step:state.step + 1] = data 105 | data = l(data, *args, **kwargs, full_target=state.state[i][:, :state.step + 1], 106 | pos_offset=state.step) 107 | 108 | state.step += 1 109 | return data 110 | 111 | 112 | class TransformerEncoder(torch.nn.Module): 113 | def __init__(self, layer, n_layers: int, *args, **kwargs): 114 | super().__init__() 115 | self.layers = torch.nn.ModuleList([layer(*args, **kwargs) for _ in range(n_layers)]) 116 | 117 | def forward(self, data: torch.Tensor, *args, **kwargs): 118 | for l in self.layers: 119 | data = l(data, *args, **kwargs) 120 | return data 121 | 122 | 123 | class TransformerDecoder(TransformerDecoderBase): 124 | def __init__(self, layer, n_layers: int, d_model: int, *args, **kwargs): 125 | super().__init__(d_model) 126 | self.layers = torch.nn.ModuleList([layer(d_model, *args, **kwargs) for _ in range(n_layers)]) 127 | 128 | def forward(self, data: torch.Tensor, *args, **kwargs): 129 | for l in self.layers: 130 | data = l(data, *args, **kwargs) 131 | return data 132 | 133 | 134 | def TransformerEncoderWithLayer(layer=TransformerEncoderLayer): 135 | return lambda *args, **kwargs: TransformerEncoder(layer, *args, **kwargs) 136 | 137 | 138 | def TransformerDecoderWithLayer(layer=TransformerDecoderLayer): 139 | return lambda *args, **kwargs: TransformerDecoder(layer, *args, **kwargs) 140 | 141 | 142 | class Transformer(torch.nn.Module): 143 | def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6, 144 | num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1, 145 | activation: ActivationFunction = F.relu, encoder_layer=TransformerEncoderWithLayer(), 146 | decoder_layer=TransformerDecoderWithLayer()): 147 | super().__init__() 148 | 149 | self.encoder = encoder_layer(num_encoder_layers, d_model, nhead, dim_feedforward, 150 | dropout, activation) 151 | self.decoder = decoder_layer(num_decoder_layers, d_model, nhead, dim_feedforward, 152 | dropout, activation) 153 | 154 | def forward(self, src: torch.Tensor, tgt: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None, 155 | src_length_mask: Optional[torch.Tensor] = None): 156 | 157 | memory = self.encoder(src, src_length_mask) 158 | return self.decoder(tgt, memory, tgt_mask, src_length_mask) 159 | 160 | def generate_square_subsequent_mask(self, sz: int, device: torch.device) -> torch.Tensor: 161 | return torch.triu(torch.ones(sz, sz, dtype=torch.bool, device=device), diagonal=1) 162 | -------------------------------------------------------------------------------- /layers/transformer/universal_relative_transformer.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from .transformer import Transformer, ActivationFunction 3 | from .universal_transformer import UniversalTransformerDecoderWithLayer, UniversalTransformerEncoderWithLayer 4 | from .relative_transformer import RelativeTransformerDecoderLayer, RelativeTransformerEncoderLayer 5 | 6 | 7 | class UniversalRelativeTransformer(Transformer): 8 | def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6, 9 | num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1, 10 | activation: ActivationFunction = F.relu): 11 | 12 | super().__init__(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout, activation, 13 | UniversalTransformerEncoderWithLayer(RelativeTransformerEncoderLayer), 14 | UniversalTransformerDecoderWithLayer(RelativeTransformerDecoderLayer)) 15 | -------------------------------------------------------------------------------- /layers/transformer/universal_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import torch.nn.functional as F 4 | from .transformer import TransformerDecoderBase, ActivationFunction, TransformerEncoderLayer, TransformerDecoderLayer, \ 5 | Transformer 6 | 7 | 8 | class UniversalTransformerEncoder(torch.nn.Module): 9 | def __init__(self, layer, depth: int, *args, **kwargs): 10 | super().__init__() 11 | self.layer = layer(*args, **kwargs) 12 | self.layers = [self.layer] * depth 13 | 14 | def forward(self, data: torch.Tensor, *args, **kwargs): 15 | for l in self.layers: 16 | data = l(data, *args, **kwargs) 17 | return data 18 | 19 | 20 | class UniversalTransformerDecoder(TransformerDecoderBase): 21 | def __init__(self, layer, depth: int, d_model: int, *args, **kwargs): 22 | super().__init__(d_model) 23 | self.layer = layer(d_model, *args, **kwargs) 24 | self.layers = [self.layer] * depth 25 | 26 | def forward(self, data: torch.Tensor, *args, **kwargs): 27 | for l in self.layers: 28 | data = l(data, *args, **kwargs) 29 | return data 30 | 31 | 32 | def UniversalTransformerEncoderWithLayer(layer=TransformerEncoderLayer): 33 | return lambda *args, **kwargs: UniversalTransformerEncoder(layer, *args, **kwargs) 34 | 35 | 36 | def UniversalTransformerDecoderWithLayer(layer=TransformerDecoderLayer): 37 | return lambda *args, **kwargs: UniversalTransformerDecoder(layer, *args, **kwargs) 38 | 39 | 40 | class UniversalTransformer(Transformer): 41 | def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = 6, 42 | num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1, 43 | activation: ActivationFunction = F.relu): 44 | 45 | super().__init__(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout, activation, 46 | UniversalTransformerEncoderWithLayer(), 47 | UniversalTransformerDecoderWithLayer()) 48 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import framework 2 | import tasks 3 | 4 | 5 | def register_args(parser: framework.helpers.ArgumentParser): 6 | parser.add_argument("-batch_size", default=128) 7 | parser.add_argument("-lr", default=1e-3) 8 | parser.add_argument("-wd", default=0.0) 9 | parser.add_argument("-lr_warmup", default=0) 10 | parser.add_argument("-test_interval", default=1000) 11 | parser.add_argument("-state_size", default=128) 12 | parser.add_argument("-stop_after", default="None", parser=parser.int_or_none_parser) 13 | parser.add_argument("-task", default="trafo_scan") 14 | parser.add_argument("-dropout", default=0.0) 15 | parser.add_argument("-grad_clip", default="1.0", parser=parser.float_or_none_parser) 16 | parser.add_argument("-scan.train_split", default="simple", parser=parser.str_list_parser) 17 | parser.add_argument("-scan.length_cutoff", default=22) 18 | parser.add_argument("-layer_sizes", default="800,800,256", parser=parser.int_list_parser) 19 | parser.add_argument("-transformer.n_heads", default=4) 20 | parser.add_argument("-transformer.variant", default="scaledinit") 21 | parser.add_argument("-transformer.ff_multiplier", default=2.0) 22 | parser.add_argument("-transformer.encoder_n_layers", default=3) 23 | parser.add_argument("-transformer.decoder_n_layers", default="3", parser=parser.int_or_none_parser) 24 | parser.add_argument("-transformer.tied_embedding", default=True) 25 | parser.add_argument("-test_batch_size", default="None", parser=parser.int_or_none_parser) 26 | parser.add_argument("-dm_math.tasks", default="algebra__linear_1d", parser=parser.str_list_parser) 27 | parser.add_argument("-dm_math.train_splits", default="easy,medium,hard", parser=parser.str_list_parser) 28 | parser.add_argument("-lr_sched.steps", default="", parser=parser.int_list_parser) 29 | parser.add_argument("-lr_sched.gamma", default=0.1) 30 | parser.add_argument("-lr_sched.type", default="step", choice=["step", "noam"]) 31 | parser.add_argument("-optimizer", default="adam", choice=["adam", "sgd"]) 32 | parser.add_argument("-adam.betas", default="0.9,0.999", parser=parser.float_list_parser) 33 | parser.add_argument("-amp", default=False) 34 | parser.add_argument("-cogs.generalization_test_interval", default=2500) 35 | parser.add_argument("-label_smoothing", default=0.0) 36 | parser.add_argument("-pcfg.split", default="simple", choice=["simple", "productivity", "substitutivity", 37 | "systematicity"]) 38 | parser.add_argument("-cfq.split", default="random", choice=["random", "query_complexity", "question_complexity", 39 | "query_pattern", "question_pattern", "mcd1", "mcd2", 40 | "mcd3"]) 41 | parser.add_argument("-max_length_per_batch", default="none", parser=parser.int_or_none_parser) 42 | parser.add_argument("-log_sample_level_loss", default=False) 43 | 44 | parser.add_profile([ 45 | parser.Profile("cfq_trafo", { 46 | "task": "cfq_trafo", 47 | "transformer.variant": "noscale", 48 | "state_size": 128, 49 | "transformer.n_heads": 16, 50 | "transformer.ff_multiplier": 2, 51 | "transformer.encoder_n_layers": 2, 52 | "transformer.decoder_n_layers": 2, 53 | "grad_clip": 1, 54 | "stop_after": 50000, 55 | "dropout": 0.1, 56 | "batch_size": 512, 57 | "lr": 1e-4, 58 | }), 59 | 60 | parser.Profile("cfq_universal_trafo", { 61 | "transformer.variant": "universal_noscale", 62 | "state_size": 256, 63 | "transformer.n_heads": 4, 64 | "transformer.ff_multiplier": 2, 65 | "transformer.encoder_n_layers": 6, 66 | "transformer.decoder_n_layers": 6, 67 | }, include="cfq_trafo"), 68 | 69 | parser.Profile("cogs_trafo_small", { 70 | "task": "cogs_transformer", 71 | "state_size": 512, 72 | "transformer.n_heads": 4, 73 | "transformer.ff_multiplier": 1, 74 | "transformer.encoder_n_layers": 2, 75 | "transformer.decoder_n_layers": 2, 76 | "grad_clip": "none", 77 | "stop_after": 50000, 78 | "dropout": 0.1, 79 | "batch_size": 128, 80 | "lr": 2, 81 | "lr_sched.type": "noam", 82 | "lr_warmup": 4000, 83 | }), 84 | 85 | parser.Profile("deepmind_math", { 86 | "task": "dm_math_transformer", 87 | "lr": 1e-4, 88 | "stop_after": 50000, 89 | "batch_size": 256, 90 | "mask_loss_weight": 0.001, 91 | "state_size": 512, 92 | "transformer.n_heads": 8, 93 | "transformer.ff_multiplier": 4, 94 | "transformer.encoder_n_layers": 6, 95 | "transformer.decoder_n_layers": 6, 96 | "test_batch_size": 1024, 97 | "grad_clip": 0.1 98 | }), 99 | 100 | parser.Profile("pcfg_trafo", { 101 | "task": "pcfg_transformer", 102 | "state_size": 512, 103 | "transformer.n_heads": 8, 104 | "transformer.ff_multiplier": 4, 105 | "transformer.encoder_n_layers": 6, 106 | "transformer.decoder_n_layers": 6, 107 | "lr": 1e-3, 108 | "grad_clip": "1", 109 | "stop_after": 1000000, 110 | "batch_size": 64 111 | }), 112 | 113 | parser.Profile("trafo_scan", { 114 | "lr": 1e-3, 115 | "grad_clip": "5", 116 | "stop_after": 15000, 117 | "batch_size": 256, 118 | "dropout": 0.5, 119 | "embedding_size": 16, 120 | "task": "trafo_scan", 121 | "state_size": 128, 122 | "transformer.n_heads": 8, 123 | "test_batch_size": 2048 124 | }) 125 | ]) 126 | 127 | 128 | def main(): 129 | helper = framework.helpers.TrainingHelper(wandb_project_name="modules", 130 | register_args=register_args, extra_dirs=["export", "model_weights"]) 131 | 132 | def invalid_task_error(_): 133 | assert False, f"Invalid task: {helper.args.task}" 134 | 135 | constructors = { 136 | "pcfg_transformer": tasks.PCFGTransformer, 137 | "cogs_transformer": tasks.COGSTransformer, 138 | "trafo_scan": tasks.ScanTransformer, 139 | "scan_resplit_transformer": tasks.ScanResplitTransformer, 140 | "cfq_trafo": tasks.CFQTransformer, 141 | "dm_math_transformer": tasks.DMMathTransformer, 142 | } 143 | 144 | task = constructors.get(helper.args.task, invalid_task_error)(helper) 145 | task.train() 146 | helper.finish() 147 | 148 | 149 | if __name__ == "__main__": 150 | main() 151 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer_enc_dec import TransformerEncDecModel 2 | -------------------------------------------------------------------------------- /models/encoder_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | 4 | 5 | def add_eos(input: torch.Tensor, lengths: torch.Tensor, eos_id: int): 6 | input = torch.cat((input, torch.zeros_like(input[0:1])), dim=0) 7 | input.scatter_(0, lengths.unsqueeze(0).long(), value=eos_id) 8 | return input 9 | 10 | -------------------------------------------------------------------------------- /models/transformer_enc_dec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import torch.nn.functional as F 4 | import framework 5 | from layers import Transformer, TiedEmbedding 6 | from typing import Callable, Optional 7 | import math 8 | 9 | 10 | # Cannot be dataclass, because that won't work with gather 11 | class TransformerResult(framework.data_structures.DotDict): 12 | data: torch.Tensor 13 | length: torch.Tensor 14 | 15 | @staticmethod 16 | def create(data: torch.Tensor, length: torch.Tensor): 17 | return TransformerResult({"data": data, "length": length}) 18 | 19 | 20 | class TransformerEncDecModel(torch.nn.Module): 21 | def __init__(self, n_input_tokens: int, n_out_tokens: int, state_size: int = 512, ff_multiplier: float = 4, 22 | max_len: int=5000, transformer = Transformer, tied_embedding: bool=False, 23 | pos_embeddig: Optional[Callable[[torch.Tensor, int], torch.Tensor]] = None, 24 | encoder_sos: bool = True, same_enc_dec_embedding: bool = False, embedding_init: str = "pytorch", 25 | in_embedding_size: Optional[int] = None, out_embedding_size: Optional[int] = None, 26 | scale_mode: str = "none", **kwargs): 27 | ''' 28 | Transformer encoder-decoder. 29 | 30 | :param n_input_tokens: Number of channels for the input vectors 31 | :param n_out_tokens: Number of channels for the output vectors 32 | :param state_size: The size of the internal state of the transformer 33 | ''' 34 | super().__init__() 35 | 36 | assert scale_mode in ["none", "opennmt", "down"] 37 | assert embedding_init in ["pytorch", "xavier", "kaiming"] 38 | 39 | assert (not same_enc_dec_embedding) or (n_input_tokens == n_out_tokens) 40 | 41 | self.tied_embedding = tied_embedding 42 | 43 | self.decoder_sos_eos = n_out_tokens 44 | self.encoder_eos = n_input_tokens 45 | self.encoder_sos = n_input_tokens + 1 if encoder_sos else None 46 | self.state_size = state_size 47 | self.embedding_init = embedding_init 48 | self.ff_multiplier = ff_multiplier 49 | self.n_input_tokens = n_input_tokens 50 | self.n_out_tokens = n_out_tokens 51 | self.in_embedding_size = in_embedding_size 52 | self.out_embedding_size = out_embedding_size 53 | self.same_enc_dec_embedding = same_enc_dec_embedding 54 | self.scale_mode = scale_mode 55 | self.pos = pos_embeddig or framework.layers.PositionalEncoding(state_size, max_len=max_len, batch_first=True, 56 | scale=(1.0 / math.sqrt(state_size)) if scale_mode == "down" else 1.0) 57 | 58 | self.register_buffer('int_seq', torch.arange(max_len, dtype=torch.long)) 59 | self.construct(transformer, **kwargs) 60 | self.reset_parameters() 61 | 62 | def pos_embed(self, t: torch.Tensor, offset: int, scale_offset: int) -> torch.Tensor: 63 | if self.scale_mode == "opennmt": 64 | t = t * math.sqrt(t.shape[-1]) 65 | 66 | return self.pos(t, offset) 67 | 68 | def construct(self, transformer, **kwargs): 69 | self.input_embedding = torch.nn.Embedding(self.n_input_tokens + 1 + int(self.encoder_sos is not None), 70 | self.in_embedding_size or self.state_size) 71 | self.output_embedding = self.input_embedding if self.same_enc_dec_embedding else \ 72 | torch.nn.Embedding(self.n_out_tokens+1, self.out_embedding_size or self.state_size) 73 | 74 | if self.in_embedding_size is not None: 75 | self.in_embedding_upscale = torch.nn.Linear(self.in_embedding_size, self.state_size) 76 | 77 | if self.out_embedding_size is not None: 78 | self.out_embedding_upscale = torch.nn.Linear(self.out_embedding_size, self.state_size) 79 | 80 | if self.tied_embedding: 81 | assert self.out_embedding_size is None 82 | self.output_map = TiedEmbedding(self.output_embedding.weight) 83 | else: 84 | self.output_map = torch.nn.Linear(self.state_size, self.n_out_tokens+1) 85 | 86 | self.trafo = transformer(d_model=self.state_size, dim_feedforward=int(self.ff_multiplier*self.state_size), 87 | **kwargs) 88 | 89 | def reset_parameters(self): 90 | if self.embedding_init == "xavier": 91 | torch.nn.init.xavier_uniform_(self.input_embedding.weight) 92 | torch.nn.init.xavier_uniform_(self.output_embedding.weight) 93 | elif self.embedding_init == "kaiming": 94 | torch.nn.init.kaiming_normal_(self.input_embedding.weight) 95 | torch.nn.init.kaiming_normal_(self.output_embedding.weight) 96 | 97 | if not self.tied_embedding: 98 | torch.nn.init.xavier_uniform_(self.output_map.weight) 99 | 100 | def generate_len_mask(self, max_len: int, len: torch.Tensor) -> torch.Tensor: 101 | return self.int_seq[: max_len] >= len.unsqueeze(-1) 102 | 103 | def output_embed(self, x: torch.Tensor) -> torch.Tensor: 104 | o = self.output_embedding(x) 105 | if self.out_embedding_size is not None: 106 | o = self.out_embedding_upscale(o) 107 | return o 108 | 109 | def run_greedy(self, src: torch.Tensor, src_len: torch.Tensor, max_len: int) -> TransformerResult: 110 | batch_size = src.shape[0] 111 | n_steps = src.shape[1] 112 | 113 | in_len_mask = self.generate_len_mask(n_steps, src_len) 114 | memory = self.trafo.encoder(src, mask=in_len_mask) 115 | 116 | running = torch.ones([batch_size], dtype=torch.bool, device=src.device) 117 | out_len = torch.zeros_like(running, dtype=torch.long) 118 | 119 | next_tgt = self.pos_embed(self.output_embed(torch.full([batch_size,1], self.decoder_sos_eos, dtype=torch.long, 120 | device=src.device)), 0, 1) 121 | 122 | all_outputs = [] 123 | state = self.trafo.decoder.create_state(src.shape[0], max_len, src.device) 124 | 125 | for i in range(max_len): 126 | output = self.trafo.decoder.one_step_forward(state, next_tgt, memory, memory_key_padding_mask=in_len_mask) 127 | 128 | output = self.output_map(output) 129 | all_outputs.append(output) 130 | 131 | out_token = torch.argmax(output[:,-1], -1) 132 | running &= out_token != self.decoder_sos_eos 133 | 134 | out_len[running] = i + 1 135 | next_tgt = self.pos_embed(self.output_embed(out_token).unsqueeze(1), i+1, 1) 136 | 137 | return TransformerResult.create(torch.cat(all_outputs, 1), out_len) 138 | 139 | def run_teacher_forcing(self, src: torch.Tensor, src_len: torch.Tensor, target: torch.Tensor, 140 | target_len: torch.Tensor) -> TransformerResult: 141 | target = self.output_embed(F.pad(target[:, :-1], (1, 0), value=self.decoder_sos_eos).long()) 142 | target = self.pos_embed(target, 0, 1) 143 | 144 | in_len_mask = self.generate_len_mask(src.shape[1], src_len) 145 | 146 | res = self.trafo(src, target, src_length_mask=in_len_mask, 147 | tgt_mask=self.trafo.generate_square_subsequent_mask(target.shape[1], src.device)) 148 | 149 | return TransformerResult.create(self.output_map(res), target_len) 150 | 151 | def input_embed(self, x: torch.Tensor) -> torch.Tensor: 152 | src = self.input_embedding(x.long()) 153 | if self.in_embedding_size is not None: 154 | src = self.in_embedding_upscale(src) 155 | return src 156 | 157 | def forward(self, src: torch.Tensor, src_len: torch.Tensor, target: torch.Tensor, 158 | target_len: torch.Tensor, teacher_forcing: bool, max_len: Optional[int] = None) -> TransformerResult: 159 | ''' 160 | Run transformer encoder-decoder on some input/output pair 161 | 162 | :param src: source tensor. Shape: [N, S], where S in the in sequence length, N is the batch size 163 | :param src_len: length of source sequences. Shape: [N], N is the batch size 164 | :param target: target tensor. Shape: [N, S], where T in the in sequence length, N is the batch size 165 | :param target_len: length of target sequences. Shape: [N], N is the batch size 166 | :param teacher_forcing: use teacher forcing or greedy decoding 167 | :param max_len: overwrite autodetected max length. Useful for parallel execution 168 | :return: prediction of the target tensor. Shape [N, T, C_out] 169 | ''' 170 | 171 | if self.encoder_sos is not None: 172 | src = F.pad(src, (1, 0), value=self.encoder_sos) 173 | src_len = src_len + 1 174 | 175 | src = self.pos_embed(self.input_embed(src), 0, 0) 176 | 177 | if teacher_forcing: 178 | return self.run_teacher_forcing(src, src_len, target, target_len) 179 | else: 180 | return self.run_greedy(src, src_len, max_len or target_len.max().item()) 181 | -------------------------------------------------------------------------------- /optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .step_lr_sched import StepLrSched 2 | from .noam_lr_sched import NoamLRSched 3 | -------------------------------------------------------------------------------- /optimizer/noam_lr_sched.py: -------------------------------------------------------------------------------- 1 | class NoamLRSched: 2 | def __init__(self, lr: float, state_size: int, warmup_steps: int): 3 | self.lr = lr / (state_size ** 0.5) 4 | self.warmup_steps = warmup_steps 5 | 6 | def get(self, step: int) -> float: 7 | if step >= self.warmup_steps: 8 | return self.lr / float(step + 1) ** 0.5 9 | else: 10 | return self.lr / (self.warmup_steps**1.5) * float(step + 1) 11 | -------------------------------------------------------------------------------- /optimizer/step_lr_sched.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import bisect 3 | 4 | 5 | class StepLrSched: 6 | def __init__(self, lr: float, steps: List[int], gamma: float): 7 | self.steps = [0] + list(sorted(steps)) 8 | self.lrs = [lr * (gamma ** i) for i in range(len(self.steps))] 9 | 10 | def get(self, step: int) -> float: 11 | return self.lrs[bisect.bisect(self.steps, step) - 1] 12 | -------------------------------------------------------------------------------- /paper/.gitignore: -------------------------------------------------------------------------------- 1 | wandb 2 | out -------------------------------------------------------------------------------- /paper/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "wandb_project": "your_username/your_project" 3 | } 4 | -------------------------------------------------------------------------------- /paper/lib/__init__.py: -------------------------------------------------------------------------------- 1 | from . import common 2 | from .stat_tracker import StatTracker, Stat, MedianTracker 3 | from . import matplotlib_config 4 | from .source import get_runs 5 | from .config import get_config 6 | from . import source 7 | -------------------------------------------------------------------------------- /paper/lib/common.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Callable 2 | from .stat_tracker import StatTracker 3 | from typing import Any 4 | 5 | 6 | def construct_name(config_names: List[str], get_name: Callable[[str], str]) -> str: 7 | return "/".join([f"{c}_{get_name(c)}" for c in config_names]) 8 | 9 | 10 | def group(runs, config_names: List[str], get_config = lambda run, name: run.config[name]) -> Dict[str, Any]: 11 | res = {} 12 | for r in runs: 13 | cval = construct_name(config_names, lambda name: get_config(r, name)) 14 | if cval not in res: 15 | res[cval] = [] 16 | 17 | res[cval].append(r) 18 | 19 | return res 20 | 21 | def calc_stat(group_of_runs: Dict[str, List[Any]], filter, tracker=StatTracker) -> Dict[str, Dict[str, StatTracker]]: 22 | all_stats = {} 23 | 24 | for k, rn in group_of_runs.items(): 25 | if k not in all_stats: 26 | all_stats[k] = {} 27 | 28 | stats = all_stats[k] 29 | 30 | for r in rn: 31 | for k, v in r.summary.items(): 32 | if not filter(k): 33 | continue 34 | 35 | if k not in stats: 36 | stats[k] = tracker() 37 | 38 | stats[k].add(v) 39 | 40 | return all_stats 41 | -------------------------------------------------------------------------------- /paper/lib/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def get_config(): 5 | with open('config.json') as json_file: 6 | return json.load(json_file) -------------------------------------------------------------------------------- /paper/lib/matplotlib_config.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | plt.rcParams['text.usetex'] = True #Let TeX do the typsetting 4 | plt.rcParams['text.latex.preamble'] = [r'\usepackage{sansmath}', r'\sansmath'] #Force sans-serif math mode (for axes labels) 5 | plt.rcParams['font.family'] = 'sans-serif' # ... for regular text 6 | plt.rcParams['font.sans-serif'] = 'Helvetica, Avant Garde, Computer Modern Sans serif' # Choose a nice font here 7 | -------------------------------------------------------------------------------- /paper/lib/source.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | from typing import List 3 | from .config import get_config 4 | from typing import Dict 5 | 6 | from gql import gql 7 | 8 | def get_sweep_table(api: wandb.Api, project: str) -> Dict[str, str]: 9 | QUERY = gql(''' 10 | query Sweep($project: String!, $entity: String) { 11 | project(name: $project, entityName: $entity) { 12 | sweeps { 13 | edges { 14 | node { 15 | name 16 | displayName 17 | config 18 | } 19 | } 20 | } 21 | } 22 | }''') 23 | 24 | entity, project = project.split("/") 25 | response = api.client.execute(QUERY, variable_values={ 26 | 'entity': entity, 27 | 'project': project, 28 | }) 29 | 30 | edges = response.get("project", {}).get("sweeps", {}).get("edges") 31 | assert edges 32 | 33 | id_to_name = {} 34 | for sweep in edges: 35 | sweep = sweep["node"] 36 | 37 | name = sweep["displayName"] 38 | if name is None: 39 | name = [s for s in sweep["config"].split("\n") if s.startswith("name:")] 40 | assert len(name)==1 41 | name = name[0].split(":")[1].strip() 42 | 43 | id_to_name[sweep["name"]] = name 44 | 45 | return id_to_name 46 | 47 | 48 | def invert_sweep_id_table(t: Dict[str, str]) -> Dict[str, str]: 49 | repeats = set() 50 | res = {} 51 | for id, name in t.items(): 52 | if name in res: 53 | repeats.add(name) 54 | 55 | res[name] = id 56 | 57 | for r in repeats: 58 | del res[r] 59 | 60 | print("Removed the following duplicated sweeps:", repeats) 61 | 62 | return res 63 | 64 | sweep_table = None 65 | 66 | def get_runs(names: List[str], filters = {}) -> List[wandb.apis.public.Run]: 67 | global sweep_table 68 | api = wandb.Api() 69 | 70 | config = get_config() 71 | project = config["wandb_project"] 72 | 73 | if sweep_table is None: 74 | sweep_table = invert_sweep_id_table(get_sweep_table(api, project)) 75 | 76 | for n in names: 77 | assert n in sweep_table, f"Sweep {n} not found" 78 | 79 | sweep_id_list = [sweep_table[n] for n in names] 80 | filter = {"sweep": {"$in": sweep_id_list}} 81 | filter.update(filters) 82 | res = list(api.runs(project, filter)) 83 | 84 | assert len(res)>0, "Runs not found." 85 | assert all(r.state == "finished" for r in res) 86 | print(f"Querying runs {names}: {len(res)} runs loaded") 87 | assert len(res) > 0 88 | return res 89 | -------------------------------------------------------------------------------- /paper/lib/stat_tracker.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Union 3 | import math 4 | import numpy as np 5 | import statistics 6 | 7 | 8 | @dataclass 9 | class Stat: 10 | mean: Union[np.ndarray, float] 11 | std: Union[np.ndarray, float] 12 | n: int 13 | 14 | 15 | class StatTracker: 16 | def __init__(self): 17 | self.sum = 0 18 | self.sqsum = 0 19 | self.n = 0 20 | 21 | def add(self, v: float): 22 | if isinstance(v, np.ndarray): 23 | v = v.astype(np.float32) 24 | self.sum = self.sum + v 25 | self.sqsum = self.sqsum + v**2 26 | self.n += 1 27 | 28 | def get(self) -> Stat: 29 | assert self.n > 0 30 | mean = self.sum / self.n 31 | var = (self.sqsum / self.n - mean ** 2) * self.n/(self.n-1) if self.n>1 else 0 32 | 33 | return Stat(mean, np.sqrt(np.maximum(var,0)), self.n) 34 | 35 | def __repr__(self) -> str: 36 | s = self.get() 37 | return f"Stat(mean: {s.mean}, std: {s.std})" 38 | 39 | def __add__(self, other): 40 | res = StatTracker() 41 | res.sum = other.sum + self.sum 42 | res.sqsum = other.sqsum + self.sqsum 43 | res.n = other.n + self.n 44 | return res 45 | 46 | 47 | class MedianTracker: 48 | def __init__(self): 49 | self.elems = [] 50 | 51 | def add(self, v: float): 52 | if isinstance(v, np.ndarray): 53 | v = v.astype(np.float32) 54 | self.elems.append(v) 55 | 56 | def get(self) -> float: 57 | assert len(self.elems) > 0 58 | return statistics.median(self.elems) 59 | 60 | def __repr__(self) -> str: 61 | return f"Median({self.get()})" 62 | 63 | def __add__(self, other): 64 | res = MedianTracker() 65 | res.elems = self.elems + other.elems 66 | return res 67 | -------------------------------------------------------------------------------- /paper/plot_big_result_table.py: -------------------------------------------------------------------------------- 1 | import lib 2 | from collections import OrderedDict 3 | 4 | data = OrderedDict() 5 | data["SCAN (length cutoff=26)"] = lib.get_runs(["scan_trafo_length_cutoff"], filters = {"config.scan.length_cutoff.value": 26}), "val", None, "$0.00^{[1]}$" 6 | data["a"] = None, None, None, None 7 | data["CFQ Output length"] = lib.get_runs(["cfq_out_length", "cfq_out_length_universal"]), "test", 35000, "$\\sim 0.66^{[2]}$" 8 | cfq_runs = lib.get_runs(["cfq_mcd", "cfq_mcd_universal"]) 9 | data["CFQ MCD 1"] = [r for r in cfq_runs if r.config["cfq.split"] == "mcd1"], "test", 35000, "$0.37\\pm0.02^{[3]}$" 10 | data["CFQ MCD 2"] = [r for r in cfq_runs if r.config["cfq.split"] == "mcd2"], "test", 35000, "$0.08\\pm0.02^{[3]}$" 11 | data["CFQ MCD 3"] = [r for r in cfq_runs if r.config["cfq.split"] == "mcd3"], "test", 35000, "$0.11\\pm0.00^{[3]}$" 12 | data["CFQ MCD mean"] = cfq_runs, "test", 35000, "$0.19\\pm0.01^{[2]}$" 13 | data["b"] = None, None, None, None 14 | data["PCFG Productivity split"] = lib.get_runs(["pcfg_nosched_productivity"]), "val", None, "$0.50\\pm0.02^{[4]}$" 15 | data["PCFG Systematicity split"] = lib.get_runs(["pcfg_nosched_systematicity"]), "val", None, "$0.72\\pm0.00^{[4]}$" 16 | data["c"] = None, None, None, None 17 | data["COGS"] = lib.get_runs(["cogs_trafo"]), "gen", None, "$0.35\\pm0.06^{[5]}$" 18 | data["d"] = None, None, None, None 19 | math_runs = lib.get_runs(["dm_math"]) 20 | data["Math: add\\_or\\_sub"] = [r for r in math_runs if r.config["dm_math.task"] == "arithmetic__add_or_sub"], "extrapolate", None, "$\\sim0.91^{[6]*}$" 21 | data["Math: place\\_value"] = [r for r in math_runs if r.config["dm_math.task"] == "numbers__place_value"], "extrapolate", None, "$\\sim0.69^{[6]*}$" 22 | 23 | 24 | columns = OrderedDict() 25 | columns["Trafo"] = ["scaledinit", "noscale", "opennmt"] 26 | columns["Uni. Trafo"] = ["universal_scaledinit", "universal_noscale", "universal_opennmt"] 27 | columns["Rel. Trafo"] = ["relative"] 28 | columns["Rel. Uni. Trafo"] = ["relative_universal"] 29 | 30 | def average_accuracy(runs, split_name, step) -> float: 31 | st = lib.StatTracker() 32 | runs = list(runs) 33 | it = max([r.summary["iteration"] for r in runs]) 34 | for r in runs: 35 | if f"validation/{split_name}/accuracy/total" not in r.summary: 36 | continue 37 | 38 | if step is None: 39 | st.add(r.summary[f"validation/{split_name}/accuracy/total"]) 40 | assert r.summary["iteration"] == it, f"Inconsistend final iteration for run {r.id}: {r.summary['iteration']} instead of {it}" 41 | else: 42 | hist = r.history(keys=[f"validation/{split_name}/accuracy/total", "iteration"], pandas=False) 43 | for h in hist: 44 | if h["iteration"] == step: 45 | st.add(h[f"validation/{split_name}/accuracy/total"]) 46 | break 47 | else: 48 | assert False, f"Step {step} not found." 49 | return st.get() 50 | 51 | def format_results(runs, split_name, step) -> str: 52 | run_group = lib.common.group(runs, ['transformer.variant']) 53 | 54 | cols = [] 55 | for clist in columns.values(): 56 | found = [] 57 | for c in clist: 58 | full_name = f"transformer.variant_{c}" 59 | if full_name in run_group: 60 | found.append(average_accuracy(run_group[full_name], split_name, step)) 61 | 62 | cols.append(max(found, key=lambda x: x.mean) if found else None) 63 | 64 | maxval = max(c.mean for c in cols if c is not None) 65 | cols = [(("\\bf{" if c.mean == maxval else "") + f"{c.mean:.2f} $\\pm$ {c.std:.2f}" + 66 | ("}" if c.mean == maxval else "")) if c is not None else "-" for c in cols] 67 | return " & ".join(cols) 68 | 69 | print(" & " + " & ".join(columns.keys()) + " & Reported\\\\") 70 | print("\\midrule") 71 | for dname, (runs, splitname, at_step, best_other) in data.items(): 72 | if runs is None: 73 | print("\\midrule") 74 | else: 75 | print(f"{dname} & {format_results(runs, splitname, at_step)} & {best_other} \\\\") 76 | -------------------------------------------------------------------------------- /paper/plot_big_result_table_iid.py: -------------------------------------------------------------------------------- 1 | import lib 2 | from collections import OrderedDict 3 | 4 | data = OrderedDict() 5 | 6 | data["SCAN (length cutoff=26)"] = lib.get_runs(["scan_trafo_length_cutoff"], filters = {"config.scan.length_cutoff.value": 26}), "val", "iid", None 7 | data["c"] = None, None, None, None 8 | data["COGS"] = lib.get_runs(["cogs_trafo"]), "gen", "val", None 9 | data["d"] = None, None, None, None 10 | math_runs = lib.get_runs(["dm_math"]) 11 | data["Math: add\\_or\\_sub"] = [r for r in math_runs if r.config["dm_math.task"] == "arithmetic__add_or_sub"], "extrapolate", "interpolate", None 12 | data["Math: place\\_value"] = [r for r in math_runs if r.config["dm_math.task"] == "numbers__place_value"], "extrapolate", "interpolate", None 13 | 14 | 15 | columns = OrderedDict() 16 | columns["Trafo"] = ["scaledinit", "noscale", "opennmt"] 17 | columns["Uni. Trafo"] = ["universal_scaledinit", "universal_noscale", "universal_opennmt"] 18 | columns["Rel. Trafo"] = ["relative"] 19 | columns["Rel. Uni. Trafo"] = ["relative_universal"] 20 | 21 | def average_accuracy(runs, split_name, step) -> float: 22 | st = lib.StatTracker() 23 | runs = list(runs) 24 | it = max([r.summary["iteration"] for r in runs]) 25 | for r in runs: 26 | if f"validation/{split_name}/accuracy/total" not in r.summary: 27 | continue 28 | 29 | if step is None: 30 | st.add(r.summary[f"validation/{split_name}/accuracy/total"]) 31 | assert r.summary["iteration"] == it, f"Inconsistend final iteration for run {r.id}: {r.summary['iteration']} instead of {it}" 32 | else: 33 | hist = r.history(keys=[f"validation/{split_name}/accuracy/total", "iteration"], pandas=False) 34 | for h in hist: 35 | if h["iteration"] == step: 36 | st.add(h[f"validation/{split_name}/accuracy/total"]) 37 | break 38 | else: 39 | assert False, f"Step {step} not found." 40 | return st.get() 41 | 42 | def format_results(runs, split_name, ex_split_name, step) -> str: 43 | run_group = lib.common.group(runs, ['transformer.variant']) 44 | 45 | cols = [] 46 | ex_cols = [] 47 | for clist in columns.values(): 48 | found = [] 49 | for c in clist: 50 | full_name = f"transformer.variant_{c}" 51 | if full_name in run_group: 52 | found.append(( 53 | average_accuracy(run_group[full_name], split_name, step), 54 | average_accuracy(run_group[full_name], ex_split_name, step), 55 | )) 56 | 57 | if found: 58 | max_i = max(range(len(found)), key=lambda i: found[i][1].mean) 59 | cols.append(found[max_i][0]) 60 | ex_cols.append(found[max_i][1]) 61 | else: 62 | cols.append(None) 63 | ex_cols.append(None) 64 | 65 | maxval = max(c.mean for c in cols if c is not None) 66 | cols = [(("{\\bf" if c.mean > maxval-0.01 else "") + f"{c.mean:.2f} $\\pm$ {c.std:.2f}" + 67 | ("}" if c.mean > maxval-0.01 else "")) if c is not None else "-" for c in cols] 68 | cols = [c + (f" ({exc.mean:.2f})" if exc else "") for c, exc in zip(cols, ex_cols)] 69 | return " & ".join(cols) 70 | 71 | print(" & " + " & ".join(columns.keys()) + " \\\\") 72 | print("\\midrule") 73 | for dname, (runs, ex_split_name, splitname, at_step) in data.items(): 74 | if runs is None: 75 | print("\\midrule") 76 | else: 77 | print(f"{dname} & {format_results(runs, splitname, ex_split_name, at_step)} \\\\") 78 | -------------------------------------------------------------------------------- /paper/plot_big_result_table_with_init.py: -------------------------------------------------------------------------------- 1 | import lib 2 | from collections import OrderedDict 3 | 4 | data = OrderedDict() 5 | data["SCAN (length cutoff=26)"] = lib.get_runs(["scan_trafo_length_cutoff"], filters = {"config.scan.length_cutoff.value": 26}), "val", None, "$0.00^{[1]}$" 6 | data["a"] = None, None, None, None 7 | data["CFQ Output length"] = lib.get_runs(["cfq_out_length", "cfq_out_length_universal"]), "test", 35000, "$\\sim 0.66^{[2]}$" 8 | cfq_runs = lib.get_runs(["cfq_mcd", "cfq_mcd_universal"]) 9 | data["CFQ MCD 1"] = [r for r in cfq_runs if r.config["cfq.split"] == "mcd1"], "test", 35000, "$0.37\\pm0.02^{[3]}$" 10 | data["CFQ MCD 2"] = [r for r in cfq_runs if r.config["cfq.split"] == "mcd2"], "test", 35000, "$0.08\\pm0.02^{[3]}$" 11 | data["CFQ MCD 3"] = [r for r in cfq_runs if r.config["cfq.split"] == "mcd3"], "test", 35000, "$0.11\\pm0.00^{[3]}$" 12 | data["CFQ MCD mean"] = cfq_runs, "test", 35000, "$0.19\\pm0.01^{[2]}$" 13 | data["b"] = None, None, None, None 14 | data["PCFG Productivity split"] = lib.get_runs(["pcfg_nosched_productivity"]), "val", None, "$0.50\\pm0.02^{[4]}$" 15 | data["PCFG Systematicity split"] = lib.get_runs(["pcfg_nosched_systematicity"]), "val", None, "$0.72\\pm0.00^{[4]}$" 16 | data["c"] = None, None, None, None 17 | data["COGS"] = lib.get_runs(["cogs_trafo"]), "gen", None, "$0.35\\pm0.06^{[5]}$" 18 | data["d"] = None, None, None, None 19 | math_runs = lib.get_runs(["dm_math"]) 20 | data["Math: add\\_or\\_sub"] = [r for r in math_runs if r.config["dm_math.task"] == "arithmetic__add_or_sub"], "extrapolate", None, "$\\sim0.91^{[6]*}$" 21 | data["Math: place\\_value"] = [r for r in math_runs if r.config["dm_math.task"] == "numbers__place_value"], "extrapolate", None, "$\\sim0.69^{[6]*}$" 22 | 23 | 24 | columns = OrderedDict() 25 | columns["Trafo"] = ["scaledinit", "noscale", "opennmt"] 26 | columns["Uni. Trafo"] = ["universal_scaledinit", "universal_noscale", "universal_opennmt"] 27 | columns["Rel. Trafo"] = ["relative"] 28 | columns["Rel. Uni. Trafo"] = ["relative_universal"] 29 | 30 | init_type_table = { 31 | "scaledinit": "PED", 32 | "noscale": "No scaling", 33 | "opennmt": "TEU", 34 | "universal_scaledinit": "PED", 35 | "universal_noscale": "No scaling", 36 | "universal_opennmt": "TEU", 37 | "relative": "No scaling", 38 | "relative_universal": "No scaling" 39 | } 40 | 41 | init_order = ["PED", "TEU", "No scaling"] 42 | 43 | def average_accuracy(runs, split_name, step) -> float: 44 | st = lib.StatTracker() 45 | runs = list(runs) 46 | it = max([r.summary["iteration"] for r in runs]) 47 | for r in runs: 48 | if f"validation/{split_name}/accuracy/total" not in r.summary: 49 | continue 50 | 51 | if step is None: 52 | st.add(r.summary[f"validation/{split_name}/accuracy/total"]) 53 | assert r.summary["iteration"] == it, f"Inconsistend final iteration for run {r.id}: {r.summary['iteration']} instead of {it}" 54 | else: 55 | hist = r.history(keys=[f"validation/{split_name}/accuracy/total", "iteration"], pandas=False) 56 | for h in hist: 57 | if h["iteration"] == step: 58 | st.add(h[f"validation/{split_name}/accuracy/total"]) 59 | break 60 | else: 61 | assert False, f"Step {step} not found." 62 | return st.get() 63 | 64 | def format_results(runs, title, split_name, step, best_other) -> str: 65 | run_group = lib.common.group(runs, ['transformer.variant']) 66 | variants = {v for k, v in init_type_table.items() if f"transformer.variant_{k}" in run_group} 67 | 68 | rtmp = [] 69 | for i in init_order: 70 | if i not in variants: 71 | continue 72 | 73 | cols = [] 74 | for clist in columns.values(): 75 | cols.append(None) 76 | for c in clist: 77 | full_name = f"transformer.variant_{c}" 78 | if init_type_table[c] == i and full_name in run_group: 79 | assert cols[-1] is None, "Can't be multiple variants" 80 | cols[-1] = average_accuracy(run_group[full_name], split_name, step) 81 | 82 | rtmp.append((i, cols)) 83 | 84 | cols = [[r[1][i].mean for r in rtmp if r[1][i]] for i in range(len(columns))] 85 | maxy = [round(max(c), 2) if c else None for c in cols] 86 | 87 | rows = [] 88 | for i, cols in rtmp: 89 | maxx = max(c.mean for c in cols if c is not None) 90 | cols = [(("\\bf{" if maxy[ci] and round(c.mean, 2) == maxy[ci] else "") + \ 91 | #("\\bf{" if abs(c.mean - maxx) < eps else "") + \ 92 | #("\\emph{" if abs(c.mean - max_all) < eps else "") + \ 93 | f"{c.mean:.2f} $\\pm$ {c.std:.2f}" + \ 94 | #("}" if abs(c.mean - max_all) < eps else "") + \ 95 | #("}" if abs(c.mean - maxx) < eps else "") + \ 96 | ("}" if maxy[ci] and round(c.mean, 2) == maxy[ci] else "")) if c is not None else "-" \ 97 | for ci, c in enumerate(cols)] 98 | 99 | rows.append(f"{i} & " + " & ".join(cols)) 100 | 101 | res = f"\\multirow{{{len(rows)}}}{{*}}{{{title}}} & " + rows[0] + f" & \\multirow{{{len(rows)}}}{{*}}{{{best_other}}} \\\\ \n" 102 | for r in rows[1:]: 103 | res += f" & {r} & \\\\ \n" 104 | 105 | return res 106 | 107 | print(" & Init & " + " & ".join(columns.keys()) + " & Reported\\\\") 108 | print("\\midrule") 109 | is_first_in_block = True 110 | for dname, (runs, splitname, at_step, best_other) in data.items(): 111 | if runs is None: 112 | print("\\midrule") 113 | is_first_in_block = True 114 | else: 115 | if not is_first_in_block: 116 | print("\\cmidrule{2-7}") 117 | is_first_in_block = False 118 | print(format_results(runs, dname, splitname, at_step, best_other), end="") 119 | -------------------------------------------------------------------------------- /paper/plot_cogs_early_stopping.py: -------------------------------------------------------------------------------- 1 | import lib 2 | import matplotlib.pyplot as plt 3 | import os 4 | import statistics 5 | from collections import OrderedDict 6 | 7 | format = "pdf" 8 | os.makedirs("out", exist_ok=True) 9 | 10 | runs = lib.get_runs(["cogs_trafo"]) 11 | runs = lib.common.group(runs, ['transformer.variant']) 12 | runs["theirs"] = lib.get_runs(["cogs_trafo_official"]) 13 | 14 | window = 2500 15 | 16 | 17 | fig = plt.figure(figsize=[4.5,1.5]) 18 | 19 | def download(run, *args, **kwargs): 20 | hist = run.history(*args, **kwargs, pandas=False) 21 | points = {p["iteration"]: p for p in hist} 22 | iters = list(sorted(points.keys())) 23 | return iters, points 24 | 25 | def plot_runs(runs): 26 | data = [] 27 | for r in runs: 28 | iters, points = download(r, keys=["validation/gen/accuracy/total", "iteration"]) 29 | acc = [points[i]["validation/gen/accuracy/total"] for i in iters] 30 | 31 | # They might be recorded with different frequencies, so query them twice 32 | iters2, points = download(r, keys=["validation/val/time_since_best_loss", "iteration"]) 33 | 34 | stop_point = iters2[-1] 35 | for i in iters2: 36 | if points[i]["validation/val/time_since_best_loss"] >= window: 37 | stop_point = i 38 | break 39 | data.append((iters, acc, stop_point - window)) 40 | 41 | for d in data: 42 | assert d[0] == data[0][0] 43 | 44 | ystat = [lib.StatTracker() for _ in data[0][0]] 45 | for d in data: 46 | for s, v in zip(ystat, d[1]): 47 | s.add(v) 48 | ystat = [s.get() for s in ystat] 49 | mean = [s.mean*100 for s in ystat] 50 | std = [s.std*100 for s in ystat] 51 | 52 | median_stop_point = statistics.median([d[2] for d in data]) 53 | 54 | p = plt.plot(data[0][0], mean) 55 | plt.fill_between(data[0][0], [m-s for m, s in zip(mean, std)], [m+s for m, s in zip(mean, std)], alpha=0.3) 56 | color = p[-1].get_color() 57 | return lambda: plt.axvline(x=median_stop_point, color=color, zorder=-100, linestyle="--") 58 | 59 | 60 | plt.xlabel("Training steps") 61 | plt.ylabel("Gen. accuracy [$\\%$]") 62 | 63 | d = OrderedDict() 64 | d["No scaling"] = "transformer.variant_noscale" 65 | d["Token Emb. Up., Noam"] = "theirs" 66 | d["Position Emb. Down."] = "transformer.variant_scaledinit" 67 | 68 | print(list(runs.keys())) 69 | plot_markers = [] 70 | for n in d.values(): 71 | plot_markers.append(plot_runs(runs[n])) 72 | 73 | # Markers must be plotted after the lines because legend will not work otherwise 74 | for pm in plot_markers: 75 | pm() 76 | 77 | plt.legend(list(d.keys())) 78 | fig.axes[0].xaxis.set_major_formatter(lambda x, _: f"{x//1000:.0f}k" if x >= 1000 else f"{x:.0f}") 79 | plt.xlim(0,50000) 80 | fig.savefig(f"out/cogs_early_stop.{format}", bbox_inches='tight', pad_inches=0.01) 81 | -------------------------------------------------------------------------------- /paper/plot_init.py: -------------------------------------------------------------------------------- 1 | import lib 2 | import os 3 | from collections import OrderedDict 4 | 5 | format = "pdf" 6 | os.makedirs("out", exist_ok=True) 7 | 8 | trafos = OrderedDict() 9 | trafos["TEU"] = "transformer.variant_opennmt" 10 | trafos["No scaling"] = "transformer.variant_noscale" 11 | trafos["PED"] = "transformer.variant_scaledinit" 12 | 13 | runs = lib.get_runs(["cogs_trafo"]) 14 | runs = lib.common.group(runs, ['transformer.variant']) 15 | runs = lib.common.calc_stat(runs, lambda n: n == "validation/gen/accuracy/total") 16 | cogs_runs = {k: v["validation/gen/accuracy/total"].get() for k, v in runs.items()} 17 | 18 | runs = lib.get_runs(["pcfg_nosched_productivity"]) 19 | runs = lib.common.group(runs, ['transformer.variant']) 20 | runs = lib.common.calc_stat(runs, lambda n: n == "validation/val/accuracy/total") 21 | pcfg_runs = {k: v["validation/val/accuracy/total"].get() for k, v in runs.items()} 22 | 23 | all_runs = OrderedDict() 24 | all_runs["COGS"] = cogs_runs 25 | all_runs["PCFG"] = pcfg_runs 26 | 27 | print(pcfg_runs) 28 | 29 | print(" & "+" & ".join(trafos.keys())+"\\\\") 30 | print("\\midrule") 31 | for rname, runs in all_runs.items(): 32 | means = [runs[k].mean for k in trafos.values()] 33 | maxmean = max(means) 34 | nums = [f"{runs[k].mean:.2f} \\pm {runs[k].std:.2f}" for k in trafos.values()] 35 | nums = [f"$\\mathbf{{{n}}}$" if m > maxmean-0.01 else f"${n}$" for n, m in zip(nums, means)] 36 | 37 | print(rname + " & " + " & ".join(nums) + "\\\\") 38 | -------------------------------------------------------------------------------- /paper/plot_init_iid.py: -------------------------------------------------------------------------------- 1 | import lib 2 | from lib.common import calc_stat 3 | from collections import OrderedDict 4 | 5 | runs = lib.get_runs(["cogs_trafo"]) 6 | runs = lib.common.group(runs, ['transformer.variant']) 7 | stats = calc_stat(runs, lambda k: k.endswith("/accuracy/total")) 8 | 9 | runs = lib.get_runs(["pcfg_nosched_productivity"]) 10 | runs = lib.common.group(runs, ['transformer.variant']) 11 | pcfg_gen_stats = calc_stat(runs, lambda k: k.endswith("/accuracy/total")) 12 | 13 | runs = lib.get_runs(["pcfg_nosched_iid"]) 14 | runs = lib.common.group(runs, ['transformer.variant']) 15 | pcfg_iid_stats = calc_stat(runs, lambda k: k.endswith("/accuracy/total")) 16 | 17 | 18 | columns = OrderedDict() 19 | columns["IID Validation"] = ["val"] 20 | columns["Gen. Test"] = ["gen"] 21 | 22 | d = OrderedDict() 23 | d["Token Emb. Up."] = "transformer.variant_opennmt" 24 | d["No scaling"] = "transformer.variant_noscale" 25 | d["Pos. Emb. Down."] = "transformer.variant_scaledinit" 26 | 27 | print(stats) 28 | 29 | print(" & & " + " & ".join(columns) + "\\\\") 30 | print("\\midrule") 31 | print("\\parbox[t]{3mm}{\\multirow{3}{*}{\\rotatebox[origin=c]{90}{\\small COGS}}}") 32 | 33 | def print_table(data): 34 | best = [max(v[i].mean for v in data.values()) for i in range(len(columns))] 35 | for vname in d.keys(): 36 | s = data[vname] 37 | s = [("{\\bf" if m - a.mean < 0.005 else "") + f"{a.mean:.2f} $\\pm$ {a.std:.2f}" + ("}" if m - a.mean < 0.005 else "") for a, m in zip(s, best)] 38 | print(" & " + vname + " & " + " & ".join(s) +" \\\\") 39 | 40 | print_table({vname: [stats[vcode][f"validation/{k[0]}/accuracy/total"].get() for k in columns.values()] for vname, vcode in d.items()}) 41 | 42 | print("\\midrule") 43 | print("\\parbox[t]{3mm}{\\multirow{3}{*}{\\rotatebox[origin=c]{90}{\\small PCFG}}}") 44 | 45 | print_table({vname: [pcfg_iid_stats[vcode]["validation/val/accuracy/total"].get(), pcfg_gen_stats[vcode]["validation/val/accuracy/total"].get()] for vname, vcode in d.items()}) 46 | -------------------------------------------------------------------------------- /paper/plot_loss_accuracy.py: -------------------------------------------------------------------------------- 1 | import lib 2 | import matplotlib.pyplot as plt 3 | import os 4 | 5 | format = "pdf" 6 | os.makedirs("out", exist_ok=True) 7 | 8 | runs = lib.get_runs(["cogs_trafo"], filters={"config.transformer.variant.value": "opennmt"}) 9 | 10 | def plot(runs, group, from_iter: int = 0, loss_group=None): 11 | xs = [] 12 | ys = [] 13 | cs = [] 14 | 15 | loss_group = loss_group or group 16 | 17 | for r in runs: 18 | hist = r.history(keys=[f"validation/{group}/accuracy/total", f"validation/{loss_group}/loss", "iteration"], 19 | pandas=False) 20 | for p in hist: 21 | if p["iteration"] < from_iter: 22 | continue 23 | xs.append(p[f"validation/{loss_group}/loss"]) 24 | ys.append(p[f"validation/{group}/accuracy/total"]*100) 25 | cs.append(p["iteration"]) 26 | sc = plt.scatter(xs, ys, c=cs) 27 | plt.xscale('log') 28 | cbar = plt.colorbar(sc, ticks=[min(cs), max(cs)], pad=0.02) 29 | cbar.ax.set_yticklabels([f"{min(cs)//1000}k", f"{max(cs)//1000}k"]) 30 | # divider = make_axes_locatable(ax) 31 | # cax = divider.append_axes("right", size=0.25, pad=0.1) 32 | # plt.colorbar(im, cax) 33 | 34 | # plt.tick_params(axis='x', labelsize=8) 35 | fig.axes[0].yaxis.set_label_coords(-0.120, 0.45) 36 | 37 | def plot_test_axis_labels(): 38 | plt.xlabel("Test loss") 39 | plt.ylabel("Test accuracy [\%]") 40 | 41 | fig = plt.figure(figsize=[4.5,1.5]) 42 | plot(runs, "gen", 1000) 43 | plot_test_axis_labels() 44 | fig.axes[0].xaxis.set_minor_formatter(lambda x, _: f"{x:.2f}") 45 | fig.savefig(f"out/cogs_loss_accuracy.{format}", bbox_inches='tight', pad_inches=0.01) 46 | 47 | fig = plt.figure(figsize=[4.5,1.4]) 48 | plot(runs, "val", 1000) 49 | plt.xlabel("Validation loss") 50 | plt.ylabel("Val. accuracy [\%]") 51 | fig.savefig(f"out/cogs_loss_accuracy_val.{format}", bbox_inches='tight', pad_inches=0.01) 52 | 53 | 54 | runs = lib.get_runs(["cfq_mcd"]) 55 | # The API doesn't support dot in names, so filter manually. 56 | runs = [r for r in runs if r.config["cfq.split"] == "mcd1" and r.config["transformer.variant"] == "relative"] 57 | 58 | fig = plt.figure(figsize=[4.5,1.5]) 59 | plot(runs, "test", 1000, loss_group="val") 60 | fig.axes[0].xaxis.set_minor_formatter(lambda x, _: f"{x:.2f}") 61 | plt.xlabel("Validation loss") 62 | plt.ylabel("Test accuracy [\%]") 63 | fig.savefig(f"out/cfq_loss_accuracy.{format}", bbox_inches='tight', pad_inches=0.01) -------------------------------------------------------------------------------- /paper/plot_loss_analysis.py: -------------------------------------------------------------------------------- 1 | import lib 2 | import os 3 | import torch 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | format = "pdf" 8 | runs = lib.get_runs(["cfq_mcd"], filters={"config.cfq.split.value": "mcd1", "config.transformer.variant.value": "relative"}) 9 | 10 | os.makedirs("out/loss_analysis", exist_ok=True) 11 | 12 | for r in runs: 13 | bdir = f"out/loss_analysis/{r.id}" 14 | if not os.path.isdir(bdir): 15 | os.makedirs(f"out/loss_analysis/{r.id}") 16 | r.file("export/loss_details/test.pth").download(root=bdir, replace=True) 17 | 18 | hist = r.history(keys=["validation/test/loss", "iteration"], pandas=False) 19 | x = [p["iteration"] for p in hist] 20 | 21 | data = torch.load(f"{bdir}/export/loss_details/test.pth") 22 | 23 | good_mask = np.sum(data["oks"], 0) > 0 24 | # good_mask=data["oks"][-1] 25 | good = data["losses"][:, good_mask] 26 | bad = data["losses"][:, ~good_mask] 27 | 28 | fig = plt.figure(figsize=[4.5,1.5]) 29 | plt.plot(x, np.mean(good, -1)) 30 | plt.plot(x, np.mean(bad, -1)) 31 | plt.plot(x, np.mean(data["losses"], -1)) 32 | plt.legend(["``Good''", "``Bad''", "Total"], loc="upper left") 33 | fig.axes[0].xaxis.set_major_formatter(lambda x, _: f"{x//1000:.0f}k" if x >= 1000 else f"{x:.0f}") 34 | plt.xlabel("Training steps") 35 | plt.ylabel("Loss") 36 | plt.xlim(0, x[-1]) 37 | fig.savefig(f"out/loss_decomposed.{format}", bbox_inches='tight', pad_inches=0.01) 38 | 39 | fig = plt.figure(figsize=[4.5,1.5]) 40 | plt.hist(good[0], 40, alpha=0.8, zorder=2, range=(0,15)) 41 | plt.hist(good[-1], 40, alpha=0.8, zorder=1, range=(0,15)) 42 | plt.legend([f"Training step {x[0]//1000}k", f"Training step {x[-1]//1000}k"]) 43 | plt.xlabel("Loss") 44 | plt.ylabel("Count") 45 | fig.savefig(f"out/loss_good_hist.{format}", bbox_inches='tight', pad_inches=0.01) 46 | 47 | fig = plt.figure(figsize=[4.5,1.5]) 48 | plt.hist(bad[0], 40, alpha=0.8, zorder=2, range=(0,15)) 49 | plt.hist(bad[-1], 40, alpha=0.8, zorder=1, range=(0,15)) 50 | plt.legend([f"Training step {x[0]//1000}k", f"Training step {x[-1]//1000}k"]) 51 | plt.xlabel("Loss") 52 | plt.ylabel("Count") 53 | fig.savefig(f"out/loss_bad_hist.{format}", bbox_inches='tight', pad_inches=0.01) 54 | 55 | break -------------------------------------------------------------------------------- /paper/plot_pcfg.py: -------------------------------------------------------------------------------- 1 | import lib 2 | import matplotlib.pyplot as plt 3 | import matplotlib 4 | import os 5 | from collections import OrderedDict 6 | 7 | format = "pdf" 8 | os.makedirs("out", exist_ok=True) 9 | 10 | runs = lib.get_runs(["pcfg_nosched_productivity"]) 11 | runs = lib.common.group(runs, ['transformer.variant']) 12 | del runs["transformer.variant_opennmt"] 13 | print(list(runs.keys())) 14 | 15 | def get(runs): 16 | acc = {} 17 | loss = {} 18 | 19 | for r in runs: 20 | hist = r.history(keys=["validation/val/accuracy/total", "validation/val/loss", "iteration"], pandas=False) 21 | for p in hist: 22 | i = p["iteration"] 23 | if i not in acc: 24 | acc[i] = lib.StatTracker() 25 | loss[i] = lib.StatTracker() 26 | 27 | loss[i].add(p["validation/val/loss"]) 28 | acc[i].add(p["validation/val/accuracy/total"] * 100) 29 | 30 | x = list(sorted(acc.keys())) 31 | acc = [acc[i].get() for i in x] 32 | loss = [loss[i].get() for i in x] 33 | 34 | return x, acc, loss 35 | 36 | data = {k: get(v) for k, v in runs.items()} 37 | 38 | d = OrderedDict() 39 | d["Standard"] = "transformer.variant_scaledinit" 40 | d["Uni."] = "transformer.variant_universal_noscale" 41 | d["Rel. Uni."] = "transformer.variant_relative_universal" 42 | 43 | fig = plt.figure(figsize=[4.5,1.5]) 44 | for k in d.values(): 45 | plt.plot(data[k][0], [y.mean for y in data[k][1]]) 46 | plt.fill_between(data[k][0], [a.mean - a.std for a in data[k][1]], [a.mean + a.std for a in data[k][1]], alpha=0.3) 47 | 48 | plt.legend(d.keys()) 49 | fig.axes[0].xaxis.set_major_formatter(lambda x, _: f"{x//1000:.0f}k" if x >= 1000 else f"{x:.0f}") 50 | plt.xlabel("Training steps") 51 | plt.ylabel("Accuracy [\\%]") 52 | plt.xlim(0,300000) 53 | plt.ylim(0,90) 54 | fig.savefig(f"out/pcfg_accuracy.{format}", bbox_inches='tight', pad_inches=0.01) 55 | 56 | 57 | fig = plt.figure(figsize=[4.5,1.5]) 58 | for k in d.values(): 59 | plt.plot(data[k][0], [y.mean for y in data[k][2]]) 60 | plt.fill_between(data[k][0], [a.mean - a.std for a in data[k][2]], [a.mean + a.std for a in data[k][2]], alpha=0.3) 61 | 62 | plt.legend(d.keys(), ncol=3) 63 | plt.xlabel("Training steps") 64 | plt.ylabel("Loss") 65 | plt.xlim(0,300000) 66 | fig.axes[0].xaxis.set_major_formatter(lambda x, _: f"{x//1000:.0f}k" if x >= 1000 else f"{x:.0f}") 67 | fig.savefig(f"out/pcfg_loss.{format}", bbox_inches='tight', pad_inches=0.01) 68 | -------------------------------------------------------------------------------- /paper/plot_relatrafo_convergece.py: -------------------------------------------------------------------------------- 1 | import lib 2 | import matplotlib.pyplot as plt 3 | import os 4 | from collections import OrderedDict 5 | import statistics 6 | 7 | format = "pdf" 8 | os.makedirs("out", exist_ok=True) 9 | 10 | 11 | columns = OrderedDict() 12 | columns["Trafo"] = ["scaledinit", "noscale", "opennmt"] 13 | columns["Uni. Trafo"] = ["universal_scaledinit", "universal_noscale", "universal_opennmt"] 14 | columns["Rel. Trafo"] = ["relative"] 15 | columns["Rel. Uni. Trafo"] = ["relative_universal"] 16 | 17 | 18 | variants = OrderedDict() 19 | # Opennmt variant unstable, skip 20 | variants["PCFG"] = [r for r in lib.get_runs(["pcfg_nosched_productivity"]) if r.config["transformer.variant"] != "opennmt"], "val" 21 | cfq_runs = lib.get_runs(["cfq_mcd", "cfq_mcd_2seed", "cfq_mcd_universal", "cfq_mcd_universal_2seed"]) 22 | variants["CFQ MCD 1"] = [r for r in cfq_runs if r.config["cfq.split"] == "mcd1"], "test" 23 | variants["CFQ MCD 2"] = [r for r in cfq_runs if r.config["cfq.split"] == "mcd2"], "test" 24 | variants["CFQ MCD 3"] = [r for r in cfq_runs if r.config["cfq.split"] == "mcd3"], "test" 25 | variants["COGS"] = lib.get_runs(["cogs_trafo"]), "gen" 26 | math_runs = lib.get_runs(["dm_math"]) 27 | variants["Math: add\\_or\\_sub"] = [r for r in math_runs if r.config["dm_math.task"] == "arithmetic__add_or_sub"], "extrapolate" 28 | # Relative variant crashes, skip it 29 | variants["Math: place\\_value"] = [r for r in math_runs if r.config["dm_math.task"] == "numbers__place_value" and r.config["transformer.variant"] != "relative"], "extrapolate" 30 | 31 | 32 | def calculate_converged_iter(runs, key="val", threshold = 0.80): 33 | stop_list = [] 34 | for r in runs: 35 | hist = r.history(keys=[f"validation/{key}/accuracy/total", "iteration"], pandas=False) 36 | x = list(sorted(a["iteration"] for a in hist)) 37 | y = [None for _ in x] 38 | 39 | for h in hist: 40 | y[x.index(h["iteration"])] = h[f"validation/{key}/accuracy/total"] 41 | 42 | th = max(y) * threshold 43 | first = None 44 | for i, a in enumerate(y): 45 | if a >= th: 46 | first = i 47 | break 48 | 49 | interp = x[first - 1] + (x[first] - x[first-1])/(y[first] - y[first-1]) * (th - y[first-1]) 50 | stop_list.append(interp) 51 | 52 | return statistics.median(stop_list) 53 | 54 | all_numbers = [] 55 | 56 | for vname, (runs, key) in variants.items(): 57 | runs = lib.common.group(runs, ['transformer.variant']) 58 | 59 | line = [] 60 | for varlist in columns.values(): 61 | found = [] 62 | for v in varlist: 63 | v = f"transformer.variant_{v}" 64 | if v in runs: 65 | found.append(calculate_converged_iter(runs[v], key)) 66 | 67 | line.append(min(found) if found else None) 68 | 69 | all_numbers.append(line) 70 | 71 | best = min(l for l in line if l) 72 | s_line = [(f"{f/1000:.0f}k" if f>10000 else f"{f/1000:.1f}k") if f else "-" for f in line] 73 | s_line = [f"\\textbf{{{s}}}" if n==best else s for s, n in zip(s_line, line)] 74 | 75 | print(vname + " & " + " & ".join(s_line)) 76 | 77 | 78 | fig = plt.figure(figsize=[4.5,2]) 79 | 80 | speedups = [] 81 | for vname, l in zip(variants.keys(), all_numbers): 82 | speedups.append((l[0] / l[2]) if l[0] and l[2] else 0) 83 | speedups.append((l[1] / l[3]) if l[1] and l[3] else 0) 84 | 85 | pos = [2.5 * i for i in range(len(variants))] 86 | plt.barh(pos, speedups[::2]) 87 | plt.barh([p + 1 for p in pos], speedups[1::2]) 88 | plt.yticks([p + 0.5 for p in pos], list(variants.keys())) 89 | plt.legend(["Trafo", "Uni. Trafo"]) 90 | plt.axvline(x=1.0, color="red", zorder=-100, linestyle="-") 91 | fig.savefig(f"out/convergence_transformer.{format}", bbox_inches='tight', pad_inches=0.01) 92 | -------------------------------------------------------------------------------- /paper/plot_scan_eos_performance.py: -------------------------------------------------------------------------------- 1 | import lib 2 | from collections import OrderedDict 3 | 4 | runs = lib.get_runs(["scan_trafo_length_cutoff"]) 5 | runs = lib.common.group(runs, ['transformer.variant', "scan.length_cutoff"]) 6 | 7 | variants = OrderedDict() 8 | variants["Trafo"] = ["scaledinit", "noscale", "opennmt"] 9 | variants["Uni. Trafo"] = ["universal_scaledinit", "universal_noscale", "universal_opennmt"] 10 | variants["Rel. Trafo"] = ["relative"] 11 | variants["Rel. Uni. Trafo"] = ["relative_universal"] 12 | 13 | lengths = [22, 24, 25, 26, 27, 28, 30, 32, 33, 36, 40] 14 | 15 | best = [0.58, 0.54, 0.69, 0.82, 0.88, 0.85, 0.89, 0.82, 1.00, 1.00, 1.00] 16 | 17 | stats = lib.common.calc_stat(runs, lambda name: name.endswith("val/accuracy/total"), tracker=lib.MedianTracker) 18 | 19 | ourtab = [] 20 | for i, (v, vlist) in enumerate(variants.items()): 21 | ourtab.append([]) 22 | for l in lengths: 23 | all_stats = [stats.get(f"transformer.variant_{vn}/scan.length_cutoff_{l}") for vn in vlist] 24 | all_stats = [a for a in all_stats if a is not None] 25 | assert all([len(a) == 1 for a in all_stats]) 26 | all_stats = [list(a.values())[0].get() for a in all_stats] 27 | ourtab[-1].append(max(all_stats)) 28 | 29 | for l in ourtab: 30 | for i, v in enumerate(l): 31 | best[i] = max(best[i], v) 32 | 33 | for i, (v, vn) in enumerate(variants.items()): 34 | pstr = [] 35 | for j, val in enumerate(ourtab[i]): 36 | pstr.append(("\\bf" if best[j] - val < 0.02 else "") + f"{val:.2f}") 37 | 38 | print(f"{' & ' if i>0 else ''}\\texttt{{{v}}}\\xspace & {' & '.join(pstr)} \\\\") 39 | -------------------------------------------------------------------------------- /paper/plot_small_batch.py: -------------------------------------------------------------------------------- 1 | import lib 2 | from collections import OrderedDict 3 | import matplotlib.pyplot as plt 4 | import os 5 | 6 | format = "pdf" 7 | os.makedirs("out", exist_ok=True) 8 | 9 | columns = OrderedDict() 10 | columns["Trafo"] = ["scaledinit", "noscale", "opennmt"] 11 | columns["Uni. Trafo"] = ["universal_scaledinit", "universal_noscale", "universal_opennmt"] 12 | columns["Rel. Trafo"] = ["relative"] 13 | columns["Rel. Uni. Trafo"] = ["relative_universal"] 14 | 15 | 16 | cfq_big_runs = lib.get_runs(["cfq_mcd", "cfq_mcd_universal"]) 17 | cfq_small_runs = lib.get_runs(["cfq_mcd_small_batch", "cfq_mcd_small_batch_universal"]) 18 | variants = OrderedDict() 19 | variants["CFQ MCD 1"] = [r for r in cfq_big_runs if r.config["cfq.split"] == "mcd1"], [r for r in cfq_small_runs if r.config["cfq.split"] == "mcd1"], "test" 20 | variants["CFQ MCD 2"] = [r for r in cfq_big_runs if r.config["cfq.split"] == "mcd2"], [r for r in cfq_small_runs if r.config["cfq.split"] == "mcd2"], "test" 21 | variants["CFQ MCD 3"] = [r for r in cfq_big_runs if r.config["cfq.split"] == "mcd3"], [r for r in cfq_small_runs if r.config["cfq.split"] == "mcd3"], "test" 22 | variants["CFQ Out. len."] = lib.get_runs(["cfq_out_length", "cfq_out_length_universal"]), lib.get_runs(["cfq_out_length_small_batch", "cfq_out_length_universal_small_batch"]), "test" 23 | 24 | 25 | def average_accuracy(runs, split_name): 26 | st = lib.StatTracker() 27 | runs = list(runs) 28 | it = max([r.summary["iteration"] for r in runs]) 29 | for r in runs: 30 | st.add(r.summary[f"validation/{split_name}/accuracy/total"]) 31 | assert r.summary["iteration"] == it 32 | return st.get() 33 | 34 | drops = [] 35 | details = [] 36 | 37 | for big_runs, small_runs, split_name in variants.values(): 38 | bgroup = lib.common.group(big_runs, ['transformer.variant']) 39 | sgroup = lib.common.group(small_runs, ['transformer.variant']) 40 | 41 | for vlist in columns.values(): 42 | bacc = max([average_accuracy(bgroup[f"transformer.variant_{v}"], split_name) for v in vlist if f"transformer.variant_{v}" in bgroup], key=lambda x: x.mean) 43 | sacc = max([average_accuracy(sgroup[f"transformer.variant_{v}"], split_name) for v in vlist if f"transformer.variant_{v}" in sgroup], key=lambda x: x.mean) 44 | 45 | details.append((bacc,sacc)) 46 | drops.append(sacc.mean/bacc.mean) 47 | 48 | 49 | print("& " + " & ".join(columns.keys())+"\\\\") 50 | print("\\midrule") 51 | for i, vname in enumerate(variants.keys()): 52 | best = max(range(len(columns)), key = lambda j: drops[i * len(columns) + j]) 53 | print(vname + " & " + " & ".join([("\\bf" if j==best else "") + f"{drops[i*len(columns) + j]:.2f}" for j in range(len(columns))])+ "\\\\") 54 | 55 | 56 | 57 | print("& Variant & " + " & ".join(columns.keys())+"\\\\") 58 | for i, vname in enumerate(variants.keys()): 59 | col = [details[i*len(columns) + j] for j in range(len(columns))] 60 | 61 | print("\\midrule") 62 | print(f"\\multirow{{3}}{{*}}{{{vname}}} & Big & "+ " & ".join([f"${c[0].mean:.2f}\\pm{c[0].std:.2f}$" for c in col])+ "\\\\") 63 | print(f" & Small & "+ " & ".join([f"${c[1].mean:.2f}\\pm{c[1].std:.2f}$" for c in col])+ "\\\\") 64 | 65 | best = max(range(len(columns)), key = lambda j: drops[i * len(columns) + j]) 66 | print("\\cmidrule{2-6}") 67 | print(" & Proportion & " + " & ".join([("\\bf" if j==best else "") + f"{drops[i*len(columns) + j]:.2f}" for j in range(len(columns))])+ "\\\\") 68 | 69 | pos = [(len(columns) + 1.5) * i for i in range(len(variants))] 70 | fig = plt.figure(figsize=[4.5,2]) 71 | for i in range(len(columns)): 72 | plt.barh([p + i + 0.5 * (i//2) for p in pos], drops[i::len(columns)]) 73 | 74 | plt.yticks([p + (len(columns)-1 + 0.5)/2 for p in pos], list(variants.keys())) 75 | plt.legend(list(columns.keys()), loc='upper left') 76 | plt.axvline(x=1.0, color="red", zorder=-100, linestyle="-") 77 | fig.savefig(f"out/small_batch_transformer.{format}", bbox_inches='tight', pad_inches=0.01) 78 | -------------------------------------------------------------------------------- /paper/run_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Remove old plots 4 | rm -r out 2>/dev/null 5 | 6 | # Run all plots 7 | for s in *.py; do 8 | echo "Running $s" 9 | /usr/bin/env python3 $s 10 | done 11 | -------------------------------------------------------------------------------- /plot_dataset_stats.py: -------------------------------------------------------------------------------- 1 | import dataset 2 | from collections import OrderedDict 3 | from multiprocessing import Process, Queue, cpu_count 4 | 5 | datasets = OrderedDict() 6 | 7 | datasets["0"] = None, None, None, None 8 | datasets["Scan (length cutoff=26)"] = (dataset.ScanLengthResplit("train", (0, 26)), 9 | dataset.ScanLengthResplit("test", (0, 26)), 10 | dataset.ScanLengthResplit("all", (27, 9999)), 11 | None) 12 | 13 | 14 | datasets["a"] = None, None, None, None 15 | for i in range(1,4): 16 | datasets[f"CFQ MCD {i}"] = (dataset.CFQ(["train"], split_type=[f"mcd{i}"]), 17 | None, 18 | dataset.CFQ(["test"], split_type=[f"mcd{i}"]), 19 | dataset.CFQ(["val"], split_type=[f"mcd{i}"])) 20 | 21 | datasets["CFQ Output Length"]= (dataset.CFQ(["train"], split_type=[f"query_complexity"]), 22 | None, 23 | dataset.CFQ(["test"], split_type=[f"query_complexity"]), 24 | dataset.CFQ(["val"], split_type=[f"query_complexity"])) 25 | 26 | 27 | datasets["b"] = None, None, None, None 28 | datasets["PCFG Productivity"]= (dataset.PCFGSet(["train"], split_type=["productivity"]), 29 | None, 30 | dataset.PCFGSet(["test"], split_type=["productivity"]), 31 | None) 32 | 33 | datasets["PCFG Systematicity"]= (dataset.PCFGSet(["train"], split_type=["systematicity"]), 34 | None, 35 | dataset.PCFGSet(["test"], split_type=["systematicity"]), 36 | None) 37 | 38 | datasets["c"] = None, None, None, None 39 | datasets["COGS"] = (dataset.COGS(["train"]), 40 | dataset.COGS(["valid"]), 41 | dataset.COGS(["gen"]), 42 | None) 43 | 44 | 45 | datasets["d"] = None, None, None, None 46 | datasets["Math: add\\_or\\_sub"] = (dataset.DeepmindMathDataset(["arithmetic__add_or_sub"], 47 | sets=["train_easy", "train_medium", "train_hard"]), 48 | dataset.DeepmindMathDataset(["arithmetic__add_or_sub"], sets=["interpolate"]), 49 | dataset.DeepmindMathDataset(["arithmetic__add_or_sub"], sets=["extrapolate"]), 50 | None) 51 | 52 | datasets["Math: place\\_value"] = (dataset.DeepmindMathDataset(["numbers__place_value"], 53 | sets=["train_easy", "train_medium", "train_hard"]), 54 | dataset.DeepmindMathDataset(["numbers__place_value"], sets=["interpolate"]), 55 | dataset.DeepmindMathDataset(["numbers__place_value"], sets=["extrapolate"]), 56 | None) 57 | 58 | 59 | def get_len(ds): 60 | nproc = cpu_count() * 2 61 | ranges = [] 62 | prev = 0 63 | step = len(ds)//nproc 64 | for _ in range(nproc): 65 | next = prev + step 66 | ranges.append([prev, next]) 67 | prev = next 68 | ranges[-1][-1] = len(ds) 69 | 70 | q = Queue() 71 | def cnt(r): 72 | mo = 0 73 | mi = 0 74 | for i in range(*r): 75 | mo = max(mo, ds[i]["out_len"]) 76 | mi = max(mi, ds[i]["in_len"]) 77 | q.put((mo, mi)) 78 | 79 | procs = [Process(target=cnt, args=(r,)) for r in ranges] 80 | for p in procs: 81 | p.start() 82 | 83 | for p in procs: 84 | p.join() 85 | 86 | lens = [q.get() for _ in procs] 87 | return max([l[0] for l in lens]), max([l[1] for l in lens]) 88 | 89 | print("Dataset & \\# train & \\# IID valid. & \\# gen. test & \\# gen. valid. & Voc. size & Train len. & Test len.\\\\") 90 | for name, dsdesc in datasets.items(): 91 | train, val_iid, test_gen, val_gen = dsdesc 92 | if train is None: 93 | print("\\midrule") 94 | continue 95 | 96 | print(f"{name} ", end="") 97 | for ds in dsdesc: 98 | print(" & " + (f"${len(ds)}$" if ds is not None else '-'), end="") 99 | 100 | max_train_out, max_train_in = get_len(train) 101 | max_test_out, max_test_in = get_len(test_gen) 102 | print(f" & ${len(train.in_vocabulary+train.out_vocabulary)}$ & ${max_train_in}$/${max_train_out}$ & ${max_test_in}$/${max_test_out}$ \\\\") 103 | -------------------------------------------------------------------------------- /requrements.txt: -------------------------------------------------------------------------------- 1 | torch==1.9.0 2 | tqdm==4.59.0 3 | psutil==5.8.0 4 | matplotlib==3.3.4 5 | tensorboard==2.4.1 6 | future==0.18.2 7 | filelock==3.0.12 8 | setproctitle==1.1.10 9 | wandb==0.12.0 10 | dataclasses=0.8 11 | Pillow==8.3.1 -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import yaml 4 | import sys 5 | import subprocess 6 | import os 7 | 8 | if len(sys.argv) != 2: 9 | print(f"Usage: {sys.argv[0]} ") 10 | sys.exit(-1) 11 | 12 | with open(sys.argv[1]) as f: 13 | config = yaml.safe_load(f) 14 | 15 | args = [] 16 | 17 | for p, pval in config["parameters"].items(): 18 | if p in ["log", "sweep_id_for_grid_search"]: 19 | continue 20 | 21 | args.append("-" + p) 22 | if "value" in pval: 23 | assert "values" not in pval 24 | args.append(pval["value"]) 25 | elif "values" in pval: 26 | if len(pval["values"]) == 1: 27 | args.append(pval["values"][0]) 28 | else: 29 | while True: 30 | print(f"Choose value for \"{p}\"") 31 | for i, v in enumerate(pval["values"]): 32 | print(f" {i+1}: {v}") 33 | 34 | choice = input("> ") 35 | if not choice.isdigit() or int(choice) < 1 or int(choice) > len(pval["values"]): 36 | print("Invalid choice.") 37 | continue 38 | 39 | args.append(pval["values"][int(choice) - 1]) 40 | break 41 | 42 | if "name" not in config["parameters"]: 43 | args.append("-name") 44 | args.append(os.path.basename(sys.argv[1]).replace(".yaml", "")) 45 | 46 | replace = { 47 | "${env}": "", 48 | "${program}": config["program"], 49 | "${args}": " ".join([str(a) for a in args]) 50 | } 51 | 52 | cmd = (" ".join([replace.get(c, c) for c in config["command"]])).strip() 53 | print(f"Running {cmd}") 54 | subprocess.run(cmd, shell=True) 55 | -------------------------------------------------------------------------------- /sweeps/cfq_mcd.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python3 5 | - ${program} 6 | - ${args} 7 | method: grid 8 | metric: 9 | name: validation/mean_accuracy 10 | goal: maximize 11 | parameters: 12 | log: 13 | value: wandb 14 | profile: 15 | value: cfq_trafo 16 | transformer.variant: 17 | values: 18 | - relative 19 | - noscale 20 | - scaledinit 21 | cfq.split: 22 | values: 23 | - mcd1 24 | - mcd2 25 | - mcd3 26 | amp: 27 | value: 1 28 | lr: 29 | value: 0.9 30 | lr_sched.type: 31 | value: noam 32 | lr_warmup: 33 | value: 4000 34 | batch_size: 35 | value: 4096 36 | sweep_id_for_grid_search: 37 | distribution: categorical 38 | values: 39 | - 1 40 | - 2 41 | - 3 42 | - 4 43 | - 5 44 | -------------------------------------------------------------------------------- /sweeps/cfq_mcd_small_batch.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python3 5 | - ${program} 6 | - ${args} 7 | method: grid 8 | metric: 9 | name: validation/mean_accuracy 10 | goal: maximize 11 | parameters: 12 | log: 13 | value: wandb 14 | profile: 15 | value: cfq_trafo 16 | transformer.variant: 17 | values: 18 | - relative 19 | - noscale 20 | cfq.split: 21 | values: 22 | - mcd1 23 | - mcd2 24 | - mcd3 25 | max_length_per_batch: 26 | value: 50 27 | amp: 28 | value: 1 29 | sweep_id_for_grid_search: 30 | distribution: categorical 31 | values: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | - 5 37 | 38 | -------------------------------------------------------------------------------- /sweeps/cfq_mcd_small_batch_universal.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python3 5 | - ${program} 6 | - ${args} 7 | method: grid 8 | metric: 9 | name: validation/mean_accuracy 10 | goal: maximize 11 | parameters: 12 | log: 13 | value: wandb 14 | profile: 15 | value: cfq_universal_trafo 16 | transformer.variant: 17 | values: 18 | - relative_universal 19 | - universal_noscale 20 | cfq.split: 21 | values: 22 | - mcd1 23 | - mcd2 24 | - mcd3 25 | max_length_per_batch: 26 | value: 50 27 | amp: 28 | value: 1 29 | sweep_id_for_grid_search: 30 | distribution: categorical 31 | values: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | - 5 37 | -------------------------------------------------------------------------------- /sweeps/cfq_mcd_universal.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python3 5 | - ${program} 6 | - ${args} 7 | method: grid 8 | metric: 9 | name: validation/mean_accuracy 10 | goal: maximize 11 | parameters: 12 | log: 13 | value: wandb 14 | profile: 15 | value: cfq_universal_trafo 16 | transformer.variant: 17 | values: 18 | - relative_universal 19 | - universal_noscale 20 | - universal_scaledinit 21 | cfq.split: 22 | values: 23 | - mcd1 24 | - mcd2 25 | - mcd3 26 | amp: 27 | value: 1 28 | lr: 29 | value: 2.24 30 | lr_sched.type: 31 | value: noam 32 | lr_warmup: 33 | value: 8000 34 | batch_size: 35 | value: 2048 36 | sweep_id_for_grid_search: 37 | distribution: categorical 38 | values: 39 | - 1 40 | - 2 41 | - 3 42 | - 4 43 | - 5 44 | -------------------------------------------------------------------------------- /sweeps/cfq_out_length.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python3 5 | - ${program} 6 | - ${args} 7 | method: grid 8 | metric: 9 | name: validation/mean_accuracy 10 | goal: maximize 11 | parameters: 12 | log: 13 | value: wandb 14 | profile: 15 | value: cfq_trafo 16 | transformer.variant: 17 | values: 18 | - relative 19 | - opennmt 20 | - scaledinit 21 | - noscale 22 | cfq.split: 23 | value: query_complexity 24 | amp: 25 | value: 1 26 | lr: 27 | value: 0.9 28 | lr_sched.type: 29 | value: noam 30 | lr_warmup: 31 | value: 4000 32 | batch_size: 33 | value: 4096 34 | max_length_per_batch: 35 | value: 9999 36 | sweep_id_for_grid_search: 37 | distribution: categorical 38 | values: 39 | - 1 40 | - 2 41 | - 3 42 | - 4 43 | - 5 44 | -------------------------------------------------------------------------------- /sweeps/cfq_out_length_small_batch.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python3 5 | - ${program} 6 | - ${args} 7 | method: grid 8 | metric: 9 | name: validation/mean_accuracy 10 | goal: maximize 11 | parameters: 12 | log: 13 | value: wandb 14 | profile: 15 | value: cfq_trafo 16 | transformer.variant: 17 | values: 18 | - relative 19 | - opennmt 20 | - scaledinit 21 | cfq.split: 22 | value: query_complexity 23 | amp: 24 | value: 1 25 | max_length_per_batch: 26 | value: 50 27 | sweep_id_for_grid_search: 28 | distribution: categorical 29 | values: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | - 5 35 | -------------------------------------------------------------------------------- /sweeps/cfq_out_length_universal.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python3 5 | - ${program} 6 | - ${args} 7 | method: grid 8 | metric: 9 | name: validation/mean_accuracy 10 | goal: maximize 11 | parameters: 12 | log: 13 | value: wandb 14 | profile: 15 | value: cfq_universal_trafo 16 | transformer.variant: 17 | values: 18 | - relative_universal 19 | - universal_noscale 20 | - universal_scaledinit 21 | - universal_opennmt 22 | cfq.split: 23 | value: query_complexity 24 | amp: 25 | value: 1 26 | lr: 27 | value: 2.24 28 | lr_sched.type: 29 | value: noam 30 | lr_warmup: 31 | value: 8000 32 | batch_size: 33 | value: 2048 34 | sweep_id_for_grid_search: 35 | distribution: categorical 36 | values: 37 | - 1 38 | - 2 39 | - 3 40 | - 4 41 | - 5 -------------------------------------------------------------------------------- /sweeps/cfq_out_length_universal_small_batch.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python3 5 | - ${program} 6 | - ${args} 7 | method: grid 8 | metric: 9 | name: validation/mean_accuracy 10 | goal: maximize 11 | parameters: 12 | log: 13 | value: wandb 14 | profile: 15 | value: cfq_universal_trafo 16 | transformer.variant: 17 | values: 18 | - relative_universal 19 | - universal_noscale 20 | cfq.split: 21 | value: query_complexity 22 | amp: 23 | value: 1 24 | sweep_id_for_grid_search: 25 | distribution: categorical 26 | values: 27 | - 1 28 | - 2 29 | - 3 30 | - 4 31 | - 5 32 | -------------------------------------------------------------------------------- /sweeps/cogs_trafo.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python3 5 | - ${program} 6 | - ${args} 7 | method: grid 8 | metric: 9 | name: validation/mean_accuracy 10 | goal: maximize 11 | parameters: 12 | name: 13 | value: cogs_trafo_small 14 | log: 15 | value: wandb 16 | profile: 17 | value: cogs_trafo_small 18 | transformer.variant: 19 | values: 20 | - opennmt 21 | - noscale 22 | - scaledinit 23 | - universal_noscale 24 | - relative 25 | - relative_universal 26 | - universal_scaledinit 27 | - universal_opennmt 28 | lr_sched.type: 29 | value: step 30 | grad_clip: 31 | value: 1.0 32 | lr_warmup: 33 | value: 0 34 | lr: 35 | value: 0.0001 36 | cogs.generalization_test_interval: 37 | value: 500 38 | test_interval: 39 | value: 500 40 | test_batch_size: 41 | value: 512 42 | stop_after: 43 | value: 50000 44 | sweep_id_for_grid_search: 45 | distribution: categorical 46 | values: 47 | - 1 48 | - 2 49 | - 3 50 | - 4 51 | - 5 52 | -------------------------------------------------------------------------------- /sweeps/cogs_trafo_official.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python3 5 | - ${program} 6 | - ${args} 7 | method: grid 8 | metric: 9 | name: validation/mean_accuracy 10 | goal: maximize 11 | parameters: 12 | name: 13 | value: cogs_trafo_small 14 | log: 15 | value: wandb 16 | profile: 17 | value: cogs_trafo_small 18 | transformer.variant: 19 | value: opennmt 20 | cogs.generalization_test_interval: 21 | value: 500 22 | test_interval: 23 | value: 500 24 | test_batch_size: 25 | value: 512 26 | stop_after: 27 | value: 50000 28 | sweep_id_for_grid_search: 29 | distribution: categorical 30 | values: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | - 5 36 | -------------------------------------------------------------------------------- /sweeps/dm_math.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python3 5 | - ${program} 6 | - ${args} 7 | method: grid 8 | metric: 9 | name: validation/mean_accuracy 10 | goal: maximize 11 | parameters: 12 | name: 13 | value: trafo_scan 14 | log: 15 | value: wandb 16 | profile: 17 | value: deepmind_math 18 | task: 19 | value: dm_math_transformer 20 | transformer.variant: 21 | values: 22 | - relative 23 | - noscale 24 | - universal_noscale 25 | - relative_universal 26 | - universal_scaledinit 27 | - scaledinit 28 | dm_math.task: 29 | values: 30 | - numbers__place_value 31 | - arithmetic__add_or_sub 32 | lr: 33 | value: 1e-4 34 | stop_after: 35 | value: 50000 36 | batch_size: 37 | value: 256 38 | amp: 39 | value: 1 40 | sweep_id_for_grid_search: 41 | distribution: categorical 42 | values: 43 | - 1 44 | - 2 45 | - 3 46 | - 4 47 | - 5 48 | -------------------------------------------------------------------------------- /sweeps/pcfg_nosched_iid.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python3 5 | - ${program} 6 | - ${args} 7 | method: grid 8 | metric: 9 | name: validation/mean_accuracy 10 | goal: maximize 11 | parameters: 12 | name: 13 | value: scan 14 | log: 15 | value: wandb 16 | profile: 17 | value: pcfg_trafo 18 | task: 19 | value: pcfg_transformer 20 | transformer.variant: 21 | values: 22 | - scaledinit 23 | - opennmt 24 | - noscale 25 | pcfg.split: 26 | value: simple 27 | lr: 28 | value: 1e-4 29 | stop_after: 30 | value: 300000 31 | amp: 32 | value: 1 33 | sweep_id_for_grid_search: 34 | distribution: categorical 35 | values: 36 | - 1 37 | - 2 38 | - 3 39 | - 4 40 | - 5 41 | -------------------------------------------------------------------------------- /sweeps/pcfg_nosched_productivity.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python3 5 | - ${program} 6 | - ${args} 7 | method: grid 8 | metric: 9 | name: validation/mean_accuracy 10 | goal: maximize 11 | parameters: 12 | name: 13 | value: scan 14 | log: 15 | value: wandb 16 | profile: 17 | value: pcfg_trafo 18 | task: 19 | value: pcfg_transformer 20 | transformer.variant: 21 | values: 22 | - relative_universal 23 | - universal_noscale 24 | - universal_scaledinit 25 | - universal_opennmt 26 | - scaledinit 27 | - opennmt 28 | - noscale 29 | pcfg.split: 30 | value: productivity 31 | lr: 32 | value: 1e-4 33 | stop_after: 34 | value: 300000 35 | amp: 36 | value: 1 37 | sweep_id_for_grid_search: 38 | distribution: categorical 39 | values: 40 | - 1 41 | - 2 42 | - 3 43 | - 4 44 | - 5 45 | -------------------------------------------------------------------------------- /sweeps/pcfg_nosched_systematicity.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python3 5 | - ${program} 6 | - ${args} 7 | method: grid 8 | metric: 9 | name: validation/mean_accuracy 10 | goal: maximize 11 | parameters: 12 | name: 13 | value: scan 14 | log: 15 | value: wandb 16 | profile: 17 | value: pcfg_trafo 18 | task: 19 | value: pcfg_transformer 20 | transformer.variant: 21 | values: 22 | - relative_universal 23 | - relative 24 | - universal_noscale 25 | - scaledinit 26 | - universal_scaledinit 27 | - noscale 28 | - opennmt 29 | - universal_opennmt 30 | pcfg.split: 31 | value: systematicity 32 | lr: 33 | value: 1e-4 34 | stop_after: 35 | value: 300000 36 | amp: 37 | value: 1 38 | sweep_id_for_grid_search: 39 | distribution: categorical 40 | values: 41 | - 1 42 | - 2 43 | - 3 44 | - 4 45 | - 5 46 | -------------------------------------------------------------------------------- /sweeps/scan_trafo_length_cutoff.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | command: 3 | - ${env} 4 | - python3 5 | - ${program} 6 | - ${args} 7 | method: grid 8 | metric: 9 | name: validation/mean_accuracy 10 | goal: maximize 11 | parameters: 12 | name: 13 | value: trafo_scan_length_cutoff 14 | log: 15 | value: wandb 16 | profile: 17 | value: trafo_scan 18 | stop_after: 19 | value: 50000 20 | scan.length_cutoff: 21 | values: 22 | - 22 23 | - 24 24 | - 25 25 | - 26 26 | - 27 27 | - 28 28 | - 30 29 | - 32 30 | - 33 31 | - 36 32 | - 40 33 | transformer.variant: 34 | values: 35 | - noscale 36 | - universal_noscale 37 | - relative 38 | - relative_universal 39 | - scaledinit 40 | - universal_scaledinit 41 | task: 42 | value: scan_resplit_transformer 43 | sweep_id_for_grid_search: 44 | distribution: categorical 45 | values: 46 | - 1 47 | - 2 48 | - 3 49 | - 4 50 | - 5 51 | -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .pcfg_transformer import PCFGTransformer 2 | from .cogs_transofrmer import COGSTransformer 3 | from .scan_transformer import ScanTransformer 4 | from .scan_resplit_transformer import ScanResplitTransformer 5 | from .cfq_transformer import CFQTransformer 6 | from .dm_math_transformer import DMMathTransformer 7 | -------------------------------------------------------------------------------- /tasks/cfq_transformer.py: -------------------------------------------------------------------------------- 1 | import dataset 2 | from .task import Task 3 | from .transformer_mixin import TransformerMixin 4 | from typing import Tuple, Any 5 | import framework 6 | import dataset.sequence 7 | import numpy as np 8 | import torch 9 | 10 | 11 | class CFQTransformer(TransformerMixin, Task): 12 | VALID_NUM_WORKERS = 0 13 | MAX_LENGHT_PER_BATCH = 101 14 | 15 | def create_datasets(self): 16 | self.batch_dim = 1 17 | self.train_set = dataset.CFQ(["train"], split_type=[self.helper.args.cfq.split]) 18 | self.valid_sets.val = dataset.CFQ(["val"], split_type=[self.helper.args.cfq.split]) 19 | self.valid_sets.test = dataset.CFQ(["test"], split_type=[self.helper.args.cfq.split]) 20 | 21 | def __init__(self, helper: framework.helpers.TrainingHelper): 22 | super().__init__(helper) 23 | self.init_valid_details() 24 | 25 | def init_valid_details(self): 26 | self.helper.state.full_loss_log = {} 27 | 28 | def validate_on_name(self, name: str) -> Tuple[Any, float]: 29 | res, loss = self.validate_on(self.valid_sets[name], self.valid_loaders[name]) 30 | if self.helper.args.log_sample_level_loss and isinstance(res, dataset.sequence.TextSequenceTestState): 31 | losses, oks = res.get_sample_info() 32 | if name not in self.helper.state.full_loss_log: 33 | self.helper.state.full_loss_log[name] = [], [] 34 | 35 | self.helper.state.full_loss_log[name][0].append(losses) 36 | self.helper.state.full_loss_log[name][1].append(oks) 37 | 38 | return res, loss 39 | 40 | def save_valid_details(self): 41 | for name, (losses, oks) in self.helper.state.full_loss_log.items(): 42 | losses = np.asfarray(losses) 43 | oks = np.asarray(oks, dtype=np.bool) 44 | torch.save({"losses": losses, "oks": oks}, self.helper.get_storage_path(f"loss_details/{name}.pth")) 45 | 46 | def train(self): 47 | super().train() 48 | self.save_valid_details() 49 | -------------------------------------------------------------------------------- /tasks/cogs_transofrmer.py: -------------------------------------------------------------------------------- 1 | import dataset 2 | from typing import Dict, Any 3 | from .task import Task 4 | from interfaces import Result 5 | from .transformer_mixin import TransformerMixin 6 | import framework 7 | 8 | 9 | class COGSTransformer(TransformerMixin, Task): 10 | VALID_NUM_WORKERS = 0 11 | 12 | def __init__(self, helper: framework.helpers.TrainingHelper): 13 | super().__init__(helper) 14 | 15 | def create_datasets(self): 16 | self.batch_dim = 1 17 | self.train_set = dataset.COGS(["train"], shared_vocabulary=True) 18 | self.valid_sets.val = dataset.COGS(["valid"], shared_vocabulary=True) 19 | self.slow_valid_set = dataset.COGS(["gen"], shared_vocabulary=True) 20 | 21 | def create_loaders(self): 22 | super().create_loaders() 23 | self.slow_valid_loader = self.create_valid_loader(self.slow_valid_set) 24 | 25 | def do_generalization_test(self) -> Dict[str, Any]: 26 | d = {} 27 | test, loss = self.validate_on(self.slow_valid_set, self.slow_valid_loader) 28 | 29 | d["validation/gen/loss"] = loss 30 | d.update({f"validation/gen/{k}": v for k, v in test.plot().items()}) 31 | d.update(self.update_best_accuracies("validation/gen", test.accuracy, loss)) 32 | return d 33 | 34 | def plot(self, res: Result) -> Dict[str, Any]: 35 | d = super().plot(res) 36 | if (self.helper.state.iter % self.helper.args.cogs.generalization_test_interval == 0) or \ 37 | (self.helper.state.iter == self.helper.args.test_interval): 38 | d.update(self.do_generalization_test()) 39 | return d 40 | 41 | def train(self): 42 | super().train() 43 | if self.helper.state.iter % self.helper.args.cogs.generalization_test_interval != 0: 44 | # Redo the test, but only if it was not already done 45 | self.helper.summary.log(self.do_generalization_test()) -------------------------------------------------------------------------------- /tasks/dm_math_transformer.py: -------------------------------------------------------------------------------- 1 | import dataset 2 | from .task import Task 3 | from .transformer_mixin import TransformerMixin 4 | 5 | 6 | class DMMathTransformer(TransformerMixin, Task): 7 | def create_datasets(self): 8 | self.batch_dim = 1 9 | self.train_set = dataset.DeepmindMathDataset(self.helper.args.dm_math.tasks, sets=[f"train_{s}" 10 | for s in self.helper.args.dm_math.train_splits]) 11 | 12 | self.valid_sets.interpolate = dataset.DeepmindMathDataset(self.helper.args.dm_math.tasks, sets=["interpolate"]) 13 | self.valid_sets.iid = dataset.DeepmindMathDataset(self.helper.args.dm_math.tasks, sets=[f"test_{s}" for s in 14 | self.helper.args.dm_math.train_splits]) 15 | 16 | extrapolate = dataset.DeepmindMathDataset(self.helper.args.dm_math.tasks, sets=["extrapolate"]) 17 | if len(extrapolate) != 0: 18 | self.valid_sets["extrapolate"] = extrapolate 19 | -------------------------------------------------------------------------------- /tasks/pcfg_transformer.py: -------------------------------------------------------------------------------- 1 | import dataset 2 | from typing import Dict, Any 3 | from .task import Task 4 | from .transformer_mixin import TransformerMixin 5 | 6 | 7 | class PCFGTransformer(TransformerMixin, Task): 8 | VALID_NUM_WORKERS = 0 9 | MAX_LENGHT_PER_BATCH = 200 10 | 11 | def create_datasets(self): 12 | self.batch_dim = 1 13 | self.train_set = dataset.PCFGSet(["train"], split_type=[self.helper.args.pcfg.split], shared_vocabulary=True) 14 | self.valid_sets.val = dataset.PCFGSet(["test"], split_type=[self.helper.args.pcfg.split], 15 | shared_vocabulary=True) 16 | -------------------------------------------------------------------------------- /tasks/scan_resplit_transformer.py: -------------------------------------------------------------------------------- 1 | import dataset 2 | from .scan_transformer import ScanTransformer 3 | 4 | 5 | class ScanResplitTransformer(ScanTransformer): 6 | def create_datasets(self): 7 | self.batch_dim = 1 8 | self.train_set = dataset.ScanLengthResplit("train", (0, self.helper.args.scan.length_cutoff)) 9 | self.valid_sets.val = dataset.ScanLengthResplit("all", (self.helper.args.scan.length_cutoff+1, 9999)) 10 | self.valid_sets.iid = dataset.ScanLengthResplit("test", (0, self.helper.args.scan.length_cutoff)) 11 | -------------------------------------------------------------------------------- /tasks/scan_transformer.py: -------------------------------------------------------------------------------- 1 | import dataset 2 | from .task import Task 3 | from .transformer_mixin import TransformerMixin 4 | 5 | 6 | class ScanTransformer(TransformerMixin, Task): 7 | VALID_NUM_WORKERS = 0 8 | 9 | def create_datasets(self): 10 | self.batch_dim = 1 11 | self.train_set = dataset.Scan(["train"], split_type=self.helper.args.scan.train_split) 12 | self.valid_sets.val = dataset.Scan(["test"], split_type=self.helper.args.scan.train_split) 13 | self.valid_sets.iid = dataset.Scan(["test"]) 14 | -------------------------------------------------------------------------------- /tasks/transformer_mixin.py: -------------------------------------------------------------------------------- 1 | import torch.nn 2 | from layers.transformer import Transformer, UniversalTransformer, RelativeTransformer, UniversalRelativeTransformer 3 | from models import TransformerEncDecModel 4 | from interfaces import TransformerEncDecInterface 5 | 6 | 7 | class TransformerMixin: 8 | def create_model(self) -> torch.nn.Module: 9 | rel_args = dict(pos_embeddig=(lambda x, offset: x), embedding_init="xavier") 10 | trafos = { 11 | "scaledinit": (Transformer, dict(embedding_init="kaiming", scale_mode="down")), 12 | "opennmt": (Transformer, dict(embedding_init="xavier", scale_mode="opennmt")), 13 | "noscale": (Transformer, {}), 14 | "universal_noscale": (UniversalTransformer, {}), 15 | "universal_scaledinit": (UniversalTransformer, dict(embedding_init="kaiming", scale_mode="down")), 16 | "universal_opennmt": (UniversalTransformer, dict(embedding_init="xavier", scale_mode="opennmt")), 17 | "relative": (RelativeTransformer, rel_args), 18 | "relative_universal": (UniversalRelativeTransformer, rel_args) 19 | } 20 | 21 | constructor, args = trafos[self.helper.args.transformer.variant] 22 | 23 | return TransformerEncDecModel(len(self.train_set.in_vocabulary), 24 | len(self.train_set.out_vocabulary), self.helper.args.state_size, 25 | nhead=self.helper.args.transformer.n_heads, 26 | num_encoder_layers=self.helper.args.transformer.encoder_n_layers, 27 | num_decoder_layers=self.helper.args.transformer.decoder_n_layers or \ 28 | self.helper.args.transformer.encoder_n_layers, 29 | ff_multiplier=self.helper.args.transformer.ff_multiplier, 30 | transformer=constructor, 31 | tied_embedding=self.helper.args.transformer.tied_embedding, **args) 32 | 33 | def create_model_interface(self): 34 | self.model_interface = TransformerEncDecInterface(self.model, label_smoothing=self.helper.args.label_smoothing) 35 | --------------------------------------------------------------------------------