├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── datasets
├── arxiv-11
│ └── arxiv-11.py
└── hyperpartisan_news_detection
│ ├── dev.jsonl
│ ├── test.jsonl
│ └── train.jsonl
├── examples
└── LRA
│ ├── README.md
│ ├── code
│ ├── attention.py
│ ├── attention_fnet.py
│ ├── attention_linear.py
│ ├── attention_linformer.py
│ ├── attention_nystrom.py
│ ├── attention_performer.py
│ ├── attention_ponet.py
│ ├── attention_reformer.py
│ ├── dataset.py
│ ├── lra_config.py
│ ├── model.py
│ ├── model_wrapper.py
│ ├── run_tasks.py
│ └── run_tasks.sh
│ └── datasets
│ ├── cifar10.py
│ ├── create_datasets.sh
│ ├── delete_repeat.sh
│ ├── listops.py
│ ├── pathfinder.py
│ ├── requirements.txt
│ ├── retrieval.py
│ └── text.py
├── extra
├── classifier_trainer.py
└── dataset_dict.py
├── image
├── consumption.png
├── model.png
└── performance.png
├── metrics
├── macro_f1_and_acc
│ └── macro_f1_and_acc.py
└── micro_f1_and_acc
│ └── micro_f1_and_acc.py
├── requirements.txt
├── run_glue.py
├── run_long_classification.py
├── run_pretrained.py
└── run_shell
├── 1-pretrain_bookcorpus_wikipedia.sh
├── 1-pretrain_bookcorpus_wikitext.sh
├── 2-GLUE.sh
├── 3-LongTask.sh
└── D1-arxiv11.sh
/.gitattributes:
--------------------------------------------------------------------------------
1 | outputs/ponet-base-uncased/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # General
2 | .DS_Store
3 | .AppleDouble
4 | .LSOverride
5 |
6 | # Icon must end with two \r
7 | Icon
8 |
9 |
10 | # Thumbnails
11 | ._*
12 |
13 | # Files that might appear in the root of a volume
14 | .DocumentRevisions-V100
15 | .fseventsd
16 | .Spotlight-V100
17 | .TemporaryItems
18 | .Trashes
19 | .VolumeIcon.icns
20 | .com.apple.timemachine.donotpresent
21 |
22 | # Directories potentially created on remote AFP share
23 | .AppleDB
24 | .AppleDesktop
25 | Network Trash Folder
26 | Temporary Items
27 | .apdisk
28 |
29 | # Byte-compiled / optimized / DLL files
30 | __pycache__/
31 | *.py[cod]
32 | *$py.class
33 |
34 | # C extensions
35 | *.so
36 |
37 | # Distribution / packaging
38 | .Python
39 | build/
40 | develop-eggs/
41 | dist/
42 | downloads/
43 | eggs/
44 | .eggs/
45 | lib/
46 | lib64/
47 | parts/
48 | sdist/
49 | var/
50 | wheels/
51 | share/python-wheels/
52 | *.egg-info/
53 | .installed.cfg
54 | *.egg
55 | MANIFEST
56 |
57 | # PyInstaller
58 | # Usually these files are written by a python script from a template
59 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
60 | *.manifest
61 | *.spec
62 |
63 | # Installer logs
64 | pip-log.txt
65 | pip-delete-this-directory.txt
66 |
67 | # Unit test / coverage reports
68 | htmlcov/
69 | .tox/
70 | .nox/
71 | .coverage
72 | .coverage.*
73 | .cache
74 | nosetests.xml
75 | coverage.xml
76 | *.cover
77 | *.py,cover
78 | .hypothesis/
79 | .pytest_cache/
80 | cover/
81 |
82 | # Translations
83 | *.mo
84 | *.pot
85 |
86 | # Django stuff:
87 | *.log
88 | local_settings.py
89 | db.sqlite3
90 | db.sqlite3-journal
91 |
92 | # Flask stuff:
93 | instance/
94 | .webassets-cache
95 |
96 | # Scrapy stuff:
97 | .scrapy
98 |
99 | # Sphinx documentation
100 | docs/_build/
101 |
102 | # PyBuilder
103 | .pybuilder/
104 | target/
105 |
106 | # Jupyter Notebook
107 | .ipynb_checkpoints
108 |
109 | # IPython
110 | profile_default/
111 | ipython_config.py
112 |
113 | # pyenv
114 | # For a library or package, you might want to ignore these files since the code is
115 | # intended to run in multiple environments; otherwise, check them in:
116 | # .python-version
117 |
118 | # pipenv
119 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
120 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
121 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
122 | # install all needed dependencies.
123 | #Pipfile.lock
124 |
125 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
126 | __pypackages__/
127 |
128 | # Celery stuff
129 | celerybeat-schedule
130 | celerybeat.pid
131 |
132 | # SageMath parsed files
133 | *.sage.py
134 |
135 | # Environments
136 | .env
137 | .venv
138 | env/
139 | venv/
140 | ENV/
141 | env.bak/
142 | venv.bak/
143 |
144 | # Spyder project settings
145 | .spyderproject
146 | .spyproject
147 |
148 | # Rope project settings
149 | .ropeproject
150 |
151 | # mkdocs documentation
152 | /site
153 |
154 | # mypy
155 | .mypy_cache/
156 | .dmypy.json
157 | dmypy.json
158 |
159 | # Pyre type checker
160 | .pyre/
161 |
162 | # pytype static type analyzer
163 | .pytype/
164 |
165 | # Cython debug symbols
166 | cython_debug/
167 |
168 | .vscode/*
169 | !.vscode/settings.json
170 | !.vscode/tasks.json
171 | !.vscode/launch.json
172 | !.vscode/extensions.json
173 | *.code-workspace
174 |
175 | # Local History for Visual Studio Code
176 | .history/
177 |
178 | datasets/arxiv-11/data
179 | outputs/*
180 | logs/*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Introduction
2 |
3 | > This is a code repository for our paper ***[PoNet: Pooling Network for Efficient Token Mixing in Long Sequences](https://arxiv.org/abs/2110.02442)***. The full source code has been released.
4 |
5 | Transformer-based models have achieved great success in various NLP, vision, and speech tasks. However, the core of Transformer, the self-attention mechanism, has a quadratic time and memory complexity with respect to the sequence length, which hinders applications of Transformer-based models to long sequences. Many approaches have been proposed to mitigate this problem, such as sparse attention mechanisms, low-rank matrix approximations and scalable kernels, and token mixing alternatives to self-attention. We propose a novel Pooling Network (PoNet) for token mixing in long sequences with linear complexity. We design multi-granularity pooling and pooling fusion to capture different levels of contextual information and combine their interactions with tokens. On the Long Range Arena benchmark, PoNet significantly outperforms Transformer and achieves competitive accuracy, while being only slightly slower than the fastest model, FNet, across all sequence lengths measured on GPUs. We also conduct systematic studies on the transfer learning capability of PoNet and observe that PoNet achieves 95.7% of the accuracy of BERT on the GLUE benchmark, outperforming FNet by 4.5% relative. Comprehensive ablation analysis demonstrates effectiveness of the designed multi-granularity pooling and pooling fusion for token mixing in long sequences and efficacy of the designed pre-training tasks for PoNet to learn transferable contextualized language representations.
6 |
7 |

8 |
9 | 
10 |
11 | 
12 |
13 | ## Instruction
14 |
15 | ##### Python environment
16 |
17 | The requirements package is in `requirements.txt`.
18 |
19 | If you are using Nvidia's GPU and CUDA version supports 11.7, you can use the following code to create the desired virtual Python environment:
20 |
21 | ```shell
22 | conda create -n ponet python=3.8
23 | conda activate ponet
24 | pip install -r requirements.txt
25 | ```
26 |
27 | ##### Special for Arxiv-11 Dataset
28 |
29 | The data can be obtained from `https://github.com/LiqunW/Long-document-dataset`.
30 |
31 | We also provided scripts to get it. Please refer to the shell file `run_shell/D1-arxiv11.sh`.
32 |
33 | ##### Run
34 |
35 | For Pre-train, GLUE and Long-Text, please refer to the shell files under the `run_shell` folder.
36 |
37 | For LRA, please refer to `examples/LRA/README.md`.
38 |
39 | ## Changelog
40 |
41 | - [x] [2023.05.22]
42 | - Source codes are submitted to ***[huggingface hub](https://huggingface.co/chtan/ponet-base-uncased)***. PoNet can now be used directly through the Transformers library.
43 | - Since the latest version of PyTorch supports `scatter_max` operations, we removed the third-party `pytorch-scatter` package and used the official functions instead.
44 | - Old codes are moved to ***[tag v1.0](https://github.com/lxchtan/PoNet/tree/v1.0)***.
45 |
46 | - [x] [2022.07.20] Add a brief introduction to the paper in README.
47 |
48 | - [x] [2022.07.09] The pretrained checkpoint is moved to GDrive.
49 |
50 | - [x] [2022.07.09] Release the source code
51 | - [x] [2021.10.19] Pre-train Tasks
52 | - [x] [2021.10.19] GLUE Tasks
53 | - [x] [2022.03.15] LRA Tasks
54 | - [x] [2022.07.09] Long-Text Tasks
55 | - [x] [2021.10.19] Release the pretrained checkpoints
56 |
57 | ## Cite
58 |
59 | ```bibtex
60 | @inproceedings{DBLP:conf/iclr/TanCWZZL22,
61 | author = {Chao{-}Hong Tan and
62 | Qian Chen and
63 | Wen Wang and
64 | Qinglin Zhang and
65 | Siqi Zheng and
66 | Zhen{-}Hua Ling},
67 | title = {PoNet: Pooling Network for Efficient Token Mixing in Long Sequences},
68 | booktitle = {The Tenth International Conference on Learning Representations, {ICLR}
69 | 2022, Virtual Event, April 25-29, 2022},
70 | publisher = {OpenReview.net},
71 | year = {2022},
72 | url = {https://openreview.net/forum?id=9jInD9JjicF},
73 | }
74 | ```
--------------------------------------------------------------------------------
/datasets/arxiv-11/arxiv-11.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Lint as: python3
17 | import os
18 | import pickle
19 | import datasets
20 | from datasets.tasks import TextClassification
21 |
22 | _DESCRIPTION = """\
23 | Arxiv-11
24 | """
25 |
26 | _DOWNLOAD_URL = "datasets/arxiv-11/data"
27 |
28 | class Arxiv11Config(datasets.BuilderConfig):
29 | def __init__(self, **kwargs):
30 | super().__init__(version=datasets.Version("1.0.0", ""), **kwargs)
31 |
32 | class Arxiv11(datasets.GeneratorBasedBuilder):
33 | BUILDER_CONFIGS = [
34 | Arxiv11Config(
35 | name="plain_text",
36 | description="Plain text",
37 | )
38 | ]
39 |
40 | def _info(self):
41 | return datasets.DatasetInfo(
42 | description=_DESCRIPTION,
43 | features=datasets.Features(
44 | {
45 | "text": datasets.Value("string"),
46 | "label": datasets.features.ClassLabel(
47 | names=['cs.AI', 'cs.NE', 'cs.cv', 'cs.CE', 'math.ST', 'cs.SY', 'cs.PL', 'cs.DS', 'cs.IT', 'math.GR', 'math.AC']
48 | )
49 | }
50 | ),
51 | supervised_keys=None,
52 | homepage="https://github.com/LiqunW/Long-document-dataset",
53 | task_templates=[TextClassification(text_column="text", label_column="label")],
54 | )
55 |
56 | def _split_generators(self, dl_manager):
57 | data_dir = _DOWNLOAD_URL
58 | with open(os.path.join(data_dir, 'Dataset.txt'), 'rb') as Dataset_file, open(os.path.join(data_dir, 'Labels_file.txt'), 'rb') as Labels_file:
59 | self.Dataset = pickle.load(Dataset_file)
60 | self.Labels = pickle.load(Labels_file)
61 |
62 | self.nTotal = len(self.Dataset)
63 | self.nTrain = int(self.nTotal*0.8)
64 | self.trainDataset = self.Dataset[0: self.nTrain]
65 | self.trainLabels = self.Labels[0: self.nTrain]
66 |
67 | self.nVal = int(self.nTotal*0.1)
68 | self.valDataset = self.Dataset[self.nTrain: self.nTrain+self.nVal]
69 | self.valLabels = self.Labels[self.nTrain: self.nTrain+self.nVal]
70 |
71 | self.nTest = self.nTotal - self.nTrain - self.nVal
72 | self.testDataset = self.Dataset[self.nTrain+self.nVal: self.nTotal]
73 | self.testLabels = self.Labels[self.nTrain + self.nVal: self.nTotal]
74 |
75 | return [
76 | datasets.SplitGenerator(
77 | name=datasets.Split.TRAIN, gen_kwargs={"file_list": [self.trainDataset, self.trainLabels]}
78 | ),
79 | datasets.SplitGenerator(
80 | name=datasets.Split.VALIDATION, gen_kwargs={"file_list": [self.valDataset, self.valLabels]}
81 | ),
82 | datasets.SplitGenerator(
83 | name=datasets.Split.TEST, gen_kwargs={"file_list": [self.testDataset, self.testLabels]}
84 | ),
85 | ]
86 |
87 | def _generate_examples(self, file_list):
88 | """Generate arxiv-11 examples."""
89 | for id_, (d, l) in enumerate(zip(*file_list)):
90 | with open(os.path.join(os.path.sep.join(_DOWNLOAD_URL.split(os.path.sep)[:-1]), d), encoding="UTF-8") as f:
91 | yield str(id_), {"text": f.read(), "label": l-1}
--------------------------------------------------------------------------------
/examples/LRA/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## LRA Benchmark
3 |
4 | We released the source code for LRA benchmark based on [Nystromformer](https://github.com/mlpen/Nystromformer).
5 |
6 | To prepare the datasets, one would need
7 | ```
8 | tensorboard>=2.3.0, tensorflow>=2.3.1, tensorflow-datasets>=4.0.1, tensorflow_text
9 | ```
10 |
11 | To prepare the datasets, one would need to download the source code from [LRA repo](https://github.com/google-research/long-range-arena) and place `long-range-arena` folder in folder `LRA/datasets/` and also download [lra_release.gz](https://storage.googleapis.com/long-range-arena/lra_release.gz) released by LRA repo and place the unzipped folder in folder `LRA/datasets/`. The directory structure would be
12 | ```
13 | LRA/datasets/long-range-arena
14 | LRA/datasets/lra_release
15 | ```
16 | You may need to delete the code `train_dataset = train_dataset.repeat()` in `./long-range-arena/lra_benchmarks/image/input_pipeline.py`.
17 |
18 | > You can simply run `delete_repeat.sh` to comment out the code above.
19 |
20 | Then, run `sh create_datasets.sh` and it will create train, dev, and test dataset pickle files for each task.
21 |
22 | To run the LRA tasks, one would need
23 |
24 | ```
25 | pytorch==2.0.1, transformers==3.3.1, performer-pytorch
26 | ```
27 | To run a LRA experiment, `run_tasks.sh` will be helpful.
28 |
29 | ```shell
30 | bash run_tasks.sh ponet
31 | ```
32 |
33 | Or you can run the following command in `code` folder
34 |
35 | ```
36 | python3 run_tasks.py --model --task
37 | ```
38 | where `` can be set to `softmax, nystrom-64, reformer-2, performer-256` corresponding to standard self-attention, Nystromformer with 64 landmarks, Reformer with 2 LSHs, Performer with 256 random projection dimension. And `` can be set to `listops, text, retrieval, image, pathfinder32-curv_contour_length_14`. The best models and log files will be saved `LRA/logs/` folder.
39 |
40 |
--------------------------------------------------------------------------------
/examples/LRA/code/attention.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import math
5 | import json
6 | from torch.utils.checkpoint import checkpoint
7 |
8 |
9 | class SoftmaxAttention(nn.Module):
10 | def __init__(self, config):
11 | super().__init__()
12 | self.drop_attn = torch.nn.Dropout(p=config["attention_dropout"])
13 | self.head_dim = config["head_dim"]
14 |
15 | def forward(self, Q, K, V, mask):
16 | dot = torch.matmul(Q, torch.transpose(K, -2, -1))
17 | dot = dot / math.sqrt(self.head_dim)
18 | dot = dot - 1e6 * (1 - mask[:, None, None, :])
19 |
20 | attn = nn.functional.softmax(dot, dim=-1)
21 | attn = self.drop_attn(attn)
22 |
23 | X = torch.matmul(attn, V)
24 | return X
25 |
26 |
27 | class NoneAttention(nn.Module):
28 | def __init__(self, config):
29 | super().__init__()
30 |
31 | def forward(self, Q, K, V, mask):
32 | return V
33 |
34 |
35 | class Attention(nn.Module):
36 | def __init__(self, config):
37 | super().__init__()
38 |
39 | self.grad_checkpointing = config["attention_grad_checkpointing"]
40 |
41 | self.dim = config["transformer_dim"]
42 | self.head_dim = config["head_dim"]
43 | self.num_head = config["num_head"]
44 |
45 | self.attn_type = config["attn_type"]
46 |
47 | self.W_q = nn.Linear(self.dim, self.num_head * self.head_dim)
48 | self.W_k = nn.Linear(self.dim, self.num_head * self.head_dim)
49 | self.W_v = nn.Linear(self.dim, self.num_head * self.head_dim)
50 |
51 | if self.attn_type == "softmax":
52 | self.attn = SoftmaxAttention(config)
53 | elif self.attn_type == "none":
54 | self.attn = NoneAttention(config)
55 | elif self.attn_type.startswith("linformer"):
56 | from attention_linformer import LinformerAttention
57 | self.attn = LinformerAttention(config)
58 | elif self.attn_type.startswith("reformer"):
59 | from attention_reformer import LSHAttention
60 | self.attn = LSHAttention(config, self.W_q, self.W_k, self.W_v)
61 | elif self.attn_type.startswith("nystrom"):
62 | from attention_nystrom import NystromAttention
63 | self.attn = NystromAttention(config)
64 | elif self.attn_type.startswith("fnet"):
65 | from attention_fnet import FFTAttention
66 | self.attn = FFTAttention(config)
67 | elif self.attn_type.startswith("ponet"):
68 | from attention_ponet import PoNetAttention
69 | self.attn = PoNetAttention(config)
70 | self.W_local = nn.Linear(self.dim, self.dim)
71 | self.W_segment = nn.Linear(self.dim, self.dim)
72 | elif self.attn_type.startswith("performer"):
73 | from attention_performer import PerformerAttention
74 | self.attn = PerformerAttention(config)
75 | elif self.attn_type.startswith("linear"):
76 | from attention_linear import LinearAttention
77 | self.attn = LinearAttention(config)
78 |
79 | self.ff = nn.Linear(self.num_head * self.head_dim, self.dim)
80 | self.sp_ff = nn.Identity()
81 |
82 | def forward(self, X, mask):
83 |
84 | if self.attn_type.startswith("longformer") or self.attn_type.startswith("reformer") or self.attn_type.startswith("fnet"):
85 | with torch.cuda.amp.autocast(enabled=False):
86 | attn_out = self.attn(X.float(), mask.float())
87 | elif self.attn_type.startswith("ponet"):
88 | Q = self.split_heads(self.W_q(X))
89 | K = self.split_heads(self.W_k(X))
90 | O = self.split_heads(self.W_v(X))
91 |
92 | local = self.W_local(X)
93 | segment = self.W_segment(X)
94 | with torch.cuda.amp.autocast(enabled=False):
95 | attn_out = self.attn(X.float(), Q.float(), K.float(), O.float(), local.float(), segment.float(), mask.float())
96 | else:
97 | Q = self.split_heads(self.W_q(X))
98 | K = self.split_heads(self.W_k(X))
99 | V = self.split_heads(self.W_v(X))
100 | with torch.cuda.amp.autocast(enabled=False):
101 | if self.grad_checkpointing:
102 | attn_out = checkpoint(self.attn, Q.float(), K.float(), V.float(), mask.float())
103 | else:
104 | attn_out = self.attn(Q.float(), K.float(), V.float(), mask.float())
105 | attn_out = self.combine_heads(attn_out)
106 |
107 | if self.attn_type.startswith("fnet"):
108 | self.ff = self.sp_ff
109 | out = self.ff(attn_out)
110 |
111 | return out
112 |
113 | def combine_heads(self, X):
114 | X = X.transpose(1, 2)
115 | X = X.reshape(X.size(0), X.size(1), self.num_head * self.head_dim)
116 | return X
117 |
118 | def split_heads(self, X):
119 | X = X.reshape(X.size(0), X.size(1), self.num_head, self.head_dim)
120 | X = X.transpose(1, 2)
121 | return X
122 |
--------------------------------------------------------------------------------
/examples/LRA/code/attention_fnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 |
6 | class FFTAttention(nn.Module):
7 |
8 | def __init__(self, config):
9 | super().__init__()
10 |
11 | def forward(self, X, mask):
12 | X = torch.fft.fft(torch.fft.fft(X, dim=-1), dim=-2).real
13 |
14 | return X
15 |
--------------------------------------------------------------------------------
/examples/LRA/code/attention_linear.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import math
5 |
6 |
7 | class LinearAttention(nn.Module):
8 |
9 | def __init__(self, config):
10 | super().__init__()
11 |
12 | def forward(self, Q, K, V, mask):
13 | Q = (nn.functional.elu(Q) + 1) / math.sqrt(math.sqrt(Q.size(2)))
14 | K = (nn.functional.elu(K) + 1) * mask[:, None, :, None] / math.sqrt(math.sqrt(K.size(2)))
15 | V = V * mask[:, None, :, None]
16 |
17 | X = torch.matmul(Q, torch.matmul(torch.transpose(K, -2, -1), V))
18 |
19 | return X
20 |
--------------------------------------------------------------------------------
/examples/LRA/code/attention_linformer.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import math
5 |
6 |
7 | class LinformerAttention(nn.Module):
8 | projection_matrix = None
9 |
10 | def __init__(self, config):
11 | super().__init__()
12 |
13 | self.num_head = config["num_head"]
14 | self.head_dim = config["head_dim"]
15 | self.linformer_k = config["linformer_k"]
16 | self.seq_len = config["max_seq_len"]
17 |
18 | if LinformerAttention.projection_matrix is not None:
19 | self.E = LinformerAttention.projection_matrix
20 | else:
21 | LinformerAttention.projection_matrix = nn.Parameter(torch.Tensor(self.num_head, self.linformer_k, self.seq_len))
22 | torch.nn.init.normal_(LinformerAttention.projection_matrix, std=0.02)
23 | self.E = LinformerAttention.projection_matrix
24 |
25 | def forward(self, Q, K, V, mask):
26 | K = torch.matmul(self.E, K * mask[:, None, :, None])
27 | V = torch.matmul(self.E, V * mask[:, None, :, None])
28 |
29 | dot = torch.matmul(Q, torch.transpose(K, -2, -1))
30 | dot = dot / math.sqrt(self.head_dim)
31 |
32 | attn = nn.functional.softmax(dot, dim=-1)
33 |
34 | X = torch.matmul(attn, V)
35 |
36 | return X
37 |
38 | def extra_repr(self):
39 | return f'linformer_k={self.linformer_k}'
40 |
--------------------------------------------------------------------------------
/examples/LRA/code/attention_nystrom.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import math
5 |
6 |
7 | class NystromAttention(nn.Module):
8 | def __init__(self, config):
9 | super().__init__()
10 |
11 | self.head_dim = config["head_dim"]
12 | self.num_head = config["num_head"]
13 |
14 | self.num_landmarks = config["num_landmarks"]
15 | self.seq_len = config["seq_len"]
16 |
17 | if "inv_coeff_init_option" in config:
18 | self.init_option = config["inv_init_coeff_option"]
19 | else:
20 | self.init_option = "original"
21 |
22 | self.use_conv = "conv_kernel_size" in config
23 | if self.use_conv:
24 | self.conv = nn.Conv2d(
25 | in_channels=self.num_head, out_channels=self.num_head,
26 | kernel_size=(config["conv_kernel_size"], 1), padding=(config["conv_kernel_size"] // 2, 0),
27 | bias=False,
28 | groups=self.num_head)
29 |
30 | def forward(self, Q, K, V, mask):
31 |
32 | Q = Q * mask[:, None, :, None] / math.sqrt(math.sqrt(self.head_dim))
33 | K = K * mask[:, None, :, None] / math.sqrt(math.sqrt(self.head_dim))
34 |
35 | if self.num_landmarks == self.seq_len:
36 | attn = torch.nn.functional.softmax(torch.matmul(Q, K.transpose(-1, -2)) - 1e9 * (1 - mask[:, None, None, :]), dim=-1)
37 | X = torch.matmul(attn, V)
38 | else:
39 | Q_landmarks = Q.reshape(-1, self.num_head, self.num_landmarks, self.seq_len // self.num_landmarks, self.head_dim).mean(dim=-2)
40 | K_landmarks = K.reshape(-1, self.num_head, self.num_landmarks, self.seq_len // self.num_landmarks, self.head_dim).mean(dim=-2)
41 |
42 | kernel_1 = torch.nn.functional.softmax(torch.matmul(Q, K_landmarks.transpose(-1, -2)), dim=-1)
43 | kernel_2 = torch.nn.functional.softmax(torch.matmul(Q_landmarks, K_landmarks.transpose(-1, -2)), dim=-1)
44 | kernel_3 = torch.nn.functional.softmax(torch.matmul(Q_landmarks, K.transpose(-1, -2)) - 1e9 * (1 - mask[:, None, None, :]), dim=-1)
45 | X = torch.matmul(torch.matmul(kernel_1, self.iterative_inv(kernel_2)), torch.matmul(kernel_3, V))
46 |
47 | if self.use_conv:
48 | X += self.conv(V * mask[:, None, :, None])
49 |
50 | return X
51 |
52 | def iterative_inv(self, mat, n_iter=6):
53 | I = torch.eye(mat.size(-1), device=mat.device)
54 | K = mat
55 |
56 | if self.init_option == "original":
57 | V = 1 / torch.max(torch.sum(K, dim=-2)) * K.transpose(-1, -2)
58 | else:
59 | V = 1 / torch.max(torch.sum(K, dim=-2), dim=-1).values[:, :, None, None] * K.transpose(-1, -2)
60 |
61 | for _ in range(n_iter):
62 | KV = torch.matmul(K, V)
63 | V = torch.matmul(0.25 * V, 13 * I - torch.matmul(KV, 15 * I - torch.matmul(KV, 7 * I - KV)))
64 | return V
65 |
66 | def extra_repr(self):
67 | return f'num_landmarks={self.num_landmarks}, seq_len={self.seq_len}'
68 |
--------------------------------------------------------------------------------
/examples/LRA/code/attention_performer.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import math
5 | from performer_pytorch import FastAttention
6 |
7 |
8 | class PerformerAttention(nn.Module):
9 | def __init__(self, config):
10 | super().__init__()
11 |
12 | self.head_dim = config["head_dim"]
13 | self.rp_dim = config["rp_dim"]
14 | self.kernel_type = config["kernel_type"]
15 | if self.kernel_type == "relu":
16 | self.attn_fn = FastAttention(dim_heads=self.head_dim, nb_features=self.rp_dim, causal=False, kernel_fn=nn.ReLU())
17 | elif self.kernel_type == "exp":
18 | self.attn_fn = FastAttention(dim_heads=self.head_dim, nb_features=self.rp_dim, causal=False, kernel_fn=torch.exp)
19 |
20 | def forward(self, Q, K, V, mask):
21 | return self.attn_fn(
22 | Q / math.sqrt(math.sqrt(self.head_dim)),
23 | K / math.sqrt(math.sqrt(self.head_dim)) * mask[:, None, :, None],
24 | V * mask[:, None, :, None])
25 |
26 | def extra_repr(self):
27 | return f'rp_dim={self.rp_dim}, kernel_type={self.kernel_type}'
28 |
--------------------------------------------------------------------------------
/examples/LRA/code/attention_ponet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import math
4 |
5 | def segment_max(src, index, dim=1):
6 | out = torch.zeros_like(src).scatter_reduce(
7 | dim, index.unsqueeze(-1).expand_as(src), src, reduce="amax", include_self=False
8 | )
9 | dummy = index.unsqueeze(-1).expand(*index.shape[:2], out.size(-1))
10 | return torch.gather(out, dim, dummy).to(dtype=src.dtype)
11 |
12 | def get_win_max(hidden_states, kernel_size=3):
13 | m = nn.MaxPool1d(kernel_size, stride=1, padding=kernel_size//2)
14 | out = m(hidden_states.permute(0,2,1)).permute(0,2,1)
15 | return out
16 |
17 | class PoNetAttention(nn.Module):
18 | def __init__(self, config):
19 | super().__init__()
20 |
21 | self.num_head = config["num_head"]
22 | self.head_dim = config["head_dim"]
23 | self.segment_num = config["segment_num"]
24 |
25 | #XXX: More natural implementation.
26 | def get_segment_index(self, input_ids):
27 | sentence_len = input_ids.shape[1]
28 | segment_num = self.segment_num
29 | segment_len = sentence_len // segment_num + 1
30 | mask = torch.arange(start=0, end=segment_num, dtype=torch.long, device=input_ids.device).view(1, segment_num, 1).repeat(1, 1, segment_len).view(1, -1)[:, :sentence_len].repeat(input_ids.shape[0], 1)
31 | return mask
32 |
33 |
34 | def forward(self, hidden_states, Q, K, O, local, segment, attention_mask):
35 | # bdlh
36 | context_layer_q = Q
37 | context_layer_k = K
38 | context_layer_v = context_layer_k
39 | context_layer_o = O
40 |
41 | if attention_mask is not None:
42 | _attention_mask = (attention_mask[:,None,:,None] < 0.5)
43 |
44 | if attention_mask is not None:
45 | context_layer_q.masked_fill_(_attention_mask, 0.0)
46 | q = context_layer_q.sum(dim=-2) / torch.ones_like(_attention_mask).to(dtype=context_layer_q.dtype).masked_fill(_attention_mask, 0.0).sum(dim=-2)
47 | else:
48 | q = context_layer_q.mean(dim=-2)
49 | att = torch.einsum("bdh,bdlh -> bdl", q, context_layer_k) / math.sqrt(context_layer_q.shape[-1])
50 |
51 | if attention_mask is not None:
52 | att.masked_fill_(_attention_mask.squeeze(-1), -10000)
53 | att_prob = att.softmax(dim=-1)
54 |
55 | v = torch.einsum('bdlh,bdl->bdh', context_layer_v, att_prob)
56 |
57 | context_layer_segment = segment
58 | context_layer_local = local
59 | if attention_mask is not None:
60 | _attention_mask = _attention_mask.squeeze(1)
61 | context_layer_segment.masked_fill_(_attention_mask, -10000)
62 | context_layer_local.masked_fill_(_attention_mask, -10000)
63 |
64 | context_layer_local = get_win_max(context_layer_local)
65 | segment_index = self.get_segment_index(hidden_states)
66 | context_layer_segment = segment_max(context_layer_segment, index=segment_index)
67 |
68 | context_layer_local = context_layer_local.view(*context_layer_local.shape[:2], self.num_head, self.head_dim).permute(0, 2, 1, 3)
69 | context_layer_segment = context_layer_segment.view(*context_layer_segment.shape[:2], self.num_head, self.head_dim).permute(0, 2, 1, 3)
70 |
71 | context_layer = (v.unsqueeze(dim=-2) + context_layer_segment) * context_layer_o + context_layer_local
72 | context_layer = context_layer.permute(0, 2, 1, 3).reshape(*hidden_states.shape[:2], -1)
73 | if attention_mask is not None:
74 | context_layer.masked_fill_(_attention_mask, 0.0)
75 |
76 | outputs = context_layer
77 | return outputs
--------------------------------------------------------------------------------
/examples/LRA/code/attention_reformer.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | from transformers.models.reformer.modeling_reformer import LSHSelfAttention, ReformerConfig
5 |
6 |
7 | class LSHAttention(LSHSelfAttention):
8 | def __init__(self, config, query, key, value):
9 | self.num_hash = config["num_hash"]
10 | reformer_config = ReformerConfig()
11 | reformer_config.attention_head_size = config["head_dim"]
12 | reformer_config.num_attention_heads = config["num_head"]
13 | reformer_config.attn_layers = ["lsh"]
14 | reformer_config.num_hashes = config["num_hash"]
15 | reformer_config.is_decoder = False
16 | reformer_config.max_position_embeddings = config["max_seq_len"]
17 | reformer_config.hidden_size = config["transformer_dim"]
18 | super().__init__(reformer_config)
19 | self.query_key.weight = query.weight
20 | self.value.weight = value.weight
21 |
22 | def forward(self, X, mask):
23 | return super().forward(hidden_states=X, attention_mask=mask).hidden_states
24 |
25 | def extra_repr(self):
26 | return f'num_hash={self.num_hash}'
27 |
--------------------------------------------------------------------------------
/examples/LRA/code/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | from torch.utils.data.dataset import Dataset
5 | import sys
6 | import os
7 | import random
8 | import json
9 | import pickle
10 | import numpy as np
11 |
12 | class LRADataset(Dataset):
13 | def __init__(self, file_path, endless):
14 |
15 | self.endless = endless
16 | with open(file_path, "rb") as f:
17 | self.examples = pickle.load(f)
18 | random.shuffle(self.examples)
19 | self.curr_idx = 0
20 |
21 | print(f"Loaded {file_path}... size={len(self.examples)}", flush = True)
22 |
23 | def __len__(self):
24 | if self.endless:
25 | return 1000000000
26 | else:
27 | return len(self.examples)
28 |
29 | def create_inst(self, inst):
30 | output = {}
31 | output["input_ids_0"] = torch.tensor(inst["input_ids_0"], dtype = torch.long)
32 | output["mask_0"] = (output["input_ids_0"] != 0).float()
33 | if "input_ids_1" in inst:
34 | output["input_ids_1"] = torch.tensor(inst["input_ids_1"], dtype = torch.long)
35 | output["mask_1"] = (output["input_ids_1"] != 0).float()
36 | output["label"] = torch.tensor(inst["label"], dtype = torch.long)
37 | return output
38 |
39 | def __getitem__(self, i):
40 | if not self.endless:
41 | return self.create_inst(self.examples[i])
42 |
43 | if self.curr_idx >= len(self.examples):
44 | random.shuffle(self.examples)
45 | self.curr_idx = 0
46 | inst = self.examples[self.curr_idx]
47 | self.curr_idx += 1
48 |
49 | return self.create_inst(inst)
50 |
--------------------------------------------------------------------------------
/examples/LRA/code/lra_config.py:
--------------------------------------------------------------------------------
1 |
2 | config = {
3 | "listops":{
4 | "dataset":{
5 | "train":96000,
6 | "dev":2000,
7 | "test":2000,
8 | },
9 | "model":{
10 | "learn_pos_emb":True,
11 | "tied_weights":False,
12 | "embedding_dim":64,
13 | "transformer_dim":64,
14 | "transformer_hidden_dim":128,
15 | "head_dim":32,
16 | "num_head":2,
17 | "num_layers":2,
18 | "vocab_size":32,
19 | "max_seq_len":2000,
20 | "dropout_prob":0.1,
21 | "attention_dropout":0.1,
22 | "pooling_mode":"MEAN",
23 | "num_classes":10,
24 | },
25 | "training":{
26 | "batch_size":32,
27 | "learning_rate":0.0001,
28 | "warmup":1000,
29 | "lr_decay":"linear",
30 | "weight_decay":0,
31 | "eval_frequency":50,
32 | "num_train_steps":5000,
33 | "num_eval_steps":62,
34 | },
35 | "gpu_memory":{
36 | "softmax":32,
37 | "nystrom-32":32,
38 | "nystrom-64":32,
39 | "nystrom-128":32,
40 | "nystrom-256":32,
41 | "ponet": 32,
42 | "fnet": 32,
43 | "linformer-256":32,
44 | "reformer-2":32,
45 | "performer-256":32,
46 | "linear":32,
47 | },
48 | "extra_attn_config":{
49 | "softmax":{"attention_grad_checkpointing":True},
50 | "nystrom-32":{"attention_grad_checkpointing":False, "num_landmarks":32, "conv_kernel_size":35},
51 | "nystrom-64":{"attention_grad_checkpointing":False, "num_landmarks":64, "conv_kernel_size":35},
52 | "nystrom-128":{"attention_grad_checkpointing":False, "num_landmarks":128, "conv_kernel_size":35},
53 | "nystrom-256":{"attention_grad_checkpointing":False, "num_landmarks":256, "conv_kernel_size":35},
54 | "ponet":{"attention_grad_checkpointing":False, "segment_num":64},
55 | "fnet":{"attention_grad_checkpointing":False},
56 | "linformer-256":{"attention_grad_checkpointing":False, "linformer_k":256},
57 | "reformer-2":{"attention_grad_checkpointing":False, "num_hash":2},
58 | "performer-256":{"attention_grad_checkpointing":False, "rp_dim":256, "kernel_type":"relu"},
59 | "linear":{"attention_grad_checkpointing":False},
60 | }
61 | },
62 | "image":{
63 | "dataset":{
64 | "train":45000,
65 | "dev":5000,
66 | "test":10000,
67 | },
68 | "model":{
69 | "learn_pos_emb":True,
70 | "tied_weights":False,
71 | "embedding_dim":64,
72 | "transformer_dim":64,
73 | "transformer_hidden_dim":128,
74 | "head_dim":32,
75 | "num_head":2,
76 | "num_layers":2,
77 | "vocab_size":512,
78 | "max_seq_len":1024,
79 | "dropout_prob":0.1,
80 | "attention_dropout":0.1,
81 | "pooling_mode":"MEAN",
82 | "num_classes": 10,
83 | },
84 | "training":{
85 | "batch_size":256,
86 | "learning_rate":0.0001,
87 | "warmup":175,
88 | "lr_decay":"linear",
89 | "weight_decay":0,
90 | "eval_frequency":175,
91 | "num_train_steps":35000,
92 | "num_eval_steps":20,
93 | },
94 | "gpu_memory":{
95 | "softmax":128,
96 | "nystrom-32":128,
97 | "nystrom-64":128,
98 | "nystrom-128":128,
99 | "nystrom-256":128,
100 | "ponet": 128,
101 | "fnet": 128,
102 | "linformer-256":128,
103 | "reformer-2":128,
104 | "performer-256":128,
105 | "linear":128,
106 | },
107 | "extra_attn_config":{
108 | "softmax":{"attention_grad_checkpointing":True},
109 | "nystrom-32":{"attention_grad_checkpointing":False, "num_landmarks":32, "conv_kernel_size":35},
110 | "nystrom-64":{"attention_grad_checkpointing":False, "num_landmarks":64, "conv_kernel_size":35},
111 | "nystrom-128":{"attention_grad_checkpointing":False, "num_landmarks":128, "conv_kernel_size":35},
112 | "nystrom-256":{"attention_grad_checkpointing":False, "num_landmarks":256, "conv_kernel_size":35},
113 | "ponet":{"attention_grad_checkpointing":False, "segment_num":32},
114 | "fnet":{"attention_grad_checkpointing":False},
115 | "linformer-256":{"attention_grad_checkpointing":False, "linformer_k":256},
116 | "reformer-2":{"attention_grad_checkpointing":False, "num_hash":2},
117 | "performer-256":{"attention_grad_checkpointing":False, "rp_dim":256, "kernel_type":"relu"},
118 | "linear":{"attention_grad_checkpointing":False},
119 | }
120 | },
121 | "pathfinder32":{
122 | "model":{
123 | "learn_pos_emb":True,
124 | "tied_weights":False,
125 | "embedding_dim":64,
126 | "transformer_dim":64,
127 | "transformer_hidden_dim":128,
128 | "head_dim":32,
129 | "num_head":2,
130 | "num_layers":2,
131 | "vocab_size":512,
132 | "max_seq_len":1024,
133 | "dropout_prob":0.1,
134 | "attention_dropout":0.1,
135 | "pooling_mode":"MEAN",
136 | "num_classes": 2,
137 | },
138 | "training":{
139 | "batch_size":256,
140 | "learning_rate":0.0001,
141 | "warmup":312,
142 | "lr_decay":"linear",
143 | "weight_decay":0,
144 | "eval_frequency":312,
145 | "num_train_steps":62400,
146 | "num_eval_steps":312,
147 | },
148 | "gpu_memory":{
149 | "softmax":128,
150 | "nystrom-32":128,
151 | "nystrom-64":128,
152 | "nystrom-128":128,
153 | "nystrom-256":128,
154 | "ponet":128,
155 | "fnet":128,
156 | "linformer-256":128,
157 | "reformer-2":128,
158 | "performer-256":128,
159 | "linear":128,
160 | },
161 | "extra_attn_config":{
162 | "softmax":{"attention_grad_checkpointing":True},
163 | "nystrom-32":{"attention_grad_checkpointing":False, "num_landmarks":32, "conv_kernel_size":35},
164 | "nystrom-64":{"attention_grad_checkpointing":False, "num_landmarks":64, "conv_kernel_size":35},
165 | "nystrom-128":{"attention_grad_checkpointing":False, "num_landmarks":128, "conv_kernel_size":35},
166 | "nystrom-256":{"attention_grad_checkpointing":False, "num_landmarks":256, "conv_kernel_size":35},
167 | "ponet":{"attention_grad_checkpointing":False, "segment_num":1},
168 | "fnet":{"attention_grad_checkpointing":False},
169 | "linformer-256":{"attention_grad_checkpointing":False, "linformer_k":256},
170 | "reformer-2":{"attention_grad_checkpointing":False, "num_hash":2},
171 | "performer-256":{"attention_grad_checkpointing":False, "rp_dim":256, "kernel_type":"relu"},
172 | "linear":{"attention_grad_checkpointing":False},
173 | }
174 | },
175 | "retrieval":{
176 | "dataset":{
177 | "train":147086,
178 | "dev":18090,
179 | "test":17437,
180 | },
181 | "model":{
182 | "learn_pos_emb":True,
183 | "tied_weights":False,
184 | "embedding_dim":64,
185 | "transformer_dim":64,
186 | "transformer_hidden_dim":128,
187 | "head_dim":32,
188 | "num_head":2,
189 | "num_layers":2,
190 | "vocab_size":512,
191 | "max_seq_len":4000,
192 | "dropout_prob":0.1,
193 | "attention_dropout":0.1,
194 | "pooling_mode":"MEAN",
195 | "num_classes": 2,
196 | },
197 | "training":{
198 | "batch_size":32,
199 | "learning_rate":0.0001,
200 | "warmup":800,
201 | "lr_decay":"linear",
202 | "weight_decay":0,
203 | "eval_frequency":300,
204 | "num_train_steps":30000,
205 | "num_eval_steps":565,
206 | },
207 | "gpu_memory":{
208 | "softmax":32,
209 | "nystrom-32":32,
210 | "nystrom-64":32,
211 | "nystrom-128":32,
212 | "nystrom-256":32,
213 | "ponet":32,
214 | "fnet":32,
215 | "linformer-256":32,
216 | "reformer-2":32,
217 | "performer-256":32,
218 | "linear":32,
219 | },
220 | "extra_attn_config":{
221 | "softmax":{"attention_grad_checkpointing":True},
222 | "nystrom-32":{"attention_grad_checkpointing":False, "num_landmarks":32, "conv_kernel_size":35},
223 | "nystrom-64":{"attention_grad_checkpointing":False, "num_landmarks":64, "conv_kernel_size":35},
224 | "nystrom-128":{"attention_grad_checkpointing":False, "num_landmarks":128, "conv_kernel_size":35},
225 | "nystrom-256":{"attention_grad_checkpointing":False, "num_landmarks":256, "conv_kernel_size":35},
226 | "ponet":{"attention_grad_checkpointing":False, "segment_num":64},
227 | "fnet":{"attention_grad_checkpointing":False},
228 | "linformer-256":{"attention_grad_checkpointing":False, "linformer_k":256},
229 | "reformer-2":{"attention_grad_checkpointing":False, "num_hash":2},
230 | "performer-256":{"attention_grad_checkpointing":False, "rp_dim":256, "kernel_type":"relu"},
231 | "linear":{"attention_grad_checkpointing":False},
232 | }
233 | },
234 | "text":{
235 | "dataset":{
236 | "train":25000,
237 | "dev":25000,
238 | "test":25000,
239 | },
240 | "model":{
241 | "learn_pos_emb":True,
242 | "tied_weights":False,
243 | "embedding_dim":64,
244 | "transformer_dim":64,
245 | "transformer_hidden_dim":128,
246 | "head_dim":32,
247 | "num_head":2,
248 | "num_layers":2,
249 | "vocab_size":512,
250 | "max_seq_len":4000,
251 | "dropout_prob":0.1,
252 | "attention_dropout":0.1,
253 | "pooling_mode":"MEAN",
254 | "num_classes": 2,
255 | },
256 | "training":{
257 | "batch_size":32,
258 | "learning_rate":0.0001,
259 | "warmup":8000,
260 | "lr_decay":"linear",
261 | "weight_decay":0,
262 | "eval_frequency":500,
263 | "num_train_steps":20000,
264 | "num_eval_steps":781,
265 | },
266 | "gpu_memory":{
267 | "softmax":32,
268 | "nystrom-32":32,
269 | "nystrom-64":32,
270 | "nystrom-128":32,
271 | "nystrom-256":32,
272 | "ponet":32,
273 | "fnet":32,
274 | "linformer-256":32,
275 | "reformer-2":32,
276 | "performer-256":32,
277 | "linear":32,
278 | },
279 | "extra_attn_config":{
280 | "softmax":{"attention_grad_checkpointing":True},
281 | "nystrom-32":{"attention_grad_checkpointing":False, "num_landmarks":32, "conv_kernel_size":35},
282 | "nystrom-64":{"attention_grad_checkpointing":False, "num_landmarks":64, "conv_kernel_size":35},
283 | "nystrom-128":{"attention_grad_checkpointing":False, "num_landmarks":128, "conv_kernel_size":35},
284 | "nystrom-256":{"attention_grad_checkpointing":False, "num_landmarks":256, "conv_kernel_size":35},
285 | "ponet":{"attention_grad_checkpointing":False, "segment_num":4096}, # Here we write down 4096 for the function get_segment_index. Actually segment num is 2048.
286 | "fnet":{"attention_grad_checkpointing":False},
287 | "linformer-256":{"attention_grad_checkpointing":False, "linformer_k":256},
288 | "reformer-2":{"attention_grad_checkpointing":False, "num_hash":2},
289 | "performer-256":{"attention_grad_checkpointing":False, "rp_dim":256, "kernel_type":"relu"},
290 | "linear":{"attention_grad_checkpointing":False},
291 | }
292 | }
293 | }
294 |
295 | config["pathfinder32-curv_baseline"] = config["pathfinder32"]
296 | config["pathfinder32-curv_contour_length_9"] = config["pathfinder32"]
297 | config["pathfinder32-curv_contour_length_14"] = config["pathfinder32"]
298 |
--------------------------------------------------------------------------------
/examples/LRA/code/model.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | import math
6 | from torch.utils.checkpoint import checkpoint
7 | from attention import Attention
8 |
9 |
10 | class Embeddings(nn.Module):
11 | def __init__(self, config):
12 | super().__init__()
13 |
14 | assert config["embedding_dim"] == config["transformer_dim"]
15 |
16 | self.dim = config["embedding_dim"]
17 |
18 | self.word_embeddings = nn.Embedding(config["vocab_size"], config["embedding_dim"])
19 | torch.nn.init.normal_(self.word_embeddings.weight, std=0.02)
20 |
21 | self.position_embeddings = nn.Embedding(config["max_seq_len"], config["embedding_dim"])
22 | torch.nn.init.normal_(self.position_embeddings.weight, std=0.02)
23 |
24 | self.dropout = torch.nn.Dropout(p=config["dropout_prob"])
25 |
26 | def fixed_pos_emb(self, seq_len, device):
27 | position = torch.arange(0, seq_len, device=device)[:, np.newaxis]
28 | div_term = torch.exp(torch.arange(0, self.dim, 2, device=device) * -(math.log(10000.0) / self.dim))
29 | pos_embed = torch.stack([torch.sin(position * div_term), torch.cos(position * div_term)], -1).reshape(seq_len, -1)
30 | return pos_embed
31 |
32 | def forward(self, input_ids):
33 |
34 | batch_size, seq_len = input_ids.size()
35 |
36 | X_token = self.word_embeddings(input_ids)
37 |
38 | position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)[None, :].repeat(batch_size, 1)
39 | X_pos = self.position_embeddings(position_ids)
40 |
41 | X = X_token + X_pos
42 |
43 | X = self.dropout(X)
44 |
45 | return X
46 |
47 |
48 | class Transformer(nn.Module):
49 | def __init__(self, config):
50 | super().__init__()
51 |
52 | self.norm1 = nn.LayerNorm(config["transformer_dim"])
53 | self.mha = Attention(config)
54 | self.dropout1 = torch.nn.Dropout(p=config["dropout_prob"])
55 | self.norm2 = nn.LayerNorm(config["transformer_dim"])
56 |
57 | self.mlpblock = nn.Sequential(
58 | nn.Linear(config["transformer_dim"], config["transformer_hidden_dim"]),
59 | nn.GELU(),
60 | torch.nn.Dropout(p=config["dropout_prob"]),
61 | nn.Linear(config["transformer_hidden_dim"], config["transformer_dim"]),
62 | torch.nn.Dropout(p=config["dropout_prob"])
63 | )
64 |
65 | def forward(self, X, mask):
66 | X = self.dropout1(self.mha(self.norm1(X), mask)) + X
67 | X = self.mlpblock(self.norm2(X)) + X
68 | return X
69 |
70 |
71 | class Model(nn.Module):
72 | def __init__(self, config):
73 | super().__init__()
74 |
75 | self.num_layers = config["num_layers"]
76 | self.tied_weights = config["tied_weights"]
77 |
78 | self.embeddings = Embeddings(config)
79 |
80 | if self.tied_weights:
81 | self.transformer = Transformer(config)
82 | else:
83 | for idx in range(self.num_layers):
84 | setattr(self, f"transformer_{idx}", Transformer(config))
85 |
86 | self.norm = nn.LayerNorm(config["transformer_dim"])
87 |
88 | def forward(self, input_ids, mask=None):
89 |
90 | X = self.embeddings(input_ids)
91 |
92 | if mask is None:
93 | mask = torch.ones_like(input_ids)
94 |
95 | if self.tied_weights:
96 | for idx in range(self.num_layers):
97 | X = self.transformer(X, mask)
98 | else:
99 | for idx in range(self.num_layers):
100 | X = getattr(self, f"transformer_{idx}")(X, mask)
101 |
102 | X = self.norm(X) * mask[:, :, None]
103 |
104 | return X
105 |
--------------------------------------------------------------------------------
/examples/LRA/code/model_wrapper.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import math
5 | from model import Model
6 |
7 |
8 | def pooling(inp, mode):
9 | if mode == "CLS":
10 | pooled = inp[:, 0, :]
11 | elif mode == "MEAN":
12 | pooled = inp.mean(dim=1)
13 | else:
14 | raise Exception()
15 | return pooled
16 |
17 |
18 | def append_cls(inp, mask, vocab_size):
19 | batch_size = inp.size(0)
20 | cls_id = ((vocab_size - 1) * torch.ones(batch_size, dtype=torch.long, device=inp.device)).long()
21 | cls_mask = torch.ones(batch_size, dtype=torch.float, device=mask.device)
22 | inp = torch.cat([cls_id[:, None], inp[:, :-1]], dim=-1)
23 | mask = torch.cat([cls_mask[:, None], mask[:, :-1]], dim=-1)
24 | return inp, mask
25 |
26 |
27 | class SCHead(nn.Module):
28 | def __init__(self, config):
29 | super().__init__()
30 | self.pooling_mode = config["pooling_mode"]
31 | self.mlpblock = nn.Sequential(
32 | nn.Linear(config["transformer_dim"], config["transformer_hidden_dim"]),
33 | nn.ReLU(),
34 | nn.Linear(config["transformer_hidden_dim"], config["num_classes"])
35 | )
36 |
37 | def forward(self, inp):
38 | seq_score = self.mlpblock(pooling(inp, self.pooling_mode))
39 | return seq_score
40 |
41 |
42 | class ModelForSC(nn.Module):
43 | def __init__(self, config):
44 | super().__init__()
45 |
46 | self.enable_amp = config["mixed_precision"]
47 | self.pooling_mode = config["pooling_mode"]
48 | self.vocab_size = config["vocab_size"]
49 |
50 | self.model = Model(config)
51 |
52 | self.seq_classifer = SCHead(config)
53 |
54 | def forward(self, input_ids_0, mask_0, label):
55 |
56 | with torch.cuda.amp.autocast(enabled=self.enable_amp):
57 |
58 | if self.pooling_mode == "CLS":
59 | input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
60 |
61 | token_out = self.model(input_ids_0, mask_0)
62 | seq_scores = self.seq_classifer(token_out)
63 |
64 | seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
65 | seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
66 | outputs = {}
67 | outputs["loss"] = seq_loss
68 | outputs["accu"] = seq_accu
69 |
70 | return outputs
71 |
72 |
73 | class SCHeadDual(nn.Module):
74 | def __init__(self, config):
75 | super().__init__()
76 | self.pooling_mode = config["pooling_mode"]
77 | self.mlpblock = nn.Sequential(
78 | nn.Linear(config["transformer_dim"] * 4, config["transformer_hidden_dim"]),
79 | nn.ReLU(),
80 | nn.Linear(config["transformer_hidden_dim"], config["num_classes"])
81 | )
82 |
83 | def forward(self, inp_0, inp_1):
84 | X_0 = pooling(inp_0, self.pooling_mode)
85 | X_1 = pooling(inp_1, self.pooling_mode)
86 | seq_score = self.mlpblock(torch.cat([X_0, X_1, X_0 * X_1, X_0 - X_1], dim=-1))
87 | return seq_score
88 |
89 |
90 | class ModelForSCDual(nn.Module):
91 | def __init__(self, config):
92 | super().__init__()
93 |
94 | self.enable_amp = config["mixed_precision"]
95 | self.pooling_mode = config["pooling_mode"]
96 | self.vocab_size = config["vocab_size"]
97 |
98 | self.model = Model(config)
99 |
100 | self.seq_classifer = SCHeadDual(config)
101 |
102 | def forward(self, input_ids_0, input_ids_1, mask_0, mask_1, label):
103 |
104 | with torch.cuda.amp.autocast(enabled=self.enable_amp):
105 |
106 | if self.pooling_mode == "CLS":
107 | input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
108 | input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, self.vocab_size)
109 |
110 | token_out_0 = self.model(input_ids_0, mask_0)
111 | token_out_1 = self.model(input_ids_1, mask_1)
112 | seq_scores = self.seq_classifer(token_out_0, token_out_1)
113 |
114 | seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
115 | seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
116 | outputs = {}
117 | outputs["loss"] = seq_loss
118 | outputs["accu"] = seq_accu
119 |
120 | return outputs
121 |
--------------------------------------------------------------------------------
/examples/LRA/code/run_tasks.py:
--------------------------------------------------------------------------------
1 | from model_wrapper import ModelForSC, ModelForSCDual
2 | from dataset import LRADataset
3 | from torch.utils.data import DataLoader
4 | import torch
5 | import torch.nn as nn
6 | import time
7 | import os
8 | import json
9 | import pickle
10 | import numpy as np
11 | import argparse
12 | import math
13 | import itertools
14 | import lra_config
15 | from transformers import set_seed
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument("--model", type=str, help="model", dest="model", required=True)
19 | parser.add_argument("--task", type=str, help="task", dest="task", required=True)
20 | parser.add_argument("--skip_train", type=int, help="skip_train", dest="skip_train", default=0)
21 | parser.add_argument("--output", type=str, help="output dir", required=True)
22 | parser.add_argument("--segment_num", type=int, default=None)
23 | # parser.add_argument("--seed", type = int, default=43)
24 |
25 | args = parser.parse_args()
26 |
27 | attn_type = args.model
28 | task = args.task
29 |
30 | checkpoint_dir = args.output
31 |
32 | print(lra_config.config[task]["extra_attn_config"].keys(), flush=True)
33 |
34 | model_config = lra_config.config[task]["model"]
35 | model_config.update(lra_config.config[task]["extra_attn_config"][attn_type])
36 |
37 | model_config["mixed_precision"] = True
38 | model_config["attn_type"] = attn_type
39 | model_config["max_seq_len"] = int(2 ** math.ceil(math.log2(model_config["max_seq_len"])))
40 |
41 | if args.segment_num is not None:
42 | model_config["segment_num"] = args.segment_num
43 |
44 | training_config = lra_config.config[task]["training"]
45 | gpu_memory_config = lra_config.config[task]["gpu_memory"]
46 |
47 | device_ids = list(range(torch.cuda.device_count()))
48 | print(f"GPU list: {device_ids}")
49 |
50 | print(json.dumps([model_config, training_config], indent=4))
51 |
52 | # set_seed(args.seed)
53 |
54 | if task == "retrieval":
55 | model = ModelForSCDual(model_config)
56 | else:
57 | model = ModelForSC(model_config)
58 |
59 | print(model)
60 | print(f"parameter_size: {[weight.size() for weight in model.parameters()]}", flush=True)
61 | print(f"num_parameter: {np.sum([np.prod(weight.size()) for weight in model.parameters()])}", flush=True)
62 |
63 | model = model.cuda()
64 | model = nn.DataParallel(model, device_ids=device_ids)
65 |
66 | ds_iter = {
67 | "train": enumerate(DataLoader(LRADataset(f"../datasets/{task}.train.pickle", True), batch_size=training_config["batch_size"], drop_last=True)),
68 | "dev": enumerate(DataLoader(LRADataset(f"../datasets/{task}.dev.pickle", True), batch_size=training_config["batch_size"], drop_last=True)),
69 | "test": enumerate(DataLoader(LRADataset(f"../datasets/{task}.test.pickle", False), batch_size=training_config["batch_size"], drop_last=True)),
70 | }
71 |
72 | optimizer = torch.optim.AdamW(
73 | model.parameters(),
74 | lr=training_config["learning_rate"],
75 | betas=(0.9, 0.999), eps=1e-6, weight_decay=training_config["weight_decay"]
76 | )
77 |
78 | lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
79 | optimizer=optimizer,
80 | max_lr=training_config["learning_rate"],
81 | pct_start=training_config["warmup"] / training_config["num_train_steps"],
82 | anneal_strategy=training_config["lr_decay"],
83 | total_steps=training_config["num_train_steps"]
84 | )
85 |
86 | amp_scaler = torch.cuda.amp.GradScaler() if model_config["mixed_precision"] else None
87 |
88 |
89 | def step(component, step_idx):
90 | t0 = time.time()
91 |
92 | optimizer.zero_grad()
93 |
94 | _, batch = next(ds_iter[component])
95 | for key in batch:
96 | batch[key] = batch[key].cuda()
97 |
98 | if component == "train":
99 | outputs = {}
100 |
101 | partial_inputs_list = [{} for _ in range(accumu_steps)]
102 | for key in batch:
103 | for idx, inp in enumerate(torch.chunk(batch[key], accumu_steps, dim=0)):
104 | partial_inputs_list[idx][key] = inp
105 |
106 | for partial_inputs in partial_inputs_list:
107 | partial_outputs = model(**partial_inputs)
108 | for key in partial_outputs:
109 | partial_outputs[key] = partial_outputs[key].mean() / accumu_steps
110 | if key not in outputs:
111 | outputs[key] = partial_outputs[key]
112 | else:
113 | outputs[key] += partial_outputs[key]
114 | amp_scaler.scale(partial_outputs["loss"]).backward()
115 |
116 | amp_scaler.step(optimizer)
117 | amp_scaler.update()
118 | lr_scheduler.step()
119 | else:
120 | with torch.no_grad():
121 | outputs = {}
122 |
123 | partial_inputs_list = [{} for _ in range(accumu_steps)]
124 | for key in batch:
125 | for idx, inp in enumerate(torch.chunk(batch[key], accumu_steps, dim=0)):
126 | partial_inputs_list[idx][key] = inp
127 |
128 | for partial_inputs in partial_inputs_list:
129 | partial_outputs = model(**partial_inputs)
130 | for key in partial_outputs:
131 | partial_outputs[key] = partial_outputs[key].mean() / accumu_steps
132 | if key not in outputs:
133 | outputs[key] = partial_outputs[key]
134 | else:
135 | outputs[key] += partial_outputs[key]
136 |
137 | t1 = time.time()
138 |
139 | batch_size = batch[list(batch.keys())[0]].size(0)
140 | t_escape = t1 - t0
141 | learning_rate = optimizer.param_groups[0]["lr"]
142 | loss = outputs["loss"].data.item()
143 | accu = outputs["accu"].data.item()
144 | time_since_start = time.time() - init_t
145 |
146 | print(f"step={step_idx}, tt={time_since_start:.1f}, t={t_escape:.3f}, bs={batch_size}, lr={learning_rate:.6f}, loss={loss:.4f}, accu={accu:.4f}\t\t\t\t", end="\r", flush=True)
147 |
148 | summary[component]["t"] += t_escape
149 | summary[component]["loss"].append(loss)
150 | summary[component]["accu"].append(accu)
151 |
152 |
153 | def print_summary(summary, save_if_improved, train_step_idx):
154 | summary["loss"] = np.mean(summary["loss"])
155 | summary["accu"] = np.mean(summary["accu"])
156 |
157 | print()
158 | if summary["accu"] > summary["best_accu"]:
159 | summary["best_accu"] = summary["accu"]
160 | if save_if_improved:
161 | best_accu = summary["best_accu"]
162 | torch.save({"model_state_dict": model.module.state_dict()}, log_f_path.replace(".log", ".model"))
163 | print(f"best_accu={best_accu}. Saved best model")
164 |
165 | summary_round = {"train_step_idx": train_step_idx}
166 | for key in summary:
167 | if type(summary[key]) is str:
168 | summary_round[key] = summary[key]
169 | else:
170 | summary_round[key] = round(summary[key], 4)
171 |
172 | print(summary_round, flush=True)
173 | log_f.write(json.dumps(summary_round, sort_keys=True) + "\n")
174 | log_f.flush()
175 |
176 | summary["t"] = 0
177 | summary["loss"] = []
178 | summary["accu"] = []
179 |
180 |
181 | init_t = time.time()
182 |
183 | log_f_path = os.path.join(checkpoint_dir, f"{task}_{attn_type}_output.log")
184 | log_f = open(log_f_path, "a+")
185 |
186 | summary = {
187 | component: {"t": 0, "loss": [], "accu": [], "best_accu": 0, "component": component}
188 | for component in ["train", "dev", "test"]
189 | }
190 |
191 | accumu_steps = max(training_config["batch_size"] // len(device_ids) // gpu_memory_config[attn_type], 1)
192 | print(f"accumu_steps={accumu_steps}")
193 |
194 | train_step_idx = None
195 | if args.skip_train == 0:
196 | try:
197 | model.train()
198 | for train_step_idx in range(training_config["num_train_steps"]):
199 | outputs = step("train", train_step_idx)
200 |
201 | if (train_step_idx + 1) % training_config["eval_frequency"] == 0:
202 | print_summary(summary["train"], False, train_step_idx)
203 | model.eval()
204 | for dev_step_idx in range(training_config["num_eval_steps"]):
205 | outputs = step("dev", dev_step_idx)
206 | print_summary(summary["dev"], True, train_step_idx)
207 | model.train()
208 | except KeyboardInterrupt as e:
209 | print(e)
210 |
211 | checkpoint = torch.load(log_f_path.replace(".log", ".model"), map_location="cpu")
212 | model.module.load_state_dict(checkpoint["model_state_dict"])
213 | model.eval()
214 | try:
215 | for test_step_idx in itertools.count():
216 | outputs = step("test", test_step_idx)
217 | except StopIteration:
218 | print_summary(summary["test"], False, train_step_idx)
219 |
--------------------------------------------------------------------------------
/examples/LRA/code/run_tasks.sh:
--------------------------------------------------------------------------------
1 | export PYTHONPATH=`pwd`
2 | export CUDA_VISIBLE_DEVICES=0
3 |
4 | MODEL=$1
5 | OUTPUT=../logs
6 |
7 | if [ ! -d ${OUTPUT} ]; then
8 | mkdir ${OUTPUT}
9 | fi
10 |
11 | python3 run_tasks.py --model ${MODEL} --task listops --output ${OUTPUT}
12 | python3 run_tasks.py --model ${MODEL} --task text --output ${OUTPUT}
13 | python3 run_tasks.py --model ${MODEL} --task retrieval --output ${OUTPUT}
14 | python3 run_tasks.py --model ${MODEL} --task image --output ${OUTPUT}
15 | python3 run_tasks.py --model ${MODEL} --task pathfinder32-curv_contour_length_14 --output ${OUTPUT}
16 |
--------------------------------------------------------------------------------
/examples/LRA/datasets/cifar10.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("./long-range-arena/lra_benchmarks/image/")
3 | import input_pipeline
4 | import numpy as np
5 | import pickle
6 |
7 | train_ds, eval_ds, test_ds, num_classes, vocab_size, input_shape = input_pipeline.get_cifar10_datasets(
8 | n_devices = 1, batch_size = 1, normalize = False)
9 |
10 | mapping = {"train":train_ds, "dev": eval_ds, "test":test_ds}
11 | for component in mapping:
12 | ds_list = []
13 | for idx, inst in enumerate(iter(mapping[component])):
14 | ds_list.append({
15 | "input_ids_0":inst["inputs"].numpy()[0].reshape(-1),
16 | "label":inst["targets"].numpy()[0]
17 | })
18 | if idx % 100 == 0:
19 | print(f"{idx}\t\t", end = "\r")
20 | with open(f"image.{component}.pickle", "wb") as f:
21 | pickle.dump(ds_list, f)
22 |
--------------------------------------------------------------------------------
/examples/LRA/datasets/create_datasets.sh:
--------------------------------------------------------------------------------
1 | export PYTHONPATH=`pwd`/long-range-arena
2 |
3 | python3 cifar10.py
4 | python3 pathfinder.py
5 | python3 listops.py
6 | python3 retrieval.py
7 | python3 text.py
8 |
--------------------------------------------------------------------------------
/examples/LRA/datasets/delete_repeat.sh:
--------------------------------------------------------------------------------
1 | sed -i 's|train_dataset = train_dataset.repeat()|# train_dataset = train_dataset.repeat()|g' ./long-range-arena/lra_benchmarks/image/input_pipeline.py
--------------------------------------------------------------------------------
/examples/LRA/datasets/listops.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("./long-range-arena/lra_benchmarks/listops/")
3 | import input_pipeline
4 | import numpy as np
5 | import pickle
6 |
7 | train_ds, eval_ds, test_ds, encoder = input_pipeline.get_datasets(
8 | n_devices = 1, task_name = "basic", data_dir = "./lra_release/lra_release/listops-1000/",
9 | batch_size = 1, max_length = 2000)
10 |
11 | mapping = {"train":train_ds, "dev": eval_ds, "test":test_ds}
12 | for component in mapping:
13 | ds_list = []
14 | for idx, inst in enumerate(iter(mapping[component])):
15 | ds_list.append({
16 | "input_ids_0":np.concatenate([inst["inputs"].numpy()[0], np.zeros(48, dtype = np.int32)]),
17 | "label":inst["targets"].numpy()[0]
18 | })
19 | if idx % 100 == 0:
20 | print(f"{idx}\t\t", end = "\r")
21 | with open(f"listops.{component}.pickle", "wb") as f:
22 | pickle.dump(ds_list, f)
23 |
--------------------------------------------------------------------------------
/examples/LRA/datasets/pathfinder.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import os
4 | import pickle
5 | import tensorflow as tf
6 | import random
7 |
8 | root_dir = "./lra_release/lra_release/"
9 | subdir = "pathfinder32"
10 | for diff_level in ["curv_baseline", "curv_contour_length_9", "curv_contour_length_14"]:
11 | data_dir = os.path.join(root_dir, subdir, diff_level)
12 | metadata_list = [
13 | os.path.join(data_dir, "metadata", file)
14 | for file in os.listdir(os.path.join(data_dir, "metadata"))
15 | if file.endswith(".npy")
16 | ]
17 | ds_list = []
18 | for idx, metadata_file in enumerate(metadata_list):
19 | print(idx, len(metadata_list), metadata_file, "\t\t", end = "\r")
20 | for inst_meta in tf.io.read_file(metadata_file).numpy().decode("utf-8").split("\n")[:-1]:
21 | metadata = inst_meta.split(" ")
22 | img_path = os.path.join(data_dir, metadata[0], metadata[1])
23 | img_bin = tf.io.read_file(img_path)
24 | if len(img_bin.numpy()) == 0:
25 | print()
26 | print("detected empty image")
27 | continue
28 | img = tf.image.decode_png(img_bin)
29 | seq = img.numpy().reshape(-1).astype(np.int32)
30 | label = int(metadata[3])
31 | ds_list.append({"input_ids_0":seq, "label":label})
32 |
33 | random.shuffle(ds_list)
34 |
35 | bp80 = int(len(ds_list) * 0.8)
36 | bp90 = int(len(ds_list) * 0.9)
37 | train = ds_list[:bp80]
38 | dev = ds_list[bp80:bp90]
39 | test = ds_list[bp90:]
40 |
41 | with open(f"{subdir}-{diff_level}.train.pickle", "wb") as f:
42 | pickle.dump(train, f)
43 | with open(f"{subdir}-{diff_level}.dev.pickle", "wb") as f:
44 | pickle.dump(dev, f)
45 | with open(f"{subdir}-{diff_level}.test.pickle", "wb") as f:
46 | pickle.dump(test, f)
47 |
--------------------------------------------------------------------------------
/examples/LRA/datasets/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorboard >= 2.3.0
2 | tensorflow >= 2.3.1
3 | tensorflow-datasets >= 4.0.1
4 | tensorflow_text
--------------------------------------------------------------------------------
/examples/LRA/datasets/retrieval.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("./long-range-arena/lra_benchmarks/matching/")
3 | import input_pipeline
4 | import numpy as np
5 | import pickle
6 |
7 | train_ds, eval_ds, test_ds, encoder = input_pipeline.get_matching_datasets(
8 | n_devices = 1, task_name = None, data_dir = "./lra_release/lra_release/tsv_data/",
9 | batch_size = 1, fixed_vocab = None, max_length = 4000, tokenizer = "char",
10 | vocab_file_path = None)
11 |
12 | mapping = {"train":train_ds, "dev": eval_ds, "test":test_ds}
13 | for component in mapping:
14 | ds_list = []
15 | for idx, inst in enumerate(iter(mapping[component])):
16 | ds_list.append({
17 | "input_ids_0":np.concatenate([inst["inputs1"].numpy()[0], np.zeros(96, dtype = np.int32)]),
18 | "input_ids_1":np.concatenate([inst["inputs2"].numpy()[0], np.zeros(96, dtype = np.int32)]),
19 | "label":inst["targets"].numpy()[0]
20 | })
21 | if idx % 100 == 0:
22 | print(f"{idx}\t\t", end = "\r")
23 | with open(f"retrieval.{component}.pickle", "wb") as f:
24 | pickle.dump(ds_list, f)
25 |
--------------------------------------------------------------------------------
/examples/LRA/datasets/text.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("./long-range-arena/lra_benchmarks/text_classification/")
3 | import input_pipeline
4 | import numpy as np
5 | import pickle
6 |
7 | train_ds, eval_ds, test_ds, encoder = input_pipeline.get_tc_datasets(
8 | n_devices = 1, task_name = "imdb_reviews", data_dir = None,
9 | batch_size = 1, fixed_vocab = None, max_length = 4000)
10 |
11 | mapping = {"train":train_ds, "dev": eval_ds, "test":test_ds}
12 | for component in mapping:
13 | ds_list = []
14 | for idx, inst in enumerate(iter(mapping[component])):
15 | ds_list.append({
16 | "input_ids_0":np.concatenate([inst["inputs"].numpy()[0], np.zeros(96, dtype = np.int32)]),
17 | "label":inst["targets"].numpy()[0]
18 | })
19 | if idx % 100 == 0:
20 | print(f"{idx}\t\t", end = "\r")
21 | with open(f"text.{component}.pickle", "wb") as f:
22 | pickle.dump(ds_list, f)
23 |
--------------------------------------------------------------------------------
/extra/classifier_trainer.py:
--------------------------------------------------------------------------------
1 | from transformers.trainer import Trainer
2 |
3 | import math
4 | import os
5 | import shutil
6 | import sys
7 | import time
8 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
9 |
10 | from tqdm.auto import tqdm
11 |
12 |
13 | # Integrations must be imported before ML frameworks:
14 | # isort: off
15 | from transformers.integrations import (
16 | hp_params,
17 | is_fairscale_available,
18 | )
19 |
20 | # isort: on
21 |
22 | import torch
23 | import torch.distributed as dist
24 | from packaging import version
25 | from torch import nn
26 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
27 | from torch.utils.data.distributed import DistributedSampler
28 |
29 | from transformers import __version__
30 | from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
31 | from transformers.deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
32 | from transformers.dependency_versions_check import dep_version_check
33 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11
34 | from transformers.trainer_callback import (
35 | DefaultFlowCallback,
36 | ProgressCallback,
37 | TrainerState,
38 | )
39 | from transformers.trainer_pt_utils import (
40 | IterableDatasetShard,
41 | get_model_param_count,
42 | )
43 | from transformers.trainer_utils import (
44 | HPSearchBackend,
45 | ShardedDDPOption,
46 | TrainOutput,
47 | has_length,
48 | speed_metrics,
49 | )
50 | from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments
51 | from transformers.utils import (
52 | is_accelerate_available,
53 | is_apex_available,
54 | is_datasets_available,
55 | is_in_notebook,
56 | is_safetensors_available,
57 | is_sagemaker_mp_enabled,
58 | is_torch_tpu_available,
59 | logging,
60 | )
61 |
62 |
63 | DEFAULT_CALLBACKS = [DefaultFlowCallback]
64 | DEFAULT_PROGRESS_CALLBACK = ProgressCallback
65 |
66 | if is_in_notebook():
67 | from transformers.utils.notebook import NotebookProgressCallback
68 |
69 | DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
70 |
71 | if is_apex_available():
72 | from apex import amp
73 |
74 | if is_datasets_available():
75 | import datasets
76 |
77 | if is_torch_tpu_available(check_device=False):
78 | import torch_xla.core.xla_model as xm
79 | import torch_xla.debug.metrics as met
80 | import torch_xla.distributed.parallel_loader as pl
81 |
82 | if is_fairscale_available():
83 | dep_version_check("fairscale")
84 | import fairscale
85 | from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
86 | from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
87 | from fairscale.nn.wrap import auto_wrap
88 | from fairscale.optim import OSS
89 | from fairscale.optim.grad_scaler import ShardedGradScaler
90 |
91 |
92 | if is_sagemaker_mp_enabled():
93 | import smdistributed.modelparallel.torch as smp
94 | from smdistributed.modelparallel import __version__ as SMP_VERSION
95 |
96 | IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
97 |
98 | from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
99 | else:
100 | IS_SAGEMAKER_MP_POST_1_10 = False
101 |
102 |
103 | if is_safetensors_available():
104 | import safetensors.torch
105 |
106 |
107 | skip_first_batches = None
108 | if is_accelerate_available():
109 | from accelerate import __version__ as accelerate_version
110 |
111 | if version.parse(accelerate_version) >= version.parse("0.16"):
112 | from accelerate import skip_first_batches
113 |
114 |
115 | if TYPE_CHECKING:
116 | import optuna
117 |
118 | logger = logging.get_logger(__name__)
119 |
120 |
121 | # Name of the files used for checkpointing
122 | TRAINING_ARGS_NAME = "training_args.bin"
123 | TRAINER_STATE_NAME = "trainer_state.json"
124 | OPTIMIZER_NAME = "optimizer.pt"
125 | SCHEDULER_NAME = "scheduler.pt"
126 | SCALER_NAME = "scaler.pt"
127 |
128 |
129 | class Classifier_Trainer(Trainer):
130 | def prediction_step(
131 | self,
132 | model: nn.Module,
133 | inputs: Dict[str, Union[torch.Tensor, Any]],
134 | prediction_loss_only: bool,
135 | ignore_keys: Optional[List[str]] = None,
136 | ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
137 | loss, logits, labels = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
138 | if logits is not None:
139 | if type(logits) == tuple:
140 | logits = tuple([l.argmax(dim=-1) for l in logits])
141 | else:
142 | logits = logits.argmax(dim=-1)
143 | return (loss, logits, labels)
144 |
145 |
146 | class SM_Trainer(Trainer):
147 | def prediction_step(
148 | self,
149 | model: nn.Module,
150 | inputs: Dict[str, Union[torch.Tensor, Any]],
151 | prediction_loss_only: bool,
152 | ignore_keys: Optional[List[str]] = None,
153 | ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
154 | loss, logits, labels = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
155 | if logits is not None:
156 | bz = labels[0].shape[0] if type(labels) in [tuple, list] else labels.shape[0]
157 | mlm_loss = logits[0].view(1, 1).repeat(bz, 1)
158 | sso_loss = logits[1].view(1, 1).repeat(bz, 1)
159 | logits = tuple([mlm_loss, sso_loss,] + [l.argmax(dim=-1) for l in logits[2:]])
160 | return (loss, logits, labels)
161 |
162 | def _inner_training_loop(
163 | self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
164 | ):
165 | self._train_batch_size = batch_size
166 | # Data loader and number of training steps
167 | train_dataloader = self.get_train_dataloader()
168 |
169 | # Setting up training control variables:
170 | # number of training epochs: num_train_epochs
171 | # number of training steps per epoch: num_update_steps_per_epoch
172 | # total number of training steps to execute: max_steps
173 | total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
174 |
175 | len_dataloader = None
176 | if has_length(train_dataloader):
177 | len_dataloader = len(train_dataloader)
178 | num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
179 | num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
180 | num_examples = self.num_examples(train_dataloader)
181 | if args.max_steps > 0:
182 | max_steps = args.max_steps
183 | num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
184 | args.max_steps % num_update_steps_per_epoch > 0
185 | )
186 | # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
187 | # the best we can do.
188 | num_train_samples = args.max_steps * total_train_batch_size
189 | else:
190 | max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
191 | num_train_epochs = math.ceil(args.num_train_epochs)
192 | num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
193 | elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
194 | max_steps = args.max_steps
195 | # Setting a very large number of epochs so we go as many times as necessary over the iterator.
196 | num_train_epochs = sys.maxsize
197 | num_update_steps_per_epoch = max_steps
198 | num_examples = total_train_batch_size * args.max_steps
199 | num_train_samples = args.max_steps * total_train_batch_size
200 | else:
201 | raise ValueError(
202 | "args.max_steps must be set to a positive value if dataloader does not have a length, was"
203 | f" {args.max_steps}"
204 | )
205 |
206 | # Compute absolute values for logging, eval, and save if given as ratio
207 | if args.logging_steps and args.logging_steps < 1:
208 | args.logging_steps = math.ceil(max_steps * args.logging_steps)
209 | if args.eval_steps and args.eval_steps < 1:
210 | args.eval_steps = math.ceil(max_steps * args.eval_steps)
211 | if args.save_steps and args.save_steps < 1:
212 | args.save_steps = math.ceil(max_steps * args.save_steps)
213 |
214 | if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
215 | if self.args.n_gpu > 1:
216 | # nn.DataParallel(model) replicates the model, creating new variables and module
217 | # references registered here no longer work on other gpus, breaking the module
218 | raise ValueError(
219 | "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
220 | " (torch.distributed.launch)."
221 | )
222 | else:
223 | debug_overflow = DebugUnderflowOverflow(self.model) # noqa
224 |
225 | delay_optimizer_creation = (
226 | self.sharded_ddp is not None
227 | and self.sharded_ddp != ShardedDDPOption.SIMPLE
228 | or is_sagemaker_mp_enabled()
229 | or self.fsdp is not None
230 | )
231 | if args.deepspeed:
232 | deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
233 | self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
234 | )
235 | self.model = deepspeed_engine.module
236 | self.model_wrapped = deepspeed_engine
237 | self.deepspeed = deepspeed_engine
238 | self.optimizer = optimizer
239 | self.lr_scheduler = lr_scheduler
240 | elif not delay_optimizer_creation:
241 | self.create_optimizer_and_scheduler(num_training_steps=max_steps)
242 |
243 | self.state = TrainerState()
244 | self.state.is_hyper_param_search = trial is not None
245 |
246 | # Activate gradient checkpointing if needed
247 | if args.gradient_checkpointing:
248 | self.model.gradient_checkpointing_enable()
249 |
250 | model = self._wrap_model(self.model_wrapped)
251 |
252 | if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
253 | self._load_from_checkpoint(resume_from_checkpoint, model)
254 |
255 | # for the rest of this function `model` is the outside model, whether it was wrapped or not
256 | if model is not self.model:
257 | self.model_wrapped = model
258 |
259 | if delay_optimizer_creation:
260 | self.create_optimizer_and_scheduler(num_training_steps=max_steps)
261 |
262 | # Check if saved optimizer or scheduler states exist
263 | self._load_optimizer_and_scheduler(resume_from_checkpoint)
264 |
265 | # important: at this point:
266 | # self.model is the Transformers Model
267 | # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
268 |
269 | # Train!
270 | logger.info("***** Running training *****")
271 | logger.info(f" Num examples = {num_examples:,}")
272 | logger.info(f" Num Epochs = {num_train_epochs:,}")
273 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size:,}")
274 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
275 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
276 | logger.info(f" Total optimization steps = {max_steps:,}")
277 | logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
278 |
279 | self.state.epoch = 0
280 | start_time = time.time()
281 | epochs_trained = 0
282 | steps_trained_in_current_epoch = 0
283 | steps_trained_progress_bar = None
284 |
285 | # Check if continuing training from a checkpoint
286 | if resume_from_checkpoint is not None and os.path.isfile(
287 | os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
288 | ):
289 | self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
290 | epochs_trained = self.state.global_step // num_update_steps_per_epoch
291 | if not args.ignore_data_skip:
292 | steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
293 | steps_trained_in_current_epoch *= args.gradient_accumulation_steps
294 | else:
295 | steps_trained_in_current_epoch = 0
296 |
297 | logger.info(" Continuing training from checkpoint, will skip to saved global_step")
298 | logger.info(f" Continuing training from epoch {epochs_trained}")
299 | logger.info(f" Continuing training from global step {self.state.global_step}")
300 | if not args.ignore_data_skip:
301 | if skip_first_batches is None:
302 | logger.info(
303 | f" Will skip the first {epochs_trained} epochs then the first"
304 | f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time,"
305 | " you can install the latest version of Accelerate with `pip install -U accelerate`.You can"
306 | " also add the `--ignore_data_skip` flag to your launch command, but you will resume the"
307 | " training on data already seen by your model."
308 | )
309 | else:
310 | logger.info(
311 | f" Will skip the first {epochs_trained} epochs then the first"
312 | f" {steps_trained_in_current_epoch} batches in the first epoch."
313 | )
314 | if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None:
315 | steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
316 | steps_trained_progress_bar.set_description("Skipping the first batches")
317 |
318 | # Update the references
319 | self.callback_handler.model = self.model
320 | self.callback_handler.optimizer = self.optimizer
321 | self.callback_handler.lr_scheduler = self.lr_scheduler
322 | self.callback_handler.train_dataloader = train_dataloader
323 | if self.hp_name is not None and self._trial is not None:
324 | # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
325 | # parameter to Train when using DDP.
326 | self.state.trial_name = self.hp_name(self._trial)
327 | if trial is not None:
328 | assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
329 | self.state.trial_params = hp_params(assignments)
330 | else:
331 | self.state.trial_params = None
332 | # This should be the same if the state has been saved but in case the training arguments changed, it's safer
333 | # to set this after the load.
334 | self.state.max_steps = max_steps
335 | self.state.num_train_epochs = num_train_epochs
336 | self.state.is_local_process_zero = self.is_local_process_zero()
337 | self.state.is_world_process_zero = self.is_world_process_zero()
338 |
339 | # tr_loss is a tensor to avoid synchronization of TPUs through .item()
340 | tr_loss = torch.tensor(0.0).to(args.device)
341 | mlm_tr_loss = torch.tensor(0.0).to(args.device)
342 | sso_tr_loss = torch.tensor(0.0).to(args.device)
343 | # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
344 | self._total_loss_scalar = 0.0
345 | self._mlm_total_loss_scalar = 0.0
346 | self._sop_total_loss_scalar = 0.0
347 | self._globalstep_last_logged = self.state.global_step
348 | model.zero_grad()
349 |
350 | self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
351 |
352 | # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
353 | if not args.ignore_data_skip:
354 | for epoch in range(epochs_trained):
355 | is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
356 | train_dataloader.sampler, RandomSampler
357 | )
358 | if is_torch_less_than_1_11 or not is_random_sampler:
359 | # We just need to begin an iteration to create the randomization of the sampler.
360 | # That was before PyTorch 1.11 however...
361 | for _ in train_dataloader:
362 | break
363 | else:
364 | # Otherwise we need to call the whooooole sampler cause there is some random operation added
365 | # AT THE VERY END!
366 | _ = list(train_dataloader.sampler)
367 |
368 | total_batched_samples = 0
369 | for epoch in range(epochs_trained, num_train_epochs):
370 | if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
371 | train_dataloader.sampler.set_epoch(epoch)
372 | elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
373 | train_dataloader.dataset.set_epoch(epoch)
374 |
375 | if is_torch_tpu_available():
376 | parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
377 | epoch_iterator = parallel_loader
378 | else:
379 | epoch_iterator = train_dataloader
380 |
381 | # Reset the past mems state at the beginning of each epoch if necessary.
382 | if args.past_index >= 0:
383 | self._past = None
384 |
385 | steps_in_epoch = (
386 | len(epoch_iterator)
387 | if len_dataloader is not None
388 | else args.max_steps * args.gradient_accumulation_steps
389 | )
390 | self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
391 |
392 | if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
393 | self._load_rng_state(resume_from_checkpoint)
394 |
395 | rng_to_sync = False
396 | steps_skipped = 0
397 | if skip_first_batches is not None and steps_trained_in_current_epoch > 0:
398 | epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
399 | steps_skipped = steps_trained_in_current_epoch
400 | steps_trained_in_current_epoch = 0
401 | rng_to_sync = True
402 |
403 | step = -1
404 | for step, inputs in enumerate(epoch_iterator):
405 | total_batched_samples += 1
406 | if rng_to_sync:
407 | self._load_rng_state(resume_from_checkpoint)
408 | rng_to_sync = False
409 |
410 | # Skip past any already trained steps if resuming training
411 | if steps_trained_in_current_epoch > 0:
412 | steps_trained_in_current_epoch -= 1
413 | if steps_trained_progress_bar is not None:
414 | steps_trained_progress_bar.update(1)
415 | if steps_trained_in_current_epoch == 0:
416 | self._load_rng_state(resume_from_checkpoint)
417 | continue
418 | elif steps_trained_progress_bar is not None:
419 | steps_trained_progress_bar.close()
420 | steps_trained_progress_bar = None
421 |
422 | if step % args.gradient_accumulation_steps == 0:
423 | self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
424 |
425 | if (
426 | (total_batched_samples % args.gradient_accumulation_steps != 0)
427 | and args.parallel_mode == ParallelMode.DISTRIBUTED
428 | and args._no_sync_in_gradient_accumulation
429 | and hasattr(model, "no_sync")
430 | ):
431 | # Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
432 | with model.no_sync():
433 | tr_loss_step, mlm_loss_step, sso_loss_step = self.training_step(model, inputs)
434 | else:
435 | tr_loss_step, mlm_loss_step, sso_loss_step = self.training_step(model, inputs)
436 |
437 | if (
438 | args.logging_nan_inf_filter
439 | and not is_torch_tpu_available()
440 | and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
441 | ):
442 | # if loss is nan or inf simply add the average of previous logged losses
443 | tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
444 | mlm_tr_loss += mlm_loss_step / (1 + self.state.global_step - self._globalstep_last_logged)
445 | sso_tr_loss += sso_loss_step / (1 + self.state.global_step - self._globalstep_last_logged)
446 | else:
447 | tr_loss += tr_loss_step
448 | mlm_tr_loss += mlm_loss_step
449 | sso_tr_loss += sso_loss_step
450 |
451 | self.current_flos += float(self.floating_point_ops(inputs))
452 |
453 | # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps
454 | if self.deepspeed:
455 | self.deepspeed.step()
456 |
457 | if total_batched_samples % args.gradient_accumulation_steps == 0 or (
458 | # last step in epoch but step is always smaller than gradient_accumulation_steps
459 | steps_in_epoch <= args.gradient_accumulation_steps
460 | and (step + 1) == steps_in_epoch
461 | ):
462 | # Gradient clipping
463 | if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:
464 | # deepspeed does its own clipping
465 |
466 | if self.do_grad_scaling:
467 | # Reduce gradients first for XLA
468 | if is_torch_tpu_available():
469 | gradients = xm._fetch_gradients(self.optimizer)
470 | xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
471 | # AMP: gradients need unscaling
472 | self.scaler.unscale_(self.optimizer)
473 |
474 | if is_sagemaker_mp_enabled() and args.fp16:
475 | self.optimizer.clip_master_grads(args.max_grad_norm)
476 | elif hasattr(self.optimizer, "clip_grad_norm"):
477 | # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
478 | self.optimizer.clip_grad_norm(args.max_grad_norm)
479 | elif hasattr(model, "clip_grad_norm_"):
480 | # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
481 | model.clip_grad_norm_(args.max_grad_norm)
482 | else:
483 | # Revert to normal clipping otherwise, handling Apex or full precision
484 | nn.utils.clip_grad_norm_(
485 | amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
486 | args.max_grad_norm,
487 | )
488 |
489 | # Optimizer step
490 | optimizer_was_run = True
491 | if self.deepspeed:
492 | pass # called outside the loop
493 | elif is_torch_tpu_available():
494 | if self.do_grad_scaling:
495 | self.scaler.step(self.optimizer)
496 | self.scaler.update()
497 | else:
498 | xm.optimizer_step(self.optimizer)
499 | elif self.do_grad_scaling:
500 | scale_before = self.scaler.get_scale()
501 | self.scaler.step(self.optimizer)
502 | self.scaler.update()
503 | scale_after = self.scaler.get_scale()
504 | optimizer_was_run = scale_before <= scale_after
505 | else:
506 | self.optimizer.step()
507 |
508 | if optimizer_was_run and not self.deepspeed:
509 | # Delay optimizer scheduling until metrics are generated
510 | if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
511 | self.lr_scheduler.step()
512 |
513 | model.zero_grad()
514 | self.state.global_step += 1
515 | self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
516 | self.control = self.callback_handler.on_step_end(args, self.state, self.control)
517 |
518 | self._maybe_log_save_evaluate(tr_loss, mlm_tr_loss, sso_tr_loss, model, trial, epoch, ignore_keys_for_eval)
519 | else:
520 | self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
521 |
522 | if self.control.should_epoch_stop or self.control.should_training_stop:
523 | break
524 | if step < 0:
525 | logger.warning(
526 | "There seems to be not a single sample in your epoch_iterator, stopping training at step"
527 | f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
528 | f" num_steps ({max_steps}) higher than the number of available samples."
529 | )
530 | self.control.should_training_stop = True
531 |
532 | self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
533 | self._maybe_log_save_evaluate(tr_loss, mlm_tr_loss, sso_tr_loss, model, trial, epoch, ignore_keys_for_eval)
534 |
535 | if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
536 | if is_torch_tpu_available():
537 | # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
538 | xm.master_print(met.metrics_report())
539 | else:
540 | logger.warning(
541 | "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
542 | "configured. Check your training configuration if this is unexpected."
543 | )
544 | if self.control.should_training_stop:
545 | break
546 |
547 | if args.past_index and hasattr(self, "_past"):
548 | # Clean the state at the end of training
549 | delattr(self, "_past")
550 |
551 | logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
552 | if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
553 | # Wait for everyone to get here so we are sur the model has been saved by process 0.
554 | if is_torch_tpu_available():
555 | xm.rendezvous("load_best_model_at_end")
556 | elif args.parallel_mode == ParallelMode.DISTRIBUTED:
557 | dist.barrier()
558 | elif is_sagemaker_mp_enabled():
559 | smp.barrier()
560 |
561 | self._load_best_model()
562 |
563 | # add remaining tr_loss
564 | self._total_loss_scalar += tr_loss.item()
565 | self._mlm_total_loss_scalar += mlm_tr_loss.item()
566 | self._sop_total_loss_scalar += sso_tr_loss.item()
567 |
568 | train_loss = self._total_loss_scalar / self.state.global_step
569 | train_mlm_loss = self._mlm_total_loss_scalar / self.state.global_step
570 | train_sso_loss = self._sop_total_loss_scalar / self.state.global_step
571 |
572 | metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
573 | self.store_flos()
574 | metrics["total_flos"] = self.state.total_flos
575 | metrics["train_loss"] = train_loss
576 | metrics["train_mlm_loss"] = train_mlm_loss
577 | metrics["train_sso_loss"] = train_sso_loss
578 |
579 | self.is_in_train = False
580 |
581 | self._memory_tracker.stop_and_update_metrics(metrics)
582 |
583 | self.log(metrics)
584 |
585 | run_dir = self._get_output_dir(trial)
586 | checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
587 |
588 | # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
589 | if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
590 | for checkpoint in checkpoints_sorted:
591 | if checkpoint != self.state.best_model_checkpoint:
592 | logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
593 | shutil.rmtree(checkpoint)
594 |
595 | self.control = self.callback_handler.on_train_end(args, self.state, self.control)
596 |
597 | return TrainOutput(self.state.global_step, train_loss, metrics)
598 |
599 | def _maybe_log_save_evaluate(self, tr_loss, mlm_tr_loss, sso_tr_loss, model, trial, epoch, ignore_keys_for_eval):
600 | if self.control.should_log:
601 | if is_torch_tpu_available():
602 | xm.mark_step()
603 |
604 | logs: Dict[str, float] = {}
605 |
606 | # all_gather + mean() to get average loss over all processes
607 | tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
608 | mlm_tr_loss_scalar = self._nested_gather(mlm_tr_loss).mean().item()
609 | sso_tr_loss_scalar = self._nested_gather(sso_tr_loss).mean().item()
610 |
611 | # reset tr_loss to zero
612 | tr_loss -= tr_loss
613 | mlm_tr_loss -= mlm_tr_loss
614 | sso_tr_loss -= sso_tr_loss
615 |
616 | logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
617 | logs["mlm_loss"] = round(mlm_tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
618 | logs["sso_loss"] = round(sso_tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
619 | logs["learning_rate"] = self._get_learning_rate()
620 |
621 | self._total_loss_scalar += tr_loss_scalar
622 | self._mlm_total_loss_scalar += mlm_tr_loss_scalar
623 | self._sop_total_loss_scalar += sso_tr_loss_scalar
624 | self._globalstep_last_logged = self.state.global_step
625 | self.store_flos()
626 |
627 | self.log(logs)
628 |
629 | metrics = None
630 | if self.control.should_evaluate:
631 | if isinstance(self.eval_dataset, dict):
632 | metrics = {}
633 | for eval_dataset_name, eval_dataset in self.eval_dataset.items():
634 | dataset_metrics = self.evaluate(
635 | eval_dataset=eval_dataset,
636 | ignore_keys=ignore_keys_for_eval,
637 | metric_key_prefix=f"eval_{eval_dataset_name}",
638 | )
639 | metrics.update(dataset_metrics)
640 | else:
641 | metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
642 | self._report_to_hp_search(trial, self.state.global_step, metrics)
643 |
644 | # Run delayed LR scheduler now that metrics are populated
645 | if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
646 | self.lr_scheduler.step(metrics[self.args.metric_for_best_model])
647 |
648 | if self.control.should_save:
649 | self._save_checkpoint(model, trial, metrics=metrics)
650 | self.control = self.callback_handler.on_save(self.args, self.state, self.control)
651 |
652 | def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
653 | """
654 | Perform a training step on a batch of inputs.
655 |
656 | Subclass and override to inject custom behavior.
657 |
658 | Args:
659 | model (`nn.Module`):
660 | The model to train.
661 | inputs (`Dict[str, Union[torch.Tensor, Any]]`):
662 | The inputs and targets of the model.
663 |
664 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
665 | argument `labels`. Check your model's documentation for all accepted arguments.
666 |
667 | Return:
668 | `torch.Tensor`: The tensor with training loss on this batch.
669 | """
670 | model.train()
671 | inputs = self._prepare_inputs(inputs)
672 |
673 | if is_sagemaker_mp_enabled():
674 | loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
675 | return loss_mb.reduce_mean().detach().to(self.args.device)
676 |
677 | with self.compute_loss_context_manager():
678 | loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
679 |
680 | mlm_loss = outputs['mlm_loss'].detach()
681 | sso_loss = outputs['sso_loss'].detach()
682 |
683 | if self.args.n_gpu > 1:
684 | loss = loss.mean() # mean() to average on multi-gpu parallel training
685 | mlm_loss = mlm_loss.mean()
686 | sso_loss = sso_loss.mean()
687 |
688 | if self.args.gradient_accumulation_steps > 1 and not self.deepspeed:
689 | # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
690 | loss = loss / self.args.gradient_accumulation_steps
691 | mlm_loss = mlm_loss / self.args.gradient_accumulation_steps
692 | sso_loss = sso_loss / self.args.gradient_accumulation_steps
693 |
694 | if self.do_grad_scaling:
695 | self.scaler.scale(loss).backward()
696 | elif self.use_apex:
697 | with amp.scale_loss(loss, self.optimizer) as scaled_loss:
698 | scaled_loss.backward()
699 | elif self.deepspeed:
700 | # loss gets scaled under gradient_accumulation_steps in deepspeed
701 | loss = self.deepspeed.backward(loss)
702 | else:
703 | loss.backward()
704 |
705 | return loss.detach(), mlm_loss, sso_loss
706 |
--------------------------------------------------------------------------------
/extra/dataset_dict.py:
--------------------------------------------------------------------------------
1 |
2 | from datasets.dataset_dict import DatasetDict as oldDatasetDict
3 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4 | from datasets.features import Features
5 |
6 |
7 | class DatasetDict(oldDatasetDict):
8 | # add new_fingerprints args
9 | def map(
10 | self,
11 | function,
12 | with_indices: bool = False,
13 | input_columns: Optional[Union[str, List[str]]] = None,
14 | batched: bool = False,
15 | batch_size: Optional[int] = 1000,
16 | remove_columns: Optional[List[str]] = None,
17 | keep_in_memory: bool = False,
18 | load_from_cache_file: bool = True,
19 | cache_file_names: Optional[Dict[str, Optional[str]]] = None,
20 | writer_batch_size: Optional[int] = 1000,
21 | features: Optional[Features] = None,
22 | disable_nullable: bool = False,
23 | fn_kwargs: Optional[dict] = None,
24 | num_proc: Optional[int] = None,
25 | new_fingerprints=None,
26 | ) -> "DatasetDict":
27 | """Apply a function to all the elements in the table (individually or in batches)
28 | and update the table (if function does updated examples).
29 | The transformation is applied to all the datasets of the dataset dictionary.
30 |
31 | Args:
32 | function (`callable`): with one of the following signature:
33 | - `function(example: Dict) -> Union[Dict, Any]` if `batched=False` and `with_indices=False`
34 | - `function(example: Dict, indices: int) -> Union[Dict, Any]` if `batched=False` and `with_indices=True`
35 | - `function(batch: Dict[List]) -> Union[Dict, Any]` if `batched=True` and `with_indices=False`
36 | - `function(batch: Dict[List], indices: List[int]) -> Union[Dict, Any]` if `batched=True` and `with_indices=True`
37 | with_indices (`bool`, defaults to `False`): Provide example indices to `function`. Note that in this case the signature of `function` should be `def function(example, idx): ...`.
38 | input_columns (`Optional[Union[str, List[str]]]`, defaults to `None`): The columns to be passed into `function` as
39 | positional arguments. If `None`, a dict mapping to all formatted columns is passed as one argument.
40 | batched (`bool`, defaults to `False`): Provide batch of examples to `function`
41 | batch_size (`Optional[int]`, defaults to `1000`): Number of examples per batch provided to `function` if `batched=True`
42 | `batch_size <= 0` or `batch_size == None`: Provide the full dataset as a single batch to `function`
43 | remove_columns (`Optional[List[str]]`, defaults to `None`): Remove a selection of columns while doing the mapping.
44 | Columns will be removed before updating the examples with the output of `function`, i.e. if `function` is adding
45 | columns with names in `remove_columns`, these columns will be kept.
46 | keep_in_memory (`bool`, defaults to `False`): Keep the dataset in memory instead of writing it to a cache file.
47 | load_from_cache_file (`bool`, defaults to `True`): If a cache file storing the current computation from `function`
48 | can be identified, use it instead of recomputing.
49 | cache_file_names (`Optional[Dict[str, str]]`, defaults to `None`): Provide the name of a path for the cache file. It is used to store the
50 | results of the computation instead of the automatically generated cache file name.
51 | You have to provide one :obj:`cache_file_name` per dataset in the dataset dictionary.
52 | writer_batch_size (:obj:`int`, default `1000`): Number of rows per write operation for the cache file writer.
53 | This value is a good trade-off between memory usage during the processing, and processing speed.
54 | Higher value makes the processing do fewer lookups, lower value consume less temporary memory while running `.map()`.
55 | features (`Optional[datasets.Features]`, defaults to `None`): Use a specific Features to store the cache file
56 | instead of the automatically generated one.
57 | disable_nullable (`bool`, defaults to `True`): Disallow null values in the table.
58 | fn_kwargs (`Optional[Dict]`, defaults to `None`): Keyword arguments to be passed to `function`
59 | num_proc (`Optional[int]`, defaults to `None`): Number of processes for multiprocessing. By default it doesn't
60 | use multiprocessing.
61 | """
62 | self._check_values_type()
63 | if cache_file_names is None:
64 | cache_file_names = {k: None for k in self}
65 | if new_fingerprints is None:
66 | new_fingerprints = {k: None for k in self}
67 | return DatasetDict(
68 | {
69 | k: dataset.map(
70 | function=function,
71 | with_indices=with_indices,
72 | input_columns=input_columns,
73 | batched=batched,
74 | batch_size=batch_size,
75 | remove_columns=remove_columns,
76 | keep_in_memory=keep_in_memory,
77 | load_from_cache_file=load_from_cache_file,
78 | cache_file_name=cache_file_names[k],
79 | writer_batch_size=writer_batch_size,
80 | features=features,
81 | disable_nullable=disable_nullable,
82 | fn_kwargs=fn_kwargs,
83 | num_proc=num_proc,
84 | new_fingerprint=new_fingerprints[k]
85 | )
86 | for k, dataset in self.items()
87 | }
88 | )
89 |
--------------------------------------------------------------------------------
/image/consumption.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxchtan/PoNet/cc9a29a3c7732e8c5cdda7eae32e556a2fa71b1d/image/consumption.png
--------------------------------------------------------------------------------
/image/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxchtan/PoNet/cc9a29a3c7732e8c5cdda7eae32e556a2fa71b1d/image/model.png
--------------------------------------------------------------------------------
/image/performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxchtan/PoNet/cc9a29a3c7732e8c5cdda7eae32e556a2fa71b1d/image/performance.png
--------------------------------------------------------------------------------
/metrics/macro_f1_and_acc/macro_f1_and_acc.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """F1 metric."""
16 |
17 | from sklearn.metrics import f1_score, accuracy_score
18 |
19 | import datasets
20 |
21 |
22 | _DESCRIPTION = """
23 | The F1 score is the harmonic mean of the precision and recall. It can be computed with:
24 | F1 = 2 * (precision * recall) / (precision + recall)
25 | """
26 |
27 | _KWARGS_DESCRIPTION = """
28 | Args:
29 | predictions: Predicted labels, as returned by a model.
30 | references: Ground truth labels.
31 | labels: The set of labels to include when average != 'binary', and
32 | their order if average is None. Labels present in the data can
33 | be excluded, for example to calculate a multiclass average ignoring
34 | a majority negative class, while labels not present in the data will
35 | result in 0 components in a macro average. For multilabel targets,
36 | labels are column indices. By default, all labels in y_true and
37 | y_pred are used in sorted order.
38 | average: This parameter is required for multiclass/multilabel targets.
39 | If None, the scores for each class are returned. Otherwise, this
40 | determines the type of averaging performed on the data:
41 | binary: Only report results for the class specified by pos_label.
42 | This is applicable only if targets (y_{true,pred}) are binary.
43 | micro: Calculate metrics globally by counting the total true positives,
44 | false negatives and false positives.
45 | macro: Calculate metrics for each label, and find their unweighted mean.
46 | This does not take label imbalance into account.
47 | weighted: Calculate metrics for each label, and find their average
48 | weighted by support (the number of true instances for each label).
49 | This alters ‘macro’ to account for label imbalance; it can result
50 | in an F-score that is not between precision and recall.
51 | samples: Calculate metrics for each instance, and find their average
52 | (only meaningful for multilabel classification).
53 | sample_weight: Sample weights.
54 | Returns:
55 | f1: F1 score.
56 | Examples:
57 |
58 | >>> f1_metric = datasets.load_metric("f1")
59 | >>> results = f1_metric.compute(references=[0, 1], predictions=[0, 1])
60 | >>> print(results)
61 | {'f1': 1.0}
62 | """
63 |
64 | _CITATION = """\
65 | @article{scikit-learn,
66 | title={Scikit-learn: Machine Learning in {P}ython},
67 | author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
68 | and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
69 | and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
70 | Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
71 | journal={Journal of Machine Learning Research},
72 | volume={12},
73 | pages={2825--2830},
74 | year={2011}
75 | }
76 | """
77 |
78 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
79 | class Micro_F1_and_Acc(datasets.Metric):
80 | def _info(self):
81 | return datasets.MetricInfo(
82 | description=_DESCRIPTION,
83 | citation=_CITATION,
84 | inputs_description=_KWARGS_DESCRIPTION,
85 | features=datasets.Features(
86 | {
87 | "predictions": datasets.Value("int64"),
88 | "references": datasets.Value("int64"),
89 | }
90 | ),
91 | reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"],
92 | )
93 |
94 | def _compute(self, predictions, references, labels=None, pos_label=1, average="macro", sample_weight=None):
95 | return {
96 | "f1": f1_score(
97 | references,
98 | predictions,
99 | labels=labels,
100 | pos_label=pos_label,
101 | average=average,
102 | sample_weight=sample_weight,
103 | ),
104 | "accuracy": accuracy_score(
105 | references,
106 | predictions,
107 | )
108 | }
109 |
--------------------------------------------------------------------------------
/metrics/micro_f1_and_acc/micro_f1_and_acc.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """F1 metric."""
16 |
17 | from sklearn.metrics import f1_score, accuracy_score
18 |
19 | import datasets
20 |
21 |
22 | _DESCRIPTION = """
23 | The F1 score is the harmonic mean of the precision and recall. It can be computed with:
24 | F1 = 2 * (precision * recall) / (precision + recall)
25 | """
26 |
27 | _KWARGS_DESCRIPTION = """
28 | Args:
29 | predictions: Predicted labels, as returned by a model.
30 | references: Ground truth labels.
31 | labels: The set of labels to include when average != 'binary', and
32 | their order if average is None. Labels present in the data can
33 | be excluded, for example to calculate a multiclass average ignoring
34 | a majority negative class, while labels not present in the data will
35 | result in 0 components in a macro average. For multilabel targets,
36 | labels are column indices. By default, all labels in y_true and
37 | y_pred are used in sorted order.
38 | average: This parameter is required for multiclass/multilabel targets.
39 | If None, the scores for each class are returned. Otherwise, this
40 | determines the type of averaging performed on the data:
41 | binary: Only report results for the class specified by pos_label.
42 | This is applicable only if targets (y_{true,pred}) are binary.
43 | micro: Calculate metrics globally by counting the total true positives,
44 | false negatives and false positives.
45 | macro: Calculate metrics for each label, and find their unweighted mean.
46 | This does not take label imbalance into account.
47 | weighted: Calculate metrics for each label, and find their average
48 | weighted by support (the number of true instances for each label).
49 | This alters ‘macro’ to account for label imbalance; it can result
50 | in an F-score that is not between precision and recall.
51 | samples: Calculate metrics for each instance, and find their average
52 | (only meaningful for multilabel classification).
53 | sample_weight: Sample weights.
54 | Returns:
55 | f1: F1 score.
56 | Examples:
57 |
58 | >>> f1_metric = datasets.load_metric("f1")
59 | >>> results = f1_metric.compute(references=[0, 1], predictions=[0, 1])
60 | >>> print(results)
61 | {'f1': 1.0}
62 | """
63 |
64 | _CITATION = """\
65 | @article{scikit-learn,
66 | title={Scikit-learn: Machine Learning in {P}ython},
67 | author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
68 | and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
69 | and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
70 | Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
71 | journal={Journal of Machine Learning Research},
72 | volume={12},
73 | pages={2825--2830},
74 | year={2011}
75 | }
76 | """
77 |
78 | @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
79 | class Micro_F1_and_Acc(datasets.Metric):
80 | def _info(self):
81 | return datasets.MetricInfo(
82 | description=_DESCRIPTION,
83 | citation=_CITATION,
84 | inputs_description=_KWARGS_DESCRIPTION,
85 | features=datasets.Features(
86 | {
87 | "predictions": datasets.Value("int64"),
88 | "references": datasets.Value("int64"),
89 | }
90 | ),
91 | reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html"],
92 | )
93 |
94 | def _compute(self, predictions, references, labels=None, pos_label=1, average="micro", sample_weight=None):
95 | return {
96 | "f1": f1_score(
97 | references,
98 | predictions,
99 | labels=labels,
100 | pos_label=pos_label,
101 | average=average,
102 | sample_weight=sample_weight,
103 | ),
104 | "accuracy": accuracy_score(
105 | references,
106 | predictions,
107 | )
108 | }
109 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch == 2.0.1
2 | transformers == 4.29.1
3 | datasets >= 1.1.3
4 | sentencepiece != 0.1.92
5 | protobuf
6 | tensorboard
7 | fairscale
8 | scipy
9 | scikit-learn
10 | accelerate
--------------------------------------------------------------------------------
/run_glue.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ Finetuning the library models for sequence classification on GLUE."""
17 | # You can also adapt this script on your own text classification task. Pointers for this are left as comments.
18 | import logging
19 | import os
20 | import random
21 | import sys
22 | from dataclasses import dataclass, field
23 | from typing import Optional
24 |
25 | import numpy as np
26 | from datasets import load_dataset, load_metric
27 |
28 | import transformers
29 | from transformers import (
30 | AutoConfig,
31 | AutoModelForSequenceClassification,
32 | AutoTokenizer,
33 | DataCollatorWithPadding,
34 | EvalPrediction,
35 | HfArgumentParser,
36 | PretrainedConfig,
37 | Trainer,
38 | TrainingArguments,
39 | default_data_collator,
40 | set_seed,
41 | )
42 | from transformers.trainer_utils import get_last_checkpoint, is_main_process
43 | from transformers.utils import check_min_version
44 |
45 |
46 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
47 | check_min_version("4.21.3")
48 |
49 | task_to_keys = {
50 | "cola": ("sentence", None),
51 | "mnli": ("premise", "hypothesis"),
52 | "mrpc": ("sentence1", "sentence2"),
53 | "qnli": ("question", "sentence"),
54 | "qqp": ("question1", "question2"),
55 | "rte": ("sentence1", "sentence2"),
56 | "sst2": ("sentence", None),
57 | "stsb": ("sentence1", "sentence2"),
58 | "wnli": ("sentence1", "sentence2"),
59 | }
60 |
61 | logger = logging.getLogger(__name__)
62 |
63 |
64 | @dataclass
65 | class DataTrainingArguments:
66 | """
67 | Arguments pertaining to what data we are going to input our model for training and eval.
68 |
69 | Using `HfArgumentParser` we can turn this class
70 | into argparse arguments to be able to specify them on
71 | the command line.
72 | """
73 |
74 | task_name: Optional[str] = field(
75 | default=None,
76 | metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
77 | )
78 | max_seq_length: int = field(
79 | default=128,
80 | metadata={
81 | "help": "The maximum total input sequence length after tokenization. Sequences longer "
82 | "than this will be truncated, sequences shorter will be padded."
83 | },
84 | )
85 | overwrite_cache: bool = field(
86 | default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
87 | )
88 | pad_to_max_length: bool = field(
89 | default=True,
90 | metadata={
91 | "help": "Whether to pad all samples to `max_seq_length`. "
92 | "If False, will pad the samples dynamically when batching to the maximum length in the batch."
93 | },
94 | )
95 | max_train_samples: Optional[int] = field(
96 | default=None,
97 | metadata={
98 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
99 | "value if set."
100 | },
101 | )
102 | max_eval_samples: Optional[int] = field(
103 | default=None,
104 | metadata={
105 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
106 | "value if set."
107 | },
108 | )
109 | max_predict_samples: Optional[int] = field(
110 | default=None,
111 | metadata={
112 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
113 | "value if set."
114 | },
115 | )
116 | train_file: Optional[str] = field(
117 | default=None, metadata={"help": "A csv or a json file containing the training data."}
118 | )
119 | validation_file: Optional[str] = field(
120 | default=None, metadata={"help": "A csv or a json file containing the validation data."}
121 | )
122 | test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})
123 |
124 | def __post_init__(self):
125 | if self.task_name is not None:
126 | self.task_name = self.task_name.lower()
127 | if self.task_name not in task_to_keys.keys():
128 | raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys()))
129 | elif self.train_file is None or self.validation_file is None:
130 | raise ValueError("Need either a GLUE task or a training/validation file.")
131 | else:
132 | train_extension = self.train_file.split(".")[-1]
133 | assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
134 | validation_extension = self.validation_file.split(".")[-1]
135 | assert (
136 | validation_extension == train_extension
137 | ), "`validation_file` should have the same extension (csv or json) as `train_file`."
138 |
139 |
140 | @dataclass
141 | class ModelArguments:
142 | """
143 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
144 | """
145 |
146 | model_name_or_path: str = field(
147 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
148 | )
149 | config_name: Optional[str] = field(
150 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
151 | )
152 | tokenizer_name: Optional[str] = field(
153 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
154 | )
155 | cache_dir: Optional[str] = field(
156 | default=None,
157 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
158 | )
159 | use_fast_tokenizer: bool = field(
160 | default=True,
161 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
162 | )
163 | model_revision: str = field(
164 | default="main",
165 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
166 | )
167 | use_auth_token: bool = field(
168 | default=False,
169 | metadata={
170 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
171 | "with private models)."
172 | },
173 | )
174 |
175 | def main():
176 | # See all possible arguments in src/transformers/training_args.py
177 | # or by passing the --help flag to this script.
178 | # We now keep distinct sets of args, for a cleaner separation of concerns.
179 |
180 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
181 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
182 | # If we pass only one argument to the script and it's the path to a json file,
183 | # let's parse it to get our arguments.
184 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
185 | else:
186 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
187 |
188 | # Detecting last checkpoint.
189 | last_checkpoint = None
190 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
191 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
192 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
193 | raise ValueError(
194 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
195 | "Use --overwrite_output_dir to overcome."
196 | )
197 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
198 | logger.info(
199 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
200 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
201 | )
202 |
203 | # Setup logging
204 | logging.basicConfig(
205 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
206 | datefmt="%m/%d/%Y %H:%M:%S",
207 | handlers=[logging.StreamHandler(sys.stdout)],
208 | )
209 | logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
210 |
211 | # Log on each process the small summary:
212 | logger.warning(
213 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
214 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
215 | )
216 | # Set the verbosity to info of the Transformers logger (on main process only):
217 | if is_main_process(training_args.local_rank):
218 | transformers.utils.logging.set_verbosity_info()
219 | transformers.utils.logging.enable_default_handler()
220 | transformers.utils.logging.enable_explicit_format()
221 | logger.info(f"Training/evaluation parameters {training_args}")
222 |
223 | # Set seed before initializing model.
224 | set_seed(training_args.seed)
225 |
226 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
227 | # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
228 | #
229 | # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
230 | # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
231 | # label if at least two columns are provided.
232 | #
233 | # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
234 | # single column. You can easily tweak this behavior (see below)
235 | #
236 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently
237 | # download the dataset.
238 | if data_args.task_name is not None:
239 | # Downloading and loading a dataset from the hub.
240 | datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir)
241 | else:
242 | # Loading a dataset from your local files.
243 | # CSV/JSON training and evaluation files are needed.
244 | data_files = {"train": data_args.train_file, "validation": data_args.validation_file}
245 |
246 | # Get the test dataset: you can provide your own CSV/JSON test file (see below)
247 | # when you use `do_predict` without specifying a GLUE benchmark task.
248 | if training_args.do_predict:
249 | if data_args.test_file is not None:
250 | train_extension = data_args.train_file.split(".")[-1]
251 | test_extension = data_args.test_file.split(".")[-1]
252 | assert (
253 | test_extension == train_extension
254 | ), "`test_file` should have the same extension (csv or json) as `train_file`."
255 | data_files["test"] = data_args.test_file
256 | else:
257 | raise ValueError("Need either a GLUE task or a test file for `do_predict`.")
258 |
259 | for key in data_files.keys():
260 | logger.info(f"load a local file for {key}: {data_files[key]}")
261 |
262 | if data_args.train_file.endswith(".csv"):
263 | # Loading a dataset from local csv files
264 | datasets = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir)
265 | else:
266 | # Loading a dataset from local json files
267 | datasets = load_dataset("json", data_files=data_files, cache_dir=model_args.cache_dir)
268 | # See more about loading any type of standard or custom dataset at
269 | # https://huggingface.co/docs/datasets/loading_datasets.html.
270 |
271 | # Labels
272 | if data_args.task_name is not None:
273 | is_regression = data_args.task_name == "stsb"
274 | if not is_regression:
275 | label_list = datasets["train"].features["label"].names
276 | num_labels = len(label_list)
277 | else:
278 | num_labels = 1
279 | else:
280 | # Trying to have good defaults here, don't hesitate to tweak to your needs.
281 | is_regression = datasets["train"].features["label"].dtype in ["float32", "float64"]
282 | if is_regression:
283 | num_labels = 1
284 | else:
285 | # A useful fast method:
286 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
287 | label_list = datasets["train"].unique("label")
288 | label_list.sort() # Let's sort it for determinism
289 | num_labels = len(label_list)
290 |
291 | # Load pretrained model and tokenizer
292 | #
293 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
294 | # download model & vocab.
295 | config = AutoConfig.from_pretrained(
296 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
297 | num_labels=num_labels,
298 | finetuning_task=data_args.task_name,
299 | cache_dir=model_args.cache_dir,
300 | revision=model_args.model_revision,
301 | use_auth_token=True if model_args.use_auth_token else None,
302 | trust_remote_code=True
303 | )
304 |
305 | tokenizer = AutoTokenizer.from_pretrained(
306 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
307 | cache_dir=model_args.cache_dir,
308 | use_fast=model_args.use_fast_tokenizer,
309 | revision=model_args.model_revision,
310 | use_auth_token=True if model_args.use_auth_token else None,
311 | trust_remote_code=True
312 | )
313 |
314 | model = AutoModelForSequenceClassification.from_pretrained(
315 | model_args.model_name_or_path,
316 | from_tf=bool(".ckpt" in model_args.model_name_or_path),
317 | config=config,
318 | cache_dir=model_args.cache_dir,
319 | revision=model_args.model_revision,
320 | use_auth_token=True if model_args.use_auth_token else None,
321 | trust_remote_code=True
322 | )
323 |
324 | # Preprocessing the datasets
325 | if data_args.task_name is not None:
326 | sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
327 | else:
328 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
329 | non_label_column_names = [name for name in datasets["train"].column_names if name != "label"]
330 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
331 | sentence1_key, sentence2_key = "sentence1", "sentence2"
332 | else:
333 | if len(non_label_column_names) >= 2:
334 | sentence1_key, sentence2_key = non_label_column_names[:2]
335 | else:
336 | sentence1_key, sentence2_key = non_label_column_names[0], None
337 |
338 | # Padding strategy
339 | if data_args.pad_to_max_length:
340 | padding = "max_length"
341 | else:
342 | # We will pad later, dynamically at batch creation, to the max sequence length in each batch
343 | padding = False
344 |
345 | # Some models have set the order of the labels to use, so let's make sure we do use it.
346 | label_to_id = None
347 | if (
348 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
349 | and data_args.task_name is not None
350 | and not is_regression
351 | ):
352 | # Some have all caps in their config, some don't.
353 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
354 | if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
355 | label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
356 | else:
357 | logger.warning(
358 | "Your model seems to have been trained with labels, but they don't match the dataset: ",
359 | f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
360 | "\nIgnoring the model labels as a result.",
361 | )
362 | elif data_args.task_name is None and not is_regression:
363 | label_to_id = {v: i for i, v in enumerate(label_list)}
364 |
365 | if data_args.max_seq_length > tokenizer.model_max_length:
366 | logger.warning(
367 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
368 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
369 | )
370 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
371 |
372 | def preprocess_function(examples):
373 | # Tokenize the texts
374 | args = (
375 | (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
376 | )
377 | result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
378 |
379 | # Map labels to IDs (not necessary for GLUE tasks)
380 | if label_to_id is not None and "label" in examples:
381 | result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
382 | return result
383 |
384 | datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache)
385 | if training_args.do_train:
386 | if "train" not in datasets:
387 | raise ValueError("--do_train requires a train dataset")
388 | train_dataset = datasets["train"]
389 | if data_args.max_train_samples is not None:
390 | train_dataset = train_dataset.select(range(data_args.max_train_samples))
391 |
392 | if training_args.do_eval:
393 | if "validation" not in datasets and "validation_matched" not in datasets:
394 | raise ValueError("--do_eval requires a validation dataset")
395 | eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
396 | if data_args.max_eval_samples is not None:
397 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
398 |
399 | if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
400 | if "test" not in datasets and "test_matched" not in datasets:
401 | raise ValueError("--do_predict requires a test dataset")
402 | predict_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
403 | if data_args.max_predict_samples is not None:
404 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
405 |
406 | # Log a few random samples from the training set:
407 | if training_args.do_train:
408 | for index in random.sample(range(len(train_dataset)), 3):
409 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
410 |
411 | # Get the metric function
412 | if data_args.task_name is not None:
413 | # metric = load_metric("../huggingface/datasets/metrics/glue", data_args.task_name)
414 | metric = load_metric('glue', data_args.task_name)
415 | # TODO: When datasets metrics include regular accuracy, make an else here and remove special branch from
416 | # compute_metrics
417 |
418 | # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
419 | # predictions and label_ids field) and has to return a dictionary string to float.
420 | def compute_metrics(p: EvalPrediction):
421 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
422 | preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
423 | print("percent: ", preds.sum()/preds.shape[0])
424 | if data_args.task_name is not None:
425 | result = metric.compute(predictions=preds, references=p.label_ids)
426 | if len(result) > 1:
427 | result["combined_score"] = np.mean(list(result.values())).item()
428 | return result
429 | elif is_regression:
430 | return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
431 | else:
432 | return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
433 |
434 | # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding.
435 | if data_args.pad_to_max_length:
436 | data_collator = default_data_collator
437 | elif training_args.fp16:
438 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
439 | else:
440 | data_collator = None
441 |
442 | # Initialize our Trainer
443 | trainer = Trainer(
444 | model=model,
445 | args=training_args,
446 | train_dataset=train_dataset if training_args.do_train else None,
447 | eval_dataset=eval_dataset if training_args.do_eval else None,
448 | compute_metrics=compute_metrics,
449 | tokenizer=tokenizer,
450 | data_collator=data_collator,
451 | )
452 |
453 | # Training
454 | if training_args.do_train:
455 | checkpoint = None
456 | if training_args.resume_from_checkpoint is not None:
457 | checkpoint = training_args.resume_from_checkpoint
458 | elif last_checkpoint is not None:
459 | checkpoint = last_checkpoint
460 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
461 | metrics = train_result.metrics
462 | max_train_samples = (
463 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
464 | )
465 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
466 |
467 | trainer.save_model() # Saves the tokenizer too for easy upload
468 |
469 | trainer.log_metrics("train", metrics)
470 | trainer.save_metrics("train", metrics)
471 | trainer.save_state()
472 |
473 | # Evaluation
474 | if training_args.do_eval:
475 | logger.info("*** Evaluate ***")
476 |
477 | # Loop to handle MNLI double evaluation (matched, mis-matched)
478 | tasks = [data_args.task_name]
479 | eval_datasets = [eval_dataset]
480 | if data_args.task_name == "mnli":
481 | tasks.append("mnli-mm")
482 | eval_datasets.append(datasets["validation_mismatched"])
483 |
484 | for eval_dataset, task in zip(eval_datasets, tasks):
485 | metrics = trainer.evaluate(eval_dataset=eval_dataset)
486 |
487 | max_eval_samples = (
488 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
489 | )
490 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
491 |
492 | trainer.log_metrics(f"eval-{task}", metrics)
493 | trainer.save_metrics(f"eval-{task}", metrics)
494 |
495 | if training_args.do_predict:
496 | logger.info("*** Predict ***")
497 |
498 | # Loop to handle MNLI double evaluation (matched, mis-matched)
499 | tasks = [data_args.task_name]
500 | predict_datasets = [predict_dataset]
501 | if data_args.task_name == "mnli":
502 | tasks.append("mnli-mm")
503 | predict_datasets.append(datasets["test_mismatched"])
504 |
505 | for predict_dataset, task in zip(predict_datasets, tasks):
506 | # Removing the `label` columns because it contains -1 and Trainer won't like that.
507 | predict_dataset.remove_columns_("label")
508 | predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
509 | predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
510 |
511 | output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt")
512 | if trainer.is_world_process_zero():
513 | with open(output_predict_file, "w") as writer:
514 | logger.info(f"***** Predict results {task} *****")
515 | writer.write("index\tprediction\n")
516 | for index, item in enumerate(predictions):
517 | if is_regression:
518 | writer.write(f"{index}\t{item:3.3f}\n")
519 | else:
520 | item = label_list[item]
521 | writer.write(f"{index}\t{item}\n")
522 |
523 | if training_args.push_to_hub:
524 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "text-classification"}
525 | if data_args.task_name is not None:
526 | kwargs["language"] = "en"
527 | kwargs["dataset_tags"] = "glue"
528 | kwargs["dataset_args"] = data_args.task_name
529 | kwargs["dataset"] = f"GLUE {data_args.task_name.upper()}"
530 |
531 | trainer.push_to_hub(**kwargs)
532 |
533 |
534 | def _mp_fn(index):
535 | # For xla_spawn (TPUs)
536 | main()
537 |
538 |
539 | if __name__ == "__main__":
540 | main()
--------------------------------------------------------------------------------
/run_long_classification.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ Finetuning the library models for sequence classification on GLUE."""
17 | # You can also adapt this script on your own text classification task. Pointers for this are left as comments.
18 |
19 | import torch
20 | import logging
21 | import os
22 | import random
23 | import sys
24 | from dataclasses import dataclass, field
25 | from typing import Optional
26 |
27 | import numpy as np
28 | from datasets import load_dataset, load_metric
29 |
30 | import transformers
31 | from transformers import (
32 | AutoConfig,
33 | AutoModelForSequenceClassification,
34 | AutoTokenizer,
35 | DataCollatorWithPadding,
36 | EvalPrediction,
37 | HfArgumentParser,
38 | PretrainedConfig,
39 | Trainer,
40 | TrainingArguments,
41 | default_data_collator,
42 | set_seed,
43 | )
44 |
45 | from transformers.trainer_utils import get_last_checkpoint, is_main_process
46 | from transformers.utils import check_min_version
47 | from nltk.tokenize import sent_tokenize
48 |
49 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
50 | check_min_version("4.21.3")
51 |
52 | task_to_metrics = {
53 | "arxiv": ("metrics/micro_f1_and_acc", None),
54 | "imdb": ("glue", "mrpc"),
55 | "yelp": ("metrics/micro_f1_and_acc", None),
56 | "hnd": ("glue", "mrpc"),
57 | }
58 |
59 | task_to_datasets = {
60 | "arxiv": {'path': 'datasets/arxiv-11'},
61 | "imdb": {'path': 'imdb'},
62 | "yelp": {'path': 'yelp_review_full'},
63 | "hnd": {
64 | 'path': 'json',
65 | 'data_files': {
66 | 'train': ['datasets/hyperpartisan_news_detection/train.jsonl'],
67 | 'validation': ['datasets/hyperpartisan_news_detection/dev.jsonl'],
68 | 'test': ['datasets/hyperpartisan_news_detection/test.jsonl'],
69 | }
70 | },
71 | }
72 |
73 | logger = logging.getLogger(__name__)
74 |
75 |
76 | @dataclass
77 | class DataTrainingArguments:
78 | """
79 | Arguments pertaining to what data we are going to input our model for training and eval.
80 |
81 | Using `HfArgumentParser` we can turn this class
82 | into argparse arguments to be able to specify them on
83 | the command line.
84 | """
85 |
86 | task_name: str = field(
87 | default=None,
88 | metadata={"help": "The name of the task to train."},
89 | )
90 | max_seq_length: int = field(
91 | default=128,
92 | metadata={
93 | "help": "The maximum total input sequence length after tokenization. Sequences longer "
94 | "than this will be truncated, sequences shorter will be padded."
95 | },
96 | )
97 | overwrite_cache: bool = field(
98 | default=False,
99 | metadata={"help": "Overwrite the cached preprocessed datasets or not."}
100 | )
101 | pad_to_max_length: bool = field(
102 | default=False,
103 | metadata={
104 | "help": "Whether to pad all samples to `max_seq_length`. "
105 | "If False, will pad the samples dynamically when batching to the maximum length in the batch."
106 | },
107 | )
108 | max_train_samples: Optional[int] = field(
109 | default=None,
110 | metadata={
111 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
112 | "value if set."
113 | },
114 | )
115 | max_eval_samples: Optional[int] = field(
116 | default=None,
117 | metadata={
118 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
119 | "value if set."
120 | },
121 | )
122 | max_predict_samples: Optional[int] = field(
123 | default=None,
124 | metadata={
125 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
126 | "value if set."
127 | },
128 | )
129 | train_file: Optional[str] = field(
130 | default=None, metadata={"help": "A csv or a json file containing the training data."}
131 | )
132 | validation_file: Optional[str] = field(
133 | default=None, metadata={"help": "A csv or a json file containing the validation data."}
134 | )
135 | test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})
136 |
137 | @dataclass
138 | class ModelArguments:
139 | """
140 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
141 | """
142 |
143 | model_name_or_path: str = field(
144 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
145 | )
146 | config_name: Optional[str] = field(
147 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
148 | )
149 | tokenizer_name: Optional[str] = field(
150 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
151 | )
152 | cache_dir: Optional[str] = field(
153 | default=None,
154 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
155 | )
156 | use_fast_tokenizer: bool = field(
157 | default=True,
158 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
159 | )
160 | model_revision: str = field(
161 | default="main",
162 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
163 | )
164 | use_auth_token: bool = field(
165 | default=False,
166 | metadata={
167 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
168 | "with private models)."
169 | },
170 | )
171 | gc: bool = field(
172 | default=True,
173 | metadata={
174 | "help": "Use gradient checkpointing."
175 | },
176 | )
177 |
178 | def main():
179 | # See all possible arguments in src/transformers/training_args.py
180 | # or by passing the --help flag to this script.
181 | # We now keep distinct sets of args, for a cleaner separation of concerns.
182 |
183 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
184 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
185 | # If we pass only one argument to the script and it's the path to a json file,
186 | # let's parse it to get our arguments.
187 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
188 | else:
189 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
190 |
191 | # Detecting last checkpoint.
192 | last_checkpoint = None
193 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
194 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
195 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
196 | raise ValueError(
197 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
198 | "Use --overwrite_output_dir to overcome."
199 | )
200 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
201 | logger.info(
202 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
203 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
204 | )
205 |
206 | # Setup logging
207 | logging.basicConfig(
208 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
209 | datefmt="%m/%d/%Y %H:%M:%S",
210 | handlers=[logging.StreamHandler(sys.stdout)],
211 | )
212 | logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
213 |
214 | # Log on each process the small summary:
215 | logger.warning(
216 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
217 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
218 | )
219 | # Set the verbosity to info of the Transformers logger (on main process only):
220 | if is_main_process(training_args.local_rank):
221 | transformers.utils.logging.set_verbosity_info()
222 | transformers.utils.logging.enable_default_handler()
223 | transformers.utils.logging.enable_explicit_format()
224 | logger.info(f"Training/evaluation parameters {training_args}")
225 |
226 | # Set seed before initializing model.
227 | set_seed(training_args.seed)
228 |
229 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently
230 | # download the dataset.
231 | datasets = load_dataset(
232 | **task_to_datasets[data_args.task_name],
233 | cache_dir=model_args.cache_dir
234 | )
235 |
236 | data_args.validation_split_percentage = 10
237 | if "validation" not in datasets.keys():
238 | if data_args.task_name == 'imdb':
239 | # test set is large enough -> ensure similar size and distribution between val/test datasets
240 | datasets["validation"] = datasets['test']
241 | else:
242 | datasets["validation"] = load_dataset(
243 | **task_to_datasets[data_args.task_name],
244 | split=f"train[:{data_args.validation_split_percentage}%]",
245 | cache_dir=model_args.cache_dir,
246 | )
247 | datasets["train"] = load_dataset(
248 | **task_to_datasets[data_args.task_name],
249 | split=f"train[{data_args.validation_split_percentage+10}%:]",
250 | cache_dir=model_args.cache_dir,
251 | )
252 |
253 | # Labels
254 | label_list = datasets["train"].unique("label")
255 | label_list.sort() # Let's sort it for determinism
256 | num_labels = len(label_list)
257 | is_regression = False
258 |
259 | # Load pretrained model and tokenizer
260 | #
261 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
262 | # download model & vocab.
263 | config = AutoConfig.from_pretrained(
264 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
265 | num_labels=num_labels,
266 | finetuning_task=data_args.task_name,
267 | cache_dir=model_args.cache_dir,
268 | revision=model_args.model_revision,
269 | use_auth_token=True if model_args.use_auth_token else None,
270 | trust_remote_code=True
271 | )
272 | tokenizer = AutoTokenizer.from_pretrained(
273 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
274 | cache_dir=model_args.cache_dir,
275 | use_fast=model_args.use_fast_tokenizer,
276 | revision=model_args.model_revision,
277 | use_auth_token=True if model_args.use_auth_token else None,
278 | trust_remote_code=True
279 | )
280 |
281 |
282 | if model_args.gc:
283 | config.gradient_checkpointing = True
284 | config.use_cache = False
285 |
286 | model = AutoModelForSequenceClassification.from_pretrained(
287 | model_args.model_name_or_path,
288 | from_tf=bool(".ckpt" in model_args.model_name_or_path),
289 | config=config,
290 | cache_dir=model_args.cache_dir,
291 | revision=model_args.model_revision,
292 | use_auth_token=True if model_args.use_auth_token else None,
293 | trust_remote_code=True
294 | )
295 |
296 | # extend position embeddings
297 | if data_args.max_seq_length > tokenizer.model_max_length:
298 | logger.warning(
299 | "Copying the position embedding due to "
300 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the "
301 | f"model ({tokenizer.model_max_length})."
302 | )
303 | max_pos = data_args.max_seq_length
304 | config.max_position_embeddings = max_pos
305 | tokenizer.model_max_length = max_pos
306 | tokenizer.init_kwargs['model_max_length'] = tokenizer.model_max_length
307 | current_max_pos, embed_size = model.ponet.embeddings.position_embeddings.weight.shape
308 | assert max_pos > current_max_pos
309 | # allocate a larger position embedding matrix
310 | new_pos_embed = model.ponet.embeddings.position_embeddings.weight.new_empty(max_pos, embed_size)
311 | # copy position embeddings over and over to initialize the new position embeddings
312 | k = 0
313 | step = current_max_pos
314 | while k < max_pos - 1:
315 | new_pos_embed[k:(k + step)] = model.ponet.embeddings.position_embeddings.weight[:]
316 | k += step
317 | model.ponet.embeddings.position_embeddings.weight.data = new_pos_embed
318 | model.ponet.embeddings.position_ids.data = torch.tensor([i for i in range(max_pos)]).reshape(1, max_pos)
319 |
320 | # Preprocessing the datasets
321 | sentence1_key, sentence2_key = "text", None
322 |
323 | # Padding strategy
324 | if data_args.pad_to_max_length:
325 | padding = "max_length"
326 | else:
327 | # We will pad later, dynamically at batch creation, to the max sequence length in each batch
328 | padding = False
329 |
330 | # Some models have set the order of the labels to use, so let's make sure we do use it.
331 | label_to_id = None
332 | if (
333 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
334 | and data_args.task_name is not None
335 | and not is_regression
336 | ):
337 | # Some have all caps in their config, some don't.
338 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
339 | if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
340 | label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
341 | else:
342 | logger.warning(
343 | "Your model seems to have been trained with labels, but they don't match the dataset: ",
344 | f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
345 | "\nIgnoring the model labels as a result.",
346 | )
347 | elif data_args.task_name is None and not is_regression:
348 | label_to_id = {v: i for i, v in enumerate(label_list)}
349 |
350 | if data_args.max_seq_length > tokenizer.model_max_length:
351 | logger.warning(
352 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
353 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
354 | )
355 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
356 |
357 | def preprocess_function_arxiv(examples):
358 | # Tokenize the texts
359 | segment_ids = []
360 | args = (
361 | ([ex.replace('\n\n', ' ').replace('\n', ' ') for ex in examples[sentence1_key]],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
362 | )
363 | result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
364 | for ex in examples[sentence1_key]:
365 | seg_lens = list(map(len, tokenizer([eex.replace('\n', ' ') for eex in ex.split('\n\n')], add_special_tokens=False, max_length=max_seq_length, truncation=True)['input_ids']))
366 | segment_id = [0] + sum([[i]*sl for i, sl in enumerate(seg_lens, start=1)], [])
367 | segment_id = segment_id[:max_seq_length-1]
368 | segment_ids.append(segment_id + [segment_id[-1]+1])
369 | result["segment_ids"] = segment_ids
370 |
371 | # Map labels to IDs
372 | if label_to_id is not None and "label" in examples:
373 | result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
374 | return result
375 |
376 | def preprocess_function(examples):
377 | # Tokenize the texts
378 | segment_ids = []
379 | args = (
380 | (examples[sentence1_key], ) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
381 | )
382 | result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
383 | for ex in examples[sentence1_key]:
384 | seg_lens = list(map(len, tokenizer(sent_tokenize(ex), add_special_tokens=False, max_length=max_seq_length, truncation=True)['input_ids']))
385 | segment_id = [0] + sum([[i]*sl for i, sl in enumerate(seg_lens, start=1)], [])
386 | segment_id = segment_id[:max_seq_length-1]
387 | segment_ids.append(segment_id + [segment_id[-1]+1])
388 | result["segment_ids"] = segment_ids
389 |
390 | # Map labels to IDs
391 | if label_to_id is not None and "label" in examples:
392 | result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
393 | return result
394 |
395 | datasets = datasets.map(
396 | preprocess_function if data_args.task_name != 'arxiv' else preprocess_function_arxiv,
397 | batched=True,
398 | load_from_cache_file=not data_args.overwrite_cache
399 | )
400 |
401 | # raise RuntimeError
402 | if training_args.do_train:
403 | if "train" not in datasets:
404 | raise ValueError("--do_train requires a train dataset")
405 | # train_dataset = datasets["train"].shuffle(seed=training_args.seed)
406 | train_dataset = datasets["train"]
407 | if data_args.max_train_samples is not None:
408 | train_dataset = train_dataset.select(range(data_args.max_train_samples))
409 |
410 | if training_args.do_eval:
411 | if "validation" not in datasets:
412 | raise ValueError("--do_eval requires a validation dataset")
413 | eval_dataset = datasets["validation"]
414 | if data_args.max_eval_samples is not None:
415 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
416 |
417 | if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
418 | if "test" not in datasets:
419 | raise ValueError("--do_predict requires a test dataset")
420 | predict_dataset = datasets["test"]
421 | if data_args.max_predict_samples is not None:
422 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
423 |
424 | # Log a few random samples from the training set:
425 | if training_args.do_train:
426 | for index in random.sample(range(len(train_dataset)), 3):
427 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
428 |
429 | # Get the metric function
430 | metric = load_metric(*task_to_metrics[data_args.task_name])
431 | # TODO: When datasets metrics include regular accuracy, make an else here and remove special branch from
432 | # compute_metrics
433 |
434 | # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
435 | # predictions and label_ids field) and has to return a dictionary string to float.
436 | def compute_metrics(p: EvalPrediction):
437 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
438 | preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
439 | result = metric.compute(predictions=preds, references=p.label_ids)
440 | if len(result) > 1:
441 | result["combined_score"] = np.mean(list(result.values())).item()
442 | return result
443 |
444 | # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding.
445 | if data_args.pad_to_max_length:
446 | data_collator = default_data_collator
447 | elif training_args.fp16:
448 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
449 | else:
450 | data_collator = None
451 |
452 | # Initialize our Trainer
453 | trainer = Trainer(
454 | model=model,
455 | args=training_args,
456 | train_dataset=train_dataset if training_args.do_train else None,
457 | eval_dataset=eval_dataset if training_args.do_eval else None,
458 | compute_metrics=compute_metrics,
459 | tokenizer=tokenizer,
460 | data_collator=data_collator,
461 | )
462 |
463 | # Training
464 | if training_args.do_train:
465 | checkpoint = None
466 | # if training_args.resume_from_checkpoint is not None:
467 | # checkpoint = training_args.resume_from_checkpoint
468 | # elif last_checkpoint is not None:
469 | # checkpoint = last_checkpoint
470 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
471 | metrics = train_result.metrics
472 | max_train_samples = (
473 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
474 | )
475 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
476 |
477 | # trainer.save_model() # Saves the tokenizer too for easy upload
478 |
479 | trainer.log_metrics("train", metrics)
480 | trainer.save_metrics("train", metrics)
481 | trainer.save_state()
482 |
483 | # Evaluation
484 | if training_args.do_eval:
485 | logger.info("*** Evaluate ***")
486 |
487 | task = data_args.task_name
488 | metrics = trainer.evaluate(eval_dataset=eval_dataset)
489 |
490 | max_eval_samples = (
491 | data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
492 | )
493 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
494 |
495 | trainer.log_metrics(f"eval-{task}", metrics)
496 | trainer.save_metrics(f"eval-{task}", metrics)
497 |
498 | if training_args.do_predict:
499 | logger.info("*** Predict ***")
500 |
501 | task = data_args.task_name
502 | metrics = trainer.evaluate(eval_dataset=predict_dataset)
503 |
504 | max_predict_samples = (
505 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
506 | )
507 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
508 |
509 | trainer.log_metrics(f"predict-{task}", metrics)
510 | trainer.save_metrics(f"predict-{task}", metrics)
511 |
512 |
513 | def _mp_fn(index):
514 | # For xla_spawn (TPUs)
515 | main()
516 |
517 |
518 | if __name__ == "__main__":
519 | main()
--------------------------------------------------------------------------------
/run_pretrained.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2020 The HuggingFace Team All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """
17 | Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) on a text file or a dataset.
18 |
19 | Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
20 | https://huggingface.co/models?filter=masked-lm
21 | """
22 | # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
23 | import torch
24 | import logging
25 | import math
26 | import os
27 | import sys
28 | from dataclasses import dataclass, field
29 | from typing import Optional
30 |
31 | from extra.dataset_dict import DatasetDict as newDatasetDict
32 |
33 | from datasets import load_dataset, concatenate_datasets
34 |
35 | import numpy as np
36 |
37 | import transformers
38 | from transformers import (
39 | CONFIG_MAPPING,
40 | MODEL_FOR_MASKED_LM_MAPPING,
41 | AutoConfig,
42 | AutoModelForPreTraining,
43 | EvalPrediction,
44 | AutoTokenizer,
45 | DataCollatorForLanguageModeling,
46 | HfArgumentParser,
47 | TrainingArguments,
48 | set_seed,
49 | )
50 | from extra.classifier_trainer import SM_Trainer as Trainer
51 | from transformers.trainer_utils import get_last_checkpoint, is_main_process
52 | from transformers.utils import check_min_version
53 | import random
54 |
55 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
56 | check_min_version("4.21.3")
57 |
58 | logger = logging.getLogger(__name__)
59 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
60 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
61 |
62 |
63 | @dataclass
64 | class ModelArguments:
65 | """
66 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
67 | """
68 |
69 | model_name_or_path: Optional[str] = field(
70 | default=None,
71 | metadata={
72 | "help": "The model checkpoint for weights initialization."
73 | "Don't set if you want to train a model from scratch."
74 | },
75 | )
76 | model_type: Optional[str] = field(
77 | default=None,
78 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
79 | )
80 | config_name: Optional[str] = field(
81 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
82 | )
83 | tokenizer_name: Optional[str] = field(
84 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
85 | )
86 | cache_dir: Optional[str] = field(
87 | default=None,
88 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
89 | )
90 | use_fast_tokenizer: bool = field(
91 | default=True,
92 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
93 | )
94 | model_revision: str = field(
95 | default="main",
96 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
97 | )
98 | use_auth_token: bool = field(
99 | default=False,
100 | metadata={
101 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
102 | "with private models)."
103 | },
104 | )
105 |
106 |
107 | @dataclass
108 | class DataTrainingArguments:
109 | """
110 | Arguments pertaining to what data we are going to input our model for training and eval.
111 | """
112 |
113 | dataset_name: Optional[str] = field(
114 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
115 | )
116 | dataset_config_name: Optional[str] = field(
117 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
118 | )
119 | dataset2_name: Optional[str] = field(
120 | default=None, metadata={"help": "The name of the dataset2 to use (via the datasets library)."}
121 | )
122 | dataset2_config_name: Optional[str] = field(
123 | default=None, metadata={"help": "The configuration name of the dataset2 to use (via the datasets library)."}
124 | )
125 | train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
126 | validation_file: Optional[str] = field(
127 | default=None,
128 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
129 | )
130 | overwrite_cache: bool = field(
131 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
132 | )
133 | validation_split_percentage: Optional[int] = field(
134 | default=5,
135 | metadata={
136 | "help": "The percentage of the train set used as validation set in case there's no validation split"
137 | },
138 | )
139 | max_seq_length: Optional[int] = field(
140 | default=None,
141 | metadata={
142 | "help": "The maximum total input sequence length after tokenization. Sequences longer "
143 | "than this will be truncated."
144 | },
145 | )
146 | preprocessing_num_workers: Optional[int] = field(
147 | default=None,
148 | metadata={"help": "The number of processes to use for the preprocessing."},
149 | )
150 | mlm_probability: float = field(
151 | default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
152 | )
153 | line_by_line: bool = field(
154 | default=False,
155 | metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
156 | )
157 | pad_to_max_length: bool = field(
158 | default=False,
159 | metadata={
160 | "help": "Whether to pad all samples to `max_seq_length`. "
161 | "If False, will pad the samples dynamically when batching to the maximum length in the batch."
162 | },
163 | )
164 | dupe_factor: int = field(
165 | default=5,
166 | metadata={
167 | "help": "Number of times to duplicate the input data (with different masks)."
168 | },
169 | )
170 | max_train_samples: Optional[int] = field(
171 | default=None,
172 | metadata={
173 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
174 | "value if set."
175 | },
176 | )
177 | max_eval_samples: Optional[int] = field(
178 | default=None,
179 | metadata={
180 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
181 | "value if set."
182 | },
183 | )
184 |
185 | def __post_init__(self):
186 | if self.dataset_name is None and self.train_file is None and self.validation_file is None:
187 | raise ValueError("Need either a dataset name or a training/validation file.")
188 | else:
189 | if self.train_file is not None:
190 | extension = self.train_file.split(".")[-1]
191 | assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
192 | if self.validation_file is not None:
193 | extension = self.validation_file.split(".")[-1]
194 | assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
195 |
196 |
197 | def main():
198 | # See all possible arguments in src/transformers/training_args.py
199 | # or by passing the --help flag to this script.
200 | # We now keep distinct sets of args, for a cleaner separation of concerns.
201 |
202 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
203 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
204 | # If we pass only one argument to the script and it's the path to a json file,
205 | # let's parse it to get our arguments.
206 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
207 | else:
208 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
209 |
210 | # Detecting last checkpoint.
211 | last_checkpoint = None
212 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
213 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
214 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
215 | raise ValueError(
216 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
217 | "Use --overwrite_output_dir to overcome."
218 | )
219 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
220 | logger.info(
221 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
222 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
223 | )
224 |
225 | # Setup logging
226 | logging.basicConfig(
227 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
228 | datefmt="%m/%d/%Y %H:%M:%S",
229 | handlers=[logging.StreamHandler(sys.stdout)],
230 | )
231 | logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
232 |
233 | # Log on each process the small summary:
234 | logger.warning(
235 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
236 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
237 | )
238 | # Set the verbosity to info of the Transformers logger (on main process only):
239 | if is_main_process(training_args.local_rank):
240 | transformers.utils.logging.set_verbosity_info()
241 | transformers.utils.logging.enable_default_handler()
242 | transformers.utils.logging.enable_explicit_format()
243 | logger.info(f"Training/evaluation parameters {training_args}")
244 |
245 | # Set seed before initializing model.
246 | set_seed(training_args.seed)
247 |
248 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
249 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
250 | # (the dataset will be downloaded automatically from the datasets Hub
251 | #
252 | # For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this
253 | # behavior (see below)
254 | #
255 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently
256 | # download the dataset.
257 | if data_args.dataset_name is not None:
258 | # Downloading and loading a dataset from the hub.
259 | datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
260 | if "validation" not in datasets.keys():
261 | datasets["validation"] = load_dataset(
262 | data_args.dataset_name,
263 | data_args.dataset_config_name,
264 | split=f"train[:{data_args.validation_split_percentage}%]",
265 | cache_dir=model_args.cache_dir,
266 | )
267 | datasets["train"] = load_dataset(
268 | data_args.dataset_name,
269 | data_args.dataset_config_name,
270 | split=f"train[{data_args.validation_split_percentage}%:]",
271 | cache_dir=model_args.cache_dir,
272 | )
273 | if data_args.dataset2_name is not None:
274 | datasets_2 = load_dataset(data_args.dataset2_name, data_args.dataset2_config_name, cache_dir=model_args.cache_dir)
275 | for k in datasets.keys():
276 | datasets[k] = concatenate_datasets([datasets[k], datasets_2[k]])
277 | else:
278 | data_files = {}
279 | if data_args.train_file is not None:
280 | data_files["train"] = data_args.train_file
281 | if data_args.validation_file is not None:
282 | data_files["validation"] = data_args.validation_file
283 | extension = data_args.train_file.split(".")[-1]
284 | if extension == "txt":
285 | extension = "text"
286 | datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
287 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
288 | # https://huggingface.co/docs/datasets/loading_datasets.html.
289 |
290 | # XXX: inject patch, reconstruct later
291 | if datasets.__class__.__name__ == 'DatasetDict':
292 | setattr(datasets.__class__, 'map', newDatasetDict.map)
293 |
294 | # Load pretrained model and tokenizer
295 | #
296 | # Distributed training:
297 | # The .from_pretrained methods guarantee that only one local process can concurrently
298 | # download model & vocab.
299 | config_kwargs = {
300 | "cache_dir": model_args.cache_dir,
301 | "revision": model_args.model_revision,
302 | "use_auth_token": True if model_args.use_auth_token else None,
303 | "trust_remote_code": True,
304 | }
305 | if model_args.config_name:
306 | config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
307 | elif model_args.model_name_or_path:
308 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
309 | else:
310 | config = CONFIG_MAPPING[model_args.model_type]()
311 | logger.warning("You are instantiating a new config instance from scratch.")
312 |
313 | tokenizer_kwargs = {
314 | "cache_dir": model_args.cache_dir,
315 | "use_fast": model_args.use_fast_tokenizer,
316 | "revision": model_args.model_revision,
317 | "use_auth_token": True if model_args.use_auth_token else None,
318 | "model_input_names": ['input_ids', 'token_type_ids', 'attention_mask', 'segment_ids'],
319 | "trust_remote_code": True,
320 | }
321 | if model_args.tokenizer_name:
322 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
323 | elif model_args.model_name_or_path:
324 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
325 | else:
326 | raise ValueError(
327 | "You are instantiating a new tokenizer from scratch. This is not supported by this script."
328 | "You can do it from another script, save it, and load it from here, using --tokenizer_name."
329 | )
330 |
331 | if model_args.model_name_or_path:
332 | model = AutoModelForPreTraining.from_pretrained(
333 | model_args.model_name_or_path,
334 | from_tf=bool(".ckpt" in model_args.model_name_or_path),
335 | config=config,
336 | cache_dir=model_args.cache_dir,
337 | revision=model_args.model_revision,
338 | use_auth_token=True if model_args.use_auth_token else None,
339 | trust_remote_code=True
340 | )
341 | else:
342 | logger.info("Training new model from scratch")
343 | model = AutoModelForPreTraining.from_config(config, trust_remote_code=True)
344 |
345 | model.resize_token_embeddings(len(tokenizer))
346 |
347 | # Preprocessing the datasets.
348 | # First we tokenize all the texts.
349 | if training_args.do_train:
350 | column_names = datasets["train"].column_names
351 | else:
352 | column_names = datasets["validation"].column_names
353 | text_column_name = "text" if "text" in column_names else column_names[0]
354 |
355 | if data_args.max_seq_length is None:
356 | max_seq_length = tokenizer.model_max_length
357 | if max_seq_length > 1024:
358 | logger.warning(
359 | f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
360 | "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
361 | )
362 | max_seq_length = 1024
363 | else:
364 | if data_args.max_seq_length > tokenizer.model_max_length:
365 | logger.warning(
366 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
367 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
368 | )
369 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
370 |
371 | if data_args.line_by_line:
372 | # When using line_by_line, we just tokenize each nonempty line.
373 | padding = "max_length" if data_args.pad_to_max_length else False
374 |
375 | def tokenize_function(examples):
376 | # Remove empty lines
377 | examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()]
378 | return tokenizer(
379 | examples["text"],
380 | padding=padding,
381 | truncation=True,
382 | max_length=max_seq_length,
383 | add_special_tokens=False,
384 | # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
385 | # receives the `special_tokens_mask`.
386 | return_special_tokens_mask=True,
387 | )
388 |
389 | tokenized_datasets = datasets.map(
390 | tokenize_function,
391 | batched=True,
392 | num_proc=data_args.preprocessing_num_workers,
393 | remove_columns=[text_column_name],
394 | load_from_cache_file=not data_args.overwrite_cache,
395 | )
396 | else:
397 | # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
398 | # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
399 | # efficient when it receives the `special_tokens_mask`.
400 | def tokenize_function(examples):
401 | return tokenizer(examples[text_column_name], return_special_tokens_mask=True, add_special_tokens=False)
402 |
403 | tokenized_datasets = datasets.map(
404 | tokenize_function,
405 | batched=True,
406 | num_proc=data_args.preprocessing_num_workers,
407 | remove_columns=column_names,
408 | load_from_cache_file=not data_args.overwrite_cache,
409 | new_fingerprints={'train': 'f100a6a7741b77ef', 'validation': 'f815a2392c2825cb'}
410 | )
411 |
412 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of
413 | # max_seq_length.
414 | def group_texts(examples):
415 | """Creates examples for a single document."""
416 | results = {k:[] for k in examples.keys()}
417 | results["sentence_structural_label"] = []
418 | results["segment_ids"] = []
419 | for _ in range(data_args.dupe_factor):
420 | # Account for special tokens
421 | max_num_tokens = max_seq_length - tokenizer.num_special_tokens_to_add(pair=True)
422 | short_seq_prob = 0.1
423 | fk = 'input_ids'
424 |
425 | # We *usually* want to fill up the entire sequence since we are padding
426 | # to `block_size` anyways, so short sequences are generally wasted
427 | # computation. However, we *sometimes*
428 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
429 | # sequences to minimize the mismatch between pretraining and fine-tuning.
430 | # The `target_seq_length` is just a rough target however, whereas
431 | # `block_size` is a hard limit.
432 | target_seq_length = max_num_tokens
433 | if random.random() < short_seq_prob:
434 | target_seq_length = random.randint(2, max_num_tokens)
435 |
436 | # get_all_chunk
437 | total_chunk = []
438 | current_chunk = [] # a buffer stored current working segments
439 | current_length = 0
440 | i = 0
441 | while i < len(examples[fk]):
442 | segment = examples[fk][i] # get a segment
443 | if not segment:
444 | i += 1
445 | continue
446 | current_chunk.append(examples['input_ids'][i]) # add a segment to current chunk
447 | current_length += len(segment) # overall token length
448 | # if current length goes to the target length or reaches the end of file, start building token a and b
449 | if i == len(examples[fk]) - 1 or current_length >= target_seq_length:
450 | if current_chunk:
451 | total_chunk.append(current_chunk)
452 |
453 | current_chunk = [] # clear current chunk
454 | current_length = 0 # reset current text length
455 | i += 1 # go to next line
456 |
457 | # We DON'T just concatenate all of the tokens from a document into a long
458 | # sequence and choose an arbitrary split point because this would make the
459 | # next sentence prediction task too easy. Instead, we split the input into
460 | # segments "A" and "B" based on the actual "sentences" provided by the user
461 | # input.
462 | current_chunk = [] # a buffer stored current working segments
463 | current_length = 0
464 | i = 0
465 | chunk_id = -1
466 | while i < len(examples[fk]):
467 | segment = examples[fk][i] # get a segment
468 | if not segment:
469 | i += 1
470 | continue
471 | current_chunk.append(examples['input_ids'][i]) # add a segment to current chunk
472 | current_length += len(segment) # overall token length
473 | # if current length goes to the target length or reaches the end of file, start building token a and b
474 | if i == len(examples[fk]) - 1 or current_length >= target_seq_length:
475 | if current_chunk:
476 | chunk_id += 1
477 | # `a_end` is how many segments from `current_chunk` go into the `A` (first) sentence.
478 | a_end = 1
479 | # if current chunk has more than 2 sentences, pick part of it `A` (first) sentence
480 | if len(current_chunk) >= 2:
481 | a_end = random.randint(1, len(current_chunk) - 1)
482 | # token a
483 | tokens_a = []
484 | a_segment_ids = []
485 | for j in range(a_end):
486 | tokens_a.extend(current_chunk[j])
487 | a_segment_ids.extend([j] * len(current_chunk[j]))
488 |
489 | # token b
490 | tokens_b = []
491 | b_segment_ids = []
492 | for j in range(a_end, len(current_chunk)):
493 | tokens_b.extend(current_chunk[j])
494 | b_segment_ids.extend([j] * len(current_chunk[j]))
495 |
496 | if len(tokens_a) == 0 or len(tokens_b) == 0:
497 | continue
498 |
499 | rdn = random.random()
500 | if rdn < 1/3:
501 | is_next = 1
502 | tokens_a, tokens_b = tokens_b, tokens_a
503 | a_segment_ids, b_segment_ids = b_segment_ids, a_segment_ids
504 | elif rdn < 2/3 and len(total_chunk) > 1:
505 | is_next = 2
506 | while True:
507 | rid = random.randint(0, len(total_chunk)-1)
508 | if rid != chunk_id:
509 | break
510 | another_chunk = total_chunk[rid]
511 | tokens_b = sum(another_chunk, [])
512 | b_segment_ids = sum([[acid]*len(ac) for acid, ac in enumerate(another_chunk)], [])
513 | else:
514 | is_next = 0
515 |
516 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, a_segment_ids, b_segment_ids):
517 | """Truncates a pair of sequences to a maximum sequence length."""
518 | while True:
519 | total_length = len(tokens_a) + len(tokens_b)
520 | if total_length <= max_num_tokens:
521 | break
522 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
523 | trunc_segment_ids = a_segment_ids if len(a_segment_ids) > len(b_segment_ids) else b_segment_ids
524 | assert len(trunc_tokens) >= 1
525 | # We want to sometimes truncate from the front and sometimes from the
526 | # back to add more randomness and avoid biases.
527 | if random.random() < 0.5:
528 | del trunc_tokens[0]
529 | del trunc_segment_ids[0]
530 | else:
531 | trunc_tokens.pop()
532 | trunc_segment_ids.pop()
533 |
534 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, a_segment_ids, b_segment_ids)
535 | # add special tokens
536 | input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
537 | # add token type ids, 0 for sentence a, 1 for sentence b
538 | token_type_ids = tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
539 | attention_mask = [1] * len(input_ids)
540 | assert len(tokens_a) >= 1
541 | assert len(tokens_b) >= 1
542 |
543 | results["input_ids"].append(input_ids)
544 | results["token_type_ids"].append(token_type_ids)
545 | results["attention_mask"].append(attention_mask)
546 | results["sentence_structural_label"].append(is_next)
547 | results["special_tokens_mask"].append([1] + [0] * len(tokens_a) + [1] + [0] * len(tokens_b) + [1])
548 |
549 | a_segment_ids = [asi-a_segment_ids[0]+1 for asi in a_segment_ids]
550 | b_segment_ids = [bsi-b_segment_ids[0]+a_segment_ids[-1]+2 for bsi in b_segment_ids]
551 | results["segment_ids"].append([0] + a_segment_ids + [a_segment_ids[-1]+1] + b_segment_ids + [b_segment_ids[-1]+1])
552 | current_chunk = [] # clear current chunk
553 | current_length = 0 # reset current text length
554 | i += 1 # go to next line
555 | return results
556 |
557 | # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
558 | # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
559 | # might be slower to preprocess.
560 | #
561 | # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
562 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
563 | ori_dupe_factor = data_args.dupe_factor
564 | data_args.dupe_factor = 1
565 | tokenized_datasets["validation"] = tokenized_datasets["validation"].map(
566 | group_texts,
567 | batched=True,
568 | num_proc=data_args.preprocessing_num_workers,
569 | load_from_cache_file=not data_args.overwrite_cache,
570 | new_fingerprint="g5888790b41eabcb"
571 | )
572 | data_args.dupe_factor = ori_dupe_factor
573 | tokenized_datasets["train"] = tokenized_datasets["train"].map(
574 | group_texts,
575 | batched=True,
576 | num_proc=data_args.preprocessing_num_workers,
577 | load_from_cache_file=not data_args.overwrite_cache,
578 | new_fingerprint="ed9f003830c2481d"
579 | )
580 |
581 | if training_args.do_train:
582 | if "train" not in tokenized_datasets:
583 | raise ValueError("--do_train requires a train dataset")
584 | train_dataset = tokenized_datasets["train"]
585 | if data_args.max_train_samples is not None:
586 | train_dataset = train_dataset.select(range(data_args.max_train_samples))
587 |
588 | if training_args.do_eval:
589 | if "validation" not in tokenized_datasets:
590 | raise ValueError("--do_eval requires a validation dataset")
591 | eval_dataset = tokenized_datasets["validation"]
592 | if data_args.max_eval_samples is not None:
593 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
594 |
595 | # Data collator
596 | # This one will take care of randomly masking the tokens.
597 | pad_to_multiple_of_8 = data_args.line_by_line and training_args.fp16 and not data_args.pad_to_max_length
598 | data_collator = DataCollatorForLanguageModeling(
599 | tokenizer=tokenizer,
600 | mlm_probability=data_args.mlm_probability,
601 | pad_to_multiple_of=8 if pad_to_multiple_of_8 else None,
602 | )
603 |
604 | def compute_metrics(p: EvalPrediction):
605 | p.predictions[2][p.predictions[2]==-100] = -1
606 | out = {
607 | "mlm_loss": p.predictions[0].mean().item(),
608 | "sop_loss": p.predictions[1].mean().item(),
609 | "mlm_accuracy": ((p.predictions[2] == p.label_ids[0]).astype(np.float32).sum()/(p.label_ids[0] != -100).astype(np.float32).sum()).item(),
610 | "sop_accuracy": (p.predictions[3] == p.label_ids[1]).astype(np.float32).mean().item(),
611 | }
612 | return out
613 |
614 | # Initialize our Trainer
615 | trainer = Trainer(
616 | model=model,
617 | args=training_args,
618 | train_dataset=train_dataset if training_args.do_train else None,
619 | eval_dataset=eval_dataset if training_args.do_eval else None,
620 | compute_metrics=compute_metrics,
621 | tokenizer=tokenizer,
622 | data_collator=data_collator,
623 | )
624 |
625 | # Training
626 | if training_args.do_train:
627 | checkpoint = None
628 | if training_args.resume_from_checkpoint is not None:
629 | checkpoint = training_args.resume_from_checkpoint
630 | elif last_checkpoint is not None:
631 | checkpoint = last_checkpoint
632 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
633 | trainer.save_model() # Saves the tokenizer too for easy upload
634 | metrics = train_result.metrics
635 |
636 | max_train_samples = (
637 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
638 | )
639 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
640 |
641 | trainer.log_metrics("train", metrics)
642 | trainer.save_metrics("train", metrics)
643 | trainer.save_state()
644 |
645 | # Evaluation
646 | if training_args.do_eval:
647 | logger.info("*** Evaluate ***")
648 |
649 | metrics = trainer.evaluate()
650 |
651 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
652 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
653 | perplexity = math.exp(metrics["eval_loss"])
654 | metrics["perplexity"] = perplexity
655 |
656 | trainer.log_metrics("eval", metrics)
657 | trainer.save_metrics("eval", metrics)
658 |
659 | if training_args.push_to_hub:
660 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tags": "fill-mask"}
661 | if data_args.dataset_name is not None:
662 | kwargs["dataset_tags"] = data_args.dataset_name
663 | if data_args.dataset_config_name is not None:
664 | kwargs["dataset_args"] = data_args.dataset_config_name
665 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
666 | else:
667 | kwargs["dataset"] = data_args.dataset_name
668 |
669 | trainer.push_to_hub(**kwargs)
670 |
671 |
672 | def _mp_fn(index):
673 | # For xla_spawn (TPUs)
674 | main()
675 |
676 |
677 | if __name__ == "__main__":
678 | main()
679 |
--------------------------------------------------------------------------------
/run_shell/1-pretrain_bookcorpus_wikipedia.sh:
--------------------------------------------------------------------------------
1 | SUFFIX=PoNet_bookcourpus_wikipedia_dupe5
2 | LOGNAME=`date +%Y%m%d%H`_${SUFFIX}.log
3 | OUTPUT=outputs/${SUFFIX}
4 | MODEL_PATH=chtan/ponet-base-uncased
5 |
6 | if [ ! -d logs ]; then
7 | mkdir logs
8 | fi
9 |
10 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port 31030 run_pretrained.py \
11 | --config_name ${MODEL_PATH} \
12 | --tokenizer_name ${MODEL_PATH} \
13 | --dataset_name bookcorpus \
14 | --dataset2_name wikipedia \
15 | --dataset2_config_name 20200501.en \
16 | --label_names labels next_sentence_label \
17 | --save_total_limit 5 \
18 | --dupe_factor 5 \
19 | --num_train_epochs 500 \
20 | --warmup_steps 5000 \
21 | --learning_rate 1e-4 \
22 | --evaluation_strategy steps \
23 | --save_steps 5000 \
24 | --eval_steps 5000 \
25 | --logging_dir ${OUTPUT} \
26 | --report_to tensorboard \
27 | --do_train \
28 | --do_eval \
29 | --ignore_data_skip \
30 | --per_device_train_batch_size 48 \
31 | --gradient_accumulation_steps 1 \
32 | --fp16 \
33 | --sharded_ddp simple \
34 | --output_dir ${OUTPUT} > logs/${LOGNAME} 2>&1 &
--------------------------------------------------------------------------------
/run_shell/1-pretrain_bookcorpus_wikitext.sh:
--------------------------------------------------------------------------------
1 | SUFFIX=PoNet_bookcourpus_wikitext_dupe5
2 | LOGNAME=`date +%Y%m%d%H`_${SUFFIX}.log
3 | OUTPUT=outputs/${SUFFIX}
4 | MODEL_PATH=chtan/ponet-base-uncased
5 |
6 | if [ ! -d logs ]; then
7 | mkdir logs
8 | fi
9 |
10 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port 31029 run_pretrained.py \
11 | --config_name ${MODEL_PATH} \
12 | --tokenizer_name ${MODEL_PATH} \
13 | --dataset_name bookcorpus \
14 | --dataset2_name wikitext \
15 | --dataset2_config_name wikitext-103-raw-v1 \
16 | --label_names labels next_sentence_label \
17 | --save_total_limit 5 \
18 | --dupe_factor 5 \
19 | --num_train_epochs 500 \
20 | --warmup_steps 5000 \
21 | --learning_rate 1e-4 \
22 | --evaluation_strategy steps \
23 | --save_steps 5000 \
24 | --eval_steps 5000 \
25 | --logging_dir ${OUTPUT} \
26 | --report_to tensorboard \
27 | --do_train \
28 | --do_eval \
29 | --ignore_data_skip \
30 | --per_device_train_batch_size 48 \
31 | --gradient_accumulation_steps 1 \
32 | --fp16 \
33 | --sharded_ddp simple \
34 | --output_dir ${OUTPUT} > logs/${LOGNAME} 2>&1 &
--------------------------------------------------------------------------------
/run_shell/2-GLUE.sh:
--------------------------------------------------------------------------------
1 | num_train_epochs=4
2 | MODEL=chtan/ponet-base-uncased
3 | OUTPRE=`pwd`
4 |
5 | if [ ! -d logs/glue ]; then
6 | mkdir -p logs/glue
7 | fi
8 |
9 | cal_mlp(){
10 | NAME=`date +%Y%m%d%H`_${TASK_NAME}_ep${num_train_epochs}_bz$((bz*GAS))_lr${lr}
11 | OUTPUT=outputs/glue/${NAME}
12 |
13 | CUDA_VISIBLE_DEVICES=${GPUID} python -u run_glue.py \
14 | --model_name_or_path ${MODEL} \
15 | --overwrite_output_dir \
16 | --task_name $TASK_NAME \
17 | --do_train \
18 | --do_eval \
19 | --max_seq_length 128 \
20 | --overwrite_output_dir \
21 | --report_to tensorboard \
22 | --per_device_train_batch_size ${bz} \
23 | --learning_rate ${lr} \
24 | --num_train_epochs ${num_train_epochs} \
25 | --save_total_limit 1 \
26 | --evaluation_strategy steps \
27 | --save_steps 5000 \
28 | --save_strategy no \
29 | --fp16 \
30 | --logging_dir ${OUTPUT} \
31 | --output_dir ${OUTPUT} > logs/glue/${NAME}.log 2>&1
32 | }
33 |
34 |
35 | search(){
36 | GAS=1
37 | bz=128
38 | lr=3e-4; cal_mlp
39 | lr=1e-4; cal_mlp
40 | lr=5e-5; cal_mlp
41 | lr=3e-5; cal_mlp
42 |
43 | GAS=1
44 | bz=64
45 | lr=3e-4; cal_mlp
46 | lr=1e-4; cal_mlp
47 | lr=5e-5; cal_mlp
48 | lr=3e-5; cal_mlp
49 |
50 | GAS=1
51 | bz=32
52 | lr=3e-4; cal_mlp
53 | lr=1e-4; cal_mlp
54 | lr=5e-5; cal_mlp
55 | lr=3e-5; cal_mlp
56 |
57 | GAS=1
58 | bz=16
59 | lr=3e-4; cal_mlp
60 | lr=1e-4; cal_mlp
61 | lr=5e-5; cal_mlp
62 | lr=3e-5; cal_mlp
63 |
64 | GAS=1
65 | bz=8
66 | lr=3e-4; cal_mlp
67 | lr=1e-4; cal_mlp
68 | lr=5e-5; cal_mlp
69 | lr=3e-5; cal_mlp
70 | }
71 |
72 | GPUID=0; TASK_NAME=cola; search &
73 | # GPUID=1; TASK_NAME=stsb; search &
74 | # GPUID=2; TASK_NAME=mrpc; search &
75 | # GPUID=3; TASK_NAME=rte; search &
76 | # GPUID=4; TASK_NAME=sst2; search &
77 | # GPUID=5; TASK_NAME=qqp; search &
78 | # GPUID=6; TASK_NAME=qnli; search &
79 | # GPUID=7; TASK_NAME=mnli; search &
80 |
--------------------------------------------------------------------------------
/run_shell/3-LongTask.sh:
--------------------------------------------------------------------------------
1 | MODEL=chtan/ponet-base-uncased
2 | OUTPRE=`pwd`
3 |
4 | cal(){
5 | NAME=`date +%Y%m%d%H`_${MAINTASK}_ep${num_train_epochs}_bz$((bz*GAS))_lr${lr}
6 | OUTPUT=outputs/${MAINTASK}/${NAME}
7 |
8 | CUDA_VISIBLE_DEVICES=${GPUID} torchrun --nproc_per_node=${GPU_NUMS} run_long_classification.py \
9 | --model_name_or_path ${MODEL} \
10 | --task_name ${MAINTASK} \
11 | --overwrite_output_dir \
12 | --do_train \
13 | --do_eval \
14 | --do_predict \
15 | --max_seq_length 4096 \
16 | --gradient_accumulation_steps ${GAS} \
17 | --per_device_train_batch_size ${bz} \
18 | --per_device_eval_batch_size ${bz} \
19 | --learning_rate ${lr} \
20 | --num_train_epochs ${num_train_epochs} \
21 | --save_total_limit 1 \
22 | --evaluation_strategy steps \
23 | --logging_steps 500 \
24 | --eval_steps 500 \
25 | --save_steps 5000 \
26 | --load_best_model_at_end \
27 | --metric_for_best_model f1 \
28 | --save_strategy no \
29 | --report_to tensorboard \
30 | --fp16 \
31 | --logging_dir ${OUTPUT} \
32 | --output_dir ${OUTPUT} > logs/${MAINTASK}/${NAME}.log 2>&1
33 | }
34 |
35 | search(){
36 | num_train_epochs=${NEP}
37 | if [ ! -d logs/${MAINTASK} ]; then
38 | mkdir -p logs/${MAINTASK}
39 | fi
40 | lr=3e-5; cal
41 | lr=5e-5; cal
42 | }
43 |
44 | GPUID=0,1
45 | GPU_NUMS=2
46 | GAS=1
47 | bz=16
48 | ### ------- Arxiv-11 --------
49 | MAINTASK=arxiv; NEP=10; search;
50 | ### ------- IMDb --------
51 | # MAINTASK=imdb; NEP=10; search;
52 | ### ------- HND --------
53 | # MAINTASK=hnd; NEP=10; search;
54 | ### ------- Yelp-5 --------
55 | # MAINTASK=yelp; NEP=2; search;
56 |
--------------------------------------------------------------------------------
/run_shell/D1-arxiv11.sh:
--------------------------------------------------------------------------------
1 | WORKDIR=`pwd`
2 | DATADIR=${WORKDIR}/datasets/arxiv-11
3 | TMPDIR=`mktemp -d`
4 | git clone https://github.com/LiqunW/Long-document-dataset ${TMPDIR}
5 | cd ${TMPDIR}
6 | RARDIR=rar_dir
7 | mkdir ${RARDIR}
8 | mv cs.*.rar math.*.rar ${RARDIR}/
9 |
10 | unrar_dir(){
11 | src_path=`readlink -f $1`
12 | dst_path=`readlink -f $2`
13 | rar_files=`find $src_path -name '*.rar'`
14 | IFS=$'\n'; array=$rar_files; unset IFS
15 | for rar_file in $array; do
16 | file_path=`echo $rar_file | sed -e "s;$src_path;$dst_path;"`
17 | ext_path=${file_path%/*}
18 | if [ ! -d $ext_path ]; then
19 | mkdir -p $ext_path
20 | fi
21 | unrar x $rar_file $ext_path
22 | done
23 | }
24 |
25 | mkdir data
26 | unrar_dir "${RARDIR}" "data"
27 |
28 | if [ ! -f dataloader.py_ori ]; then
29 | cp dataloader.py dataloader.py_ori
30 | echo "Dataloader('data', 32)" >> dataloader.py
31 | fi
32 | python dataloader.py
33 |
34 | mv Dataset.txt Labels_file.txt data/
35 | mv data ${DATADIR}
36 |
37 | # if [ -d ${TMPDIR} ]; then
38 | # rm -r ${TMPDIR}
39 | # fi
--------------------------------------------------------------------------------