├── .gitignore
├── CITATION.cff
├── LICENSE
├── README.md
├── example_inputs
├── harry_potter.txt
└── harry_potter_full.txt
├── requirements.txt
└── src
├── __init__.py
├── configs
├── data
│ ├── contract_nli.json
│ ├── gov_report.json
│ ├── hotpotqa.json
│ ├── hotpotqa_second_only.json
│ ├── narative_qa.json
│ ├── qasper.json
│ ├── qmsum.json
│ ├── quality.json
│ ├── squad.json
│ ├── squad_ordered_distractors.json
│ ├── squad_shuffled_distractors.json
│ └── summ_screen_fd.json
├── model
│ ├── bart_base_sled.json
│ ├── bart_large_sled.json
│ └── primera_govreport_sled.json
└── training
│ └── base_training_args.json
├── index_building.py
├── inference-example.py
├── metrics
├── __init__.py
└── metrics.py
├── random_training_unlimiformer.py
├── run.py
├── run_generation.py
├── unlimiformer.py
├── usage.py
└── utils
├── __init__.py
├── config.py
├── custom_hf_argument_parser.py
├── custom_seq2seq_trainer.py
├── decoding.py
├── duplicates.py
└── override_training_args.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | .vscode/
131 | wandb/
132 | output*
133 | local*
134 | *.pdf
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | @article{bertsch2023unlimiformer,
2 | title={Unlimiformer: Long-Range Transformers with Unlimited Length Input},
3 | author={Bertsch, Amanda and Alon, Uri and Neubig, Graham and Gormley, Matthew R},
4 | journal={arXiv preprint arXiv:2305.01625},
5 | year={2023}
6 | }
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Amanda Bertsch, Uri Alon, Graham Neubig, and Matthew R. Gormley
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Unlimiformer: Long-Range Transformers with Unlimited Length Input (NeurIPS 2023)
2 | 
3 |
4 | This is the official implementation of the paper:
5 |
6 | [Amanda Bertsch](https://www.cs.cmu.edu/~abertsch/), [Uri Alon](https://urialon.ml/), [Graham Neubig](http://www.phontron.com/), and [Matthew R. Gormley](http://www.cs.cmu.edu/~mgormley/):
7 | [Unlimiformer: Long-Range Transformers with Unlimited Length Input](https://arxiv.org/pdf/2305.01625) (to appear in **NeurIPS 2023**)
8 |
9 | Unlimiformer is a method for augmenting pretrained encoder-decoder models with retrieval-based attention, without changing the mathematical definition of attention.
10 | This allows the use of unlimited length inputs with any pretrained encoder-decoder!
11 | See also our [**Tweet**](https://twitter.com/abertsch72/status/1654110919977324545?s=20).
12 |
13 | Unlimiformer can be used to improve the performance of an already-trained model. For best results, the model can be trained with Unlimiformer training.
14 |
15 | If you have any questions on this work, please open a [GitHub issue](https://github.com/abertsch72/unlimiformer/issues) or email the authors at ```abertsch@cs.cmu.edu, ualon@cs.cmu.edu```
16 |
17 | ## **_October 2023_** - Unlimiformer will appear at NeurIPS 2023!
18 |
19 | ## **_August 2023_** - Unlimiformer now supports **Llama-2** (and all its derivatives)!
20 | To prompt Llama-2 with extremely long inputs, for example, the content of an *entire book*, use:
21 | ```bash
22 | python src/run_generation.py --model_type llama --model_name_or_path meta-llama/Llama-2-13b-chat-hf \
23 | --prefix "[INST] <>\n You are a helpful assistant. Answer with detailed responses according to the entire instruction or question. \n<>\n\n Summarize the following book: " \
24 | --prompt example_inputs/harry_potter_full.txt \
25 | --suffix " [/INST]" --test_unlimiformer --fp16 --length 200 --layer_begin 16 \
26 | --index_devices 1 --datastore_device 1
27 | ```
28 | * The final prompt will be a concatenation of the content of the flags: `--prefix`, `--prompt`, `--suffix`.
29 | * The flag `--prompt` may contain either a path to a text file (e.g., `example_inputs/harry_potter_full.txt`) or the concrete prompt string.
30 | * The flag `--test_unlimiformer` is required to enable Unlimiformer.
31 | * The flag `--length` determines the desired output length.
32 | * The flag `--layer_begin` determines the layer from which Unlimiformer will start to be applied. For example, if we set `--layer_begin 20`, the first 20 layers of the model will perform the standard attention over the last `context_window_size` tokens of the prompt as usual, and the 21st layer and above will attend to the _entire long input_. From our initial experiments, the value of `--layer_begin` should be more than half of the total number of layers in the model, and tuning it dramatically changes the quality of the output.
33 | * The flags: `--datastore_device N` and `--index_devices N1 N2 N3 ...` specify on which GPUs to store Unlimiformer's datastore and index (the base model will be stored on GPU #0).
34 | * Add the flag `--stream_output` to make the generated tokens appear one by one as they are generated.
35 |
36 |
37 | ## Getting Started
38 |
39 | ### General Instructions
40 | Copy the files from `src` into your source code folder.
41 |
42 | You'll need to set values for the Unlimiformer-specific arguments outlined in [`usage.py`](https://github.com/abertsch72/unlimiformer/blob/main/src/usage.py) - you can add these arguments wherever you usually process hyperparameters. To use the model, you must set `test_unlimiformer=True`. For datastore usage, the model must be in evaluation model (e.g. call ```model.eval()``` before inference).
43 |
44 | [`inference-example.py`](https://github.com/abertsch72/unlimiformer/blob/main/src/inference-example.py) outlines a minimal example for running a sequence through an Unlimiformer model, using the default arguments.
45 |
46 | [`run.py`](https://github.com/abertsch72/unlimiformer/blob/main/src/run.py) is an example of a full training setup that integrates Unlimiformer, adopted from [SLED](https://github.com/Mivg/SLED). See full command lines below.
47 |
48 | ### Reproducing the Experiments from the Paper - Command Lines
49 |
50 | To run a standard finetuning + evaluation of BART-base on the GovReport dataset (as examples), use:
51 | ```python
52 | python src/run.py \
53 | src/configs/training/base_training_args.json \
54 | src/configs/data/gov_report.json \
55 | --output_dir output_train_bart_base_local/ \
56 | --learning_rate 1e-5 \
57 | --model_name_or_path facebook/bart-base \
58 | --max_source_length 1024 \
59 | --eval_max_source_length 1024 --do_eval=True \
60 | --eval_steps 1000 --save_steps 1000 \
61 | --per_device_eval_batch_size 1 --per_device_train_batch_size 2 \
62 | --extra_metrics bertscore
63 | ```
64 |
65 | * To use Unlimiformer at **training** time (called "Retrieval training" in the paper), use: `--unlimiformer_training --max_source_length 16384`
66 | * In this case, you might want to use Unlimiformer also at **test**/validation time, and use also: `--test_unlimiformer --eval_max_source_length 999999`
67 | * Alternatively, to use the computationally cheaper "Random-encoded" at **training** time, use `--random_unlimiformer_training --max_source_length 16384`
68 | * To alternate between "retrieval training" and "random-encoded training", use both flags: `--unlimiformer_training --random_unlimiformer_training --max_source_length 16384`
69 |
70 | For additional flags and options, see [`usage.py`](https://github.com/abertsch72/unlimiformer/blob/main/src/usage.py)
71 |
72 |
73 | ## Recommended settings
74 |
75 | ### To evaluate with Unlimiformer
76 | At evaluation time, we recommend the default value for each setting.
77 |
78 | ### To train with Unlimiformer
79 | For an inexpensive method, we recommend training as usual and using Unlimiformer during early stopping. To do so, set ```knn=True``` and leave all other values at default.
80 |
81 |
82 | For best performance, there are 3 expensive settings for training. The best one varies by dataset.
83 | 1. Set ```random_unlimiformer_training=True```: this is the *random-encoded training* setting from the paper
84 | 2. Set ```unlimiformer_training=True```: this is the *retrieval training* setting from the paper
85 | 3. Set ```random_unlimiformer_training=True``` AND ```unlimiformer_training=True```: this is the *alternating training* setting from the paper
86 |
87 | See Table 5 in the paper for a more detailed breakdown of relative training costs.
88 |
89 | ## Tips for very large inputs
90 | ### For training
91 | * you may need to truncate your inputs at training time, e.g. to 8k or 16k tokens. You can use the full inputs at evaluation time
92 | * you can also try splitting your inputs into 16k-token-chunks and training on each one as its own example
93 | ### For evaluation (including early stopping)
94 | * if you're consistently running out of CUDA memory, set ```use_datastore=True``` to use a Faiss datastore to store hidden states.
95 | * if you're still having issues, set ```gpu_datastore=False``` or ```gpu_index=False```, but note that this will degrade performance
96 |
97 | ## Trained models
98 | The following models from the paper are available on Hugging Face. Please note that you must add the Unlimiformer-specific files to your repository, and load these models with ```test_unlimiformer=True```. *If you download these models from Hugging Face, they may not use Unlimiformer by default!*
99 |
100 | ### Table 3: low-cost training methods
101 | | Dataset | Method | Hugging Face link |
102 | | ------------- | ------------- | ------------- |
103 | | GovReport | Baseline: BART-base | [abertsch/bart-base-govreport](https://huggingface.co/abertsch/bart-base-govreport) |
104 | | GovReport | BART-base + Unlimiformer early stopping | [abertsch/unlimiformer-bart-govreport-earlyk](https://huggingface.co/abertsch/unlimiformer-bart-govreport-earlyk) |
105 | | SummScreen | Baseline: BART-base | [abertsch/bart-base-summscreen](https://huggingface.co/abertsch/bart-base-summscreen) |
106 | | SummScreen | BART-base + Unlimiformer early stopping | [abertsch/unlimiformer-bart-summscreen-earlyk](https://huggingface.co/abertsch/unlimiformer-bart-summscreen-earlyk) |
107 |
108 |
109 | ### Table 4: Long-range training methods
110 | | Dataset | Method | Hugging Face link |
111 | | ------------- | ------------- | ------------- |
112 | | GovReport | BART + Unlimiformer (alternating training) | [abertsch/unlimiformer-bart-govreport-alternating](https://huggingface.co/abertsch/unlimiformer-bart-govreport-alternating) |
113 | | SummScreen | BART + Unlimiformer (retrieval training) | [abertsch/unlimiformer-bart-summscreen-retrieval](https://huggingface.co/abertsch/unlimiformer-bart-summscreen-retrieval) |
114 |
115 | ## Table 5: BookSum
116 | | Dataset | Method | Hugging Face link |
117 | | ------------- | ------------- | ------------- |
118 | | BookSum | Baseline: BART-base | [abertsch/bart-base-booksum](https://huggingface.co/abertsch/bart-base-booksum) |
119 | | BookSum | BART-base + Unlimiformer early stopping | [abertsch/unlimiformer-bart-booksum-earlyk](https://huggingface.co/abertsch/unlimiformer-bart-booksum-earlyk) |
120 | | Booksum | BART-base + Unlimiformer (random-encoding training) | [abertsch/unlimiformer-bart-booksum-random-encoding](https://huggingface.co/abertsch/unlimiformer-bart-booksum-random-encoding) |
121 | | Booksum | BART-base + Unlimiformer (alternating training) | [abertsch/unlimiformer-bart-booksum-alternating](https://huggingface.co/abertsch/unlimiformer-bart-booksum-alternating) |
122 |
123 | ## Results
124 |
125 |
126 |
127 |
128 |
129 |
130 | ## Citation
131 | If you use our method or models, please cite [our paper](https://arxiv.org/abs/2305.01625):
132 | ```
133 | @article{bertsch2023unlimiformer,
134 | title={Unlimiformer: Long-Range Transformers with Unlimited Length Input},
135 | author={Bertsch, Amanda and Alon, Uri and Neubig, Graham and Gormley, Matthew R},
136 | journal={arXiv preprint arXiv:2305.01625},
137 | year={2023}
138 | }
139 | ```
140 |
141 |
142 |
143 |
--------------------------------------------------------------------------------
/example_inputs/harry_potter.txt:
--------------------------------------------------------------------------------
1 | Harry Potter and the Sorcerer's Stone
2 |
3 |
4 | CHAPTER ONE
5 |
6 | THE BOY WHO LIVED
7 |
8 | Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say
9 | that they were perfectly normal, thank you very much. They were the last
10 | people you'd expect to be involved in anything strange or mysterious,
11 | because they just didn't hold with such nonsense.
12 |
13 | Mr. Dursley was the director of a firm called Grunnings, which made
14 | drills. He was a big, beefy man with hardly any neck, although he did
15 | have a very large mustache. Mrs. Dursley was thin and blonde and had
16 | nearly twice the usual amount of neck, which came in very useful as she
17 | spent so much of her time craning over garden fences, spying on the
18 | neighbors. The Dursleys had a small son called Dudley and in their
19 | opinion there was no finer boy anywhere.
20 |
21 | The Dursleys had everything they wanted, but they also had a secret, and
22 | their greatest fear was that somebody would discover it. They didn't
23 | think they could bear it if anyone found out about the Potters. Mrs.
24 | Potter was Mrs. Dursley's sister, but they hadn't met for several years;
25 | in fact, Mrs. Dursley pretended she didn't have a sister, because her
26 | sister and her good-for-nothing husband were as unDursleyish as it was
27 | possible to be. The Dursleys shuddered to think what the neighbors would
28 | say if the Potters arrived in the street. The Dursleys knew that the
29 | Potters had a small son, too, but they had never even seen him. This boy
30 | was another good reason for keeping the Potters away; they didn't want
31 | Dudley mixing with a child like that.
32 |
33 | When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story
34 | starts, there was nothing about the cloudy sky outside to suggest that
35 | strange and mysterious things would soon be happening all over the
36 | country. Mr. Dursley hummed as he picked out his most boring tie for
37 | work, and Mrs. Dursley gossiped away happily as she wrestled a screaming
38 | Dudley into his high chair.
39 |
40 | None of them noticed a large, tawny owl flutter past the window.
41 |
42 | At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs.
43 | Dursley on the cheek, and tried to kiss Dudley good-bye but missed,
44 | because Dudley was now having a tantrum and throwing his cereal at the
45 | walls. "Little tyke," chortled Mr. Dursley as he left the house. He got
46 | into his car and backed out of number four's drive.
47 |
48 | It was on the corner of the street that he noticed the first sign of
49 | something peculiar -- a cat reading a map. For a second, Mr. Dursley
50 | didn't realize what he had seen -- then he jerked his head around to
51 | look again. There was a tabby cat standing on the corner of Privet
52 | Drive, but there wasn't a map in sight. What could he have been thinking
53 | of? It must have been a trick of the light. Mr. Dursley blinked and
54 | stared at the cat. It stared back. As Mr. Dursley drove around the
55 | corner and up the road, he watched the cat in his mirror. It was now
56 | reading the sign that said Privet Drive -- no, looking at the sign; cats
57 | couldn't read maps or signs. Mr. Dursley gave himself a little shake and
58 | put the cat out of his mind. As he drove toward town he thought of
59 | nothing except a large order of drills he was hoping to get that day.
60 |
61 | But on the edge of town, drills were driven out of his mind by something
62 | else. As he sat in the usual morning traffic jam, he couldn't help
63 | noticing that there seemed to be a lot of strangely dressed people
64 | about. People in cloaks. Mr. Dursley couldn't bear people who dressed in
65 | funny clothes -- the getups you saw on young people! He supposed this
66 | was some stupid new fashion. He drummed his fingers on the steering
67 | wheel and his eyes fell on a huddle of these weirdos standing quite
68 | close by. They were whispering excitedly together. Mr. Dursley was
69 | enraged to see that a couple of them weren't young at all; why, that man
70 | had to be older than he was, and wearing an emerald-green cloak! The
71 | nerve of him! But then it struck Mr. Dursley that this was probably some
72 | silly stunt -- these people were obviously collecting for something...
73 | yes, that would be it. The traffic moved on and a few minutes later, Mr.
74 | Dursley arrived in the Grunnings parking lot, his mind back on drills.
75 |
76 | Mr. Dursley always sat with his back to the window in his office on the
77 | ninth floor. If he hadn't, he might have found it harder to concentrate
78 | on drills that morning. He didn't see the owls swoop ing past in broad
79 | daylight, though people down in the street did; they pointed and gazed
80 | open- mouthed as owl after owl sped overhead. Most of them had never
81 | seen an owl even at nighttime. Mr. Dursley, however, had a perfectly
82 | normal, owl-free morning. He yelled at five different people. He made
83 | several important telephone calls and shouted a bit more. He was in a
84 | very good mood until lunchtime, when he thought he'd stretch his legs
85 | and walk across the road to buy himself a bun from the bakery.
86 |
87 | He'd forgotten all about the people in cloaks until he passed a group of
88 | them next to the baker's. He eyed them angrily as he passed. He didn't
89 | know why, but they made him uneasy. This bunch were whispering
90 | excitedly, too, and he couldn't see a single collecting tin. It was on
91 | his way back past them, clutching a large doughnut in a bag, that he
92 | caught a few words of what they were saying.
93 |
94 | "The Potters, that's right, that's what I heard yes, their son, Harry"
95 |
96 | Mr. Dursley stopped dead. Fear flooded him. He looked back at the
97 | whisperers as if he wanted to say something to them, but thought better
98 | of it.
99 |
100 | He dashed back across the road, hurried up to his office, snapped at his
101 | secretary not to disturb him, seized his telephone, and had almost
102 | finished dialing his home number when he changed his mind. He put the
103 | receiver back down and stroked his mustache, thinking... no, he was
104 | being stupid. Potter wasn't such an unusual name. He was sure there were
105 | lots of people called Potter who had a son called Harry. Come to think
106 | of it, he wasn't even sure his nephew was called Harry. He'd never even
107 | seen the boy. It might have been Harvey. Or Harold. There was no point
108 | in worrying Mrs. Dursley; she always got so upset at any mention of her
109 | sister. He didn't blame her -- if he'd had a sister like that... but all
110 | the same, those people in cloaks...
111 |
112 | He found it a lot harder to concentrate on drills that afternoon and
113 | when he left the building at five o'clock, he was still so worried that
114 | he walked straight into someone just outside the door.
115 |
116 | "Sorry," he grunted, as the tiny old man stumbled and almost fell. It
117 | was a few seconds before Mr. Dursley realized that the man was wearing a
118 | violet cloak. He didn't seem at all upset at being almost knocked to the
119 | ground. On the contrary, his face split into a wide smile and he said in
120 | a squeaky voice that made passersby stare, "Don't be sorry, my dear sir,
121 | for nothing could upset me today! Rejoice, for You-Know-Who has gone at
122 | last! Even Muggles like yourself should be celebrating, this happy,
123 | happy day!"
124 |
125 | And the old man hugged Mr. Dursley around the
126 | middle and walked off.
127 |
128 | Mr. Dursley stood rooted to the spot. He had been
129 | hugged by a complete stranger. He also thought he
130 | had been called a Muggle, whatever that was. He was
131 | rattled. He hurried to his car and set off for home,
132 | hoping he was imagining things, which he had never
133 | hoped before, because he didn’t approve of
134 | imagination.
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | sentencepiece
2 | protobuf<=3.20.1
3 | nltk
4 | datasets>=1.17.0
5 | absl-py
6 | rouge-score
7 | pandas
8 | transformers>=4.27.0
9 | wandb
10 | makefun>=1.14.0
11 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abertsch72/unlimiformer/e38b0149488636da9528c7504d2befdfb76f6d98/src/__init__.py
--------------------------------------------------------------------------------
/src/configs/data/contract_nli.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "contract_nli",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 64,
6 | "pad_prefix": true,
7 | "generation_max_length": 8,
8 | "num_train_epochs": 20,
9 | "metric_names": ["exact_match"],
10 | "metric_for_best_model": "exact_match",
11 | "greater_is_better": true
12 | }
--------------------------------------------------------------------------------
/src/configs/data/gov_report.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "gov_report",
4 | "max_source_length": 16384,
5 | "generation_max_length": 1024,
6 | "max_prefix_length": 0,
7 | "pad_prefix": false,
8 | "num_train_epochs": 10,
9 | "metric_names": ["rouge"],
10 | "metric_for_best_model": "rouge/geometric_mean",
11 | "greater_is_better": true
12 | }
--------------------------------------------------------------------------------
/src/configs/data/hotpotqa.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "hotpotqa",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 64,
6 | "pad_prefix": true,
7 | "generation_max_length": 128,
8 | "num_train_epochs": 9,
9 | "metric_names": ["f1", "exact_match"],
10 | "metric_for_best_model": "f1",
11 | "greater_is_better": true
12 | }
--------------------------------------------------------------------------------
/src/configs/data/hotpotqa_second_only.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "hotpotqa_second_only",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 64,
6 | "pad_prefix": true,
7 | "generation_max_length": 128,
8 | "num_train_epochs": 9,
9 | "metric_names": ["f1", "exact_match"],
10 | "metric_for_best_model": "f1",
11 | "greater_is_better": true
12 | }
--------------------------------------------------------------------------------
/src/configs/data/narative_qa.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "narrative_qa",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 64,
6 | "pad_prefix": true,
7 | "num_train_epochs": 2,
8 | "generation_max_length": 128,
9 | "metric_names": ["f1"],
10 | "metric_for_best_model": "f1",
11 | "greater_is_better": true
12 | }
--------------------------------------------------------------------------------
/src/configs/data/qasper.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "qasper",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 64,
6 | "pad_prefix": true,
7 | "generation_max_length": 128,
8 | "num_train_epochs": 20,
9 | "metric_names": ["f1"],
10 | "metric_for_best_model": "f1",
11 | "greater_is_better": true
12 | }
--------------------------------------------------------------------------------
/src/configs/data/qmsum.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "qmsum",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 64,
6 | "pad_prefix": true,
7 | "num_train_epochs": 20,
8 | "generation_max_length": 1024,
9 | "metric_names": ["rouge"],
10 | "metric_for_best_model": "rouge/geometric_mean",
11 | "greater_is_better": true
12 | }
--------------------------------------------------------------------------------
/src/configs/data/quality.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "quality",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 160,
6 | "pad_prefix": true,
7 | "num_train_epochs": 20,
8 | "generation_max_length": 128,
9 | "metric_names": ["exact_match"],
10 | "metric_for_best_model": "exact_match",
11 | "greater_is_better": true
12 | }
--------------------------------------------------------------------------------
/src/configs/data/squad.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "squad",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 64,
6 | "pad_prefix": true,
7 | "num_train_epochs": 3,
8 | "generation_max_length": 128,
9 | "metric_names": ["f1", "exact_match"],
10 | "metric_for_best_model": "f1",
11 | "greater_is_better": true
12 | }
--------------------------------------------------------------------------------
/src/configs/data/squad_ordered_distractors.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "squad_ordered_distractors",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 64,
6 | "pad_prefix": true,
7 | "num_train_epochs": 3,
8 | "generation_max_length": 128,
9 | "metric_names": ["f1", "exact_match"],
10 | "metric_for_best_model": "f1",
11 | "greater_is_better": true
12 | }
--------------------------------------------------------------------------------
/src/configs/data/squad_shuffled_distractors.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "squad_shuffled_distractors",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 64,
6 | "pad_prefix": true,
7 | "num_train_epochs": 3,
8 | "generation_max_length": 128,
9 | "metric_names": ["f1", "exact_match"],
10 | "metric_for_best_model": "f1",
11 | "greater_is_better": true
12 | }
--------------------------------------------------------------------------------
/src/configs/data/summ_screen_fd.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "summ_screen_fd",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 0,
6 | "pad_prefix": false,
7 | "num_train_epochs": 10,
8 | "generation_max_length": 1024,
9 | "metric_names": ["rouge"],
10 | "metric_for_best_model": "rouge/geometric_mean",
11 | "greater_is_better": true
12 | }
--------------------------------------------------------------------------------
/src/configs/model/bart_base_sled.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_name_or_path": "tau/bart-base-sled",
3 | "use_auth_token": false,
4 | "max_target_length": 1024,
5 | "fp16": true
6 | }
--------------------------------------------------------------------------------
/src/configs/model/bart_large_sled.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_name_or_path": "tau/bart-large-sled",
3 | "use_auth_token": false,
4 | "max_target_length": 1024,
5 | "fp16": true
6 | }
--------------------------------------------------------------------------------
/src/configs/model/primera_govreport_sled.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "tau/sled",
3 | "underlying_config": "allenai/PRIMERA",
4 | "context_size": 4096,
5 | "window_fraction": 0.5,
6 | "prepend_prefix": true,
7 | "encode_prefix": true,
8 | "sliding_method": "dynamic"
9 | }
--------------------------------------------------------------------------------
/src/configs/training/base_training_args.json:
--------------------------------------------------------------------------------
1 | {
2 | "eval_steps_override": 0.5,
3 | "save_steps_override": 0.5,
4 | "evaluation_strategy": "steps",
5 | "eval_fraction": 1000,
6 | "predict_with_generate": true,
7 | "gradient_checkpointing": true,
8 | "do_train": true,
9 | "do_eval": true,
10 | "seed": 42,
11 | "warmup_ratio": 0.1,
12 | "save_total_limit": 2,
13 | "preprocessing_num_workers": 1,
14 | "load_best_model_at_end": true,
15 | "lr_scheduler": "linear",
16 | "adam_epsilon": 1e-6,
17 | "adam_beta1": 0.9,
18 | "adam_beta2": 0.98,
19 | "weight_decay": 0.001,
20 | "patience": 10,
21 | "extra_metrics": "bertscore"
22 | }
--------------------------------------------------------------------------------
/src/index_building.py:
--------------------------------------------------------------------------------
1 | import faiss
2 | import faiss.contrib.torch_utils
3 | import time
4 | import logging
5 |
6 | import torch
7 | import numpy as np
8 |
9 | code_size = 64
10 |
11 | class DatastoreBatch():
12 | def __init__(self, dim, batch_size, flat_index=False, gpu_index=False, verbose=False, index_device=None) -> None:
13 | self.indices = []
14 | self.batch_size = batch_size
15 | self.device = index_device if index_device is not None else torch.device('cuda' if gpu_index else 'cpu')
16 | for i in range(batch_size):
17 | self.indices.append(Datastore(dim, use_flat_index=flat_index, gpu_index=gpu_index, verbose=verbose, device=self.device))
18 |
19 | def move_to_gpu(self):
20 | for i in range(self.batch_size):
21 | self.indices[i].move_to_gpu()
22 |
23 | def add_keys(self, keys, num_keys_to_add_at_a_time=100000):
24 | for i in range(self.batch_size):
25 | self.indices[i].add_keys(keys[i], num_keys_to_add_at_a_time)
26 |
27 | def train_index(self, keys):
28 | for index, example_keys in zip(self.indices, keys):
29 | index.train_index(example_keys)
30 |
31 | def search(self, queries, k):
32 | found_scores, found_values = [], []
33 | for i in range(self.batch_size):
34 | scores, values = self.indices[i].search(queries[i], k)
35 | found_scores.append(scores)
36 | found_values.append(values)
37 | return torch.stack(found_scores, dim=0), torch.stack(found_values, dim=0)
38 |
39 | def search_and_reconstruct(self, queries, k):
40 | found_scores, found_values = [], []
41 | found_vectors = []
42 | for i in range(self.batch_size):
43 | scores, values, vectors = self.indices[i].search_and_reconstruct(queries[i], k)
44 | found_scores.append(scores)
45 | found_values.append(values)
46 | found_vectors.append(vectors)
47 | return torch.stack(found_scores, dim=0), torch.stack(found_values, dim=0), torch.stack(found_vectors, dim=0)
48 |
49 | class Datastore():
50 | def __init__(self, dim, use_flat_index=False, gpu_index=False, verbose=False, device=None) -> None:
51 | self.dimension = dim
52 | self.device = device if device is not None else torch.device('cuda' if gpu_index else 'cpu')
53 | self.logger = logging.getLogger('index_building')
54 | self.logger.setLevel(20)
55 | self.use_flat_index = use_flat_index
56 | self.gpu_index = gpu_index
57 |
58 | # Initialize faiss index
59 | # TODO: is preprocessing efficient enough to spend time on?
60 | if not use_flat_index:
61 | self.index = faiss.IndexFlatIP(self.dimension) # inner product index because we use IP attention
62 |
63 | # need to wrap in index ID map to enable add_with_ids
64 | # self.index = faiss.IndexIDMap(self.index)
65 |
66 | self.index_size = 0
67 | # if self.gpu_index:
68 | # self.move_to_gpu()
69 |
70 | def move_to_gpu(self):
71 | if self.use_flat_index:
72 | # self.keys = self.keys.to(self.device)
73 | return
74 | else:
75 | co = faiss.GpuClonerOptions()
76 | co.useFloat16 = True
77 | self.index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), self.device.index, self.index, co)
78 |
79 | def train_index(self, keys):
80 | if self.use_flat_index:
81 | self.add_keys(keys=keys, index_is_trained=True)
82 | else:
83 | keys = keys.cpu().float()
84 | ncentroids = int(keys.shape[0] / 128)
85 | self.index = faiss.IndexIVFPQ(self.index, self.dimension,
86 | ncentroids, code_size, 8)
87 | self.index.nprobe = min(32, ncentroids)
88 | # if not self.gpu_index:
89 | # keys = keys.cpu()
90 |
91 | self.logger.info('Training index')
92 | start_time = time.time()
93 | self.index.train(keys)
94 | self.logger.info(f'Training took {time.time() - start_time} s')
95 | self.add_keys(keys=keys, index_is_trained=True)
96 | # self.keys = None
97 | if self.gpu_index:
98 | self.move_to_gpu()
99 |
100 | def add_keys(self, keys, num_keys_to_add_at_a_time=1000000, index_is_trained=False):
101 | self.keys = keys
102 | if not self.use_flat_index and index_is_trained:
103 | start = 0
104 | while start < keys.shape[0]:
105 | end = min(len(keys), start + num_keys_to_add_at_a_time)
106 | to_add = keys[start:end]
107 | # if not self.gpu_index:
108 | # to_add = to_add.cpu()
109 | # self.index.add_with_ids(to_add, torch.arange(start+self.index_size, end+self.index_size))
110 | self.index.add(to_add)
111 | self.index_size += end - start
112 | start += end
113 | if (start % 1000000) == 0:
114 | self.logger.info(f'Added {start} tokens so far')
115 | # else:
116 | # self.keys.append(keys)
117 |
118 | # self.logger.info(f'Adding total {start} keys')
119 | # self.logger.info(f'Adding took {time.time() - start_time} s')
120 |
121 | def search_and_reconstruct(self, queries, k):
122 | if len(queries.shape) == 1: # searching for only 1 vector, add one extra dim
123 | self.logger.info("Searching for a single vector; unsqueezing")
124 | queries = queries.unsqueeze(0)
125 | # self.logger.info("Searching with reconstruct")
126 | assert queries.shape[-1] == self.dimension # query vectors are same shape as "key" vectors
127 | scores, values, vectors = self.index.index.search_and_reconstruct(queries.cpu().detach(), k)
128 | # self.logger.info("Searching done")
129 | return scores, values, vectors
130 |
131 | def search(self, queries, k):
132 | # model_device = queries.device
133 | # model_dtype = queries.dtype
134 | if len(queries.shape) == 1: # searching for only 1 vector, add one extra dim
135 | self.logger.info("Searching for a single vector; unsqueezing")
136 | queries = queries.unsqueeze(0)
137 | assert queries.shape[-1] == self.dimension # query vectors are same shape as "key" vectors
138 | # if not self.gpu_index:
139 | # queries = queries.cpu()
140 | # else:
141 | # queries = queries.to(self.device)
142 | if self.use_flat_index:
143 | if self.gpu_index:
144 | scores, values = faiss.knn_gpu(faiss.StandardGpuResources(), queries, self.keys, k,
145 | metric=faiss.METRIC_INNER_PRODUCT, device=self.device.index)
146 | else:
147 | scores, values = faiss.knn(queries, self.keys, k, metric=faiss.METRIC_INNER_PRODUCT)
148 | scores = torch.from_numpy(scores).to(queries.dtype)
149 | values = torch.from_numpy(values) #.to(model_dtype)
150 | else:
151 | scores, values = self.index.search(queries.float(), k)
152 |
153 | # avoid returning -1 as a value
154 | # TODO: get a handle on the attention mask and mask the values that were -1
155 | values = torch.where(torch.logical_or(values < 0, values >= self.keys.shape[0]), torch.zeros_like(values), values)
156 | # self.logger.info("Searching done")
157 | # return scores.to(model_dtype).to(model_device), values.to(model_device)
158 | return scores, values
159 |
160 |
161 |
--------------------------------------------------------------------------------
/src/inference-example.py:
--------------------------------------------------------------------------------
1 | from unlimiformer import Unlimiformer
2 | from random_training_unlimiformer import RandomTrainingUnlimiformer
3 | from usage import UnlimiformerArguments, training_addin
4 |
5 | from transformers import BartForConditionalGeneration, AutoTokenizer
6 | from datasets import load_dataset
7 | import torch
8 |
9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10 |
11 | # example using govreport
12 | modelname = "abertsch/unlimiformer-bart-govreport-alternating"
13 | dataset = load_dataset("urialon/gov_report_validation")
14 |
15 | tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
16 | model = BartForConditionalGeneration.from_pretrained(modelname)
17 |
18 | example_input = dataset['validation'][0]['input']
19 |
20 | example = tokenizer(example_input, truncation=False, return_tensors="pt")
21 | truncated_example = tokenizer(example_input, truncation=True, max_length=1024, return_tensors="pt")
22 |
23 | example.to(device)
24 | truncated_example.to(device)
25 |
26 | print(f"INPUT LENGTH (tokens): {example['input_ids'].shape[-1]}")
27 |
28 |
29 | defaults = UnlimiformerArguments()
30 | unlimiformer_kwargs = {
31 | 'layer_begin': defaults.layer_begin,
32 | 'layer_end': defaults.layer_end,
33 | 'unlimiformer_head_num': defaults.unlimiformer_head_num,
34 | 'exclude_attention': defaults.unlimiformer_exclude,
35 | 'chunk_overlap': defaults.unlimiformer_chunk_overlap,
36 | 'model_encoder_max_len': defaults.unlimiformer_chunk_size,
37 | 'verbose': defaults.unlimiformer_verbose, 'tokenizer': tokenizer,
38 | 'unlimiformer_training': defaults.unlimiformer_training,
39 | 'use_datastore': defaults.use_datastore,
40 | 'flat_index': defaults.flat_index,
41 | 'test_datastore': defaults.test_datastore,
42 | 'reconstruct_embeddings': defaults.reconstruct_embeddings,
43 | 'gpu_datastore': defaults.gpu_datastore,
44 | 'gpu_index': defaults.gpu_index
45 | }
46 |
47 | model.to(device)
48 | # the output of the model /without/ using unlimiformer
49 | truncated_out = tokenizer.batch_decode(model.generate(**truncated_example, max_length=512))
50 |
51 | model = Unlimiformer.convert_model(model, **unlimiformer_kwargs)
52 | model.eval()
53 | model.to(device)
54 |
55 | # the output of the model /with/ unlimiformer
56 | unlimiformer_out = tokenizer.batch_decode(model.generate(**example, max_length=512), ignore_special_tokens=True)[0]
57 | print(unlimiformer_out)
58 |
--------------------------------------------------------------------------------
/src/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .metrics import load_metric, download_metric
2 |
--------------------------------------------------------------------------------
/src/metrics/metrics.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict
2 | import os
3 | import importlib
4 | from abc import ABC, abstractmethod
5 | import inspect
6 | import shutil
7 |
8 | import numpy as np
9 |
10 | from utils.decoding import decode
11 | from datasets import load_metric as hf_load_metric
12 | from huggingface_hub import hf_hub_download
13 |
14 |
15 | class Metric(ABC):
16 | def __init__(self, **kwargs) -> None:
17 | super().__init__()
18 | self._kwargs = kwargs
19 |
20 | self.prefix = os.path.splitext(os.path.basename(inspect.getfile(self.__class__)))[0]
21 | self.requires_decoded = False
22 |
23 | def __call__(self, id_to_pred, id_to_labels, is_decoded=False):
24 | if self.requires_decoded and is_decoded is False:
25 | id_to_pred = self._decode(id_to_pred)
26 | id_to_labels = self._decode(id_to_labels)
27 | return self._compute_metrics(id_to_pred, id_to_labels)
28 |
29 | @abstractmethod
30 | def _compute_metrics(self, id_to_pred, id_to_labels) -> Dict[str, float]:
31 | return
32 |
33 | def _decode(self, id_to_something):
34 | tokenizer = self._kwargs.get("tokenizer")
35 | data_args = self._kwargs.get("data_args")
36 | return decode(id_to_something, tokenizer, data_args)
37 |
38 |
39 | class MetricCollection(Metric):
40 | def __init__(self, metrics: List[Metric], **kwargs):
41 | super().__init__(**kwargs)
42 | self._metrics = metrics
43 |
44 | def __call__(self, id_to_pred, id_to_labels):
45 | return self._compute_metrics(id_to_pred, id_to_labels)
46 |
47 | def _compute_metrics(self, id_to_pred, id_to_labels):
48 | results = {}
49 |
50 | id_to_pred_decoded = None
51 | id_to_labels_decoded = None
52 | for metric in self._metrics:
53 | metric_prefix = f"{metric.prefix}/" if metric.prefix else ""
54 | if metric.requires_decoded:
55 | if id_to_pred_decoded is None:
56 | id_to_pred_decoded = self._decode(id_to_pred)
57 | if id_to_labels_decoded is None:
58 | id_to_labels_decoded = self._decode(id_to_labels)
59 |
60 | result = metric(id_to_pred_decoded, id_to_labels_decoded, is_decoded=True)
61 | else:
62 | result = metric(id_to_pred, id_to_labels)
63 |
64 | results.update({f"{metric_prefix}{k}": np.mean(v) if type(v) is list else v for k, v in result.items() if type(v) is not str})
65 |
66 | results["num_predicted"] = len(id_to_pred)
67 | results["mean_prediction_length_characters"] = np.mean([len(pred) for pred in id_to_pred_decoded.values()])
68 |
69 | elem = next(iter(id_to_pred.values()))
70 | if not ((isinstance(elem, list) and isinstance(elem[0], str)) or isinstance(elem, str)):
71 | tokenizer = self._kwargs["tokenizer"]
72 | results["mean_prediction_length_tokens"] = np.mean(
73 | [np.count_nonzero(np.array(pred) != tokenizer.pad_token_id) for pred in id_to_pred.values()]
74 | ) # includes BOS/EOS tokens
75 |
76 | results = {key: round(value, 4) for key, value in results.items()}
77 | return results
78 |
79 |
80 | def load_metric(paths: List[str], **kwargs):
81 | if paths is None or len(paths) == 0:
82 | return None
83 | if isinstance(paths, str):
84 | paths = [paths]
85 | else:
86 | paths = [path for path in paths]
87 |
88 | metric_cls_list = []
89 |
90 | scrolls_custom_metrics = []
91 | to_remove = []
92 | for i, path in enumerate(paths):
93 | if not os.path.isfile(path):
94 | scrolls_custom_metrics.append(path)
95 | to_remove.append(i)
96 | for i in sorted(to_remove, reverse=True):
97 | del paths[i]
98 | if len(scrolls_custom_metrics) > 0:
99 | scrolls_custom_metrics.insert(0, "") # In order to have an identifying comma in the beginning
100 | metric_cls_list.append(ScrollsWrapper(",".join(scrolls_custom_metrics), **kwargs))
101 |
102 | for path in paths:
103 | path = path.strip()
104 | if len(path) == 0:
105 | continue
106 | if os.path.isfile(path) is False:
107 | path = os.path.join("src", "metrics", f"{path}.py")
108 |
109 | module = path[:-3].replace(os.sep, ".")
110 |
111 | metric_cls = import_main_class(module)
112 | metric_cls_list.append(metric_cls(**kwargs))
113 |
114 | return MetricCollection(metric_cls_list, **kwargs)
115 |
116 |
117 | # Modified from datasets.load
118 | def import_main_class(module_path):
119 | """Import a module at module_path and return its main class"""
120 | module = importlib.import_module(module_path)
121 |
122 | main_cls_type = Metric
123 |
124 | # Find the main class in our imported module
125 | module_main_cls = None
126 | for name, obj in module.__dict__.items():
127 | if isinstance(obj, type) and issubclass(obj, main_cls_type):
128 | if inspect.isabstract(obj):
129 | continue
130 | module_main_cls = obj
131 | break
132 |
133 | return module_main_cls
134 |
135 |
136 | class ScrollsWrapper(Metric):
137 | def __init__(self, comma_separated_metric_names, **kwargs) -> None:
138 | super().__init__(**kwargs)
139 | self.prefix = None
140 |
141 | self._metric = hf_load_metric(download_metric(), comma_separated_metric_names, keep_in_memory=True)
142 |
143 | self.requires_decoded = True
144 |
145 | def _compute_metrics(self, id_to_pred, id_to_labels) -> Dict[str, float]:
146 | return self._metric.compute(**self._metric.convert_from_map_format(id_to_pred, id_to_labels))
147 |
148 | class HFMetricWrapper(Metric):
149 | def __init__(self, metric_name, **kwargs) -> None:
150 | super().__init__(**kwargs)
151 | self._metric = hf_load_metric(metric_name)
152 | self.kwargs = HFMetricWrapper.metric_specific_kwargs.get(metric_name, {})
153 | self.requires_decoded = True
154 | self.prefix = metric_name
155 | self.requires_decoded = True
156 |
157 | def _compute_metrics(self, id_to_pred, id_to_labels) -> Dict[str, float]:
158 | return self._metric.compute(**self.convert_from_map_format(id_to_pred, id_to_labels), **self.kwargs)
159 |
160 | def convert_from_map_format(self, id_to_pred, id_to_labels):
161 | index_to_id = list(id_to_pred.keys())
162 | predictions = [id_to_pred[id_] for id_ in index_to_id]
163 | references = [id_to_labels[id_] for id_ in index_to_id]
164 | return {"predictions": predictions, "references": references}
165 |
166 | metric_specific_kwargs = {
167 | 'bertscore': {
168 | # 'model_type': 'microsoft/deberta-large-mnli' or the larger 'microsoft/deberta-xlarge-mnli'
169 | 'model_type': 'facebook/bart-large-mnli', # has context window of 1024,
170 | 'num_layers': 11 # according to: https://docs.google.com/spreadsheets/d/1RKOVpselB98Nnh_EOC4A2BYn8_201tmPODpNWu4w7xI/edit#gid=0
171 | }
172 | }
173 |
174 |
175 | def download_metric():
176 | # here we load the custom metrics
177 | scrolls_metric_path = hf_hub_download(repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type='dataset')
178 | updated_scrolls_metric_path = (
179 | os.path.dirname(scrolls_metric_path) + os.path.basename(scrolls_metric_path).replace(".", "_") + ".py"
180 | )
181 | shutil.copy(scrolls_metric_path, updated_scrolls_metric_path)
182 | return updated_scrolls_metric_path
183 |
--------------------------------------------------------------------------------
/src/random_training_unlimiformer.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 | from enum import Enum, auto
6 | from unlimiformer import Unlimiformer, ModelType, UnlimiformerBART, UnlimiformerT5, UnlimiformerLED
7 | from transformers import BartModel, BartForConditionalGeneration, \
8 | T5Model, T5ForConditionalGeneration, \
9 | LEDModel, LEDForConditionalGeneration, \
10 | AutoModelForSeq2SeqLM
11 |
12 | class RandomTrainingUnlimiformer(Unlimiformer[ModelType]):
13 | def __init__(self, model: ModelType, *args, **kwargs):
14 | super().__init__(model, *args, **kwargs)
15 | self.training_hooks_injected = False
16 | self.train_step = 0
17 |
18 | @classmethod
19 | def convert_model(cls, model, *args, **kwargs):
20 | # model_clone = AutoModelForSeq2SeqLM.from_config(model.config)
21 | # model_clone.load_state_dict(model.state_dict()).to(args.device)
22 | type_to_class = {
23 | BartModel: RandomUnlimiformerBART,
24 | BartForConditionalGeneration: RandomUnlimiformerBART,
25 | T5Model: RandomUnlimiformerT5,
26 | T5ForConditionalGeneration: RandomUnlimiformerT5,
27 | LEDModel: RandomUnlimiformerLED,
28 | LEDForConditionalGeneration: RandomUnlimiformerLED,
29 | }
30 | type_to_class[type(model)](model, *args, **kwargs)
31 | return model
32 |
33 | def pre_eval_hook(self):
34 | self.remove_training_hooks(self.model)
35 | self.inject_hooks(self.model)
36 | self.original_model_eval_func()
37 |
38 | def pre_train_hook(self, mode=True):
39 | # mode=True means model.train() is called
40 | # mode=False means model.eval() is called
41 | torch.cuda.empty_cache()
42 | if mode is True:
43 | self.break_out(self.model)
44 | self.remove_training_hooks(self.model)
45 | if self.unlimiformer_training and self.train_step % 2 == 0:
46 | super().inject_training_hooks(self.model)
47 | else:
48 | self.inject_training_hooks(self.model)
49 | self.train_step += 1
50 | self.original_model_train_func(mode)
51 |
52 | def inject_training_hooks(self, model):
53 | if self.training_hooks_injected:
54 | return
55 | # self.original_forward_func = model.forward
56 | model.forward = self.random_inputs_forward_hook
57 |
58 | decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end)
59 |
60 | self.original_decoder_layer_self_attn_forward_funcs = []
61 | for decoder_layer in decoder_layers_to_run:
62 | attention = self.self_attention(decoder_layer)
63 | self.original_decoder_layer_self_attn_forward_funcs.append(attention.forward)
64 | attention.forward = self.create_self_attn_random_pre_forward_hook(attention.forward)
65 |
66 | self.original_decoder_layer_forward_funcs = []
67 | for decoder_layer in decoder_layers_to_run:
68 | self.original_decoder_layer_forward_funcs.append(decoder_layer.forward)
69 | decoder_layer.forward = self.create_decoder_layer_random_func(decoder_layer.forward, decoder_layer)
70 |
71 | self.original_decoder_layer_cross_attn_forward_funcs = []
72 | for i, decoder_layer in enumerate(decoder_layers_to_run):
73 | attention = self.cross_attention(decoder_layer)
74 | self.original_decoder_layer_cross_attn_forward_funcs.append(attention.forward)
75 |
76 | self.inject_hooks_for_unaffected_layers(model, decoder_layers_to_run)
77 |
78 | self.training_hooks_injected = True
79 |
80 | def create_self_attn_random_pre_forward_hook(self, original_self_attn_forward_func):
81 | def self_attention_pre_forward_hook(*args, **kwargs):
82 | kwargs['past_key_value'] = None
83 | return original_self_attn_forward_func(*args, **kwargs)
84 |
85 | return self_attention_pre_forward_hook
86 |
87 | def create_decoder_layer_random_func(self, decoder_layer_original_forward_func, decoder_layer):
88 | def checkpointed_decoder_layer(
89 | hidden_states: torch.Tensor,
90 | attention_mask=None,
91 | encoder_hidden_states=None,
92 | encoder_attention_mask=None,
93 | layer_head_mask=None,
94 | cross_attn_layer_head_mask=None,
95 | past_key_value=None,
96 | output_attentions=False,
97 | position_bias=None,
98 | encoder_decoder_position_bias=None,
99 | use_cache=True):
100 |
101 |
102 |
103 | def sample_and_forward(hidden_states, attention_mask,
104 | encoder_hidden_states, encoder_attention_mask, layer_head_mask,
105 | cross_attn_layer_head_mask, past_key_value,
106 | output_attentions, use_cache, long_inputs, long_inputs_mask, rand_indices,
107 | position_bias, encoder_decoder_position_bias):
108 |
109 | sampled_input, _ = self.sample_long_input(long_inputs, long_inputs_mask, rand_indices)
110 | key, value = self.create_key_value(sampled_input, decoder_layer)
111 | decoder_layer_args = self.create_decoder_layer_args(
112 | hidden_states=hidden_states,
113 | attention_mask=attention_mask,
114 | encoder_hidden_states=encoder_hidden_states,
115 | encoder_attention_mask=encoder_attention_mask,
116 | layer_head_mask=layer_head_mask,
117 | cross_attn_layer_head_mask=cross_attn_layer_head_mask,
118 | past_key_value=past_key_value,
119 | output_attentions=output_attentions,
120 | position_bias=position_bias,
121 | encoder_decoder_position_bias=encoder_decoder_position_bias,
122 | use_cache=use_cache,
123 | key=key,value=value
124 | )
125 | return decoder_layer_original_forward_func(**decoder_layer_args)
126 |
127 |
128 | with torch.no_grad():
129 | # This sampling must be done outside of the checkpoint, to ensure that the same sampling happens
130 | # both in "forward" and "backward" passes
131 | rand_indices = self.sample_random_indices()
132 |
133 | return torch.utils.checkpoint.checkpoint(
134 | sample_and_forward, hidden_states, attention_mask,
135 | encoder_hidden_states, encoder_attention_mask, layer_head_mask,
136 | cross_attn_layer_head_mask, None,
137 | output_attentions, use_cache, self.long_inputs_encoded, self.long_inputs_mask, rand_indices,
138 | position_bias, encoder_decoder_position_bias)
139 |
140 | return checkpointed_decoder_layer
141 |
142 | def sample_random_indices(self):
143 | rand_indices_list = []
144 | seq_lens = self.long_inputs_mask.sum(-1).tolist()
145 | for seq_len in seq_lens:
146 | if seq_len < self.actual_model_window_size:
147 | rand_indices = torch.arange(self.actual_model_window_size).to(self.device)
148 | rand_indices_list.append(rand_indices)
149 | continue
150 |
151 | rand_indices = torch.torch.randperm(seq_len)[:self.actual_model_window_size].to(self.device)
152 | if seq_len < self.actual_model_window_size:
153 | padding = max(self.actual_model_window_size - seq_len, 0)
154 | rand_indices = torch.cat([rand_indices, torch.arange(padding).to(self.device) + seq_len], axis=-1).to(self.device)
155 | rand_indices_list.append(rand_indices)
156 | rand_indices = torch.stack(rand_indices_list, dim=0)
157 | return rand_indices
158 |
159 | def random_inputs_forward_hook(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
160 | self.model.base_model.decoder.gradient_checkpointing = False
161 | self.long_inputs_encoded, self.long_inputs_mask = self.chunked_encode_input(input_ids=input_ids, attention_mask=attention_mask)
162 |
163 | # TODO: should the inputs be sampled or the truncated beginning?
164 | # if self.random_knn_initial_inputs:
165 | # encoded_inputs, encoded_inputs_mask = self.sample_long_input(self.long_inputs_encoded, self.long_inputs_mask)
166 | # else:
167 | encoded_inputs = self.long_inputs_encoded[:, :self.actual_model_window_size]
168 | encoded_inputs_mask = self.long_inputs_mask[:, :self.actual_model_window_size]
169 | return self.original_forward_func(encoder_outputs=(encoded_inputs, ), labels=labels, attention_mask=encoded_inputs_mask, **kwargs)
170 |
171 | def sample_long_input(self, long_inputs_encoded, long_inputs_mask, random_indices=None):
172 | if long_inputs_mask.shape[-1] < self.actual_model_window_size:
173 | return long_inputs_encoded, long_inputs_mask
174 | batch_size = long_inputs_encoded.shape[0]
175 |
176 | if random_indices is None:
177 | random_indices = self.sample_random_indices()
178 | random_mask = torch.zeros_like(long_inputs_mask).to(self.device) \
179 | .scatter_(dim=-1, index=random_indices, src=torch.ones_like(random_indices)).bool().to(self.device)
180 | sampled_input = long_inputs_encoded[random_mask].reshape(batch_size, self.actual_model_window_size, -1).to(self.device)
181 | sampled_mask = long_inputs_mask[random_mask].reshape(batch_size, self.actual_model_window_size).to(self.device)
182 | return sampled_input, sampled_mask
183 |
184 | def chunked_encode_input(self, input_ids, attention_mask):
185 | long_inputs_encoded = []
186 | long_inputs_mask = []
187 | window_indices = self.window_indices(input_ids.shape[-1])
188 |
189 | self.is_input_encoding_pass = True
190 | for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices:
191 | chunk = input_ids[:, context_start_ind:context_end_ind]
192 | chunk_attention_mask = attention_mask[:, context_start_ind:context_end_ind]
193 | output = self.model.base_model.encoder(chunk, attention_mask=chunk_attention_mask, return_dict=True, output_hidden_states=True)
194 | encoder_last_hidden_state = output.last_hidden_state # (batch, time, dim)
195 |
196 | # list of (batch, head, chunked_time, dim)
197 | encoder_last_hidden_state = encoder_last_hidden_state[:, update_start_ind:update_end_ind] # (batch, chunked_time, dim)
198 | chunk_attention_mask = chunk_attention_mask[:, update_start_ind:update_end_ind] # (batch, chunked_time)
199 |
200 | long_inputs_encoded.append(encoder_last_hidden_state) # (batch, chunked_source_len, dim)
201 | long_inputs_mask.append(chunk_attention_mask) # (batch, chunked_source_len)
202 |
203 | long_inputs_encoded = torch.cat(long_inputs_encoded, dim=1) # (batch, source_len, dim)
204 | long_inputs_mask = torch.cat(long_inputs_mask, dim=1) # (batch, source_len)
205 |
206 | self.is_input_encoding_pass = False
207 | if self.verbose:
208 | print(f'Input: '
209 | f'{self.tokenizer.decode(input_ids[0][:self.actual_model_window_size], skip_special_tokens=True)} ||| '
210 | f'{self.tokenizer.decode(input_ids[0][self.actual_model_window_size:], skip_special_tokens=True)}')
211 | print()
212 | return long_inputs_encoded, long_inputs_mask
213 |
214 | class RandomUnlimiformerBART(RandomTrainingUnlimiformer[BartModel], UnlimiformerBART):
215 | def __init__(self, model: BartModel, *args, **kwargs):
216 | super().__init__(model, *args, **kwargs)
217 |
218 | class RandomUnlimiformerT5(RandomTrainingUnlimiformer[T5Model], UnlimiformerT5):
219 | def __init__(self, model: T5Model, *args, **kwargs):
220 | super().__init__(model, *args, **kwargs)
221 |
222 | class RandomUnlimiformerLED(RandomTrainingUnlimiformer[LEDModel], UnlimiformerLED):
223 | def __init__(self, model: LEDModel, *args, **kwargs):
224 | super().__init__(model, *args, **kwargs)
--------------------------------------------------------------------------------
/src/run.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2021 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """
17 | Fine-tuning the library models for sequence to sequence.
18 | """
19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20 | import logging
21 | import os
22 | import sys
23 |
24 | import numpy as np
25 | from unlimiformer import Unlimiformer
26 | from random_training_unlimiformer import RandomTrainingUnlimiformer
27 |
28 | import nltk
29 |
30 | # we import the logging frameworks before any other import to make sure all monkey patching for the logging are active
31 | # from sled import SledConfig
32 |
33 | import wandb
34 | import torch
35 |
36 | sys.path.insert(0, os.path.dirname(__file__)) # seq2seq package path
37 | sys.path.insert(0, os.getcwd())
38 |
39 | from dataclasses import dataclass, field, replace
40 | from typing import List, Optional
41 | import json
42 | from copy import deepcopy
43 | import torch.nn.functional as F
44 |
45 | import datasets
46 |
47 | import transformers
48 | from transformers import (
49 | AutoConfig,
50 | AutoModelForSeq2SeqLM,
51 | AutoTokenizer,
52 | EarlyStoppingCallback,
53 | set_seed, WEIGHTS_NAME,
54 | )
55 | from transformers.trainer_utils import get_last_checkpoint
56 | from transformers import DataCollatorForSeq2Seq
57 |
58 | from datasets import load_dataset
59 |
60 | # noinspection PyUnresolvedReferences
61 | # import sled # *** required so that SledModels will be registered for the AutoClasses ***
62 |
63 | from utils.config import handle_args_to_ignore
64 | from utils.decoding import decode
65 | from metrics import load_metric
66 | from utils.duplicates import drop_duplicates_in_input
67 | from utils.override_training_args import TrainingOverridesArguments
68 | from utils.custom_seq2seq_trainer import CustomTrainer
69 | from utils.custom_hf_argument_parser import CustomHfArgumentParser
70 | from metrics.metrics import HFMetricWrapper, MetricCollection
71 |
72 | logger = logging.getLogger('sled')
73 |
74 | PREFIX_DOC_SEP = '\n\n'
75 |
76 | DEBUG = os.environ.get('DEBUG', 'false').lower() in {'1', 'true', 'yes'} # If set, will set some configuration to help debug
77 | if DEBUG:
78 | assert not torch.cuda.is_available() or torch.cuda.device_count() == 1
79 |
80 |
81 | @dataclass
82 | class ModelArguments:
83 | """
84 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
85 | """
86 |
87 | model_name_or_path: str = field(
88 | default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
89 | )
90 | config_name: Optional[str] = field(
91 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
92 | )
93 | tokenizer_name: Optional[str] = field(
94 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
95 | )
96 | cache_dir: Optional[str] = field(
97 | default=None,
98 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
99 | )
100 | use_fast_tokenizer: bool = field(
101 | default=True,
102 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
103 | )
104 | model_revision: str = field(
105 | default="main",
106 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
107 | )
108 | drop_duplicates_in_eval: bool = field(
109 | default=True,
110 | )
111 |
112 | def __post_init__(self):
113 | pass
114 |
115 |
116 |
117 | @dataclass
118 | class DataTrainingArguments:
119 | """
120 | Arguments pertaining to what data we are going to input our model for training and eval.
121 | """
122 |
123 | dataset_name: Optional[str] = field(
124 | default=None,
125 | metadata={
126 | "help": "The name of the dataset to use (via the datasets library) or name of the file in src/data."
127 | },
128 | )
129 | dataset_config_name: Optional[str] = field(
130 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
131 | )
132 | metric_names: Optional[List[str]] = field(
133 | default=None,
134 | metadata={"help": "The name of the metric to use (from src/metrics)."},
135 | )
136 | input_column: Optional[str] = field(
137 | default=None,
138 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
139 | )
140 | input_prefix_column: Optional[str] = field(
141 | default=None,
142 | metadata={"help": "The name of the column in the datasets containing the input prefix (e.g. questions), when those exist."},
143 | )
144 | output_column: Optional[str] = field(
145 | default=None,
146 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
147 | )
148 | train_file: Optional[str] = field(
149 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
150 | )
151 | validation_file: Optional[str] = field(
152 | default=None,
153 | metadata={
154 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on "
155 | "(a jsonlines or csv file)."
156 | },
157 | )
158 | test_file: Optional[str] = field(
159 | default=None,
160 | metadata={
161 | "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)."
162 | },
163 | )
164 | overwrite_cache: bool = field(
165 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
166 | )
167 | preprocessing_num_workers: Optional[int] = field(
168 | default=None,
169 | metadata={"help": "The number of processes to use for the preprocessing."},
170 | )
171 | max_source_length: Optional[int] = field(
172 | default=None,
173 | metadata={
174 | "help": "The maximum total input sequence length after tokenization. Sequences longer "
175 | "than this will be truncated, sequences shorter will be padded."
176 | },
177 | )
178 | eval_max_source_length: Optional[int] = field(
179 | default=None,
180 | metadata={"help": "if None, will be same as max_source_length"},
181 | )
182 | max_prefix_length: Optional[int] = field(
183 | default=0,
184 | metadata={
185 | "help": "The maximum total input_prefix sequence length after tokenization. Sequences longer "
186 | "than this will be truncated, sequences shorter will be padded from the left "
187 | "(only used if prefixes are not merged)."
188 | },
189 | )
190 | max_target_length: Optional[int] = field(
191 | default=128,
192 | metadata={
193 | "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
194 | "than this will be truncated, sequences shorter will be padded."
195 | },
196 | )
197 | val_max_target_length: Optional[int] = field(
198 | default=None,
199 | metadata={
200 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
201 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
202 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
203 | "during ``evaluate`` and ``predict``."
204 | },
205 | )
206 | pad_to_max_length: bool = field(
207 | default=False,
208 | metadata={
209 | "help": "Whether to pad all samples to model maximum sentence length. "
210 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
211 | "efficient on GPU but very bad for TPU."
212 | },
213 | )
214 | max_train_samples: Optional[int] = field(
215 | default=None,
216 | metadata={
217 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
218 | "value if set."
219 | },
220 | )
221 | max_eval_samples: Optional[int] = field(
222 | default=None,
223 | metadata={
224 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
225 | "value if set."
226 | },
227 | )
228 | max_predict_samples: Optional[int] = field(
229 | default=None,
230 | metadata={
231 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
232 | "value if set."
233 | },
234 | )
235 | num_beams: Optional[int] = field(
236 | default=None,
237 | metadata={
238 | "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
239 | "which is used during ``evaluate`` and ``predict``."
240 | },
241 | )
242 | ignore_pad_token_for_loss: bool = field(
243 | default=True,
244 | metadata={
245 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
246 | },
247 | )
248 | source_prefix: Optional[str] = field(
249 | default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
250 | )
251 | data_dir: Optional[str] = field(
252 | default=None,
253 | metadata={"help": "Defining the data_dir of the dataset configuration."},
254 | )
255 | download_mode: Optional[str] = field(
256 | default=None,
257 | metadata={
258 | "help": "Defining the download_mode when loading the dataset. Options are `reuse_dataset_if_exists` (default), `reuse_cache_if_exists` and `force_redownload`."
259 | },
260 | )
261 | evaluate_on_training_data: bool = field(
262 | default=False,
263 | metadata={"help": "Whether to evaluate on training data or not, to make sure the model can overfit."},
264 | )
265 | folder_suffix: str = field(
266 | default="",
267 | metadata={"help": "args to be suffixes for the output folder of the run"},
268 | )
269 | preprocess_only: bool = field(
270 | default=False,
271 | metadata={"help": "Preprocess only: Don't start training, just do the things before"},
272 | )
273 | assign_zero_to_too_long_val_examples: bool = field(
274 | default=False,
275 | metadata={
276 | "help": "If true, all sequences longer then max_source_length will be assign a score of 0 in the metric evaluation"
277 | },
278 | )
279 | shared_storage: bool = field(
280 | default=True,
281 | metadata={"help": "Whether nodes share the same storage"},
282 | )
283 | trim_very_long_strings: bool = field(
284 | default=False,
285 | metadata={"help": "Whether to trim very long strings before tokenizing them"},
286 | )
287 | pad_prefix: bool = field(
288 | default=False,
289 | metadata={
290 | "help": "Whether to pad the prefix if it exists to max_prefix_length. "
291 | "Note - important if you are using a SLED model on an input that contains an input_prefix"
292 | },
293 | )
294 | test_start_ind: Optional[int] = field(
295 | default=None,
296 | metadata={"help": "if given, uses the test set starting from this index"},
297 | )
298 | test_end_ind: Optional[int] = field(
299 | default=None,
300 | metadata={"help": "if given, uses the test set ending at this index"},
301 | )
302 | # Uri:
303 | patience: Optional[int] = field(
304 | default=None,
305 | )
306 | length_penalty: Optional[float] = field(
307 | default=1.0,
308 | )
309 | extra_metrics: Optional[List[str]] = field(
310 | default=None,
311 | metadata={"help": "The name of the metric to use (from src/metrics)."},
312 | )
313 | chunked_training_size: Optional[int] = field(
314 | default=None,
315 | )
316 | oracle_training: Optional[bool] = field(
317 | default=False,
318 | metadata={"help": "If True, train on the input sentences that provide the highest ROUGE score with the labels"}
319 | )
320 | oracle_merge: Optional[bool] = field(
321 | default=False,
322 | metadata={"help": "If True, merge the oracle dataset and the standard training dataset"}
323 | )
324 | def __post_init__(self):
325 | if self.val_max_target_length is None:
326 | self.val_max_target_length = self.max_target_length
327 | if self.pad_prefix and self.max_prefix_length == 0:
328 | raise ValueError('When padding prefix, you must set a max_prefix_length')
329 | assert self.max_prefix_length == 0 or self.max_prefix_length <= 0.5*self.max_source_length,\
330 | 'If max_prefix_length is given, it must be much shorter than the total input'
331 | # Uri:
332 | if self.eval_max_source_length is None:
333 | self.eval_max_source_length = self.max_source_length
334 |
335 |
336 | @dataclass
337 | class UnlimiformerArguments:
338 | """
339 | Arguments pertaining to what data we are going to input our model for training and eval.
340 | """
341 | test_unlimiformer: Optional[bool] = field(
342 | default=False,
343 | metadata={
344 | "help": "whether to use KNN."
345 | },
346 | )
347 | unlimiformer_verbose: Optional[bool] = field(
348 | default=False,
349 | metadata={
350 | "help": "whether to print KNN intermediate predictions (mostly for debugging)."
351 | },
352 | )
353 | layer_begin: Optional[int] = field(
354 | default=0,
355 | metadata={"help": "The layer to begin applying KNN to. KNN will be applied to layers[knn_layer_begin:layer_end]. "
356 | "By default, it will be applied to all layers: [0:None]]"},
357 | )
358 | layer_end: Optional[int] = field(
359 | default=None,
360 | metadata={"help": "The layer to end applying KNN to. KNN will be applied to layers[knn_layer_begin:layer_end]. "
361 | "By default, it will be applied to all layers: [0:None]]"},
362 | )
363 | unlimiformer_chunk_overlap: Optional[float] = field(
364 | default=0.5,
365 | metadata={"help": "The fraction of overlap between input chunks"},
366 | )
367 | unlimiformer_chunk_size: Optional[int] = field(
368 | default=None,
369 | metadata={"help": "The size of each input chunk"},
370 | )
371 | unlimiformer_head_num: Optional[int] = field(
372 | default=None,
373 | metadata={"help": "The head to apply KNN to (if None, apply to all heads)"},
374 | )
375 | unlimiformer_exclude: Optional[bool] = field(
376 | default=False,
377 | metadata={
378 | "help": "If True, prioritize the inputs that are **not** in the standard attention window."
379 | },
380 | )
381 | random_unlimiformer_training: Optional[bool] = field(
382 | default=False,
383 | )
384 | unlimiformer_training: Optional[bool] = field(
385 | default=False,
386 | )
387 | use_datastore: Optional[bool] = field(default=False)
388 | flat_index: Optional[bool] = field(default=False)
389 | test_datastore: Optional[bool] = field(default=False)
390 | reconstruct_embeddings: Optional[bool] = field(default=False)
391 | gpu_datastore: Optional[bool] = field(default=True)
392 | gpu_index: Optional[bool] = field(default=True)
393 |
394 |
395 | def main():
396 | handle_args_to_ignore(sys.argv) # Just for sweeps
397 |
398 | # See all possible arguments in src/transformers/training_args.py
399 | # or by passing the --help flag to this script.
400 | # We now keep distinct sets of args, for a cleaner separation of concerns.
401 |
402 | parser = CustomHfArgumentParser((ModelArguments, DataTrainingArguments, TrainingOverridesArguments, UnlimiformerArguments))
403 | model_args, data_args, training_args, unlimiformer_args = parser.parse_dictionary_and_args()
404 |
405 | set_up_logging(training_args)
406 | logger.info(f"Training Arguments: {training_args}")
407 | logger.info(f"Data Arguments: {data_args}")
408 | logger.info(f"Model Arguments: {model_args}")
409 | logger.info(f"Unlimiformer Arguments: {unlimiformer_args}")
410 |
411 |
412 | # Added to avoid wandb.errors.UsageError: Error communicating with wandb process
413 | wandb.init(settings=wandb.Settings(start_method="fork"), name=training_args.output_dir)
414 |
415 | # Used to find missing dependencies early on
416 | load_metric(data_args.metric_names, **locals())
417 | load_extra_metrics(data_args.extra_metrics)
418 |
419 | if data_args.source_prefix is None and model_args.model_name_or_path in [
420 | "t5-small",
421 | "t5-base",
422 | "t5-large",
423 | "t5-3b",
424 | "t5-11b",
425 | ]:
426 | logger.warning(
427 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with "
428 | "`--source_prefix 'summarize: ' `"
429 | )
430 |
431 | # Detecting last checkpoint.
432 | last_checkpoint = _detect_last_checkpoint(training_args)
433 |
434 | # Set seed before initializing model.
435 | set_seed(training_args.seed)
436 |
437 | seq2seq_dataset = _get_dataset(data_args, model_args, training_args)
438 |
439 | # Load pretrained model and tokenizer
440 | #
441 | # Distributed training:
442 | # The .from_pretrained methods guarantee that only one local process can concurrently
443 | # download model & vocab.
444 | config_name = None
445 | if model_args.config_name:
446 | config_name = model_args.config_name
447 | else:
448 | if os.path.isfile(model_args.model_name_or_path):
449 | config_name = os.path.dirname(model_args.model_name_or_path)
450 | else:
451 | config_name = model_args.model_name_or_path
452 |
453 | config_overrides = {}
454 | if training_args.gradient_checkpointing is not None:
455 | config_overrides["gradient_checkpointing"] = training_args.gradient_checkpointing
456 |
457 | config = AutoConfig.from_pretrained(
458 | config_name,
459 | cache_dir=model_args.cache_dir,
460 | revision=model_args.model_revision,
461 | use_auth_token=training_args.use_auth_token,
462 | **config_overrides
463 | )
464 | # override for sled models to make sure we are explicit in our request
465 | # if isinstance(config, SledConfig) and (not data_args.pad_prefix or data_args.max_prefix_length == 0):
466 | # logger.warning('Setting prepend_prefix to False if using a SLED model, as the input does not have a prefix or '
467 | # 'pad_prefix is False (all prefixes must be of the same length for SLED). If you do not use SLED '
468 | # 'or finetune on a dataset with no prefixes, ignore this warning')
469 | # config.prepend_prefix = False
470 |
471 | if model_args.model_name_or_path is None:
472 | # Padding for divisibility by 8
473 | if config.vocab_size % 8 != 0 and training_args.fp16_padding:
474 | config.vocab_size += 8 - (config.vocab_size % 8)
475 |
476 | tokenizer_name = None
477 | if model_args.tokenizer_name:
478 | tokenizer_name = model_args.tokenizer_name
479 | else:
480 | if os.path.isfile(model_args.model_name_or_path):
481 | tokenizer_name = os.path.dirname(model_args.model_name_or_path)
482 | else:
483 | tokenizer_name = model_args.model_name_or_path
484 | tokenizer = AutoTokenizer.from_pretrained(
485 | tokenizer_name,
486 | cache_dir=model_args.cache_dir,
487 | use_fast=model_args.use_fast_tokenizer,
488 | revision=model_args.model_revision,
489 | use_auth_token=training_args.use_auth_token,
490 | )
491 | if model_args.model_name_or_path is not None:
492 | model = AutoModelForSeq2SeqLM.from_pretrained(
493 | model_args.model_name_or_path,
494 | from_tf=bool(".ckpt" in model_args.model_name_or_path),
495 | config=config,
496 | cache_dir=model_args.cache_dir,
497 | revision=model_args.model_revision,
498 | use_auth_token=training_args.use_auth_token,
499 | )
500 | else:
501 | model = AutoModelForSeq2SeqLM.from_config(
502 | config,
503 | )
504 | if unlimiformer_args.test_unlimiformer:
505 | unlimiformer_kwargs = {
506 | 'layer_begin': unlimiformer_args.layer_begin,
507 | 'layer_end': unlimiformer_args.layer_end,
508 | 'unlimiformer_head_num': unlimiformer_args.unlimiformer_head_num,
509 | 'exclude_attention': unlimiformer_args.unlimiformer_exclude,
510 | 'chunk_overlap': unlimiformer_args.unlimiformer_chunk_overlap,
511 | 'model_encoder_max_len': unlimiformer_args.unlimiformer_chunk_size,
512 | 'verbose': unlimiformer_args.unlimiformer_verbose, 'tokenizer': tokenizer,
513 | 'unlimiformer_training': unlimiformer_args.unlimiformer_training,
514 | 'use_datastore': unlimiformer_args.use_datastore,
515 | 'flat_index': unlimiformer_args.flat_index,
516 | 'test_datastore': unlimiformer_args.test_datastore,
517 | 'reconstruct_embeddings': unlimiformer_args.reconstruct_embeddings,
518 | 'gpu_datastore': unlimiformer_args.gpu_datastore,
519 | 'gpu_index': unlimiformer_args.gpu_index
520 | }
521 | if unlimiformer_args.random_unlimiformer_training:
522 | model = RandomTrainingUnlimiformer.convert_model(model, **unlimiformer_kwargs)
523 | else:
524 | model = Unlimiformer.convert_model(model, **unlimiformer_kwargs)
525 |
526 | model.config.use_cache = True
527 | if training_args.gradient_checkpointing and getattr(model.config, 'use_cache', False) and training_args.do_train:
528 | logger.warning('Cannot use cache in models when using gradient checkpointing. turning it off')
529 | model.config.use_cache = False
530 |
531 | model.resize_token_embeddings(len(tokenizer))
532 |
533 | if model.config.decoder_start_token_id is None:
534 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
535 |
536 | prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
537 |
538 | # Preprocessing the datasets.
539 | # We need to tokenize inputs and targets.
540 | if training_args.do_train:
541 | column_names = seq2seq_dataset["train"].column_names
542 | elif training_args.do_eval:
543 | column_names = seq2seq_dataset["validation"].column_names
544 | elif training_args.do_predict:
545 | column_names = seq2seq_dataset["test"].column_names
546 | else:
547 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
548 | return
549 |
550 | # Get the column names for input/target.
551 | if data_args.input_column is None:
552 | input_column = "input"
553 | else:
554 | input_column = data_args.input_column
555 | if input_column not in column_names:
556 | raise ValueError(
557 | f"--input_column' value '{data_args.input_column}' needs to be one of: {', '.join(column_names)}"
558 | )
559 | if data_args.input_prefix_column is None:
560 | input_prefix_column = "input_prefix"
561 | else:
562 | input_prefix_column = data_args.input_prefix_column
563 | if input_prefix_column not in column_names:
564 | raise ValueError(
565 | f"--input_prefix_column' value '{data_args.input_prefix_column}' needs to be one of: {', '.join(column_names)}"
566 | )
567 | if data_args.output_column is None:
568 | output_column = "output"
569 | else:
570 | output_column = data_args.output_column
571 | if output_column not in column_names:
572 | raise ValueError(
573 | f"--output_column' value '{data_args.output_column}' needs to be one of: {', '.join(column_names)}"
574 | )
575 |
576 | # Temporarily set max_target_length for training.
577 | max_target_length = data_args.max_target_length
578 | padding = "max_length" if data_args.pad_to_max_length else False
579 |
580 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
581 | logger.warning(
582 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
583 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
584 | )
585 |
586 | def preprocess_function_kwargs_fn():
587 | return {
588 | "tokenizer": deepcopy(tokenizer),
589 | "prefix": prefix,
590 | "input_column": input_column,
591 | "input_prefix_column": input_prefix_column,
592 | "output_column": output_column,
593 | "max_source_length": data_args.max_source_length,
594 | "max_prefix_length": data_args.max_prefix_length,
595 | "max_target_length": max_target_length,
596 | "prefix_sep": PREFIX_DOC_SEP,
597 | "padding": padding,
598 | "ignore_pad_token_for_loss": data_args.ignore_pad_token_for_loss,
599 | "assign_zero_to_too_long_val_examples": data_args.assign_zero_to_too_long_val_examples,
600 | "trim_very_long_strings": data_args.trim_very_long_strings,
601 | "pad_prefix": data_args.pad_prefix
602 | }
603 |
604 | if training_args.do_train:
605 | if "train" not in seq2seq_dataset:
606 | raise ValueError("--do_train requires a train dataset")
607 | logger.info("")
608 | logger.info("Training examples before tokenization:")
609 | if input_prefix_column in column_names:
610 | logger.info(f"input_prefix #0: {seq2seq_dataset['train'][0][input_prefix_column]}")
611 | # logger.info(f"input #0: {seq2seq_dataset['train'][0]['input']}")
612 | # logger.info(f"output #0: {seq2seq_dataset['train'][0]['output']}")
613 | if input_prefix_column in column_names:
614 | logger.info(f"input_prefix #1: {seq2seq_dataset['train'][1][input_prefix_column]}")
615 | # logger.info(f"input #1: {seq2seq_dataset['train'][1]['input']}")
616 | # logger.info(f"output #1: {seq2seq_dataset['train'][1]['output']}")
617 | logger.info("")
618 | untokenized_train_dataset = seq2seq_dataset["train"]
619 | if data_args.max_train_samples is not None:
620 | untokenized_train_dataset = untokenized_train_dataset.select(range(data_args.max_train_samples))
621 |
622 | if DEBUG:
623 | # In debug mode, we want to recreate the data
624 | data_args.shared_storage = False
625 | data_args.overwrite_cache = True
626 | with training_args.main_process_first(
627 | local=not data_args.shared_storage, desc="train dataset map pre-processing"
628 | ):
629 |
630 | if data_args.oracle_training:
631 | logger.info("Using oracle training")
632 | oracle_processed_dir = f'oracle_input_{data_args.dataset_config_name}'
633 | if os.path.isdir(oracle_processed_dir):
634 | logger.info(f"Using oracle training from {oracle_processed_dir}")
635 | oracle_training_set = datasets.load_from_disk(oracle_processed_dir)
636 | else:
637 | rouge_scorer = datasets.load_metric('rouge')
638 | oracle_training_set = untokenized_train_dataset.map(
639 | extract_oracle_sent_batch,
640 | fn_kwargs={'max_length': data_args.max_source_length,
641 | 'tokenizer': tokenizer,
642 | 'rouge_scorer': rouge_scorer},
643 | batched=True,
644 | batch_size=1,
645 | num_proc=data_args.preprocessing_num_workers,
646 | load_from_cache_file=not data_args.overwrite_cache,
647 | desc="Extracting oracle sentences from every training example",
648 | )
649 | oracle_training_set.save_to_disk(oracle_processed_dir)
650 |
651 |
652 | if data_args.oracle_merge:
653 | untokenized_train_dataset = datasets.concatenate_datasets([untokenized_train_dataset, oracle_training_set])
654 | untokenized_train_dataset = untokenized_train_dataset.shuffle(seed=training_args.seed)
655 | else:
656 | untokenized_train_dataset = oracle_training_set
657 |
658 | train_dataset = untokenized_train_dataset.map(
659 | preprocess_function,
660 | fn_kwargs=preprocess_function_kwargs_fn(),
661 | batched=True,
662 | num_proc=data_args.preprocessing_num_workers,
663 | remove_columns=untokenized_train_dataset.column_names,
664 | load_from_cache_file=not data_args.overwrite_cache,
665 | desc="Running tokenizer on train dataset",
666 | )
667 |
668 | if data_args.chunked_training_size is not None:
669 | train_dataset = train_dataset.map(
670 | chunk_dataset_function,
671 | fn_kwargs={'chunk_size': data_args.chunked_training_size},
672 | batched=True,
673 | num_proc=data_args.preprocessing_num_workers,
674 | load_from_cache_file=not data_args.overwrite_cache,
675 | desc="Chunking train dataset source",
676 | )
677 | train_dataset = train_dataset.shuffle(seed=training_args.seed)
678 |
679 | if training_args.do_eval:
680 | max_target_length = data_args.val_max_target_length
681 | preprocess_function_kwargs = preprocess_function_kwargs_fn()
682 | preprocess_function_kwargs["max_target_length"] = max_target_length
683 | preprocess_function_kwargs['max_source_length'] = data_args.eval_max_source_length
684 | if "validation" not in seq2seq_dataset:
685 | raise ValueError("--do_eval requires a validation dataset")
686 | logger.info("")
687 | logger.info("Validation examples before tokenization:")
688 | if input_prefix_column in column_names:
689 | logger.info(f"input_prefix #0: {seq2seq_dataset['validation'][0][input_prefix_column]}")
690 | # logger.info(f"input #0: {seq2seq_dataset['validation'][0]['input']}")
691 | # logger.info(f"output #0: {seq2seq_dataset['validation'][0]['output']}")
692 | if input_prefix_column in column_names:
693 | logger.info(f"input_prefix #1: {seq2seq_dataset['validation'][1][input_prefix_column]}")
694 | # logger.info(f"input #1: {seq2seq_dataset['validation'][1]['input']}")
695 | # logger.info(f"output #1: {seq2seq_dataset['validation'][1]['output']}")
696 | logger.info("")
697 | untokenized_eval_dataset = seq2seq_dataset["validation"]
698 | if data_args.max_eval_samples is not None:
699 | untokenized_eval_dataset = untokenized_eval_dataset.select(range(data_args.max_eval_samples))
700 | if model_args.drop_duplicates_in_eval is True:
701 | untokenized_eval_dataset = drop_duplicates_in_input(untokenized_eval_dataset)
702 | untokenized_eval_dataset_orig = untokenized_eval_dataset
703 | assert training_args.eval_fraction > 0
704 | n = len(untokenized_eval_dataset)
705 | training_args = replace(training_args, eval_fraction = min(training_args.eval_fraction, n))
706 | if training_args.eval_fraction != 1:
707 | if training_args.eval_fraction > 1:
708 | assert training_args.eval_fraction == int(training_args.eval_fraction)
709 | logger.info(f'using predetermined absolute samples from eval set ({training_args.eval_fraction} )')
710 | training_args = replace(training_args, eval_fraction = training_args.eval_fraction / n)
711 | indices = np.random.permutation(n)[:int(np.ceil(max(1, training_args.eval_fraction * n)))]
712 | untokenized_eval_dataset = type(untokenized_eval_dataset).from_dict(untokenized_eval_dataset[indices])
713 | logger.info(f'During training, will only use {training_args.eval_fraction:.3%} samples of the eval set '
714 | f'which amounts to {len(untokenized_eval_dataset)} out of {n} samples')
715 |
716 | eval_dataset = process_eval_set(data_args, preprocess_function_kwargs, training_args, untokenized_eval_dataset)
717 | eval_dataset_orig = eval_dataset
718 | if training_args.eval_fraction < 1:
719 | eval_dataset_orig = process_eval_set(data_args, preprocess_function_kwargs, training_args,
720 | untokenized_eval_dataset_orig)
721 |
722 | if training_args.do_predict:
723 | max_target_length = data_args.val_max_target_length
724 | preprocess_function_kwargs = preprocess_function_kwargs_fn()
725 | preprocess_function_kwargs["max_target_length"] = max_target_length
726 | preprocess_function_kwargs['max_source_length'] = data_args.eval_max_source_length
727 | if "test" not in seq2seq_dataset:
728 | raise ValueError("--do_predict requires a test dataset")
729 | untokenized_predict_dataset = seq2seq_dataset["test"]
730 | if data_args.max_predict_samples is not None:
731 | untokenized_predict_dataset = untokenized_predict_dataset.select(range(data_args.max_predict_samples))
732 | if model_args.drop_duplicates_in_eval is True:
733 | untokenized_predict_dataset = drop_duplicates_in_input(untokenized_predict_dataset)
734 |
735 | if output_column in untokenized_predict_dataset.column_names:
736 | untokenized_predict_dataset = untokenized_predict_dataset.remove_columns(output_column)
737 |
738 | if data_args.test_start_ind is not None:
739 | sind = data_args.test_start_ind
740 | eind = -1 if data_args.test_end_ind is None else data_args.test_end_ind
741 | logger.info(f'Using only a subset of the test dataset [{sind}, {eind}]')
742 | untokenized_predict_dataset = type(untokenized_predict_dataset).from_dict(untokenized_predict_dataset[sind:eind])
743 |
744 | with training_args.main_process_first(
745 | local=not data_args.shared_storage, desc="prediction dataset map pre-processing"
746 | ):
747 | predict_dataset = untokenized_predict_dataset.map(
748 | preprocess_function,
749 | fn_kwargs=preprocess_function_kwargs,
750 | batched=True,
751 | num_proc=data_args.preprocessing_num_workers,
752 | remove_columns=untokenized_predict_dataset.column_names,
753 | load_from_cache_file=not data_args.overwrite_cache,
754 | desc="Running tokenizer on prediction dataset",
755 | )
756 |
757 | if data_args.preprocess_only:
758 | logger.info(f"With --preprocess_only, exiting after preprocess_on the data")
759 | exit()
760 |
761 | # Data collator
762 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
763 | pad_to = 8 if training_args.fp16 and training_args.fp16_padding else None
764 |
765 |
766 | data_collator = DataCollatorForSeq2Seq(
767 | tokenizer,
768 | model=model,
769 | label_pad_token_id=label_pad_token_id,
770 | pad_to_multiple_of=pad_to,
771 | )
772 |
773 | # Metric
774 | compute_metrics = load_metric(data_args.metric_names, **locals())
775 | compute_metrics = load_extra_metrics(data_args.extra_metrics, compute_metrics)
776 |
777 | # Initialize our Trainer
778 | trainer = CustomTrainer(
779 | model=model,
780 | args=training_args,
781 | train_dataset=train_dataset if training_args.do_train else None,
782 | eval_dataset=eval_dataset if training_args.do_eval else None,
783 | untokenized_eval_dataset=untokenized_eval_dataset if training_args.do_eval else None,
784 | tokenizer=tokenizer,
785 | data_collator=data_collator,
786 | compute_metrics=compute_metrics if training_args.predict_with_generate else None,
787 | output_dir=training_args.output_dir,
788 | data_args=data_args,
789 | callbacks=[EarlyStoppingCallback(early_stopping_patience=data_args.patience)] if data_args.patience is not None else None,
790 | )
791 |
792 | # setup_cometml_trainer_callback(trainer)
793 |
794 | # Training
795 | if training_args.do_train:
796 | checkpoint = None
797 | if training_args.resume_from_checkpoint is not None:
798 | checkpoint = training_args.resume_from_checkpoint
799 | elif last_checkpoint is not None:
800 | checkpoint = last_checkpoint # look for checkpoints in the outdir
801 |
802 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
803 | logger.info('Done training')
804 | trainer.save_model() # Saves the tokenizer too for easy upload
805 |
806 | metrics = train_result.metrics
807 | max_train_samples = (
808 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
809 | )
810 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
811 |
812 | trainer.log_metrics("train", metrics)
813 | trainer.save_metrics("train", metrics)
814 | trainer.save_state()
815 |
816 | # Evaluation
817 | results = {}
818 | if training_args.do_eval:
819 | logger.info("*** Evaluate ***")
820 |
821 | if training_args.eval_fraction < 1:
822 | logger.info('setting the eval set back to the full one')
823 | trainer.eval_dataset = eval_dataset_orig
824 | trainer._untokenized_eval_dataset = untokenized_eval_dataset_orig
825 |
826 | metrics = trainer.evaluate(metric_key_prefix="eval", use_cache=True, length_penalty=data_args.length_penalty)
827 | logger.info('Done evaluating')
828 |
829 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
830 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
831 |
832 | trainer.log_metrics("eval", metrics)
833 | trainer.save_metrics("eval", metrics)
834 |
835 | if training_args.do_predict:
836 | logger.info("*** Predict ***")
837 | trainer.args.predict_with_generate = True # during prediction, we don't have labels
838 |
839 | # load last (and best) model, or the one specified if any
840 | logger.info("*** Loading model weights before the prediction ***")
841 | last_checkpoint = model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else _detect_last_checkpoint(training_args)
842 | if last_checkpoint is not None and os.path.isdir(last_checkpoint):
843 | logger.info(f'Loading weights from {last_checkpoint} for the prediction')
844 | state_dict = torch.load(os.path.join(last_checkpoint, WEIGHTS_NAME), map_location="cpu")
845 | # If the model is on the GPU, it still works!
846 | # trainer._load_state_dict_in_model(state_dict)
847 | # release memory
848 | del state_dict
849 | logger.info("*** Done loading weights ***")
850 | elif training_args.do_train:
851 | raise ValueError('Could not find a model to load for prediction')
852 | else:
853 | logger.info(f'Using {model_args.model_name_or_path} as the model for the prediction')
854 |
855 | predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", use_cache=True)
856 | logger.info('Done predicting')
857 |
858 | metrics = predict_results.metrics
859 | max_predict_samples = (
860 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
861 | )
862 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
863 |
864 | trainer.log_metrics("predict", metrics)
865 | trainer.save_metrics("predict", metrics)
866 |
867 | if trainer.is_world_process_zero():
868 | if training_args.predict_with_generate:
869 | id_to_prediction = {}
870 | for i, instance in enumerate(untokenized_predict_dataset):
871 | id_to_prediction[instance["id"]] = predict_results.predictions[i]
872 | predictions = decode(id_to_prediction, tokenizer, data_args)
873 | output_name = "generated_predictions.json"
874 | if data_args.test_start_ind is not None:
875 | output_name = f"generated_predictions_{data_args.test_start_ind}_{data_args.test_end_ind}.json"
876 | output_prediction_file = os.path.join(training_args.output_dir, output_name)
877 | with open(output_prediction_file, "w") as writer:
878 | json.dump(predictions, writer, indent=4)
879 |
880 | if training_args.push_to_hub:
881 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"}
882 | if data_args.dataset_name is not None:
883 | kwargs["dataset_tags"] = data_args.dataset_name
884 | if data_args.dataset_config_name is not None:
885 | kwargs["dataset_args"] = data_args.dataset_config_name
886 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
887 | else:
888 | kwargs["dataset"] = data_args.dataset_name
889 |
890 | trainer.push_to_hub(**kwargs)
891 |
892 | return results
893 |
894 | def _detect_last_checkpoint(training_args):
895 | last_checkpoint = None
896 | if os.path.isdir(training_args.output_dir) and training_args.do_train:
897 | if not training_args.overwrite_output_dir:
898 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
899 |
900 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
901 | logger.info(
902 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
903 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
904 | )
905 | return last_checkpoint
906 |
907 | def process_eval_set(data_args, preprocess_function_kwargs, training_args, untokenized_eval_dataset):
908 | with training_args.main_process_first(
909 | local=not data_args.shared_storage, desc="validation dataset map pre-processing"
910 | ):
911 | eval_dataset = untokenized_eval_dataset.map(
912 | preprocess_function,
913 | fn_kwargs=preprocess_function_kwargs,
914 | batched=True,
915 | num_proc=data_args.preprocessing_num_workers,
916 | remove_columns=untokenized_eval_dataset.column_names,
917 | load_from_cache_file=not data_args.overwrite_cache,
918 | desc="Running tokenizer on validation dataset",
919 | )
920 | return eval_dataset
921 |
922 |
923 | def _get_dataset(data_args, model_args, training_args):
924 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
925 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
926 | # (the dataset will be downloaded automatically from the datasets Hub).
927 | #
928 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the
929 | # summaries (unless you specify column names for this with the `input_column` and `output_column` arguments).
930 | #
931 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently
932 | # download the dataset.
933 | data_files = None
934 | if data_args.train_file is not None or data_args.validation_file is not None or data_args.test_file is not None:
935 | data_files = {}
936 | if data_args.train_file is not None:
937 | data_files["train"] = data_args.train_file
938 | if data_args.validation_file is not None:
939 | data_files["validation"] = data_args.validation_file
940 | if data_args.test_file is not None:
941 | data_files["test"] = data_args.test_file
942 | # Downloading and loading a dataset from the hub/local script.
943 | seq2seq_dataset = load_dataset(
944 | data_args.dataset_name,
945 | data_args.dataset_config_name,
946 | verification_mode='no_checks',
947 | cache_dir=model_args.cache_dir,
948 | data_dir=data_args.data_dir,
949 | data_files=data_files,
950 | download_mode=data_args.download_mode,
951 | use_auth_token=training_args.use_auth_token
952 | )
953 | if training_args.do_train:
954 | training_args.apply_overrides(len(seq2seq_dataset['train']))
955 | if data_args.evaluate_on_training_data:
956 | seq2seq_dataset["validation"] = seq2seq_dataset["train"]
957 |
958 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
959 | # https://huggingface.co/docs/datasets/loading_datasets.html.
960 |
961 | return seq2seq_dataset
962 |
963 |
964 | def set_up_logging(training_args):
965 | logging.basicConfig(
966 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
967 | datefmt="%m/%d/%Y %H:%M:%S",
968 | handlers=[logging.StreamHandler(sys.stdout)],
969 | )
970 | log_level = training_args.get_process_log_level()
971 | logger.setLevel(log_level)
972 | datasets.utils.logging.set_verbosity(log_level)
973 | transformers.utils.logging.set_verbosity(log_level)
974 | transformers.utils.logging.enable_default_handler()
975 | transformers.utils.logging.enable_explicit_format()
976 | # Log on each process the small summary:
977 | logger.warning(
978 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
979 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
980 | )
981 | logger.info(f"Training/evaluation parameters {training_args}")
982 |
983 | def extract_oracle_sent_batch(examples, max_length, tokenizer, rouge_scorer):
984 | items = examples.data.items()
985 | keys = [item[0] for item in items]
986 | values = [item[1] for item in items]
987 | extracted = {k: [] for k in keys}
988 | input_str = 'input'
989 |
990 | for ex in zip(*values):
991 | ex = dict(zip(keys, ex))
992 | ex_input = ex[input_str]
993 | extracted_input = extract_oracle_sentences(ex_input, ex['output'], max_length, tokenizer, rouge_scorer)
994 | extracted[input_str].append(extracted_input)
995 | for k in set(keys) - {input_str}:
996 | extracted[k].append(ex[k])
997 | return extracted
998 |
999 | def extract_oracle_sentences(input_sequence, output, max_length, tokenizer, rouge_scorer, criterion='rouge/geometric_mean'):
1000 | sentences = nltk.sent_tokenize(input_sequence)
1001 | selected_mask = [False for _ in sentences]
1002 |
1003 | max_rouge = 0.0
1004 | joined_selection = ''
1005 | counter = 0
1006 | while len(tokenizer(joined_selection)) < max_length and counter < 100:
1007 | cur_max_rouge = max_rouge
1008 | max_index = -1
1009 |
1010 | cur_candidate_indices = []
1011 | cur_candidates = []
1012 | for i in range(len(sentences)):
1013 | if selected_mask[i]:
1014 | # We already selected this sentence
1015 | continue
1016 | candidate_mask = list(selected_mask)
1017 | candidate_mask[i] = True
1018 | candidate_prediction = ' '.join(sent for sent, mask in zip(sentences, candidate_mask) if mask)
1019 | cur_candidates.append(candidate_prediction)
1020 | cur_candidate_indices.append(i)
1021 |
1022 | rouge = rouge_scorer.compute(predictions=cur_candidates, references=[[output]] * len(cur_candidates), use_aggregator=False)
1023 | aggregated_rouge_types = [s1.fmeasure * s2.fmeasure * sL.fmeasure for s1, s2, sL in zip(rouge['rouge1'], rouge['rouge2'], rouge['rougeLsum'])]
1024 | max_index = np.argmax(aggregated_rouge_types)
1025 | cur_max_rouge = aggregated_rouge_types[max_index]
1026 |
1027 | if max_rouge >= cur_max_rouge:
1028 | # No sentence improves the score
1029 | break
1030 |
1031 | selected_mask[cur_candidate_indices[max_index]] = True
1032 | max_rouge = cur_max_rouge
1033 | joined_selection = ' '.join(sent for sent, mask in zip(sentences, selected_mask) if mask)
1034 | counter += 1
1035 |
1036 | return joined_selection
1037 |
1038 |
1039 | def chunk_dataset_function(examples, chunk_size):
1040 | input_ids_str = 'input_ids'
1041 | attention_mask_str = 'attention_mask'
1042 | items = examples.data.items()
1043 | keys = [item[0] for item in items]
1044 | values = [item[1] for item in items]
1045 | chunked = {k: [] for k in keys}
1046 | for ex in zip(*values):
1047 | ex = dict(zip(keys, ex))
1048 | for i in range(0, len(ex[input_ids_str]), chunk_size):
1049 | chunked_input_ids_st = ex[input_ids_str][i:i + chunk_size]
1050 | chunked_attention_mask = ex[attention_mask_str][i:i + chunk_size]
1051 |
1052 | if sum(chunked_attention_mask) < 10:
1053 | continue
1054 | chunked[input_ids_str].append(chunked_input_ids_st)
1055 | chunked[attention_mask_str].append(chunked_attention_mask)
1056 | for k in set(keys) - {input_ids_str, attention_mask_str}:
1057 | chunked[k].append(ex[k])
1058 | return chunked
1059 |
1060 |
1061 |
1062 | def preprocess_function(
1063 | examples,
1064 | tokenizer,
1065 | prefix,
1066 | input_column,
1067 | input_prefix_column,
1068 | output_column,
1069 | max_source_length,
1070 | max_prefix_length,
1071 | max_target_length,
1072 | prefix_sep,
1073 | padding,
1074 | ignore_pad_token_for_loss,
1075 | assign_zero_to_too_long_val_examples,
1076 | trim_very_long_strings,
1077 | pad_prefix
1078 | ):
1079 | if not isinstance(examples[input_column][0], str):
1080 | model_inputs = _preprocess_tokenized_inputs()
1081 | else:
1082 | model_inputs = _preprocess_raw_inputs(assign_zero_to_too_long_val_examples, examples, input_column, input_prefix_column,
1083 | max_source_length, padding, prefix, tokenizer, trim_very_long_strings, max_prefix_length,
1084 | prefix_sep, pad_prefix)
1085 |
1086 | _preprocess_targets(examples, ignore_pad_token_for_loss, max_target_length, model_inputs, output_column, padding, tokenizer)
1087 | model_inputs["length"] = [len(x) for x in model_inputs["input_ids"]]
1088 | return model_inputs
1089 |
1090 |
1091 | def _preprocess_raw_inputs(assign_zero_to_too_long_val_examples, examples, input_column, input_prefix_column,
1092 | max_source_length, padding, prefix, tokenizer, trim_very_long_strings, max_prefix_length,
1093 | prefix_sep, pad_prefix):
1094 | inputs = examples[input_column]
1095 |
1096 | # the given prefix is what used in models like T5 (e.g. "summarize: ")
1097 | # if prefix exists, it is added to the input_prefixes
1098 | if input_prefix_column in examples.keys():
1099 | input_prefixes = [inp + prefix_sep for inp in examples[input_prefix_column]]
1100 | if prefix != "":
1101 | input_prefixes = [prefix + inp for inp in input_prefixes]
1102 | elif prefix != "":
1103 | inputs = [prefix + inp for inp in inputs]
1104 |
1105 | # tokenize the input prefix if it exists
1106 | model_prefix_inputs = None
1107 | if input_prefix_column in examples.keys():
1108 | if trim_very_long_strings:
1109 | input_prefixes = [inp[: max_prefix_length * 7] for inp in input_prefixes]
1110 | if pad_prefix:
1111 | model_prefix_inputs = tokenizer(input_prefixes, max_length=max_prefix_length, padding='max_length', truncation=True)
1112 | else:
1113 | # for led, we do not pad the prefix
1114 | model_prefix_inputs = tokenizer(input_prefixes, max_length=max_source_length, padding='do_not_pad', truncation=True)
1115 |
1116 | if trim_very_long_strings:
1117 | inputs = [inp[: max_source_length * 7] for inp in inputs]
1118 | model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)
1119 |
1120 | if max_source_length is not None and assign_zero_to_too_long_val_examples:
1121 | model_inputs_untrimmed = tokenizer(inputs)
1122 | model_inputs["not_valid_for_eval"] = [
1123 | len(token_ids) > max_source_length for token_ids in model_inputs_untrimmed["input_ids"]
1124 | ]
1125 | else:
1126 | model_inputs["not_valid_for_eval"] = [False] * len(model_inputs["input_ids"])
1127 |
1128 | # now, combine the concat prefix to the input, trimming it to max_source_length if given
1129 | if model_prefix_inputs is not None:
1130 | max_source_length = max_source_length or -1
1131 | model_inputs['input_ids'] = [(inp1+inp2)[:max_source_length] for inp1, inp2
1132 | in zip(model_prefix_inputs['input_ids'], model_inputs['input_ids'])]
1133 | model_inputs['attention_mask'] = [(inp1+inp2)[:max_source_length] for inp1, inp2
1134 | in zip(model_prefix_inputs['attention_mask'], model_inputs['attention_mask'])]
1135 | # add prefix_length
1136 | if pad_prefix:
1137 | # no need to go over them as they will all be of the same length
1138 | model_inputs['prefix_length'] = [max_prefix_length] * len(model_inputs['input_ids'])
1139 | else:
1140 | model_inputs['prefix_length'] = [len(inp) for inp in model_prefix_inputs['input_ids']]
1141 |
1142 | return model_inputs
1143 |
1144 | def _preprocess_targets(examples, ignore_pad_token_for_loss, max_target_length, model_inputs, output_column, padding, tokenizer):
1145 | targets = examples[output_column] if output_column in examples else None
1146 | if targets is not None:
1147 | if not isinstance(targets[0], str):
1148 | if max_target_length is not None:
1149 | targets = [target[:max_target_length] for target in targets]
1150 | model_inputs["labels"] = targets
1151 | else:
1152 | # Setup the tokenizer for targets
1153 | with tokenizer.as_target_tokenizer():
1154 | labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
1155 |
1156 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
1157 | # padding in the loss.
1158 | if padding == "max_length" and ignore_pad_token_for_loss:
1159 | labels["input_ids"] = [
1160 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
1161 | ]
1162 |
1163 | model_inputs["labels"] = labels["input_ids"]
1164 |
1165 | def load_extra_metrics(metric_names, loaded_metrics=None):
1166 | if loaded_metrics is None:
1167 | loaded_metrics = MetricCollection([])
1168 | if metric_names is not None:
1169 | for metric_name in metric_names:
1170 | if len(metric_name) > 0:
1171 | loaded_metrics._metrics.append(HFMetricWrapper(metric_name))
1172 | return loaded_metrics
1173 |
1174 | def _mp_fn(index):
1175 | # For xla_spawn (TPUs)
1176 | main()
1177 |
1178 |
1179 | if __name__ == "__main__":
1180 | main()
1181 |
--------------------------------------------------------------------------------
/src/run_generation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 | """ Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)
18 | """
19 |
20 |
21 | import argparse
22 | import inspect
23 | import logging
24 |
25 | from dataclasses import dataclass, field
26 | from typing import Tuple, List, Optional, Union
27 |
28 | import numpy as np
29 | import torch
30 | import os
31 |
32 | normal_repr = torch.Tensor.__repr__
33 | torch.Tensor.__repr__ = lambda self: f"{self.shape}_{normal_repr(self)}"
34 |
35 | from transformers import (
36 | AutoTokenizer,
37 | BloomForCausalLM,
38 | BloomTokenizerFast,
39 | CTRLLMHeadModel,
40 | CTRLTokenizer,
41 | GenerationMixin,
42 | GPT2LMHeadModel,
43 | GPT2Tokenizer,
44 | GPTJForCausalLM,
45 | HfArgumentParser,
46 | LlamaForCausalLM,
47 | LlamaTokenizer,
48 | OpenAIGPTLMHeadModel,
49 | OpenAIGPTTokenizer,
50 | OPTForCausalLM,
51 | TransfoXLLMHeadModel,
52 | TransfoXLTokenizer,
53 | XLMTokenizer,
54 | XLMWithLMHeadModel,
55 | XLNetLMHeadModel,
56 | XLNetTokenizer,
57 | TextStreamer,
58 | )
59 | from transformers.modeling_outputs import CausalLMOutputWithPast
60 |
61 | from unlimiformer import Unlimiformer
62 | from random_training_unlimiformer import RandomTrainingUnlimiformer
63 |
64 | @dataclass
65 | class UnlimiformerArguments:
66 | """
67 | Arguments pertaining to what data we are going to input our model for training and eval.
68 | """
69 | test_unlimiformer: Optional[bool] = field(
70 | default=False,
71 | metadata={
72 | "help": "whether to use KNN."
73 | },
74 | )
75 | unlimiformer_verbose: Optional[bool] = field(
76 | default=False,
77 | metadata={
78 | "help": "whether to print KNN intermediate predictions (mostly for debugging)."
79 | },
80 | )
81 | layer_begin: Optional[int] = field(
82 | default=0,
83 | metadata={"help": "The layer to begin applying KNN to. KNN will be applied to layers[knn_layer_begin:layer_end]. "
84 | "By default, it will be applied to all layers: [0:None]]"},
85 | )
86 | layer_end: Optional[int] = field(
87 | default=None,
88 | metadata={"help": "The layer to end applying KNN to. KNN will be applied to layers[knn_layer_begin:layer_end]. "
89 | "By default, it will be applied to all layers: [0:None]]"},
90 | )
91 | unlimiformer_chunk_overlap: Optional[float] = field(
92 | default=0.5,
93 | metadata={"help": "The fraction of overlap between input chunks"},
94 | )
95 | unlimiformer_chunk_size: Optional[int] = field(
96 | default=None,
97 | metadata={"help": "The size of each input chunk"},
98 | )
99 | unlimiformer_head_num: Optional[int] = field(
100 | default=None,
101 | metadata={"help": "The head to apply KNN to (if None, apply to all heads)"},
102 | )
103 | unlimiformer_exclude: Optional[bool] = field(
104 | default=False,
105 | metadata={
106 | "help": "If True, prioritize the inputs that are **not** in the standard attention window."
107 | },
108 | )
109 | random_unlimiformer_training: Optional[bool] = field(
110 | default=False,
111 | )
112 | unlimiformer_training: Optional[bool] = field(
113 | default=False,
114 | )
115 | index_devices: Optional[List[int]] = field(
116 | default_factory=lambda: (0,),
117 | )
118 | datastore_device: Optional[int] = field(
119 | default=0,
120 | )
121 | use_datastore: Optional[bool] = field(default=True)
122 | flat_index: Optional[bool] = field(default=True)
123 | test_datastore: Optional[bool] = field(default=False)
124 | reconstruct_embeddings: Optional[bool] = field(default=False)
125 | gpu_datastore: Optional[bool] = field(default=True)
126 | gpu_index: Optional[bool] = field(default=True)
127 |
128 |
129 | logging.basicConfig(
130 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
131 | datefmt="%m/%d/%Y %H:%M:%S",
132 | level=logging.INFO,
133 | )
134 | logger = logging.getLogger(__name__)
135 |
136 | MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
137 |
138 | MODEL_CLASSES = {
139 | "gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
140 | "ctrl": (CTRLLMHeadModel, CTRLTokenizer),
141 | "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
142 | "xlnet": (XLNetLMHeadModel, XLNetTokenizer),
143 | "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
144 | "xlm": (XLMWithLMHeadModel, XLMTokenizer),
145 | "gptj": (GPTJForCausalLM, AutoTokenizer),
146 | "bloom": (BloomForCausalLM, BloomTokenizerFast),
147 | "llama": (LlamaForCausalLM, LlamaTokenizer),
148 | "opt": (OPTForCausalLM, GPT2Tokenizer),
149 | }
150 |
151 | # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
152 | # in https://github.com/rusiaaman/XLNet-gen#methodology
153 | # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
154 | PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family
155 | (except for Alexei and Maria) are discovered.
156 | The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
157 | remainder of the story. 1883 Western Siberia,
158 | a young Grigori Rasputin is asked by his father and a group of men to perform magic.
159 | Rasputin has a vision and denounces one of the men as a horse thief. Although his
160 | father initially slaps him for making such an accusation, Rasputin watches as the
161 | man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
162 | the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
163 | with people, even a bishop, begging for his blessing. """
164 |
165 |
166 | def set_seed(args):
167 | np.random.seed(args.seed)
168 | torch.manual_seed(args.seed)
169 | if args.n_gpu > 0:
170 | torch.cuda.manual_seed_all(args.seed)
171 |
172 |
173 | #
174 | # Functions to prepare models' input
175 | #
176 |
177 |
178 | def prepare_ctrl_input(args, _, tokenizer, prompt_text):
179 | if args.temperature > 0.7:
180 | logger.info("CTRL typically works better with lower temperatures (and lower top_k).")
181 |
182 | encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
183 | if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
184 | logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
185 | return prompt_text
186 |
187 |
188 | def prepare_xlm_input(args, model, tokenizer, prompt_text):
189 | # kwargs = {"language": None, "mask_token_id": None}
190 |
191 | # Set the language
192 | use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
193 | if hasattr(model.config, "lang2id") and use_lang_emb:
194 | available_languages = model.config.lang2id.keys()
195 | if args.xlm_language in available_languages:
196 | language = args.xlm_language
197 | else:
198 | language = None
199 | while language not in available_languages:
200 | language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ")
201 |
202 | model.config.lang_id = model.config.lang2id[language]
203 | # kwargs["language"] = tokenizer.lang2id[language]
204 |
205 | # TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
206 | # XLM masked-language modeling (MLM) models need masked token
207 | # is_xlm_mlm = "mlm" in args.model_name_or_path
208 | # if is_xlm_mlm:
209 | # kwargs["mask_token_id"] = tokenizer.mask_token_id
210 |
211 | return prompt_text
212 |
213 |
214 | def prepare_xlnet_input(args, _, tokenizer, prompt_text):
215 | prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
216 | prompt_text = prefix + prompt_text
217 | return prompt_text
218 |
219 |
220 | def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
221 | prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
222 | prompt_text = prefix + prompt_text
223 | return prompt_text
224 |
225 |
226 | PREPROCESSING_FUNCTIONS = {
227 | "ctrl": prepare_ctrl_input,
228 | "xlm": prepare_xlm_input,
229 | "xlnet": prepare_xlnet_input,
230 | "transfo-xl": prepare_transfoxl_input,
231 | }
232 |
233 |
234 | def adjust_length_to_model(length, max_sequence_length):
235 | if length < 0 and max_sequence_length > 0:
236 | length = max_sequence_length
237 | elif 0 < max_sequence_length < length:
238 | length = max_sequence_length # No generation bigger than model size
239 | elif length < 0:
240 | length = MAX_LENGTH # avoid infinite loop
241 | return length
242 |
243 |
244 | def sparse_model_config(model_config):
245 | embedding_size = None
246 | if hasattr(model_config, "hidden_size"):
247 | embedding_size = model_config.hidden_size
248 | elif hasattr(model_config, "n_embed"):
249 | embedding_size = model_config.n_embed
250 | elif hasattr(model_config, "n_embd"):
251 | embedding_size = model_config.n_embd
252 |
253 | num_head = None
254 | if hasattr(model_config, "num_attention_heads"):
255 | num_head = model_config.num_attention_heads
256 | elif hasattr(model_config, "n_head"):
257 | num_head = model_config.n_head
258 |
259 | if embedding_size is None or num_head is None or num_head == 0:
260 | raise ValueError("Check the model config")
261 |
262 | num_embedding_size_per_head = int(embedding_size / num_head)
263 | if hasattr(model_config, "n_layer"):
264 | num_layer = model_config.n_layer
265 | elif hasattr(model_config, "num_hidden_layers"):
266 | num_layer = model_config.num_hidden_layers
267 | else:
268 | raise ValueError("Number of hidden layers couldn't be determined from the model config")
269 |
270 | return num_layer, num_head, num_embedding_size_per_head
271 |
272 |
273 | def generate_past_key_values(model, batch_size, seq_len):
274 | num_block_layers, num_attention_heads, num_embedding_size_per_head = sparse_model_config(model.config)
275 | if model.config.model_type == "bloom":
276 | past_key_values = tuple(
277 | (
278 | torch.empty(int(num_attention_heads * batch_size), num_embedding_size_per_head, seq_len)
279 | .to(model.dtype)
280 | .to(model.device),
281 | torch.empty(int(num_attention_heads * batch_size), seq_len, num_embedding_size_per_head)
282 | .to(model.dtype)
283 | .to(model.device),
284 | )
285 | for _ in range(num_block_layers)
286 | )
287 | else:
288 | past_key_values = tuple(
289 | (
290 | torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head)
291 | .to(model.dtype)
292 | .to(model.device),
293 | torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head)
294 | .to(model.dtype)
295 | .to(model.device),
296 | )
297 | for _ in range(num_block_layers)
298 | )
299 | return past_key_values
300 |
301 |
302 | def prepare_jit_inputs(inputs, model, tokenizer):
303 | batch_size = len(inputs)
304 | dummy_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt")
305 | dummy_input = dummy_input.to(model.device)
306 | if model.config.use_cache:
307 | dummy_input["past_key_values"] = generate_past_key_values(model, batch_size, 1)
308 | dummy_input["attention_mask"] = torch.cat(
309 | [
310 | torch.zeros(dummy_input["attention_mask"].shape[0], 1)
311 | .to(dummy_input["attention_mask"].dtype)
312 | .to(model.device),
313 | dummy_input["attention_mask"],
314 | ],
315 | -1,
316 | )
317 | return dummy_input
318 |
319 |
320 | class _ModelFallbackWrapper(GenerationMixin):
321 | __slots__ = ("_optimized", "_default")
322 |
323 | def __init__(self, optimized, default):
324 | self._optimized = optimized
325 | self._default = default
326 |
327 | def __call__(self, *args, **kwargs):
328 | if kwargs["past_key_values"] is None and self._default.config.use_cache:
329 | kwargs["past_key_values"] = generate_past_key_values(self._default, kwargs["input_ids"].shape[0], 0)
330 | kwargs.pop("position_ids", None)
331 | for k in list(kwargs.keys()):
332 | if kwargs[k] is None or isinstance(kwargs[k], bool):
333 | kwargs.pop(k)
334 | outputs = self._optimized(**kwargs)
335 | lm_logits = outputs[0]
336 | past_key_values = outputs[1]
337 | fixed_output = CausalLMOutputWithPast(
338 | loss=None,
339 | logits=lm_logits,
340 | past_key_values=past_key_values,
341 | hidden_states=None,
342 | attentions=None,
343 | )
344 | return fixed_output
345 |
346 | def __getattr__(self, item):
347 | return getattr(self._default, item)
348 |
349 | def prepare_inputs_for_generation(
350 | self, input_ids, past_key_values=None, inputs_embeds=None, use_cache=None, **kwargs
351 | ):
352 | return self._default.prepare_inputs_for_generation(
353 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, **kwargs
354 | )
355 |
356 | def _reorder_cache(
357 | self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
358 | ) -> Tuple[Tuple[torch.Tensor]]:
359 | """
360 | This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
361 | [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
362 | beam_idx at every generation step.
363 | """
364 | return self._default._reorder_cache(past_key_values, beam_idx)
365 |
366 |
367 | def main():
368 | parser = argparse.ArgumentParser()
369 | parser.add_argument(
370 | "--model_type",
371 | default=None,
372 | type=str,
373 | required=True,
374 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
375 | )
376 | parser.add_argument(
377 | "--model_name_or_path",
378 | default=None,
379 | type=str,
380 | required=True,
381 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
382 | )
383 |
384 | parser.add_argument("--prompt", type=str, default="")
385 | parser.add_argument("--length", type=int, default=100)
386 | parser.add_argument("--num_hidden_layers", type=int, default=None)
387 | parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
388 |
389 | parser.add_argument(
390 | "--temperature",
391 | type=float,
392 | default=1.0,
393 | help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
394 | )
395 | parser.add_argument(
396 | "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
397 | )
398 | parser.add_argument("--k", type=int, default=0)
399 | parser.add_argument("--p", type=float, default=0.9)
400 |
401 | parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.")
402 | parser.add_argument("--suffix", type=str, default="", help="Text added after the input.")
403 | parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.")
404 | parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
405 |
406 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
407 | parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
408 | parser.add_argument("--stream_output", action="store_true")
409 | parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
410 | parser.add_argument(
411 | "--fp16",
412 | action="store_true",
413 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
414 | )
415 | parser.add_argument("--jit", action="store_true", help="Whether or not to use jit trace to accelerate inference")
416 |
417 | # args = parser.parse_args()
418 | args, unknown_args = parser.parse_known_args()
419 |
420 | hf_parser = HfArgumentParser(UnlimiformerArguments)
421 | unlimiformer_args, unknown_unlimiformer_args = hf_parser.parse_known_args()
422 |
423 | if len(set(unknown_args) & set(unknown_unlimiformer_args)) > 0:
424 | raise ValueError(f"Unknown arguments detected: {set(unknown_args) & set(unknown_unlimiformer_args)}")
425 |
426 | args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
427 | args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
428 |
429 | logger.warning(f"device: {args.device}, n_gpu: {args.n_gpu}, 16-bits training: {args.fp16}")
430 |
431 | set_seed(args)
432 |
433 | # Initialize the model and tokenizer
434 | try:
435 | args.model_type = args.model_type.lower()
436 | model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
437 | except KeyError:
438 | raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
439 |
440 | tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
441 | if tokenizer.pad_token is None:
442 | tokenizer.pad_token = tokenizer.eos_token
443 | model_kwargs = {}
444 | if args.num_hidden_layers is not None:
445 | model_kwargs["num_hidden_layers"] = args.num_hidden_layers
446 | model = model_class.from_pretrained(args.model_name_or_path, **model_kwargs)
447 |
448 | if args.fp16:
449 | model.half()
450 | model.to(args.device)
451 |
452 | max_seq_length = getattr(model.config, "max_position_embeddings", 0)
453 | args.length = adjust_length_to_model(args.length, max_sequence_length=max_seq_length)
454 | logger.info(args)
455 |
456 | if unlimiformer_args.test_unlimiformer:
457 | unlimiformer_kwargs = {
458 | 'layer_begin': unlimiformer_args.layer_begin,
459 | 'layer_end': unlimiformer_args.layer_end,
460 | 'unlimiformer_head_num': unlimiformer_args.unlimiformer_head_num,
461 | 'exclude_attention': unlimiformer_args.unlimiformer_exclude,
462 | 'chunk_overlap': unlimiformer_args.unlimiformer_chunk_overlap,
463 | 'model_encoder_max_len': unlimiformer_args.unlimiformer_chunk_size,
464 | 'verbose': unlimiformer_args.unlimiformer_verbose, 'tokenizer': tokenizer,
465 | 'unlimiformer_training': unlimiformer_args.unlimiformer_training,
466 | 'use_datastore': unlimiformer_args.use_datastore,
467 | 'flat_index': unlimiformer_args.flat_index,
468 | 'test_datastore': unlimiformer_args.test_datastore,
469 | 'reconstruct_embeddings': unlimiformer_args.reconstruct_embeddings,
470 | 'gpu_datastore': unlimiformer_args.gpu_datastore,
471 | 'gpu_index': unlimiformer_args.gpu_index,
472 | 'index_devices': unlimiformer_args.index_devices,
473 | 'datastore_device': unlimiformer_args.datastore_device,
474 | }
475 | if unlimiformer_args.random_unlimiformer_training:
476 | model = RandomTrainingUnlimiformer.convert_model(model, **unlimiformer_kwargs)
477 | else:
478 | model = Unlimiformer.convert_model(model, **unlimiformer_kwargs)
479 |
480 | prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")
481 | # Check if prompt_text is a valid file name:
482 | if os.path.exists(prompt_text):
483 | with open(prompt_text, "r") as f:
484 | prompt_text = f.read()
485 |
486 | # Different models need different input formatting and/or extra arguments
487 | requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
488 | if requires_preprocessing:
489 | prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
490 | preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
491 |
492 | if model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
493 | tokenizer_kwargs = {"add_space_before_punct_symbol": True}
494 | else:
495 | tokenizer_kwargs = {}
496 |
497 | encoded_prompt = tokenizer.encode(
498 | preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs
499 | )
500 | else:
501 | # prefix = args.prefix if args.prefix else args.padding_text
502 | prompt_text = f'{args.prefix}{prompt_text}{args.suffix}'
503 | encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
504 |
505 | if not unlimiformer_args.test_unlimiformer:
506 | encoded_prompt = encoded_prompt[:, -2048:]
507 | encoded_prompt = encoded_prompt.to(args.device)
508 |
509 | if encoded_prompt.size()[-1] == 0:
510 | input_ids = None
511 | else:
512 | input_ids = encoded_prompt
513 |
514 | if args.jit:
515 | jit_input_texts = ["enable jit"]
516 | jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer)
517 | torch._C._jit_set_texpr_fuser_enabled(False)
518 | model.config.return_dict = False
519 | if hasattr(model, "forward"):
520 | sig = inspect.signature(model.forward)
521 | else:
522 | sig = inspect.signature(model.__call__)
523 | jit_inputs = tuple(jit_inputs[key] for key in sig.parameters if jit_inputs.get(key, None) is not None)
524 | traced_model = torch.jit.trace(model, jit_inputs, strict=False)
525 | traced_model = torch.jit.freeze(traced_model.eval())
526 | traced_model(*jit_inputs)
527 | traced_model(*jit_inputs)
528 |
529 | model = _ModelFallbackWrapper(traced_model, model)
530 |
531 | model.eval()
532 | output_sequences = model.generate(
533 | input_ids=input_ids,
534 | # max_length=args.length + len(encoded_prompt[0]),
535 | max_new_tokens=args.length,
536 | temperature=args.temperature,
537 | top_k=args.k,
538 | top_p=args.p,
539 | repetition_penalty=args.repetition_penalty,
540 | do_sample=True,
541 | num_return_sequences=args.num_return_sequences,
542 | streamer=TextStreamer(tokenizer, skip_prompt=True) if args.stream_output else None,
543 | )
544 |
545 | # Remove the batch dimension when returning multiple sequences
546 | if len(output_sequences.shape) > 2:
547 | output_sequences.squeeze_()
548 |
549 | generated_sequences = []
550 |
551 | for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
552 | print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} (input length: {input_ids.shape[-1]}) ===")
553 | generated_sequence = generated_sequence.tolist()
554 | # generated_sequence = generated_sequence[len(encoded_prompt[0]):] + tokenizer.encode(' ') + generated_sequence[:len(encoded_prompt[0])]
555 |
556 | # Decode text
557 | # text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
558 | prompt_length = min(input_ids.shape[-1], model.unlimiformer.window_size()) if unlimiformer_args.test_unlimiformer else input_ids.shape[-1]
559 | completion = tokenizer.decode(generated_sequence[prompt_length:])
560 |
561 | # Remove all text after the stop token
562 | # text = text[: text.find(args.stop_token) if args.stop_token else None]
563 |
564 | # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
565 | total_sequence = (
566 | # prompt_text +
567 | '|||' + completion
568 | )
569 |
570 | generated_sequences.append(total_sequence)
571 | print(total_sequence)
572 |
573 | return generated_sequences
574 |
575 |
576 | if __name__ == "__main__":
577 | main()
578 |
--------------------------------------------------------------------------------
/src/usage.py:
--------------------------------------------------------------------------------
1 | from unlimiformer import Unlimiformer
2 | from random_training_unlimiformer import RandomTrainingUnlimiformer
3 |
4 | from dataclasses import dataclass, field
5 | from typing import List, Optional
6 |
7 |
8 | @dataclass
9 | class UnlimiformerArguments:
10 | """
11 | Arguments pertaining to what data we are going to input our model for training and eval.
12 | """
13 | test_unlimiformer: Optional[bool] = field(
14 | default=True,
15 | metadata={
16 | "help": "whether to use KNN."
17 | },
18 | )
19 | unlimiformer_verbose: Optional[bool] = field(
20 | default=False,
21 | metadata={
22 | "help": "whether to print KNN intermediate predictions (mostly for debugging)."
23 | },
24 | )
25 | layer_begin: Optional[int] = field(
26 | default=0,
27 | metadata={"help": "The layer to begin applying KNN to. KNN will be applied to layers[knn_layer_begin:layer_end]. "
28 | "By default, it will be applied to all layers: [0:None]]"},
29 | )
30 | layer_end: Optional[int] = field(
31 | default=None,
32 | metadata={"help": "The layer to end applying KNN to. KNN will be applied to layers[knn_layer_begin:layer_end]. "
33 | "By default, it will be applied to all layers: [0:None]]"},
34 | )
35 | unlimiformer_chunk_overlap: Optional[float] = field(
36 | default=0.5,
37 | metadata={"help": "The fraction of overlap between input chunks"},
38 | )
39 | unlimiformer_chunk_size: Optional[int] = field(
40 | default=None,
41 | metadata={"help": "The size of each input chunk"},
42 | )
43 | unlimiformer_head_num: Optional[int] = field(
44 | default=None,
45 | metadata={"help": "The head to apply KNN to (if None, apply to all heads)"},
46 | )
47 | unlimiformer_exclude: Optional[bool] = field(
48 | default=False,
49 | metadata={
50 | "help": "If True, prioritize the inputs that are **not** in the standard attention window."
51 | },
52 | )
53 | random_unlimiformer_training: Optional[bool] = field(
54 | default=False,
55 | )
56 | unlimiformer_training: Optional[bool] = field(
57 | default=False,
58 | )
59 | use_datastore: Optional[bool] = field(default=False)
60 | flat_index: Optional[bool] = field(default=False)
61 | test_datastore: Optional[bool] = field(default=False)
62 | reconstruct_embeddings: Optional[bool] = field(default=False)
63 | gpu_datastore: Optional[bool] = field(default=True)
64 | gpu_index: Optional[bool] = field(default=True)
65 |
66 |
67 |
68 | # include these lines in your code somewhere before model training
69 | def training_addin():
70 | if unlimiformer_args.test_unlimiformer:
71 | unlimiformer_kwargs = {
72 | 'layer_begin': unlimiformer_args.layer_begin,
73 | 'layer_end': unlimiformer_args.layer_end,
74 | 'unlimiformer_head_num': unlimiformer_args.unlimiformer_head_num,
75 | 'exclude_attention': unlimiformer_args.unlimiformer_exclude,
76 | 'chunk_overlap': unlimiformer_args.unlimiformer_chunk_overlap,
77 | 'model_encoder_max_len': unlimiformer_args.unlimiformer_chunk_size,
78 | 'verbose': unlimiformer_args.unlimiformer_verbose, 'tokenizer': tokenizer,
79 | 'unlimiformer_training': unlimiformer_args.unlimiformer_training,
80 | 'use_datastore': unlimiformer_args.use_datastore,
81 | 'flat_index': unlimiformer_args.flat_index,
82 | 'test_datastore': unlimiformer_args.test_datastore,
83 | 'reconstruct_embeddings': unlimiformer_args.reconstruct_embeddings,
84 | 'gpu_datastore': unlimiformer_args.gpu_datastore,
85 | 'gpu_index': unlimiformer_args.gpu_index
86 | }
87 | if unlimiformer_args.random_unlimiformer_training:
88 | model = RandomTrainingUnlimiformer.convert_model(model, **unlimiformer_kwargs)
89 | else:
90 | model = Unlimiformer.convert_model(model, **unlimiformer_kwargs)
91 |
92 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abertsch72/unlimiformer/e38b0149488636da9528c7504d2befdfb76f6d98/src/utils/__init__.py
--------------------------------------------------------------------------------
/src/utils/config.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 |
4 | def handle_args_to_ignore(args: List[str]):
5 | indices_to_remove = []
6 | for i, arg in enumerate(args):
7 | if "_ignore_" in arg:
8 | indices_to_remove.append(i)
9 | if not arg.startswith("-"):
10 | indices_to_remove.append(i - 1)
11 |
12 | for i in sorted(indices_to_remove, reverse=True):
13 | del args[i]
14 |
--------------------------------------------------------------------------------
/src/utils/custom_hf_argument_parser.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import sys
4 | from typing import Tuple
5 |
6 | from transformers import HfArgumentParser
7 | from transformers.hf_argparser import DataClass
8 |
9 |
10 | class CustomHfArgumentParser(HfArgumentParser):
11 | def parse_dictionary_and_args(self) -> Tuple[DataClass, ...]:
12 | """
13 | Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
14 | dataclass types.
15 | """
16 | args = []
17 | data = {}
18 | for i in range(1, len(sys.argv)):
19 | if not sys.argv[i].endswith('.json'):
20 | break
21 |
22 | with open(sys.argv[i]) as f:
23 | new_data = json.load(f)
24 | conflicting_keys = set(new_data.keys()).intersection(data.keys())
25 | if len(conflicting_keys) > 0:
26 | raise ValueError(f'There are conflicting keys in the config files: {conflicting_keys}')
27 | data.update(new_data)
28 |
29 | for k, v in data.items():
30 | # if any options were given explicitly through the CLA then they override anything defined in the config files
31 | if f'--{k}' in sys.argv:
32 | logging.info(f'While {k}={v} was given in a config file, a manual override was set through the CLA')
33 | continue
34 | args.extend(
35 | ["--" + k, *(v if isinstance(v, list) else [str(v)])]
36 | ) # add the file arguments first so command line args has precedence
37 | args += sys.argv[i:]
38 |
39 | return self.parse_args_into_dataclasses(args=args, look_for_args_file=False)
--------------------------------------------------------------------------------
/src/utils/custom_seq2seq_trainer.py:
--------------------------------------------------------------------------------
1 | import json
2 | import math
3 | import os
4 | import time
5 | from collections import defaultdict
6 | from typing import Any, Dict, List, Optional, Tuple, Union
7 |
8 | import torch
9 | from datasets import Dataset
10 | from torch import nn
11 | from transformers.debug_utils import DebugOption
12 | from transformers.deepspeed import is_deepspeed_zero3_enabled
13 | from transformers.trainer_utils import speed_metrics
14 |
15 | from transformers.utils import logging
16 | from transformers import Seq2SeqTrainer, is_torch_tpu_available
17 |
18 | import gc
19 |
20 | if is_torch_tpu_available(check_device=False):
21 | import torch_xla.core.xla_model as xm
22 | import torch_xla.debug.metrics as met
23 |
24 |
25 | from utils.decoding import decode
26 |
27 | logger = logging.get_logger(__name__)
28 |
29 |
30 | def _clean_memory():
31 | gc.collect()
32 | torch.cuda.empty_cache()
33 |
34 | # This custom trainer is based on the trainer defined in https://github.com/huggingface/transformers/compare/main...eladsegal:public-transformers:scrolls
35 | class CustomTrainer(Seq2SeqTrainer):
36 | def __init__(
37 | self, *args, untokenized_eval_dataset=None, data_args=None, output_dir: Optional[str] = None, **kwargs
38 | ):
39 | super().__init__(*args, **kwargs)
40 | self._untokenized_eval_dataset = untokenized_eval_dataset
41 | self._max_length = data_args.val_max_target_length
42 | self._num_beams = data_args.num_beams
43 | self._output_dir = output_dir
44 | self._data_args = data_args
45 | self.mock_predictions_to_assign_zero_metric_score = self.tokenizer.encode("TOO_MANY_INPUT_TOKENS",return_tensors="np")[0]
46 |
47 | def prediction_step(
48 | self,
49 | model: nn.Module,
50 | inputs: Dict[str, Union[torch.Tensor, Any]],
51 | prediction_loss_only: bool,
52 | ignore_keys: Optional[List[str]] = None,
53 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
54 | """
55 | Perform an evaluation step on `model` using `inputs`.
56 |
57 | Subclass and override to inject custom behavior.
58 |
59 | Args:
60 | model (`nn.Module`):
61 | The model to evaluate.
62 | inputs (`Dict[str, Union[torch.Tensor, Any]]`):
63 | The inputs and targets of the model.
64 |
65 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
66 | argument `labels`. Check your model's documentation for all accepted arguments.
67 | prediction_loss_only (`bool`):
68 | Whether or not to ret`urn the loss only.
69 |
70 | Return:
71 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
72 | labels (each being optional).
73 | """
74 | if not ("labels" in inputs or 'decoder_input_ids' in inputs):
75 | if model.training:
76 | logger.warning('When computing loss, must give labels or decoder_input_ids. '
77 | 'If you only perform prediction, you can safely ignore this message')
78 | # This is an issue here because the input may be longer than the max-output length of the model,
79 | # and if nothing was given it will shift the input and use it to compute loss (and later discard it).
80 | # This may cause an indexing error when absolute embeddings are used (CUDA device side assert)
81 | inputs['decoder_input_ids'] = inputs['input_ids'][:,:2].clone() # dummy outputs
82 |
83 | if not self.args.predict_with_generate or prediction_loss_only:
84 | return super().prediction_step(
85 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
86 | )
87 |
88 | has_labels = "labels" in inputs
89 | inputs = self._prepare_inputs(inputs)
90 |
91 | # XXX: adapt synced_gpus for fairscale as well
92 | gen_kwargs = self._gen_kwargs.copy()
93 | gen_kwargs["max_length"] = (
94 | gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.model.config.max_length
95 | )
96 | gen_kwargs["num_beams"] = (
97 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
98 | )
99 | default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
100 | gen_kwargs["synced_gpus"] = (
101 | gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
102 | )
103 |
104 | if "attention_mask" in inputs:
105 | gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
106 | if "global_attention_mask" in inputs:
107 | gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)
108 |
109 | # --------------------- addition compared to the source file --------------------
110 | if 'prefix_length' in inputs:
111 | gen_kwargs['prefix_length'] = inputs['prefix_length']
112 | _clean_memory()
113 | # ------------------------------------------------------------------------------
114 |
115 | # prepare generation inputs
116 | # some encoder-decoder models can have varying encoder's and thus
117 | # varying model input names
118 | if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
119 | generation_inputs = inputs[self.model.encoder.main_input_name]
120 | else:
121 | generation_inputs = inputs[self.model.main_input_name]
122 |
123 | # Uri: to make sure we use cache even during mid-training evaluation, where this is disabled in general:
124 | gen_kwargs['use_cache'] = True
125 |
126 | generated_tokens = self.model.generate(
127 | generation_inputs,
128 | **gen_kwargs,
129 | )
130 | # --------------------- addition compared to the source file --------------------
131 | _clean_memory()
132 | # ------------------------------------------------------------------------------
133 | # in case the batch is shorter than max length, the output should be padded
134 | if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
135 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
136 |
137 | if has_labels: # changed the order of the if's here because there is no point going through the model if there are no labels to compute the loss on..
138 | with torch.no_grad():
139 | with self.compute_loss_context_manager():
140 | outputs = model(**inputs)
141 | if self.label_smoother is not None:
142 | loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
143 | else:
144 | loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
145 | else:
146 | loss = None
147 |
148 | if self.args.prediction_loss_only:
149 | return (loss, None, None)
150 |
151 | if has_labels:
152 | labels = inputs["labels"]
153 | if labels.shape[-1] < gen_kwargs["max_length"]:
154 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
155 | else:
156 | labels = None
157 |
158 | return (loss, generated_tokens, labels)
159 |
160 | @property
161 | def _restart_generator(self):
162 | if getattr(self, '_is_restart_generator', False):
163 | self._is_restart_generator = False
164 | return True
165 | return False
166 |
167 | def set_restart_generator(self):
168 | self._is_restart_generator = True
169 |
170 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
171 | sampler = super()._get_train_sampler()
172 | try:
173 | if self._restart_generator:
174 | sampler.generator.manual_seed(self._initial_seed)
175 | else:
176 | self._initial_seed = sampler.generator.initial_seed()
177 | except Exception as e:
178 | logger.warning(f'Cannot save or set the seed of the generator: {e}')
179 | return sampler
180 |
181 | def _post_process_function(self, untokenized_eval_dataset, predictions):
182 | id_to_prediction = {}
183 | id_to_label_ids = defaultdict(list)
184 |
185 | assert len(untokenized_eval_dataset) == len(self.eval_dataset)
186 |
187 | for i, (instance, not_valid_for_eval) in enumerate(zip(untokenized_eval_dataset, self.eval_dataset["not_valid_for_eval"])):
188 | if not_valid_for_eval:
189 | id_to_prediction[instance["id"]] = self.mock_predictions_to_assign_zero_metric_score
190 | else:
191 | id_to_prediction[instance["id"]] = predictions[i]
192 |
193 | if "outputs" in instance:
194 | id_to_label_ids[instance["id"]] = instance["outputs"]
195 | else:
196 | id_to_label_ids[instance["id"]].append(instance["output"])
197 |
198 | return id_to_prediction, id_to_label_ids
199 |
200 | def evaluate(
201 | self,
202 | eval_dataset: Optional[Dataset] = None,
203 | ignore_keys: Optional[List[str]] = None,
204 | metric_key_prefix: str = "eval",
205 | untokenized_eval_dataset: Optional[Dataset] = None,
206 | **gen_kwargs
207 | ) -> Dict[str, float]:
208 | """
209 | Run evaluation and returns metrics.
210 |
211 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
212 | (pass it to the init `compute_metrics` argument).
213 |
214 | You can also subclass and override this method to inject custom behavior.
215 |
216 | Args:
217 | eval_dataset (`Dataset`, *optional*):
218 | Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
219 | not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
220 | method.
221 | ignore_keys (`List[str]`, *optional*):
222 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when
223 | gathering predictions.
224 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
225 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
226 | "eval_bleu" if the prefix is `"eval"` (default)
227 | max_length (`int`, *optional*):
228 | The maximum target length to use when predicting with the generate method.
229 | num_beams (`int`, *optional*):
230 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no
231 | beam search.
232 | gen_kwargs:
233 | Additional `generate` specific kwargs.
234 |
235 | Returns:
236 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
237 | dictionary also contains the epoch number which comes from the training state.
238 | """
239 |
240 | gen_kwargs = gen_kwargs.copy()
241 | gen_kwargs["max_length"] = (
242 | gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length
243 | )
244 | gen_kwargs["num_beams"] = (
245 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
246 | )
247 | self._gen_kwargs = gen_kwargs
248 |
249 | self._memory_tracker.start()
250 |
251 | eval_dataloader = self.get_eval_dataloader(eval_dataset)
252 | # ----------------------------------- Added -----------------------------------
253 | untokenized_eval_dataset = (
254 | self._untokenized_eval_dataset if untokenized_eval_dataset is None else untokenized_eval_dataset
255 | )
256 | compute_metrics = self.compute_metrics
257 | self.compute_metrics = None
258 | # -----------------------------------------------------------------------------
259 |
260 | start_time = time.time()
261 |
262 | eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
263 | try:
264 | output = eval_loop(
265 | eval_dataloader,
266 | description="Evaluation",
267 | # No point gathering the predictions if there are no metrics, otherwise we defer to
268 | # self.args.prediction_loss_only
269 | prediction_loss_only=None, # MODIFIED since we need the predictions
270 | ignore_keys=ignore_keys,
271 | metric_key_prefix=metric_key_prefix,
272 | )
273 | finally:
274 | # ----------------------------------- Added -----------------------------------
275 | # revert the compute metrics back
276 | self.compute_metrics = compute_metrics
277 | # -----------------------------------------------------------------------------
278 |
279 | # ----------------------------------- Added -----------------------------------
280 | # compute our metrics
281 | if output.predictions is not None:
282 | eval_preds = self._post_process_function(untokenized_eval_dataset, output.predictions)
283 |
284 | if self._output_dir is not None and self.is_world_process_zero():
285 | predictions = decode(eval_preds[0], self.tokenizer, self._data_args)
286 | output_prediction_file = os.path.join(
287 | self._output_dir, f"generated_predictions_eval_{self.state.global_step}.json"
288 | )
289 | with open(output_prediction_file, "w") as writer:
290 | json.dump(predictions, writer, indent=4)
291 |
292 | output_labels_file = os.path.join(
293 | self._output_dir, f"eval_labels.json"
294 | )
295 | if not os.path.isfile(output_labels_file):
296 | with open(output_labels_file, "w") as writer:
297 | json.dump(eval_preds[1], writer, indent=4)
298 |
299 | if self.compute_metrics is not None:
300 | output.metrics.update(self.compute_metrics(*eval_preds))
301 |
302 | # Prefix all keys with metric_key_prefix + '_'
303 | for key in list(output.metrics.keys()):
304 | if not key.startswith(f"{metric_key_prefix}_"):
305 | output.metrics[f"{metric_key_prefix}_{key}"] = output.metrics.pop(key)
306 | # -----------------------------------------------------------------------------
307 |
308 | total_batch_size = self.args.eval_batch_size * self.args.world_size
309 | output.metrics.update(
310 | speed_metrics(
311 | metric_key_prefix,
312 | start_time,
313 | num_samples=output.num_samples,
314 | num_steps=math.ceil(output.num_samples / total_batch_size),
315 | )
316 | )
317 |
318 | self.log(output.metrics)
319 |
320 | if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
321 | # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
322 | xm.master_print(met.metrics_report())
323 |
324 | self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
325 |
326 | self._memory_tracker.stop_and_update_metrics(output.metrics)
327 |
328 | return output.metrics
329 |
--------------------------------------------------------------------------------
/src/utils/decoding.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import numpy as np
3 |
4 |
5 | def decode(id_to_something, tokenizer=None, data_args=None):
6 | decode_fn = None
7 | switch_case = None
8 | elem = next(iter(id_to_something.values()))
9 | if isinstance(elem, str):
10 | switch_case = -1
11 | decode_fn = lambda text: text.strip()
12 | elif isinstance(elem, list) and not isinstance(elem[0], int):
13 | if isinstance(elem[0], str):
14 | switch_case = 0
15 | decode_fn = lambda texts: [text.strip() for text in texts]
16 | else:
17 | switch_case = 1
18 | decode_fn = lambda token_ids_list: [
19 | text.strip()
20 | for text in partial(
21 | tokenizer.batch_decode, skip_special_tokens=True, clean_up_tokenization_spaces=True
22 | )(token_ids_list)
23 | ]
24 | else:
25 | switch_case = 2
26 | decode_fn = lambda token_ids: partial(
27 | tokenizer.decode, skip_special_tokens=True, clean_up_tokenization_spaces=True
28 | )(token_ids).strip()
29 |
30 | id_to_text = {}
31 | for id_, something in id_to_something.items():
32 | if switch_case == -1 or switch_case == 0:
33 | obj_to_decode = something
34 | else:
35 | if data_args is None:
36 | data_args = {}
37 | if not isinstance(data_args, dict):
38 | data_args = vars(data_args)
39 | if data_args.get("ignore_pad_token_for_loss", True):
40 | # Replace -100 in the token_ids as we can't decode them.
41 | if switch_case == 1:
42 | token_ids_list = something
43 | for i in range(len(token_ids_list)):
44 | token_ids_list[i] = _replace_padding(token_ids_list[i], tokenizer.pad_token_id)
45 | obj_to_decode = token_ids_list
46 | elif switch_case == 2:
47 | token_ids = something
48 | token_ids = _replace_padding(token_ids, tokenizer.pad_token_id)
49 | obj_to_decode = token_ids
50 | else:
51 | obj_to_decode = something
52 |
53 | id_to_text[id_] = decode_fn(obj_to_decode)
54 |
55 | return id_to_text
56 |
57 |
58 | def _replace_padding(token_ids: np.array, pad_token_id):
59 | return np.where(token_ids != -100, token_ids, pad_token_id)
60 |
--------------------------------------------------------------------------------
/src/utils/duplicates.py:
--------------------------------------------------------------------------------
1 | def drop_duplicates_in_input(untokenized_dataset):
2 | indices_to_keep = []
3 | id_to_idx = {}
4 | outputs = []
5 | for i, (id_, output) in enumerate(zip(untokenized_dataset["id"], untokenized_dataset["output"])):
6 | if id_ in id_to_idx:
7 | outputs[id_to_idx[id_]].append(output)
8 | continue
9 | indices_to_keep.append(i)
10 | id_to_idx[id_] = len(outputs)
11 | outputs.append([output])
12 | untokenized_dataset = untokenized_dataset.select(indices_to_keep).flatten_indices()
13 | untokenized_dataset = untokenized_dataset.remove_columns("output")
14 | untokenized_dataset = untokenized_dataset.add_column("outputs", outputs)
15 | return untokenized_dataset
16 |
--------------------------------------------------------------------------------
/src/utils/override_training_args.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import sys
4 |
5 | import torch.cuda
6 | from transformers.utils import logging
7 |
8 | sys.path.insert(0, os.getcwd())
9 |
10 | from dataclasses import dataclass, field
11 |
12 | from transformers.trainer_utils import IntervalStrategy
13 | from transformers import Seq2SeqTrainingArguments
14 |
15 | logger = logging.get_logger('swed_logger')
16 |
17 | @dataclass
18 | class TrainingOverridesArguments(Seq2SeqTrainingArguments):
19 | """
20 | To use if, it requires evaluation_strategy == IntervalStrategy.STEPS
21 | """
22 | eval_steps_override: float = field(default=0, metadata={"help": "a fraction, to set the the save_steps w.r.t to number of steps in "
23 | "a single epoch. changes eval_steps. 0 to disable (default)"})
24 | save_steps_override: float = field(default=0, metadata={"help": "a fraction, to set the the save_steps w.r.t to number of steps in "
25 | "a single epoch. changes save_steps. must be a multiple of eval_steps"
26 | " (or eval_steps_override if given). 0 to disable (default)"})
27 |
28 | eval_fraction: float = field(default=1, metadata={
29 | "help": "A float in (0,1] that corresponds to how much of the eval set to use during evaluations "
30 | "(same subset all the time) or an integer >= 2 which amounts to the absolute number of training "
31 | "samples to use. 1. to disable it and use the entire eval set "})
32 |
33 | use_auth_token: bool = field(
34 | default=False,
35 | metadata={
36 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
37 | "with private models). If AUTH_TOKEN is set as an environment variable, would use that"
38 | },
39 | )
40 |
41 | fp16_padding: bool = field(
42 | default=False,
43 | metadata={"help": "Whether to use padding for fp16"},
44 | )
45 |
46 |
47 | def __post_init__(self):
48 | super(TrainingOverridesArguments, self).__post_init__()
49 | if self.eval_steps_override > 0 or self.save_steps_override > 0:
50 | if self.evaluation_strategy != IntervalStrategy.STEPS:
51 | raise ValueError(
52 | f"using eval/save steps override requires evaluation strategy to be {IntervalStrategy.STEPS}"
53 | )
54 | if self.save_steps_override == 0 or self.eval_steps_override == 0:
55 | raise ValueError(
56 | f"using eval/save steps override requires both overrides to be non zero"
57 | )
58 | diff = (self.save_steps_override / self.eval_steps_override) % 1
59 | if min(1-diff, diff) > 1e-5: # we do it like that to support fractions modulo as well, with loss of precision
60 | raise ValueError(
61 | f"using eval/save steps override requires save steps override to be a multiple of eval_steps_override"
62 | )
63 | if self.use_auth_token and 'AUTH_TOKEN' in os.environ:
64 | self.use_auth_token = os.getenv('AUTH_TOKEN')
65 |
66 | @property
67 | def effective_batch_size(self):
68 | if not hasattr(self, '_ebs'):
69 | n_gpu = self.n_gpu if torch.cuda.is_available() else 1 # may be on cpu
70 | self._ebs = self.per_device_train_batch_size * self.gradient_accumulation_steps * n_gpu
71 | logger.warning(f'Training with {self.per_device_train_batch_size} per_device_train_size, {self.n_gpu} gpus and '
72 | f'{self.gradient_accumulation_steps} gradient accumulation steps, resulting in {self._ebs} effective batch size')
73 | return self._ebs
74 |
75 | def apply_overrides(self, dataset_size):
76 | # Uri:
77 | return
78 |
79 | if self.eval_steps_override == 0:
80 | return
81 | es, ss = self.eval_steps, self.save_steps
82 | total_steps_per_epoch = dataset_size / self.effective_batch_size # note that this may not be an integer
83 | eval_steps = int(total_steps_per_epoch * self.eval_steps_override)
84 | if eval_steps >= self.logging_steps:
85 | if eval_steps % self.logging_steps != 0:
86 | logger.warning(f'Eval steps override would result in eval every {eval_steps} steps, but it is not a '
87 | f'multiple of logging steps ({self.logging_steps}) so changing to '
88 | f'{eval_steps + self.logging_steps - eval_steps % self.logging_steps}')
89 | eval_steps = eval_steps + self.logging_steps - eval_steps % self.logging_steps
90 | elif eval_steps < self.logging_steps:
91 | logger.warning(f'Eval steps override would result in eval every {eval_steps} steps, but it is not a '
92 | f'multiple of logging steps ({self.logging_steps}) so changing to {self.logging_steps}')
93 | eval_steps = self.logging_steps
94 | self.eval_steps = eval_steps
95 |
96 | save_steps = int(total_steps_per_epoch * self.save_steps_override)
97 | if save_steps < eval_steps or save_steps % eval_steps != 0:
98 | logger.warning(f'Save steps override would result in eval every {save_steps} steps, but it is not a '
99 | f'multiple of eval steps ({eval_steps}) so changing to '
100 | f'{save_steps + eval_steps - save_steps % self.eval_steps}')
101 | save_steps = save_steps + eval_steps - save_steps % self.eval_steps
102 | self.save_steps = save_steps
103 |
104 | logger.warning(f'Using overrides with dataset of size {dataset_size} and effective batch size of '
105 | f'{self.effective_batch_size}, moving from (eval_steps, save_steps) '
106 | f'of {(es, ss)} to {(self.eval_steps, self.save_steps)}')
107 |
--------------------------------------------------------------------------------