├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── datasets ├── arxiv-11 │ └── arxiv-11.py └── hyperpartisan_news_detection │ ├── dev.jsonl │ ├── test.jsonl │ └── train.jsonl ├── examples └── LRA │ ├── README.md │ ├── code │ ├── attention.py │ ├── attention_fnet.py │ ├── attention_linear.py │ ├── attention_linformer.py │ ├── attention_nystrom.py │ ├── attention_performer.py │ ├── attention_ponet.py │ ├── attention_reformer.py │ ├── dataset.py │ ├── lra_config.py │ ├── model.py │ ├── model_wrapper.py │ ├── run_tasks.py │ └── run_tasks.sh │ └── datasets │ ├── cifar10.py │ ├── create_datasets.sh │ ├── delete_repeat.sh │ ├── listops.py │ ├── pathfinder.py │ ├── requirements.txt │ ├── retrieval.py │ └── text.py ├── extra ├── classifier_trainer.py └── dataset_dict.py ├── image ├── consumption.png ├── model.png └── performance.png ├── metrics ├── macro_f1_and_acc │ └── macro_f1_and_acc.py └── micro_f1_and_acc │ └── micro_f1_and_acc.py ├── requirements.txt ├── run_glue.py ├── run_long_classification.py ├── run_pretrained.py └── run_shell ├── 1-pretrain_bookcorpus_wikipedia.sh ├── 1-pretrain_bookcorpus_wikitext.sh ├── 2-GLUE.sh ├── 3-LongTask.sh └── D1-arxiv11.sh /.gitattributes: -------------------------------------------------------------------------------- 1 | outputs/ponet-base-uncased/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # General 2 | .DS_Store 3 | .AppleDouble 4 | .LSOverride 5 | 6 | # Icon must end with two \r 7 | Icon 8 | 9 | 10 | # Thumbnails 11 | ._* 12 | 13 | # Files that might appear in the root of a volume 14 | .DocumentRevisions-V100 15 | .fseventsd 16 | .Spotlight-V100 17 | .TemporaryItems 18 | .Trashes 19 | .VolumeIcon.icns 20 | .com.apple.timemachine.donotpresent 21 | 22 | # Directories potentially created on remote AFP share 23 | .AppleDB 24 | .AppleDesktop 25 | Network Trash Folder 26 | Temporary Items 27 | .apdisk 28 | 29 | # Byte-compiled / optimized / DLL files 30 | __pycache__/ 31 | *.py[cod] 32 | *$py.class 33 | 34 | # C extensions 35 | *.so 36 | 37 | # Distribution / packaging 38 | .Python 39 | build/ 40 | develop-eggs/ 41 | dist/ 42 | downloads/ 43 | eggs/ 44 | .eggs/ 45 | lib/ 46 | lib64/ 47 | parts/ 48 | sdist/ 49 | var/ 50 | wheels/ 51 | share/python-wheels/ 52 | *.egg-info/ 53 | .installed.cfg 54 | *.egg 55 | MANIFEST 56 | 57 | # PyInstaller 58 | # Usually these files are written by a python script from a template 59 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 60 | *.manifest 61 | *.spec 62 | 63 | # Installer logs 64 | pip-log.txt 65 | pip-delete-this-directory.txt 66 | 67 | # Unit test / coverage reports 68 | htmlcov/ 69 | .tox/ 70 | .nox/ 71 | .coverage 72 | .coverage.* 73 | .cache 74 | nosetests.xml 75 | coverage.xml 76 | *.cover 77 | *.py,cover 78 | .hypothesis/ 79 | .pytest_cache/ 80 | cover/ 81 | 82 | # Translations 83 | *.mo 84 | *.pot 85 | 86 | # Django stuff: 87 | *.log 88 | local_settings.py 89 | db.sqlite3 90 | db.sqlite3-journal 91 | 92 | # Flask stuff: 93 | instance/ 94 | .webassets-cache 95 | 96 | # Scrapy stuff: 97 | .scrapy 98 | 99 | # Sphinx documentation 100 | docs/_build/ 101 | 102 | # PyBuilder 103 | .pybuilder/ 104 | target/ 105 | 106 | # Jupyter Notebook 107 | .ipynb_checkpoints 108 | 109 | # IPython 110 | profile_default/ 111 | ipython_config.py 112 | 113 | # pyenv 114 | # For a library or package, you might want to ignore these files since the code is 115 | # intended to run in multiple environments; otherwise, check them in: 116 | # .python-version 117 | 118 | # pipenv 119 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 120 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 121 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 122 | # install all needed dependencies. 123 | #Pipfile.lock 124 | 125 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 126 | __pypackages__/ 127 | 128 | # Celery stuff 129 | celerybeat-schedule 130 | celerybeat.pid 131 | 132 | # SageMath parsed files 133 | *.sage.py 134 | 135 | # Environments 136 | .env 137 | .venv 138 | env/ 139 | venv/ 140 | ENV/ 141 | env.bak/ 142 | venv.bak/ 143 | 144 | # Spyder project settings 145 | .spyderproject 146 | .spyproject 147 | 148 | # Rope project settings 149 | .ropeproject 150 | 151 | # mkdocs documentation 152 | /site 153 | 154 | # mypy 155 | .mypy_cache/ 156 | .dmypy.json 157 | dmypy.json 158 | 159 | # Pyre type checker 160 | .pyre/ 161 | 162 | # pytype static type analyzer 163 | .pytype/ 164 | 165 | # Cython debug symbols 166 | cython_debug/ 167 | 168 | .vscode/* 169 | !.vscode/settings.json 170 | !.vscode/tasks.json 171 | !.vscode/launch.json 172 | !.vscode/extensions.json 173 | *.code-workspace 174 | 175 | # Local History for Visual Studio Code 176 | .history/ 177 | 178 | datasets/arxiv-11/data 179 | outputs/* 180 | logs/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | 3 | > This is a code repository for our paper ***[PoNet: Pooling Network for Efficient Token Mixing in Long Sequences](https://arxiv.org/abs/2110.02442)***. The full source code has been released. 4 | 5 | Transformer-based models have achieved great success in various NLP, vision, and speech tasks. However, the core of Transformer, the self-attention mechanism, has a quadratic time and memory complexity with respect to the sequence length, which hinders applications of Transformer-based models to long sequences. Many approaches have been proposed to mitigate this problem, such as sparse attention mechanisms, low-rank matrix approximations and scalable kernels, and token mixing alternatives to self-attention. We propose a novel Pooling Network (PoNet) for token mixing in long sequences with linear complexity. We design multi-granularity pooling and pooling fusion to capture different levels of contextual information and combine their interactions with tokens. On the Long Range Arena benchmark, PoNet significantly outperforms Transformer and achieves competitive accuracy, while being only slightly slower than the fastest model, FNet, across all sequence lengths measured on GPUs. We also conduct systematic studies on the transfer learning capability of PoNet and observe that PoNet achieves 95.7% of the accuracy of BERT on the GLUE benchmark, outperforming FNet by 4.5% relative. Comprehensive ablation analysis demonstrates effectiveness of the designed multi-granularity pooling and pooling fusion for token mixing in long sequences and efficacy of the designed pre-training tasks for PoNet to learn transferable contextualized language representations. 6 | 7 |
8 | 9 |
10 | 11 |
12 | 13 | ## Instruction 14 | 15 | ##### Python environment 16 | 17 | The requirements package is in `requirements.txt`. 18 | 19 | If you are using Nvidia's GPU and CUDA version supports 11.7, you can use the following code to create the desired virtual Python environment: 20 | 21 | ```shell 22 | conda create -n ponet python=3.8 23 | conda activate ponet 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | ##### Special for Arxiv-11 Dataset 28 | 29 | The data can be obtained from `https://github.com/LiqunW/Long-document-dataset`. 30 | 31 | We also provided scripts to get it. Please refer to the shell file `run_shell/D1-arxiv11.sh`. 32 | 33 | ##### Run 34 | 35 | For Pre-train, GLUE and Long-Text, please refer to the shell files under the `run_shell` folder. 36 | 37 | For LRA, please refer to `examples/LRA/README.md`. 38 | 39 | ## Changelog 40 | 41 | - [x] [2023.05.22] 42 | - Source codes are submitted to ***[huggingface hub](https://huggingface.co/chtan/ponet-base-uncased)***. PoNet can now be used directly through the Transformers library. 43 | - Since the latest version of PyTorch supports `scatter_max` operations, we removed the third-party `pytorch-scatter` package and used the official functions instead. 44 | - Old codes are moved to ***[tag v1.0](https://github.com/lxchtan/PoNet/tree/v1.0)***. 45 | 46 | - [x] [2022.07.20] Add a brief introduction to the paper in README. 47 | 48 | - [x] [2022.07.09] The pretrained checkpoint is moved to GDrive. 49 | 50 | - [x] [2022.07.09] Release the source code 51 | - [x] [2021.10.19] Pre-train Tasks 52 | - [x] [2021.10.19] GLUE Tasks 53 | - [x] [2022.03.15] LRA Tasks 54 | - [x] [2022.07.09] Long-Text Tasks 55 | - [x] [2021.10.19] Release the pretrained checkpoints 56 | 57 | ## Cite 58 | 59 | ```bibtex 60 | @inproceedings{DBLP:conf/iclr/TanCWZZL22, 61 | author = {Chao{-}Hong Tan and 62 | Qian Chen and 63 | Wen Wang and 64 | Qinglin Zhang and 65 | Siqi Zheng and 66 | Zhen{-}Hua Ling}, 67 | title = {PoNet: Pooling Network for Efficient Token Mixing in Long Sequences}, 68 | booktitle = {The Tenth International Conference on Learning Representations, {ICLR} 69 | 2022, Virtual Event, April 25-29, 2022}, 70 | publisher = {OpenReview.net}, 71 | year = {2022}, 72 | url = {https://openreview.net/forum?id=9jInD9JjicF}, 73 | } 74 | ``` -------------------------------------------------------------------------------- /datasets/arxiv-11/arxiv-11.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | import os 18 | import pickle 19 | import datasets 20 | from datasets.tasks import TextClassification 21 | 22 | _DESCRIPTION = """\ 23 | Arxiv-11 24 | """ 25 | 26 | _DOWNLOAD_URL = "datasets/arxiv-11/data" 27 | 28 | class Arxiv11Config(datasets.BuilderConfig): 29 | def __init__(self, **kwargs): 30 | super().__init__(version=datasets.Version("1.0.0", ""), **kwargs) 31 | 32 | class Arxiv11(datasets.GeneratorBasedBuilder): 33 | BUILDER_CONFIGS = [ 34 | Arxiv11Config( 35 | name="plain_text", 36 | description="Plain text", 37 | ) 38 | ] 39 | 40 | def _info(self): 41 | return datasets.DatasetInfo( 42 | description=_DESCRIPTION, 43 | features=datasets.Features( 44 | { 45 | "text": datasets.Value("string"), 46 | "label": datasets.features.ClassLabel( 47 | names=['cs.AI', 'cs.NE', 'cs.cv', 'cs.CE', 'math.ST', 'cs.SY', 'cs.PL', 'cs.DS', 'cs.IT', 'math.GR', 'math.AC'] 48 | ) 49 | } 50 | ), 51 | supervised_keys=None, 52 | homepage="https://github.com/LiqunW/Long-document-dataset", 53 | task_templates=[TextClassification(text_column="text", label_column="label")], 54 | ) 55 | 56 | def _split_generators(self, dl_manager): 57 | data_dir = _DOWNLOAD_URL 58 | with open(os.path.join(data_dir, 'Dataset.txt'), 'rb') as Dataset_file, open(os.path.join(data_dir, 'Labels_file.txt'), 'rb') as Labels_file: 59 | self.Dataset = pickle.load(Dataset_file) 60 | self.Labels = pickle.load(Labels_file) 61 | 62 | self.nTotal = len(self.Dataset) 63 | self.nTrain = int(self.nTotal*0.8) 64 | self.trainDataset = self.Dataset[0: self.nTrain] 65 | self.trainLabels = self.Labels[0: self.nTrain] 66 | 67 | self.nVal = int(self.nTotal*0.1) 68 | self.valDataset = self.Dataset[self.nTrain: self.nTrain+self.nVal] 69 | self.valLabels = self.Labels[self.nTrain: self.nTrain+self.nVal] 70 | 71 | self.nTest = self.nTotal - self.nTrain - self.nVal 72 | self.testDataset = self.Dataset[self.nTrain+self.nVal: self.nTotal] 73 | self.testLabels = self.Labels[self.nTrain + self.nVal: self.nTotal] 74 | 75 | return [ 76 | datasets.SplitGenerator( 77 | name=datasets.Split.TRAIN, gen_kwargs={"file_list": [self.trainDataset, self.trainLabels]} 78 | ), 79 | datasets.SplitGenerator( 80 | name=datasets.Split.VALIDATION, gen_kwargs={"file_list": [self.valDataset, self.valLabels]} 81 | ), 82 | datasets.SplitGenerator( 83 | name=datasets.Split.TEST, gen_kwargs={"file_list": [self.testDataset, self.testLabels]} 84 | ), 85 | ] 86 | 87 | def _generate_examples(self, file_list): 88 | """Generate arxiv-11 examples.""" 89 | for id_, (d, l) in enumerate(zip(*file_list)): 90 | with open(os.path.join(os.path.sep.join(_DOWNLOAD_URL.split(os.path.sep)[:-1]), d), encoding="UTF-8") as f: 91 | yield str(id_), {"text": f.read(), "label": l-1} -------------------------------------------------------------------------------- /examples/LRA/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## LRA Benchmark 3 | 4 | We released the source code for LRA benchmark based on [Nystromformer](https://github.com/mlpen/Nystromformer). 5 | 6 | To prepare the datasets, one would need 7 | ``` 8 | tensorboard>=2.3.0, tensorflow>=2.3.1, tensorflow-datasets>=4.0.1, tensorflow_text 9 | ``` 10 | 11 | To prepare the datasets, one would need to download the source code from [LRA repo](https://github.com/google-research/long-range-arena) and place `long-range-arena` folder in folder `LRA/datasets/` and also download [lra_release.gz](https://storage.googleapis.com/long-range-arena/lra_release.gz) released by LRA repo and place the unzipped folder in folder `LRA/datasets/`. The directory structure would be 12 | ``` 13 | LRA/datasets/long-range-arena 14 | LRA/datasets/lra_release 15 | ``` 16 | You may need to delete the code `train_dataset = train_dataset.repeat()` in `./long-range-arena/lra_benchmarks/image/input_pipeline.py`. 17 | 18 | > You can simply run `delete_repeat.sh` to comment out the code above. 19 | 20 | Then, run `sh create_datasets.sh` and it will create train, dev, and test dataset pickle files for each task. 21 | 22 | To run the LRA tasks, one would need 23 | 24 | ``` 25 | pytorch==2.0.1, transformers==3.3.1, performer-pytorch 26 | ``` 27 | To run a LRA experiment, `run_tasks.sh` will be helpful. 28 | 29 | ```shell 30 | bash run_tasks.sh ponet 31 | ``` 32 | 33 | Or you can run the following command in `code` folder 34 | 35 | ``` 36 | python3 run_tasks.py --model --task 37 | ``` 38 | where `` can be set to `softmax, nystrom-64, reformer-2, performer-256` corresponding to standard self-attention, Nystromformer with 64 landmarks, Reformer with 2 LSHs, Performer with 256 random projection dimension. And `` can be set to `listops, text, retrieval, image, pathfinder32-curv_contour_length_14`. The best models and log files will be saved `LRA/logs/` folder. 39 | 40 | -------------------------------------------------------------------------------- /examples/LRA/code/attention.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | import json 6 | from torch.utils.checkpoint import checkpoint 7 | 8 | 9 | class SoftmaxAttention(nn.Module): 10 | def __init__(self, config): 11 | super().__init__() 12 | self.drop_attn = torch.nn.Dropout(p=config["attention_dropout"]) 13 | self.head_dim = config["head_dim"] 14 | 15 | def forward(self, Q, K, V, mask): 16 | dot = torch.matmul(Q, torch.transpose(K, -2, -1)) 17 | dot = dot / math.sqrt(self.head_dim) 18 | dot = dot - 1e6 * (1 - mask[:, None, None, :]) 19 | 20 | attn = nn.functional.softmax(dot, dim=-1) 21 | attn = self.drop_attn(attn) 22 | 23 | X = torch.matmul(attn, V) 24 | return X 25 | 26 | 27 | class NoneAttention(nn.Module): 28 | def __init__(self, config): 29 | super().__init__() 30 | 31 | def forward(self, Q, K, V, mask): 32 | return V 33 | 34 | 35 | class Attention(nn.Module): 36 | def __init__(self, config): 37 | super().__init__() 38 | 39 | self.grad_checkpointing = config["attention_grad_checkpointing"] 40 | 41 | self.dim = config["transformer_dim"] 42 | self.head_dim = config["head_dim"] 43 | self.num_head = config["num_head"] 44 | 45 | self.attn_type = config["attn_type"] 46 | 47 | self.W_q = nn.Linear(self.dim, self.num_head * self.head_dim) 48 | self.W_k = nn.Linear(self.dim, self.num_head * self.head_dim) 49 | self.W_v = nn.Linear(self.dim, self.num_head * self.head_dim) 50 | 51 | if self.attn_type == "softmax": 52 | self.attn = SoftmaxAttention(config) 53 | elif self.attn_type == "none": 54 | self.attn = NoneAttention(config) 55 | elif self.attn_type.startswith("linformer"): 56 | from attention_linformer import LinformerAttention 57 | self.attn = LinformerAttention(config) 58 | elif self.attn_type.startswith("reformer"): 59 | from attention_reformer import LSHAttention 60 | self.attn = LSHAttention(config, self.W_q, self.W_k, self.W_v) 61 | elif self.attn_type.startswith("nystrom"): 62 | from attention_nystrom import NystromAttention 63 | self.attn = NystromAttention(config) 64 | elif self.attn_type.startswith("fnet"): 65 | from attention_fnet import FFTAttention 66 | self.attn = FFTAttention(config) 67 | elif self.attn_type.startswith("ponet"): 68 | from attention_ponet import PoNetAttention 69 | self.attn = PoNetAttention(config) 70 | self.W_local = nn.Linear(self.dim, self.dim) 71 | self.W_segment = nn.Linear(self.dim, self.dim) 72 | elif self.attn_type.startswith("performer"): 73 | from attention_performer import PerformerAttention 74 | self.attn = PerformerAttention(config) 75 | elif self.attn_type.startswith("linear"): 76 | from attention_linear import LinearAttention 77 | self.attn = LinearAttention(config) 78 | 79 | self.ff = nn.Linear(self.num_head * self.head_dim, self.dim) 80 | self.sp_ff = nn.Identity() 81 | 82 | def forward(self, X, mask): 83 | 84 | if self.attn_type.startswith("longformer") or self.attn_type.startswith("reformer") or self.attn_type.startswith("fnet"): 85 | with torch.cuda.amp.autocast(enabled=False): 86 | attn_out = self.attn(X.float(), mask.float()) 87 | elif self.attn_type.startswith("ponet"): 88 | Q = self.split_heads(self.W_q(X)) 89 | K = self.split_heads(self.W_k(X)) 90 | O = self.split_heads(self.W_v(X)) 91 | 92 | local = self.W_local(X) 93 | segment = self.W_segment(X) 94 | with torch.cuda.amp.autocast(enabled=False): 95 | attn_out = self.attn(X.float(), Q.float(), K.float(), O.float(), local.float(), segment.float(), mask.float()) 96 | else: 97 | Q = self.split_heads(self.W_q(X)) 98 | K = self.split_heads(self.W_k(X)) 99 | V = self.split_heads(self.W_v(X)) 100 | with torch.cuda.amp.autocast(enabled=False): 101 | if self.grad_checkpointing: 102 | attn_out = checkpoint(self.attn, Q.float(), K.float(), V.float(), mask.float()) 103 | else: 104 | attn_out = self.attn(Q.float(), K.float(), V.float(), mask.float()) 105 | attn_out = self.combine_heads(attn_out) 106 | 107 | if self.attn_type.startswith("fnet"): 108 | self.ff = self.sp_ff 109 | out = self.ff(attn_out) 110 | 111 | return out 112 | 113 | def combine_heads(self, X): 114 | X = X.transpose(1, 2) 115 | X = X.reshape(X.size(0), X.size(1), self.num_head * self.head_dim) 116 | return X 117 | 118 | def split_heads(self, X): 119 | X = X.reshape(X.size(0), X.size(1), self.num_head, self.head_dim) 120 | X = X.transpose(1, 2) 121 | return X 122 | -------------------------------------------------------------------------------- /examples/LRA/code/attention_fnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | 6 | class FFTAttention(nn.Module): 7 | 8 | def __init__(self, config): 9 | super().__init__() 10 | 11 | def forward(self, X, mask): 12 | X = torch.fft.fft(torch.fft.fft(X, dim=-1), dim=-2).real 13 | 14 | return X 15 | -------------------------------------------------------------------------------- /examples/LRA/code/attention_linear.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | 6 | 7 | class LinearAttention(nn.Module): 8 | 9 | def __init__(self, config): 10 | super().__init__() 11 | 12 | def forward(self, Q, K, V, mask): 13 | Q = (nn.functional.elu(Q) + 1) / math.sqrt(math.sqrt(Q.size(2))) 14 | K = (nn.functional.elu(K) + 1) * mask[:, None, :, None] / math.sqrt(math.sqrt(K.size(2))) 15 | V = V * mask[:, None, :, None] 16 | 17 | X = torch.matmul(Q, torch.matmul(torch.transpose(K, -2, -1), V)) 18 | 19 | return X 20 | -------------------------------------------------------------------------------- /examples/LRA/code/attention_linformer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | 6 | 7 | class LinformerAttention(nn.Module): 8 | projection_matrix = None 9 | 10 | def __init__(self, config): 11 | super().__init__() 12 | 13 | self.num_head = config["num_head"] 14 | self.head_dim = config["head_dim"] 15 | self.linformer_k = config["linformer_k"] 16 | self.seq_len = config["max_seq_len"] 17 | 18 | if LinformerAttention.projection_matrix is not None: 19 | self.E = LinformerAttention.projection_matrix 20 | else: 21 | LinformerAttention.projection_matrix = nn.Parameter(torch.Tensor(self.num_head, self.linformer_k, self.seq_len)) 22 | torch.nn.init.normal_(LinformerAttention.projection_matrix, std=0.02) 23 | self.E = LinformerAttention.projection_matrix 24 | 25 | def forward(self, Q, K, V, mask): 26 | K = torch.matmul(self.E, K * mask[:, None, :, None]) 27 | V = torch.matmul(self.E, V * mask[:, None, :, None]) 28 | 29 | dot = torch.matmul(Q, torch.transpose(K, -2, -1)) 30 | dot = dot / math.sqrt(self.head_dim) 31 | 32 | attn = nn.functional.softmax(dot, dim=-1) 33 | 34 | X = torch.matmul(attn, V) 35 | 36 | return X 37 | 38 | def extra_repr(self): 39 | return f'linformer_k={self.linformer_k}' 40 | -------------------------------------------------------------------------------- /examples/LRA/code/attention_nystrom.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | 6 | 7 | class NystromAttention(nn.Module): 8 | def __init__(self, config): 9 | super().__init__() 10 | 11 | self.head_dim = config["head_dim"] 12 | self.num_head = config["num_head"] 13 | 14 | self.num_landmarks = config["num_landmarks"] 15 | self.seq_len = config["seq_len"] 16 | 17 | if "inv_coeff_init_option" in config: 18 | self.init_option = config["inv_init_coeff_option"] 19 | else: 20 | self.init_option = "original" 21 | 22 | self.use_conv = "conv_kernel_size" in config 23 | if self.use_conv: 24 | self.conv = nn.Conv2d( 25 | in_channels=self.num_head, out_channels=self.num_head, 26 | kernel_size=(config["conv_kernel_size"], 1), padding=(config["conv_kernel_size"] // 2, 0), 27 | bias=False, 28 | groups=self.num_head) 29 | 30 | def forward(self, Q, K, V, mask): 31 | 32 | Q = Q * mask[:, None, :, None] / math.sqrt(math.sqrt(self.head_dim)) 33 | K = K * mask[:, None, :, None] / math.sqrt(math.sqrt(self.head_dim)) 34 | 35 | if self.num_landmarks == self.seq_len: 36 | attn = torch.nn.functional.softmax(torch.matmul(Q, K.transpose(-1, -2)) - 1e9 * (1 - mask[:, None, None, :]), dim=-1) 37 | X = torch.matmul(attn, V) 38 | else: 39 | Q_landmarks = Q.reshape(-1, self.num_head, self.num_landmarks, self.seq_len // self.num_landmarks, self.head_dim).mean(dim=-2) 40 | K_landmarks = K.reshape(-1, self.num_head, self.num_landmarks, self.seq_len // self.num_landmarks, self.head_dim).mean(dim=-2) 41 | 42 | kernel_1 = torch.nn.functional.softmax(torch.matmul(Q, K_landmarks.transpose(-1, -2)), dim=-1) 43 | kernel_2 = torch.nn.functional.softmax(torch.matmul(Q_landmarks, K_landmarks.transpose(-1, -2)), dim=-1) 44 | kernel_3 = torch.nn.functional.softmax(torch.matmul(Q_landmarks, K.transpose(-1, -2)) - 1e9 * (1 - mask[:, None, None, :]), dim=-1) 45 | X = torch.matmul(torch.matmul(kernel_1, self.iterative_inv(kernel_2)), torch.matmul(kernel_3, V)) 46 | 47 | if self.use_conv: 48 | X += self.conv(V * mask[:, None, :, None]) 49 | 50 | return X 51 | 52 | def iterative_inv(self, mat, n_iter=6): 53 | I = torch.eye(mat.size(-1), device=mat.device) 54 | K = mat 55 | 56 | if self.init_option == "original": 57 | V = 1 / torch.max(torch.sum(K, dim=-2)) * K.transpose(-1, -2) 58 | else: 59 | V = 1 / torch.max(torch.sum(K, dim=-2), dim=-1).values[:, :, None, None] * K.transpose(-1, -2) 60 | 61 | for _ in range(n_iter): 62 | KV = torch.matmul(K, V) 63 | V = torch.matmul(0.25 * V, 13 * I - torch.matmul(KV, 15 * I - torch.matmul(KV, 7 * I - KV))) 64 | return V 65 | 66 | def extra_repr(self): 67 | return f'num_landmarks={self.num_landmarks}, seq_len={self.seq_len}' 68 | -------------------------------------------------------------------------------- /examples/LRA/code/attention_performer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | from performer_pytorch import FastAttention 6 | 7 | 8 | class PerformerAttention(nn.Module): 9 | def __init__(self, config): 10 | super().__init__() 11 | 12 | self.head_dim = config["head_dim"] 13 | self.rp_dim = config["rp_dim"] 14 | self.kernel_type = config["kernel_type"] 15 | if self.kernel_type == "relu": 16 | self.attn_fn = FastAttention(dim_heads=self.head_dim, nb_features=self.rp_dim, causal=False, kernel_fn=nn.ReLU()) 17 | elif self.kernel_type == "exp": 18 | self.attn_fn = FastAttention(dim_heads=self.head_dim, nb_features=self.rp_dim, causal=False, kernel_fn=torch.exp) 19 | 20 | def forward(self, Q, K, V, mask): 21 | return self.attn_fn( 22 | Q / math.sqrt(math.sqrt(self.head_dim)), 23 | K / math.sqrt(math.sqrt(self.head_dim)) * mask[:, None, :, None], 24 | V * mask[:, None, :, None]) 25 | 26 | def extra_repr(self): 27 | return f'rp_dim={self.rp_dim}, kernel_type={self.kernel_type}' 28 | -------------------------------------------------------------------------------- /examples/LRA/code/attention_ponet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | 5 | def segment_max(src, index, dim=1): 6 | out = torch.zeros_like(src).scatter_reduce( 7 | dim, index.unsqueeze(-1).expand_as(src), src, reduce="amax", include_self=False 8 | ) 9 | dummy = index.unsqueeze(-1).expand(*index.shape[:2], out.size(-1)) 10 | return torch.gather(out, dim, dummy).to(dtype=src.dtype) 11 | 12 | def get_win_max(hidden_states, kernel_size=3): 13 | m = nn.MaxPool1d(kernel_size, stride=1, padding=kernel_size//2) 14 | out = m(hidden_states.permute(0,2,1)).permute(0,2,1) 15 | return out 16 | 17 | class PoNetAttention(nn.Module): 18 | def __init__(self, config): 19 | super().__init__() 20 | 21 | self.num_head = config["num_head"] 22 | self.head_dim = config["head_dim"] 23 | self.segment_num = config["segment_num"] 24 | 25 | #XXX: More natural implementation. 26 | def get_segment_index(self, input_ids): 27 | sentence_len = input_ids.shape[1] 28 | segment_num = self.segment_num 29 | segment_len = sentence_len // segment_num + 1 30 | mask = torch.arange(start=0, end=segment_num, dtype=torch.long, device=input_ids.device).view(1, segment_num, 1).repeat(1, 1, segment_len).view(1, -1)[:, :sentence_len].repeat(input_ids.shape[0], 1) 31 | return mask 32 | 33 | 34 | def forward(self, hidden_states, Q, K, O, local, segment, attention_mask): 35 | # bdlh 36 | context_layer_q = Q 37 | context_layer_k = K 38 | context_layer_v = context_layer_k 39 | context_layer_o = O 40 | 41 | if attention_mask is not None: 42 | _attention_mask = (attention_mask[:,None,:,None] < 0.5) 43 | 44 | if attention_mask is not None: 45 | context_layer_q.masked_fill_(_attention_mask, 0.0) 46 | q = context_layer_q.sum(dim=-2) / torch.ones_like(_attention_mask).to(dtype=context_layer_q.dtype).masked_fill(_attention_mask, 0.0).sum(dim=-2) 47 | else: 48 | q = context_layer_q.mean(dim=-2) 49 | att = torch.einsum("bdh,bdlh -> bdl", q, context_layer_k) / math.sqrt(context_layer_q.shape[-1]) 50 | 51 | if attention_mask is not None: 52 | att.masked_fill_(_attention_mask.squeeze(-1), -10000) 53 | att_prob = att.softmax(dim=-1) 54 | 55 | v = torch.einsum('bdlh,bdl->bdh', context_layer_v, att_prob) 56 | 57 | context_layer_segment = segment 58 | context_layer_local = local 59 | if attention_mask is not None: 60 | _attention_mask = _attention_mask.squeeze(1) 61 | context_layer_segment.masked_fill_(_attention_mask, -10000) 62 | context_layer_local.masked_fill_(_attention_mask, -10000) 63 | 64 | context_layer_local = get_win_max(context_layer_local) 65 | segment_index = self.get_segment_index(hidden_states) 66 | context_layer_segment = segment_max(context_layer_segment, index=segment_index) 67 | 68 | context_layer_local = context_layer_local.view(*context_layer_local.shape[:2], self.num_head, self.head_dim).permute(0, 2, 1, 3) 69 | context_layer_segment = context_layer_segment.view(*context_layer_segment.shape[:2], self.num_head, self.head_dim).permute(0, 2, 1, 3) 70 | 71 | context_layer = (v.unsqueeze(dim=-2) + context_layer_segment) * context_layer_o + context_layer_local 72 | context_layer = context_layer.permute(0, 2, 1, 3).reshape(*hidden_states.shape[:2], -1) 73 | if attention_mask is not None: 74 | context_layer.masked_fill_(_attention_mask, 0.0) 75 | 76 | outputs = context_layer 77 | return outputs -------------------------------------------------------------------------------- /examples/LRA/code/attention_reformer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from transformers.models.reformer.modeling_reformer import LSHSelfAttention, ReformerConfig 5 | 6 | 7 | class LSHAttention(LSHSelfAttention): 8 | def __init__(self, config, query, key, value): 9 | self.num_hash = config["num_hash"] 10 | reformer_config = ReformerConfig() 11 | reformer_config.attention_head_size = config["head_dim"] 12 | reformer_config.num_attention_heads = config["num_head"] 13 | reformer_config.attn_layers = ["lsh"] 14 | reformer_config.num_hashes = config["num_hash"] 15 | reformer_config.is_decoder = False 16 | reformer_config.max_position_embeddings = config["max_seq_len"] 17 | reformer_config.hidden_size = config["transformer_dim"] 18 | super().__init__(reformer_config) 19 | self.query_key.weight = query.weight 20 | self.value.weight = value.weight 21 | 22 | def forward(self, X, mask): 23 | return super().forward(hidden_states=X, attention_mask=mask).hidden_states 24 | 25 | def extra_repr(self): 26 | return f'num_hash={self.num_hash}' 27 | -------------------------------------------------------------------------------- /examples/LRA/code/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from torch.utils.data.dataset import Dataset 5 | import sys 6 | import os 7 | import random 8 | import json 9 | import pickle 10 | import numpy as np 11 | 12 | class LRADataset(Dataset): 13 | def __init__(self, file_path, endless): 14 | 15 | self.endless = endless 16 | with open(file_path, "rb") as f: 17 | self.examples = pickle.load(f) 18 | random.shuffle(self.examples) 19 | self.curr_idx = 0 20 | 21 | print(f"Loaded {file_path}... size={len(self.examples)}", flush = True) 22 | 23 | def __len__(self): 24 | if self.endless: 25 | return 1000000000 26 | else: 27 | return len(self.examples) 28 | 29 | def create_inst(self, inst): 30 | output = {} 31 | output["input_ids_0"] = torch.tensor(inst["input_ids_0"], dtype = torch.long) 32 | output["mask_0"] = (output["input_ids_0"] != 0).float() 33 | if "input_ids_1" in inst: 34 | output["input_ids_1"] = torch.tensor(inst["input_ids_1"], dtype = torch.long) 35 | output["mask_1"] = (output["input_ids_1"] != 0).float() 36 | output["label"] = torch.tensor(inst["label"], dtype = torch.long) 37 | return output 38 | 39 | def __getitem__(self, i): 40 | if not self.endless: 41 | return self.create_inst(self.examples[i]) 42 | 43 | if self.curr_idx >= len(self.examples): 44 | random.shuffle(self.examples) 45 | self.curr_idx = 0 46 | inst = self.examples[self.curr_idx] 47 | self.curr_idx += 1 48 | 49 | return self.create_inst(inst) 50 | -------------------------------------------------------------------------------- /examples/LRA/code/lra_config.py: -------------------------------------------------------------------------------- 1 | 2 | config = { 3 | "listops":{ 4 | "dataset":{ 5 | "train":96000, 6 | "dev":2000, 7 | "test":2000, 8 | }, 9 | "model":{ 10 | "learn_pos_emb":True, 11 | "tied_weights":False, 12 | "embedding_dim":64, 13 | "transformer_dim":64, 14 | "transformer_hidden_dim":128, 15 | "head_dim":32, 16 | "num_head":2, 17 | "num_layers":2, 18 | "vocab_size":32, 19 | "max_seq_len":2000, 20 | "dropout_prob":0.1, 21 | "attention_dropout":0.1, 22 | "pooling_mode":"MEAN", 23 | "num_classes":10, 24 | }, 25 | "training":{ 26 | "batch_size":32, 27 | "learning_rate":0.0001, 28 | "warmup":1000, 29 | "lr_decay":"linear", 30 | "weight_decay":0, 31 | "eval_frequency":50, 32 | "num_train_steps":5000, 33 | "num_eval_steps":62, 34 | }, 35 | "gpu_memory":{ 36 | "softmax":32, 37 | "nystrom-32":32, 38 | "nystrom-64":32, 39 | "nystrom-128":32, 40 | "nystrom-256":32, 41 | "ponet": 32, 42 | "fnet": 32, 43 | "linformer-256":32, 44 | "reformer-2":32, 45 | "performer-256":32, 46 | "linear":32, 47 | }, 48 | "extra_attn_config":{ 49 | "softmax":{"attention_grad_checkpointing":True}, 50 | "nystrom-32":{"attention_grad_checkpointing":False, "num_landmarks":32, "conv_kernel_size":35}, 51 | "nystrom-64":{"attention_grad_checkpointing":False, "num_landmarks":64, "conv_kernel_size":35}, 52 | "nystrom-128":{"attention_grad_checkpointing":False, "num_landmarks":128, "conv_kernel_size":35}, 53 | "nystrom-256":{"attention_grad_checkpointing":False, "num_landmarks":256, "conv_kernel_size":35}, 54 | "ponet":{"attention_grad_checkpointing":False, "segment_num":64}, 55 | "fnet":{"attention_grad_checkpointing":False}, 56 | "linformer-256":{"attention_grad_checkpointing":False, "linformer_k":256}, 57 | "reformer-2":{"attention_grad_checkpointing":False, "num_hash":2}, 58 | "performer-256":{"attention_grad_checkpointing":False, "rp_dim":256, "kernel_type":"relu"}, 59 | "linear":{"attention_grad_checkpointing":False}, 60 | } 61 | }, 62 | "image":{ 63 | "dataset":{ 64 | "train":45000, 65 | "dev":5000, 66 | "test":10000, 67 | }, 68 | "model":{ 69 | "learn_pos_emb":True, 70 | "tied_weights":False, 71 | "embedding_dim":64, 72 | "transformer_dim":64, 73 | "transformer_hidden_dim":128, 74 | "head_dim":32, 75 | "num_head":2, 76 | "num_layers":2, 77 | "vocab_size":512, 78 | "max_seq_len":1024, 79 | "dropout_prob":0.1, 80 | "attention_dropout":0.1, 81 | "pooling_mode":"MEAN", 82 | "num_classes": 10, 83 | }, 84 | "training":{ 85 | "batch_size":256, 86 | "learning_rate":0.0001, 87 | "warmup":175, 88 | "lr_decay":"linear", 89 | "weight_decay":0, 90 | "eval_frequency":175, 91 | "num_train_steps":35000, 92 | "num_eval_steps":20, 93 | }, 94 | "gpu_memory":{ 95 | "softmax":128, 96 | "nystrom-32":128, 97 | "nystrom-64":128, 98 | "nystrom-128":128, 99 | "nystrom-256":128, 100 | "ponet": 128, 101 | "fnet": 128, 102 | "linformer-256":128, 103 | "reformer-2":128, 104 | "performer-256":128, 105 | "linear":128, 106 | }, 107 | "extra_attn_config":{ 108 | "softmax":{"attention_grad_checkpointing":True}, 109 | "nystrom-32":{"attention_grad_checkpointing":False, "num_landmarks":32, "conv_kernel_size":35}, 110 | "nystrom-64":{"attention_grad_checkpointing":False, "num_landmarks":64, "conv_kernel_size":35}, 111 | "nystrom-128":{"attention_grad_checkpointing":False, "num_landmarks":128, "conv_kernel_size":35}, 112 | "nystrom-256":{"attention_grad_checkpointing":False, "num_landmarks":256, "conv_kernel_size":35}, 113 | "ponet":{"attention_grad_checkpointing":False, "segment_num":32}, 114 | "fnet":{"attention_grad_checkpointing":False}, 115 | "linformer-256":{"attention_grad_checkpointing":False, "linformer_k":256}, 116 | "reformer-2":{"attention_grad_checkpointing":False, "num_hash":2}, 117 | "performer-256":{"attention_grad_checkpointing":False, "rp_dim":256, "kernel_type":"relu"}, 118 | "linear":{"attention_grad_checkpointing":False}, 119 | } 120 | }, 121 | "pathfinder32":{ 122 | "model":{ 123 | "learn_pos_emb":True, 124 | "tied_weights":False, 125 | "embedding_dim":64, 126 | "transformer_dim":64, 127 | "transformer_hidden_dim":128, 128 | "head_dim":32, 129 | "num_head":2, 130 | "num_layers":2, 131 | "vocab_size":512, 132 | "max_seq_len":1024, 133 | "dropout_prob":0.1, 134 | "attention_dropout":0.1, 135 | "pooling_mode":"MEAN", 136 | "num_classes": 2, 137 | }, 138 | "training":{ 139 | "batch_size":256, 140 | "learning_rate":0.0001, 141 | "warmup":312, 142 | "lr_decay":"linear", 143 | "weight_decay":0, 144 | "eval_frequency":312, 145 | "num_train_steps":62400, 146 | "num_eval_steps":312, 147 | }, 148 | "gpu_memory":{ 149 | "softmax":128, 150 | "nystrom-32":128, 151 | "nystrom-64":128, 152 | "nystrom-128":128, 153 | "nystrom-256":128, 154 | "ponet":128, 155 | "fnet":128, 156 | "linformer-256":128, 157 | "reformer-2":128, 158 | "performer-256":128, 159 | "linear":128, 160 | }, 161 | "extra_attn_config":{ 162 | "softmax":{"attention_grad_checkpointing":True}, 163 | "nystrom-32":{"attention_grad_checkpointing":False, "num_landmarks":32, "conv_kernel_size":35}, 164 | "nystrom-64":{"attention_grad_checkpointing":False, "num_landmarks":64, "conv_kernel_size":35}, 165 | "nystrom-128":{"attention_grad_checkpointing":False, "num_landmarks":128, "conv_kernel_size":35}, 166 | "nystrom-256":{"attention_grad_checkpointing":False, "num_landmarks":256, "conv_kernel_size":35}, 167 | "ponet":{"attention_grad_checkpointing":False, "segment_num":1}, 168 | "fnet":{"attention_grad_checkpointing":False}, 169 | "linformer-256":{"attention_grad_checkpointing":False, "linformer_k":256}, 170 | "reformer-2":{"attention_grad_checkpointing":False, "num_hash":2}, 171 | "performer-256":{"attention_grad_checkpointing":False, "rp_dim":256, "kernel_type":"relu"}, 172 | "linear":{"attention_grad_checkpointing":False}, 173 | } 174 | }, 175 | "retrieval":{ 176 | "dataset":{ 177 | "train":147086, 178 | "dev":18090, 179 | "test":17437, 180 | }, 181 | "model":{ 182 | "learn_pos_emb":True, 183 | "tied_weights":False, 184 | "embedding_dim":64, 185 | "transformer_dim":64, 186 | "transformer_hidden_dim":128, 187 | "head_dim":32, 188 | "num_head":2, 189 | "num_layers":2, 190 | "vocab_size":512, 191 | "max_seq_len":4000, 192 | "dropout_prob":0.1, 193 | "attention_dropout":0.1, 194 | "pooling_mode":"MEAN", 195 | "num_classes": 2, 196 | }, 197 | "training":{ 198 | "batch_size":32, 199 | "learning_rate":0.0001, 200 | "warmup":800, 201 | "lr_decay":"linear", 202 | "weight_decay":0, 203 | "eval_frequency":300, 204 | "num_train_steps":30000, 205 | "num_eval_steps":565, 206 | }, 207 | "gpu_memory":{ 208 | "softmax":32, 209 | "nystrom-32":32, 210 | "nystrom-64":32, 211 | "nystrom-128":32, 212 | "nystrom-256":32, 213 | "ponet":32, 214 | "fnet":32, 215 | "linformer-256":32, 216 | "reformer-2":32, 217 | "performer-256":32, 218 | "linear":32, 219 | }, 220 | "extra_attn_config":{ 221 | "softmax":{"attention_grad_checkpointing":True}, 222 | "nystrom-32":{"attention_grad_checkpointing":False, "num_landmarks":32, "conv_kernel_size":35}, 223 | "nystrom-64":{"attention_grad_checkpointing":False, "num_landmarks":64, "conv_kernel_size":35}, 224 | "nystrom-128":{"attention_grad_checkpointing":False, "num_landmarks":128, "conv_kernel_size":35}, 225 | "nystrom-256":{"attention_grad_checkpointing":False, "num_landmarks":256, "conv_kernel_size":35}, 226 | "ponet":{"attention_grad_checkpointing":False, "segment_num":64}, 227 | "fnet":{"attention_grad_checkpointing":False}, 228 | "linformer-256":{"attention_grad_checkpointing":False, "linformer_k":256}, 229 | "reformer-2":{"attention_grad_checkpointing":False, "num_hash":2}, 230 | "performer-256":{"attention_grad_checkpointing":False, "rp_dim":256, "kernel_type":"relu"}, 231 | "linear":{"attention_grad_checkpointing":False}, 232 | } 233 | }, 234 | "text":{ 235 | "dataset":{ 236 | "train":25000, 237 | "dev":25000, 238 | "test":25000, 239 | }, 240 | "model":{ 241 | "learn_pos_emb":True, 242 | "tied_weights":False, 243 | "embedding_dim":64, 244 | "transformer_dim":64, 245 | "transformer_hidden_dim":128, 246 | "head_dim":32, 247 | "num_head":2, 248 | "num_layers":2, 249 | "vocab_size":512, 250 | "max_seq_len":4000, 251 | "dropout_prob":0.1, 252 | "attention_dropout":0.1, 253 | "pooling_mode":"MEAN", 254 | "num_classes": 2, 255 | }, 256 | "training":{ 257 | "batch_size":32, 258 | "learning_rate":0.0001, 259 | "warmup":8000, 260 | "lr_decay":"linear", 261 | "weight_decay":0, 262 | "eval_frequency":500, 263 | "num_train_steps":20000, 264 | "num_eval_steps":781, 265 | }, 266 | "gpu_memory":{ 267 | "softmax":32, 268 | "nystrom-32":32, 269 | "nystrom-64":32, 270 | "nystrom-128":32, 271 | "nystrom-256":32, 272 | "ponet":32, 273 | "fnet":32, 274 | "linformer-256":32, 275 | "reformer-2":32, 276 | "performer-256":32, 277 | "linear":32, 278 | }, 279 | "extra_attn_config":{ 280 | "softmax":{"attention_grad_checkpointing":True}, 281 | "nystrom-32":{"attention_grad_checkpointing":False, "num_landmarks":32, "conv_kernel_size":35}, 282 | "nystrom-64":{"attention_grad_checkpointing":False, "num_landmarks":64, "conv_kernel_size":35}, 283 | "nystrom-128":{"attention_grad_checkpointing":False, "num_landmarks":128, "conv_kernel_size":35}, 284 | "nystrom-256":{"attention_grad_checkpointing":False, "num_landmarks":256, "conv_kernel_size":35}, 285 | "ponet":{"attention_grad_checkpointing":False, "segment_num":4096}, # Here we write down 4096 for the function get_segment_index. Actually segment num is 2048. 286 | "fnet":{"attention_grad_checkpointing":False}, 287 | "linformer-256":{"attention_grad_checkpointing":False, "linformer_k":256}, 288 | "reformer-2":{"attention_grad_checkpointing":False, "num_hash":2}, 289 | "performer-256":{"attention_grad_checkpointing":False, "rp_dim":256, "kernel_type":"relu"}, 290 | "linear":{"attention_grad_checkpointing":False}, 291 | } 292 | } 293 | } 294 | 295 | config["pathfinder32-curv_baseline"] = config["pathfinder32"] 296 | config["pathfinder32-curv_contour_length_9"] = config["pathfinder32"] 297 | config["pathfinder32-curv_contour_length_14"] = config["pathfinder32"] 298 | -------------------------------------------------------------------------------- /examples/LRA/code/model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import math 6 | from torch.utils.checkpoint import checkpoint 7 | from attention import Attention 8 | 9 | 10 | class Embeddings(nn.Module): 11 | def __init__(self, config): 12 | super().__init__() 13 | 14 | assert config["embedding_dim"] == config["transformer_dim"] 15 | 16 | self.dim = config["embedding_dim"] 17 | 18 | self.word_embeddings = nn.Embedding(config["vocab_size"], config["embedding_dim"]) 19 | torch.nn.init.normal_(self.word_embeddings.weight, std=0.02) 20 | 21 | self.position_embeddings = nn.Embedding(config["max_seq_len"], config["embedding_dim"]) 22 | torch.nn.init.normal_(self.position_embeddings.weight, std=0.02) 23 | 24 | self.dropout = torch.nn.Dropout(p=config["dropout_prob"]) 25 | 26 | def fixed_pos_emb(self, seq_len, device): 27 | position = torch.arange(0, seq_len, device=device)[:, np.newaxis] 28 | div_term = torch.exp(torch.arange(0, self.dim, 2, device=device) * -(math.log(10000.0) / self.dim)) 29 | pos_embed = torch.stack([torch.sin(position * div_term), torch.cos(position * div_term)], -1).reshape(seq_len, -1) 30 | return pos_embed 31 | 32 | def forward(self, input_ids): 33 | 34 | batch_size, seq_len = input_ids.size() 35 | 36 | X_token = self.word_embeddings(input_ids) 37 | 38 | position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)[None, :].repeat(batch_size, 1) 39 | X_pos = self.position_embeddings(position_ids) 40 | 41 | X = X_token + X_pos 42 | 43 | X = self.dropout(X) 44 | 45 | return X 46 | 47 | 48 | class Transformer(nn.Module): 49 | def __init__(self, config): 50 | super().__init__() 51 | 52 | self.norm1 = nn.LayerNorm(config["transformer_dim"]) 53 | self.mha = Attention(config) 54 | self.dropout1 = torch.nn.Dropout(p=config["dropout_prob"]) 55 | self.norm2 = nn.LayerNorm(config["transformer_dim"]) 56 | 57 | self.mlpblock = nn.Sequential( 58 | nn.Linear(config["transformer_dim"], config["transformer_hidden_dim"]), 59 | nn.GELU(), 60 | torch.nn.Dropout(p=config["dropout_prob"]), 61 | nn.Linear(config["transformer_hidden_dim"], config["transformer_dim"]), 62 | torch.nn.Dropout(p=config["dropout_prob"]) 63 | ) 64 | 65 | def forward(self, X, mask): 66 | X = self.dropout1(self.mha(self.norm1(X), mask)) + X 67 | X = self.mlpblock(self.norm2(X)) + X 68 | return X 69 | 70 | 71 | class Model(nn.Module): 72 | def __init__(self, config): 73 | super().__init__() 74 | 75 | self.num_layers = config["num_layers"] 76 | self.tied_weights = config["tied_weights"] 77 | 78 | self.embeddings = Embeddings(config) 79 | 80 | if self.tied_weights: 81 | self.transformer = Transformer(config) 82 | else: 83 | for idx in range(self.num_layers): 84 | setattr(self, f"transformer_{idx}", Transformer(config)) 85 | 86 | self.norm = nn.LayerNorm(config["transformer_dim"]) 87 | 88 | def forward(self, input_ids, mask=None): 89 | 90 | X = self.embeddings(input_ids) 91 | 92 | if mask is None: 93 | mask = torch.ones_like(input_ids) 94 | 95 | if self.tied_weights: 96 | for idx in range(self.num_layers): 97 | X = self.transformer(X, mask) 98 | else: 99 | for idx in range(self.num_layers): 100 | X = getattr(self, f"transformer_{idx}")(X, mask) 101 | 102 | X = self.norm(X) * mask[:, :, None] 103 | 104 | return X 105 | -------------------------------------------------------------------------------- /examples/LRA/code/model_wrapper.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | from model import Model 6 | 7 | 8 | def pooling(inp, mode): 9 | if mode == "CLS": 10 | pooled = inp[:, 0, :] 11 | elif mode == "MEAN": 12 | pooled = inp.mean(dim=1) 13 | else: 14 | raise Exception() 15 | return pooled 16 | 17 | 18 | def append_cls(inp, mask, vocab_size): 19 | batch_size = inp.size(0) 20 | cls_id = ((vocab_size - 1) * torch.ones(batch_size, dtype=torch.long, device=inp.device)).long() 21 | cls_mask = torch.ones(batch_size, dtype=torch.float, device=mask.device) 22 | inp = torch.cat([cls_id[:, None], inp[:, :-1]], dim=-1) 23 | mask = torch.cat([cls_mask[:, None], mask[:, :-1]], dim=-1) 24 | return inp, mask 25 | 26 | 27 | class SCHead(nn.Module): 28 | def __init__(self, config): 29 | super().__init__() 30 | self.pooling_mode = config["pooling_mode"] 31 | self.mlpblock = nn.Sequential( 32 | nn.Linear(config["transformer_dim"], config["transformer_hidden_dim"]), 33 | nn.ReLU(), 34 | nn.Linear(config["transformer_hidden_dim"], config["num_classes"]) 35 | ) 36 | 37 | def forward(self, inp): 38 | seq_score = self.mlpblock(pooling(inp, self.pooling_mode)) 39 | return seq_score 40 | 41 | 42 | class ModelForSC(nn.Module): 43 | def __init__(self, config): 44 | super().__init__() 45 | 46 | self.enable_amp = config["mixed_precision"] 47 | self.pooling_mode = config["pooling_mode"] 48 | self.vocab_size = config["vocab_size"] 49 | 50 | self.model = Model(config) 51 | 52 | self.seq_classifer = SCHead(config) 53 | 54 | def forward(self, input_ids_0, mask_0, label): 55 | 56 | with torch.cuda.amp.autocast(enabled=self.enable_amp): 57 | 58 | if self.pooling_mode == "CLS": 59 | input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size) 60 | 61 | token_out = self.model(input_ids_0, mask_0) 62 | seq_scores = self.seq_classifer(token_out) 63 | 64 | seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label) 65 | seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32) 66 | outputs = {} 67 | outputs["loss"] = seq_loss 68 | outputs["accu"] = seq_accu 69 | 70 | return outputs 71 | 72 | 73 | class SCHeadDual(nn.Module): 74 | def __init__(self, config): 75 | super().__init__() 76 | self.pooling_mode = config["pooling_mode"] 77 | self.mlpblock = nn.Sequential( 78 | nn.Linear(config["transformer_dim"] * 4, config["transformer_hidden_dim"]), 79 | nn.ReLU(), 80 | nn.Linear(config["transformer_hidden_dim"], config["num_classes"]) 81 | ) 82 | 83 | def forward(self, inp_0, inp_1): 84 | X_0 = pooling(inp_0, self.pooling_mode) 85 | X_1 = pooling(inp_1, self.pooling_mode) 86 | seq_score = self.mlpblock(torch.cat([X_0, X_1, X_0 * X_1, X_0 - X_1], dim=-1)) 87 | return seq_score 88 | 89 | 90 | class ModelForSCDual(nn.Module): 91 | def __init__(self, config): 92 | super().__init__() 93 | 94 | self.enable_amp = config["mixed_precision"] 95 | self.pooling_mode = config["pooling_mode"] 96 | self.vocab_size = config["vocab_size"] 97 | 98 | self.model = Model(config) 99 | 100 | self.seq_classifer = SCHeadDual(config) 101 | 102 | def forward(self, input_ids_0, input_ids_1, mask_0, mask_1, label): 103 | 104 | with torch.cuda.amp.autocast(enabled=self.enable_amp): 105 | 106 | if self.pooling_mode == "CLS": 107 | input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size) 108 | input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, self.vocab_size) 109 | 110 | token_out_0 = self.model(input_ids_0, mask_0) 111 | token_out_1 = self.model(input_ids_1, mask_1) 112 | seq_scores = self.seq_classifer(token_out_0, token_out_1) 113 | 114 | seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label) 115 | seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32) 116 | outputs = {} 117 | outputs["loss"] = seq_loss 118 | outputs["accu"] = seq_accu 119 | 120 | return outputs 121 | -------------------------------------------------------------------------------- /examples/LRA/code/run_tasks.py: -------------------------------------------------------------------------------- 1 | from model_wrapper import ModelForSC, ModelForSCDual 2 | from dataset import LRADataset 3 | from torch.utils.data import DataLoader 4 | import torch 5 | import torch.nn as nn 6 | import time 7 | import os 8 | import json 9 | import pickle 10 | import numpy as np 11 | import argparse 12 | import math 13 | import itertools 14 | import lra_config 15 | from transformers import set_seed 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--model", type=str, help="model", dest="model", required=True) 19 | parser.add_argument("--task", type=str, help="task", dest="task", required=True) 20 | parser.add_argument("--skip_train", type=int, help="skip_train", dest="skip_train", default=0) 21 | parser.add_argument("--output", type=str, help="output dir", required=True) 22 | parser.add_argument("--segment_num", type=int, default=None) 23 | # parser.add_argument("--seed", type = int, default=43) 24 | 25 | args = parser.parse_args() 26 | 27 | attn_type = args.model 28 | task = args.task 29 | 30 | checkpoint_dir = args.output 31 | 32 | print(lra_config.config[task]["extra_attn_config"].keys(), flush=True) 33 | 34 | model_config = lra_config.config[task]["model"] 35 | model_config.update(lra_config.config[task]["extra_attn_config"][attn_type]) 36 | 37 | model_config["mixed_precision"] = True 38 | model_config["attn_type"] = attn_type 39 | model_config["max_seq_len"] = int(2 ** math.ceil(math.log2(model_config["max_seq_len"]))) 40 | 41 | if args.segment_num is not None: 42 | model_config["segment_num"] = args.segment_num 43 | 44 | training_config = lra_config.config[task]["training"] 45 | gpu_memory_config = lra_config.config[task]["gpu_memory"] 46 | 47 | device_ids = list(range(torch.cuda.device_count())) 48 | print(f"GPU list: {device_ids}") 49 | 50 | print(json.dumps([model_config, training_config], indent=4)) 51 | 52 | # set_seed(args.seed) 53 | 54 | if task == "retrieval": 55 | model = ModelForSCDual(model_config) 56 | else: 57 | model = ModelForSC(model_config) 58 | 59 | print(model) 60 | print(f"parameter_size: {[weight.size() for weight in model.parameters()]}", flush=True) 61 | print(f"num_parameter: {np.sum([np.prod(weight.size()) for weight in model.parameters()])}", flush=True) 62 | 63 | model = model.cuda() 64 | model = nn.DataParallel(model, device_ids=device_ids) 65 | 66 | ds_iter = { 67 | "train": enumerate(DataLoader(LRADataset(f"../datasets/{task}.train.pickle", True), batch_size=training_config["batch_size"], drop_last=True)), 68 | "dev": enumerate(DataLoader(LRADataset(f"../datasets/{task}.dev.pickle", True), batch_size=training_config["batch_size"], drop_last=True)), 69 | "test": enumerate(DataLoader(LRADataset(f"../datasets/{task}.test.pickle", False), batch_size=training_config["batch_size"], drop_last=True)), 70 | } 71 | 72 | optimizer = torch.optim.AdamW( 73 | model.parameters(), 74 | lr=training_config["learning_rate"], 75 | betas=(0.9, 0.999), eps=1e-6, weight_decay=training_config["weight_decay"] 76 | ) 77 | 78 | lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( 79 | optimizer=optimizer, 80 | max_lr=training_config["learning_rate"], 81 | pct_start=training_config["warmup"] / training_config["num_train_steps"], 82 | anneal_strategy=training_config["lr_decay"], 83 | total_steps=training_config["num_train_steps"] 84 | ) 85 | 86 | amp_scaler = torch.cuda.amp.GradScaler() if model_config["mixed_precision"] else None 87 | 88 | 89 | def step(component, step_idx): 90 | t0 = time.time() 91 | 92 | optimizer.zero_grad() 93 | 94 | _, batch = next(ds_iter[component]) 95 | for key in batch: 96 | batch[key] = batch[key].cuda() 97 | 98 | if component == "train": 99 | outputs = {} 100 | 101 | partial_inputs_list = [{} for _ in range(accumu_steps)] 102 | for key in batch: 103 | for idx, inp in enumerate(torch.chunk(batch[key], accumu_steps, dim=0)): 104 | partial_inputs_list[idx][key] = inp 105 | 106 | for partial_inputs in partial_inputs_list: 107 | partial_outputs = model(**partial_inputs) 108 | for key in partial_outputs: 109 | partial_outputs[key] = partial_outputs[key].mean() / accumu_steps 110 | if key not in outputs: 111 | outputs[key] = partial_outputs[key] 112 | else: 113 | outputs[key] += partial_outputs[key] 114 | amp_scaler.scale(partial_outputs["loss"]).backward() 115 | 116 | amp_scaler.step(optimizer) 117 | amp_scaler.update() 118 | lr_scheduler.step() 119 | else: 120 | with torch.no_grad(): 121 | outputs = {} 122 | 123 | partial_inputs_list = [{} for _ in range(accumu_steps)] 124 | for key in batch: 125 | for idx, inp in enumerate(torch.chunk(batch[key], accumu_steps, dim=0)): 126 | partial_inputs_list[idx][key] = inp 127 | 128 | for partial_inputs in partial_inputs_list: 129 | partial_outputs = model(**partial_inputs) 130 | for key in partial_outputs: 131 | partial_outputs[key] = partial_outputs[key].mean() / accumu_steps 132 | if key not in outputs: 133 | outputs[key] = partial_outputs[key] 134 | else: 135 | outputs[key] += partial_outputs[key] 136 | 137 | t1 = time.time() 138 | 139 | batch_size = batch[list(batch.keys())[0]].size(0) 140 | t_escape = t1 - t0 141 | learning_rate = optimizer.param_groups[0]["lr"] 142 | loss = outputs["loss"].data.item() 143 | accu = outputs["accu"].data.item() 144 | time_since_start = time.time() - init_t 145 | 146 | print(f"step={step_idx}, tt={time_since_start:.1f}, t={t_escape:.3f}, bs={batch_size}, lr={learning_rate:.6f}, loss={loss:.4f}, accu={accu:.4f}\t\t\t\t", end="\r", flush=True) 147 | 148 | summary[component]["t"] += t_escape 149 | summary[component]["loss"].append(loss) 150 | summary[component]["accu"].append(accu) 151 | 152 | 153 | def print_summary(summary, save_if_improved, train_step_idx): 154 | summary["loss"] = np.mean(summary["loss"]) 155 | summary["accu"] = np.mean(summary["accu"]) 156 | 157 | print() 158 | if summary["accu"] > summary["best_accu"]: 159 | summary["best_accu"] = summary["accu"] 160 | if save_if_improved: 161 | best_accu = summary["best_accu"] 162 | torch.save({"model_state_dict": model.module.state_dict()}, log_f_path.replace(".log", ".model")) 163 | print(f"best_accu={best_accu}. Saved best model") 164 | 165 | summary_round = {"train_step_idx": train_step_idx} 166 | for key in summary: 167 | if type(summary[key]) is str: 168 | summary_round[key] = summary[key] 169 | else: 170 | summary_round[key] = round(summary[key], 4) 171 | 172 | print(summary_round, flush=True) 173 | log_f.write(json.dumps(summary_round, sort_keys=True) + "\n") 174 | log_f.flush() 175 | 176 | summary["t"] = 0 177 | summary["loss"] = [] 178 | summary["accu"] = [] 179 | 180 | 181 | init_t = time.time() 182 | 183 | log_f_path = os.path.join(checkpoint_dir, f"{task}_{attn_type}_output.log") 184 | log_f = open(log_f_path, "a+") 185 | 186 | summary = { 187 | component: {"t": 0, "loss": [], "accu": [], "best_accu": 0, "component": component} 188 | for component in ["train", "dev", "test"] 189 | } 190 | 191 | accumu_steps = max(training_config["batch_size"] // len(device_ids) // gpu_memory_config[attn_type], 1) 192 | print(f"accumu_steps={accumu_steps}") 193 | 194 | train_step_idx = None 195 | if args.skip_train == 0: 196 | try: 197 | model.train() 198 | for train_step_idx in range(training_config["num_train_steps"]): 199 | outputs = step("train", train_step_idx) 200 | 201 | if (train_step_idx + 1) % training_config["eval_frequency"] == 0: 202 | print_summary(summary["train"], False, train_step_idx) 203 | model.eval() 204 | for dev_step_idx in range(training_config["num_eval_steps"]): 205 | outputs = step("dev", dev_step_idx) 206 | print_summary(summary["dev"], True, train_step_idx) 207 | model.train() 208 | except KeyboardInterrupt as e: 209 | print(e) 210 | 211 | checkpoint = torch.load(log_f_path.replace(".log", ".model"), map_location="cpu") 212 | model.module.load_state_dict(checkpoint["model_state_dict"]) 213 | model.eval() 214 | try: 215 | for test_step_idx in itertools.count(): 216 | outputs = step("test", test_step_idx) 217 | except StopIteration: 218 | print_summary(summary["test"], False, train_step_idx) 219 | -------------------------------------------------------------------------------- /examples/LRA/code/run_tasks.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=`pwd` 2 | export CUDA_VISIBLE_DEVICES=0 3 | 4 | MODEL=$1 5 | OUTPUT=../logs 6 | 7 | if [ ! -d ${OUTPUT} ]; then 8 | mkdir ${OUTPUT} 9 | fi 10 | 11 | python3 run_tasks.py --model ${MODEL} --task listops --output ${OUTPUT} 12 | python3 run_tasks.py --model ${MODEL} --task text --output ${OUTPUT} 13 | python3 run_tasks.py --model ${MODEL} --task retrieval --output ${OUTPUT} 14 | python3 run_tasks.py --model ${MODEL} --task image --output ${OUTPUT} 15 | python3 run_tasks.py --model ${MODEL} --task pathfinder32-curv_contour_length_14 --output ${OUTPUT} 16 | -------------------------------------------------------------------------------- /examples/LRA/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("./long-range-arena/lra_benchmarks/image/") 3 | import input_pipeline 4 | import numpy as np 5 | import pickle 6 | 7 | train_ds, eval_ds, test_ds, num_classes, vocab_size, input_shape = input_pipeline.get_cifar10_datasets( 8 | n_devices = 1, batch_size = 1, normalize = False) 9 | 10 | mapping = {"train":train_ds, "dev": eval_ds, "test":test_ds} 11 | for component in mapping: 12 | ds_list = [] 13 | for idx, inst in enumerate(iter(mapping[component])): 14 | ds_list.append({ 15 | "input_ids_0":inst["inputs"].numpy()[0].reshape(-1), 16 | "label":inst["targets"].numpy()[0] 17 | }) 18 | if idx % 100 == 0: 19 | print(f"{idx}\t\t", end = "\r") 20 | with open(f"image.{component}.pickle", "wb") as f: 21 | pickle.dump(ds_list, f) 22 | -------------------------------------------------------------------------------- /examples/LRA/datasets/create_datasets.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=`pwd`/long-range-arena 2 | 3 | python3 cifar10.py 4 | python3 pathfinder.py 5 | python3 listops.py 6 | python3 retrieval.py 7 | python3 text.py 8 | -------------------------------------------------------------------------------- /examples/LRA/datasets/delete_repeat.sh: -------------------------------------------------------------------------------- 1 | sed -i 's|train_dataset = train_dataset.repeat()|# train_dataset = train_dataset.repeat()|g' ./long-range-arena/lra_benchmarks/image/input_pipeline.py -------------------------------------------------------------------------------- /examples/LRA/datasets/listops.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("./long-range-arena/lra_benchmarks/listops/") 3 | import input_pipeline 4 | import numpy as np 5 | import pickle 6 | 7 | train_ds, eval_ds, test_ds, encoder = input_pipeline.get_datasets( 8 | n_devices = 1, task_name = "basic", data_dir = "./lra_release/lra_release/listops-1000/", 9 | batch_size = 1, max_length = 2000) 10 | 11 | mapping = {"train":train_ds, "dev": eval_ds, "test":test_ds} 12 | for component in mapping: 13 | ds_list = [] 14 | for idx, inst in enumerate(iter(mapping[component])): 15 | ds_list.append({ 16 | "input_ids_0":np.concatenate([inst["inputs"].numpy()[0], np.zeros(48, dtype = np.int32)]), 17 | "label":inst["targets"].numpy()[0] 18 | }) 19 | if idx % 100 == 0: 20 | print(f"{idx}\t\t", end = "\r") 21 | with open(f"listops.{component}.pickle", "wb") as f: 22 | pickle.dump(ds_list, f) 23 | -------------------------------------------------------------------------------- /examples/LRA/datasets/pathfinder.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import os 4 | import pickle 5 | import tensorflow as tf 6 | import random 7 | 8 | root_dir = "./lra_release/lra_release/" 9 | subdir = "pathfinder32" 10 | for diff_level in ["curv_baseline", "curv_contour_length_9", "curv_contour_length_14"]: 11 | data_dir = os.path.join(root_dir, subdir, diff_level) 12 | metadata_list = [ 13 | os.path.join(data_dir, "metadata", file) 14 | for file in os.listdir(os.path.join(data_dir, "metadata")) 15 | if file.endswith(".npy") 16 | ] 17 | ds_list = [] 18 | for idx, metadata_file in enumerate(metadata_list): 19 | print(idx, len(metadata_list), metadata_file, "\t\t", end = "\r") 20 | for inst_meta in tf.io.read_file(metadata_file).numpy().decode("utf-8").split("\n")[:-1]: 21 | metadata = inst_meta.split(" ") 22 | img_path = os.path.join(data_dir, metadata[0], metadata[1]) 23 | img_bin = tf.io.read_file(img_path) 24 | if len(img_bin.numpy()) == 0: 25 | print() 26 | print("detected empty image") 27 | continue 28 | img = tf.image.decode_png(img_bin) 29 | seq = img.numpy().reshape(-1).astype(np.int32) 30 | label = int(metadata[3]) 31 | ds_list.append({"input_ids_0":seq, "label":label}) 32 | 33 | random.shuffle(ds_list) 34 | 35 | bp80 = int(len(ds_list) * 0.8) 36 | bp90 = int(len(ds_list) * 0.9) 37 | train = ds_list[:bp80] 38 | dev = ds_list[bp80:bp90] 39 | test = ds_list[bp90:] 40 | 41 | with open(f"{subdir}-{diff_level}.train.pickle", "wb") as f: 42 | pickle.dump(train, f) 43 | with open(f"{subdir}-{diff_level}.dev.pickle", "wb") as f: 44 | pickle.dump(dev, f) 45 | with open(f"{subdir}-{diff_level}.test.pickle", "wb") as f: 46 | pickle.dump(test, f) 47 | -------------------------------------------------------------------------------- /examples/LRA/datasets/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorboard >= 2.3.0 2 | tensorflow >= 2.3.1 3 | tensorflow-datasets >= 4.0.1 4 | tensorflow_text -------------------------------------------------------------------------------- /examples/LRA/datasets/retrieval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("./long-range-arena/lra_benchmarks/matching/") 3 | import input_pipeline 4 | import numpy as np 5 | import pickle 6 | 7 | train_ds, eval_ds, test_ds, encoder = input_pipeline.get_matching_datasets( 8 | n_devices = 1, task_name = None, data_dir = "./lra_release/lra_release/tsv_data/", 9 | batch_size = 1, fixed_vocab = None, max_length = 4000, tokenizer = "char", 10 | vocab_file_path = None) 11 | 12 | mapping = {"train":train_ds, "dev": eval_ds, "test":test_ds} 13 | for component in mapping: 14 | ds_list = [] 15 | for idx, inst in enumerate(iter(mapping[component])): 16 | ds_list.append({ 17 | "input_ids_0":np.concatenate([inst["inputs1"].numpy()[0], np.zeros(96, dtype = np.int32)]), 18 | "input_ids_1":np.concatenate([inst["inputs2"].numpy()[0], np.zeros(96, dtype = np.int32)]), 19 | "label":inst["targets"].numpy()[0] 20 | }) 21 | if idx % 100 == 0: 22 | print(f"{idx}\t\t", end = "\r") 23 | with open(f"retrieval.{component}.pickle", "wb") as f: 24 | pickle.dump(ds_list, f) 25 | -------------------------------------------------------------------------------- /examples/LRA/datasets/text.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("./long-range-arena/lra_benchmarks/text_classification/") 3 | import input_pipeline 4 | import numpy as np 5 | import pickle 6 | 7 | train_ds, eval_ds, test_ds, encoder = input_pipeline.get_tc_datasets( 8 | n_devices = 1, task_name = "imdb_reviews", data_dir = None, 9 | batch_size = 1, fixed_vocab = None, max_length = 4000) 10 | 11 | mapping = {"train":train_ds, "dev": eval_ds, "test":test_ds} 12 | for component in mapping: 13 | ds_list = [] 14 | for idx, inst in enumerate(iter(mapping[component])): 15 | ds_list.append({ 16 | "input_ids_0":np.concatenate([inst["inputs"].numpy()[0], np.zeros(96, dtype = np.int32)]), 17 | "label":inst["targets"].numpy()[0] 18 | }) 19 | if idx % 100 == 0: 20 | print(f"{idx}\t\t", end = "\r") 21 | with open(f"text.{component}.pickle", "wb") as f: 22 | pickle.dump(ds_list, f) 23 | -------------------------------------------------------------------------------- /extra/classifier_trainer.py: -------------------------------------------------------------------------------- 1 | from transformers.trainer import Trainer 2 | 3 | import math 4 | import os 5 | import shutil 6 | import sys 7 | import time 8 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union 9 | 10 | from tqdm.auto import tqdm 11 | 12 | 13 | # Integrations must be imported before ML frameworks: 14 | # isort: off 15 | from transformers.integrations import ( 16 | hp_params, 17 | is_fairscale_available, 18 | ) 19 | 20 | # isort: on 21 | 22 | import torch 23 | import torch.distributed as dist 24 | from packaging import version 25 | from torch import nn 26 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler 27 | from torch.utils.data.distributed import DistributedSampler 28 | 29 | from transformers import __version__ 30 | from transformers.debug_utils import DebugOption, DebugUnderflowOverflow 31 | from transformers.deepspeed import deepspeed_init, is_deepspeed_zero3_enabled 32 | from transformers.dependency_versions_check import dep_version_check 33 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11 34 | from transformers.trainer_callback import ( 35 | DefaultFlowCallback, 36 | ProgressCallback, 37 | TrainerState, 38 | ) 39 | from transformers.trainer_pt_utils import ( 40 | IterableDatasetShard, 41 | get_model_param_count, 42 | ) 43 | from transformers.trainer_utils import ( 44 | HPSearchBackend, 45 | ShardedDDPOption, 46 | TrainOutput, 47 | has_length, 48 | speed_metrics, 49 | ) 50 | from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments 51 | from transformers.utils import ( 52 | is_accelerate_available, 53 | is_apex_available, 54 | is_datasets_available, 55 | is_in_notebook, 56 | is_safetensors_available, 57 | is_sagemaker_mp_enabled, 58 | is_torch_tpu_available, 59 | logging, 60 | ) 61 | 62 | 63 | DEFAULT_CALLBACKS = [DefaultFlowCallback] 64 | DEFAULT_PROGRESS_CALLBACK = ProgressCallback 65 | 66 | if is_in_notebook(): 67 | from transformers.utils.notebook import NotebookProgressCallback 68 | 69 | DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback 70 | 71 | if is_apex_available(): 72 | from apex import amp 73 | 74 | if is_datasets_available(): 75 | import datasets 76 | 77 | if is_torch_tpu_available(check_device=False): 78 | import torch_xla.core.xla_model as xm 79 | import torch_xla.debug.metrics as met 80 | import torch_xla.distributed.parallel_loader as pl 81 | 82 | if is_fairscale_available(): 83 | dep_version_check("fairscale") 84 | import fairscale 85 | from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP 86 | from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP 87 | from fairscale.nn.wrap import auto_wrap 88 | from fairscale.optim import OSS 89 | from fairscale.optim.grad_scaler import ShardedGradScaler 90 | 91 | 92 | if is_sagemaker_mp_enabled(): 93 | import smdistributed.modelparallel.torch as smp 94 | from smdistributed.modelparallel import __version__ as SMP_VERSION 95 | 96 | IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") 97 | 98 | from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat 99 | else: 100 | IS_SAGEMAKER_MP_POST_1_10 = False 101 | 102 | 103 | if is_safetensors_available(): 104 | import safetensors.torch 105 | 106 | 107 | skip_first_batches = None 108 | if is_accelerate_available(): 109 | from accelerate import __version__ as accelerate_version 110 | 111 | if version.parse(accelerate_version) >= version.parse("0.16"): 112 | from accelerate import skip_first_batches 113 | 114 | 115 | if TYPE_CHECKING: 116 | import optuna 117 | 118 | logger = logging.get_logger(__name__) 119 | 120 | 121 | # Name of the files used for checkpointing 122 | TRAINING_ARGS_NAME = "training_args.bin" 123 | TRAINER_STATE_NAME = "trainer_state.json" 124 | OPTIMIZER_NAME = "optimizer.pt" 125 | SCHEDULER_NAME = "scheduler.pt" 126 | SCALER_NAME = "scaler.pt" 127 | 128 | 129 | class Classifier_Trainer(Trainer): 130 | def prediction_step( 131 | self, 132 | model: nn.Module, 133 | inputs: Dict[str, Union[torch.Tensor, Any]], 134 | prediction_loss_only: bool, 135 | ignore_keys: Optional[List[str]] = None, 136 | ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: 137 | loss, logits, labels = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) 138 | if logits is not None: 139 | if type(logits) == tuple: 140 | logits = tuple([l.argmax(dim=-1) for l in logits]) 141 | else: 142 | logits = logits.argmax(dim=-1) 143 | return (loss, logits, labels) 144 | 145 | 146 | class SM_Trainer(Trainer): 147 | def prediction_step( 148 | self, 149 | model: nn.Module, 150 | inputs: Dict[str, Union[torch.Tensor, Any]], 151 | prediction_loss_only: bool, 152 | ignore_keys: Optional[List[str]] = None, 153 | ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: 154 | loss, logits, labels = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) 155 | if logits is not None: 156 | bz = labels[0].shape[0] if type(labels) in [tuple, list] else labels.shape[0] 157 | mlm_loss = logits[0].view(1, 1).repeat(bz, 1) 158 | sso_loss = logits[1].view(1, 1).repeat(bz, 1) 159 | logits = tuple([mlm_loss, sso_loss,] + [l.argmax(dim=-1) for l in logits[2:]]) 160 | return (loss, logits, labels) 161 | 162 | def _inner_training_loop( 163 | self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None 164 | ): 165 | self._train_batch_size = batch_size 166 | # Data loader and number of training steps 167 | train_dataloader = self.get_train_dataloader() 168 | 169 | # Setting up training control variables: 170 | # number of training epochs: num_train_epochs 171 | # number of training steps per epoch: num_update_steps_per_epoch 172 | # total number of training steps to execute: max_steps 173 | total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size 174 | 175 | len_dataloader = None 176 | if has_length(train_dataloader): 177 | len_dataloader = len(train_dataloader) 178 | num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps 179 | num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) 180 | num_examples = self.num_examples(train_dataloader) 181 | if args.max_steps > 0: 182 | max_steps = args.max_steps 183 | num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( 184 | args.max_steps % num_update_steps_per_epoch > 0 185 | ) 186 | # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's 187 | # the best we can do. 188 | num_train_samples = args.max_steps * total_train_batch_size 189 | else: 190 | max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) 191 | num_train_epochs = math.ceil(args.num_train_epochs) 192 | num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs 193 | elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size 194 | max_steps = args.max_steps 195 | # Setting a very large number of epochs so we go as many times as necessary over the iterator. 196 | num_train_epochs = sys.maxsize 197 | num_update_steps_per_epoch = max_steps 198 | num_examples = total_train_batch_size * args.max_steps 199 | num_train_samples = args.max_steps * total_train_batch_size 200 | else: 201 | raise ValueError( 202 | "args.max_steps must be set to a positive value if dataloader does not have a length, was" 203 | f" {args.max_steps}" 204 | ) 205 | 206 | # Compute absolute values for logging, eval, and save if given as ratio 207 | if args.logging_steps and args.logging_steps < 1: 208 | args.logging_steps = math.ceil(max_steps * args.logging_steps) 209 | if args.eval_steps and args.eval_steps < 1: 210 | args.eval_steps = math.ceil(max_steps * args.eval_steps) 211 | if args.save_steps and args.save_steps < 1: 212 | args.save_steps = math.ceil(max_steps * args.save_steps) 213 | 214 | if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: 215 | if self.args.n_gpu > 1: 216 | # nn.DataParallel(model) replicates the model, creating new variables and module 217 | # references registered here no longer work on other gpus, breaking the module 218 | raise ValueError( 219 | "Currently --debug underflow_overflow is not supported under DP. Please use DDP" 220 | " (torch.distributed.launch)." 221 | ) 222 | else: 223 | debug_overflow = DebugUnderflowOverflow(self.model) # noqa 224 | 225 | delay_optimizer_creation = ( 226 | self.sharded_ddp is not None 227 | and self.sharded_ddp != ShardedDDPOption.SIMPLE 228 | or is_sagemaker_mp_enabled() 229 | or self.fsdp is not None 230 | ) 231 | if args.deepspeed: 232 | deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( 233 | self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint 234 | ) 235 | self.model = deepspeed_engine.module 236 | self.model_wrapped = deepspeed_engine 237 | self.deepspeed = deepspeed_engine 238 | self.optimizer = optimizer 239 | self.lr_scheduler = lr_scheduler 240 | elif not delay_optimizer_creation: 241 | self.create_optimizer_and_scheduler(num_training_steps=max_steps) 242 | 243 | self.state = TrainerState() 244 | self.state.is_hyper_param_search = trial is not None 245 | 246 | # Activate gradient checkpointing if needed 247 | if args.gradient_checkpointing: 248 | self.model.gradient_checkpointing_enable() 249 | 250 | model = self._wrap_model(self.model_wrapped) 251 | 252 | if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: 253 | self._load_from_checkpoint(resume_from_checkpoint, model) 254 | 255 | # for the rest of this function `model` is the outside model, whether it was wrapped or not 256 | if model is not self.model: 257 | self.model_wrapped = model 258 | 259 | if delay_optimizer_creation: 260 | self.create_optimizer_and_scheduler(num_training_steps=max_steps) 261 | 262 | # Check if saved optimizer or scheduler states exist 263 | self._load_optimizer_and_scheduler(resume_from_checkpoint) 264 | 265 | # important: at this point: 266 | # self.model is the Transformers Model 267 | # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc. 268 | 269 | # Train! 270 | logger.info("***** Running training *****") 271 | logger.info(f" Num examples = {num_examples:,}") 272 | logger.info(f" Num Epochs = {num_train_epochs:,}") 273 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size:,}") 274 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") 275 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 276 | logger.info(f" Total optimization steps = {max_steps:,}") 277 | logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") 278 | 279 | self.state.epoch = 0 280 | start_time = time.time() 281 | epochs_trained = 0 282 | steps_trained_in_current_epoch = 0 283 | steps_trained_progress_bar = None 284 | 285 | # Check if continuing training from a checkpoint 286 | if resume_from_checkpoint is not None and os.path.isfile( 287 | os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) 288 | ): 289 | self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) 290 | epochs_trained = self.state.global_step // num_update_steps_per_epoch 291 | if not args.ignore_data_skip: 292 | steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) 293 | steps_trained_in_current_epoch *= args.gradient_accumulation_steps 294 | else: 295 | steps_trained_in_current_epoch = 0 296 | 297 | logger.info(" Continuing training from checkpoint, will skip to saved global_step") 298 | logger.info(f" Continuing training from epoch {epochs_trained}") 299 | logger.info(f" Continuing training from global step {self.state.global_step}") 300 | if not args.ignore_data_skip: 301 | if skip_first_batches is None: 302 | logger.info( 303 | f" Will skip the first {epochs_trained} epochs then the first" 304 | f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time," 305 | " you can install the latest version of Accelerate with `pip install -U accelerate`.You can" 306 | " also add the `--ignore_data_skip` flag to your launch command, but you will resume the" 307 | " training on data already seen by your model." 308 | ) 309 | else: 310 | logger.info( 311 | f" Will skip the first {epochs_trained} epochs then the first" 312 | f" {steps_trained_in_current_epoch} batches in the first epoch." 313 | ) 314 | if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None: 315 | steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) 316 | steps_trained_progress_bar.set_description("Skipping the first batches") 317 | 318 | # Update the references 319 | self.callback_handler.model = self.model 320 | self.callback_handler.optimizer = self.optimizer 321 | self.callback_handler.lr_scheduler = self.lr_scheduler 322 | self.callback_handler.train_dataloader = train_dataloader 323 | if self.hp_name is not None and self._trial is not None: 324 | # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial 325 | # parameter to Train when using DDP. 326 | self.state.trial_name = self.hp_name(self._trial) 327 | if trial is not None: 328 | assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial 329 | self.state.trial_params = hp_params(assignments) 330 | else: 331 | self.state.trial_params = None 332 | # This should be the same if the state has been saved but in case the training arguments changed, it's safer 333 | # to set this after the load. 334 | self.state.max_steps = max_steps 335 | self.state.num_train_epochs = num_train_epochs 336 | self.state.is_local_process_zero = self.is_local_process_zero() 337 | self.state.is_world_process_zero = self.is_world_process_zero() 338 | 339 | # tr_loss is a tensor to avoid synchronization of TPUs through .item() 340 | tr_loss = torch.tensor(0.0).to(args.device) 341 | mlm_tr_loss = torch.tensor(0.0).to(args.device) 342 | sso_tr_loss = torch.tensor(0.0).to(args.device) 343 | # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses 344 | self._total_loss_scalar = 0.0 345 | self._mlm_total_loss_scalar = 0.0 346 | self._sop_total_loss_scalar = 0.0 347 | self._globalstep_last_logged = self.state.global_step 348 | model.zero_grad() 349 | 350 | self.control = self.callback_handler.on_train_begin(args, self.state, self.control) 351 | 352 | # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. 353 | if not args.ignore_data_skip: 354 | for epoch in range(epochs_trained): 355 | is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance( 356 | train_dataloader.sampler, RandomSampler 357 | ) 358 | if is_torch_less_than_1_11 or not is_random_sampler: 359 | # We just need to begin an iteration to create the randomization of the sampler. 360 | # That was before PyTorch 1.11 however... 361 | for _ in train_dataloader: 362 | break 363 | else: 364 | # Otherwise we need to call the whooooole sampler cause there is some random operation added 365 | # AT THE VERY END! 366 | _ = list(train_dataloader.sampler) 367 | 368 | total_batched_samples = 0 369 | for epoch in range(epochs_trained, num_train_epochs): 370 | if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): 371 | train_dataloader.sampler.set_epoch(epoch) 372 | elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard): 373 | train_dataloader.dataset.set_epoch(epoch) 374 | 375 | if is_torch_tpu_available(): 376 | parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device) 377 | epoch_iterator = parallel_loader 378 | else: 379 | epoch_iterator = train_dataloader 380 | 381 | # Reset the past mems state at the beginning of each epoch if necessary. 382 | if args.past_index >= 0: 383 | self._past = None 384 | 385 | steps_in_epoch = ( 386 | len(epoch_iterator) 387 | if len_dataloader is not None 388 | else args.max_steps * args.gradient_accumulation_steps 389 | ) 390 | self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) 391 | 392 | if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: 393 | self._load_rng_state(resume_from_checkpoint) 394 | 395 | rng_to_sync = False 396 | steps_skipped = 0 397 | if skip_first_batches is not None and steps_trained_in_current_epoch > 0: 398 | epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) 399 | steps_skipped = steps_trained_in_current_epoch 400 | steps_trained_in_current_epoch = 0 401 | rng_to_sync = True 402 | 403 | step = -1 404 | for step, inputs in enumerate(epoch_iterator): 405 | total_batched_samples += 1 406 | if rng_to_sync: 407 | self._load_rng_state(resume_from_checkpoint) 408 | rng_to_sync = False 409 | 410 | # Skip past any already trained steps if resuming training 411 | if steps_trained_in_current_epoch > 0: 412 | steps_trained_in_current_epoch -= 1 413 | if steps_trained_progress_bar is not None: 414 | steps_trained_progress_bar.update(1) 415 | if steps_trained_in_current_epoch == 0: 416 | self._load_rng_state(resume_from_checkpoint) 417 | continue 418 | elif steps_trained_progress_bar is not None: 419 | steps_trained_progress_bar.close() 420 | steps_trained_progress_bar = None 421 | 422 | if step % args.gradient_accumulation_steps == 0: 423 | self.control = self.callback_handler.on_step_begin(args, self.state, self.control) 424 | 425 | if ( 426 | (total_batched_samples % args.gradient_accumulation_steps != 0) 427 | and args.parallel_mode == ParallelMode.DISTRIBUTED 428 | and args._no_sync_in_gradient_accumulation 429 | and hasattr(model, "no_sync") 430 | ): 431 | # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. 432 | with model.no_sync(): 433 | tr_loss_step, mlm_loss_step, sso_loss_step = self.training_step(model, inputs) 434 | else: 435 | tr_loss_step, mlm_loss_step, sso_loss_step = self.training_step(model, inputs) 436 | 437 | if ( 438 | args.logging_nan_inf_filter 439 | and not is_torch_tpu_available() 440 | and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) 441 | ): 442 | # if loss is nan or inf simply add the average of previous logged losses 443 | tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) 444 | mlm_tr_loss += mlm_loss_step / (1 + self.state.global_step - self._globalstep_last_logged) 445 | sso_tr_loss += sso_loss_step / (1 + self.state.global_step - self._globalstep_last_logged) 446 | else: 447 | tr_loss += tr_loss_step 448 | mlm_tr_loss += mlm_loss_step 449 | sso_tr_loss += sso_loss_step 450 | 451 | self.current_flos += float(self.floating_point_ops(inputs)) 452 | 453 | # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps 454 | if self.deepspeed: 455 | self.deepspeed.step() 456 | 457 | if total_batched_samples % args.gradient_accumulation_steps == 0 or ( 458 | # last step in epoch but step is always smaller than gradient_accumulation_steps 459 | steps_in_epoch <= args.gradient_accumulation_steps 460 | and (step + 1) == steps_in_epoch 461 | ): 462 | # Gradient clipping 463 | if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed: 464 | # deepspeed does its own clipping 465 | 466 | if self.do_grad_scaling: 467 | # Reduce gradients first for XLA 468 | if is_torch_tpu_available(): 469 | gradients = xm._fetch_gradients(self.optimizer) 470 | xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) 471 | # AMP: gradients need unscaling 472 | self.scaler.unscale_(self.optimizer) 473 | 474 | if is_sagemaker_mp_enabled() and args.fp16: 475 | self.optimizer.clip_master_grads(args.max_grad_norm) 476 | elif hasattr(self.optimizer, "clip_grad_norm"): 477 | # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping 478 | self.optimizer.clip_grad_norm(args.max_grad_norm) 479 | elif hasattr(model, "clip_grad_norm_"): 480 | # Some models (like FullyShardedDDP) have a specific way to do gradient clipping 481 | model.clip_grad_norm_(args.max_grad_norm) 482 | else: 483 | # Revert to normal clipping otherwise, handling Apex or full precision 484 | nn.utils.clip_grad_norm_( 485 | amp.master_params(self.optimizer) if self.use_apex else model.parameters(), 486 | args.max_grad_norm, 487 | ) 488 | 489 | # Optimizer step 490 | optimizer_was_run = True 491 | if self.deepspeed: 492 | pass # called outside the loop 493 | elif is_torch_tpu_available(): 494 | if self.do_grad_scaling: 495 | self.scaler.step(self.optimizer) 496 | self.scaler.update() 497 | else: 498 | xm.optimizer_step(self.optimizer) 499 | elif self.do_grad_scaling: 500 | scale_before = self.scaler.get_scale() 501 | self.scaler.step(self.optimizer) 502 | self.scaler.update() 503 | scale_after = self.scaler.get_scale() 504 | optimizer_was_run = scale_before <= scale_after 505 | else: 506 | self.optimizer.step() 507 | 508 | if optimizer_was_run and not self.deepspeed: 509 | # Delay optimizer scheduling until metrics are generated 510 | if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): 511 | self.lr_scheduler.step() 512 | 513 | model.zero_grad() 514 | self.state.global_step += 1 515 | self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch 516 | self.control = self.callback_handler.on_step_end(args, self.state, self.control) 517 | 518 | self._maybe_log_save_evaluate(tr_loss, mlm_tr_loss, sso_tr_loss, model, trial, epoch, ignore_keys_for_eval) 519 | else: 520 | self.control = self.callback_handler.on_substep_end(args, self.state, self.control) 521 | 522 | if self.control.should_epoch_stop or self.control.should_training_stop: 523 | break 524 | if step < 0: 525 | logger.warning( 526 | "There seems to be not a single sample in your epoch_iterator, stopping training at step" 527 | f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" 528 | f" num_steps ({max_steps}) higher than the number of available samples." 529 | ) 530 | self.control.should_training_stop = True 531 | 532 | self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) 533 | self._maybe_log_save_evaluate(tr_loss, mlm_tr_loss, sso_tr_loss, model, trial, epoch, ignore_keys_for_eval) 534 | 535 | if DebugOption.TPU_METRICS_DEBUG in self.args.debug: 536 | if is_torch_tpu_available(): 537 | # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) 538 | xm.master_print(met.metrics_report()) 539 | else: 540 | logger.warning( 541 | "You enabled PyTorch/XLA debug metrics but you don't have a TPU " 542 | "configured. Check your training configuration if this is unexpected." 543 | ) 544 | if self.control.should_training_stop: 545 | break 546 | 547 | if args.past_index and hasattr(self, "_past"): 548 | # Clean the state at the end of training 549 | delattr(self, "_past") 550 | 551 | logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") 552 | if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: 553 | # Wait for everyone to get here so we are sur the model has been saved by process 0. 554 | if is_torch_tpu_available(): 555 | xm.rendezvous("load_best_model_at_end") 556 | elif args.parallel_mode == ParallelMode.DISTRIBUTED: 557 | dist.barrier() 558 | elif is_sagemaker_mp_enabled(): 559 | smp.barrier() 560 | 561 | self._load_best_model() 562 | 563 | # add remaining tr_loss 564 | self._total_loss_scalar += tr_loss.item() 565 | self._mlm_total_loss_scalar += mlm_tr_loss.item() 566 | self._sop_total_loss_scalar += sso_tr_loss.item() 567 | 568 | train_loss = self._total_loss_scalar / self.state.global_step 569 | train_mlm_loss = self._mlm_total_loss_scalar / self.state.global_step 570 | train_sso_loss = self._sop_total_loss_scalar / self.state.global_step 571 | 572 | metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) 573 | self.store_flos() 574 | metrics["total_flos"] = self.state.total_flos 575 | metrics["train_loss"] = train_loss 576 | metrics["train_mlm_loss"] = train_mlm_loss 577 | metrics["train_sso_loss"] = train_sso_loss 578 | 579 | self.is_in_train = False 580 | 581 | self._memory_tracker.stop_and_update_metrics(metrics) 582 | 583 | self.log(metrics) 584 | 585 | run_dir = self._get_output_dir(trial) 586 | checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) 587 | 588 | # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. 589 | if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: 590 | for checkpoint in checkpoints_sorted: 591 | if checkpoint != self.state.best_model_checkpoint: 592 | logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") 593 | shutil.rmtree(checkpoint) 594 | 595 | self.control = self.callback_handler.on_train_end(args, self.state, self.control) 596 | 597 | return TrainOutput(self.state.global_step, train_loss, metrics) 598 | 599 | def _maybe_log_save_evaluate(self, tr_loss, mlm_tr_loss, sso_tr_loss, model, trial, epoch, ignore_keys_for_eval): 600 | if self.control.should_log: 601 | if is_torch_tpu_available(): 602 | xm.mark_step() 603 | 604 | logs: Dict[str, float] = {} 605 | 606 | # all_gather + mean() to get average loss over all processes 607 | tr_loss_scalar = self._nested_gather(tr_loss).mean().item() 608 | mlm_tr_loss_scalar = self._nested_gather(mlm_tr_loss).mean().item() 609 | sso_tr_loss_scalar = self._nested_gather(sso_tr_loss).mean().item() 610 | 611 | # reset tr_loss to zero 612 | tr_loss -= tr_loss 613 | mlm_tr_loss -= mlm_tr_loss 614 | sso_tr_loss -= sso_tr_loss 615 | 616 | logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) 617 | logs["mlm_loss"] = round(mlm_tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) 618 | logs["sso_loss"] = round(sso_tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) 619 | logs["learning_rate"] = self._get_learning_rate() 620 | 621 | self._total_loss_scalar += tr_loss_scalar 622 | self._mlm_total_loss_scalar += mlm_tr_loss_scalar 623 | self._sop_total_loss_scalar += sso_tr_loss_scalar 624 | self._globalstep_last_logged = self.state.global_step 625 | self.store_flos() 626 | 627 | self.log(logs) 628 | 629 | metrics = None 630 | if self.control.should_evaluate: 631 | if isinstance(self.eval_dataset, dict): 632 | metrics = {} 633 | for eval_dataset_name, eval_dataset in self.eval_dataset.items(): 634 | dataset_metrics = self.evaluate( 635 | eval_dataset=eval_dataset, 636 | ignore_keys=ignore_keys_for_eval, 637 | metric_key_prefix=f"eval_{eval_dataset_name}", 638 | ) 639 | metrics.update(dataset_metrics) 640 | else: 641 | metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) 642 | self._report_to_hp_search(trial, self.state.global_step, metrics) 643 | 644 | # Run delayed LR scheduler now that metrics are populated 645 | if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): 646 | self.lr_scheduler.step(metrics[self.args.metric_for_best_model]) 647 | 648 | if self.control.should_save: 649 | self._save_checkpoint(model, trial, metrics=metrics) 650 | self.control = self.callback_handler.on_save(self.args, self.state, self.control) 651 | 652 | def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: 653 | """ 654 | Perform a training step on a batch of inputs. 655 | 656 | Subclass and override to inject custom behavior. 657 | 658 | Args: 659 | model (`nn.Module`): 660 | The model to train. 661 | inputs (`Dict[str, Union[torch.Tensor, Any]]`): 662 | The inputs and targets of the model. 663 | 664 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 665 | argument `labels`. Check your model's documentation for all accepted arguments. 666 | 667 | Return: 668 | `torch.Tensor`: The tensor with training loss on this batch. 669 | """ 670 | model.train() 671 | inputs = self._prepare_inputs(inputs) 672 | 673 | if is_sagemaker_mp_enabled(): 674 | loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) 675 | return loss_mb.reduce_mean().detach().to(self.args.device) 676 | 677 | with self.compute_loss_context_manager(): 678 | loss, outputs = self.compute_loss(model, inputs, return_outputs=True) 679 | 680 | mlm_loss = outputs['mlm_loss'].detach() 681 | sso_loss = outputs['sso_loss'].detach() 682 | 683 | if self.args.n_gpu > 1: 684 | loss = loss.mean() # mean() to average on multi-gpu parallel training 685 | mlm_loss = mlm_loss.mean() 686 | sso_loss = sso_loss.mean() 687 | 688 | if self.args.gradient_accumulation_steps > 1 and not self.deepspeed: 689 | # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward` 690 | loss = loss / self.args.gradient_accumulation_steps 691 | mlm_loss = mlm_loss / self.args.gradient_accumulation_steps 692 | sso_loss = sso_loss / self.args.gradient_accumulation_steps 693 | 694 | if self.do_grad_scaling: 695 | self.scaler.scale(loss).backward() 696 | elif self.use_apex: 697 | with amp.scale_loss(loss, self.optimizer) as scaled_loss: 698 | scaled_loss.backward() 699 | elif self.deepspeed: 700 | # loss gets scaled under gradient_accumulation_steps in deepspeed 701 | loss = self.deepspeed.backward(loss) 702 | else: 703 | loss.backward() 704 | 705 | return loss.detach(), mlm_loss, sso_loss 706 | -------------------------------------------------------------------------------- /extra/dataset_dict.py: -------------------------------------------------------------------------------- 1 | 2 | from datasets.dataset_dict import DatasetDict as oldDatasetDict 3 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 4 | from datasets.features import Features 5 | 6 | 7 | class DatasetDict(oldDatasetDict): 8 | # add new_fingerprints args 9 | def map( 10 | self, 11 | function, 12 | with_indices: bool = False, 13 | input_columns: Optional[Union[str, List[str]]] = None, 14 | batched: bool = False, 15 | batch_size: Optional[int] = 1000, 16 | remove_columns: Optional[List[str]] = None, 17 | keep_in_memory: bool = False, 18 | load_from_cache_file: bool = True, 19 | cache_file_names: Optional[Dict[str, Optional[str]]] = None, 20 | writer_batch_size: Optional[int] = 1000, 21 | features: Optional[Features] = None, 22 | disable_nullable: bool = False, 23 | fn_kwargs: Optional[dict] = None, 24 | num_proc: Optional[int] = None, 25 | new_fingerprints=None, 26 | ) -> "DatasetDict": 27 | """Apply a function to all the elements in the table (individually or in batches) 28 | and update the table (if function does updated examples). 29 | The transformation is applied to all the datasets of the dataset dictionary. 30 | 31 | Args: 32 | function (`callable`): with one of the following signature: 33 | - `function(example: Dict) -> Union[Dict, Any]` if `batched=False` and `with_indices=False` 34 | - `function(example: Dict, indices: int) -> Union[Dict, Any]` if `batched=False` and `with_indices=True` 35 | - `function(batch: Dict[List]) -> Union[Dict, Any]` if `batched=True` and `with_indices=False` 36 | - `function(batch: Dict[List], indices: List[int]) -> Union[Dict, Any]` if `batched=True` and `with_indices=True` 37 | with_indices (`bool`, defaults to `False`): Provide example indices to `function`. Note that in this case the signature of `function` should be `def function(example, idx): ...`. 38 | input_columns (`Optional[Union[str, List[str]]]`, defaults to `None`): The columns to be passed into `function` as 39 | positional arguments. If `None`, a dict mapping to all formatted columns is passed as one argument. 40 | batched (`bool`, defaults to `False`): Provide batch of examples to `function` 41 | batch_size (`Optional[int]`, defaults to `1000`): Number of examples per batch provided to `function` if `batched=True` 42 | `batch_size <= 0` or `batch_size == None`: Provide the full dataset as a single batch to `function` 43 | remove_columns (`Optional[List[str]]`, defaults to `None`): Remove a selection of columns while doing the mapping. 44 | Columns will be removed before updating the examples with the output of `function`, i.e. if `function` is adding 45 | columns with names in `remove_columns`, these columns will be kept. 46 | keep_in_memory (`bool`, defaults to `False`): Keep the dataset in memory instead of writing it to a cache file. 47 | load_from_cache_file (`bool`, defaults to `True`): If a cache file storing the current computation from `function` 48 | can be identified, use it instead of recomputing. 49 | cache_file_names (`Optional[Dict[str, str]]`, defaults to `None`): Provide the name of a path for the cache file. It is used to store the 50 | results of the computation instead of the automatically generated cache file name. 51 | You have to provide one :obj:`cache_file_name` per dataset in the dataset dictionary. 52 | writer_batch_size (:obj:`int`, default `1000`): Number of rows per write operation for the cache file writer. 53 | This value is a good trade-off between memory usage during the processing, and processing speed. 54 | Higher value makes the processing do fewer lookups, lower value consume less temporary memory while running `.map()`. 55 | features (`Optional[datasets.Features]`, defaults to `None`): Use a specific Features to store the cache file 56 | instead of the automatically generated one. 57 | disable_nullable (`bool`, defaults to `True`): Disallow null values in the table. 58 | fn_kwargs (`Optional[Dict]`, defaults to `None`): Keyword arguments to be passed to `function` 59 | num_proc (`Optional[int]`, defaults to `None`): Number of processes for multiprocessing. By default it doesn't 60 | use multiprocessing. 61 | """ 62 | self._check_values_type() 63 | if cache_file_names is None: 64 | cache_file_names = {k: None for k in self} 65 | if new_fingerprints is None: 66 | new_fingerprints = {k: None for k in self} 67 | return DatasetDict( 68 | { 69 | k: dataset.map( 70 | function=function, 71 | with_indices=with_indices, 72 | input_columns=input_columns, 73 | batched=batched, 74 | batch_size=batch_size, 75 | remove_columns=remove_columns, 76 | keep_in_memory=keep_in_memory, 77 | load_from_cache_file=load_from_cache_file, 78 | cache_file_name=cache_file_names[k], 79 | writer_batch_size=writer_batch_size, 80 | features=features, 81 | disable_nullable=disable_nullable, 82 | fn_kwargs=fn_kwargs, 83 | num_proc=num_proc, 84 | new_fingerprint=new_fingerprints[k] 85 | ) 86 | for k, dataset in self.items() 87 | } 88 | ) 89 | -------------------------------------------------------------------------------- /image/consumption.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxchtan/PoNet/cc9a29a3c7732e8c5cdda7eae32e556a2fa71b1d/image/consumption.png -------------------------------------------------------------------------------- /image/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxchtan/PoNet/cc9a29a3c7732e8c5cdda7eae32e556a2fa71b1d/image/model.png -------------------------------------------------------------------------------- /image/performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxchtan/PoNet/cc9a29a3c7732e8c5cdda7eae32e556a2fa71b1d/image/performance.png -------------------------------------------------------------------------------- /metrics/macro_f1_and_acc/macro_f1_and_acc.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """F1 metric.""" 16 | 17 | from sklearn.metrics import f1_score, accuracy_score 18 | 19 | import datasets 20 | 21 | 22 | _DESCRIPTION = """ 23 | The F1 score is the harmonic mean of the precision and recall. It can be computed with: 24 | F1 = 2 * (precision * recall) / (precision + recall) 25 | """ 26 | 27 | _KWARGS_DESCRIPTION = """ 28 | Args: 29 | predictions: Predicted labels, as returned by a model. 30 | references: Ground truth labels. 31 | labels: The set of labels to include when average != 'binary', and 32 | their order if average is None. Labels present in the data can 33 | be excluded, for example to calculate a multiclass average ignoring 34 | a majority negative class, while labels not present in the data will 35 | result in 0 components in a macro average. For multilabel targets, 36 | labels are column indices. By default, all labels in y_true and 37 | y_pred are used in sorted order. 38 | average: This parameter is required for multiclass/multilabel targets. 39 | If None, the scores for each class are returned. Otherwise, this 40 | determines the type of averaging performed on the data: 41 | binary: Only report results for the class specified by pos_label. 42 | This is applicable only if targets (y_{true,pred}) are binary. 43 | micro: Calculate metrics globally by counting the total true positives, 44 | false negatives and false positives. 45 | macro: Calculate metrics for each label, and find their unweighted mean. 46 | This does not take label imbalance into account. 47 | weighted: Calculate metrics for each label, and find their average 48 | weighted by support (the number of true instances for each label). 49 | This alters ‘macro’ to account for label imbalance; it can result 50 | in an F-score that is not between precision and recall. 51 | samples: Calculate metrics for each instance, and find their average 52 | (only meaningful for multilabel classification). 53 | sample_weight: Sample weights. 54 | Returns: 55 | f1: F1 score. 56 | Examples: 57 | 58 | >>> f1_metric = datasets.load_metric("f1") 59 | >>> results = f1_metric.compute(references=[0, 1], predictions=[0, 1]) 60 | >>> print(results) 61 | {'f1': 1.0} 62 | """ 63 | 64 | _CITATION = """\ 65 | @article{scikit-learn, 66 | title={Scikit-learn: Machine Learning in {P}ython}, 67 | author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V. 68 | and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P. 69 | and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and 70 | Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.}, 71 | journal={Journal of Machine Learning Research}, 72 | volume={12}, 73 | pages={2825--2830}, 74 | year={2011} 75 | } 76 | """ 77 | 78 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 79 | class Micro_F1_and_Acc(datasets.Metric): 80 | def _info(self): 81 | return datasets.MetricInfo( 82 | description=_DESCRIPTION, 83 | citation=_CITATION, 84 | inputs_description=_KWARGS_DESCRIPTION, 85 | features=datasets.Features( 86 | { 87 | "predictions": datasets.Value("int64"), 88 | "references": datasets.Value("int64"), 89 | } 90 | ), 91 | reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"], 92 | ) 93 | 94 | def _compute(self, predictions, references, labels=None, pos_label=1, average="macro", sample_weight=None): 95 | return { 96 | "f1": f1_score( 97 | references, 98 | predictions, 99 | labels=labels, 100 | pos_label=pos_label, 101 | average=average, 102 | sample_weight=sample_weight, 103 | ), 104 | "accuracy": accuracy_score( 105 | references, 106 | predictions, 107 | ) 108 | } 109 | -------------------------------------------------------------------------------- /metrics/micro_f1_and_acc/micro_f1_and_acc.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """F1 metric.""" 16 | 17 | from sklearn.metrics import f1_score, accuracy_score 18 | 19 | import datasets 20 | 21 | 22 | _DESCRIPTION = """ 23 | The F1 score is the harmonic mean of the precision and recall. It can be computed with: 24 | F1 = 2 * (precision * recall) / (precision + recall) 25 | """ 26 | 27 | _KWARGS_DESCRIPTION = """ 28 | Args: 29 | predictions: Predicted labels, as returned by a model. 30 | references: Ground truth labels. 31 | labels: The set of labels to include when average != 'binary', and 32 | their order if average is None. Labels present in the data can 33 | be excluded, for example to calculate a multiclass average ignoring 34 | a majority negative class, while labels not present in the data will 35 | result in 0 components in a macro average. For multilabel targets, 36 | labels are column indices. By default, all labels in y_true and 37 | y_pred are used in sorted order. 38 | average: This parameter is required for multiclass/multilabel targets. 39 | If None, the scores for each class are returned. Otherwise, this 40 | determines the type of averaging performed on the data: 41 | binary: Only report results for the class specified by pos_label. 42 | This is applicable only if targets (y_{true,pred}) are binary. 43 | micro: Calculate metrics globally by counting the total true positives, 44 | false negatives and false positives. 45 | macro: Calculate metrics for each label, and find their unweighted mean. 46 | This does not take label imbalance into account. 47 | weighted: Calculate metrics for each label, and find their average 48 | weighted by support (the number of true instances for each label). 49 | This alters ‘macro’ to account for label imbalance; it can result 50 | in an F-score that is not between precision and recall. 51 | samples: Calculate metrics for each instance, and find their average 52 | (only meaningful for multilabel classification). 53 | sample_weight: Sample weights. 54 | Returns: 55 | f1: F1 score. 56 | Examples: 57 | 58 | >>> f1_metric = datasets.load_metric("f1") 59 | >>> results = f1_metric.compute(references=[0, 1], predictions=[0, 1]) 60 | >>> print(results) 61 | {'f1': 1.0} 62 | """ 63 | 64 | _CITATION = """\ 65 | @article{scikit-learn, 66 | title={Scikit-learn: Machine Learning in {P}ython}, 67 | author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V. 68 | and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P. 69 | and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and 70 | Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.}, 71 | journal={Journal of Machine Learning Research}, 72 | volume={12}, 73 | pages={2825--2830}, 74 | year={2011} 75 | } 76 | """ 77 | 78 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) 79 | class Micro_F1_and_Acc(datasets.Metric): 80 | def _info(self): 81 | return datasets.MetricInfo( 82 | description=_DESCRIPTION, 83 | citation=_CITATION, 84 | inputs_description=_KWARGS_DESCRIPTION, 85 | features=datasets.Features( 86 | { 87 | "predictions": datasets.Value("int64"), 88 | "references": datasets.Value("int64"), 89 | } 90 | ), 91 | reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"], 92 | ) 93 | 94 | def _compute(self, predictions, references, labels=None, pos_label=1, average="micro", sample_weight=None): 95 | return { 96 | "f1": f1_score( 97 | references, 98 | predictions, 99 | labels=labels, 100 | pos_label=pos_label, 101 | average=average, 102 | sample_weight=sample_weight, 103 | ), 104 | "accuracy": accuracy_score( 105 | references, 106 | predictions, 107 | ) 108 | } 109 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch == 2.0.1 2 | transformers == 4.29.1 3 | datasets >= 1.1.3 4 | sentencepiece != 0.1.92 5 | protobuf 6 | tensorboard 7 | fairscale 8 | scipy 9 | scikit-learn 10 | accelerate -------------------------------------------------------------------------------- /run_glue.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification on GLUE.""" 17 | # You can also adapt this script on your own text classification task. Pointers for this are left as comments. 18 | import logging 19 | import os 20 | import random 21 | import sys 22 | from dataclasses import dataclass, field 23 | from typing import Optional 24 | 25 | import numpy as np 26 | from datasets import load_dataset, load_metric 27 | 28 | import transformers 29 | from transformers import ( 30 | AutoConfig, 31 | AutoModelForSequenceClassification, 32 | AutoTokenizer, 33 | DataCollatorWithPadding, 34 | EvalPrediction, 35 | HfArgumentParser, 36 | PretrainedConfig, 37 | Trainer, 38 | TrainingArguments, 39 | default_data_collator, 40 | set_seed, 41 | ) 42 | from transformers.trainer_utils import get_last_checkpoint, is_main_process 43 | from transformers.utils import check_min_version 44 | 45 | 46 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 47 | check_min_version("4.21.3") 48 | 49 | task_to_keys = { 50 | "cola": ("sentence", None), 51 | "mnli": ("premise", "hypothesis"), 52 | "mrpc": ("sentence1", "sentence2"), 53 | "qnli": ("question", "sentence"), 54 | "qqp": ("question1", "question2"), 55 | "rte": ("sentence1", "sentence2"), 56 | "sst2": ("sentence", None), 57 | "stsb": ("sentence1", "sentence2"), 58 | "wnli": ("sentence1", "sentence2"), 59 | } 60 | 61 | logger = logging.getLogger(__name__) 62 | 63 | 64 | @dataclass 65 | class DataTrainingArguments: 66 | """ 67 | Arguments pertaining to what data we are going to input our model for training and eval. 68 | 69 | Using `HfArgumentParser` we can turn this class 70 | into argparse arguments to be able to specify them on 71 | the command line. 72 | """ 73 | 74 | task_name: Optional[str] = field( 75 | default=None, 76 | metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, 77 | ) 78 | max_seq_length: int = field( 79 | default=128, 80 | metadata={ 81 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 82 | "than this will be truncated, sequences shorter will be padded." 83 | }, 84 | ) 85 | overwrite_cache: bool = field( 86 | default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} 87 | ) 88 | pad_to_max_length: bool = field( 89 | default=True, 90 | metadata={ 91 | "help": "Whether to pad all samples to `max_seq_length`. " 92 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 93 | }, 94 | ) 95 | max_train_samples: Optional[int] = field( 96 | default=None, 97 | metadata={ 98 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 99 | "value if set." 100 | }, 101 | ) 102 | max_eval_samples: Optional[int] = field( 103 | default=None, 104 | metadata={ 105 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 106 | "value if set." 107 | }, 108 | ) 109 | max_predict_samples: Optional[int] = field( 110 | default=None, 111 | metadata={ 112 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 113 | "value if set." 114 | }, 115 | ) 116 | train_file: Optional[str] = field( 117 | default=None, metadata={"help": "A csv or a json file containing the training data."} 118 | ) 119 | validation_file: Optional[str] = field( 120 | default=None, metadata={"help": "A csv or a json file containing the validation data."} 121 | ) 122 | test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."}) 123 | 124 | def __post_init__(self): 125 | if self.task_name is not None: 126 | self.task_name = self.task_name.lower() 127 | if self.task_name not in task_to_keys.keys(): 128 | raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys())) 129 | elif self.train_file is None or self.validation_file is None: 130 | raise ValueError("Need either a GLUE task or a training/validation file.") 131 | else: 132 | train_extension = self.train_file.split(".")[-1] 133 | assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file." 134 | validation_extension = self.validation_file.split(".")[-1] 135 | assert ( 136 | validation_extension == train_extension 137 | ), "`validation_file` should have the same extension (csv or json) as `train_file`." 138 | 139 | 140 | @dataclass 141 | class ModelArguments: 142 | """ 143 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 144 | """ 145 | 146 | model_name_or_path: str = field( 147 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 148 | ) 149 | config_name: Optional[str] = field( 150 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 151 | ) 152 | tokenizer_name: Optional[str] = field( 153 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 154 | ) 155 | cache_dir: Optional[str] = field( 156 | default=None, 157 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 158 | ) 159 | use_fast_tokenizer: bool = field( 160 | default=True, 161 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 162 | ) 163 | model_revision: str = field( 164 | default="main", 165 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 166 | ) 167 | use_auth_token: bool = field( 168 | default=False, 169 | metadata={ 170 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 171 | "with private models)." 172 | }, 173 | ) 174 | 175 | def main(): 176 | # See all possible arguments in src/transformers/training_args.py 177 | # or by passing the --help flag to this script. 178 | # We now keep distinct sets of args, for a cleaner separation of concerns. 179 | 180 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 181 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 182 | # If we pass only one argument to the script and it's the path to a json file, 183 | # let's parse it to get our arguments. 184 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 185 | else: 186 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 187 | 188 | # Detecting last checkpoint. 189 | last_checkpoint = None 190 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 191 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 192 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 193 | raise ValueError( 194 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 195 | "Use --overwrite_output_dir to overcome." 196 | ) 197 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 198 | logger.info( 199 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 200 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 201 | ) 202 | 203 | # Setup logging 204 | logging.basicConfig( 205 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 206 | datefmt="%m/%d/%Y %H:%M:%S", 207 | handlers=[logging.StreamHandler(sys.stdout)], 208 | ) 209 | logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) 210 | 211 | # Log on each process the small summary: 212 | logger.warning( 213 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 214 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 215 | ) 216 | # Set the verbosity to info of the Transformers logger (on main process only): 217 | if is_main_process(training_args.local_rank): 218 | transformers.utils.logging.set_verbosity_info() 219 | transformers.utils.logging.enable_default_handler() 220 | transformers.utils.logging.enable_explicit_format() 221 | logger.info(f"Training/evaluation parameters {training_args}") 222 | 223 | # Set seed before initializing model. 224 | set_seed(training_args.seed) 225 | 226 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 227 | # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). 228 | # 229 | # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the 230 | # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named 231 | # label if at least two columns are provided. 232 | # 233 | # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this 234 | # single column. You can easily tweak this behavior (see below) 235 | # 236 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 237 | # download the dataset. 238 | if data_args.task_name is not None: 239 | # Downloading and loading a dataset from the hub. 240 | datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir) 241 | else: 242 | # Loading a dataset from your local files. 243 | # CSV/JSON training and evaluation files are needed. 244 | data_files = {"train": data_args.train_file, "validation": data_args.validation_file} 245 | 246 | # Get the test dataset: you can provide your own CSV/JSON test file (see below) 247 | # when you use `do_predict` without specifying a GLUE benchmark task. 248 | if training_args.do_predict: 249 | if data_args.test_file is not None: 250 | train_extension = data_args.train_file.split(".")[-1] 251 | test_extension = data_args.test_file.split(".")[-1] 252 | assert ( 253 | test_extension == train_extension 254 | ), "`test_file` should have the same extension (csv or json) as `train_file`." 255 | data_files["test"] = data_args.test_file 256 | else: 257 | raise ValueError("Need either a GLUE task or a test file for `do_predict`.") 258 | 259 | for key in data_files.keys(): 260 | logger.info(f"load a local file for {key}: {data_files[key]}") 261 | 262 | if data_args.train_file.endswith(".csv"): 263 | # Loading a dataset from local csv files 264 | datasets = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir) 265 | else: 266 | # Loading a dataset from local json files 267 | datasets = load_dataset("json", data_files=data_files, cache_dir=model_args.cache_dir) 268 | # See more about loading any type of standard or custom dataset at 269 | # https://huggingface.co/docs/datasets/loading_datasets.html. 270 | 271 | # Labels 272 | if data_args.task_name is not None: 273 | is_regression = data_args.task_name == "stsb" 274 | if not is_regression: 275 | label_list = datasets["train"].features["label"].names 276 | num_labels = len(label_list) 277 | else: 278 | num_labels = 1 279 | else: 280 | # Trying to have good defaults here, don't hesitate to tweak to your needs. 281 | is_regression = datasets["train"].features["label"].dtype in ["float32", "float64"] 282 | if is_regression: 283 | num_labels = 1 284 | else: 285 | # A useful fast method: 286 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique 287 | label_list = datasets["train"].unique("label") 288 | label_list.sort() # Let's sort it for determinism 289 | num_labels = len(label_list) 290 | 291 | # Load pretrained model and tokenizer 292 | # 293 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 294 | # download model & vocab. 295 | config = AutoConfig.from_pretrained( 296 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 297 | num_labels=num_labels, 298 | finetuning_task=data_args.task_name, 299 | cache_dir=model_args.cache_dir, 300 | revision=model_args.model_revision, 301 | use_auth_token=True if model_args.use_auth_token else None, 302 | trust_remote_code=True 303 | ) 304 | 305 | tokenizer = AutoTokenizer.from_pretrained( 306 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 307 | cache_dir=model_args.cache_dir, 308 | use_fast=model_args.use_fast_tokenizer, 309 | revision=model_args.model_revision, 310 | use_auth_token=True if model_args.use_auth_token else None, 311 | trust_remote_code=True 312 | ) 313 | 314 | model = AutoModelForSequenceClassification.from_pretrained( 315 | model_args.model_name_or_path, 316 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 317 | config=config, 318 | cache_dir=model_args.cache_dir, 319 | revision=model_args.model_revision, 320 | use_auth_token=True if model_args.use_auth_token else None, 321 | trust_remote_code=True 322 | ) 323 | 324 | # Preprocessing the datasets 325 | if data_args.task_name is not None: 326 | sentence1_key, sentence2_key = task_to_keys[data_args.task_name] 327 | else: 328 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. 329 | non_label_column_names = [name for name in datasets["train"].column_names if name != "label"] 330 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: 331 | sentence1_key, sentence2_key = "sentence1", "sentence2" 332 | else: 333 | if len(non_label_column_names) >= 2: 334 | sentence1_key, sentence2_key = non_label_column_names[:2] 335 | else: 336 | sentence1_key, sentence2_key = non_label_column_names[0], None 337 | 338 | # Padding strategy 339 | if data_args.pad_to_max_length: 340 | padding = "max_length" 341 | else: 342 | # We will pad later, dynamically at batch creation, to the max sequence length in each batch 343 | padding = False 344 | 345 | # Some models have set the order of the labels to use, so let's make sure we do use it. 346 | label_to_id = None 347 | if ( 348 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id 349 | and data_args.task_name is not None 350 | and not is_regression 351 | ): 352 | # Some have all caps in their config, some don't. 353 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} 354 | if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): 355 | label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} 356 | else: 357 | logger.warning( 358 | "Your model seems to have been trained with labels, but they don't match the dataset: ", 359 | f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." 360 | "\nIgnoring the model labels as a result.", 361 | ) 362 | elif data_args.task_name is None and not is_regression: 363 | label_to_id = {v: i for i, v in enumerate(label_list)} 364 | 365 | if data_args.max_seq_length > tokenizer.model_max_length: 366 | logger.warning( 367 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 368 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 369 | ) 370 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 371 | 372 | def preprocess_function(examples): 373 | # Tokenize the texts 374 | args = ( 375 | (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) 376 | ) 377 | result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) 378 | 379 | # Map labels to IDs (not necessary for GLUE tasks) 380 | if label_to_id is not None and "label" in examples: 381 | result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] 382 | return result 383 | 384 | datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache) 385 | if training_args.do_train: 386 | if "train" not in datasets: 387 | raise ValueError("--do_train requires a train dataset") 388 | train_dataset = datasets["train"] 389 | if data_args.max_train_samples is not None: 390 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 391 | 392 | if training_args.do_eval: 393 | if "validation" not in datasets and "validation_matched" not in datasets: 394 | raise ValueError("--do_eval requires a validation dataset") 395 | eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] 396 | if data_args.max_eval_samples is not None: 397 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 398 | 399 | if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None: 400 | if "test" not in datasets and "test_matched" not in datasets: 401 | raise ValueError("--do_predict requires a test dataset") 402 | predict_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"] 403 | if data_args.max_predict_samples is not None: 404 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) 405 | 406 | # Log a few random samples from the training set: 407 | if training_args.do_train: 408 | for index in random.sample(range(len(train_dataset)), 3): 409 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 410 | 411 | # Get the metric function 412 | if data_args.task_name is not None: 413 | # metric = load_metric("../huggingface/datasets/metrics/glue", data_args.task_name) 414 | metric = load_metric('glue', data_args.task_name) 415 | # TODO: When datasets metrics include regular accuracy, make an else here and remove special branch from 416 | # compute_metrics 417 | 418 | # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a 419 | # predictions and label_ids field) and has to return a dictionary string to float. 420 | def compute_metrics(p: EvalPrediction): 421 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 422 | preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) 423 | print("percent: ", preds.sum()/preds.shape[0]) 424 | if data_args.task_name is not None: 425 | result = metric.compute(predictions=preds, references=p.label_ids) 426 | if len(result) > 1: 427 | result["combined_score"] = np.mean(list(result.values())).item() 428 | return result 429 | elif is_regression: 430 | return {"mse": ((preds - p.label_ids) ** 2).mean().item()} 431 | else: 432 | return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} 433 | 434 | # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding. 435 | if data_args.pad_to_max_length: 436 | data_collator = default_data_collator 437 | elif training_args.fp16: 438 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) 439 | else: 440 | data_collator = None 441 | 442 | # Initialize our Trainer 443 | trainer = Trainer( 444 | model=model, 445 | args=training_args, 446 | train_dataset=train_dataset if training_args.do_train else None, 447 | eval_dataset=eval_dataset if training_args.do_eval else None, 448 | compute_metrics=compute_metrics, 449 | tokenizer=tokenizer, 450 | data_collator=data_collator, 451 | ) 452 | 453 | # Training 454 | if training_args.do_train: 455 | checkpoint = None 456 | if training_args.resume_from_checkpoint is not None: 457 | checkpoint = training_args.resume_from_checkpoint 458 | elif last_checkpoint is not None: 459 | checkpoint = last_checkpoint 460 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 461 | metrics = train_result.metrics 462 | max_train_samples = ( 463 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 464 | ) 465 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 466 | 467 | trainer.save_model() # Saves the tokenizer too for easy upload 468 | 469 | trainer.log_metrics("train", metrics) 470 | trainer.save_metrics("train", metrics) 471 | trainer.save_state() 472 | 473 | # Evaluation 474 | if training_args.do_eval: 475 | logger.info("*** Evaluate ***") 476 | 477 | # Loop to handle MNLI double evaluation (matched, mis-matched) 478 | tasks = [data_args.task_name] 479 | eval_datasets = [eval_dataset] 480 | if data_args.task_name == "mnli": 481 | tasks.append("mnli-mm") 482 | eval_datasets.append(datasets["validation_mismatched"]) 483 | 484 | for eval_dataset, task in zip(eval_datasets, tasks): 485 | metrics = trainer.evaluate(eval_dataset=eval_dataset) 486 | 487 | max_eval_samples = ( 488 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 489 | ) 490 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 491 | 492 | trainer.log_metrics(f"eval-{task}", metrics) 493 | trainer.save_metrics(f"eval-{task}", metrics) 494 | 495 | if training_args.do_predict: 496 | logger.info("*** Predict ***") 497 | 498 | # Loop to handle MNLI double evaluation (matched, mis-matched) 499 | tasks = [data_args.task_name] 500 | predict_datasets = [predict_dataset] 501 | if data_args.task_name == "mnli": 502 | tasks.append("mnli-mm") 503 | predict_datasets.append(datasets["test_mismatched"]) 504 | 505 | for predict_dataset, task in zip(predict_datasets, tasks): 506 | # Removing the `label` columns because it contains -1 and Trainer won't like that. 507 | predict_dataset.remove_columns_("label") 508 | predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions 509 | predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1) 510 | 511 | output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt") 512 | if trainer.is_world_process_zero(): 513 | with open(output_predict_file, "w") as writer: 514 | logger.info(f"***** Predict results {task} *****") 515 | writer.write("index\tprediction\n") 516 | for index, item in enumerate(predictions): 517 | if is_regression: 518 | writer.write(f"{index}\t{item:3.3f}\n") 519 | else: 520 | item = label_list[item] 521 | writer.write(f"{index}\t{item}\n") 522 | 523 | if training_args.push_to_hub: 524 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "text-classification"} 525 | if data_args.task_name is not None: 526 | kwargs["language"] = "en" 527 | kwargs["dataset_tags"] = "glue" 528 | kwargs["dataset_args"] = data_args.task_name 529 | kwargs["dataset"] = f"GLUE {data_args.task_name.upper()}" 530 | 531 | trainer.push_to_hub(**kwargs) 532 | 533 | 534 | def _mp_fn(index): 535 | # For xla_spawn (TPUs) 536 | main() 537 | 538 | 539 | if __name__ == "__main__": 540 | main() -------------------------------------------------------------------------------- /run_long_classification.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification on GLUE.""" 17 | # You can also adapt this script on your own text classification task. Pointers for this are left as comments. 18 | 19 | import torch 20 | import logging 21 | import os 22 | import random 23 | import sys 24 | from dataclasses import dataclass, field 25 | from typing import Optional 26 | 27 | import numpy as np 28 | from datasets import load_dataset, load_metric 29 | 30 | import transformers 31 | from transformers import ( 32 | AutoConfig, 33 | AutoModelForSequenceClassification, 34 | AutoTokenizer, 35 | DataCollatorWithPadding, 36 | EvalPrediction, 37 | HfArgumentParser, 38 | PretrainedConfig, 39 | Trainer, 40 | TrainingArguments, 41 | default_data_collator, 42 | set_seed, 43 | ) 44 | 45 | from transformers.trainer_utils import get_last_checkpoint, is_main_process 46 | from transformers.utils import check_min_version 47 | from nltk.tokenize import sent_tokenize 48 | 49 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 50 | check_min_version("4.21.3") 51 | 52 | task_to_metrics = { 53 | "arxiv": ("metrics/micro_f1_and_acc", None), 54 | "imdb": ("glue", "mrpc"), 55 | "yelp": ("metrics/micro_f1_and_acc", None), 56 | "hnd": ("glue", "mrpc"), 57 | } 58 | 59 | task_to_datasets = { 60 | "arxiv": {'path': 'datasets/arxiv-11'}, 61 | "imdb": {'path': 'imdb'}, 62 | "yelp": {'path': 'yelp_review_full'}, 63 | "hnd": { 64 | 'path': 'json', 65 | 'data_files': { 66 | 'train': ['datasets/hyperpartisan_news_detection/train.jsonl'], 67 | 'validation': ['datasets/hyperpartisan_news_detection/dev.jsonl'], 68 | 'test': ['datasets/hyperpartisan_news_detection/test.jsonl'], 69 | } 70 | }, 71 | } 72 | 73 | logger = logging.getLogger(__name__) 74 | 75 | 76 | @dataclass 77 | class DataTrainingArguments: 78 | """ 79 | Arguments pertaining to what data we are going to input our model for training and eval. 80 | 81 | Using `HfArgumentParser` we can turn this class 82 | into argparse arguments to be able to specify them on 83 | the command line. 84 | """ 85 | 86 | task_name: str = field( 87 | default=None, 88 | metadata={"help": "The name of the task to train."}, 89 | ) 90 | max_seq_length: int = field( 91 | default=128, 92 | metadata={ 93 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 94 | "than this will be truncated, sequences shorter will be padded." 95 | }, 96 | ) 97 | overwrite_cache: bool = field( 98 | default=False, 99 | metadata={"help": "Overwrite the cached preprocessed datasets or not."} 100 | ) 101 | pad_to_max_length: bool = field( 102 | default=False, 103 | metadata={ 104 | "help": "Whether to pad all samples to `max_seq_length`. " 105 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 106 | }, 107 | ) 108 | max_train_samples: Optional[int] = field( 109 | default=None, 110 | metadata={ 111 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 112 | "value if set." 113 | }, 114 | ) 115 | max_eval_samples: Optional[int] = field( 116 | default=None, 117 | metadata={ 118 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 119 | "value if set." 120 | }, 121 | ) 122 | max_predict_samples: Optional[int] = field( 123 | default=None, 124 | metadata={ 125 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 126 | "value if set." 127 | }, 128 | ) 129 | train_file: Optional[str] = field( 130 | default=None, metadata={"help": "A csv or a json file containing the training data."} 131 | ) 132 | validation_file: Optional[str] = field( 133 | default=None, metadata={"help": "A csv or a json file containing the validation data."} 134 | ) 135 | test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."}) 136 | 137 | @dataclass 138 | class ModelArguments: 139 | """ 140 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 141 | """ 142 | 143 | model_name_or_path: str = field( 144 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 145 | ) 146 | config_name: Optional[str] = field( 147 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 148 | ) 149 | tokenizer_name: Optional[str] = field( 150 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 151 | ) 152 | cache_dir: Optional[str] = field( 153 | default=None, 154 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 155 | ) 156 | use_fast_tokenizer: bool = field( 157 | default=True, 158 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 159 | ) 160 | model_revision: str = field( 161 | default="main", 162 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 163 | ) 164 | use_auth_token: bool = field( 165 | default=False, 166 | metadata={ 167 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 168 | "with private models)." 169 | }, 170 | ) 171 | gc: bool = field( 172 | default=True, 173 | metadata={ 174 | "help": "Use gradient checkpointing." 175 | }, 176 | ) 177 | 178 | def main(): 179 | # See all possible arguments in src/transformers/training_args.py 180 | # or by passing the --help flag to this script. 181 | # We now keep distinct sets of args, for a cleaner separation of concerns. 182 | 183 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 184 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 185 | # If we pass only one argument to the script and it's the path to a json file, 186 | # let's parse it to get our arguments. 187 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 188 | else: 189 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 190 | 191 | # Detecting last checkpoint. 192 | last_checkpoint = None 193 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 194 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 195 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 196 | raise ValueError( 197 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 198 | "Use --overwrite_output_dir to overcome." 199 | ) 200 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 201 | logger.info( 202 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 203 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 204 | ) 205 | 206 | # Setup logging 207 | logging.basicConfig( 208 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 209 | datefmt="%m/%d/%Y %H:%M:%S", 210 | handlers=[logging.StreamHandler(sys.stdout)], 211 | ) 212 | logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) 213 | 214 | # Log on each process the small summary: 215 | logger.warning( 216 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 217 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 218 | ) 219 | # Set the verbosity to info of the Transformers logger (on main process only): 220 | if is_main_process(training_args.local_rank): 221 | transformers.utils.logging.set_verbosity_info() 222 | transformers.utils.logging.enable_default_handler() 223 | transformers.utils.logging.enable_explicit_format() 224 | logger.info(f"Training/evaluation parameters {training_args}") 225 | 226 | # Set seed before initializing model. 227 | set_seed(training_args.seed) 228 | 229 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 230 | # download the dataset. 231 | datasets = load_dataset( 232 | **task_to_datasets[data_args.task_name], 233 | cache_dir=model_args.cache_dir 234 | ) 235 | 236 | data_args.validation_split_percentage = 10 237 | if "validation" not in datasets.keys(): 238 | if data_args.task_name == 'imdb': 239 | # test set is large enough -> ensure similar size and distribution between val/test datasets 240 | datasets["validation"] = datasets['test'] 241 | else: 242 | datasets["validation"] = load_dataset( 243 | **task_to_datasets[data_args.task_name], 244 | split=f"train[:{data_args.validation_split_percentage}%]", 245 | cache_dir=model_args.cache_dir, 246 | ) 247 | datasets["train"] = load_dataset( 248 | **task_to_datasets[data_args.task_name], 249 | split=f"train[{data_args.validation_split_percentage+10}%:]", 250 | cache_dir=model_args.cache_dir, 251 | ) 252 | 253 | # Labels 254 | label_list = datasets["train"].unique("label") 255 | label_list.sort() # Let's sort it for determinism 256 | num_labels = len(label_list) 257 | is_regression = False 258 | 259 | # Load pretrained model and tokenizer 260 | # 261 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 262 | # download model & vocab. 263 | config = AutoConfig.from_pretrained( 264 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 265 | num_labels=num_labels, 266 | finetuning_task=data_args.task_name, 267 | cache_dir=model_args.cache_dir, 268 | revision=model_args.model_revision, 269 | use_auth_token=True if model_args.use_auth_token else None, 270 | trust_remote_code=True 271 | ) 272 | tokenizer = AutoTokenizer.from_pretrained( 273 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 274 | cache_dir=model_args.cache_dir, 275 | use_fast=model_args.use_fast_tokenizer, 276 | revision=model_args.model_revision, 277 | use_auth_token=True if model_args.use_auth_token else None, 278 | trust_remote_code=True 279 | ) 280 | 281 | 282 | if model_args.gc: 283 | config.gradient_checkpointing = True 284 | config.use_cache = False 285 | 286 | model = AutoModelForSequenceClassification.from_pretrained( 287 | model_args.model_name_or_path, 288 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 289 | config=config, 290 | cache_dir=model_args.cache_dir, 291 | revision=model_args.model_revision, 292 | use_auth_token=True if model_args.use_auth_token else None, 293 | trust_remote_code=True 294 | ) 295 | 296 | # extend position embeddings 297 | if data_args.max_seq_length > tokenizer.model_max_length: 298 | logger.warning( 299 | "Copying the position embedding due to " 300 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the " 301 | f"model ({tokenizer.model_max_length})." 302 | ) 303 | max_pos = data_args.max_seq_length 304 | config.max_position_embeddings = max_pos 305 | tokenizer.model_max_length = max_pos 306 | tokenizer.init_kwargs['model_max_length'] = tokenizer.model_max_length 307 | current_max_pos, embed_size = model.ponet.embeddings.position_embeddings.weight.shape 308 | assert max_pos > current_max_pos 309 | # allocate a larger position embedding matrix 310 | new_pos_embed = model.ponet.embeddings.position_embeddings.weight.new_empty(max_pos, embed_size) 311 | # copy position embeddings over and over to initialize the new position embeddings 312 | k = 0 313 | step = current_max_pos 314 | while k < max_pos - 1: 315 | new_pos_embed[k:(k + step)] = model.ponet.embeddings.position_embeddings.weight[:] 316 | k += step 317 | model.ponet.embeddings.position_embeddings.weight.data = new_pos_embed 318 | model.ponet.embeddings.position_ids.data = torch.tensor([i for i in range(max_pos)]).reshape(1, max_pos) 319 | 320 | # Preprocessing the datasets 321 | sentence1_key, sentence2_key = "text", None 322 | 323 | # Padding strategy 324 | if data_args.pad_to_max_length: 325 | padding = "max_length" 326 | else: 327 | # We will pad later, dynamically at batch creation, to the max sequence length in each batch 328 | padding = False 329 | 330 | # Some models have set the order of the labels to use, so let's make sure we do use it. 331 | label_to_id = None 332 | if ( 333 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id 334 | and data_args.task_name is not None 335 | and not is_regression 336 | ): 337 | # Some have all caps in their config, some don't. 338 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} 339 | if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): 340 | label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} 341 | else: 342 | logger.warning( 343 | "Your model seems to have been trained with labels, but they don't match the dataset: ", 344 | f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." 345 | "\nIgnoring the model labels as a result.", 346 | ) 347 | elif data_args.task_name is None and not is_regression: 348 | label_to_id = {v: i for i, v in enumerate(label_list)} 349 | 350 | if data_args.max_seq_length > tokenizer.model_max_length: 351 | logger.warning( 352 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 353 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 354 | ) 355 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 356 | 357 | def preprocess_function_arxiv(examples): 358 | # Tokenize the texts 359 | segment_ids = [] 360 | args = ( 361 | ([ex.replace('\n\n', ' ').replace('\n', ' ') for ex in examples[sentence1_key]],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) 362 | ) 363 | result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) 364 | for ex in examples[sentence1_key]: 365 | seg_lens = list(map(len, tokenizer([eex.replace('\n', ' ') for eex in ex.split('\n\n')], add_special_tokens=False, max_length=max_seq_length, truncation=True)['input_ids'])) 366 | segment_id = [0] + sum([[i]*sl for i, sl in enumerate(seg_lens, start=1)], []) 367 | segment_id = segment_id[:max_seq_length-1] 368 | segment_ids.append(segment_id + [segment_id[-1]+1]) 369 | result["segment_ids"] = segment_ids 370 | 371 | # Map labels to IDs 372 | if label_to_id is not None and "label" in examples: 373 | result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] 374 | return result 375 | 376 | def preprocess_function(examples): 377 | # Tokenize the texts 378 | segment_ids = [] 379 | args = ( 380 | (examples[sentence1_key], ) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) 381 | ) 382 | result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) 383 | for ex in examples[sentence1_key]: 384 | seg_lens = list(map(len, tokenizer(sent_tokenize(ex), add_special_tokens=False, max_length=max_seq_length, truncation=True)['input_ids'])) 385 | segment_id = [0] + sum([[i]*sl for i, sl in enumerate(seg_lens, start=1)], []) 386 | segment_id = segment_id[:max_seq_length-1] 387 | segment_ids.append(segment_id + [segment_id[-1]+1]) 388 | result["segment_ids"] = segment_ids 389 | 390 | # Map labels to IDs 391 | if label_to_id is not None and "label" in examples: 392 | result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] 393 | return result 394 | 395 | datasets = datasets.map( 396 | preprocess_function if data_args.task_name != 'arxiv' else preprocess_function_arxiv, 397 | batched=True, 398 | load_from_cache_file=not data_args.overwrite_cache 399 | ) 400 | 401 | # raise RuntimeError 402 | if training_args.do_train: 403 | if "train" not in datasets: 404 | raise ValueError("--do_train requires a train dataset") 405 | # train_dataset = datasets["train"].shuffle(seed=training_args.seed) 406 | train_dataset = datasets["train"] 407 | if data_args.max_train_samples is not None: 408 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 409 | 410 | if training_args.do_eval: 411 | if "validation" not in datasets: 412 | raise ValueError("--do_eval requires a validation dataset") 413 | eval_dataset = datasets["validation"] 414 | if data_args.max_eval_samples is not None: 415 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 416 | 417 | if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None: 418 | if "test" not in datasets: 419 | raise ValueError("--do_predict requires a test dataset") 420 | predict_dataset = datasets["test"] 421 | if data_args.max_predict_samples is not None: 422 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) 423 | 424 | # Log a few random samples from the training set: 425 | if training_args.do_train: 426 | for index in random.sample(range(len(train_dataset)), 3): 427 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 428 | 429 | # Get the metric function 430 | metric = load_metric(*task_to_metrics[data_args.task_name]) 431 | # TODO: When datasets metrics include regular accuracy, make an else here and remove special branch from 432 | # compute_metrics 433 | 434 | # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a 435 | # predictions and label_ids field) and has to return a dictionary string to float. 436 | def compute_metrics(p: EvalPrediction): 437 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 438 | preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) 439 | result = metric.compute(predictions=preds, references=p.label_ids) 440 | if len(result) > 1: 441 | result["combined_score"] = np.mean(list(result.values())).item() 442 | return result 443 | 444 | # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding. 445 | if data_args.pad_to_max_length: 446 | data_collator = default_data_collator 447 | elif training_args.fp16: 448 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) 449 | else: 450 | data_collator = None 451 | 452 | # Initialize our Trainer 453 | trainer = Trainer( 454 | model=model, 455 | args=training_args, 456 | train_dataset=train_dataset if training_args.do_train else None, 457 | eval_dataset=eval_dataset if training_args.do_eval else None, 458 | compute_metrics=compute_metrics, 459 | tokenizer=tokenizer, 460 | data_collator=data_collator, 461 | ) 462 | 463 | # Training 464 | if training_args.do_train: 465 | checkpoint = None 466 | # if training_args.resume_from_checkpoint is not None: 467 | # checkpoint = training_args.resume_from_checkpoint 468 | # elif last_checkpoint is not None: 469 | # checkpoint = last_checkpoint 470 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 471 | metrics = train_result.metrics 472 | max_train_samples = ( 473 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 474 | ) 475 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 476 | 477 | # trainer.save_model() # Saves the tokenizer too for easy upload 478 | 479 | trainer.log_metrics("train", metrics) 480 | trainer.save_metrics("train", metrics) 481 | trainer.save_state() 482 | 483 | # Evaluation 484 | if training_args.do_eval: 485 | logger.info("*** Evaluate ***") 486 | 487 | task = data_args.task_name 488 | metrics = trainer.evaluate(eval_dataset=eval_dataset) 489 | 490 | max_eval_samples = ( 491 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 492 | ) 493 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 494 | 495 | trainer.log_metrics(f"eval-{task}", metrics) 496 | trainer.save_metrics(f"eval-{task}", metrics) 497 | 498 | if training_args.do_predict: 499 | logger.info("*** Predict ***") 500 | 501 | task = data_args.task_name 502 | metrics = trainer.evaluate(eval_dataset=predict_dataset) 503 | 504 | max_predict_samples = ( 505 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 506 | ) 507 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 508 | 509 | trainer.log_metrics(f"predict-{task}", metrics) 510 | trainer.save_metrics(f"predict-{task}", metrics) 511 | 512 | 513 | def _mp_fn(index): 514 | # For xla_spawn (TPUs) 515 | main() 516 | 517 | 518 | if __name__ == "__main__": 519 | main() -------------------------------------------------------------------------------- /run_pretrained.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Team All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) on a text file or a dataset. 18 | 19 | Here is the full list of checkpoints on the hub that can be fine-tuned by this script: 20 | https://huggingface.co/models?filter=masked-lm 21 | """ 22 | # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. 23 | import torch 24 | import logging 25 | import math 26 | import os 27 | import sys 28 | from dataclasses import dataclass, field 29 | from typing import Optional 30 | 31 | from extra.dataset_dict import DatasetDict as newDatasetDict 32 | 33 | from datasets import load_dataset, concatenate_datasets 34 | 35 | import numpy as np 36 | 37 | import transformers 38 | from transformers import ( 39 | CONFIG_MAPPING, 40 | MODEL_FOR_MASKED_LM_MAPPING, 41 | AutoConfig, 42 | AutoModelForPreTraining, 43 | EvalPrediction, 44 | AutoTokenizer, 45 | DataCollatorForLanguageModeling, 46 | HfArgumentParser, 47 | TrainingArguments, 48 | set_seed, 49 | ) 50 | from extra.classifier_trainer import SM_Trainer as Trainer 51 | from transformers.trainer_utils import get_last_checkpoint, is_main_process 52 | from transformers.utils import check_min_version 53 | import random 54 | 55 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 56 | check_min_version("4.21.3") 57 | 58 | logger = logging.getLogger(__name__) 59 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) 60 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 61 | 62 | 63 | @dataclass 64 | class ModelArguments: 65 | """ 66 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 67 | """ 68 | 69 | model_name_or_path: Optional[str] = field( 70 | default=None, 71 | metadata={ 72 | "help": "The model checkpoint for weights initialization." 73 | "Don't set if you want to train a model from scratch." 74 | }, 75 | ) 76 | model_type: Optional[str] = field( 77 | default=None, 78 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, 79 | ) 80 | config_name: Optional[str] = field( 81 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 82 | ) 83 | tokenizer_name: Optional[str] = field( 84 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 85 | ) 86 | cache_dir: Optional[str] = field( 87 | default=None, 88 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 89 | ) 90 | use_fast_tokenizer: bool = field( 91 | default=True, 92 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 93 | ) 94 | model_revision: str = field( 95 | default="main", 96 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 97 | ) 98 | use_auth_token: bool = field( 99 | default=False, 100 | metadata={ 101 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 102 | "with private models)." 103 | }, 104 | ) 105 | 106 | 107 | @dataclass 108 | class DataTrainingArguments: 109 | """ 110 | Arguments pertaining to what data we are going to input our model for training and eval. 111 | """ 112 | 113 | dataset_name: Optional[str] = field( 114 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 115 | ) 116 | dataset_config_name: Optional[str] = field( 117 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 118 | ) 119 | dataset2_name: Optional[str] = field( 120 | default=None, metadata={"help": "The name of the dataset2 to use (via the datasets library)."} 121 | ) 122 | dataset2_config_name: Optional[str] = field( 123 | default=None, metadata={"help": "The configuration name of the dataset2 to use (via the datasets library)."} 124 | ) 125 | train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) 126 | validation_file: Optional[str] = field( 127 | default=None, 128 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 129 | ) 130 | overwrite_cache: bool = field( 131 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 132 | ) 133 | validation_split_percentage: Optional[int] = field( 134 | default=5, 135 | metadata={ 136 | "help": "The percentage of the train set used as validation set in case there's no validation split" 137 | }, 138 | ) 139 | max_seq_length: Optional[int] = field( 140 | default=None, 141 | metadata={ 142 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 143 | "than this will be truncated." 144 | }, 145 | ) 146 | preprocessing_num_workers: Optional[int] = field( 147 | default=None, 148 | metadata={"help": "The number of processes to use for the preprocessing."}, 149 | ) 150 | mlm_probability: float = field( 151 | default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} 152 | ) 153 | line_by_line: bool = field( 154 | default=False, 155 | metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."}, 156 | ) 157 | pad_to_max_length: bool = field( 158 | default=False, 159 | metadata={ 160 | "help": "Whether to pad all samples to `max_seq_length`. " 161 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 162 | }, 163 | ) 164 | dupe_factor: int = field( 165 | default=5, 166 | metadata={ 167 | "help": "Number of times to duplicate the input data (with different masks)." 168 | }, 169 | ) 170 | max_train_samples: Optional[int] = field( 171 | default=None, 172 | metadata={ 173 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 174 | "value if set." 175 | }, 176 | ) 177 | max_eval_samples: Optional[int] = field( 178 | default=None, 179 | metadata={ 180 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 181 | "value if set." 182 | }, 183 | ) 184 | 185 | def __post_init__(self): 186 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 187 | raise ValueError("Need either a dataset name or a training/validation file.") 188 | else: 189 | if self.train_file is not None: 190 | extension = self.train_file.split(".")[-1] 191 | assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." 192 | if self.validation_file is not None: 193 | extension = self.validation_file.split(".")[-1] 194 | assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." 195 | 196 | 197 | def main(): 198 | # See all possible arguments in src/transformers/training_args.py 199 | # or by passing the --help flag to this script. 200 | # We now keep distinct sets of args, for a cleaner separation of concerns. 201 | 202 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 203 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 204 | # If we pass only one argument to the script and it's the path to a json file, 205 | # let's parse it to get our arguments. 206 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 207 | else: 208 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 209 | 210 | # Detecting last checkpoint. 211 | last_checkpoint = None 212 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 213 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 214 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 215 | raise ValueError( 216 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 217 | "Use --overwrite_output_dir to overcome." 218 | ) 219 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 220 | logger.info( 221 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 222 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 223 | ) 224 | 225 | # Setup logging 226 | logging.basicConfig( 227 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 228 | datefmt="%m/%d/%Y %H:%M:%S", 229 | handlers=[logging.StreamHandler(sys.stdout)], 230 | ) 231 | logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) 232 | 233 | # Log on each process the small summary: 234 | logger.warning( 235 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 236 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 237 | ) 238 | # Set the verbosity to info of the Transformers logger (on main process only): 239 | if is_main_process(training_args.local_rank): 240 | transformers.utils.logging.set_verbosity_info() 241 | transformers.utils.logging.enable_default_handler() 242 | transformers.utils.logging.enable_explicit_format() 243 | logger.info(f"Training/evaluation parameters {training_args}") 244 | 245 | # Set seed before initializing model. 246 | set_seed(training_args.seed) 247 | 248 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 249 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 250 | # (the dataset will be downloaded automatically from the datasets Hub 251 | # 252 | # For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this 253 | # behavior (see below) 254 | # 255 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 256 | # download the dataset. 257 | if data_args.dataset_name is not None: 258 | # Downloading and loading a dataset from the hub. 259 | datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) 260 | if "validation" not in datasets.keys(): 261 | datasets["validation"] = load_dataset( 262 | data_args.dataset_name, 263 | data_args.dataset_config_name, 264 | split=f"train[:{data_args.validation_split_percentage}%]", 265 | cache_dir=model_args.cache_dir, 266 | ) 267 | datasets["train"] = load_dataset( 268 | data_args.dataset_name, 269 | data_args.dataset_config_name, 270 | split=f"train[{data_args.validation_split_percentage}%:]", 271 | cache_dir=model_args.cache_dir, 272 | ) 273 | if data_args.dataset2_name is not None: 274 | datasets_2 = load_dataset(data_args.dataset2_name, data_args.dataset2_config_name, cache_dir=model_args.cache_dir) 275 | for k in datasets.keys(): 276 | datasets[k] = concatenate_datasets([datasets[k], datasets_2[k]]) 277 | else: 278 | data_files = {} 279 | if data_args.train_file is not None: 280 | data_files["train"] = data_args.train_file 281 | if data_args.validation_file is not None: 282 | data_files["validation"] = data_args.validation_file 283 | extension = data_args.train_file.split(".")[-1] 284 | if extension == "txt": 285 | extension = "text" 286 | datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) 287 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 288 | # https://huggingface.co/docs/datasets/loading_datasets.html. 289 | 290 | # XXX: inject patch, reconstruct later 291 | if datasets.__class__.__name__ == 'DatasetDict': 292 | setattr(datasets.__class__, 'map', newDatasetDict.map) 293 | 294 | # Load pretrained model and tokenizer 295 | # 296 | # Distributed training: 297 | # The .from_pretrained methods guarantee that only one local process can concurrently 298 | # download model & vocab. 299 | config_kwargs = { 300 | "cache_dir": model_args.cache_dir, 301 | "revision": model_args.model_revision, 302 | "use_auth_token": True if model_args.use_auth_token else None, 303 | "trust_remote_code": True, 304 | } 305 | if model_args.config_name: 306 | config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) 307 | elif model_args.model_name_or_path: 308 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) 309 | else: 310 | config = CONFIG_MAPPING[model_args.model_type]() 311 | logger.warning("You are instantiating a new config instance from scratch.") 312 | 313 | tokenizer_kwargs = { 314 | "cache_dir": model_args.cache_dir, 315 | "use_fast": model_args.use_fast_tokenizer, 316 | "revision": model_args.model_revision, 317 | "use_auth_token": True if model_args.use_auth_token else None, 318 | "model_input_names": ['input_ids', 'token_type_ids', 'attention_mask', 'segment_ids'], 319 | "trust_remote_code": True, 320 | } 321 | if model_args.tokenizer_name: 322 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) 323 | elif model_args.model_name_or_path: 324 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs) 325 | else: 326 | raise ValueError( 327 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 328 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 329 | ) 330 | 331 | if model_args.model_name_or_path: 332 | model = AutoModelForPreTraining.from_pretrained( 333 | model_args.model_name_or_path, 334 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 335 | config=config, 336 | cache_dir=model_args.cache_dir, 337 | revision=model_args.model_revision, 338 | use_auth_token=True if model_args.use_auth_token else None, 339 | trust_remote_code=True 340 | ) 341 | else: 342 | logger.info("Training new model from scratch") 343 | model = AutoModelForPreTraining.from_config(config, trust_remote_code=True) 344 | 345 | model.resize_token_embeddings(len(tokenizer)) 346 | 347 | # Preprocessing the datasets. 348 | # First we tokenize all the texts. 349 | if training_args.do_train: 350 | column_names = datasets["train"].column_names 351 | else: 352 | column_names = datasets["validation"].column_names 353 | text_column_name = "text" if "text" in column_names else column_names[0] 354 | 355 | if data_args.max_seq_length is None: 356 | max_seq_length = tokenizer.model_max_length 357 | if max_seq_length > 1024: 358 | logger.warning( 359 | f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " 360 | "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx." 361 | ) 362 | max_seq_length = 1024 363 | else: 364 | if data_args.max_seq_length > tokenizer.model_max_length: 365 | logger.warning( 366 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 367 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 368 | ) 369 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 370 | 371 | if data_args.line_by_line: 372 | # When using line_by_line, we just tokenize each nonempty line. 373 | padding = "max_length" if data_args.pad_to_max_length else False 374 | 375 | def tokenize_function(examples): 376 | # Remove empty lines 377 | examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()] 378 | return tokenizer( 379 | examples["text"], 380 | padding=padding, 381 | truncation=True, 382 | max_length=max_seq_length, 383 | add_special_tokens=False, 384 | # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it 385 | # receives the `special_tokens_mask`. 386 | return_special_tokens_mask=True, 387 | ) 388 | 389 | tokenized_datasets = datasets.map( 390 | tokenize_function, 391 | batched=True, 392 | num_proc=data_args.preprocessing_num_workers, 393 | remove_columns=[text_column_name], 394 | load_from_cache_file=not data_args.overwrite_cache, 395 | ) 396 | else: 397 | # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. 398 | # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more 399 | # efficient when it receives the `special_tokens_mask`. 400 | def tokenize_function(examples): 401 | return tokenizer(examples[text_column_name], return_special_tokens_mask=True, add_special_tokens=False) 402 | 403 | tokenized_datasets = datasets.map( 404 | tokenize_function, 405 | batched=True, 406 | num_proc=data_args.preprocessing_num_workers, 407 | remove_columns=column_names, 408 | load_from_cache_file=not data_args.overwrite_cache, 409 | new_fingerprints={'train': 'f100a6a7741b77ef', 'validation': 'f815a2392c2825cb'} 410 | ) 411 | 412 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of 413 | # max_seq_length. 414 | def group_texts(examples): 415 | """Creates examples for a single document.""" 416 | results = {k:[] for k in examples.keys()} 417 | results["sentence_structural_label"] = [] 418 | results["segment_ids"] = [] 419 | for _ in range(data_args.dupe_factor): 420 | # Account for special tokens 421 | max_num_tokens = max_seq_length - tokenizer.num_special_tokens_to_add(pair=True) 422 | short_seq_prob = 0.1 423 | fk = 'input_ids' 424 | 425 | # We *usually* want to fill up the entire sequence since we are padding 426 | # to `block_size` anyways, so short sequences are generally wasted 427 | # computation. However, we *sometimes* 428 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter 429 | # sequences to minimize the mismatch between pretraining and fine-tuning. 430 | # The `target_seq_length` is just a rough target however, whereas 431 | # `block_size` is a hard limit. 432 | target_seq_length = max_num_tokens 433 | if random.random() < short_seq_prob: 434 | target_seq_length = random.randint(2, max_num_tokens) 435 | 436 | # get_all_chunk 437 | total_chunk = [] 438 | current_chunk = [] # a buffer stored current working segments 439 | current_length = 0 440 | i = 0 441 | while i < len(examples[fk]): 442 | segment = examples[fk][i] # get a segment 443 | if not segment: 444 | i += 1 445 | continue 446 | current_chunk.append(examples['input_ids'][i]) # add a segment to current chunk 447 | current_length += len(segment) # overall token length 448 | # if current length goes to the target length or reaches the end of file, start building token a and b 449 | if i == len(examples[fk]) - 1 or current_length >= target_seq_length: 450 | if current_chunk: 451 | total_chunk.append(current_chunk) 452 | 453 | current_chunk = [] # clear current chunk 454 | current_length = 0 # reset current text length 455 | i += 1 # go to next line 456 | 457 | # We DON'T just concatenate all of the tokens from a document into a long 458 | # sequence and choose an arbitrary split point because this would make the 459 | # next sentence prediction task too easy. Instead, we split the input into 460 | # segments "A" and "B" based on the actual "sentences" provided by the user 461 | # input. 462 | current_chunk = [] # a buffer stored current working segments 463 | current_length = 0 464 | i = 0 465 | chunk_id = -1 466 | while i < len(examples[fk]): 467 | segment = examples[fk][i] # get a segment 468 | if not segment: 469 | i += 1 470 | continue 471 | current_chunk.append(examples['input_ids'][i]) # add a segment to current chunk 472 | current_length += len(segment) # overall token length 473 | # if current length goes to the target length or reaches the end of file, start building token a and b 474 | if i == len(examples[fk]) - 1 or current_length >= target_seq_length: 475 | if current_chunk: 476 | chunk_id += 1 477 | # `a_end` is how many segments from `current_chunk` go into the `A` (first) sentence. 478 | a_end = 1 479 | # if current chunk has more than 2 sentences, pick part of it `A` (first) sentence 480 | if len(current_chunk) >= 2: 481 | a_end = random.randint(1, len(current_chunk) - 1) 482 | # token a 483 | tokens_a = [] 484 | a_segment_ids = [] 485 | for j in range(a_end): 486 | tokens_a.extend(current_chunk[j]) 487 | a_segment_ids.extend([j] * len(current_chunk[j])) 488 | 489 | # token b 490 | tokens_b = [] 491 | b_segment_ids = [] 492 | for j in range(a_end, len(current_chunk)): 493 | tokens_b.extend(current_chunk[j]) 494 | b_segment_ids.extend([j] * len(current_chunk[j])) 495 | 496 | if len(tokens_a) == 0 or len(tokens_b) == 0: 497 | continue 498 | 499 | rdn = random.random() 500 | if rdn < 1/3: 501 | is_next = 1 502 | tokens_a, tokens_b = tokens_b, tokens_a 503 | a_segment_ids, b_segment_ids = b_segment_ids, a_segment_ids 504 | elif rdn < 2/3 and len(total_chunk) > 1: 505 | is_next = 2 506 | while True: 507 | rid = random.randint(0, len(total_chunk)-1) 508 | if rid != chunk_id: 509 | break 510 | another_chunk = total_chunk[rid] 511 | tokens_b = sum(another_chunk, []) 512 | b_segment_ids = sum([[acid]*len(ac) for acid, ac in enumerate(another_chunk)], []) 513 | else: 514 | is_next = 0 515 | 516 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, a_segment_ids, b_segment_ids): 517 | """Truncates a pair of sequences to a maximum sequence length.""" 518 | while True: 519 | total_length = len(tokens_a) + len(tokens_b) 520 | if total_length <= max_num_tokens: 521 | break 522 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b 523 | trunc_segment_ids = a_segment_ids if len(a_segment_ids) > len(b_segment_ids) else b_segment_ids 524 | assert len(trunc_tokens) >= 1 525 | # We want to sometimes truncate from the front and sometimes from the 526 | # back to add more randomness and avoid biases. 527 | if random.random() < 0.5: 528 | del trunc_tokens[0] 529 | del trunc_segment_ids[0] 530 | else: 531 | trunc_tokens.pop() 532 | trunc_segment_ids.pop() 533 | 534 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, a_segment_ids, b_segment_ids) 535 | # add special tokens 536 | input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b) 537 | # add token type ids, 0 for sentence a, 1 for sentence b 538 | token_type_ids = tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b) 539 | attention_mask = [1] * len(input_ids) 540 | assert len(tokens_a) >= 1 541 | assert len(tokens_b) >= 1 542 | 543 | results["input_ids"].append(input_ids) 544 | results["token_type_ids"].append(token_type_ids) 545 | results["attention_mask"].append(attention_mask) 546 | results["sentence_structural_label"].append(is_next) 547 | results["special_tokens_mask"].append([1] + [0] * len(tokens_a) + [1] + [0] * len(tokens_b) + [1]) 548 | 549 | a_segment_ids = [asi-a_segment_ids[0]+1 for asi in a_segment_ids] 550 | b_segment_ids = [bsi-b_segment_ids[0]+a_segment_ids[-1]+2 for bsi in b_segment_ids] 551 | results["segment_ids"].append([0] + a_segment_ids + [a_segment_ids[-1]+1] + b_segment_ids + [b_segment_ids[-1]+1]) 552 | current_chunk = [] # clear current chunk 553 | current_length = 0 # reset current text length 554 | i += 1 # go to next line 555 | return results 556 | 557 | # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a 558 | # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value 559 | # might be slower to preprocess. 560 | # 561 | # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: 562 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map 563 | ori_dupe_factor = data_args.dupe_factor 564 | data_args.dupe_factor = 1 565 | tokenized_datasets["validation"] = tokenized_datasets["validation"].map( 566 | group_texts, 567 | batched=True, 568 | num_proc=data_args.preprocessing_num_workers, 569 | load_from_cache_file=not data_args.overwrite_cache, 570 | new_fingerprint="g5888790b41eabcb" 571 | ) 572 | data_args.dupe_factor = ori_dupe_factor 573 | tokenized_datasets["train"] = tokenized_datasets["train"].map( 574 | group_texts, 575 | batched=True, 576 | num_proc=data_args.preprocessing_num_workers, 577 | load_from_cache_file=not data_args.overwrite_cache, 578 | new_fingerprint="ed9f003830c2481d" 579 | ) 580 | 581 | if training_args.do_train: 582 | if "train" not in tokenized_datasets: 583 | raise ValueError("--do_train requires a train dataset") 584 | train_dataset = tokenized_datasets["train"] 585 | if data_args.max_train_samples is not None: 586 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 587 | 588 | if training_args.do_eval: 589 | if "validation" not in tokenized_datasets: 590 | raise ValueError("--do_eval requires a validation dataset") 591 | eval_dataset = tokenized_datasets["validation"] 592 | if data_args.max_eval_samples is not None: 593 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 594 | 595 | # Data collator 596 | # This one will take care of randomly masking the tokens. 597 | pad_to_multiple_of_8 = data_args.line_by_line and training_args.fp16 and not data_args.pad_to_max_length 598 | data_collator = DataCollatorForLanguageModeling( 599 | tokenizer=tokenizer, 600 | mlm_probability=data_args.mlm_probability, 601 | pad_to_multiple_of=8 if pad_to_multiple_of_8 else None, 602 | ) 603 | 604 | def compute_metrics(p: EvalPrediction): 605 | p.predictions[2][p.predictions[2]==-100] = -1 606 | out = { 607 | "mlm_loss": p.predictions[0].mean().item(), 608 | "sop_loss": p.predictions[1].mean().item(), 609 | "mlm_accuracy": ((p.predictions[2] == p.label_ids[0]).astype(np.float32).sum()/(p.label_ids[0] != -100).astype(np.float32).sum()).item(), 610 | "sop_accuracy": (p.predictions[3] == p.label_ids[1]).astype(np.float32).mean().item(), 611 | } 612 | return out 613 | 614 | # Initialize our Trainer 615 | trainer = Trainer( 616 | model=model, 617 | args=training_args, 618 | train_dataset=train_dataset if training_args.do_train else None, 619 | eval_dataset=eval_dataset if training_args.do_eval else None, 620 | compute_metrics=compute_metrics, 621 | tokenizer=tokenizer, 622 | data_collator=data_collator, 623 | ) 624 | 625 | # Training 626 | if training_args.do_train: 627 | checkpoint = None 628 | if training_args.resume_from_checkpoint is not None: 629 | checkpoint = training_args.resume_from_checkpoint 630 | elif last_checkpoint is not None: 631 | checkpoint = last_checkpoint 632 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 633 | trainer.save_model() # Saves the tokenizer too for easy upload 634 | metrics = train_result.metrics 635 | 636 | max_train_samples = ( 637 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 638 | ) 639 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 640 | 641 | trainer.log_metrics("train", metrics) 642 | trainer.save_metrics("train", metrics) 643 | trainer.save_state() 644 | 645 | # Evaluation 646 | if training_args.do_eval: 647 | logger.info("*** Evaluate ***") 648 | 649 | metrics = trainer.evaluate() 650 | 651 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 652 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 653 | perplexity = math.exp(metrics["eval_loss"]) 654 | metrics["perplexity"] = perplexity 655 | 656 | trainer.log_metrics("eval", metrics) 657 | trainer.save_metrics("eval", metrics) 658 | 659 | if training_args.push_to_hub: 660 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "fill-mask"} 661 | if data_args.dataset_name is not None: 662 | kwargs["dataset_tags"] = data_args.dataset_name 663 | if data_args.dataset_config_name is not None: 664 | kwargs["dataset_args"] = data_args.dataset_config_name 665 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 666 | else: 667 | kwargs["dataset"] = data_args.dataset_name 668 | 669 | trainer.push_to_hub(**kwargs) 670 | 671 | 672 | def _mp_fn(index): 673 | # For xla_spawn (TPUs) 674 | main() 675 | 676 | 677 | if __name__ == "__main__": 678 | main() 679 | -------------------------------------------------------------------------------- /run_shell/1-pretrain_bookcorpus_wikipedia.sh: -------------------------------------------------------------------------------- 1 | SUFFIX=PoNet_bookcourpus_wikipedia_dupe5 2 | LOGNAME=`date +%Y%m%d%H`_${SUFFIX}.log 3 | OUTPUT=outputs/${SUFFIX} 4 | MODEL_PATH=chtan/ponet-base-uncased 5 | 6 | if [ ! -d logs ]; then 7 | mkdir logs 8 | fi 9 | 10 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port 31030 run_pretrained.py \ 11 | --config_name ${MODEL_PATH} \ 12 | --tokenizer_name ${MODEL_PATH} \ 13 | --dataset_name bookcorpus \ 14 | --dataset2_name wikipedia \ 15 | --dataset2_config_name 20200501.en \ 16 | --label_names labels next_sentence_label \ 17 | --save_total_limit 5 \ 18 | --dupe_factor 5 \ 19 | --num_train_epochs 500 \ 20 | --warmup_steps 5000 \ 21 | --learning_rate 1e-4 \ 22 | --evaluation_strategy steps \ 23 | --save_steps 5000 \ 24 | --eval_steps 5000 \ 25 | --logging_dir ${OUTPUT} \ 26 | --report_to tensorboard \ 27 | --do_train \ 28 | --do_eval \ 29 | --ignore_data_skip \ 30 | --per_device_train_batch_size 48 \ 31 | --gradient_accumulation_steps 1 \ 32 | --fp16 \ 33 | --sharded_ddp simple \ 34 | --output_dir ${OUTPUT} > logs/${LOGNAME} 2>&1 & -------------------------------------------------------------------------------- /run_shell/1-pretrain_bookcorpus_wikitext.sh: -------------------------------------------------------------------------------- 1 | SUFFIX=PoNet_bookcourpus_wikitext_dupe5 2 | LOGNAME=`date +%Y%m%d%H`_${SUFFIX}.log 3 | OUTPUT=outputs/${SUFFIX} 4 | MODEL_PATH=chtan/ponet-base-uncased 5 | 6 | if [ ! -d logs ]; then 7 | mkdir logs 8 | fi 9 | 10 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port 31029 run_pretrained.py \ 11 | --config_name ${MODEL_PATH} \ 12 | --tokenizer_name ${MODEL_PATH} \ 13 | --dataset_name bookcorpus \ 14 | --dataset2_name wikitext \ 15 | --dataset2_config_name wikitext-103-raw-v1 \ 16 | --label_names labels next_sentence_label \ 17 | --save_total_limit 5 \ 18 | --dupe_factor 5 \ 19 | --num_train_epochs 500 \ 20 | --warmup_steps 5000 \ 21 | --learning_rate 1e-4 \ 22 | --evaluation_strategy steps \ 23 | --save_steps 5000 \ 24 | --eval_steps 5000 \ 25 | --logging_dir ${OUTPUT} \ 26 | --report_to tensorboard \ 27 | --do_train \ 28 | --do_eval \ 29 | --ignore_data_skip \ 30 | --per_device_train_batch_size 48 \ 31 | --gradient_accumulation_steps 1 \ 32 | --fp16 \ 33 | --sharded_ddp simple \ 34 | --output_dir ${OUTPUT} > logs/${LOGNAME} 2>&1 & -------------------------------------------------------------------------------- /run_shell/2-GLUE.sh: -------------------------------------------------------------------------------- 1 | num_train_epochs=4 2 | MODEL=chtan/ponet-base-uncased 3 | OUTPRE=`pwd` 4 | 5 | if [ ! -d logs/glue ]; then 6 | mkdir -p logs/glue 7 | fi 8 | 9 | cal_mlp(){ 10 | NAME=`date +%Y%m%d%H`_${TASK_NAME}_ep${num_train_epochs}_bz$((bz*GAS))_lr${lr} 11 | OUTPUT=outputs/glue/${NAME} 12 | 13 | CUDA_VISIBLE_DEVICES=${GPUID} python -u run_glue.py \ 14 | --model_name_or_path ${MODEL} \ 15 | --overwrite_output_dir \ 16 | --task_name $TASK_NAME \ 17 | --do_train \ 18 | --do_eval \ 19 | --max_seq_length 128 \ 20 | --overwrite_output_dir \ 21 | --report_to tensorboard \ 22 | --per_device_train_batch_size ${bz} \ 23 | --learning_rate ${lr} \ 24 | --num_train_epochs ${num_train_epochs} \ 25 | --save_total_limit 1 \ 26 | --evaluation_strategy steps \ 27 | --save_steps 5000 \ 28 | --save_strategy no \ 29 | --fp16 \ 30 | --logging_dir ${OUTPUT} \ 31 | --output_dir ${OUTPUT} > logs/glue/${NAME}.log 2>&1 32 | } 33 | 34 | 35 | search(){ 36 | GAS=1 37 | bz=128 38 | lr=3e-4; cal_mlp 39 | lr=1e-4; cal_mlp 40 | lr=5e-5; cal_mlp 41 | lr=3e-5; cal_mlp 42 | 43 | GAS=1 44 | bz=64 45 | lr=3e-4; cal_mlp 46 | lr=1e-4; cal_mlp 47 | lr=5e-5; cal_mlp 48 | lr=3e-5; cal_mlp 49 | 50 | GAS=1 51 | bz=32 52 | lr=3e-4; cal_mlp 53 | lr=1e-4; cal_mlp 54 | lr=5e-5; cal_mlp 55 | lr=3e-5; cal_mlp 56 | 57 | GAS=1 58 | bz=16 59 | lr=3e-4; cal_mlp 60 | lr=1e-4; cal_mlp 61 | lr=5e-5; cal_mlp 62 | lr=3e-5; cal_mlp 63 | 64 | GAS=1 65 | bz=8 66 | lr=3e-4; cal_mlp 67 | lr=1e-4; cal_mlp 68 | lr=5e-5; cal_mlp 69 | lr=3e-5; cal_mlp 70 | } 71 | 72 | GPUID=0; TASK_NAME=cola; search & 73 | # GPUID=1; TASK_NAME=stsb; search & 74 | # GPUID=2; TASK_NAME=mrpc; search & 75 | # GPUID=3; TASK_NAME=rte; search & 76 | # GPUID=4; TASK_NAME=sst2; search & 77 | # GPUID=5; TASK_NAME=qqp; search & 78 | # GPUID=6; TASK_NAME=qnli; search & 79 | # GPUID=7; TASK_NAME=mnli; search & 80 | -------------------------------------------------------------------------------- /run_shell/3-LongTask.sh: -------------------------------------------------------------------------------- 1 | MODEL=chtan/ponet-base-uncased 2 | OUTPRE=`pwd` 3 | 4 | cal(){ 5 | NAME=`date +%Y%m%d%H`_${MAINTASK}_ep${num_train_epochs}_bz$((bz*GAS))_lr${lr} 6 | OUTPUT=outputs/${MAINTASK}/${NAME} 7 | 8 | CUDA_VISIBLE_DEVICES=${GPUID} torchrun --nproc_per_node=${GPU_NUMS} run_long_classification.py \ 9 | --model_name_or_path ${MODEL} \ 10 | --task_name ${MAINTASK} \ 11 | --overwrite_output_dir \ 12 | --do_train \ 13 | --do_eval \ 14 | --do_predict \ 15 | --max_seq_length 4096 \ 16 | --gradient_accumulation_steps ${GAS} \ 17 | --per_device_train_batch_size ${bz} \ 18 | --per_device_eval_batch_size ${bz} \ 19 | --learning_rate ${lr} \ 20 | --num_train_epochs ${num_train_epochs} \ 21 | --save_total_limit 1 \ 22 | --evaluation_strategy steps \ 23 | --logging_steps 500 \ 24 | --eval_steps 500 \ 25 | --save_steps 5000 \ 26 | --load_best_model_at_end \ 27 | --metric_for_best_model f1 \ 28 | --save_strategy no \ 29 | --report_to tensorboard \ 30 | --fp16 \ 31 | --logging_dir ${OUTPUT} \ 32 | --output_dir ${OUTPUT} > logs/${MAINTASK}/${NAME}.log 2>&1 33 | } 34 | 35 | search(){ 36 | num_train_epochs=${NEP} 37 | if [ ! -d logs/${MAINTASK} ]; then 38 | mkdir -p logs/${MAINTASK} 39 | fi 40 | lr=3e-5; cal 41 | lr=5e-5; cal 42 | } 43 | 44 | GPUID=0,1 45 | GPU_NUMS=2 46 | GAS=1 47 | bz=16 48 | ### ------- Arxiv-11 -------- 49 | MAINTASK=arxiv; NEP=10; search; 50 | ### ------- IMDb -------- 51 | # MAINTASK=imdb; NEP=10; search; 52 | ### ------- HND -------- 53 | # MAINTASK=hnd; NEP=10; search; 54 | ### ------- Yelp-5 -------- 55 | # MAINTASK=yelp; NEP=2; search; 56 | -------------------------------------------------------------------------------- /run_shell/D1-arxiv11.sh: -------------------------------------------------------------------------------- 1 | WORKDIR=`pwd` 2 | DATADIR=${WORKDIR}/datasets/arxiv-11 3 | TMPDIR=`mktemp -d` 4 | git clone https://github.com/LiqunW/Long-document-dataset ${TMPDIR} 5 | cd ${TMPDIR} 6 | RARDIR=rar_dir 7 | mkdir ${RARDIR} 8 | mv cs.*.rar math.*.rar ${RARDIR}/ 9 | 10 | unrar_dir(){ 11 | src_path=`readlink -f $1` 12 | dst_path=`readlink -f $2` 13 | rar_files=`find $src_path -name '*.rar'` 14 | IFS=$'\n'; array=$rar_files; unset IFS 15 | for rar_file in $array; do 16 | file_path=`echo $rar_file | sed -e "s;$src_path;$dst_path;"` 17 | ext_path=${file_path%/*} 18 | if [ ! -d $ext_path ]; then 19 | mkdir -p $ext_path 20 | fi 21 | unrar x $rar_file $ext_path 22 | done 23 | } 24 | 25 | mkdir data 26 | unrar_dir "${RARDIR}" "data" 27 | 28 | if [ ! -f dataloader.py_ori ]; then 29 | cp dataloader.py dataloader.py_ori 30 | echo "Dataloader('data', 32)" >> dataloader.py 31 | fi 32 | python dataloader.py 33 | 34 | mv Dataset.txt Labels_file.txt data/ 35 | mv data ${DATADIR} 36 | 37 | # if [ -d ${TMPDIR} ]; then 38 | # rm -r ${TMPDIR} 39 | # fi --------------------------------------------------------------------------------