├── .gitignore ├── LICENSE ├── README.md ├── actions ├── __init__.py ├── evaluate.py └── train.py ├── args.py ├── data ├── __init__.py └── utils.py ├── main.py ├── models ├── __init__.py ├── attention.py ├── embeddings.py ├── long_range_lm.py ├── transformer.py └── utils.py ├── preprocess ├── encode_eval_data.sh ├── tokenize_eval_data.py └── tokenize_pg19_train.py ├── requirements.txt ├── slurm.py ├── tokenize_eval_data.py └── utils └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SuffixLM 2 | 3 | This is the repository for our NAACL 2022 paper [CHAPTERBREAK: A Challenge Dataset for Long-Range Language Models](https://arxiv.org/pdf/2204.10878.pdf). 4 | 5 | 6 | # Setup 7 | 8 | This repository requires `Python 3.8` and `cuda11/11.1.0`. Install all dependencies by running the following command: 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | 14 | # Data 15 | 16 | ## Data download 17 | 18 | Update: ChapterBreak can be downloaded from HuggingFace now: https://huggingface.co/datasets/simsun131/chapterbreak. 19 | ``` 20 | from datasets import load_dataset 21 | 22 | dataset = load_dataset("simsun131/chapterbreak") 23 | ``` 24 | 25 | Old link: 26 | 27 | Download [ChapterBreak](https://drive.google.com/drive/folders/1JkFHspT56_yRWwXVj47Fw0PzHtitODt5?usp=sharing) with various prefix lengths. Download the long-fanfic data (13,682 filtered long fanfics posted on Archive of Our Own) [here](https://drive.google.com/drive/folders/1Wb5dG6PABOleYGDG9rARVy1nSQ2WhHK2?usp=sharing). Please refer to the Appendix A for more details about the filtering of this dataset. 28 | 29 | ## Data format 30 | 31 | Our data contains two split `pg19` and `ao3`. Each split contains a dictionary where the key is the id of a work, and value is a list of examples extracted from that work 32 | ``` 33 | { 34 | (workid): [{ 35 | 'ctx': (the context preceding the gold suffix) 36 | 'pos': (the gold suffix) 37 | 'negs': (list of 5 negatives) 38 | }, 39 | 40 | ...] 41 | } 42 | ``` 43 | For the `ao3_longfic` dataset, the data is stored as following format: 44 | ``` 45 | { 46 | (workid): { 47 | 'title': (title of this work) 48 | 'fandom': (fandom this work belongs to, separated by ',') 49 | 'genre': (genres this work belongs to, separated by ',') 50 | 'Summary': (Summary written by the author) 51 | 'num_chaps': (number of chapters) 52 | 'num_words': (number of words) 53 | 'text': (the work content) 54 | } 55 | } 56 | ``` 57 | 58 | 59 | ## Notes 60 | 61 | While evaluating each example of ChapterBreak, if the total length of (prefix+suffix) is greater than the maximum input length of a model, truncate the prefix from left and the suffix from right. In our experiments, we truncate the suffix to be maximum 128 as the segment size of our SuffixLM is 128. We show in Appendix G. the variation in suffix lengths does not explain the large performance gap between SuffixLM and other LMs evaluated in our work. 62 | 63 | 64 | 65 | # Model 66 | ## Model download 67 | 68 | We provide a pre-trained suffixLM [here](https://drive.google.com/file/d/1eQJABPau-rbeag2_aZI-nrtZaD_HgT0t/view?usp=sharing). This model is trained on PG19 with segment size 128 and max input sequence of 10K tokens. To train your own suffixLM, please read the following sections. 69 | 70 | ## Evaluate SuffixLM 71 | 72 | To run the SuffixLM, first run the following code to encode the chapterbreak dataset to binary data: 73 | 74 | ``` 75 | input_path=path/to/downloaded/chapterbreak/chapterbreak_ctx_512.json 76 | output_path=/path/to/output/encoded/chapterbreak 77 | mkdir -p $output_path 78 | python tokenize_eval_data.py \ 79 | --input-path $input_path \ 80 | --output-path $output_path \ 81 | --tokenize-only 82 | ``` 83 | After running the above command, you should see two files are created in your `output_path`: `pg19_ctx512.pkl` and `ao3_ctx512.pkl`. To run the ctx with other lengths, replace the `512` in the above example with the corresponding sequence length. The length of original suffixes is around 150 words, we truncate them to 128 tokens as the segment size of our SuffixLM is set to be 128. 84 | 85 | Next, make sure you have a trained or downloaded suffixLM model with name `best_checkpoint.pt` in your experiment folder. 86 | 87 | Run the following command to evaluate suffixLM on `PG19` split with prefix length 512, replace `[port_number]` with a random port number 88 | 89 | ``` 90 | data_path=/path/to/encoded/chapterbreak/pg19_ctx512.pkl 91 | experiment_path=/path/to/trained/suffixlm/model 92 | python -m torch.distributed.run --master_port [port_number] main.py \ 93 | --data-path $data_path \ 94 | --fp16 \ 95 | --action eval-seg-lm \ 96 | --batch-size 1 \ 97 | --restore $experiment_path/best_checkpoint.pt 98 | 99 | ``` 100 | 101 | ## Train suffixLM 102 | 103 | To train suffixLM, first download PG19 data from [here](https://github.com/deepmind/pg19), then encode the PG19 data for training SuffixLM. You can encode the sharded data in parallel using commands such as the following: 104 | ``` 105 | IN_PATH=/path/to/raw/text/train 106 | OUT_PATH=/path/to/encoded/text/train-tok 107 | for (( SHARD_ID={s_id}; SHARD_ID<{e_id}; SHARD_ID++ )); do 108 | python encode_pg19_train.py \ 109 | --input-path $IN_PATH --output-path $OUT_PATH --shard-id $SHARD_ID --shard-size 100 \ 110 | --chunk-size 128 --batch-size 64 111 | done 112 | ``` 113 | 114 | You also need to encode the `test` and `eval` set similarly. After encoding the data, you should have a folder named `train-tok` containing the encoded training files and a folder `valid-tok` containing encoded validation file. Train suffixLM with the following command. 115 | 116 | ``` 117 | data_path=/path/to/data 118 | experiment_path=/path/to/save/checkpoints 119 | mkdir -p $experiment_path 120 | export NGPU=1 121 | python -m torch.distributed.launch --master_port=[port_number] main.py \ 122 | --data-path $data_path \ 123 | --fp16 \ 124 | --split train-tok \ 125 | --action train-seg-lm \ 126 | --checkpoint-path $experiment_path 127 | ``` 128 | 129 | 130 | 131 | # Citation 132 | 133 | ``` 134 | @inproceedings{long21, 135 | author={Simeng Sun and Katherine Thai and Mohit Iyyer}, 136 | Booktitle = {North American Association for Computational Linguistics}, 137 | Year = "2022", 138 | Title={ChapterBreak: A Challenge Dataset for Long-Range Language Models}, 139 | } 140 | ``` 141 | -------------------------------------------------------------------------------- /actions/__init__.py: -------------------------------------------------------------------------------- 1 | from actions.train import Trainer -------------------------------------------------------------------------------- /actions/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pdb 4 | import wandb 5 | import shutil 6 | import pickle 7 | import logging 8 | import numpy as np 9 | import time 10 | import torch 11 | import torch.nn as nn 12 | from torch import optim 13 | import numpy as np 14 | from apex import amp 15 | from tqdm import tqdm 16 | from models.utils import checkpoint 17 | 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | class Evaluator(object): 22 | 23 | def __init__(self, args, model, dataloaders, eval_type="segment"): 24 | 25 | self.args = args 26 | self.model = model 27 | self.model = model.cuda() 28 | self.model = nn.parallel.DistributedDataParallel(self.model, 29 | device_ids=[args.local_rank], 30 | output_device=args.local_rank, broadcast_buffers=True) 31 | self.eval_type = eval_type 32 | self.dataloaders = dataloaders 33 | 34 | self.modules = {'model': self.model} 35 | 36 | def eval_segment(self): 37 | self.model.eval() 38 | 39 | total_loss, total_num_examples = 0, 0 40 | for dl in self.dataloaders: 41 | for batch in tqdm(dl): 42 | loss, pad_mask = self.model(batch) 43 | 44 | num_examples = pad_mask.sum() 45 | loss = loss.sum() 46 | 47 | total_loss += loss 48 | total_num_examples += num_examples 49 | 50 | self.model.train() 51 | 52 | return total_loss/total_num_examples 53 | 54 | def eval_suffix_identification(self): 55 | self.model.eval() 56 | 57 | total_num = 0 58 | crrc_num = 0 59 | res = [] 60 | for i, batch in enumerate(tqdm(self.dataloaders)): 61 | batch = { 62 | 'data': batch['data'].squeeze(0), 63 | 'padding_mask': batch['padding_mask'].squeeze(0) 64 | } 65 | 66 | dr = self.model(batch) 67 | correct = ((dr - dr[0]) > 0).sum() == 0 68 | crrc_num += (1 * correct).item() 69 | total_num += 1 70 | 71 | res.append(((1 * correct).item(), dr.cpu().tolist())) 72 | print(f"suffix identification accuray: {crrc_num / total_num}") 73 | 74 | def __call__(self, ): 75 | 76 | with torch.no_grad(): 77 | 78 | if self.eval_type == "segment": 79 | return self.eval_segment() 80 | 81 | else: 82 | return self.eval_suffix_identification() 83 | 84 | -------------------------------------------------------------------------------- /actions/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pdb 4 | import wandb 5 | import shutil 6 | import random 7 | import logging 8 | import numpy as np 9 | import time 10 | import torch 11 | import torch.nn as nn 12 | from torch import optim 13 | from torch.optim.lr_scheduler import CosineAnnealingLR 14 | import numpy as np 15 | from apex import amp 16 | from tqdm import tqdm 17 | from actions.evaluate import Evaluator 18 | from models.utils import checkpoint 19 | 20 | random.seed(42) 21 | logger = logging.getLogger(__name__) 22 | torch.set_num_threads(1) 23 | class Trainer(object): 24 | 25 | def __init__(self, args, model, dataloader_lst, train_type="segment", validation_dataloader_lst=None, clip=0.25): 26 | 27 | self.args = args 28 | self.model = model 29 | self.train_type = train_type 30 | self.dataloaders = dataloader_lst 31 | self.validation_dataloaders = validation_dataloader_lst 32 | self.clip = clip 33 | self.epoch = 0 34 | if self.args.wandb: 35 | wandb.init(project=self.args.project_name, config=vars(args)) 36 | print(f"setting run name to {args.wandb_run_name}") 37 | wandb.run.name = args.wandb_run_name 38 | wandb.run.save() 39 | self.steps = 0 40 | 41 | self.optimizer = optim.Adam(model.parameters(), 42 | args.learning_rate, 43 | betas=(0.9, 0.999), eps=1e-08) 44 | 45 | self.lr_scheduler = CosineAnnealingLR(self.optimizer, args.max_steps//args.accumulate_steps, eta_min=args.final_lr) 46 | 47 | 48 | self.model = model.cuda() 49 | if args.fp16: 50 | self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O1') 51 | 52 | self.model = nn.parallel.DistributedDataParallel(self.model, 53 | device_ids=[args.local_rank], 54 | output_device=args.local_rank, broadcast_buffers=True, 55 | find_unused_parameters=True) 56 | 57 | if self.args.wandb: 58 | wandb.watch(self.model, log='all') 59 | 60 | self.modules = { 61 | 'model': self.model, 62 | 'optimizer': self.optimizer, 63 | 'lr_scheduler': self.lr_scheduler 64 | } 65 | 66 | self.best_eval_loss = sys.maxsize 67 | 68 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 69 | datefmt='%m/%d/%Y %H:%M:%S', 70 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 71 | logger.warning("Process rank: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 72 | args.local_rank, args.n_gpu_per_node, bool(args.local_rank != -1), args.fp16) 73 | 74 | def optimize(self, step_i): 75 | 76 | if self.args.fp16: 77 | torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.clip) 78 | else: 79 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip) 80 | 81 | self.optimizer.step() 82 | self.optimizer.zero_grad() 83 | self.model.zero_grad() 84 | 85 | if step_i < self.args.warmup_steps: 86 | self.optimizer.param_groups[0]['lr'] = self.args.learning_rate * step_i / self.args.warmup_steps 87 | return self.optimizer.param_groups[0]['lr'] 88 | else: 89 | self.lr_scheduler.step() 90 | return self.lr_scheduler.get_lr()[0] 91 | 92 | def compute_gradient(self, batch): 93 | self.model.train() 94 | 95 | loss, pad_mask = self.model(batch) 96 | 97 | num_examples = pad_mask.sum() 98 | loss = loss.sum() 99 | 100 | if self.args.fp16: 101 | with amp.scale_loss(loss, self.optimizer) as scaled_loss: 102 | scaled_loss.backward() 103 | else: 104 | loss.backward() 105 | 106 | return loss, num_examples 107 | 108 | def save_checkpoint(self, step, best=False): 109 | 110 | checkpoint_path = checkpoint( 111 | self.epoch, step, self.modules, 112 | self.args.checkpoint_path, 113 | max_checkpoints=self.args.max_checkpoints 114 | ) 115 | 116 | if best: 117 | dirname = os.path.dirname(checkpoint_path) 118 | basename = os.path.basename(checkpoint_path) 119 | best_checkpoint_path = os.path.join(dirname, f'best_{basename}') 120 | shutil.copy2(checkpoint_path, best_checkpoint_path) 121 | 122 | def train_epoch_suffix_lm(self, ): 123 | 124 | total_loss = 0 125 | total_num_examples = 0 126 | def eval_and_save(optimize_steps): 127 | vdl_lst = [iter(vdl) for vdl in self.validation_dataloaders] 128 | evaluator = Evaluator(self.args, self.model, vdl_lst) 129 | eval_loss = evaluator() 130 | 131 | if self.args.wandb: 132 | wandb.log({'eval_loss': eval_loss}, step=self.steps//self.args.accumulate_steps) 133 | logger.info(f"eval_loss {eval_loss}") 134 | else: 135 | logger.info(f"eval_loss {eval_loss}") 136 | 137 | self.save_checkpoint(optimize_steps, best=eval_loss < self.best_eval_loss) 138 | self.best_eval_loss = eval_loss if eval_loss < self.best_eval_loss else self.best_eval_loss 139 | 140 | if self.args.wandb: 141 | wandb.log({'best_eval_loss': self.best_eval_loss}, step=self.steps//self.args.accumulate_steps) 142 | logger.info(f"best_eval_loss {self.best_eval_loss}") 143 | else: 144 | logger.info(f"best_eval_loss {self.best_eval_loss}") 145 | 146 | def sample_batch(dataloaders): 147 | 148 | if len(dataloaders) == 0: 149 | return None 150 | 151 | dl_idx = random.choice(range(len(dataloaders))) 152 | try: 153 | batch = next(dataloaders[dl_idx]) 154 | return batch 155 | 156 | except: 157 | dataloaders.pop(dl_idx) 158 | return sample_batch(dataloaders) 159 | 160 | def optimize_batch(batch, total_loss, total_num_examples): 161 | loss, num_examples = self.compute_gradient(batch) 162 | self.steps += 1 163 | 164 | total_loss += loss 165 | total_num_examples += num_examples 166 | optimize_steps = self.steps // self.args.accumulate_steps 167 | 168 | if self.steps % self.args.accumulate_steps == 0: 169 | lr = self.optimize(optimize_steps) 170 | if self.args.wandb: 171 | wandb.log({'train_loss': total_loss / total_num_examples, 172 | 'steps': self.steps, 173 | 'optimize_steps': self.steps // self.args.accumulate_steps, 174 | 'lr': lr}, step=self.steps//self.args.accumulate_steps) 175 | logger.info(f"train_loss {total_loss.item() / total_num_examples} optimize_steps {self.steps // self.args.accumulate_steps} lr {lr}") 176 | else: 177 | logger.info(f"train_loss {total_loss.item() / total_num_examples} optimize_steps {self.steps // self.args.accumulate_steps} lr {lr}") 178 | 179 | 180 | if optimize_steps % self.args.eval_every == 0: 181 | eval_and_save(optimize_steps) 182 | 183 | if optimize_steps % self.args.ckpt_every == 0: 184 | self.save_checkpoint(optimize_steps) 185 | 186 | return optimize_steps, total_loss, total_num_examples 187 | 188 | dataloaders = [iter(dl) for dl in self.dataloaders] 189 | 190 | while True: 191 | batch = sample_batch(dataloaders) 192 | if batch is None: 193 | break 194 | 195 | optimize_steps, total_loss, total_num_examples = optimize_batch(batch, 196 | total_loss, 197 | total_num_examples) 198 | if optimize_steps >= self.args.max_steps: 199 | break 200 | 201 | 202 | lr = self.optimize(self.steps) 203 | if self.args.wandb: 204 | wandb.log({'train_loss': total_loss / total_num_examples, 205 | 'steps': self.steps, 206 | 'lr': lr}, step=self.steps//self.args.accumulate_steps) 207 | logger.info(f"train_loss {total_loss.item() / total_num_examples} optimize_steps {self.steps // self.args.accumulate_steps} lr {lr}") 208 | eval_and_save(optimize_steps) 209 | 210 | 211 | def train_suffix_lm(self): 212 | optimize_steps = self.steps // self.args.accumulate_steps 213 | while (optimize_steps < self.args.max_steps) and (self.epoch < self.args.max_epochs): 214 | self.epoch += 1 215 | self.train_epoch_suffix_lm() 216 | 217 | 218 | def __call__(self, ): 219 | 220 | if self.train_type == "segment": 221 | self.train_suffix_lm() 222 | 223 | else: 224 | raise NotImplementedError 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | ACTIONS = ['train-seg-lm', 'train-lm', 'eval-seg-lm', 'eval-lm', 'preprocess-data'] 4 | 5 | def parse_args(): 6 | parser = argparse.ArgumentParser() 7 | 8 | # general args 9 | parser.add_argument("--action", type=str, default="train", choices=ACTIONS) 10 | parser.add_argument("--checkpoint-path", type=str, default=None) 11 | parser.add_argument("--max-checkpoints", type=int, default=5) 12 | parser.add_argument("--debug", action="store_true", default=False) 13 | parser.add_argument("--wandb", action="store_true", default=False) 14 | parser.add_argument("--project-name", type=str, default="lrlm") 15 | parser.add_argument("--wandb-run-name", type=str, default=None) 16 | parser.add_argument("--restore", type=str, default=None, help="reload checkpoint path") 17 | parser.add_argument("--load-optimizer", action="store_true", default=False) 18 | 19 | # data args 20 | parser.add_argument("--data-path", type=str, default=None) 21 | parser.add_argument("--batch-size", type=int, default=32) 22 | parser.add_argument("--split", type=str, default="train") 23 | parser.add_argument("--max-books", type=int, default=30000) 24 | parser.add_argument("--max-chunk-per-seq", type=int, default=64) 25 | parser.add_argument("--max-tokens-per-batch", type=int, default=10240) 26 | parser.add_argument("--chunk-size-list", type=int, default=[128], nargs='+') 27 | parser.add_argument("--preprocess-data", action="store_true", default=False) 28 | 29 | # eval args 30 | parser.add_argument("--eval-out-path", type=str, default=None) 31 | 32 | # decoder(suffixLM) args 33 | parser.add_argument("--embedding-size", type=int, default=768) 34 | parser.add_argument("--num-heads", type=int, default=8) 35 | parser.add_argument("--model-size", type=int, default=768) 36 | parser.add_argument("--num-layers", type=int, default=6) 37 | parser.add_argument("--hidden-dim", type=int, default=2048) 38 | 39 | # LRLM general 40 | parser.add_argument("--train-small-roberta-from-scratch", action="store_true", default=False) 41 | parser.add_argument("--init-std", type=float, default=0.01) 42 | 43 | # train args 44 | parser.add_argument("--fp16", action="store_true", default=False) 45 | parser.add_argument("--learning-rate", type=float, default=7e-5) 46 | parser.add_argument("--final-lr", type=float, default=1e-7) 47 | parser.add_argument("--max-steps", type=int, default=200000) 48 | parser.add_argument("--warmup-steps", type=int, default=400) 49 | parser.add_argument("--max-epochs", type=int, default=150) 50 | parser.add_argument("--accumulate-steps", type=int, default=4) 51 | parser.add_argument("--optimizer", type=str, default="adam") 52 | parser.add_argument("--dropout-p", type=float, default=0.1) 53 | parser.add_argument("--local_rank", type=int, default=-1) 54 | parser.add_argument("--master_port", type=int, default=-1) 55 | parser.add_argument("--eval-every", type=int, default=500) 56 | parser.add_argument("--ckpt-every", type=int, default=1000) 57 | args = parser.parse_args() 58 | return args 59 | 60 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import pdb 4 | import pickle 5 | import torch 6 | import random 7 | import logging 8 | import numpy as np 9 | from tqdm import tqdm 10 | import torch.nn.functional as F 11 | from torch.utils.data import Dataset 12 | from torch.nn.utils.rnn import pad_sequence 13 | logger = logging.getLogger(__name__) 14 | random.seed(42) 15 | 16 | from transformers import RobertaTokenizer, GPT2Tokenizer 17 | tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 18 | cls_id = tokenizer.convert_tokens_to_ids(tokenizer.cls_token) 19 | eos_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token) 20 | pad_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) 21 | 22 | class SLMEvalDataset(Dataset): 23 | """ 24 | Dataset for ChapterBreak, need to preprocess 25 | the chapterbreak data with `tokenize_eval_data.py` first 26 | """ 27 | def __init__(self, args, data_path, batch_size): 28 | 29 | with open(data_path, "rb") as f: 30 | data = pickle.load(f) 31 | 32 | assert batch_size == 1, "Need to have one example per evaluation batch" 33 | self.args = args 34 | self.data = [example for book_id in data for example in data[book_id]] 35 | 36 | def __len__(self): 37 | return len(self.data) 38 | 39 | def _get_cls_item(self, idx): 40 | 41 | this_item = self.data[idx] 42 | 43 | ctx_vec = this_item['ctx_vec'] 44 | pos_vec = this_item['pos_vec'] 45 | 46 | # tokenized context could be longer, need to cut it down 47 | # to the maximize chunk length allowed by the trained suffixLM 48 | ctx_vec = ctx_vec[-(self.args.max_tokens_per_batch // self.args.chunk_size_list[0]-1):] 49 | 50 | negs_vec = this_item['negs_vec'] 51 | 52 | pos_data = torch.cat([ctx_vec, pos_vec], dim=0) 53 | negs_data = [torch.cat([ctx_vec, neg_vec]) for neg_vec in negs_vec] 54 | 55 | batch_data = torch.stack([pos_data] + negs_data) 56 | padding_mask = torch.zeros(batch_data.shape[0], batch_data.shape[1]).float() 57 | 58 | return { 59 | 'data': batch_data, 60 | 'padding_mask': padding_mask 61 | } 62 | 63 | def __getitem__(self, idx): 64 | 65 | return self._get_cls_item(idx) 66 | 67 | 68 | class TextDataset(Dataset): 69 | """ 70 | Dataset for loading data for training SuffixLM 71 | """ 72 | def __init__(self, args, data_path, max_num_chunks, split, max_books=500, chunk_size=256, fileids=None): 73 | ''' 74 | load tokenized and binarized training data data, 75 | divide to chunks of max_tokens_per_batch 76 | 77 | chunk_size: chunk_size can be different for each dataset 78 | num_chunks: the seq_len for segment-level LM, depends on 79 | max_token_per_batch and chunk_size 80 | ''' 81 | self.max_books = max_books 82 | self.max_num_chunks = max_num_chunks 83 | self.args = args 84 | self.chunk_size_list = args.chunk_size_list 85 | self.max_tokens_per_batch = args.max_tokens_per_batch 86 | 87 | # load roberta tokenized input ids 88 | book_data, book_fid = self.load_tokenized(os.path.join(data_path, split), chunk_size, fileids) # {'book_id': [#segments, chunk_size]} {'book_id': 'file_id'} 89 | print(f'number of books {len(book_data)}') 90 | 91 | self.data = self.group_segment(book_data, max_num_chunks) # [ [<=max_num_chunks, chunk_size], ] 92 | print(f'number of examples {len(self.data)}') 93 | 94 | def _segment_book(self, book_input_ids, chunk_size=256): 95 | ''' 96 | 1. read data and do sentence tokenization 97 | 2. keep adding sentence until after tokenization > chunk_size tokens, step back and pad 98 | ''' 99 | this_book_ids = [] 100 | segment_ids = [] 101 | for sent_ids in book_input_ids: 102 | 103 | if len(segment_ids) + len(sent_ids) < chunk_size - 1: 104 | segment_ids.extend(sent_ids) 105 | 106 | else: 107 | if len(segment_ids) != 0 and len(segment_ids) < chunk_size - 1: 108 | segment_ids += [eos_id] 109 | segment_ids.extend([pad_id] * (chunk_size - len(segment_ids) - 1)) 110 | segment_ids = [cls_id] + segment_ids 111 | assert len(segment_ids) == chunk_size, pdb.set_trace() 112 | this_book_ids.append(segment_ids) 113 | segment_ids = sent_ids 114 | 115 | if len(segment_ids) == 0 or len(segment_ids) >= chunk_size - 1: 116 | this_ids = sent_ids if len(segment_ids) == 0 else segment_ids 117 | num_sent_chunks = len(this_ids) // (chunk_size - 1) + 1 118 | for chunk_id in range(num_sent_chunks): 119 | sid = chunk_id * (chunk_size - 1) 120 | eid = min((chunk_id+1) * (chunk_size - 1), len(this_ids)) 121 | if eid == (chunk_id+1) * (chunk_size - 1): 122 | this_book_ids.append([cls_id] + this_ids[sid:eid]) 123 | else: 124 | this_book_ids.append([cls_id] + this_ids[sid:eid] + [eos_id] + [pad_id] * (chunk_size - (eid-sid) - 2)) 125 | segment_ids = [] 126 | 127 | assert all(len(x) == chunk_size for x in this_book_ids), pdb.set_trace() 128 | return torch.tensor(this_book_ids) 129 | 130 | def group_segment(self, book_data, num_chunks): 131 | ret = [] 132 | for book_id in book_data: 133 | this_book_data = book_data[book_id] 134 | this_book_data = torch.split(this_book_data, num_chunks) 135 | this_book_data = [bd for bd in this_book_data if bd.shape[0] == num_chunks] 136 | ret.extend(this_book_data) 137 | return ret 138 | 139 | def load_tokenized(self, data_path, chunk_size, fileids): 140 | ''' 141 | load tokenized training data 142 | each .pkl file contain _n_ books, 143 | each book contains field 'input_ids', 144 | which is a list of list of roberta token_ids, 145 | each list is a sentence, no `bos` and `eos`, need to insert later 146 | 147 | output format: {'book_id': [#segments, chunk_size]} 148 | } 149 | ''' 150 | files = sorted([fn for fn in os.listdir(data_path) if fn.endswith(".pkl")]) 151 | fileids = set(fileids) 152 | 153 | ret = {} 154 | book_fid_map = {} 155 | for fname in tqdm(files): 156 | 157 | with open(os.path.join(data_path, fname), "rb") as f: 158 | data = pickle.load(f) 159 | 160 | for book_id in data: 161 | book_data = data[book_id]['input_ids'] 162 | book_data = self._segment_book(book_data, chunk_size=chunk_size) 163 | ret[book_id]= book_data 164 | book_fid_map[book_id] = fname 165 | if len(ret) >= self.max_books: 166 | break 167 | 168 | if len(ret) >= self.max_books: 169 | break 170 | 171 | return ret, book_fid_map 172 | 173 | def __len__(self): 174 | return len(self.data) 175 | 176 | def __getitem__(self, idx): 177 | 178 | this_item = torch.tensor(self.data[idx]) 179 | padding_mask = torch.zeros(self.max_num_chunks).float() 180 | length = this_item.shape[0] 181 | if length < self.max_num_chunks: 182 | padding_mask[-(self.max_num_chunks-length):] = float('-inf') 183 | 184 | ret = { 185 | 'data': F.pad(this_item, (0, 0, 0, self.max_num_chunks-length)), 186 | 'padding_mask': padding_mask 187 | } 188 | 189 | return ret 190 | 191 | 192 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import pdb 3 | from args import parse_args 4 | from data.utils import SLMEvalDataset 5 | from actions.train import Trainer 6 | from actions.evaluate import Evaluator 7 | from models.long_range_lm import LRLM 8 | from slurm import init_distributed_mode 9 | from models.utils import restore 10 | from utils import prepare_data 11 | 12 | import torch 13 | from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler 14 | 15 | def main(): 16 | args = parse_args() 17 | 18 | print('='*42) 19 | print('All configs:') 20 | v_configs = vars(args) 21 | for k in v_configs: 22 | print('\t{:20s} {:50s}'.format(k, str(v_configs[k]))) 23 | print('='*42) 24 | 25 | init_distributed_mode(args) 26 | 27 | if args.action == "preprocess-data": 28 | dl_lst, vdl_lst = prepare_data(args) 29 | exit(0) 30 | 31 | elif args.action == "train-seg-lm": 32 | 33 | dl_lst, vdl_lst = prepare_data(args) 34 | 35 | model = LRLM(args) 36 | print(model) 37 | 38 | actioner = Trainer(args, model, dl_lst, validation_dataloader_lst=vdl_lst, train_type="segment") 39 | 40 | if args.restore: 41 | 42 | restore_modules = { 43 | module_name: module 44 | for module_name, module in actioner.modules.items() 45 | } 46 | 47 | epoch, step = restore( 48 | args.restore, 49 | restore_modules, 50 | num_checkpoints=1, 51 | map_location=torch.device('cuda'), 52 | strict=False 53 | ) 54 | 55 | actioner.steps = 0 if not args.load_optimizer else step * args.accumulate_steps 56 | actioner.epoch = 0 57 | 58 | elif args.action == "eval-seg-lm": 59 | 60 | ds = SLMEvalDataset(args, args.data_path, batch_size=1) 61 | sampler = SequentialSampler(ds) 62 | dl = DataLoader(ds, sampler=sampler, batch_size=args.batch_size) 63 | 64 | args.dropout_p = 0.0 65 | model = LRLM(args) 66 | actioner = Evaluator(args, model, dl, eval_type="suffix_identification") 67 | 68 | restore_modules = { 69 | module_name: module 70 | for module_name, module in actioner.modules.items() 71 | } 72 | 73 | epoch, step = restore( 74 | args.restore, 75 | restore_modules, 76 | num_checkpoints=1, 77 | map_location=torch.device('cuda'), 78 | strict=True 79 | ) 80 | 81 | else: 82 | raise NotImplementedError 83 | 84 | actioner() 85 | 86 | if __name__ == "__main__": 87 | main() 88 | 89 | 90 | 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.long_range_lm import LRLM -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | ''' 2 | A module which implements various attention mechanisms 3 | ''' 4 | import pdb 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from utils import same_tensor 10 | 11 | 12 | 13 | class MultiHeadedAttention(nn.Module): 14 | ''' Implement a multi-headed attention module ''' 15 | def __init__(self, embed_dim, num_heads=1): 16 | ''' Initialize the attention module ''' 17 | super(MultiHeadedAttention, self).__init__() 18 | 19 | # ensure valid inputs 20 | assert embed_dim % num_heads == 0, \ 21 | f'num_heads={num_heads} should evenly divide embed_dim={embed_dim}' 22 | 23 | # store off the scale and input params 24 | self.embed_dim = embed_dim 25 | self.num_heads = num_heads 26 | self.projection_dim = embed_dim // num_heads 27 | self.scale = self.projection_dim ** -0.5 28 | 29 | # Combine projections for multiple heads into a single linear layer for efficiency 30 | self.input_weights = nn.Parameter(torch.Tensor(3 * embed_dim, embed_dim)) 31 | self.output_projection = nn.Linear(embed_dim, embed_dim, bias=False) 32 | 33 | self.reset_parameters() 34 | 35 | def reset_parameters(self): 36 | ''' Reset parameters using xavier initialization ''' 37 | # Initialize using Xavier 38 | gain = nn.init.calculate_gain('linear') 39 | nn.init.xavier_uniform_(self.input_weights, gain) 40 | nn.init.xavier_uniform_(self.output_projection.weight, gain) 41 | 42 | def project(self, inputs, index=0, chunks=1): 43 | ''' Produce a linear projection using the weights ''' 44 | batch_size = inputs.shape[0] 45 | start = index * self.embed_dim 46 | end = start + chunks * self.embed_dim 47 | projections = F.linear(inputs, self.input_weights[start:end]).chunk(chunks, dim=-1) 48 | 49 | output_projections = [] 50 | for projection in projections: 51 | # transform projection to (BH x T x E) 52 | output_projections.append( 53 | projection.view( 54 | batch_size, 55 | -1, 56 | self.num_heads, 57 | self.projection_dim 58 | ).transpose(2, 1).contiguous().view( 59 | batch_size * self.num_heads, 60 | -1, 61 | self.projection_dim 62 | ) 63 | ) 64 | 65 | return output_projections 66 | 67 | def attention(self, values, keys, queries, key_mask=None, mask=None): 68 | ''' Scaled dot product attention with optional masks ''' 69 | 70 | logits = self.scale * torch.bmm(queries, keys.transpose(2, 1)) 71 | 72 | if mask is not None: 73 | logits += mask 74 | 75 | if key_mask is not None: 76 | logits_shape = logits.shape 77 | batch_size = logits_shape[0] // self.num_heads 78 | logits = logits.view(batch_size, self.num_heads, logits_shape[1], logits_shape[2]) 79 | logits.masked_fill_(key_mask[:, None, None], float('-inf')) 80 | logits = logits.view(logits_shape) 81 | values = torch.nan_to_num(values) 82 | 83 | attended = torch.bmm(F.softmax(logits, dim=-1), values) 84 | 85 | # By this point the values, keys, and queries all have B * H as their first dimension 86 | batch_size = queries.shape[0] // self.num_heads 87 | return attended.view( 88 | batch_size, 89 | self.num_heads, 90 | -1, 91 | self.projection_dim 92 | ).transpose(2, 1).contiguous().view( 93 | batch_size, 94 | -1, 95 | self.num_heads * self.projection_dim 96 | ) 97 | 98 | def forward(self, values, keys, queries, # pylint:disable=arguments-differ 99 | key_mask=None, attention_mask=None, num_queries=0): 100 | ''' Forward pass of the attention ''' 101 | # pylint:disable=unbalanced-tuple-unpacking 102 | if same_tensor(values, keys, queries): 103 | values, keys, queries = self.project(values, chunks=3) 104 | elif same_tensor(values, keys): 105 | values, keys = self.project(values, chunks=2) 106 | queries, = self.project(queries, 2) 107 | else: 108 | values, = self.project(values, 0) 109 | keys, = self.project(keys, 1) 110 | queries, = self.project(queries, 2) 111 | # pylint:enable=unbalanced-tuple-unpacking 112 | 113 | if num_queries: 114 | queries = queries[:, -num_queries:] 115 | 116 | attended = self.attention(values, keys, queries, key_mask, attention_mask) 117 | 118 | return self.output_projection(attended) 119 | 120 | 121 | -------------------------------------------------------------------------------- /models/embeddings.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import torch 3 | import math 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | class TokenEmbedding(nn.Embedding): 8 | ''' An embedding layer used for the transformer ''' 9 | def __init__(self, num_embeddings, embedding_dim, padding_idx=0): 10 | super(TokenEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx=padding_idx) 11 | 12 | self.scale = embedding_dim ** 0.5 13 | nn.init.constant_(self.weight[padding_idx], 0) 14 | nn.init.normal_(self.weight, mean=0, std=embedding_dim ** -0.5) 15 | 16 | def forward(self, inputs, transpose=False): # pylint:disable=arguments-differ 17 | ''' Implement the forward pass of the embedding ''' 18 | if transpose: 19 | return F.linear(inputs, self.weight) 20 | else: 21 | return self.scale * super(TokenEmbedding, self).forward(inputs) 22 | 23 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 24 | missing_keys, unexpected_keys, error_msgs): 25 | ''' 26 | Not sure if this is the best approach, but override the internal function to support loading 27 | from a different sized vocabulary. 28 | ''' 29 | if strict: 30 | super(TokenEmbedding, self)._load_from_state_dict( 31 | state_dict, prefix, local_metadata, strict, 32 | missing_keys, unexpected_keys, error_msgs 33 | ) 34 | else: 35 | # Support loading from a different sized vocabulary 36 | weight_name = prefix + 'weight' 37 | for key, param in state_dict.items(): 38 | if key == weight_name: 39 | old_vocab_size = len(param) 40 | new_vocab_size = len(self.weight) 41 | vocab_size_diff = new_vocab_size - old_vocab_size 42 | if vocab_size_diff > 0: 43 | param = torch.cat((param, self.weight[old_vocab_size:]), 0) 44 | else: 45 | param = param[:new_vocab_size] 46 | 47 | self.weight.data.copy_(param) 48 | 49 | 50 | return self.compute_bias(qlen, klen) # shape (1, num_heads, qlen, klen) 51 | 52 | 53 | class PositionEmbedding(nn.Module): 54 | ''' Produce position embeddings ''' 55 | def __init__(self, dim, freq=1e4): 56 | ''' Initialize the PositionEmbedding ''' 57 | super(PositionEmbedding, self).__init__() 58 | 59 | self.dim = dim # require the number of dimension to be even 60 | self.freq = freq 61 | 62 | _embeddings = threading.local() 63 | def forward(self, inputs): # pylint:disable=arguments-differ 64 | ''' Implement the forward pass of the embedding ''' 65 | device = inputs.device 66 | max_length = inputs.shape[1] 67 | embedding_store = PositionEmbedding._embeddings.__dict__ 68 | device_store = embedding_store.get(device, {}) 69 | if ( 70 | not device_store or 71 | self.dim not in device_store or 72 | device_store[self.dim].shape[0] < max_length 73 | ): 74 | positions = torch.arange(0., max_length, device=device).unsqueeze(1) 75 | 76 | # the tensor2tensor code is slightly different than described in the paper 77 | # dividing by (self.dim - 2) produces nearly identical results to their version 78 | # when comparing the tensorflow results to these torch results 79 | dims = torch.arange(0., self.dim, 2., device=device).unsqueeze(0) / (self.dim - 2) 80 | 81 | sin = torch.sin(positions / torch.pow(self.freq, dims)) 82 | cos = torch.cos(positions / torch.pow(self.freq, dims)) 83 | 84 | embeddings = torch.stack((sin, cos), 0) 85 | device_store[self.dim] = embeddings.transpose(0, 1).contiguous().view(-1, self.dim) 86 | 87 | embeddings = device_store[self.dim] 88 | embedding_store[device] = device_store 89 | 90 | return embeddings[:max_length].unsqueeze(0) 91 | -------------------------------------------------------------------------------- /models/long_range_lm.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import copy 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from models.transformer import Transformer 8 | from transformers import ( 9 | RobertaTokenizer, 10 | RobertaModel, 11 | ) 12 | 13 | from data.utils import pad_id 14 | 15 | 16 | class LRLM(nn.Module): 17 | 18 | def __init__(self, args): 19 | super(LRLM, self).__init__() 20 | 21 | self.args = args 22 | 23 | roberta = RobertaModel.from_pretrained('roberta-base') 24 | 25 | if 'eval' in self.args.action: 26 | roberta.eval() 27 | del roberta.encoder.layer[-10:] 28 | self.encoder = roberta 29 | if 'eval' in self.args.action: 30 | self.encoder.eval() 31 | print(self.encoder) 32 | 33 | self.sentence_lm = Transformer(args, encoder=False) 34 | 35 | def _encode_pretrained_lm(self, batch): 36 | ''' 37 | In: 38 | batch['data'] -> tensor([bsz, num_chunks, segment_size]) 39 | batch['padding_mask'] -> tensor([bsz, num_chunks]) 40 | Out: 41 | encoded batch -> tensor([bsz, num_chunks, d_model]) 42 | ''' 43 | 44 | bsz, num_chunks, segment_size = batch['data'].shape 45 | data = batch['data'].contiguous().view(-1, segment_size) 46 | attention_mask = (data != pad_id).long() 47 | 48 | 49 | 50 | out = self.encoder(input_ids=data, 51 | attention_mask=attention_mask, 52 | output_attentions=False, 53 | output_hidden_states=False).last_hidden_state[:, 0, :] # bsz*num_chunks x 768 54 | 55 | if 'eval' in self.args.action: 56 | outputs = out.detach() 57 | else: 58 | outputs = out 59 | 60 | return outputs.contiguous().view(bsz, num_chunks, outputs.shape[-1]) 61 | 62 | def sentence_lm_forward(self, batch): 63 | ''' 64 | In: 65 | batch -> tensor([bsz, num_chunks+1, 768]) 66 | Out: 67 | loss -> tensor([bsz, num_chunks]) 68 | ''' 69 | # bsz x (num_chunks) x d_model 70 | slm_out = self.sentence_lm(batch) # shift happens in the transformer 71 | 72 | # right shift batch by one [bsz, num_chunks-1, d_model] 73 | target = batch['data'][:, 1:, :] 74 | padding_mask = batch['padding_mask'][:, 1:] 75 | padding_mask[padding_mask == 0] = 1.0 76 | padding_mask[padding_mask < 0] = 0 77 | 78 | if self.args.action == "eval-seg-lm": 79 | slm_out = slm_out[0, -1] # should all be the same during eval-suffix, since context is the same 80 | target = target[:, -1] 81 | dr = torch.matmul(slm_out, target.transpose(1, 0)) 82 | return dr 83 | 84 | # a: predicted next segment representation 85 | # b: encoded gold next segment representation 86 | # model the density ratio p(b | a) / p(b) to maximize MI(encoded context, next segment) 87 | bsz, L, embed_dim = slm_out.shape 88 | slm_out_reshape = slm_out.contiguous().view(-1, embed_dim) 89 | target_reshape = target.contiguous().view(-1, embed_dim).transpose(1,0) 90 | dr_out = torch.matmul(slm_out_reshape, target_reshape) 91 | dr = dr_out.contiguous().view(bsz, L, -1) 92 | 93 | # contrastive loss 94 | bsz, n, _ = dr.shape 95 | loss = - F.log_softmax(dr * padding_mask[:, :, None], dim=2).reshape(bsz, n, bsz, n) 96 | loss = loss.transpose(2,1)[range(bsz), range(bsz)][:, range(n), range(n)] 97 | 98 | return loss, padding_mask 99 | 100 | def forward(self, raw_batch): 101 | ''' 102 | IN: 103 | raw_batch: 104 | raw_batch['data'] -> tensor([bsz, num_chunks, segment_size]) 105 | raw_batch['padding_mask'] -> tensor([bsz, num_chunks]) 106 | 107 | OUT: 108 | loss 109 | ''' 110 | 111 | batch = {} 112 | 113 | batch['data'] = self._encode_pretrained_lm(raw_batch) 114 | batch['data'] = F.pad(batch['data'], (0,0,1,0), "constant", 0) # pad left (segment-level padding) 115 | batch['padding_mask'] = F.pad(raw_batch['padding_mask'], (1, 0), "constant", 0) 116 | 117 | return self.sentence_lm_forward(batch) 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | A module which implements the basic Transformer 3 | ''' 4 | import uuid 5 | import threading 6 | import pdb 7 | import sys 8 | import torch 9 | from torch import nn 10 | import numpy as np 11 | from utils import triu 12 | from models.attention import MultiHeadedAttention 13 | from models.embeddings import PositionEmbedding, TokenEmbedding 14 | from transformers import RobertaTokenizer 15 | 16 | 17 | class TransformerSublayer(nn.Module): 18 | ''' 19 | Implements a sub layer of the transformer model, which consists of: 20 | 1) A sub layer module 21 | 2) Followed by dropout 22 | 3) Plus a residual connection 23 | 4) With layer normalization 24 | ''' 25 | def __init__(self, sublayer, sublayer_shape, dropout_p=0.1, init_std=0.02): 26 | ''' Initialize the transformer sublayer ''' 27 | super(TransformerSublayer, self).__init__() 28 | self.init_std = init_std 29 | self.sublayer = sublayer 30 | self.norm = nn.LayerNorm(sublayer_shape) 31 | self.dropout = nn.Dropout(dropout_p, inplace=True) 32 | self.reset_parameters() 33 | 34 | def reset_parameters(self): 35 | # ''' Reset parameters using xavier initialiation ''' 36 | # self.norm.reset_parameters() 37 | nn.init.normal_(self.norm.weight, 1.0, self.init_std) 38 | 39 | def forward(self, inputs, *sublayer_args, **sublayer_kwargs): # pylint:disable=arguments-differ 40 | ''' The forward pass of the sublayer ''' 41 | out = self.dropout(self.sublayer(*sublayer_args, **sublayer_kwargs)) 42 | return self.norm(inputs + out) 43 | 44 | 45 | class TransformerFFN(nn.Module): 46 | ''' Implements the Transformer feed-forward network ''' 47 | def __init__(self, embedding_size, hidden_dim, init_std=0.02): 48 | super(TransformerFFN, self).__init__() 49 | 50 | self.init_std = init_std 51 | self.relu = nn.ReLU() 52 | self.hidden = nn.Linear(embedding_size, hidden_dim) 53 | self.output = nn.Linear(hidden_dim, embedding_size) 54 | self.reset_parameters() 55 | 56 | def reset_parameters(self): 57 | ''' Reset parameters using xavier initialiation ''' 58 | nn.init.normal_(self.hidden.weight, 0., self.init_std) 59 | nn.init.normal_(self.output.weight, 0., self.init_std) 60 | nn.init.constant_(self.hidden.bias, 0.) 61 | nn.init.constant_(self.output.bias, 0.) 62 | 63 | def forward(self, inputs): # pylint:disable=arguments-differ 64 | ''' The forward pass of the feed-forward network ''' 65 | return self.output(self.relu(self.hidden(inputs))) 66 | 67 | 68 | class TransformerLayer(nn.Module): 69 | ''' Implements a single decoder layer in a transformer decoder stack ''' 70 | def __init__(self, config, num_heads, dim, hidden_dim, causal=True, 71 | dropout_p=0.1): 72 | ''' Initialize the transformer layer ''' 73 | super(TransformerLayer, self).__init__() 74 | 75 | self.uuid = uuid.uuid4() 76 | 77 | self.ffn = TransformerSublayer( 78 | TransformerFFN(dim, hidden_dim), 79 | dim, dropout_p) 80 | 81 | self.self_attention = TransformerSublayer( 82 | MultiHeadedAttention(dim, num_heads=num_heads), 83 | dim, dropout_p) 84 | 85 | # unidirectional lm 86 | self.causal = causal 87 | 88 | # enforce learned heads to look at local windows 89 | self.config = config 90 | 91 | def reset_parameters(self): 92 | ''' Reset the parameters of the module ''' 93 | self.ffn.reset_parameters() 94 | self.self_attention.reset_parameters() 95 | 96 | def forward(self, inputs, global_mask=None): # pylint:disable=arguments-differ 97 | ''' The forward pass ''' 98 | 99 | state = inputs['state'] 100 | cache = inputs.get('cache') 101 | decoder_position = state.shape[1] - 1 102 | L = state.shape[1] 103 | # each layer might have different config 104 | residual = state 105 | kwargs = {} 106 | kwargs['attention_mask'] = self.mask(state) # just causal mask 107 | kwargs['key_mask'] = inputs['padding_mask'] 108 | 109 | state = self.self_attention( 110 | residual, # residual 111 | state, state, state, **kwargs # passed to attention 112 | ) 113 | 114 | state = self.ffn( 115 | state, # residual 116 | state # passed to FF layer 117 | ) 118 | 119 | return {'state': state, 'padding_mask': inputs['padding_mask']} 120 | 121 | _masks = threading.local() 122 | def mask(self, inputs): 123 | ''' 124 | Get a self-attention mask 125 | The mask will be of shape [T x T] containing elements from the set {0, -inf} 126 | Input shape: (B x T x E) 127 | Output shape: (T x T) 128 | ''' 129 | if not self.causal: 130 | return None 131 | 132 | dim = inputs.shape[1] 133 | device = inputs.device 134 | mask_store = TransformerLayer._masks.__dict__ 135 | if device not in mask_store or (device in mask_store and mask_store[device].shape[1] < dim): 136 | mask = inputs.new_full((dim, dim), float('-inf')) 137 | mask_store[device] = triu(mask, 1, 1, 1) 138 | 139 | mask = mask_store[device] 140 | return mask[None, :dim, :dim] 141 | 142 | class Transformer(nn.Module): 143 | ''' The Transformer LM module ''' 144 | def __init__(self, config, encoder=False): 145 | ''' Initialize the Transformer ''' 146 | super(Transformer, self).__init__() 147 | 148 | self.config = config 149 | 150 | self.encoder = encoder 151 | if encoder and config.train_encoder_from_scratch: # need to have token embedding 152 | # to get the embedding table size, load roberta-base tokenizer 153 | tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 154 | self.padding_idx = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) 155 | print(f"embedding table size {len(tokenizer)}") 156 | self.token_embedding = TokenEmbedding( 157 | len(tokenizer), 158 | config.embedding_size, 159 | padding_idx=self.padding_idx 160 | ) 161 | self.position_embedding = PositionEmbedding(config.embedding_size) 162 | 163 | self.dropout = nn.Dropout(config.dropout_p, inplace=True) 164 | 165 | self.layers = self.create_layers(config, encoder=encoder) 166 | 167 | self.reset_named_parameters() 168 | 169 | @classmethod 170 | def create_layers(self, config, encoder=False, rpe=None): 171 | ''' Create the transformer decoders ''' 172 | kwargs = {'dropout_p': config.dropout_p, 173 | 'causal': not encoder} # sublayer kwargs 174 | 175 | args = [config, config.num_heads, config.model_size, config.hidden_dim] 176 | 177 | layers = nn.ModuleList([ 178 | TransformerLayer(*args, **kwargs) 179 | for layer_i in range(config.num_layers) 180 | ]) 181 | 182 | return layers 183 | 184 | def reset_named_parameters(self): 185 | 186 | for layer in self.layers: 187 | layer.reset_parameters() 188 | 189 | if hasattr(self, "token_embedding"): 190 | self.token_embedding.reset_parameters() 191 | 192 | def forward(self, batch): # pylint:disable=arguments-differ 193 | ''' batch: bsz x l x embed_dim if input is segment vector 194 | else 195 | (bsz * num_chunks) x num_tokens_per_chunk otherwise 196 | ''' 197 | 198 | if self.encoder: 199 | padding_mask = batch['padding_mask'].bool() 200 | batch = batch['data'] 201 | _, L = batch.shape 202 | 203 | else: 204 | padding_mask = batch['padding_mask'][:, :-1].bool() 205 | batch = batch['data'][:, :-1, :] 206 | bsz, L, embed_dim = batch.shape 207 | 208 | pos_added_batch = self.embed(batch, token_embedding=getattr(self, "token_embedding", None)) 209 | decoded = {'state': pos_added_batch, 'padding_mask' : padding_mask} 210 | 211 | # decoded['state'][batch == self.padding_idx] = 0 212 | for i, decoder in enumerate(self.layers): 213 | decoded = decoder(decoded) 214 | return decoded['state'] # bs x L x hidden_dim 215 | 216 | 217 | def embed(self, inputs, token_embedding=None): 218 | ''' Embed the given inputs ''' 219 | if token_embedding is None: # input is segment vector, no need to encode each token 220 | return self.dropout(inputs + self.position_embedding(inputs)) 221 | else: 222 | return self.dropout(token_embedding(inputs) + self.position_embedding(inputs)) 223 | 224 | 225 | 226 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | import shutil 4 | import tempfile 5 | from collections import OrderedDict 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | def restore(path, modules, num_checkpoints=1, map_location=None, strict=True): 11 | ''' 12 | Restore from a checkpoint 13 | 14 | Args: 15 | path - path to restore from 16 | modules - a dict of name to object that supports the method load_state_dict 17 | ''' 18 | if not os.path.isfile(path): 19 | print(f'Cannot find checkpoint: {path}') 20 | return 0, 0 21 | 22 | print(f'Loading checkpoint {path}') 23 | state = torch.load(path, map_location=map_location) 24 | 25 | if 'model' in modules: 26 | model_state = state['model'] 27 | root, ext = os.path.splitext(path) 28 | 29 | # strip any trailing digits 30 | base = root.rstrip(''.join(str(i) for i in range(10))) 31 | 32 | # determine the integer representation of the trailing digits 33 | idx = root[len(base):] 34 | start_idx = int(idx) if idx else 0 35 | 36 | count = 1 37 | for idx in range(1, num_checkpoints): 38 | # use the digits as the start index for loading subsequent checkpoints for averaging 39 | path = f'{base}{start_idx + idx}{ext}' 40 | if not os.path.isfile(path): 41 | print(f'Cannot find checkpoint: {path} Skipping it!') 42 | continue 43 | 44 | print(f'Averaging with checkpoint {path}') 45 | previous_state = torch.load(path, map_location=map_location) 46 | previous_model_state = previous_state['model'] 47 | for name, param in model_state.items(): 48 | param.mul_(count).add_(previous_model_state[name]).div_(count + 1) 49 | 50 | count += 1 51 | 52 | new_model_state = state['model'].copy() 53 | #for key in state['model']: 54 | # new_key = key#.replace('module.', '') 55 | # new_model_state[new_key] = state['model'][key] 56 | #del new_model_state[key] 57 | state['model'] = new_model_state 58 | 59 | for name, obj in modules.items(): 60 | if isinstance(obj, nn.Module): 61 | obj.load_state_dict(state[name], strict=strict) 62 | else: 63 | obj.load_state_dict(state[name]) 64 | return state['epoch'], state['step'] 65 | 66 | def checkpoint(epoch, step, modules, directory, filename='checkpoint.pt', max_checkpoints=5): 67 | ''' 68 | Save a checkpoint 69 | Args: 70 | epoch - current epoch 71 | step - current step 72 | modules - a dict of name to object that supports the method state_dict 73 | directory - the directory to save the checkpoint file 74 | filename - the filename of the checkpoint 75 | max_checkpoints - how many checkpoints to keep 76 | ''' 77 | if not os.path.isdir(directory): 78 | os.makedirs(directory) 79 | 80 | state = { 81 | 'step': step, 82 | 'epoch': epoch, 83 | } 84 | 85 | for name, obj in modules.items(): 86 | state[name] = obj.state_dict() 87 | 88 | with tempfile.NamedTemporaryFile() as temp_checkpoint_file: 89 | torch.save(state, temp_checkpoint_file) 90 | 91 | checkpoint_path = os.path.join(directory, filename) 92 | if os.path.exists(checkpoint_path): 93 | root, ext = os.path.splitext(filename) 94 | for i in range(max_checkpoints - 2, -1, -1): 95 | previous_path = os.path.join(directory, f'{root}{i}{ext}') if i else checkpoint_path 96 | if os.path.exists(previous_path): 97 | backup_path = os.path.join(directory, f'{root}{i+1}{ext}') 98 | if os.path.exists(backup_path): 99 | os.replace(previous_path, backup_path) 100 | else: 101 | os.rename(previous_path, backup_path) 102 | 103 | shutil.copy(temp_checkpoint_file.name, f'{checkpoint_path}.incomplete') 104 | os.rename(f'{checkpoint_path}.incomplete', checkpoint_path) 105 | 106 | return checkpoint_path 107 | 108 | -------------------------------------------------------------------------------- /preprocess/encode_eval_data.sh: -------------------------------------------------------------------------------- 1 | data_path=/mnt/nfs/work1/miyyer/simengsun/data/PG19/valid_hard_suffix/chapter_breaks_new/rawmaxlen-4096-suffix-168.pkl 2 | out_path=/mnt/nfs/work1/miyyer/simengsun/data/PG19/valid_hard_tok/chapter_breaks_new/rawmaxlen-4096-suffix-168.pkl 3 | 4 | mkdir -p $out_path 5 | python encode_eval_Data.py \ 6 | --input-path $data_path \ 7 | --output-path $out_path \ 8 | --chunk-size 128 \ 9 | --suffix-size 128 \ 10 | --tokenize-only 11 | -------------------------------------------------------------------------------- /preprocess/tokenize_eval_data.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import pdb 4 | import pickle 5 | import argparse 6 | from tqdm import tqdm 7 | import torch 8 | import random 9 | import torch 10 | from nltk.tokenize import sent_tokenize 11 | torch.set_num_threads(1) 12 | random.seed(42) 13 | from transformers import RobertaTokenizer, RobertaModel, AutoModel, AutoConfig 14 | tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 15 | cls_id = tokenizer.convert_tokens_to_ids(tokenizer.cls_token) 16 | eos_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token) 17 | pad_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--input-path", type=str, default=None) 22 | parser.add_argument("--output-path", type=str, default=None) 23 | parser.add_argument("--action", type=str, default=None) 24 | parser.add_argument("--chunk-size", type=int, default=256) 25 | parser.add_argument("--suffix-size", type=int, default=256) 26 | parser.add_argument("--tokenize-only", action="store_true", default=False) 27 | args = parser.parse_args() 28 | return args 29 | 30 | args = parse_args() 31 | 32 | model = RobertaModel.from_pretrained('roberta-base') 33 | mode = model.cuda() 34 | for param in model.parameters(): 35 | param.requires_grad = False 36 | 37 | def tokenize(args, data, chunk_size=256, single_chunk=False): 38 | ''' 39 | tokenize the data with roberta tokenizer and chunk to 256, 40 | respect sentence boundaries 41 | ''' 42 | sent_tok_data = sent_tokenize(data) 43 | segment_ids = [] 44 | this_book_ids = [] 45 | sent_chunk_map = {} 46 | for si, sent in enumerate(sent_tok_data): 47 | sent_ids = tokenizer(sent)['input_ids'][1:-1] # get rid of the , added later when adding to this_book_ids 48 | if len(segment_ids) + len(sent_ids) < chunk_size - 1: 49 | segment_ids.extend(sent_ids) 50 | 51 | else: 52 | # if adding new sentence leads to >256 tokens 53 | # back-off a sentence 54 | # pdb.set_trace() 55 | if len(segment_ids) != 0 and len(segment_ids) < chunk_size - 1: 56 | segment_ids += [eos_id] 57 | segment_ids.extend([pad_id] * (chunk_size - len(segment_ids) - 1)) 58 | segment_ids = [cls_id] + segment_ids 59 | assert len(segment_ids) == chunk_size, pdb.set_trace() 60 | this_book_ids.append(segment_ids) 61 | segment_ids = sent_ids 62 | 63 | if len(segment_ids) == 0 or len(segment_ids) >= chunk_size - 1: 64 | # split the sentence into multiple chunks of chunk_size 65 | this_ids = sent_ids if len(segment_ids) == 0 else segment_ids 66 | num_sent_chunks = len(this_ids) // (chunk_size - 1) + 1 67 | for chunk_id in range(num_sent_chunks): 68 | sid = chunk_id * (chunk_size - 1) 69 | eid = min((chunk_id+1) * (chunk_size - 1), len(this_ids)) 70 | if eid == (chunk_id+1) * (chunk_size - 1): 71 | this_book_ids.append([cls_id] + this_ids[sid:eid]) 72 | else: 73 | this_book_ids.append([cls_id] + this_ids[sid:eid] + [eos_id] + [pad_id] * (chunk_size - (eid-sid) - 2)) 74 | segment_ids = [] 75 | 76 | sent_chunk_map[si] = len(this_book_ids) 77 | 78 | if len(segment_ids) != 0: 79 | segment_ids += [eos_id] 80 | segment_ids.extend([pad_id] * (chunk_size - len(segment_ids) - 1)) 81 | segment_ids = [cls_id] + segment_ids 82 | assert len(segment_ids) == chunk_size, pdb.set_trace() 83 | this_book_ids.append(segment_ids) 84 | 85 | sent_chunk_map[si] = len(this_book_ids) 86 | 87 | if chunk_size != args.chunk_size: 88 | if chunk_size > args.chunk_size: 89 | for i in range(len(this_book_ids)): 90 | this_book_ids[i] = this_book_ids[i][:args.chunk_size] 91 | else: 92 | for i in range(len(this_book_ids)): 93 | this_book_ids[i] = this_book_ids[i] + [pad_id] * (args.chunk_size - len(this_book_ids[i])) 94 | assert len(this_book_ids[i]) == args.chunk_size 95 | 96 | if single_chunk: 97 | if len(this_book_ids) != 1: 98 | try: 99 | this_book_ids = [this_book_ids[0]] 100 | except: 101 | pdb.set_trace() 102 | else: 103 | assert all(len(x) == chunk_size for x in this_book_ids), pdb.set_trace() 104 | 105 | return torch.tensor(this_book_ids) 106 | 107 | 108 | def encode(input_ids, batch_size=64): 109 | ''' 110 | input_ids: #chunks x chunk_size(256) 111 | return 112 | #chunks x #hidden_dim 113 | ''' 114 | # create attention mask 115 | input_ids = input_ids.cuda() 116 | attention_mask = (input_ids != pad_id).long().cuda() 117 | sent_vecs = [] 118 | for batch_id in range(input_ids.shape[0] // batch_size + 1): 119 | batch_sid = batch_id * batch_size 120 | batch_eid = min((batch_id+1) * batch_size, input_ids.shape[0]) 121 | try: 122 | outputs = model(input_ids=input_ids[batch_sid:batch_eid], 123 | attention_mask=attention_mask[batch_sid:batch_eid], 124 | output_attentions=False, 125 | output_hidden_states=False) 126 | except: 127 | pdb.set_trace() 128 | this_batch_sent_vecs = outputs.last_hidden_state[:,0,:].detach() # batch_size x hidden_dim 129 | sent_vecs.append(this_batch_sent_vecs.cpu()) 130 | 131 | sent_vecs = torch.cat(sent_vecs) # #chunks x #hidden_dim 132 | return sent_vecs 133 | 134 | def encode_file(args, tokenize_only=False): 135 | 136 | try: 137 | with open(os.path.join(args.input_path), "rb") as f: 138 | data = pickle.load(f) 139 | except: 140 | raise FileNotFoundError, "invalid input file path" 141 | 142 | all_data = {} 143 | for book_id in tqdm(data): 144 | all_data[book_id] = [] 145 | for example in data[book_id]: 146 | ctx = example['ctx'] 147 | if args.suffix_size != args.chunk_size: 148 | pos = tokenizer.decode(tokenizer.encode(example['pos'])[1:-1][:args.suffix_size]) 149 | negs = [tokenizer.decode(tokenizer.encode(neg)[1:-1][:args.suffix_size]) \ 150 | for neg in example['negs']] 151 | else: 152 | pos = example['pos'] 153 | negs = example['negs'] 154 | 155 | ctx_ids = tokenize(args, ctx, chunk_size=args.chunk_size) 156 | pos_ids = tokenize(args, pos, chunk_size=args.suffix_size, single_chunk=True) 157 | negs_ids = [tokenize(args, neg, chunk_size=args.suffix_size, single_chunk=True) for neg in negs] 158 | all_data[book_id].append( 159 | { 160 | 'ctx_vec': encode(ctx_ids) if not tokenize_only else ctx_ids, 161 | 'pos_vec': encode(pos_ids) if not tokenize_only else pos_ids, 162 | 'negs_vec': [encode(neg_ids) for neg_ids in negs_ids] if not tokenize_only else negs_ids 163 | } 164 | ) 165 | 166 | with open(os.path.join(args.output_path), "wb") as f: 167 | pickle.dump(all_data, f, protocol=pickle.HIGHEST_PROTOCOL) 168 | 169 | 170 | def main(): 171 | # if tokenize_only, only run tokenizer, else extract [cls] vector from roberta-base output 172 | encode_file(args, tokenize_only=args.tokenize_only) 173 | 174 | if __name__ == "__main__": 175 | main() 176 | -------------------------------------------------------------------------------- /preprocess/tokenize_pg19_train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | tokenize PG19, store input_ids 4 | 5 | ''' 6 | 7 | import os 8 | import pdb 9 | import torch 10 | import random 11 | import pickle 12 | import torch 13 | import argparse 14 | from tqdm import tqdm 15 | from nltk.tokenize import sent_tokenize 16 | random.seed(42) 17 | from transformers import RobertaTokenizer, RobertaModel 18 | tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 19 | cls_id = tokenizer.convert_tokens_to_ids(tokenizer.cls_token) 20 | eos_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token) 21 | pad_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--input-path", type=str, default=None) 26 | parser.add_argument("--output-path", type=str, default=None) 27 | parser.add_argument("--shard-id", type=int, default=None) 28 | parser.add_argument("--shard-size", type=int, default=None) 29 | args = parser.parse_args() 30 | return args 31 | 32 | def tokenize(in_path, book_name, chunk_size=256): 33 | ''' 34 | store a list of tokenized sentence for each book 35 | ''' 36 | with open(os.path.join(in_path, book_name), "r") as f: 37 | data = ' '.join([l.strip() for l in f.readlines()]) 38 | 39 | sent_tok_data = sent_tokenize(data) 40 | 41 | this_book_sents = [] 42 | for sent in tqdm(sent_tok_data): 43 | sent_ids = tokenizer(sent)['input_ids'][1:-1] # get rid of the , added later when adding to this_book_ids 44 | this_book_sents.append(sent_ids) 45 | return this_book_sents 46 | 47 | 48 | def main(): 49 | args = parse_args() 50 | books = sorted(os.listdir(args.input_path)) 51 | 52 | shard_id = args.shard_id 53 | shard_size = args.shard_size 54 | this_books = books[shard_size*shard_id:shard_size*(shard_id+1)] 55 | encoded_all = {} 56 | for book in this_books: 57 | book_id = book.strip(".txt") 58 | input_ids = tokenize(args.input_path, book) 59 | encoded_all[book_id] = {'input_ids': input_ids} 60 | 61 | with open(os.path.join(args.output_path, f"{shard_id:04}.pkl"), 'wb') as f: 62 | pickle.dump(encoded_all, f, protocol=pickle.HIGHEST_PROTOCOL) 63 | 64 | if __name__ == "__main__": 65 | main() 66 | 67 | 68 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _openmp_mutex=4.5=1_gnu 6 | aiohttp=3.8.1=pypi_0 7 | aiosignal=1.2.0=pypi_0 8 | apex=0.1=pypi_0 9 | async-timeout=4.0.2=pypi_0 10 | attrs=21.3.0=pypi_0 11 | blas=1.0=mkl 12 | blis=0.7.7=pypi_0 13 | bzip2=1.0.8=h7b6447c_0 14 | ca-certificates=2021.7.5=h06a4308_1 15 | catalogue=2.0.7=pypi_0 16 | certifi=2021.5.30=py38h06a4308_0 17 | charset-normalizer=2.0.6=pypi_0 18 | click=8.0.1=pypi_0 19 | configparser=5.0.2=pypi_0 20 | cudatoolkit=11.1.74=h6bb024c_0 21 | cymem=2.0.6=pypi_0 22 | datasets=1.17.0=pypi_0 23 | dill=0.3.4=pypi_0 24 | docker-pycreds=0.4.0=pypi_0 25 | en-core-web-sm=3.2.0=pypi_0 26 | ffmpeg=4.3=hf484d3e_0 27 | filelock=3.1.0=pypi_0 28 | freetype=2.10.4=h5ab3b9f_0 29 | frozenlist=1.2.0=pypi_0 30 | fsspec=2021.11.1=pypi_0 31 | gitdb=4.0.7=pypi_0 32 | gitpython=3.1.24=pypi_0 33 | gmp=6.2.1=h2531618_2 34 | gnutls=3.6.15=he1e5248_0 35 | huggingface-hub=0.1.2=pypi_0 36 | idna=3.2=pypi_0 37 | intel-openmp=2021.3.0=h06a4308_3350 38 | jinja2=3.1.0=pypi_0 39 | joblib=1.0.1=pypi_0 40 | jpeg=9b=h024ee3a_2 41 | lame=3.100=h7b6447c_0 42 | langcodes=3.3.0=pypi_0 43 | lcms2=2.12=h3be6417_0 44 | ld_impl_linux-64=2.35.1=h7274673_9 45 | libffi=3.3=he6710b0_2 46 | libgcc-ng=9.3.0=h5101ec6_17 47 | libgomp=9.3.0=h5101ec6_17 48 | libiconv=1.15=h63c8f33_5 49 | libidn2=2.3.2=h7f8727e_0 50 | libpng=1.6.37=hbc83047_0 51 | libstdcxx-ng=9.3.0=hd4cf53a_17 52 | libtasn1=4.16.0=h27cfd23_0 53 | libtiff=4.2.0=h85742a9_0 54 | libunistring=0.9.10=h27cfd23_0 55 | libuv=1.40.0=h7b6447c_0 56 | libwebp-base=1.2.0=h27cfd23_0 57 | lz4-c=1.9.3=h295c915_1 58 | markupsafe=2.1.1=pypi_0 59 | mkl=2021.3.0=h06a4308_520 60 | mkl-service=2.4.0=py38h7f8727e_0 61 | mkl_fft=1.3.0=py38h42c9631_2 62 | mkl_random=1.2.2=py38h51133e4_0 63 | multidict=5.2.0=pypi_0 64 | multiprocess=0.70.12.2=pypi_0 65 | murmurhash=1.0.6=pypi_0 66 | ncurses=6.2=he6710b0_1 67 | nettle=3.7.3=hbbd107a_1 68 | ninja=1.10.2=hff7bd54_1 69 | nltk=3.6.3=pypi_0 70 | numpy=1.20.3=py38hf144106_0 71 | numpy-base=1.20.3=py38h74d4b33_0 72 | olefile=0.46=pyhd3eb1b0_0 73 | openh264=2.1.0=hd408876_0 74 | openjpeg=2.4.0=h3ad879b_0 75 | openssl=1.1.1l=h7f8727e_0 76 | packaging=21.0=pypi_0 77 | pandas=1.3.5=pypi_0 78 | pathtools=0.1.2=pypi_0 79 | pathy=0.6.1=pypi_0 80 | pillow=8.3.1=py38h2c7a002_0 81 | pip=21.0.1=py38h06a4308_0 82 | preshed=3.0.6=pypi_0 83 | promise=2.3=pypi_0 84 | protobuf=3.18.0=pypi_0 85 | psutil=5.8.0=pypi_0 86 | pyarrow=6.0.1=pypi_0 87 | pydantic=1.8.2=pypi_0 88 | pyparsing=2.4.7=pypi_0 89 | python=3.8.11=h12debd9_0_cpython 90 | python-dateutil=2.8.2=pypi_0 91 | pytz=2021.3=pypi_0 92 | pyyaml=5.4.1=pypi_0 93 | readline=8.1=h27cfd23_0 94 | regex=2021.9.24=pypi_0 95 | requests=2.26.0=pypi_0 96 | sacremoses=0.0.46=pypi_0 97 | scikit-learn=1.0.2=pypi_0 98 | scipy=1.7.1=pypi_0 99 | sentencepiece=0.1.96=pypi_0 100 | sentry-sdk=1.4.3=pypi_0 101 | setuptools=58.0.4=py38h06a4308_0 102 | shortuuid=1.0.1=pypi_0 103 | six=1.16.0=pyhd3eb1b0_0 104 | sklearn=0.0=pypi_0 105 | smart-open=5.2.1=pypi_0 106 | smmap=4.0.0=pypi_0 107 | spacy=3.2.3=pypi_0 108 | spacy-legacy=3.0.9=pypi_0 109 | spacy-loggers=1.0.1=pypi_0 110 | sqlite=3.36.0=hc218d9a_0 111 | srsly=2.4.2=pypi_0 112 | subprocess32=3.5.4=pypi_0 113 | termcolor=1.1.0=pypi_0 114 | thinc=8.0.15=pypi_0 115 | threadpoolctl=3.1.0=pypi_0 116 | tk=8.6.11=h1ccaba5_0 117 | tokenizers=0.10.3=pypi_0 118 | torch=1.9.1+cu111=pypi_0 119 | torchaudio=0.9.1=pypi_0 120 | torchvision=0.10.1+cu111=pypi_0 121 | tqdm=4.62.3=pypi_0 122 | transformers=4.13.0.dev0=dev_0 123 | typer=0.4.0=pypi_0 124 | typing_extensions=3.10.0.2=pyh06a4308_0 125 | urllib3=1.26.7=pypi_0 126 | wandb=0.12.2=pypi_0 127 | wasabi=0.9.0=pypi_0 128 | wheel=0.37.0=pyhd3eb1b0_1 129 | xxhash=2.0.2=pypi_0 130 | xz=5.2.5=h7b6447c_0 131 | yarl=1.7.2=pypi_0 132 | yaspin=2.1.0=pypi_0 133 | zlib=1.2.11=h7b6447c_3 134 | zstd=1.4.9=haebb681_0 -------------------------------------------------------------------------------- /slurm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import pdb 8 | from logging import getLogger 9 | import os 10 | import sys 11 | import torch 12 | import socket 13 | import signal 14 | import subprocess 15 | import random 16 | 17 | 18 | logger = getLogger() 19 | 20 | 21 | def sig_handler(signum, frame): 22 | logger.warning("Signal handler called with signal " + str(signum)) 23 | prod_id = int(os.environ['SLURM_PROCID']) 24 | logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id)) 25 | if prod_id == 0: 26 | logger.warning("Requeuing job " + os.environ['SLURM_JOB_ID']) 27 | os.system('scontrol requeue ' + os.environ['SLURM_JOB_ID']) 28 | else: 29 | logger.warning("Not the master process, no need to requeue.") 30 | sys.exit(-1) 31 | 32 | 33 | def term_handler(signum, frame): 34 | logger.warning("Signal handler called with signal " + str(signum)) 35 | logger.warning("Bypassing SIGTERM.") 36 | 37 | 38 | def init_signal_handler(): 39 | """ 40 | Handle signals sent by SLURM for time limit / pre-emption. 41 | """ 42 | signal.signal(signal.SIGUSR1, sig_handler) 43 | signal.signal(signal.SIGTERM, term_handler) 44 | logger.warning("Signal handler installed.") 45 | 46 | 47 | def init_distributed_mode(params): 48 | """ 49 | Handle single and multi-GPU / multi-node / SLURM jobs. 50 | Initialize the following variables: 51 | - n_nodes 52 | - node_id 53 | - local_rank 54 | - global_rank 55 | - world_size 56 | """ 57 | #params.is_slurm_job = True if not params.debug else False 58 | params.is_slurm_job = False 59 | # SLURM job 60 | if params.is_slurm_job: 61 | 62 | #assert params.local_rank == -1 # on the cluster, this is handled by SLURM 63 | 64 | SLURM_VARIABLES = [ 65 | 'SLURM_JOB_ID', 66 | 'SLURM_JOB_NODELIST', 'SLURM_JOB_NUM_NODES', 'SLURM_NTASKS', 'SLURM_TASKS_PER_NODE', 67 | 'SLURM_MEM_PER_NODE', 'SLURM_MEM_PER_CPU', 68 | 'SLURM_NODEID', 'SLURM_PROCID', 'SLURM_LOCALID', 'SLURM_TASK_PID' 69 | ] 70 | 71 | PREFIX = "%i - " % int(os.environ['SLURM_PROCID']) 72 | for name in SLURM_VARIABLES: 73 | value = os.environ.get(name, None) 74 | print(PREFIX + "%s: %s" % (name, str(value))) 75 | 76 | # # job ID 77 | # params.job_id = os.environ['SLURM_JOB_ID'] 78 | 79 | # number of nodes / node ID 80 | params.n_nodes = int(os.environ['SLURM_JOB_NUM_NODES']) 81 | params.node_id = int(os.environ['SLURM_NODEID']) 82 | 83 | # local rank on the current node / global rank 84 | params.local_rank = int(os.environ['SLURM_LOCALID']) 85 | params.global_rank = int(os.environ['SLURM_PROCID']) 86 | 87 | # number of processes / GPUs per node 88 | params.world_size = int(os.environ['SLURM_NTASKS']) 89 | params.n_gpu_per_node = params.world_size // params.n_nodes 90 | 91 | # define master address and master port 92 | hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']]) 93 | params.master_addr = hostnames.split()[0].decode('utf-8') 94 | #params.master_port = 20091#random.randint(10001, 20000) 95 | #assert 10001 <= params.master_port <= 20000 or params.world_size == 1 96 | print(PREFIX + "Master address: %s" % params.master_addr) 97 | print(PREFIX + "Master port : %i" % params.master_port) 98 | 99 | # set environment variables for 'env://' 100 | os.environ['MASTER_ADDR'] = params.master_addr 101 | os.environ['MASTER_PORT'] = str(params.master_port) 102 | os.environ['WORLD_SIZE'] = str(params.world_size) 103 | os.environ['RANK'] = str(params.global_rank) 104 | 105 | # multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch 106 | elif params.local_rank != -1: 107 | 108 | assert params.master_port == -1 109 | 110 | # read environment variables 111 | params.global_rank = int(os.environ['RANK']) 112 | params.world_size = int(os.environ['WORLD_SIZE']) 113 | params.n_gpu_per_node = int(os.environ['NGPU']) 114 | 115 | # number of nodes / node ID 116 | params.n_nodes = params.world_size // params.n_gpu_per_node 117 | params.node_id = params.global_rank // params.n_gpu_per_node 118 | 119 | # local job (single GPU) 120 | else: 121 | assert params.local_rank == -1 122 | assert params.master_port == -1 123 | params.n_nodes = 1 124 | params.node_id = 0 125 | params.local_rank = 0 126 | params.global_rank = 0 127 | params.world_size = 1 128 | params.n_gpu_per_node = 1 129 | 130 | # sanity checks 131 | assert params.n_nodes >= 1 132 | assert 0 <= params.node_id < params.n_nodes 133 | assert 0 <= params.local_rank <= params.global_rank < params.world_size 134 | assert params.world_size == params.n_nodes * params.n_gpu_per_node 135 | 136 | # define whether this is the master process / if we are in distributed mode 137 | params.is_master = params.node_id == 0 and params.local_rank == 0 138 | params.multi_node = params.n_nodes > 1 139 | params.multi_gpu = params.world_size > 1 140 | 141 | # summary 142 | PREFIX = "%i - " % params.global_rank 143 | print(PREFIX + "Number of nodes: %i" % params.n_nodes) 144 | print(PREFIX + "Node ID : %i" % params.node_id) 145 | print(PREFIX + "Local rank : %i" % params.local_rank) 146 | print(PREFIX + "Global rank : %i" % params.global_rank) 147 | print(PREFIX + "World size : %i" % params.world_size) 148 | print(PREFIX + "GPUs per node : %i" % params.n_gpu_per_node) 149 | print(PREFIX + "Master : %s" % str(params.is_master)) 150 | print(PREFIX + "Multi-node : %s" % str(params.multi_node)) 151 | print(PREFIX + "Multi-GPU : %s" % str(params.multi_gpu)) 152 | print(PREFIX + "Hostname : %s" % socket.gethostname()) 153 | 154 | # set GPU device 155 | torch.cuda.set_device(params.local_rank) 156 | 157 | # initialize multi-GPU 158 | # if params.multi_gpu: 159 | 160 | # http://pytorch.apachecn.org/en/0.3.0/distributed.html#environment-variable-initialization 161 | # 'env://' will read these environment variables: 162 | # MASTER_PORT - required; has to be a free port on machine with rank 0 163 | # MASTER_ADDR - required (except for rank 0); address of rank 0 node 164 | # WORLD_SIZE - required; can be set either here, or in a call to init function 165 | # RANK - required; can be set either here, or in a call to init function 166 | 167 | print("Initializing PyTorch distributed ...") 168 | torch.distributed.init_process_group( 169 | init_method='env://', 170 | backend='nccl', 171 | ) 172 | 173 | 174 | -------------------------------------------------------------------------------- /tokenize_eval_data.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | import pickle 5 | import argparse 6 | from tqdm import tqdm 7 | import torch 8 | import random 9 | import torch 10 | from nltk.tokenize import sent_tokenize 11 | torch.set_num_threads(1) 12 | random.seed(42) 13 | from transformers import RobertaTokenizer, RobertaModel, AutoModel, AutoConfig 14 | tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 15 | cls_id = tokenizer.convert_tokens_to_ids(tokenizer.cls_token) 16 | eos_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token) 17 | pad_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--input-path", type=str, default=None) 22 | parser.add_argument("--output-path", type=str, default=None) 23 | parser.add_argument("--action", type=str, default=None) 24 | parser.add_argument("--chunk-size", type=int, default=128) 25 | parser.add_argument("--suffix-size", type=int, default=128) 26 | parser.add_argument("--tokenize-only", action="store_true", default=False) 27 | args = parser.parse_args() 28 | return args 29 | 30 | args = parse_args() 31 | 32 | if not args.tokenize_only: 33 | model = RobertaModel.from_pretrained('roberta-base') 34 | mode = model.cuda() 35 | for param in model.parameters(): 36 | param.requires_grad = False 37 | 38 | def tokenize(args, data, chunk_size=128, single_chunk=False): 39 | ''' 40 | tokenize the data with roberta tokenizer and chunk input to chunk_size 41 | each chunk starts from a new sentence unless content from previous chunk 42 | overflows the current chunk 43 | ''' 44 | sent_tok_data = sent_tokenize(data) 45 | segment_ids = [] 46 | this_book_ids = [] 47 | sent_chunk_map = {} 48 | for si, sent in enumerate(sent_tok_data): 49 | sent_ids = tokenizer(sent)['input_ids'][1:-1] # get rid of the , added later when adding to this_book_ids 50 | if len(segment_ids) + len(sent_ids) < chunk_size - 1: 51 | segment_ids.extend(sent_ids) 52 | 53 | else: 54 | if len(segment_ids) != 0 and len(segment_ids) < chunk_size - 1: 55 | segment_ids += [eos_id] 56 | segment_ids.extend([pad_id] * (chunk_size - len(segment_ids) - 1)) 57 | segment_ids = [cls_id] + segment_ids 58 | assert len(segment_ids) == chunk_size, pdb.set_trace() 59 | this_book_ids.append(segment_ids) 60 | segment_ids = sent_ids 61 | 62 | if len(segment_ids) == 0 or len(segment_ids) >= chunk_size - 1: 63 | # split the sentence into multiple chunks of chunk_size 64 | this_ids = sent_ids if len(segment_ids) == 0 else segment_ids 65 | num_sent_chunks = len(this_ids) // (chunk_size - 1) + 1 66 | for chunk_id in range(num_sent_chunks): 67 | sid = chunk_id * (chunk_size - 1) 68 | eid = min((chunk_id+1) * (chunk_size - 1), len(this_ids)) 69 | if eid == (chunk_id+1) * (chunk_size - 1): 70 | this_book_ids.append([cls_id] + this_ids[sid:eid]) 71 | else: 72 | this_book_ids.append([cls_id] + this_ids[sid:eid] + [eos_id] + [pad_id] * (chunk_size - (eid-sid) - 2)) 73 | segment_ids = [] 74 | 75 | sent_chunk_map[si] = len(this_book_ids) 76 | 77 | if len(segment_ids) != 0: 78 | segment_ids += [eos_id] 79 | segment_ids.extend([pad_id] * (chunk_size - len(segment_ids) - 1)) 80 | segment_ids = [cls_id] + segment_ids 81 | assert len(segment_ids) == chunk_size, pdb.set_trace() 82 | this_book_ids.append(segment_ids) 83 | 84 | sent_chunk_map[si] = len(this_book_ids) 85 | 86 | if chunk_size != args.chunk_size: 87 | if chunk_size > args.chunk_size: 88 | for i in range(len(this_book_ids)): 89 | this_book_ids[i] = this_book_ids[i][:args.chunk_size] 90 | else: 91 | for i in range(len(this_book_ids)): 92 | this_book_ids[i] = this_book_ids[i] + [pad_id] * (args.chunk_size - len(this_book_ids[i])) 93 | assert len(this_book_ids[i]) == args.chunk_size 94 | 95 | if single_chunk: 96 | if len(this_book_ids) != 1: 97 | try: 98 | this_book_ids = [this_book_ids[0]] 99 | except: 100 | pdb.set_trace() 101 | else: 102 | assert all(len(x) == chunk_size for x in this_book_ids), pdb.set_trace() 103 | 104 | return torch.tensor(this_book_ids) 105 | 106 | 107 | def encode(input_ids, batch_size=64): 108 | ''' 109 | input_ids: #chunks x chunk_size(256) 110 | return 111 | #chunks x #hidden_dim 112 | ''' 113 | # create attention mask 114 | input_ids = input_ids.cuda() 115 | attention_mask = (input_ids != pad_id).long().cuda() 116 | sent_vecs = [] 117 | for batch_id in range(input_ids.shape[0] // batch_size + 1): 118 | batch_sid = batch_id * batch_size 119 | batch_eid = min((batch_id+1) * batch_size, input_ids.shape[0]) 120 | try: 121 | outputs = model(input_ids=input_ids[batch_sid:batch_eid], 122 | attention_mask=attention_mask[batch_sid:batch_eid], 123 | output_attentions=False, 124 | output_hidden_states=False) 125 | except: 126 | pdb.set_trace() 127 | this_batch_sent_vecs = outputs.last_hidden_state[:,0,:].detach() # batch_size x hidden_dim 128 | sent_vecs.append(this_batch_sent_vecs.cpu()) 129 | 130 | sent_vecs = torch.cat(sent_vecs) # #chunks x #hidden_dim 131 | return sent_vecs 132 | 133 | def tokenize_encode(split_name, args, data): 134 | all_data = {} 135 | for book_id in tqdm(data): 136 | all_data[book_id] = [] 137 | for example in data[book_id]: 138 | ctx = example['ctx'] 139 | if args.suffix_size != args.chunk_size: 140 | pos = tokenizer.decode(tokenizer.encode(example['pos'])[1:-1][:args.suffix_size]) 141 | negs = [tokenizer.decode(tokenizer.encode(neg)[1:-1][:args.suffix_size]) \ 142 | for neg in example['negs']] 143 | else: 144 | pos = example['pos'] 145 | negs = example['negs'] 146 | 147 | ctx_ids = tokenize(args, ctx, chunk_size=args.chunk_size) 148 | pos_ids = tokenize(args, pos, chunk_size=args.suffix_size, single_chunk=True) 149 | negs_ids = [tokenize(args, neg, chunk_size=args.suffix_size, single_chunk=True) for neg in negs] 150 | all_data[book_id].append( 151 | { 152 | 'ctx_vec': encode(ctx_ids) if not args.tokenize_only else ctx_ids, 153 | 'pos_vec': encode(pos_ids) if not args.tokenize_only else pos_ids, 154 | 'negs_vec': [encode(neg_ids) for neg_ids in negs_ids] if not args.tokenize_only else negs_ids 155 | } 156 | ) 157 | 158 | ctx_size = os.path.basename(args.input_path).strip('.json').split('_')[-1] 159 | print('ctx_size') 160 | with open(os.path.join(args.output_path, f'{split_name}_ctx{ctx_size}.pkl'), "wb") as f: 161 | pickle.dump(all_data, f, protocol=pickle.HIGHEST_PROTOCOL) 162 | 163 | def encode_file(args, tokenize_only=False): 164 | 165 | try: 166 | with open(os.path.join(args.input_path), "r") as f: 167 | data = f.read() 168 | data = json.loads(data) 169 | except: 170 | raise FileNotFoundError 171 | 172 | tokenize_encode('ao3', args, data['ao3']) 173 | tokenize_encode('pg19', args, data['pg19']) 174 | 175 | def main(): 176 | # if tokenize_only, only run tokenizer, else extract [cls] vector from roberta-base output 177 | encode_file(args, tokenize_only=args.tokenize_only) 178 | 179 | if __name__ == "__main__": 180 | main() 181 | 182 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Various utilities 3 | ''' 4 | import contextlib 5 | import io 6 | import random 7 | import sys 8 | import threading 9 | from itertools import tee, zip_longest 10 | from subprocess import check_output, CalledProcessError 11 | from data.utils import * 12 | 13 | import numpy as np 14 | import torch 15 | from tqdm import tqdm 16 | from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler 17 | 18 | 19 | 20 | INF = float('inf') 21 | NEG_INF = float('-inf') 22 | CHUNK_SIZES = [32, 64, 128, 256, 512] 23 | 24 | def prepare_data(args): 25 | 26 | dl_lst, vdl_lst = [], [] 27 | 28 | def construct_dl(Dataset, chunk_size=128, fileids=None): 29 | kwargs = {"chunk_size": chunk_size} 30 | kwargs['max_books'] = 500 if args.debug else args.max_books 31 | if fileids is not None: 32 | kwargs['fileids'] = fileids 33 | # build dataset and dataloader 34 | ds = Dataset(args, 35 | args.data_path, 36 | args.max_tokens_per_batch // chunk_size, 37 | args.split, 38 | **kwargs) 39 | vds = Dataset(args, args.data_path, 40 | args.max_tokens_per_batch // chunk_size, 41 | "valid-tok", 42 | **kwargs) 43 | print(f"num_chunks {args.max_tokens_per_batch // chunk_size} chunk_size {chunk_size}") 44 | 45 | # using SequentialSampler to get most in-book negatives 46 | sampler = SequentialSampler(ds) 47 | 48 | dl = DataLoader(ds, sampler=sampler, batch_size=args.batch_size) 49 | vdl = DataLoader(vds, sampler=SequentialSampler(vds), batch_size=args.batch_size) 50 | return dl, vdl 51 | 52 | CHUNK_SIZES = args.chunk_size_list 53 | 54 | file_lst = [fn for fn in os.listdir(os.path.join(args.data_path, args.split)) if 'data' not in fn] 55 | file_lst += [fn for fn in os.listdir(os.path.join(args.data_path, args.split.replace('train', 'valid'))) if 'data' not in fn] 56 | if args.debug: 57 | file_lst = file_lst[:10] 58 | random.shuffle(file_lst) 59 | shard_size = len(file_lst) // len(CHUNK_SIZES) 60 | for i, chunk_size in enumerate(CHUNK_SIZES): 61 | files = file_lst[i*shard_size:(i+1)*shard_size] 62 | if len(files) == 0: 63 | continue 64 | print(f"constructing dataloader for chunk_size {chunk_size}") 65 | dl, vdl = construct_dl(TextDataset, chunk_size=chunk_size, 66 | fileids=files) 67 | dl_lst.append(dl) 68 | vdl_lst.append(vdl) 69 | 70 | return dl_lst, vdl_lst 71 | 72 | 73 | # pylint:disable=line-too-long 74 | def ceildiv(x, y): 75 | ''' https://stackoverflow.com/questions/14822184/is-there-a-ceiling-equivalent-of-operator-in-python#17511341 ''' 76 | return -(-x // y) 77 | # pylint:enable=line-too-long 78 | 79 | 80 | def pairwise(iterable, longest=False): 81 | ''' 82 | See itertools recipes: 83 | https://docs.python.org/3/library/itertools.html#itertools-recipes 84 | s -> (s0,s1), (s1,s2), (s2, s3), ... 85 | ''' 86 | x, y = tee(iterable) 87 | next(y, None) 88 | zip_func = zip_longest if longest else zip 89 | return zip_func(x, y) 90 | 91 | 92 | def grouper(iterable, n, fillvalue=None, padded=False): # pylint:disable=invalid-name 93 | ''' 94 | See itertools recipes: 95 | https://docs.python.org/3/library/itertools.html#itertools-recipes 96 | Collect data into fixed-length chunks or blocks 97 | ''' 98 | args = [iter(iterable)] * n 99 | groups = zip_longest(*args, fillvalue=fillvalue) 100 | if padded: 101 | # keep the fill value 102 | return groups 103 | else: 104 | # ignore the fill value 105 | return [[x for x in group if x is not fillvalue] for group in groups] 106 | 107 | 108 | def partition(seq, num): 109 | ''' Partition a sequence into num equal parts (potentially except for the last slice) ''' 110 | return [seq[i:i + num] for i in range(0, len(seq), num)] 111 | 112 | 113 | def divvy(num, chunks): 114 | ''' Divvy a number into an array of equal sized chunks ''' 115 | chunk_mod = (num % chunks) 116 | chunk_size = num // chunks 117 | return [chunk_size + 1] * chunk_mod + [chunk_size] * (chunks - chunk_mod) 118 | 119 | 120 | def triu(inputs, diagonal=0, span=1, stride=1, offset=0): 121 | ''' 122 | Returns an upper triangular matrix, but allows span which determines how many contiguous 123 | elements of the matrix to consider as a "single" number, e.g. 124 | 125 | >> triu(torch.full((8, 8), float('-inf'), 1, 2) 126 | tensor([[0.0000, 0.0000, -inf, -inf, -inf, -inf, -inf, -inf], 127 | [0.0000, 0.0000, -inf, -inf, -inf, -inf, -inf, -inf], 128 | [0.0000, 0.0000, 0.0000, 0.0000, -inf, -inf, -inf, -inf], 129 | [0.0000, 0.0000, 0.0000, 0.0000, -inf, -inf, -inf, -inf], 130 | [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -inf, -inf], 131 | [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -inf, -inf], 132 | [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], 133 | [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]) 134 | ''' 135 | for i, row in enumerate(inputs): 136 | row[:span * (diagonal + i // stride) + offset] = 0. 137 | 138 | return inputs 139 | 140 | 141 | def get_version_string(): 142 | ''' Return a git version string for the repo ''' 143 | try: 144 | version = check_output(['git', 'describe', '--always', '--dirty'], encoding='utf-8') 145 | except CalledProcessError: 146 | raise RuntimeError('Call to "git describe" failed!') 147 | 148 | return version 149 | 150 | 151 | def to_numpy_dtype(dtype): 152 | ''' Convert a torch dtype to a numpy dtype ''' 153 | return np.dtype(dtype.__reduce__().replace('torch.', '')) 154 | 155 | 156 | def left_pad(x, dim=-1, count=1, fill=0): 157 | ''' left pad the given tensor ''' 158 | if not count: 159 | return x 160 | 161 | shape = list(x.shape) 162 | dims = len(shape) 163 | dim = dim % dims 164 | fill_shape = shape[:dim] + [count] + shape[dim + 1:] 165 | return torch.cat((x.new_full(fill_shape, fill), x), dim) 166 | 167 | 168 | def right_pad(x, dim=-1, count=1, fill=0): 169 | ''' right pad the given tensor ''' 170 | if not count: 171 | return x 172 | 173 | shape = list(x.shape) 174 | dims = len(shape) 175 | dim = dim % dims 176 | fill_shape = shape[:dim] + [count] + shape[dim + 1:] 177 | return torch.cat((x, x.new_full(fill_shape, fill)), dim) 178 | 179 | 180 | def left_shift(x, dim=-1, shift=1, fill=None): 181 | ''' left shift the given tensor ''' 182 | if not shift: 183 | return x 184 | 185 | if fill is not None: 186 | x = right_pad(x, dim, shift, fill) 187 | 188 | shape = list(x.shape) 189 | dims = len(shape) 190 | dim = dim % dims 191 | return x[tuple(slice(shift if d == dim else 0, s + shift) for d, s in enumerate(shape))] 192 | 193 | 194 | def right_shift(x, dim=-1, shift=1, fill=None): 195 | ''' Right shift the given tensor ''' 196 | if not shift: 197 | return x 198 | 199 | if fill is not None: 200 | x = left_pad(x, dim, shift, fill) 201 | 202 | shape = list(x.shape) 203 | dims = len(shape) 204 | dim = dim % dims 205 | return x[tuple(slice(-shift if d == dim else s) for d, s in enumerate(shape))] 206 | 207 | 208 | def same_tensor(tensor, *args): 209 | ''' Do the input tensors all point to the same underlying data ''' 210 | for other in args: 211 | if not torch.is_tensor(other): 212 | return False 213 | 214 | if tensor.device != other.device: 215 | return False 216 | 217 | if tensor.dtype != other.dtype: 218 | return False 219 | 220 | if tensor.data_ptr() != other.data_ptr(): 221 | return False 222 | 223 | return True 224 | 225 | 226 | class TQDMStreamWrapper(io.IOBase): 227 | ''' A wrapper around an existing IO stream to funnel to tqdm ''' 228 | def __init__(self, stream): 229 | ''' Initialize the stream wrapper ''' 230 | super(TQDMStreamWrapper, self).__init__() 231 | self.stream = stream 232 | 233 | def write(self, line): 234 | ''' Potentially write to the stream ''' 235 | if line.rstrip(): # avoid printing empty lines (only whitespace) 236 | tqdm.write(line, file=self.stream) 237 | 238 | 239 | _STREAMS = threading.local() 240 | _STREAMS.stdout_stack = [] 241 | @contextlib.contextmanager 242 | def tqdm_wrap_stdout(): 243 | ''' Wrap a sys.stdout and funnel it to tqdm.write ''' 244 | _STREAMS.stdout_stack.append(sys.stdout) 245 | sys.stdout = TQDMStreamWrapper(sys.stdout) 246 | yield 247 | sys.stdout = _STREAMS.stdout_stack.pop() 248 | 249 | 250 | @contextlib.contextmanager 251 | def tqdm_unwrap_stdout(): 252 | ''' Unwrap a tqdm.write and funnel it to sys.stdout ''' 253 | saved = sys.stdout 254 | sys.stdout = _STREAMS.stdout_stack.pop() 255 | yield 256 | _STREAMS.stdout_stack.append(sys.stdout) 257 | sys.stdout = saved 258 | 259 | 260 | # Recursively split or chunk the given data structure. split_or_chunk is based on 261 | # torch.nn.parallel.scatter_gather.scatter 262 | def split_or_chunk(inputs, num_chunks_or_sections, dim=0): 263 | r""" 264 | Splits tensors into approximately equal chunks or specified chunk sizes (based on the 265 | 'num_chunks_or_sections'). Duplicates references to objects that are not tensors. 266 | """ 267 | def split_map(obj): 268 | if isinstance(obj, torch.Tensor): 269 | if isinstance(num_chunks_or_sections, int): 270 | return torch.chunk(obj, num_chunks_or_sections, dim=dim) 271 | else: 272 | return torch.split(obj, num_chunks_or_sections, dim=dim) 273 | if isinstance(obj, tuple) and obj: 274 | return list(zip(*map(split_map, obj))) 275 | if isinstance(obj, list) and obj: 276 | return list(map(list, zip(*map(split_map, obj)))) 277 | if isinstance(obj, dict) and obj: 278 | return list(map(type(obj), zip(*map(split_map, obj.items())))) 279 | if isinstance(num_chunks_or_sections, int): 280 | return [obj for chunk in range(num_chunks_or_sections)] 281 | else: 282 | return [obj for chunk in num_chunks_or_sections] 283 | 284 | # After split_map is called, a split_map cell will exist. This cell 285 | # has a reference to the actual function split_map, which has references 286 | # to a closure that has a reference to the split_map cell (because the 287 | # fn is recursive). To avoid this reference cycle, we set the function to 288 | # None, clearing the cell 289 | try: 290 | return split_map(inputs) 291 | finally: 292 | split_map = None 293 | 294 | 295 | # Recursively split or chunk the given data structure. split_or_chunk is based on 296 | # torch.nn.parallel.scatter_gather.gather 297 | def cat(outputs, dim=0): 298 | r""" 299 | Concatenates tensors recursively in collections. 300 | """ 301 | def cat_map(outputs): 302 | out = outputs[0] 303 | if isinstance(out, torch.Tensor): 304 | return torch.cat(outputs, dim=dim) 305 | if out is None: 306 | return None 307 | if isinstance(out, dict): 308 | if not all((len(out) == len(d) for d in outputs)): 309 | raise ValueError('All dicts must have the same number of keys') 310 | return type(out)(((k, cat_map([d[k] for d in outputs])) 311 | for k in out)) 312 | return type(out)(map(cat_map, zip(*outputs))) 313 | 314 | # Recursive function calls like this create reference cycles. 315 | # Setting the function to None clears the refcycle. 316 | try: 317 | return cat_map(outputs) 318 | finally: 319 | cat_map = None 320 | 321 | 322 | def get_random_seed_fn(seed, cuda=True): 323 | ''' Return a function that sets a random seed ''' 324 | def set_random_seed(worker_id=0): # pylint:disable=unused-argument 325 | random.seed(seed) 326 | np.random.seed(seed) 327 | torch.manual_seed(seed) 328 | if cuda and torch.cuda.is_available(): 329 | torch.cuda.manual_seed(seed) 330 | 331 | return set_random_seed 332 | 333 | 334 | 335 | 336 | 337 | 338 | --------------------------------------------------------------------------------