├── .gitignore ├── CITATION.cff ├── LICENSE ├── README.md ├── example_inputs ├── harry_potter.txt └── harry_potter_full.txt ├── requirements.txt └── src ├── __init__.py ├── configs ├── data │ ├── contract_nli.json │ ├── gov_report.json │ ├── hotpotqa.json │ ├── hotpotqa_second_only.json │ ├── narative_qa.json │ ├── qasper.json │ ├── qmsum.json │ ├── quality.json │ ├── squad.json │ ├── squad_ordered_distractors.json │ ├── squad_shuffled_distractors.json │ └── summ_screen_fd.json ├── model │ ├── bart_base_sled.json │ ├── bart_large_sled.json │ └── primera_govreport_sled.json └── training │ └── base_training_args.json ├── index_building.py ├── inference-example.py ├── metrics ├── __init__.py └── metrics.py ├── random_training_unlimiformer.py ├── run.py ├── run_generation.py ├── unlimiformer.py ├── usage.py └── utils ├── __init__.py ├── config.py ├── custom_hf_argument_parser.py ├── custom_seq2seq_trainer.py ├── decoding.py ├── duplicates.py └── override_training_args.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .vscode/ 131 | wandb/ 132 | output* 133 | local* 134 | *.pdf -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | @article{bertsch2023unlimiformer, 2 | title={Unlimiformer: Long-Range Transformers with Unlimited Length Input}, 3 | author={Bertsch, Amanda and Alon, Uri and Neubig, Graham and Gormley, Matthew R}, 4 | journal={arXiv preprint arXiv:2305.01625}, 5 | year={2023} 6 | } 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Amanda Bertsch, Uri Alon, Graham Neubig, and Matthew R. Gormley 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unlimiformer: Long-Range Transformers with Unlimited Length Input (NeurIPS 2023) 2 | ![unlimiformer_diagram3_with_overlaps](https://github.com/abertsch72/unlimiformer/assets/15002544/55c5e623-b4de-48a5-b717-fe6ead95e66c) 3 | 4 | This is the official implementation of the paper: 5 | 6 | [Amanda Bertsch](https://www.cs.cmu.edu/~abertsch/), [Uri Alon](https://urialon.ml/), [Graham Neubig](http://www.phontron.com/), and [Matthew R. Gormley](http://www.cs.cmu.edu/~mgormley/): 7 | [Unlimiformer: Long-Range Transformers with Unlimited Length Input](https://arxiv.org/pdf/2305.01625) (to appear in **NeurIPS 2023**) 8 | 9 | Unlimiformer is a method for augmenting pretrained encoder-decoder models with retrieval-based attention, without changing the mathematical definition of attention. 10 | This allows the use of unlimited length inputs with any pretrained encoder-decoder! 11 | See also our [**Tweet**](https://twitter.com/abertsch72/status/1654110919977324545?s=20). 12 | 13 | Unlimiformer can be used to improve the performance of an already-trained model. For best results, the model can be trained with Unlimiformer training. 14 | 15 | If you have any questions on this work, please open a [GitHub issue](https://github.com/abertsch72/unlimiformer/issues) or email the authors at ```abertsch@cs.cmu.edu, ualon@cs.cmu.edu``` 16 | 17 | ## **_October 2023_** - Unlimiformer will appear at NeurIPS 2023! 18 | 19 | ## **_August 2023_** - Unlimiformer now supports **Llama-2** (and all its derivatives)! 20 | To prompt Llama-2 with extremely long inputs, for example, the content of an *entire book*, use: 21 | ```bash 22 | python src/run_generation.py --model_type llama --model_name_or_path meta-llama/Llama-2-13b-chat-hf \ 23 | --prefix "[INST] <>\n You are a helpful assistant. Answer with detailed responses according to the entire instruction or question. \n<>\n\n Summarize the following book: " \ 24 | --prompt example_inputs/harry_potter_full.txt \ 25 | --suffix " [/INST]" --test_unlimiformer --fp16 --length 200 --layer_begin 16 \ 26 | --index_devices 1 --datastore_device 1 27 | ``` 28 | * The final prompt will be a concatenation of the content of the flags: `--prefix`, `--prompt`, `--suffix`. 29 | * The flag `--prompt` may contain either a path to a text file (e.g., `example_inputs/harry_potter_full.txt`) or the concrete prompt string. 30 | * The flag `--test_unlimiformer` is required to enable Unlimiformer. 31 | * The flag `--length` determines the desired output length. 32 | * The flag `--layer_begin` determines the layer from which Unlimiformer will start to be applied. For example, if we set `--layer_begin 20`, the first 20 layers of the model will perform the standard attention over the last `context_window_size` tokens of the prompt as usual, and the 21st layer and above will attend to the _entire long input_. From our initial experiments, the value of `--layer_begin` should be more than half of the total number of layers in the model, and tuning it dramatically changes the quality of the output. 33 | * The flags: `--datastore_device N` and `--index_devices N1 N2 N3 ...` specify on which GPUs to store Unlimiformer's datastore and index (the base model will be stored on GPU #0). 34 | * Add the flag `--stream_output` to make the generated tokens appear one by one as they are generated. 35 | 36 | 37 | ## Getting Started 38 | 39 | ### General Instructions 40 | Copy the files from `src` into your source code folder. 41 | 42 | You'll need to set values for the Unlimiformer-specific arguments outlined in [`usage.py`](https://github.com/abertsch72/unlimiformer/blob/main/src/usage.py) - you can add these arguments wherever you usually process hyperparameters. To use the model, you must set `test_unlimiformer=True`. For datastore usage, the model must be in evaluation model (e.g. call ```model.eval()``` before inference). 43 | 44 | [`inference-example.py`](https://github.com/abertsch72/unlimiformer/blob/main/src/inference-example.py) outlines a minimal example for running a sequence through an Unlimiformer model, using the default arguments. 45 | 46 | [`run.py`](https://github.com/abertsch72/unlimiformer/blob/main/src/run.py) is an example of a full training setup that integrates Unlimiformer, adopted from [SLED](https://github.com/Mivg/SLED). See full command lines below. 47 | 48 | ### Reproducing the Experiments from the Paper - Command Lines 49 | 50 | To run a standard finetuning + evaluation of BART-base on the GovReport dataset (as examples), use: 51 | ```python 52 | python src/run.py \ 53 | src/configs/training/base_training_args.json \ 54 | src/configs/data/gov_report.json \ 55 | --output_dir output_train_bart_base_local/ \ 56 | --learning_rate 1e-5 \ 57 | --model_name_or_path facebook/bart-base \ 58 | --max_source_length 1024 \ 59 | --eval_max_source_length 1024 --do_eval=True \ 60 | --eval_steps 1000 --save_steps 1000 \ 61 | --per_device_eval_batch_size 1 --per_device_train_batch_size 2 \ 62 | --extra_metrics bertscore 63 | ``` 64 | 65 | * To use Unlimiformer at **training** time (called "Retrieval training" in the paper), use: `--unlimiformer_training --max_source_length 16384` 66 | * In this case, you might want to use Unlimiformer also at **test**/validation time, and use also: `--test_unlimiformer --eval_max_source_length 999999` 67 | * Alternatively, to use the computationally cheaper "Random-encoded" at **training** time, use `--random_unlimiformer_training --max_source_length 16384` 68 | * To alternate between "retrieval training" and "random-encoded training", use both flags: `--unlimiformer_training --random_unlimiformer_training --max_source_length 16384` 69 | 70 | For additional flags and options, see [`usage.py`](https://github.com/abertsch72/unlimiformer/blob/main/src/usage.py) 71 | 72 | 73 | ## Recommended settings 74 | 75 | ### To evaluate with Unlimiformer 76 | At evaluation time, we recommend the default value for each setting. 77 | 78 | ### To train with Unlimiformer 79 | For an inexpensive method, we recommend training as usual and using Unlimiformer during early stopping. To do so, set ```knn=True``` and leave all other values at default. 80 | 81 | 82 | For best performance, there are 3 expensive settings for training. The best one varies by dataset. 83 | 1. Set ```random_unlimiformer_training=True```: this is the *random-encoded training* setting from the paper 84 | 2. Set ```unlimiformer_training=True```: this is the *retrieval training* setting from the paper 85 | 3. Set ```random_unlimiformer_training=True``` AND ```unlimiformer_training=True```: this is the *alternating training* setting from the paper 86 | 87 | See Table 5 in the paper for a more detailed breakdown of relative training costs. 88 | 89 | ## Tips for very large inputs 90 | ### For training 91 | * you may need to truncate your inputs at training time, e.g. to 8k or 16k tokens. You can use the full inputs at evaluation time 92 | * you can also try splitting your inputs into 16k-token-chunks and training on each one as its own example 93 | ### For evaluation (including early stopping) 94 | * if you're consistently running out of CUDA memory, set ```use_datastore=True``` to use a Faiss datastore to store hidden states. 95 | * if you're still having issues, set ```gpu_datastore=False``` or ```gpu_index=False```, but note that this will degrade performance 96 | 97 | ## Trained models 98 | The following models from the paper are available on Hugging Face. Please note that you must add the Unlimiformer-specific files to your repository, and load these models with ```test_unlimiformer=True```. *If you download these models from Hugging Face, they may not use Unlimiformer by default!* 99 | 100 | ### Table 3: low-cost training methods 101 | | Dataset | Method | Hugging Face link | 102 | | ------------- | ------------- | ------------- | 103 | | GovReport | Baseline: BART-base | [abertsch/bart-base-govreport](https://huggingface.co/abertsch/bart-base-govreport) | 104 | | GovReport | BART-base + Unlimiformer early stopping | [abertsch/unlimiformer-bart-govreport-earlyk](https://huggingface.co/abertsch/unlimiformer-bart-govreport-earlyk) | 105 | | SummScreen | Baseline: BART-base | [abertsch/bart-base-summscreen](https://huggingface.co/abertsch/bart-base-summscreen) | 106 | | SummScreen | BART-base + Unlimiformer early stopping | [abertsch/unlimiformer-bart-summscreen-earlyk](https://huggingface.co/abertsch/unlimiformer-bart-summscreen-earlyk) | 107 | 108 | 109 | ### Table 4: Long-range training methods 110 | | Dataset | Method | Hugging Face link | 111 | | ------------- | ------------- | ------------- | 112 | | GovReport | BART + Unlimiformer (alternating training) | [abertsch/unlimiformer-bart-govreport-alternating](https://huggingface.co/abertsch/unlimiformer-bart-govreport-alternating) | 113 | | SummScreen | BART + Unlimiformer (retrieval training) | [abertsch/unlimiformer-bart-summscreen-retrieval](https://huggingface.co/abertsch/unlimiformer-bart-summscreen-retrieval) | 114 | 115 | ## Table 5: BookSum 116 | | Dataset | Method | Hugging Face link | 117 | | ------------- | ------------- | ------------- | 118 | | BookSum | Baseline: BART-base | [abertsch/bart-base-booksum](https://huggingface.co/abertsch/bart-base-booksum) | 119 | | BookSum | BART-base + Unlimiformer early stopping | [abertsch/unlimiformer-bart-booksum-earlyk](https://huggingface.co/abertsch/unlimiformer-bart-booksum-earlyk) | 120 | | Booksum | BART-base + Unlimiformer (random-encoding training) | [abertsch/unlimiformer-bart-booksum-random-encoding](https://huggingface.co/abertsch/unlimiformer-bart-booksum-random-encoding) | 121 | | Booksum | BART-base + Unlimiformer (alternating training) | [abertsch/unlimiformer-bart-booksum-alternating](https://huggingface.co/abertsch/unlimiformer-bart-booksum-alternating) | 122 | 123 | ## Results 124 | 125 | image 126 | image 127 | image 128 | 129 | 130 | ## Citation 131 | If you use our method or models, please cite [our paper](https://arxiv.org/abs/2305.01625): 132 | ``` 133 | @article{bertsch2023unlimiformer, 134 | title={Unlimiformer: Long-Range Transformers with Unlimited Length Input}, 135 | author={Bertsch, Amanda and Alon, Uri and Neubig, Graham and Gormley, Matthew R}, 136 | journal={arXiv preprint arXiv:2305.01625}, 137 | year={2023} 138 | } 139 | ``` 140 | 141 | 142 | 143 | -------------------------------------------------------------------------------- /example_inputs/harry_potter.txt: -------------------------------------------------------------------------------- 1 | Harry Potter and the Sorcerer's Stone 2 | 3 | 4 | CHAPTER ONE 5 | 6 | THE BOY WHO LIVED 7 | 8 | Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say 9 | that they were perfectly normal, thank you very much. They were the last 10 | people you'd expect to be involved in anything strange or mysterious, 11 | because they just didn't hold with such nonsense. 12 | 13 | Mr. Dursley was the director of a firm called Grunnings, which made 14 | drills. He was a big, beefy man with hardly any neck, although he did 15 | have a very large mustache. Mrs. Dursley was thin and blonde and had 16 | nearly twice the usual amount of neck, which came in very useful as she 17 | spent so much of her time craning over garden fences, spying on the 18 | neighbors. The Dursleys had a small son called Dudley and in their 19 | opinion there was no finer boy anywhere. 20 | 21 | The Dursleys had everything they wanted, but they also had a secret, and 22 | their greatest fear was that somebody would discover it. They didn't 23 | think they could bear it if anyone found out about the Potters. Mrs. 24 | Potter was Mrs. Dursley's sister, but they hadn't met for several years; 25 | in fact, Mrs. Dursley pretended she didn't have a sister, because her 26 | sister and her good-for-nothing husband were as unDursleyish as it was 27 | possible to be. The Dursleys shuddered to think what the neighbors would 28 | say if the Potters arrived in the street. The Dursleys knew that the 29 | Potters had a small son, too, but they had never even seen him. This boy 30 | was another good reason for keeping the Potters away; they didn't want 31 | Dudley mixing with a child like that. 32 | 33 | When Mr. and Mrs. Dursley woke up on the dull, gray Tuesday our story 34 | starts, there was nothing about the cloudy sky outside to suggest that 35 | strange and mysterious things would soon be happening all over the 36 | country. Mr. Dursley hummed as he picked out his most boring tie for 37 | work, and Mrs. Dursley gossiped away happily as she wrestled a screaming 38 | Dudley into his high chair. 39 | 40 | None of them noticed a large, tawny owl flutter past the window. 41 | 42 | At half past eight, Mr. Dursley picked up his briefcase, pecked Mrs. 43 | Dursley on the cheek, and tried to kiss Dudley good-bye but missed, 44 | because Dudley was now having a tantrum and throwing his cereal at the 45 | walls. "Little tyke," chortled Mr. Dursley as he left the house. He got 46 | into his car and backed out of number four's drive. 47 | 48 | It was on the corner of the street that he noticed the first sign of 49 | something peculiar -- a cat reading a map. For a second, Mr. Dursley 50 | didn't realize what he had seen -- then he jerked his head around to 51 | look again. There was a tabby cat standing on the corner of Privet 52 | Drive, but there wasn't a map in sight. What could he have been thinking 53 | of? It must have been a trick of the light. Mr. Dursley blinked and 54 | stared at the cat. It stared back. As Mr. Dursley drove around the 55 | corner and up the road, he watched the cat in his mirror. It was now 56 | reading the sign that said Privet Drive -- no, looking at the sign; cats 57 | couldn't read maps or signs. Mr. Dursley gave himself a little shake and 58 | put the cat out of his mind. As he drove toward town he thought of 59 | nothing except a large order of drills he was hoping to get that day. 60 | 61 | But on the edge of town, drills were driven out of his mind by something 62 | else. As he sat in the usual morning traffic jam, he couldn't help 63 | noticing that there seemed to be a lot of strangely dressed people 64 | about. People in cloaks. Mr. Dursley couldn't bear people who dressed in 65 | funny clothes -- the getups you saw on young people! He supposed this 66 | was some stupid new fashion. He drummed his fingers on the steering 67 | wheel and his eyes fell on a huddle of these weirdos standing quite 68 | close by. They were whispering excitedly together. Mr. Dursley was 69 | enraged to see that a couple of them weren't young at all; why, that man 70 | had to be older than he was, and wearing an emerald-green cloak! The 71 | nerve of him! But then it struck Mr. Dursley that this was probably some 72 | silly stunt -- these people were obviously collecting for something... 73 | yes, that would be it. The traffic moved on and a few minutes later, Mr. 74 | Dursley arrived in the Grunnings parking lot, his mind back on drills. 75 | 76 | Mr. Dursley always sat with his back to the window in his office on the 77 | ninth floor. If he hadn't, he might have found it harder to concentrate 78 | on drills that morning. He didn't see the owls swoop ing past in broad 79 | daylight, though people down in the street did; they pointed and gazed 80 | open- mouthed as owl after owl sped overhead. Most of them had never 81 | seen an owl even at nighttime. Mr. Dursley, however, had a perfectly 82 | normal, owl-free morning. He yelled at five different people. He made 83 | several important telephone calls and shouted a bit more. He was in a 84 | very good mood until lunchtime, when he thought he'd stretch his legs 85 | and walk across the road to buy himself a bun from the bakery. 86 | 87 | He'd forgotten all about the people in cloaks until he passed a group of 88 | them next to the baker's. He eyed them angrily as he passed. He didn't 89 | know why, but they made him uneasy. This bunch were whispering 90 | excitedly, too, and he couldn't see a single collecting tin. It was on 91 | his way back past them, clutching a large doughnut in a bag, that he 92 | caught a few words of what they were saying. 93 | 94 | "The Potters, that's right, that's what I heard yes, their son, Harry" 95 | 96 | Mr. Dursley stopped dead. Fear flooded him. He looked back at the 97 | whisperers as if he wanted to say something to them, but thought better 98 | of it. 99 | 100 | He dashed back across the road, hurried up to his office, snapped at his 101 | secretary not to disturb him, seized his telephone, and had almost 102 | finished dialing his home number when he changed his mind. He put the 103 | receiver back down and stroked his mustache, thinking... no, he was 104 | being stupid. Potter wasn't such an unusual name. He was sure there were 105 | lots of people called Potter who had a son called Harry. Come to think 106 | of it, he wasn't even sure his nephew was called Harry. He'd never even 107 | seen the boy. It might have been Harvey. Or Harold. There was no point 108 | in worrying Mrs. Dursley; she always got so upset at any mention of her 109 | sister. He didn't blame her -- if he'd had a sister like that... but all 110 | the same, those people in cloaks... 111 | 112 | He found it a lot harder to concentrate on drills that afternoon and 113 | when he left the building at five o'clock, he was still so worried that 114 | he walked straight into someone just outside the door. 115 | 116 | "Sorry," he grunted, as the tiny old man stumbled and almost fell. It 117 | was a few seconds before Mr. Dursley realized that the man was wearing a 118 | violet cloak. He didn't seem at all upset at being almost knocked to the 119 | ground. On the contrary, his face split into a wide smile and he said in 120 | a squeaky voice that made passersby stare, "Don't be sorry, my dear sir, 121 | for nothing could upset me today! Rejoice, for You-Know-Who has gone at 122 | last! Even Muggles like yourself should be celebrating, this happy, 123 | happy day!" 124 | 125 | And the old man hugged Mr. Dursley around the 126 | middle and walked off. 127 | 128 | Mr. Dursley stood rooted to the spot. He had been 129 | hugged by a complete stranger. He also thought he 130 | had been called a Muggle, whatever that was. He was 131 | rattled. He hurried to his car and set off for home, 132 | hoping he was imagining things, which he had never 133 | hoped before, because he didn’t approve of 134 | imagination. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | sentencepiece 2 | protobuf<=3.20.1 3 | nltk 4 | datasets>=1.17.0 5 | absl-py 6 | rouge-score 7 | pandas 8 | transformers>=4.27.0 9 | wandb 10 | makefun>=1.14.0 11 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abertsch72/unlimiformer/e38b0149488636da9528c7504d2befdfb76f6d98/src/__init__.py -------------------------------------------------------------------------------- /src/configs/data/contract_nli.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "contract_nli", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 64, 6 | "pad_prefix": true, 7 | "generation_max_length": 8, 8 | "num_train_epochs": 20, 9 | "metric_names": ["exact_match"], 10 | "metric_for_best_model": "exact_match", 11 | "greater_is_better": true 12 | } -------------------------------------------------------------------------------- /src/configs/data/gov_report.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "gov_report", 4 | "max_source_length": 16384, 5 | "generation_max_length": 1024, 6 | "max_prefix_length": 0, 7 | "pad_prefix": false, 8 | "num_train_epochs": 10, 9 | "metric_names": ["rouge"], 10 | "metric_for_best_model": "rouge/geometric_mean", 11 | "greater_is_better": true 12 | } -------------------------------------------------------------------------------- /src/configs/data/hotpotqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "hotpotqa", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 64, 6 | "pad_prefix": true, 7 | "generation_max_length": 128, 8 | "num_train_epochs": 9, 9 | "metric_names": ["f1", "exact_match"], 10 | "metric_for_best_model": "f1", 11 | "greater_is_better": true 12 | } -------------------------------------------------------------------------------- /src/configs/data/hotpotqa_second_only.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "hotpotqa_second_only", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 64, 6 | "pad_prefix": true, 7 | "generation_max_length": 128, 8 | "num_train_epochs": 9, 9 | "metric_names": ["f1", "exact_match"], 10 | "metric_for_best_model": "f1", 11 | "greater_is_better": true 12 | } -------------------------------------------------------------------------------- /src/configs/data/narative_qa.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "narrative_qa", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 64, 6 | "pad_prefix": true, 7 | "num_train_epochs": 2, 8 | "generation_max_length": 128, 9 | "metric_names": ["f1"], 10 | "metric_for_best_model": "f1", 11 | "greater_is_better": true 12 | } -------------------------------------------------------------------------------- /src/configs/data/qasper.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "qasper", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 64, 6 | "pad_prefix": true, 7 | "generation_max_length": 128, 8 | "num_train_epochs": 20, 9 | "metric_names": ["f1"], 10 | "metric_for_best_model": "f1", 11 | "greater_is_better": true 12 | } -------------------------------------------------------------------------------- /src/configs/data/qmsum.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "qmsum", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 64, 6 | "pad_prefix": true, 7 | "num_train_epochs": 20, 8 | "generation_max_length": 1024, 9 | "metric_names": ["rouge"], 10 | "metric_for_best_model": "rouge/geometric_mean", 11 | "greater_is_better": true 12 | } -------------------------------------------------------------------------------- /src/configs/data/quality.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "quality", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 160, 6 | "pad_prefix": true, 7 | "num_train_epochs": 20, 8 | "generation_max_length": 128, 9 | "metric_names": ["exact_match"], 10 | "metric_for_best_model": "exact_match", 11 | "greater_is_better": true 12 | } -------------------------------------------------------------------------------- /src/configs/data/squad.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "squad", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 64, 6 | "pad_prefix": true, 7 | "num_train_epochs": 3, 8 | "generation_max_length": 128, 9 | "metric_names": ["f1", "exact_match"], 10 | "metric_for_best_model": "f1", 11 | "greater_is_better": true 12 | } -------------------------------------------------------------------------------- /src/configs/data/squad_ordered_distractors.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "squad_ordered_distractors", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 64, 6 | "pad_prefix": true, 7 | "num_train_epochs": 3, 8 | "generation_max_length": 128, 9 | "metric_names": ["f1", "exact_match"], 10 | "metric_for_best_model": "f1", 11 | "greater_is_better": true 12 | } -------------------------------------------------------------------------------- /src/configs/data/squad_shuffled_distractors.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "squad_shuffled_distractors", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 64, 6 | "pad_prefix": true, 7 | "num_train_epochs": 3, 8 | "generation_max_length": 128, 9 | "metric_names": ["f1", "exact_match"], 10 | "metric_for_best_model": "f1", 11 | "greater_is_better": true 12 | } -------------------------------------------------------------------------------- /src/configs/data/summ_screen_fd.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "summ_screen_fd", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 0, 6 | "pad_prefix": false, 7 | "num_train_epochs": 10, 8 | "generation_max_length": 1024, 9 | "metric_names": ["rouge"], 10 | "metric_for_best_model": "rouge/geometric_mean", 11 | "greater_is_better": true 12 | } -------------------------------------------------------------------------------- /src/configs/model/bart_base_sled.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name_or_path": "tau/bart-base-sled", 3 | "use_auth_token": false, 4 | "max_target_length": 1024, 5 | "fp16": true 6 | } -------------------------------------------------------------------------------- /src/configs/model/bart_large_sled.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name_or_path": "tau/bart-large-sled", 3 | "use_auth_token": false, 4 | "max_target_length": 1024, 5 | "fp16": true 6 | } -------------------------------------------------------------------------------- /src/configs/model/primera_govreport_sled.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "tau/sled", 3 | "underlying_config": "allenai/PRIMERA", 4 | "context_size": 4096, 5 | "window_fraction": 0.5, 6 | "prepend_prefix": true, 7 | "encode_prefix": true, 8 | "sliding_method": "dynamic" 9 | } -------------------------------------------------------------------------------- /src/configs/training/base_training_args.json: -------------------------------------------------------------------------------- 1 | { 2 | "eval_steps_override": 0.5, 3 | "save_steps_override": 0.5, 4 | "evaluation_strategy": "steps", 5 | "eval_fraction": 1000, 6 | "predict_with_generate": true, 7 | "gradient_checkpointing": true, 8 | "do_train": true, 9 | "do_eval": true, 10 | "seed": 42, 11 | "warmup_ratio": 0.1, 12 | "save_total_limit": 2, 13 | "preprocessing_num_workers": 1, 14 | "load_best_model_at_end": true, 15 | "lr_scheduler": "linear", 16 | "adam_epsilon": 1e-6, 17 | "adam_beta1": 0.9, 18 | "adam_beta2": 0.98, 19 | "weight_decay": 0.001, 20 | "patience": 10, 21 | "extra_metrics": "bertscore" 22 | } -------------------------------------------------------------------------------- /src/index_building.py: -------------------------------------------------------------------------------- 1 | import faiss 2 | import faiss.contrib.torch_utils 3 | import time 4 | import logging 5 | 6 | import torch 7 | import numpy as np 8 | 9 | code_size = 64 10 | 11 | class DatastoreBatch(): 12 | def __init__(self, dim, batch_size, flat_index=False, gpu_index=False, verbose=False, index_device=None) -> None: 13 | self.indices = [] 14 | self.batch_size = batch_size 15 | self.device = index_device if index_device is not None else torch.device('cuda' if gpu_index else 'cpu') 16 | for i in range(batch_size): 17 | self.indices.append(Datastore(dim, use_flat_index=flat_index, gpu_index=gpu_index, verbose=verbose, device=self.device)) 18 | 19 | def move_to_gpu(self): 20 | for i in range(self.batch_size): 21 | self.indices[i].move_to_gpu() 22 | 23 | def add_keys(self, keys, num_keys_to_add_at_a_time=100000): 24 | for i in range(self.batch_size): 25 | self.indices[i].add_keys(keys[i], num_keys_to_add_at_a_time) 26 | 27 | def train_index(self, keys): 28 | for index, example_keys in zip(self.indices, keys): 29 | index.train_index(example_keys) 30 | 31 | def search(self, queries, k): 32 | found_scores, found_values = [], [] 33 | for i in range(self.batch_size): 34 | scores, values = self.indices[i].search(queries[i], k) 35 | found_scores.append(scores) 36 | found_values.append(values) 37 | return torch.stack(found_scores, dim=0), torch.stack(found_values, dim=0) 38 | 39 | def search_and_reconstruct(self, queries, k): 40 | found_scores, found_values = [], [] 41 | found_vectors = [] 42 | for i in range(self.batch_size): 43 | scores, values, vectors = self.indices[i].search_and_reconstruct(queries[i], k) 44 | found_scores.append(scores) 45 | found_values.append(values) 46 | found_vectors.append(vectors) 47 | return torch.stack(found_scores, dim=0), torch.stack(found_values, dim=0), torch.stack(found_vectors, dim=0) 48 | 49 | class Datastore(): 50 | def __init__(self, dim, use_flat_index=False, gpu_index=False, verbose=False, device=None) -> None: 51 | self.dimension = dim 52 | self.device = device if device is not None else torch.device('cuda' if gpu_index else 'cpu') 53 | self.logger = logging.getLogger('index_building') 54 | self.logger.setLevel(20) 55 | self.use_flat_index = use_flat_index 56 | self.gpu_index = gpu_index 57 | 58 | # Initialize faiss index 59 | # TODO: is preprocessing efficient enough to spend time on? 60 | if not use_flat_index: 61 | self.index = faiss.IndexFlatIP(self.dimension) # inner product index because we use IP attention 62 | 63 | # need to wrap in index ID map to enable add_with_ids 64 | # self.index = faiss.IndexIDMap(self.index) 65 | 66 | self.index_size = 0 67 | # if self.gpu_index: 68 | # self.move_to_gpu() 69 | 70 | def move_to_gpu(self): 71 | if self.use_flat_index: 72 | # self.keys = self.keys.to(self.device) 73 | return 74 | else: 75 | co = faiss.GpuClonerOptions() 76 | co.useFloat16 = True 77 | self.index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), self.device.index, self.index, co) 78 | 79 | def train_index(self, keys): 80 | if self.use_flat_index: 81 | self.add_keys(keys=keys, index_is_trained=True) 82 | else: 83 | keys = keys.cpu().float() 84 | ncentroids = int(keys.shape[0] / 128) 85 | self.index = faiss.IndexIVFPQ(self.index, self.dimension, 86 | ncentroids, code_size, 8) 87 | self.index.nprobe = min(32, ncentroids) 88 | # if not self.gpu_index: 89 | # keys = keys.cpu() 90 | 91 | self.logger.info('Training index') 92 | start_time = time.time() 93 | self.index.train(keys) 94 | self.logger.info(f'Training took {time.time() - start_time} s') 95 | self.add_keys(keys=keys, index_is_trained=True) 96 | # self.keys = None 97 | if self.gpu_index: 98 | self.move_to_gpu() 99 | 100 | def add_keys(self, keys, num_keys_to_add_at_a_time=1000000, index_is_trained=False): 101 | self.keys = keys 102 | if not self.use_flat_index and index_is_trained: 103 | start = 0 104 | while start < keys.shape[0]: 105 | end = min(len(keys), start + num_keys_to_add_at_a_time) 106 | to_add = keys[start:end] 107 | # if not self.gpu_index: 108 | # to_add = to_add.cpu() 109 | # self.index.add_with_ids(to_add, torch.arange(start+self.index_size, end+self.index_size)) 110 | self.index.add(to_add) 111 | self.index_size += end - start 112 | start += end 113 | if (start % 1000000) == 0: 114 | self.logger.info(f'Added {start} tokens so far') 115 | # else: 116 | # self.keys.append(keys) 117 | 118 | # self.logger.info(f'Adding total {start} keys') 119 | # self.logger.info(f'Adding took {time.time() - start_time} s') 120 | 121 | def search_and_reconstruct(self, queries, k): 122 | if len(queries.shape) == 1: # searching for only 1 vector, add one extra dim 123 | self.logger.info("Searching for a single vector; unsqueezing") 124 | queries = queries.unsqueeze(0) 125 | # self.logger.info("Searching with reconstruct") 126 | assert queries.shape[-1] == self.dimension # query vectors are same shape as "key" vectors 127 | scores, values, vectors = self.index.index.search_and_reconstruct(queries.cpu().detach(), k) 128 | # self.logger.info("Searching done") 129 | return scores, values, vectors 130 | 131 | def search(self, queries, k): 132 | # model_device = queries.device 133 | # model_dtype = queries.dtype 134 | if len(queries.shape) == 1: # searching for only 1 vector, add one extra dim 135 | self.logger.info("Searching for a single vector; unsqueezing") 136 | queries = queries.unsqueeze(0) 137 | assert queries.shape[-1] == self.dimension # query vectors are same shape as "key" vectors 138 | # if not self.gpu_index: 139 | # queries = queries.cpu() 140 | # else: 141 | # queries = queries.to(self.device) 142 | if self.use_flat_index: 143 | if self.gpu_index: 144 | scores, values = faiss.knn_gpu(faiss.StandardGpuResources(), queries, self.keys, k, 145 | metric=faiss.METRIC_INNER_PRODUCT, device=self.device.index) 146 | else: 147 | scores, values = faiss.knn(queries, self.keys, k, metric=faiss.METRIC_INNER_PRODUCT) 148 | scores = torch.from_numpy(scores).to(queries.dtype) 149 | values = torch.from_numpy(values) #.to(model_dtype) 150 | else: 151 | scores, values = self.index.search(queries.float(), k) 152 | 153 | # avoid returning -1 as a value 154 | # TODO: get a handle on the attention mask and mask the values that were -1 155 | values = torch.where(torch.logical_or(values < 0, values >= self.keys.shape[0]), torch.zeros_like(values), values) 156 | # self.logger.info("Searching done") 157 | # return scores.to(model_dtype).to(model_device), values.to(model_device) 158 | return scores, values 159 | 160 | 161 | -------------------------------------------------------------------------------- /src/inference-example.py: -------------------------------------------------------------------------------- 1 | from unlimiformer import Unlimiformer 2 | from random_training_unlimiformer import RandomTrainingUnlimiformer 3 | from usage import UnlimiformerArguments, training_addin 4 | 5 | from transformers import BartForConditionalGeneration, AutoTokenizer 6 | from datasets import load_dataset 7 | import torch 8 | 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | 11 | # example using govreport 12 | modelname = "abertsch/unlimiformer-bart-govreport-alternating" 13 | dataset = load_dataset("urialon/gov_report_validation") 14 | 15 | tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") 16 | model = BartForConditionalGeneration.from_pretrained(modelname) 17 | 18 | example_input = dataset['validation'][0]['input'] 19 | 20 | example = tokenizer(example_input, truncation=False, return_tensors="pt") 21 | truncated_example = tokenizer(example_input, truncation=True, max_length=1024, return_tensors="pt") 22 | 23 | example.to(device) 24 | truncated_example.to(device) 25 | 26 | print(f"INPUT LENGTH (tokens): {example['input_ids'].shape[-1]}") 27 | 28 | 29 | defaults = UnlimiformerArguments() 30 | unlimiformer_kwargs = { 31 | 'layer_begin': defaults.layer_begin, 32 | 'layer_end': defaults.layer_end, 33 | 'unlimiformer_head_num': defaults.unlimiformer_head_num, 34 | 'exclude_attention': defaults.unlimiformer_exclude, 35 | 'chunk_overlap': defaults.unlimiformer_chunk_overlap, 36 | 'model_encoder_max_len': defaults.unlimiformer_chunk_size, 37 | 'verbose': defaults.unlimiformer_verbose, 'tokenizer': tokenizer, 38 | 'unlimiformer_training': defaults.unlimiformer_training, 39 | 'use_datastore': defaults.use_datastore, 40 | 'flat_index': defaults.flat_index, 41 | 'test_datastore': defaults.test_datastore, 42 | 'reconstruct_embeddings': defaults.reconstruct_embeddings, 43 | 'gpu_datastore': defaults.gpu_datastore, 44 | 'gpu_index': defaults.gpu_index 45 | } 46 | 47 | model.to(device) 48 | # the output of the model /without/ using unlimiformer 49 | truncated_out = tokenizer.batch_decode(model.generate(**truncated_example, max_length=512)) 50 | 51 | model = Unlimiformer.convert_model(model, **unlimiformer_kwargs) 52 | model.eval() 53 | model.to(device) 54 | 55 | # the output of the model /with/ unlimiformer 56 | unlimiformer_out = tokenizer.batch_decode(model.generate(**example, max_length=512), ignore_special_tokens=True)[0] 57 | print(unlimiformer_out) 58 | -------------------------------------------------------------------------------- /src/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import load_metric, download_metric 2 | -------------------------------------------------------------------------------- /src/metrics/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | import os 3 | import importlib 4 | from abc import ABC, abstractmethod 5 | import inspect 6 | import shutil 7 | 8 | import numpy as np 9 | 10 | from utils.decoding import decode 11 | from datasets import load_metric as hf_load_metric 12 | from huggingface_hub import hf_hub_download 13 | 14 | 15 | class Metric(ABC): 16 | def __init__(self, **kwargs) -> None: 17 | super().__init__() 18 | self._kwargs = kwargs 19 | 20 | self.prefix = os.path.splitext(os.path.basename(inspect.getfile(self.__class__)))[0] 21 | self.requires_decoded = False 22 | 23 | def __call__(self, id_to_pred, id_to_labels, is_decoded=False): 24 | if self.requires_decoded and is_decoded is False: 25 | id_to_pred = self._decode(id_to_pred) 26 | id_to_labels = self._decode(id_to_labels) 27 | return self._compute_metrics(id_to_pred, id_to_labels) 28 | 29 | @abstractmethod 30 | def _compute_metrics(self, id_to_pred, id_to_labels) -> Dict[str, float]: 31 | return 32 | 33 | def _decode(self, id_to_something): 34 | tokenizer = self._kwargs.get("tokenizer") 35 | data_args = self._kwargs.get("data_args") 36 | return decode(id_to_something, tokenizer, data_args) 37 | 38 | 39 | class MetricCollection(Metric): 40 | def __init__(self, metrics: List[Metric], **kwargs): 41 | super().__init__(**kwargs) 42 | self._metrics = metrics 43 | 44 | def __call__(self, id_to_pred, id_to_labels): 45 | return self._compute_metrics(id_to_pred, id_to_labels) 46 | 47 | def _compute_metrics(self, id_to_pred, id_to_labels): 48 | results = {} 49 | 50 | id_to_pred_decoded = None 51 | id_to_labels_decoded = None 52 | for metric in self._metrics: 53 | metric_prefix = f"{metric.prefix}/" if metric.prefix else "" 54 | if metric.requires_decoded: 55 | if id_to_pred_decoded is None: 56 | id_to_pred_decoded = self._decode(id_to_pred) 57 | if id_to_labels_decoded is None: 58 | id_to_labels_decoded = self._decode(id_to_labels) 59 | 60 | result = metric(id_to_pred_decoded, id_to_labels_decoded, is_decoded=True) 61 | else: 62 | result = metric(id_to_pred, id_to_labels) 63 | 64 | results.update({f"{metric_prefix}{k}": np.mean(v) if type(v) is list else v for k, v in result.items() if type(v) is not str}) 65 | 66 | results["num_predicted"] = len(id_to_pred) 67 | results["mean_prediction_length_characters"] = np.mean([len(pred) for pred in id_to_pred_decoded.values()]) 68 | 69 | elem = next(iter(id_to_pred.values())) 70 | if not ((isinstance(elem, list) and isinstance(elem[0], str)) or isinstance(elem, str)): 71 | tokenizer = self._kwargs["tokenizer"] 72 | results["mean_prediction_length_tokens"] = np.mean( 73 | [np.count_nonzero(np.array(pred) != tokenizer.pad_token_id) for pred in id_to_pred.values()] 74 | ) # includes BOS/EOS tokens 75 | 76 | results = {key: round(value, 4) for key, value in results.items()} 77 | return results 78 | 79 | 80 | def load_metric(paths: List[str], **kwargs): 81 | if paths is None or len(paths) == 0: 82 | return None 83 | if isinstance(paths, str): 84 | paths = [paths] 85 | else: 86 | paths = [path for path in paths] 87 | 88 | metric_cls_list = [] 89 | 90 | scrolls_custom_metrics = [] 91 | to_remove = [] 92 | for i, path in enumerate(paths): 93 | if not os.path.isfile(path): 94 | scrolls_custom_metrics.append(path) 95 | to_remove.append(i) 96 | for i in sorted(to_remove, reverse=True): 97 | del paths[i] 98 | if len(scrolls_custom_metrics) > 0: 99 | scrolls_custom_metrics.insert(0, "") # In order to have an identifying comma in the beginning 100 | metric_cls_list.append(ScrollsWrapper(",".join(scrolls_custom_metrics), **kwargs)) 101 | 102 | for path in paths: 103 | path = path.strip() 104 | if len(path) == 0: 105 | continue 106 | if os.path.isfile(path) is False: 107 | path = os.path.join("src", "metrics", f"{path}.py") 108 | 109 | module = path[:-3].replace(os.sep, ".") 110 | 111 | metric_cls = import_main_class(module) 112 | metric_cls_list.append(metric_cls(**kwargs)) 113 | 114 | return MetricCollection(metric_cls_list, **kwargs) 115 | 116 | 117 | # Modified from datasets.load 118 | def import_main_class(module_path): 119 | """Import a module at module_path and return its main class""" 120 | module = importlib.import_module(module_path) 121 | 122 | main_cls_type = Metric 123 | 124 | # Find the main class in our imported module 125 | module_main_cls = None 126 | for name, obj in module.__dict__.items(): 127 | if isinstance(obj, type) and issubclass(obj, main_cls_type): 128 | if inspect.isabstract(obj): 129 | continue 130 | module_main_cls = obj 131 | break 132 | 133 | return module_main_cls 134 | 135 | 136 | class ScrollsWrapper(Metric): 137 | def __init__(self, comma_separated_metric_names, **kwargs) -> None: 138 | super().__init__(**kwargs) 139 | self.prefix = None 140 | 141 | self._metric = hf_load_metric(download_metric(), comma_separated_metric_names, keep_in_memory=True) 142 | 143 | self.requires_decoded = True 144 | 145 | def _compute_metrics(self, id_to_pred, id_to_labels) -> Dict[str, float]: 146 | return self._metric.compute(**self._metric.convert_from_map_format(id_to_pred, id_to_labels)) 147 | 148 | class HFMetricWrapper(Metric): 149 | def __init__(self, metric_name, **kwargs) -> None: 150 | super().__init__(**kwargs) 151 | self._metric = hf_load_metric(metric_name) 152 | self.kwargs = HFMetricWrapper.metric_specific_kwargs.get(metric_name, {}) 153 | self.requires_decoded = True 154 | self.prefix = metric_name 155 | self.requires_decoded = True 156 | 157 | def _compute_metrics(self, id_to_pred, id_to_labels) -> Dict[str, float]: 158 | return self._metric.compute(**self.convert_from_map_format(id_to_pred, id_to_labels), **self.kwargs) 159 | 160 | def convert_from_map_format(self, id_to_pred, id_to_labels): 161 | index_to_id = list(id_to_pred.keys()) 162 | predictions = [id_to_pred[id_] for id_ in index_to_id] 163 | references = [id_to_labels[id_] for id_ in index_to_id] 164 | return {"predictions": predictions, "references": references} 165 | 166 | metric_specific_kwargs = { 167 | 'bertscore': { 168 | # 'model_type': 'microsoft/deberta-large-mnli' or the larger 'microsoft/deberta-xlarge-mnli' 169 | 'model_type': 'facebook/bart-large-mnli', # has context window of 1024, 170 | 'num_layers': 11 # according to: https://docs.google.com/spreadsheets/d/1RKOVpselB98Nnh_EOC4A2BYn8_201tmPODpNWu4w7xI/edit#gid=0 171 | } 172 | } 173 | 174 | 175 | def download_metric(): 176 | # here we load the custom metrics 177 | scrolls_metric_path = hf_hub_download(repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type='dataset') 178 | updated_scrolls_metric_path = ( 179 | os.path.dirname(scrolls_metric_path) + os.path.basename(scrolls_metric_path).replace(".", "_") + ".py" 180 | ) 181 | shutil.copy(scrolls_metric_path, updated_scrolls_metric_path) 182 | return updated_scrolls_metric_path 183 | -------------------------------------------------------------------------------- /src/random_training_unlimiformer.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from enum import Enum, auto 6 | from unlimiformer import Unlimiformer, ModelType, UnlimiformerBART, UnlimiformerT5, UnlimiformerLED 7 | from transformers import BartModel, BartForConditionalGeneration, \ 8 | T5Model, T5ForConditionalGeneration, \ 9 | LEDModel, LEDForConditionalGeneration, \ 10 | AutoModelForSeq2SeqLM 11 | 12 | class RandomTrainingUnlimiformer(Unlimiformer[ModelType]): 13 | def __init__(self, model: ModelType, *args, **kwargs): 14 | super().__init__(model, *args, **kwargs) 15 | self.training_hooks_injected = False 16 | self.train_step = 0 17 | 18 | @classmethod 19 | def convert_model(cls, model, *args, **kwargs): 20 | # model_clone = AutoModelForSeq2SeqLM.from_config(model.config) 21 | # model_clone.load_state_dict(model.state_dict()).to(args.device) 22 | type_to_class = { 23 | BartModel: RandomUnlimiformerBART, 24 | BartForConditionalGeneration: RandomUnlimiformerBART, 25 | T5Model: RandomUnlimiformerT5, 26 | T5ForConditionalGeneration: RandomUnlimiformerT5, 27 | LEDModel: RandomUnlimiformerLED, 28 | LEDForConditionalGeneration: RandomUnlimiformerLED, 29 | } 30 | type_to_class[type(model)](model, *args, **kwargs) 31 | return model 32 | 33 | def pre_eval_hook(self): 34 | self.remove_training_hooks(self.model) 35 | self.inject_hooks(self.model) 36 | self.original_model_eval_func() 37 | 38 | def pre_train_hook(self, mode=True): 39 | # mode=True means model.train() is called 40 | # mode=False means model.eval() is called 41 | torch.cuda.empty_cache() 42 | if mode is True: 43 | self.break_out(self.model) 44 | self.remove_training_hooks(self.model) 45 | if self.unlimiformer_training and self.train_step % 2 == 0: 46 | super().inject_training_hooks(self.model) 47 | else: 48 | self.inject_training_hooks(self.model) 49 | self.train_step += 1 50 | self.original_model_train_func(mode) 51 | 52 | def inject_training_hooks(self, model): 53 | if self.training_hooks_injected: 54 | return 55 | # self.original_forward_func = model.forward 56 | model.forward = self.random_inputs_forward_hook 57 | 58 | decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end) 59 | 60 | self.original_decoder_layer_self_attn_forward_funcs = [] 61 | for decoder_layer in decoder_layers_to_run: 62 | attention = self.self_attention(decoder_layer) 63 | self.original_decoder_layer_self_attn_forward_funcs.append(attention.forward) 64 | attention.forward = self.create_self_attn_random_pre_forward_hook(attention.forward) 65 | 66 | self.original_decoder_layer_forward_funcs = [] 67 | for decoder_layer in decoder_layers_to_run: 68 | self.original_decoder_layer_forward_funcs.append(decoder_layer.forward) 69 | decoder_layer.forward = self.create_decoder_layer_random_func(decoder_layer.forward, decoder_layer) 70 | 71 | self.original_decoder_layer_cross_attn_forward_funcs = [] 72 | for i, decoder_layer in enumerate(decoder_layers_to_run): 73 | attention = self.cross_attention(decoder_layer) 74 | self.original_decoder_layer_cross_attn_forward_funcs.append(attention.forward) 75 | 76 | self.inject_hooks_for_unaffected_layers(model, decoder_layers_to_run) 77 | 78 | self.training_hooks_injected = True 79 | 80 | def create_self_attn_random_pre_forward_hook(self, original_self_attn_forward_func): 81 | def self_attention_pre_forward_hook(*args, **kwargs): 82 | kwargs['past_key_value'] = None 83 | return original_self_attn_forward_func(*args, **kwargs) 84 | 85 | return self_attention_pre_forward_hook 86 | 87 | def create_decoder_layer_random_func(self, decoder_layer_original_forward_func, decoder_layer): 88 | def checkpointed_decoder_layer( 89 | hidden_states: torch.Tensor, 90 | attention_mask=None, 91 | encoder_hidden_states=None, 92 | encoder_attention_mask=None, 93 | layer_head_mask=None, 94 | cross_attn_layer_head_mask=None, 95 | past_key_value=None, 96 | output_attentions=False, 97 | position_bias=None, 98 | encoder_decoder_position_bias=None, 99 | use_cache=True): 100 | 101 | 102 | 103 | def sample_and_forward(hidden_states, attention_mask, 104 | encoder_hidden_states, encoder_attention_mask, layer_head_mask, 105 | cross_attn_layer_head_mask, past_key_value, 106 | output_attentions, use_cache, long_inputs, long_inputs_mask, rand_indices, 107 | position_bias, encoder_decoder_position_bias): 108 | 109 | sampled_input, _ = self.sample_long_input(long_inputs, long_inputs_mask, rand_indices) 110 | key, value = self.create_key_value(sampled_input, decoder_layer) 111 | decoder_layer_args = self.create_decoder_layer_args( 112 | hidden_states=hidden_states, 113 | attention_mask=attention_mask, 114 | encoder_hidden_states=encoder_hidden_states, 115 | encoder_attention_mask=encoder_attention_mask, 116 | layer_head_mask=layer_head_mask, 117 | cross_attn_layer_head_mask=cross_attn_layer_head_mask, 118 | past_key_value=past_key_value, 119 | output_attentions=output_attentions, 120 | position_bias=position_bias, 121 | encoder_decoder_position_bias=encoder_decoder_position_bias, 122 | use_cache=use_cache, 123 | key=key,value=value 124 | ) 125 | return decoder_layer_original_forward_func(**decoder_layer_args) 126 | 127 | 128 | with torch.no_grad(): 129 | # This sampling must be done outside of the checkpoint, to ensure that the same sampling happens 130 | # both in "forward" and "backward" passes 131 | rand_indices = self.sample_random_indices() 132 | 133 | return torch.utils.checkpoint.checkpoint( 134 | sample_and_forward, hidden_states, attention_mask, 135 | encoder_hidden_states, encoder_attention_mask, layer_head_mask, 136 | cross_attn_layer_head_mask, None, 137 | output_attentions, use_cache, self.long_inputs_encoded, self.long_inputs_mask, rand_indices, 138 | position_bias, encoder_decoder_position_bias) 139 | 140 | return checkpointed_decoder_layer 141 | 142 | def sample_random_indices(self): 143 | rand_indices_list = [] 144 | seq_lens = self.long_inputs_mask.sum(-1).tolist() 145 | for seq_len in seq_lens: 146 | if seq_len < self.actual_model_window_size: 147 | rand_indices = torch.arange(self.actual_model_window_size).to(self.device) 148 | rand_indices_list.append(rand_indices) 149 | continue 150 | 151 | rand_indices = torch.torch.randperm(seq_len)[:self.actual_model_window_size].to(self.device) 152 | if seq_len < self.actual_model_window_size: 153 | padding = max(self.actual_model_window_size - seq_len, 0) 154 | rand_indices = torch.cat([rand_indices, torch.arange(padding).to(self.device) + seq_len], axis=-1).to(self.device) 155 | rand_indices_list.append(rand_indices) 156 | rand_indices = torch.stack(rand_indices_list, dim=0) 157 | return rand_indices 158 | 159 | def random_inputs_forward_hook(self, input_ids=None, attention_mask=None, labels=None, **kwargs): 160 | self.model.base_model.decoder.gradient_checkpointing = False 161 | self.long_inputs_encoded, self.long_inputs_mask = self.chunked_encode_input(input_ids=input_ids, attention_mask=attention_mask) 162 | 163 | # TODO: should the inputs be sampled or the truncated beginning? 164 | # if self.random_knn_initial_inputs: 165 | # encoded_inputs, encoded_inputs_mask = self.sample_long_input(self.long_inputs_encoded, self.long_inputs_mask) 166 | # else: 167 | encoded_inputs = self.long_inputs_encoded[:, :self.actual_model_window_size] 168 | encoded_inputs_mask = self.long_inputs_mask[:, :self.actual_model_window_size] 169 | return self.original_forward_func(encoder_outputs=(encoded_inputs, ), labels=labels, attention_mask=encoded_inputs_mask, **kwargs) 170 | 171 | def sample_long_input(self, long_inputs_encoded, long_inputs_mask, random_indices=None): 172 | if long_inputs_mask.shape[-1] < self.actual_model_window_size: 173 | return long_inputs_encoded, long_inputs_mask 174 | batch_size = long_inputs_encoded.shape[0] 175 | 176 | if random_indices is None: 177 | random_indices = self.sample_random_indices() 178 | random_mask = torch.zeros_like(long_inputs_mask).to(self.device) \ 179 | .scatter_(dim=-1, index=random_indices, src=torch.ones_like(random_indices)).bool().to(self.device) 180 | sampled_input = long_inputs_encoded[random_mask].reshape(batch_size, self.actual_model_window_size, -1).to(self.device) 181 | sampled_mask = long_inputs_mask[random_mask].reshape(batch_size, self.actual_model_window_size).to(self.device) 182 | return sampled_input, sampled_mask 183 | 184 | def chunked_encode_input(self, input_ids, attention_mask): 185 | long_inputs_encoded = [] 186 | long_inputs_mask = [] 187 | window_indices = self.window_indices(input_ids.shape[-1]) 188 | 189 | self.is_input_encoding_pass = True 190 | for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices: 191 | chunk = input_ids[:, context_start_ind:context_end_ind] 192 | chunk_attention_mask = attention_mask[:, context_start_ind:context_end_ind] 193 | output = self.model.base_model.encoder(chunk, attention_mask=chunk_attention_mask, return_dict=True, output_hidden_states=True) 194 | encoder_last_hidden_state = output.last_hidden_state # (batch, time, dim) 195 | 196 | # list of (batch, head, chunked_time, dim) 197 | encoder_last_hidden_state = encoder_last_hidden_state[:, update_start_ind:update_end_ind] # (batch, chunked_time, dim) 198 | chunk_attention_mask = chunk_attention_mask[:, update_start_ind:update_end_ind] # (batch, chunked_time) 199 | 200 | long_inputs_encoded.append(encoder_last_hidden_state) # (batch, chunked_source_len, dim) 201 | long_inputs_mask.append(chunk_attention_mask) # (batch, chunked_source_len) 202 | 203 | long_inputs_encoded = torch.cat(long_inputs_encoded, dim=1) # (batch, source_len, dim) 204 | long_inputs_mask = torch.cat(long_inputs_mask, dim=1) # (batch, source_len) 205 | 206 | self.is_input_encoding_pass = False 207 | if self.verbose: 208 | print(f'Input: ' 209 | f'{self.tokenizer.decode(input_ids[0][:self.actual_model_window_size], skip_special_tokens=True)} ||| ' 210 | f'{self.tokenizer.decode(input_ids[0][self.actual_model_window_size:], skip_special_tokens=True)}') 211 | print() 212 | return long_inputs_encoded, long_inputs_mask 213 | 214 | class RandomUnlimiformerBART(RandomTrainingUnlimiformer[BartModel], UnlimiformerBART): 215 | def __init__(self, model: BartModel, *args, **kwargs): 216 | super().__init__(model, *args, **kwargs) 217 | 218 | class RandomUnlimiformerT5(RandomTrainingUnlimiformer[T5Model], UnlimiformerT5): 219 | def __init__(self, model: T5Model, *args, **kwargs): 220 | super().__init__(model, *args, **kwargs) 221 | 222 | class RandomUnlimiformerLED(RandomTrainingUnlimiformer[LEDModel], UnlimiformerLED): 223 | def __init__(self, model: LEDModel, *args, **kwargs): 224 | super().__init__(model, *args, **kwargs) -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for sequence to sequence. 18 | """ 19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. 20 | import logging 21 | import os 22 | import sys 23 | 24 | import numpy as np 25 | from unlimiformer import Unlimiformer 26 | from random_training_unlimiformer import RandomTrainingUnlimiformer 27 | 28 | import nltk 29 | 30 | # we import the logging frameworks before any other import to make sure all monkey patching for the logging are active 31 | # from sled import SledConfig 32 | 33 | import wandb 34 | import torch 35 | 36 | sys.path.insert(0, os.path.dirname(__file__)) # seq2seq package path 37 | sys.path.insert(0, os.getcwd()) 38 | 39 | from dataclasses import dataclass, field, replace 40 | from typing import List, Optional 41 | import json 42 | from copy import deepcopy 43 | import torch.nn.functional as F 44 | 45 | import datasets 46 | 47 | import transformers 48 | from transformers import ( 49 | AutoConfig, 50 | AutoModelForSeq2SeqLM, 51 | AutoTokenizer, 52 | EarlyStoppingCallback, 53 | set_seed, WEIGHTS_NAME, 54 | ) 55 | from transformers.trainer_utils import get_last_checkpoint 56 | from transformers import DataCollatorForSeq2Seq 57 | 58 | from datasets import load_dataset 59 | 60 | # noinspection PyUnresolvedReferences 61 | # import sled # *** required so that SledModels will be registered for the AutoClasses *** 62 | 63 | from utils.config import handle_args_to_ignore 64 | from utils.decoding import decode 65 | from metrics import load_metric 66 | from utils.duplicates import drop_duplicates_in_input 67 | from utils.override_training_args import TrainingOverridesArguments 68 | from utils.custom_seq2seq_trainer import CustomTrainer 69 | from utils.custom_hf_argument_parser import CustomHfArgumentParser 70 | from metrics.metrics import HFMetricWrapper, MetricCollection 71 | 72 | logger = logging.getLogger('sled') 73 | 74 | PREFIX_DOC_SEP = '\n\n' 75 | 76 | DEBUG = os.environ.get('DEBUG', 'false').lower() in {'1', 'true', 'yes'} # If set, will set some configuration to help debug 77 | if DEBUG: 78 | assert not torch.cuda.is_available() or torch.cuda.device_count() == 1 79 | 80 | 81 | @dataclass 82 | class ModelArguments: 83 | """ 84 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 85 | """ 86 | 87 | model_name_or_path: str = field( 88 | default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 89 | ) 90 | config_name: Optional[str] = field( 91 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 92 | ) 93 | tokenizer_name: Optional[str] = field( 94 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 95 | ) 96 | cache_dir: Optional[str] = field( 97 | default=None, 98 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 99 | ) 100 | use_fast_tokenizer: bool = field( 101 | default=True, 102 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 103 | ) 104 | model_revision: str = field( 105 | default="main", 106 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 107 | ) 108 | drop_duplicates_in_eval: bool = field( 109 | default=True, 110 | ) 111 | 112 | def __post_init__(self): 113 | pass 114 | 115 | 116 | 117 | @dataclass 118 | class DataTrainingArguments: 119 | """ 120 | Arguments pertaining to what data we are going to input our model for training and eval. 121 | """ 122 | 123 | dataset_name: Optional[str] = field( 124 | default=None, 125 | metadata={ 126 | "help": "The name of the dataset to use (via the datasets library) or name of the file in src/data." 127 | }, 128 | ) 129 | dataset_config_name: Optional[str] = field( 130 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 131 | ) 132 | metric_names: Optional[List[str]] = field( 133 | default=None, 134 | metadata={"help": "The name of the metric to use (from src/metrics)."}, 135 | ) 136 | input_column: Optional[str] = field( 137 | default=None, 138 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 139 | ) 140 | input_prefix_column: Optional[str] = field( 141 | default=None, 142 | metadata={"help": "The name of the column in the datasets containing the input prefix (e.g. questions), when those exist."}, 143 | ) 144 | output_column: Optional[str] = field( 145 | default=None, 146 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 147 | ) 148 | train_file: Optional[str] = field( 149 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 150 | ) 151 | validation_file: Optional[str] = field( 152 | default=None, 153 | metadata={ 154 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " 155 | "(a jsonlines or csv file)." 156 | }, 157 | ) 158 | test_file: Optional[str] = field( 159 | default=None, 160 | metadata={ 161 | "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)." 162 | }, 163 | ) 164 | overwrite_cache: bool = field( 165 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 166 | ) 167 | preprocessing_num_workers: Optional[int] = field( 168 | default=None, 169 | metadata={"help": "The number of processes to use for the preprocessing."}, 170 | ) 171 | max_source_length: Optional[int] = field( 172 | default=None, 173 | metadata={ 174 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 175 | "than this will be truncated, sequences shorter will be padded." 176 | }, 177 | ) 178 | eval_max_source_length: Optional[int] = field( 179 | default=None, 180 | metadata={"help": "if None, will be same as max_source_length"}, 181 | ) 182 | max_prefix_length: Optional[int] = field( 183 | default=0, 184 | metadata={ 185 | "help": "The maximum total input_prefix sequence length after tokenization. Sequences longer " 186 | "than this will be truncated, sequences shorter will be padded from the left " 187 | "(only used if prefixes are not merged)." 188 | }, 189 | ) 190 | max_target_length: Optional[int] = field( 191 | default=128, 192 | metadata={ 193 | "help": "The maximum total sequence length for target text after tokenization. Sequences longer " 194 | "than this will be truncated, sequences shorter will be padded." 195 | }, 196 | ) 197 | val_max_target_length: Optional[int] = field( 198 | default=None, 199 | metadata={ 200 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " 201 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 202 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 203 | "during ``evaluate`` and ``predict``." 204 | }, 205 | ) 206 | pad_to_max_length: bool = field( 207 | default=False, 208 | metadata={ 209 | "help": "Whether to pad all samples to model maximum sentence length. " 210 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 211 | "efficient on GPU but very bad for TPU." 212 | }, 213 | ) 214 | max_train_samples: Optional[int] = field( 215 | default=None, 216 | metadata={ 217 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 218 | "value if set." 219 | }, 220 | ) 221 | max_eval_samples: Optional[int] = field( 222 | default=None, 223 | metadata={ 224 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 225 | "value if set." 226 | }, 227 | ) 228 | max_predict_samples: Optional[int] = field( 229 | default=None, 230 | metadata={ 231 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 232 | "value if set." 233 | }, 234 | ) 235 | num_beams: Optional[int] = field( 236 | default=None, 237 | metadata={ 238 | "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 239 | "which is used during ``evaluate`` and ``predict``." 240 | }, 241 | ) 242 | ignore_pad_token_for_loss: bool = field( 243 | default=True, 244 | metadata={ 245 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 246 | }, 247 | ) 248 | source_prefix: Optional[str] = field( 249 | default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 250 | ) 251 | data_dir: Optional[str] = field( 252 | default=None, 253 | metadata={"help": "Defining the data_dir of the dataset configuration."}, 254 | ) 255 | download_mode: Optional[str] = field( 256 | default=None, 257 | metadata={ 258 | "help": "Defining the download_mode when loading the dataset. Options are `reuse_dataset_if_exists` (default), `reuse_cache_if_exists` and `force_redownload`." 259 | }, 260 | ) 261 | evaluate_on_training_data: bool = field( 262 | default=False, 263 | metadata={"help": "Whether to evaluate on training data or not, to make sure the model can overfit."}, 264 | ) 265 | folder_suffix: str = field( 266 | default="", 267 | metadata={"help": "args to be suffixes for the output folder of the run"}, 268 | ) 269 | preprocess_only: bool = field( 270 | default=False, 271 | metadata={"help": "Preprocess only: Don't start training, just do the things before"}, 272 | ) 273 | assign_zero_to_too_long_val_examples: bool = field( 274 | default=False, 275 | metadata={ 276 | "help": "If true, all sequences longer then max_source_length will be assign a score of 0 in the metric evaluation" 277 | }, 278 | ) 279 | shared_storage: bool = field( 280 | default=True, 281 | metadata={"help": "Whether nodes share the same storage"}, 282 | ) 283 | trim_very_long_strings: bool = field( 284 | default=False, 285 | metadata={"help": "Whether to trim very long strings before tokenizing them"}, 286 | ) 287 | pad_prefix: bool = field( 288 | default=False, 289 | metadata={ 290 | "help": "Whether to pad the prefix if it exists to max_prefix_length. " 291 | "Note - important if you are using a SLED model on an input that contains an input_prefix" 292 | }, 293 | ) 294 | test_start_ind: Optional[int] = field( 295 | default=None, 296 | metadata={"help": "if given, uses the test set starting from this index"}, 297 | ) 298 | test_end_ind: Optional[int] = field( 299 | default=None, 300 | metadata={"help": "if given, uses the test set ending at this index"}, 301 | ) 302 | # Uri: 303 | patience: Optional[int] = field( 304 | default=None, 305 | ) 306 | length_penalty: Optional[float] = field( 307 | default=1.0, 308 | ) 309 | extra_metrics: Optional[List[str]] = field( 310 | default=None, 311 | metadata={"help": "The name of the metric to use (from src/metrics)."}, 312 | ) 313 | chunked_training_size: Optional[int] = field( 314 | default=None, 315 | ) 316 | oracle_training: Optional[bool] = field( 317 | default=False, 318 | metadata={"help": "If True, train on the input sentences that provide the highest ROUGE score with the labels"} 319 | ) 320 | oracle_merge: Optional[bool] = field( 321 | default=False, 322 | metadata={"help": "If True, merge the oracle dataset and the standard training dataset"} 323 | ) 324 | def __post_init__(self): 325 | if self.val_max_target_length is None: 326 | self.val_max_target_length = self.max_target_length 327 | if self.pad_prefix and self.max_prefix_length == 0: 328 | raise ValueError('When padding prefix, you must set a max_prefix_length') 329 | assert self.max_prefix_length == 0 or self.max_prefix_length <= 0.5*self.max_source_length,\ 330 | 'If max_prefix_length is given, it must be much shorter than the total input' 331 | # Uri: 332 | if self.eval_max_source_length is None: 333 | self.eval_max_source_length = self.max_source_length 334 | 335 | 336 | @dataclass 337 | class UnlimiformerArguments: 338 | """ 339 | Arguments pertaining to what data we are going to input our model for training and eval. 340 | """ 341 | test_unlimiformer: Optional[bool] = field( 342 | default=False, 343 | metadata={ 344 | "help": "whether to use KNN." 345 | }, 346 | ) 347 | unlimiformer_verbose: Optional[bool] = field( 348 | default=False, 349 | metadata={ 350 | "help": "whether to print KNN intermediate predictions (mostly for debugging)." 351 | }, 352 | ) 353 | layer_begin: Optional[int] = field( 354 | default=0, 355 | metadata={"help": "The layer to begin applying KNN to. KNN will be applied to layers[knn_layer_begin:layer_end]. " 356 | "By default, it will be applied to all layers: [0:None]]"}, 357 | ) 358 | layer_end: Optional[int] = field( 359 | default=None, 360 | metadata={"help": "The layer to end applying KNN to. KNN will be applied to layers[knn_layer_begin:layer_end]. " 361 | "By default, it will be applied to all layers: [0:None]]"}, 362 | ) 363 | unlimiformer_chunk_overlap: Optional[float] = field( 364 | default=0.5, 365 | metadata={"help": "The fraction of overlap between input chunks"}, 366 | ) 367 | unlimiformer_chunk_size: Optional[int] = field( 368 | default=None, 369 | metadata={"help": "The size of each input chunk"}, 370 | ) 371 | unlimiformer_head_num: Optional[int] = field( 372 | default=None, 373 | metadata={"help": "The head to apply KNN to (if None, apply to all heads)"}, 374 | ) 375 | unlimiformer_exclude: Optional[bool] = field( 376 | default=False, 377 | metadata={ 378 | "help": "If True, prioritize the inputs that are **not** in the standard attention window." 379 | }, 380 | ) 381 | random_unlimiformer_training: Optional[bool] = field( 382 | default=False, 383 | ) 384 | unlimiformer_training: Optional[bool] = field( 385 | default=False, 386 | ) 387 | use_datastore: Optional[bool] = field(default=False) 388 | flat_index: Optional[bool] = field(default=False) 389 | test_datastore: Optional[bool] = field(default=False) 390 | reconstruct_embeddings: Optional[bool] = field(default=False) 391 | gpu_datastore: Optional[bool] = field(default=True) 392 | gpu_index: Optional[bool] = field(default=True) 393 | 394 | 395 | def main(): 396 | handle_args_to_ignore(sys.argv) # Just for sweeps 397 | 398 | # See all possible arguments in src/transformers/training_args.py 399 | # or by passing the --help flag to this script. 400 | # We now keep distinct sets of args, for a cleaner separation of concerns. 401 | 402 | parser = CustomHfArgumentParser((ModelArguments, DataTrainingArguments, TrainingOverridesArguments, UnlimiformerArguments)) 403 | model_args, data_args, training_args, unlimiformer_args = parser.parse_dictionary_and_args() 404 | 405 | set_up_logging(training_args) 406 | logger.info(f"Training Arguments: {training_args}") 407 | logger.info(f"Data Arguments: {data_args}") 408 | logger.info(f"Model Arguments: {model_args}") 409 | logger.info(f"Unlimiformer Arguments: {unlimiformer_args}") 410 | 411 | 412 | # Added to avoid wandb.errors.UsageError: Error communicating with wandb process 413 | wandb.init(settings=wandb.Settings(start_method="fork"), name=training_args.output_dir) 414 | 415 | # Used to find missing dependencies early on 416 | load_metric(data_args.metric_names, **locals()) 417 | load_extra_metrics(data_args.extra_metrics) 418 | 419 | if data_args.source_prefix is None and model_args.model_name_or_path in [ 420 | "t5-small", 421 | "t5-base", 422 | "t5-large", 423 | "t5-3b", 424 | "t5-11b", 425 | ]: 426 | logger.warning( 427 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " 428 | "`--source_prefix 'summarize: ' `" 429 | ) 430 | 431 | # Detecting last checkpoint. 432 | last_checkpoint = _detect_last_checkpoint(training_args) 433 | 434 | # Set seed before initializing model. 435 | set_seed(training_args.seed) 436 | 437 | seq2seq_dataset = _get_dataset(data_args, model_args, training_args) 438 | 439 | # Load pretrained model and tokenizer 440 | # 441 | # Distributed training: 442 | # The .from_pretrained methods guarantee that only one local process can concurrently 443 | # download model & vocab. 444 | config_name = None 445 | if model_args.config_name: 446 | config_name = model_args.config_name 447 | else: 448 | if os.path.isfile(model_args.model_name_or_path): 449 | config_name = os.path.dirname(model_args.model_name_or_path) 450 | else: 451 | config_name = model_args.model_name_or_path 452 | 453 | config_overrides = {} 454 | if training_args.gradient_checkpointing is not None: 455 | config_overrides["gradient_checkpointing"] = training_args.gradient_checkpointing 456 | 457 | config = AutoConfig.from_pretrained( 458 | config_name, 459 | cache_dir=model_args.cache_dir, 460 | revision=model_args.model_revision, 461 | use_auth_token=training_args.use_auth_token, 462 | **config_overrides 463 | ) 464 | # override for sled models to make sure we are explicit in our request 465 | # if isinstance(config, SledConfig) and (not data_args.pad_prefix or data_args.max_prefix_length == 0): 466 | # logger.warning('Setting prepend_prefix to False if using a SLED model, as the input does not have a prefix or ' 467 | # 'pad_prefix is False (all prefixes must be of the same length for SLED). If you do not use SLED ' 468 | # 'or finetune on a dataset with no prefixes, ignore this warning') 469 | # config.prepend_prefix = False 470 | 471 | if model_args.model_name_or_path is None: 472 | # Padding for divisibility by 8 473 | if config.vocab_size % 8 != 0 and training_args.fp16_padding: 474 | config.vocab_size += 8 - (config.vocab_size % 8) 475 | 476 | tokenizer_name = None 477 | if model_args.tokenizer_name: 478 | tokenizer_name = model_args.tokenizer_name 479 | else: 480 | if os.path.isfile(model_args.model_name_or_path): 481 | tokenizer_name = os.path.dirname(model_args.model_name_or_path) 482 | else: 483 | tokenizer_name = model_args.model_name_or_path 484 | tokenizer = AutoTokenizer.from_pretrained( 485 | tokenizer_name, 486 | cache_dir=model_args.cache_dir, 487 | use_fast=model_args.use_fast_tokenizer, 488 | revision=model_args.model_revision, 489 | use_auth_token=training_args.use_auth_token, 490 | ) 491 | if model_args.model_name_or_path is not None: 492 | model = AutoModelForSeq2SeqLM.from_pretrained( 493 | model_args.model_name_or_path, 494 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 495 | config=config, 496 | cache_dir=model_args.cache_dir, 497 | revision=model_args.model_revision, 498 | use_auth_token=training_args.use_auth_token, 499 | ) 500 | else: 501 | model = AutoModelForSeq2SeqLM.from_config( 502 | config, 503 | ) 504 | if unlimiformer_args.test_unlimiformer: 505 | unlimiformer_kwargs = { 506 | 'layer_begin': unlimiformer_args.layer_begin, 507 | 'layer_end': unlimiformer_args.layer_end, 508 | 'unlimiformer_head_num': unlimiformer_args.unlimiformer_head_num, 509 | 'exclude_attention': unlimiformer_args.unlimiformer_exclude, 510 | 'chunk_overlap': unlimiformer_args.unlimiformer_chunk_overlap, 511 | 'model_encoder_max_len': unlimiformer_args.unlimiformer_chunk_size, 512 | 'verbose': unlimiformer_args.unlimiformer_verbose, 'tokenizer': tokenizer, 513 | 'unlimiformer_training': unlimiformer_args.unlimiformer_training, 514 | 'use_datastore': unlimiformer_args.use_datastore, 515 | 'flat_index': unlimiformer_args.flat_index, 516 | 'test_datastore': unlimiformer_args.test_datastore, 517 | 'reconstruct_embeddings': unlimiformer_args.reconstruct_embeddings, 518 | 'gpu_datastore': unlimiformer_args.gpu_datastore, 519 | 'gpu_index': unlimiformer_args.gpu_index 520 | } 521 | if unlimiformer_args.random_unlimiformer_training: 522 | model = RandomTrainingUnlimiformer.convert_model(model, **unlimiformer_kwargs) 523 | else: 524 | model = Unlimiformer.convert_model(model, **unlimiformer_kwargs) 525 | 526 | model.config.use_cache = True 527 | if training_args.gradient_checkpointing and getattr(model.config, 'use_cache', False) and training_args.do_train: 528 | logger.warning('Cannot use cache in models when using gradient checkpointing. turning it off') 529 | model.config.use_cache = False 530 | 531 | model.resize_token_embeddings(len(tokenizer)) 532 | 533 | if model.config.decoder_start_token_id is None: 534 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") 535 | 536 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 537 | 538 | # Preprocessing the datasets. 539 | # We need to tokenize inputs and targets. 540 | if training_args.do_train: 541 | column_names = seq2seq_dataset["train"].column_names 542 | elif training_args.do_eval: 543 | column_names = seq2seq_dataset["validation"].column_names 544 | elif training_args.do_predict: 545 | column_names = seq2seq_dataset["test"].column_names 546 | else: 547 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") 548 | return 549 | 550 | # Get the column names for input/target. 551 | if data_args.input_column is None: 552 | input_column = "input" 553 | else: 554 | input_column = data_args.input_column 555 | if input_column not in column_names: 556 | raise ValueError( 557 | f"--input_column' value '{data_args.input_column}' needs to be one of: {', '.join(column_names)}" 558 | ) 559 | if data_args.input_prefix_column is None: 560 | input_prefix_column = "input_prefix" 561 | else: 562 | input_prefix_column = data_args.input_prefix_column 563 | if input_prefix_column not in column_names: 564 | raise ValueError( 565 | f"--input_prefix_column' value '{data_args.input_prefix_column}' needs to be one of: {', '.join(column_names)}" 566 | ) 567 | if data_args.output_column is None: 568 | output_column = "output" 569 | else: 570 | output_column = data_args.output_column 571 | if output_column not in column_names: 572 | raise ValueError( 573 | f"--output_column' value '{data_args.output_column}' needs to be one of: {', '.join(column_names)}" 574 | ) 575 | 576 | # Temporarily set max_target_length for training. 577 | max_target_length = data_args.max_target_length 578 | padding = "max_length" if data_args.pad_to_max_length else False 579 | 580 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): 581 | logger.warning( 582 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" 583 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" 584 | ) 585 | 586 | def preprocess_function_kwargs_fn(): 587 | return { 588 | "tokenizer": deepcopy(tokenizer), 589 | "prefix": prefix, 590 | "input_column": input_column, 591 | "input_prefix_column": input_prefix_column, 592 | "output_column": output_column, 593 | "max_source_length": data_args.max_source_length, 594 | "max_prefix_length": data_args.max_prefix_length, 595 | "max_target_length": max_target_length, 596 | "prefix_sep": PREFIX_DOC_SEP, 597 | "padding": padding, 598 | "ignore_pad_token_for_loss": data_args.ignore_pad_token_for_loss, 599 | "assign_zero_to_too_long_val_examples": data_args.assign_zero_to_too_long_val_examples, 600 | "trim_very_long_strings": data_args.trim_very_long_strings, 601 | "pad_prefix": data_args.pad_prefix 602 | } 603 | 604 | if training_args.do_train: 605 | if "train" not in seq2seq_dataset: 606 | raise ValueError("--do_train requires a train dataset") 607 | logger.info("") 608 | logger.info("Training examples before tokenization:") 609 | if input_prefix_column in column_names: 610 | logger.info(f"input_prefix #0: {seq2seq_dataset['train'][0][input_prefix_column]}") 611 | # logger.info(f"input #0: {seq2seq_dataset['train'][0]['input']}") 612 | # logger.info(f"output #0: {seq2seq_dataset['train'][0]['output']}") 613 | if input_prefix_column in column_names: 614 | logger.info(f"input_prefix #1: {seq2seq_dataset['train'][1][input_prefix_column]}") 615 | # logger.info(f"input #1: {seq2seq_dataset['train'][1]['input']}") 616 | # logger.info(f"output #1: {seq2seq_dataset['train'][1]['output']}") 617 | logger.info("") 618 | untokenized_train_dataset = seq2seq_dataset["train"] 619 | if data_args.max_train_samples is not None: 620 | untokenized_train_dataset = untokenized_train_dataset.select(range(data_args.max_train_samples)) 621 | 622 | if DEBUG: 623 | # In debug mode, we want to recreate the data 624 | data_args.shared_storage = False 625 | data_args.overwrite_cache = True 626 | with training_args.main_process_first( 627 | local=not data_args.shared_storage, desc="train dataset map pre-processing" 628 | ): 629 | 630 | if data_args.oracle_training: 631 | logger.info("Using oracle training") 632 | oracle_processed_dir = f'oracle_input_{data_args.dataset_config_name}' 633 | if os.path.isdir(oracle_processed_dir): 634 | logger.info(f"Using oracle training from {oracle_processed_dir}") 635 | oracle_training_set = datasets.load_from_disk(oracle_processed_dir) 636 | else: 637 | rouge_scorer = datasets.load_metric('rouge') 638 | oracle_training_set = untokenized_train_dataset.map( 639 | extract_oracle_sent_batch, 640 | fn_kwargs={'max_length': data_args.max_source_length, 641 | 'tokenizer': tokenizer, 642 | 'rouge_scorer': rouge_scorer}, 643 | batched=True, 644 | batch_size=1, 645 | num_proc=data_args.preprocessing_num_workers, 646 | load_from_cache_file=not data_args.overwrite_cache, 647 | desc="Extracting oracle sentences from every training example", 648 | ) 649 | oracle_training_set.save_to_disk(oracle_processed_dir) 650 | 651 | 652 | if data_args.oracle_merge: 653 | untokenized_train_dataset = datasets.concatenate_datasets([untokenized_train_dataset, oracle_training_set]) 654 | untokenized_train_dataset = untokenized_train_dataset.shuffle(seed=training_args.seed) 655 | else: 656 | untokenized_train_dataset = oracle_training_set 657 | 658 | train_dataset = untokenized_train_dataset.map( 659 | preprocess_function, 660 | fn_kwargs=preprocess_function_kwargs_fn(), 661 | batched=True, 662 | num_proc=data_args.preprocessing_num_workers, 663 | remove_columns=untokenized_train_dataset.column_names, 664 | load_from_cache_file=not data_args.overwrite_cache, 665 | desc="Running tokenizer on train dataset", 666 | ) 667 | 668 | if data_args.chunked_training_size is not None: 669 | train_dataset = train_dataset.map( 670 | chunk_dataset_function, 671 | fn_kwargs={'chunk_size': data_args.chunked_training_size}, 672 | batched=True, 673 | num_proc=data_args.preprocessing_num_workers, 674 | load_from_cache_file=not data_args.overwrite_cache, 675 | desc="Chunking train dataset source", 676 | ) 677 | train_dataset = train_dataset.shuffle(seed=training_args.seed) 678 | 679 | if training_args.do_eval: 680 | max_target_length = data_args.val_max_target_length 681 | preprocess_function_kwargs = preprocess_function_kwargs_fn() 682 | preprocess_function_kwargs["max_target_length"] = max_target_length 683 | preprocess_function_kwargs['max_source_length'] = data_args.eval_max_source_length 684 | if "validation" not in seq2seq_dataset: 685 | raise ValueError("--do_eval requires a validation dataset") 686 | logger.info("") 687 | logger.info("Validation examples before tokenization:") 688 | if input_prefix_column in column_names: 689 | logger.info(f"input_prefix #0: {seq2seq_dataset['validation'][0][input_prefix_column]}") 690 | # logger.info(f"input #0: {seq2seq_dataset['validation'][0]['input']}") 691 | # logger.info(f"output #0: {seq2seq_dataset['validation'][0]['output']}") 692 | if input_prefix_column in column_names: 693 | logger.info(f"input_prefix #1: {seq2seq_dataset['validation'][1][input_prefix_column]}") 694 | # logger.info(f"input #1: {seq2seq_dataset['validation'][1]['input']}") 695 | # logger.info(f"output #1: {seq2seq_dataset['validation'][1]['output']}") 696 | logger.info("") 697 | untokenized_eval_dataset = seq2seq_dataset["validation"] 698 | if data_args.max_eval_samples is not None: 699 | untokenized_eval_dataset = untokenized_eval_dataset.select(range(data_args.max_eval_samples)) 700 | if model_args.drop_duplicates_in_eval is True: 701 | untokenized_eval_dataset = drop_duplicates_in_input(untokenized_eval_dataset) 702 | untokenized_eval_dataset_orig = untokenized_eval_dataset 703 | assert training_args.eval_fraction > 0 704 | n = len(untokenized_eval_dataset) 705 | training_args = replace(training_args, eval_fraction = min(training_args.eval_fraction, n)) 706 | if training_args.eval_fraction != 1: 707 | if training_args.eval_fraction > 1: 708 | assert training_args.eval_fraction == int(training_args.eval_fraction) 709 | logger.info(f'using predetermined absolute samples from eval set ({training_args.eval_fraction} )') 710 | training_args = replace(training_args, eval_fraction = training_args.eval_fraction / n) 711 | indices = np.random.permutation(n)[:int(np.ceil(max(1, training_args.eval_fraction * n)))] 712 | untokenized_eval_dataset = type(untokenized_eval_dataset).from_dict(untokenized_eval_dataset[indices]) 713 | logger.info(f'During training, will only use {training_args.eval_fraction:.3%} samples of the eval set ' 714 | f'which amounts to {len(untokenized_eval_dataset)} out of {n} samples') 715 | 716 | eval_dataset = process_eval_set(data_args, preprocess_function_kwargs, training_args, untokenized_eval_dataset) 717 | eval_dataset_orig = eval_dataset 718 | if training_args.eval_fraction < 1: 719 | eval_dataset_orig = process_eval_set(data_args, preprocess_function_kwargs, training_args, 720 | untokenized_eval_dataset_orig) 721 | 722 | if training_args.do_predict: 723 | max_target_length = data_args.val_max_target_length 724 | preprocess_function_kwargs = preprocess_function_kwargs_fn() 725 | preprocess_function_kwargs["max_target_length"] = max_target_length 726 | preprocess_function_kwargs['max_source_length'] = data_args.eval_max_source_length 727 | if "test" not in seq2seq_dataset: 728 | raise ValueError("--do_predict requires a test dataset") 729 | untokenized_predict_dataset = seq2seq_dataset["test"] 730 | if data_args.max_predict_samples is not None: 731 | untokenized_predict_dataset = untokenized_predict_dataset.select(range(data_args.max_predict_samples)) 732 | if model_args.drop_duplicates_in_eval is True: 733 | untokenized_predict_dataset = drop_duplicates_in_input(untokenized_predict_dataset) 734 | 735 | if output_column in untokenized_predict_dataset.column_names: 736 | untokenized_predict_dataset = untokenized_predict_dataset.remove_columns(output_column) 737 | 738 | if data_args.test_start_ind is not None: 739 | sind = data_args.test_start_ind 740 | eind = -1 if data_args.test_end_ind is None else data_args.test_end_ind 741 | logger.info(f'Using only a subset of the test dataset [{sind}, {eind}]') 742 | untokenized_predict_dataset = type(untokenized_predict_dataset).from_dict(untokenized_predict_dataset[sind:eind]) 743 | 744 | with training_args.main_process_first( 745 | local=not data_args.shared_storage, desc="prediction dataset map pre-processing" 746 | ): 747 | predict_dataset = untokenized_predict_dataset.map( 748 | preprocess_function, 749 | fn_kwargs=preprocess_function_kwargs, 750 | batched=True, 751 | num_proc=data_args.preprocessing_num_workers, 752 | remove_columns=untokenized_predict_dataset.column_names, 753 | load_from_cache_file=not data_args.overwrite_cache, 754 | desc="Running tokenizer on prediction dataset", 755 | ) 756 | 757 | if data_args.preprocess_only: 758 | logger.info(f"With --preprocess_only, exiting after preprocess_on the data") 759 | exit() 760 | 761 | # Data collator 762 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 763 | pad_to = 8 if training_args.fp16 and training_args.fp16_padding else None 764 | 765 | 766 | data_collator = DataCollatorForSeq2Seq( 767 | tokenizer, 768 | model=model, 769 | label_pad_token_id=label_pad_token_id, 770 | pad_to_multiple_of=pad_to, 771 | ) 772 | 773 | # Metric 774 | compute_metrics = load_metric(data_args.metric_names, **locals()) 775 | compute_metrics = load_extra_metrics(data_args.extra_metrics, compute_metrics) 776 | 777 | # Initialize our Trainer 778 | trainer = CustomTrainer( 779 | model=model, 780 | args=training_args, 781 | train_dataset=train_dataset if training_args.do_train else None, 782 | eval_dataset=eval_dataset if training_args.do_eval else None, 783 | untokenized_eval_dataset=untokenized_eval_dataset if training_args.do_eval else None, 784 | tokenizer=tokenizer, 785 | data_collator=data_collator, 786 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 787 | output_dir=training_args.output_dir, 788 | data_args=data_args, 789 | callbacks=[EarlyStoppingCallback(early_stopping_patience=data_args.patience)] if data_args.patience is not None else None, 790 | ) 791 | 792 | # setup_cometml_trainer_callback(trainer) 793 | 794 | # Training 795 | if training_args.do_train: 796 | checkpoint = None 797 | if training_args.resume_from_checkpoint is not None: 798 | checkpoint = training_args.resume_from_checkpoint 799 | elif last_checkpoint is not None: 800 | checkpoint = last_checkpoint # look for checkpoints in the outdir 801 | 802 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 803 | logger.info('Done training') 804 | trainer.save_model() # Saves the tokenizer too for easy upload 805 | 806 | metrics = train_result.metrics 807 | max_train_samples = ( 808 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 809 | ) 810 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 811 | 812 | trainer.log_metrics("train", metrics) 813 | trainer.save_metrics("train", metrics) 814 | trainer.save_state() 815 | 816 | # Evaluation 817 | results = {} 818 | if training_args.do_eval: 819 | logger.info("*** Evaluate ***") 820 | 821 | if training_args.eval_fraction < 1: 822 | logger.info('setting the eval set back to the full one') 823 | trainer.eval_dataset = eval_dataset_orig 824 | trainer._untokenized_eval_dataset = untokenized_eval_dataset_orig 825 | 826 | metrics = trainer.evaluate(metric_key_prefix="eval", use_cache=True, length_penalty=data_args.length_penalty) 827 | logger.info('Done evaluating') 828 | 829 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 830 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 831 | 832 | trainer.log_metrics("eval", metrics) 833 | trainer.save_metrics("eval", metrics) 834 | 835 | if training_args.do_predict: 836 | logger.info("*** Predict ***") 837 | trainer.args.predict_with_generate = True # during prediction, we don't have labels 838 | 839 | # load last (and best) model, or the one specified if any 840 | logger.info("*** Loading model weights before the prediction ***") 841 | last_checkpoint = model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else _detect_last_checkpoint(training_args) 842 | if last_checkpoint is not None and os.path.isdir(last_checkpoint): 843 | logger.info(f'Loading weights from {last_checkpoint} for the prediction') 844 | state_dict = torch.load(os.path.join(last_checkpoint, WEIGHTS_NAME), map_location="cpu") 845 | # If the model is on the GPU, it still works! 846 | # trainer._load_state_dict_in_model(state_dict) 847 | # release memory 848 | del state_dict 849 | logger.info("*** Done loading weights ***") 850 | elif training_args.do_train: 851 | raise ValueError('Could not find a model to load for prediction') 852 | else: 853 | logger.info(f'Using {model_args.model_name_or_path} as the model for the prediction') 854 | 855 | predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", use_cache=True) 856 | logger.info('Done predicting') 857 | 858 | metrics = predict_results.metrics 859 | max_predict_samples = ( 860 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 861 | ) 862 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 863 | 864 | trainer.log_metrics("predict", metrics) 865 | trainer.save_metrics("predict", metrics) 866 | 867 | if trainer.is_world_process_zero(): 868 | if training_args.predict_with_generate: 869 | id_to_prediction = {} 870 | for i, instance in enumerate(untokenized_predict_dataset): 871 | id_to_prediction[instance["id"]] = predict_results.predictions[i] 872 | predictions = decode(id_to_prediction, tokenizer, data_args) 873 | output_name = "generated_predictions.json" 874 | if data_args.test_start_ind is not None: 875 | output_name = f"generated_predictions_{data_args.test_start_ind}_{data_args.test_end_ind}.json" 876 | output_prediction_file = os.path.join(training_args.output_dir, output_name) 877 | with open(output_prediction_file, "w") as writer: 878 | json.dump(predictions, writer, indent=4) 879 | 880 | if training_args.push_to_hub: 881 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"} 882 | if data_args.dataset_name is not None: 883 | kwargs["dataset_tags"] = data_args.dataset_name 884 | if data_args.dataset_config_name is not None: 885 | kwargs["dataset_args"] = data_args.dataset_config_name 886 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 887 | else: 888 | kwargs["dataset"] = data_args.dataset_name 889 | 890 | trainer.push_to_hub(**kwargs) 891 | 892 | return results 893 | 894 | def _detect_last_checkpoint(training_args): 895 | last_checkpoint = None 896 | if os.path.isdir(training_args.output_dir) and training_args.do_train: 897 | if not training_args.overwrite_output_dir: 898 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 899 | 900 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None: 901 | logger.info( 902 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 903 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 904 | ) 905 | return last_checkpoint 906 | 907 | def process_eval_set(data_args, preprocess_function_kwargs, training_args, untokenized_eval_dataset): 908 | with training_args.main_process_first( 909 | local=not data_args.shared_storage, desc="validation dataset map pre-processing" 910 | ): 911 | eval_dataset = untokenized_eval_dataset.map( 912 | preprocess_function, 913 | fn_kwargs=preprocess_function_kwargs, 914 | batched=True, 915 | num_proc=data_args.preprocessing_num_workers, 916 | remove_columns=untokenized_eval_dataset.column_names, 917 | load_from_cache_file=not data_args.overwrite_cache, 918 | desc="Running tokenizer on validation dataset", 919 | ) 920 | return eval_dataset 921 | 922 | 923 | def _get_dataset(data_args, model_args, training_args): 924 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 925 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 926 | # (the dataset will be downloaded automatically from the datasets Hub). 927 | # 928 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the 929 | # summaries (unless you specify column names for this with the `input_column` and `output_column` arguments). 930 | # 931 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 932 | # download the dataset. 933 | data_files = None 934 | if data_args.train_file is not None or data_args.validation_file is not None or data_args.test_file is not None: 935 | data_files = {} 936 | if data_args.train_file is not None: 937 | data_files["train"] = data_args.train_file 938 | if data_args.validation_file is not None: 939 | data_files["validation"] = data_args.validation_file 940 | if data_args.test_file is not None: 941 | data_files["test"] = data_args.test_file 942 | # Downloading and loading a dataset from the hub/local script. 943 | seq2seq_dataset = load_dataset( 944 | data_args.dataset_name, 945 | data_args.dataset_config_name, 946 | verification_mode='no_checks', 947 | cache_dir=model_args.cache_dir, 948 | data_dir=data_args.data_dir, 949 | data_files=data_files, 950 | download_mode=data_args.download_mode, 951 | use_auth_token=training_args.use_auth_token 952 | ) 953 | if training_args.do_train: 954 | training_args.apply_overrides(len(seq2seq_dataset['train'])) 955 | if data_args.evaluate_on_training_data: 956 | seq2seq_dataset["validation"] = seq2seq_dataset["train"] 957 | 958 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 959 | # https://huggingface.co/docs/datasets/loading_datasets.html. 960 | 961 | return seq2seq_dataset 962 | 963 | 964 | def set_up_logging(training_args): 965 | logging.basicConfig( 966 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 967 | datefmt="%m/%d/%Y %H:%M:%S", 968 | handlers=[logging.StreamHandler(sys.stdout)], 969 | ) 970 | log_level = training_args.get_process_log_level() 971 | logger.setLevel(log_level) 972 | datasets.utils.logging.set_verbosity(log_level) 973 | transformers.utils.logging.set_verbosity(log_level) 974 | transformers.utils.logging.enable_default_handler() 975 | transformers.utils.logging.enable_explicit_format() 976 | # Log on each process the small summary: 977 | logger.warning( 978 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 979 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 980 | ) 981 | logger.info(f"Training/evaluation parameters {training_args}") 982 | 983 | def extract_oracle_sent_batch(examples, max_length, tokenizer, rouge_scorer): 984 | items = examples.data.items() 985 | keys = [item[0] for item in items] 986 | values = [item[1] for item in items] 987 | extracted = {k: [] for k in keys} 988 | input_str = 'input' 989 | 990 | for ex in zip(*values): 991 | ex = dict(zip(keys, ex)) 992 | ex_input = ex[input_str] 993 | extracted_input = extract_oracle_sentences(ex_input, ex['output'], max_length, tokenizer, rouge_scorer) 994 | extracted[input_str].append(extracted_input) 995 | for k in set(keys) - {input_str}: 996 | extracted[k].append(ex[k]) 997 | return extracted 998 | 999 | def extract_oracle_sentences(input_sequence, output, max_length, tokenizer, rouge_scorer, criterion='rouge/geometric_mean'): 1000 | sentences = nltk.sent_tokenize(input_sequence) 1001 | selected_mask = [False for _ in sentences] 1002 | 1003 | max_rouge = 0.0 1004 | joined_selection = '' 1005 | counter = 0 1006 | while len(tokenizer(joined_selection)) < max_length and counter < 100: 1007 | cur_max_rouge = max_rouge 1008 | max_index = -1 1009 | 1010 | cur_candidate_indices = [] 1011 | cur_candidates = [] 1012 | for i in range(len(sentences)): 1013 | if selected_mask[i]: 1014 | # We already selected this sentence 1015 | continue 1016 | candidate_mask = list(selected_mask) 1017 | candidate_mask[i] = True 1018 | candidate_prediction = ' '.join(sent for sent, mask in zip(sentences, candidate_mask) if mask) 1019 | cur_candidates.append(candidate_prediction) 1020 | cur_candidate_indices.append(i) 1021 | 1022 | rouge = rouge_scorer.compute(predictions=cur_candidates, references=[[output]] * len(cur_candidates), use_aggregator=False) 1023 | aggregated_rouge_types = [s1.fmeasure * s2.fmeasure * sL.fmeasure for s1, s2, sL in zip(rouge['rouge1'], rouge['rouge2'], rouge['rougeLsum'])] 1024 | max_index = np.argmax(aggregated_rouge_types) 1025 | cur_max_rouge = aggregated_rouge_types[max_index] 1026 | 1027 | if max_rouge >= cur_max_rouge: 1028 | # No sentence improves the score 1029 | break 1030 | 1031 | selected_mask[cur_candidate_indices[max_index]] = True 1032 | max_rouge = cur_max_rouge 1033 | joined_selection = ' '.join(sent for sent, mask in zip(sentences, selected_mask) if mask) 1034 | counter += 1 1035 | 1036 | return joined_selection 1037 | 1038 | 1039 | def chunk_dataset_function(examples, chunk_size): 1040 | input_ids_str = 'input_ids' 1041 | attention_mask_str = 'attention_mask' 1042 | items = examples.data.items() 1043 | keys = [item[0] for item in items] 1044 | values = [item[1] for item in items] 1045 | chunked = {k: [] for k in keys} 1046 | for ex in zip(*values): 1047 | ex = dict(zip(keys, ex)) 1048 | for i in range(0, len(ex[input_ids_str]), chunk_size): 1049 | chunked_input_ids_st = ex[input_ids_str][i:i + chunk_size] 1050 | chunked_attention_mask = ex[attention_mask_str][i:i + chunk_size] 1051 | 1052 | if sum(chunked_attention_mask) < 10: 1053 | continue 1054 | chunked[input_ids_str].append(chunked_input_ids_st) 1055 | chunked[attention_mask_str].append(chunked_attention_mask) 1056 | for k in set(keys) - {input_ids_str, attention_mask_str}: 1057 | chunked[k].append(ex[k]) 1058 | return chunked 1059 | 1060 | 1061 | 1062 | def preprocess_function( 1063 | examples, 1064 | tokenizer, 1065 | prefix, 1066 | input_column, 1067 | input_prefix_column, 1068 | output_column, 1069 | max_source_length, 1070 | max_prefix_length, 1071 | max_target_length, 1072 | prefix_sep, 1073 | padding, 1074 | ignore_pad_token_for_loss, 1075 | assign_zero_to_too_long_val_examples, 1076 | trim_very_long_strings, 1077 | pad_prefix 1078 | ): 1079 | if not isinstance(examples[input_column][0], str): 1080 | model_inputs = _preprocess_tokenized_inputs() 1081 | else: 1082 | model_inputs = _preprocess_raw_inputs(assign_zero_to_too_long_val_examples, examples, input_column, input_prefix_column, 1083 | max_source_length, padding, prefix, tokenizer, trim_very_long_strings, max_prefix_length, 1084 | prefix_sep, pad_prefix) 1085 | 1086 | _preprocess_targets(examples, ignore_pad_token_for_loss, max_target_length, model_inputs, output_column, padding, tokenizer) 1087 | model_inputs["length"] = [len(x) for x in model_inputs["input_ids"]] 1088 | return model_inputs 1089 | 1090 | 1091 | def _preprocess_raw_inputs(assign_zero_to_too_long_val_examples, examples, input_column, input_prefix_column, 1092 | max_source_length, padding, prefix, tokenizer, trim_very_long_strings, max_prefix_length, 1093 | prefix_sep, pad_prefix): 1094 | inputs = examples[input_column] 1095 | 1096 | # the given prefix is what used in models like T5 (e.g. "summarize: ") 1097 | # if prefix exists, it is added to the input_prefixes 1098 | if input_prefix_column in examples.keys(): 1099 | input_prefixes = [inp + prefix_sep for inp in examples[input_prefix_column]] 1100 | if prefix != "": 1101 | input_prefixes = [prefix + inp for inp in input_prefixes] 1102 | elif prefix != "": 1103 | inputs = [prefix + inp for inp in inputs] 1104 | 1105 | # tokenize the input prefix if it exists 1106 | model_prefix_inputs = None 1107 | if input_prefix_column in examples.keys(): 1108 | if trim_very_long_strings: 1109 | input_prefixes = [inp[: max_prefix_length * 7] for inp in input_prefixes] 1110 | if pad_prefix: 1111 | model_prefix_inputs = tokenizer(input_prefixes, max_length=max_prefix_length, padding='max_length', truncation=True) 1112 | else: 1113 | # for led, we do not pad the prefix 1114 | model_prefix_inputs = tokenizer(input_prefixes, max_length=max_source_length, padding='do_not_pad', truncation=True) 1115 | 1116 | if trim_very_long_strings: 1117 | inputs = [inp[: max_source_length * 7] for inp in inputs] 1118 | model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True) 1119 | 1120 | if max_source_length is not None and assign_zero_to_too_long_val_examples: 1121 | model_inputs_untrimmed = tokenizer(inputs) 1122 | model_inputs["not_valid_for_eval"] = [ 1123 | len(token_ids) > max_source_length for token_ids in model_inputs_untrimmed["input_ids"] 1124 | ] 1125 | else: 1126 | model_inputs["not_valid_for_eval"] = [False] * len(model_inputs["input_ids"]) 1127 | 1128 | # now, combine the concat prefix to the input, trimming it to max_source_length if given 1129 | if model_prefix_inputs is not None: 1130 | max_source_length = max_source_length or -1 1131 | model_inputs['input_ids'] = [(inp1+inp2)[:max_source_length] for inp1, inp2 1132 | in zip(model_prefix_inputs['input_ids'], model_inputs['input_ids'])] 1133 | model_inputs['attention_mask'] = [(inp1+inp2)[:max_source_length] for inp1, inp2 1134 | in zip(model_prefix_inputs['attention_mask'], model_inputs['attention_mask'])] 1135 | # add prefix_length 1136 | if pad_prefix: 1137 | # no need to go over them as they will all be of the same length 1138 | model_inputs['prefix_length'] = [max_prefix_length] * len(model_inputs['input_ids']) 1139 | else: 1140 | model_inputs['prefix_length'] = [len(inp) for inp in model_prefix_inputs['input_ids']] 1141 | 1142 | return model_inputs 1143 | 1144 | def _preprocess_targets(examples, ignore_pad_token_for_loss, max_target_length, model_inputs, output_column, padding, tokenizer): 1145 | targets = examples[output_column] if output_column in examples else None 1146 | if targets is not None: 1147 | if not isinstance(targets[0], str): 1148 | if max_target_length is not None: 1149 | targets = [target[:max_target_length] for target in targets] 1150 | model_inputs["labels"] = targets 1151 | else: 1152 | # Setup the tokenizer for targets 1153 | with tokenizer.as_target_tokenizer(): 1154 | labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True) 1155 | 1156 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 1157 | # padding in the loss. 1158 | if padding == "max_length" and ignore_pad_token_for_loss: 1159 | labels["input_ids"] = [ 1160 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 1161 | ] 1162 | 1163 | model_inputs["labels"] = labels["input_ids"] 1164 | 1165 | def load_extra_metrics(metric_names, loaded_metrics=None): 1166 | if loaded_metrics is None: 1167 | loaded_metrics = MetricCollection([]) 1168 | if metric_names is not None: 1169 | for metric_name in metric_names: 1170 | if len(metric_name) > 0: 1171 | loaded_metrics._metrics.append(HFMetricWrapper(metric_name)) 1172 | return loaded_metrics 1173 | 1174 | def _mp_fn(index): 1175 | # For xla_spawn (TPUs) 1176 | main() 1177 | 1178 | 1179 | if __name__ == "__main__": 1180 | main() 1181 | -------------------------------------------------------------------------------- /src/run_generation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """ Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet) 18 | """ 19 | 20 | 21 | import argparse 22 | import inspect 23 | import logging 24 | 25 | from dataclasses import dataclass, field 26 | from typing import Tuple, List, Optional, Union 27 | 28 | import numpy as np 29 | import torch 30 | import os 31 | 32 | normal_repr = torch.Tensor.__repr__ 33 | torch.Tensor.__repr__ = lambda self: f"{self.shape}_{normal_repr(self)}" 34 | 35 | from transformers import ( 36 | AutoTokenizer, 37 | BloomForCausalLM, 38 | BloomTokenizerFast, 39 | CTRLLMHeadModel, 40 | CTRLTokenizer, 41 | GenerationMixin, 42 | GPT2LMHeadModel, 43 | GPT2Tokenizer, 44 | GPTJForCausalLM, 45 | HfArgumentParser, 46 | LlamaForCausalLM, 47 | LlamaTokenizer, 48 | OpenAIGPTLMHeadModel, 49 | OpenAIGPTTokenizer, 50 | OPTForCausalLM, 51 | TransfoXLLMHeadModel, 52 | TransfoXLTokenizer, 53 | XLMTokenizer, 54 | XLMWithLMHeadModel, 55 | XLNetLMHeadModel, 56 | XLNetTokenizer, 57 | TextStreamer, 58 | ) 59 | from transformers.modeling_outputs import CausalLMOutputWithPast 60 | 61 | from unlimiformer import Unlimiformer 62 | from random_training_unlimiformer import RandomTrainingUnlimiformer 63 | 64 | @dataclass 65 | class UnlimiformerArguments: 66 | """ 67 | Arguments pertaining to what data we are going to input our model for training and eval. 68 | """ 69 | test_unlimiformer: Optional[bool] = field( 70 | default=False, 71 | metadata={ 72 | "help": "whether to use KNN." 73 | }, 74 | ) 75 | unlimiformer_verbose: Optional[bool] = field( 76 | default=False, 77 | metadata={ 78 | "help": "whether to print KNN intermediate predictions (mostly for debugging)." 79 | }, 80 | ) 81 | layer_begin: Optional[int] = field( 82 | default=0, 83 | metadata={"help": "The layer to begin applying KNN to. KNN will be applied to layers[knn_layer_begin:layer_end]. " 84 | "By default, it will be applied to all layers: [0:None]]"}, 85 | ) 86 | layer_end: Optional[int] = field( 87 | default=None, 88 | metadata={"help": "The layer to end applying KNN to. KNN will be applied to layers[knn_layer_begin:layer_end]. " 89 | "By default, it will be applied to all layers: [0:None]]"}, 90 | ) 91 | unlimiformer_chunk_overlap: Optional[float] = field( 92 | default=0.5, 93 | metadata={"help": "The fraction of overlap between input chunks"}, 94 | ) 95 | unlimiformer_chunk_size: Optional[int] = field( 96 | default=None, 97 | metadata={"help": "The size of each input chunk"}, 98 | ) 99 | unlimiformer_head_num: Optional[int] = field( 100 | default=None, 101 | metadata={"help": "The head to apply KNN to (if None, apply to all heads)"}, 102 | ) 103 | unlimiformer_exclude: Optional[bool] = field( 104 | default=False, 105 | metadata={ 106 | "help": "If True, prioritize the inputs that are **not** in the standard attention window." 107 | }, 108 | ) 109 | random_unlimiformer_training: Optional[bool] = field( 110 | default=False, 111 | ) 112 | unlimiformer_training: Optional[bool] = field( 113 | default=False, 114 | ) 115 | index_devices: Optional[List[int]] = field( 116 | default_factory=lambda: (0,), 117 | ) 118 | datastore_device: Optional[int] = field( 119 | default=0, 120 | ) 121 | use_datastore: Optional[bool] = field(default=True) 122 | flat_index: Optional[bool] = field(default=True) 123 | test_datastore: Optional[bool] = field(default=False) 124 | reconstruct_embeddings: Optional[bool] = field(default=False) 125 | gpu_datastore: Optional[bool] = field(default=True) 126 | gpu_index: Optional[bool] = field(default=True) 127 | 128 | 129 | logging.basicConfig( 130 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 131 | datefmt="%m/%d/%Y %H:%M:%S", 132 | level=logging.INFO, 133 | ) 134 | logger = logging.getLogger(__name__) 135 | 136 | MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop 137 | 138 | MODEL_CLASSES = { 139 | "gpt2": (GPT2LMHeadModel, GPT2Tokenizer), 140 | "ctrl": (CTRLLMHeadModel, CTRLTokenizer), 141 | "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), 142 | "xlnet": (XLNetLMHeadModel, XLNetTokenizer), 143 | "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer), 144 | "xlm": (XLMWithLMHeadModel, XLMTokenizer), 145 | "gptj": (GPTJForCausalLM, AutoTokenizer), 146 | "bloom": (BloomForCausalLM, BloomTokenizerFast), 147 | "llama": (LlamaForCausalLM, LlamaTokenizer), 148 | "opt": (OPTForCausalLM, GPT2Tokenizer), 149 | } 150 | 151 | # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia 152 | # in https://github.com/rusiaaman/XLNet-gen#methodology 153 | # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e 154 | PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family 155 | (except for Alexei and Maria) are discovered. 156 | The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the 157 | remainder of the story. 1883 Western Siberia, 158 | a young Grigori Rasputin is asked by his father and a group of men to perform magic. 159 | Rasputin has a vision and denounces one of the men as a horse thief. Although his 160 | father initially slaps him for making such an accusation, Rasputin watches as the 161 | man is chased outside and beaten. Twenty years later, Rasputin sees a vision of 162 | the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, 163 | with people, even a bishop, begging for his blessing. """ 164 | 165 | 166 | def set_seed(args): 167 | np.random.seed(args.seed) 168 | torch.manual_seed(args.seed) 169 | if args.n_gpu > 0: 170 | torch.cuda.manual_seed_all(args.seed) 171 | 172 | 173 | # 174 | # Functions to prepare models' input 175 | # 176 | 177 | 178 | def prepare_ctrl_input(args, _, tokenizer, prompt_text): 179 | if args.temperature > 0.7: 180 | logger.info("CTRL typically works better with lower temperatures (and lower top_k).") 181 | 182 | encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False) 183 | if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()): 184 | logger.info("WARNING! You are not starting your generation from a control code so you won't get good results") 185 | return prompt_text 186 | 187 | 188 | def prepare_xlm_input(args, model, tokenizer, prompt_text): 189 | # kwargs = {"language": None, "mask_token_id": None} 190 | 191 | # Set the language 192 | use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb 193 | if hasattr(model.config, "lang2id") and use_lang_emb: 194 | available_languages = model.config.lang2id.keys() 195 | if args.xlm_language in available_languages: 196 | language = args.xlm_language 197 | else: 198 | language = None 199 | while language not in available_languages: 200 | language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ") 201 | 202 | model.config.lang_id = model.config.lang2id[language] 203 | # kwargs["language"] = tokenizer.lang2id[language] 204 | 205 | # TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers 206 | # XLM masked-language modeling (MLM) models need masked token 207 | # is_xlm_mlm = "mlm" in args.model_name_or_path 208 | # if is_xlm_mlm: 209 | # kwargs["mask_token_id"] = tokenizer.mask_token_id 210 | 211 | return prompt_text 212 | 213 | 214 | def prepare_xlnet_input(args, _, tokenizer, prompt_text): 215 | prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX 216 | prompt_text = prefix + prompt_text 217 | return prompt_text 218 | 219 | 220 | def prepare_transfoxl_input(args, _, tokenizer, prompt_text): 221 | prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX 222 | prompt_text = prefix + prompt_text 223 | return prompt_text 224 | 225 | 226 | PREPROCESSING_FUNCTIONS = { 227 | "ctrl": prepare_ctrl_input, 228 | "xlm": prepare_xlm_input, 229 | "xlnet": prepare_xlnet_input, 230 | "transfo-xl": prepare_transfoxl_input, 231 | } 232 | 233 | 234 | def adjust_length_to_model(length, max_sequence_length): 235 | if length < 0 and max_sequence_length > 0: 236 | length = max_sequence_length 237 | elif 0 < max_sequence_length < length: 238 | length = max_sequence_length # No generation bigger than model size 239 | elif length < 0: 240 | length = MAX_LENGTH # avoid infinite loop 241 | return length 242 | 243 | 244 | def sparse_model_config(model_config): 245 | embedding_size = None 246 | if hasattr(model_config, "hidden_size"): 247 | embedding_size = model_config.hidden_size 248 | elif hasattr(model_config, "n_embed"): 249 | embedding_size = model_config.n_embed 250 | elif hasattr(model_config, "n_embd"): 251 | embedding_size = model_config.n_embd 252 | 253 | num_head = None 254 | if hasattr(model_config, "num_attention_heads"): 255 | num_head = model_config.num_attention_heads 256 | elif hasattr(model_config, "n_head"): 257 | num_head = model_config.n_head 258 | 259 | if embedding_size is None or num_head is None or num_head == 0: 260 | raise ValueError("Check the model config") 261 | 262 | num_embedding_size_per_head = int(embedding_size / num_head) 263 | if hasattr(model_config, "n_layer"): 264 | num_layer = model_config.n_layer 265 | elif hasattr(model_config, "num_hidden_layers"): 266 | num_layer = model_config.num_hidden_layers 267 | else: 268 | raise ValueError("Number of hidden layers couldn't be determined from the model config") 269 | 270 | return num_layer, num_head, num_embedding_size_per_head 271 | 272 | 273 | def generate_past_key_values(model, batch_size, seq_len): 274 | num_block_layers, num_attention_heads, num_embedding_size_per_head = sparse_model_config(model.config) 275 | if model.config.model_type == "bloom": 276 | past_key_values = tuple( 277 | ( 278 | torch.empty(int(num_attention_heads * batch_size), num_embedding_size_per_head, seq_len) 279 | .to(model.dtype) 280 | .to(model.device), 281 | torch.empty(int(num_attention_heads * batch_size), seq_len, num_embedding_size_per_head) 282 | .to(model.dtype) 283 | .to(model.device), 284 | ) 285 | for _ in range(num_block_layers) 286 | ) 287 | else: 288 | past_key_values = tuple( 289 | ( 290 | torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head) 291 | .to(model.dtype) 292 | .to(model.device), 293 | torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head) 294 | .to(model.dtype) 295 | .to(model.device), 296 | ) 297 | for _ in range(num_block_layers) 298 | ) 299 | return past_key_values 300 | 301 | 302 | def prepare_jit_inputs(inputs, model, tokenizer): 303 | batch_size = len(inputs) 304 | dummy_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt") 305 | dummy_input = dummy_input.to(model.device) 306 | if model.config.use_cache: 307 | dummy_input["past_key_values"] = generate_past_key_values(model, batch_size, 1) 308 | dummy_input["attention_mask"] = torch.cat( 309 | [ 310 | torch.zeros(dummy_input["attention_mask"].shape[0], 1) 311 | .to(dummy_input["attention_mask"].dtype) 312 | .to(model.device), 313 | dummy_input["attention_mask"], 314 | ], 315 | -1, 316 | ) 317 | return dummy_input 318 | 319 | 320 | class _ModelFallbackWrapper(GenerationMixin): 321 | __slots__ = ("_optimized", "_default") 322 | 323 | def __init__(self, optimized, default): 324 | self._optimized = optimized 325 | self._default = default 326 | 327 | def __call__(self, *args, **kwargs): 328 | if kwargs["past_key_values"] is None and self._default.config.use_cache: 329 | kwargs["past_key_values"] = generate_past_key_values(self._default, kwargs["input_ids"].shape[0], 0) 330 | kwargs.pop("position_ids", None) 331 | for k in list(kwargs.keys()): 332 | if kwargs[k] is None or isinstance(kwargs[k], bool): 333 | kwargs.pop(k) 334 | outputs = self._optimized(**kwargs) 335 | lm_logits = outputs[0] 336 | past_key_values = outputs[1] 337 | fixed_output = CausalLMOutputWithPast( 338 | loss=None, 339 | logits=lm_logits, 340 | past_key_values=past_key_values, 341 | hidden_states=None, 342 | attentions=None, 343 | ) 344 | return fixed_output 345 | 346 | def __getattr__(self, item): 347 | return getattr(self._default, item) 348 | 349 | def prepare_inputs_for_generation( 350 | self, input_ids, past_key_values=None, inputs_embeds=None, use_cache=None, **kwargs 351 | ): 352 | return self._default.prepare_inputs_for_generation( 353 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, **kwargs 354 | ) 355 | 356 | def _reorder_cache( 357 | self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor 358 | ) -> Tuple[Tuple[torch.Tensor]]: 359 | """ 360 | This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or 361 | [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct 362 | beam_idx at every generation step. 363 | """ 364 | return self._default._reorder_cache(past_key_values, beam_idx) 365 | 366 | 367 | def main(): 368 | parser = argparse.ArgumentParser() 369 | parser.add_argument( 370 | "--model_type", 371 | default=None, 372 | type=str, 373 | required=True, 374 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), 375 | ) 376 | parser.add_argument( 377 | "--model_name_or_path", 378 | default=None, 379 | type=str, 380 | required=True, 381 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()), 382 | ) 383 | 384 | parser.add_argument("--prompt", type=str, default="") 385 | parser.add_argument("--length", type=int, default=100) 386 | parser.add_argument("--num_hidden_layers", type=int, default=None) 387 | parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped") 388 | 389 | parser.add_argument( 390 | "--temperature", 391 | type=float, 392 | default=1.0, 393 | help="temperature of 1.0 has no effect, lower tend toward greedy sampling", 394 | ) 395 | parser.add_argument( 396 | "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2" 397 | ) 398 | parser.add_argument("--k", type=int, default=0) 399 | parser.add_argument("--p", type=float, default=0.9) 400 | 401 | parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.") 402 | parser.add_argument("--suffix", type=str, default="", help="Text added after the input.") 403 | parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.") 404 | parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.") 405 | 406 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 407 | parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") 408 | parser.add_argument("--stream_output", action="store_true") 409 | parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.") 410 | parser.add_argument( 411 | "--fp16", 412 | action="store_true", 413 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", 414 | ) 415 | parser.add_argument("--jit", action="store_true", help="Whether or not to use jit trace to accelerate inference") 416 | 417 | # args = parser.parse_args() 418 | args, unknown_args = parser.parse_known_args() 419 | 420 | hf_parser = HfArgumentParser(UnlimiformerArguments) 421 | unlimiformer_args, unknown_unlimiformer_args = hf_parser.parse_known_args() 422 | 423 | if len(set(unknown_args) & set(unknown_unlimiformer_args)) > 0: 424 | raise ValueError(f"Unknown arguments detected: {set(unknown_args) & set(unknown_unlimiformer_args)}") 425 | 426 | args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 427 | args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() 428 | 429 | logger.warning(f"device: {args.device}, n_gpu: {args.n_gpu}, 16-bits training: {args.fp16}") 430 | 431 | set_seed(args) 432 | 433 | # Initialize the model and tokenizer 434 | try: 435 | args.model_type = args.model_type.lower() 436 | model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 437 | except KeyError: 438 | raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)") 439 | 440 | tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) 441 | if tokenizer.pad_token is None: 442 | tokenizer.pad_token = tokenizer.eos_token 443 | model_kwargs = {} 444 | if args.num_hidden_layers is not None: 445 | model_kwargs["num_hidden_layers"] = args.num_hidden_layers 446 | model = model_class.from_pretrained(args.model_name_or_path, **model_kwargs) 447 | 448 | if args.fp16: 449 | model.half() 450 | model.to(args.device) 451 | 452 | max_seq_length = getattr(model.config, "max_position_embeddings", 0) 453 | args.length = adjust_length_to_model(args.length, max_sequence_length=max_seq_length) 454 | logger.info(args) 455 | 456 | if unlimiformer_args.test_unlimiformer: 457 | unlimiformer_kwargs = { 458 | 'layer_begin': unlimiformer_args.layer_begin, 459 | 'layer_end': unlimiformer_args.layer_end, 460 | 'unlimiformer_head_num': unlimiformer_args.unlimiformer_head_num, 461 | 'exclude_attention': unlimiformer_args.unlimiformer_exclude, 462 | 'chunk_overlap': unlimiformer_args.unlimiformer_chunk_overlap, 463 | 'model_encoder_max_len': unlimiformer_args.unlimiformer_chunk_size, 464 | 'verbose': unlimiformer_args.unlimiformer_verbose, 'tokenizer': tokenizer, 465 | 'unlimiformer_training': unlimiformer_args.unlimiformer_training, 466 | 'use_datastore': unlimiformer_args.use_datastore, 467 | 'flat_index': unlimiformer_args.flat_index, 468 | 'test_datastore': unlimiformer_args.test_datastore, 469 | 'reconstruct_embeddings': unlimiformer_args.reconstruct_embeddings, 470 | 'gpu_datastore': unlimiformer_args.gpu_datastore, 471 | 'gpu_index': unlimiformer_args.gpu_index, 472 | 'index_devices': unlimiformer_args.index_devices, 473 | 'datastore_device': unlimiformer_args.datastore_device, 474 | } 475 | if unlimiformer_args.random_unlimiformer_training: 476 | model = RandomTrainingUnlimiformer.convert_model(model, **unlimiformer_kwargs) 477 | else: 478 | model = Unlimiformer.convert_model(model, **unlimiformer_kwargs) 479 | 480 | prompt_text = args.prompt if args.prompt else input("Model prompt >>> ") 481 | # Check if prompt_text is a valid file name: 482 | if os.path.exists(prompt_text): 483 | with open(prompt_text, "r") as f: 484 | prompt_text = f.read() 485 | 486 | # Different models need different input formatting and/or extra arguments 487 | requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys() 488 | if requires_preprocessing: 489 | prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) 490 | preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text) 491 | 492 | if model.__class__.__name__ in ["TransfoXLLMHeadModel"]: 493 | tokenizer_kwargs = {"add_space_before_punct_symbol": True} 494 | else: 495 | tokenizer_kwargs = {} 496 | 497 | encoded_prompt = tokenizer.encode( 498 | preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs 499 | ) 500 | else: 501 | # prefix = args.prefix if args.prefix else args.padding_text 502 | prompt_text = f'{args.prefix}{prompt_text}{args.suffix}' 503 | encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt") 504 | 505 | if not unlimiformer_args.test_unlimiformer: 506 | encoded_prompt = encoded_prompt[:, -2048:] 507 | encoded_prompt = encoded_prompt.to(args.device) 508 | 509 | if encoded_prompt.size()[-1] == 0: 510 | input_ids = None 511 | else: 512 | input_ids = encoded_prompt 513 | 514 | if args.jit: 515 | jit_input_texts = ["enable jit"] 516 | jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer) 517 | torch._C._jit_set_texpr_fuser_enabled(False) 518 | model.config.return_dict = False 519 | if hasattr(model, "forward"): 520 | sig = inspect.signature(model.forward) 521 | else: 522 | sig = inspect.signature(model.__call__) 523 | jit_inputs = tuple(jit_inputs[key] for key in sig.parameters if jit_inputs.get(key, None) is not None) 524 | traced_model = torch.jit.trace(model, jit_inputs, strict=False) 525 | traced_model = torch.jit.freeze(traced_model.eval()) 526 | traced_model(*jit_inputs) 527 | traced_model(*jit_inputs) 528 | 529 | model = _ModelFallbackWrapper(traced_model, model) 530 | 531 | model.eval() 532 | output_sequences = model.generate( 533 | input_ids=input_ids, 534 | # max_length=args.length + len(encoded_prompt[0]), 535 | max_new_tokens=args.length, 536 | temperature=args.temperature, 537 | top_k=args.k, 538 | top_p=args.p, 539 | repetition_penalty=args.repetition_penalty, 540 | do_sample=True, 541 | num_return_sequences=args.num_return_sequences, 542 | streamer=TextStreamer(tokenizer, skip_prompt=True) if args.stream_output else None, 543 | ) 544 | 545 | # Remove the batch dimension when returning multiple sequences 546 | if len(output_sequences.shape) > 2: 547 | output_sequences.squeeze_() 548 | 549 | generated_sequences = [] 550 | 551 | for generated_sequence_idx, generated_sequence in enumerate(output_sequences): 552 | print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} (input length: {input_ids.shape[-1]}) ===") 553 | generated_sequence = generated_sequence.tolist() 554 | # generated_sequence = generated_sequence[len(encoded_prompt[0]):] + tokenizer.encode(' ') + generated_sequence[:len(encoded_prompt[0])] 555 | 556 | # Decode text 557 | # text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) 558 | prompt_length = min(input_ids.shape[-1], model.unlimiformer.window_size()) if unlimiformer_args.test_unlimiformer else input_ids.shape[-1] 559 | completion = tokenizer.decode(generated_sequence[prompt_length:]) 560 | 561 | # Remove all text after the stop token 562 | # text = text[: text.find(args.stop_token) if args.stop_token else None] 563 | 564 | # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing 565 | total_sequence = ( 566 | # prompt_text + 567 | '|||' + completion 568 | ) 569 | 570 | generated_sequences.append(total_sequence) 571 | print(total_sequence) 572 | 573 | return generated_sequences 574 | 575 | 576 | if __name__ == "__main__": 577 | main() 578 | -------------------------------------------------------------------------------- /src/usage.py: -------------------------------------------------------------------------------- 1 | from unlimiformer import Unlimiformer 2 | from random_training_unlimiformer import RandomTrainingUnlimiformer 3 | 4 | from dataclasses import dataclass, field 5 | from typing import List, Optional 6 | 7 | 8 | @dataclass 9 | class UnlimiformerArguments: 10 | """ 11 | Arguments pertaining to what data we are going to input our model for training and eval. 12 | """ 13 | test_unlimiformer: Optional[bool] = field( 14 | default=True, 15 | metadata={ 16 | "help": "whether to use KNN." 17 | }, 18 | ) 19 | unlimiformer_verbose: Optional[bool] = field( 20 | default=False, 21 | metadata={ 22 | "help": "whether to print KNN intermediate predictions (mostly for debugging)." 23 | }, 24 | ) 25 | layer_begin: Optional[int] = field( 26 | default=0, 27 | metadata={"help": "The layer to begin applying KNN to. KNN will be applied to layers[knn_layer_begin:layer_end]. " 28 | "By default, it will be applied to all layers: [0:None]]"}, 29 | ) 30 | layer_end: Optional[int] = field( 31 | default=None, 32 | metadata={"help": "The layer to end applying KNN to. KNN will be applied to layers[knn_layer_begin:layer_end]. " 33 | "By default, it will be applied to all layers: [0:None]]"}, 34 | ) 35 | unlimiformer_chunk_overlap: Optional[float] = field( 36 | default=0.5, 37 | metadata={"help": "The fraction of overlap between input chunks"}, 38 | ) 39 | unlimiformer_chunk_size: Optional[int] = field( 40 | default=None, 41 | metadata={"help": "The size of each input chunk"}, 42 | ) 43 | unlimiformer_head_num: Optional[int] = field( 44 | default=None, 45 | metadata={"help": "The head to apply KNN to (if None, apply to all heads)"}, 46 | ) 47 | unlimiformer_exclude: Optional[bool] = field( 48 | default=False, 49 | metadata={ 50 | "help": "If True, prioritize the inputs that are **not** in the standard attention window." 51 | }, 52 | ) 53 | random_unlimiformer_training: Optional[bool] = field( 54 | default=False, 55 | ) 56 | unlimiformer_training: Optional[bool] = field( 57 | default=False, 58 | ) 59 | use_datastore: Optional[bool] = field(default=False) 60 | flat_index: Optional[bool] = field(default=False) 61 | test_datastore: Optional[bool] = field(default=False) 62 | reconstruct_embeddings: Optional[bool] = field(default=False) 63 | gpu_datastore: Optional[bool] = field(default=True) 64 | gpu_index: Optional[bool] = field(default=True) 65 | 66 | 67 | 68 | # include these lines in your code somewhere before model training 69 | def training_addin(): 70 | if unlimiformer_args.test_unlimiformer: 71 | unlimiformer_kwargs = { 72 | 'layer_begin': unlimiformer_args.layer_begin, 73 | 'layer_end': unlimiformer_args.layer_end, 74 | 'unlimiformer_head_num': unlimiformer_args.unlimiformer_head_num, 75 | 'exclude_attention': unlimiformer_args.unlimiformer_exclude, 76 | 'chunk_overlap': unlimiformer_args.unlimiformer_chunk_overlap, 77 | 'model_encoder_max_len': unlimiformer_args.unlimiformer_chunk_size, 78 | 'verbose': unlimiformer_args.unlimiformer_verbose, 'tokenizer': tokenizer, 79 | 'unlimiformer_training': unlimiformer_args.unlimiformer_training, 80 | 'use_datastore': unlimiformer_args.use_datastore, 81 | 'flat_index': unlimiformer_args.flat_index, 82 | 'test_datastore': unlimiformer_args.test_datastore, 83 | 'reconstruct_embeddings': unlimiformer_args.reconstruct_embeddings, 84 | 'gpu_datastore': unlimiformer_args.gpu_datastore, 85 | 'gpu_index': unlimiformer_args.gpu_index 86 | } 87 | if unlimiformer_args.random_unlimiformer_training: 88 | model = RandomTrainingUnlimiformer.convert_model(model, **unlimiformer_kwargs) 89 | else: 90 | model = Unlimiformer.convert_model(model, **unlimiformer_kwargs) 91 | 92 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abertsch72/unlimiformer/e38b0149488636da9528c7504d2befdfb76f6d98/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/config.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | def handle_args_to_ignore(args: List[str]): 5 | indices_to_remove = [] 6 | for i, arg in enumerate(args): 7 | if "_ignore_" in arg: 8 | indices_to_remove.append(i) 9 | if not arg.startswith("-"): 10 | indices_to_remove.append(i - 1) 11 | 12 | for i in sorted(indices_to_remove, reverse=True): 13 | del args[i] 14 | -------------------------------------------------------------------------------- /src/utils/custom_hf_argument_parser.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import sys 4 | from typing import Tuple 5 | 6 | from transformers import HfArgumentParser 7 | from transformers.hf_argparser import DataClass 8 | 9 | 10 | class CustomHfArgumentParser(HfArgumentParser): 11 | def parse_dictionary_and_args(self) -> Tuple[DataClass, ...]: 12 | """ 13 | Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the 14 | dataclass types. 15 | """ 16 | args = [] 17 | data = {} 18 | for i in range(1, len(sys.argv)): 19 | if not sys.argv[i].endswith('.json'): 20 | break 21 | 22 | with open(sys.argv[i]) as f: 23 | new_data = json.load(f) 24 | conflicting_keys = set(new_data.keys()).intersection(data.keys()) 25 | if len(conflicting_keys) > 0: 26 | raise ValueError(f'There are conflicting keys in the config files: {conflicting_keys}') 27 | data.update(new_data) 28 | 29 | for k, v in data.items(): 30 | # if any options were given explicitly through the CLA then they override anything defined in the config files 31 | if f'--{k}' in sys.argv: 32 | logging.info(f'While {k}={v} was given in a config file, a manual override was set through the CLA') 33 | continue 34 | args.extend( 35 | ["--" + k, *(v if isinstance(v, list) else [str(v)])] 36 | ) # add the file arguments first so command line args has precedence 37 | args += sys.argv[i:] 38 | 39 | return self.parse_args_into_dataclasses(args=args, look_for_args_file=False) -------------------------------------------------------------------------------- /src/utils/custom_seq2seq_trainer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | import time 5 | from collections import defaultdict 6 | from typing import Any, Dict, List, Optional, Tuple, Union 7 | 8 | import torch 9 | from datasets import Dataset 10 | from torch import nn 11 | from transformers.debug_utils import DebugOption 12 | from transformers.deepspeed import is_deepspeed_zero3_enabled 13 | from transformers.trainer_utils import speed_metrics 14 | 15 | from transformers.utils import logging 16 | from transformers import Seq2SeqTrainer, is_torch_tpu_available 17 | 18 | import gc 19 | 20 | if is_torch_tpu_available(check_device=False): 21 | import torch_xla.core.xla_model as xm 22 | import torch_xla.debug.metrics as met 23 | 24 | 25 | from utils.decoding import decode 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | 30 | def _clean_memory(): 31 | gc.collect() 32 | torch.cuda.empty_cache() 33 | 34 | # This custom trainer is based on the trainer defined in https://github.com/huggingface/transformers/compare/main...eladsegal:public-transformers:scrolls 35 | class CustomTrainer(Seq2SeqTrainer): 36 | def __init__( 37 | self, *args, untokenized_eval_dataset=None, data_args=None, output_dir: Optional[str] = None, **kwargs 38 | ): 39 | super().__init__(*args, **kwargs) 40 | self._untokenized_eval_dataset = untokenized_eval_dataset 41 | self._max_length = data_args.val_max_target_length 42 | self._num_beams = data_args.num_beams 43 | self._output_dir = output_dir 44 | self._data_args = data_args 45 | self.mock_predictions_to_assign_zero_metric_score = self.tokenizer.encode("TOO_MANY_INPUT_TOKENS",return_tensors="np")[0] 46 | 47 | def prediction_step( 48 | self, 49 | model: nn.Module, 50 | inputs: Dict[str, Union[torch.Tensor, Any]], 51 | prediction_loss_only: bool, 52 | ignore_keys: Optional[List[str]] = None, 53 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 54 | """ 55 | Perform an evaluation step on `model` using `inputs`. 56 | 57 | Subclass and override to inject custom behavior. 58 | 59 | Args: 60 | model (`nn.Module`): 61 | The model to evaluate. 62 | inputs (`Dict[str, Union[torch.Tensor, Any]]`): 63 | The inputs and targets of the model. 64 | 65 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 66 | argument `labels`. Check your model's documentation for all accepted arguments. 67 | prediction_loss_only (`bool`): 68 | Whether or not to ret`urn the loss only. 69 | 70 | Return: 71 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and 72 | labels (each being optional). 73 | """ 74 | if not ("labels" in inputs or 'decoder_input_ids' in inputs): 75 | if model.training: 76 | logger.warning('When computing loss, must give labels or decoder_input_ids. ' 77 | 'If you only perform prediction, you can safely ignore this message') 78 | # This is an issue here because the input may be longer than the max-output length of the model, 79 | # and if nothing was given it will shift the input and use it to compute loss (and later discard it). 80 | # This may cause an indexing error when absolute embeddings are used (CUDA device side assert) 81 | inputs['decoder_input_ids'] = inputs['input_ids'][:,:2].clone() # dummy outputs 82 | 83 | if not self.args.predict_with_generate or prediction_loss_only: 84 | return super().prediction_step( 85 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys 86 | ) 87 | 88 | has_labels = "labels" in inputs 89 | inputs = self._prepare_inputs(inputs) 90 | 91 | # XXX: adapt synced_gpus for fairscale as well 92 | gen_kwargs = self._gen_kwargs.copy() 93 | gen_kwargs["max_length"] = ( 94 | gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.model.config.max_length 95 | ) 96 | gen_kwargs["num_beams"] = ( 97 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams 98 | ) 99 | default_synced_gpus = True if is_deepspeed_zero3_enabled() else False 100 | gen_kwargs["synced_gpus"] = ( 101 | gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus 102 | ) 103 | 104 | if "attention_mask" in inputs: 105 | gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) 106 | if "global_attention_mask" in inputs: 107 | gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None) 108 | 109 | # --------------------- addition compared to the source file -------------------- 110 | if 'prefix_length' in inputs: 111 | gen_kwargs['prefix_length'] = inputs['prefix_length'] 112 | _clean_memory() 113 | # ------------------------------------------------------------------------------ 114 | 115 | # prepare generation inputs 116 | # some encoder-decoder models can have varying encoder's and thus 117 | # varying model input names 118 | if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name: 119 | generation_inputs = inputs[self.model.encoder.main_input_name] 120 | else: 121 | generation_inputs = inputs[self.model.main_input_name] 122 | 123 | # Uri: to make sure we use cache even during mid-training evaluation, where this is disabled in general: 124 | gen_kwargs['use_cache'] = True 125 | 126 | generated_tokens = self.model.generate( 127 | generation_inputs, 128 | **gen_kwargs, 129 | ) 130 | # --------------------- addition compared to the source file -------------------- 131 | _clean_memory() 132 | # ------------------------------------------------------------------------------ 133 | # in case the batch is shorter than max length, the output should be padded 134 | if generated_tokens.shape[-1] < gen_kwargs["max_length"]: 135 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) 136 | 137 | if has_labels: # changed the order of the if's here because there is no point going through the model if there are no labels to compute the loss on.. 138 | with torch.no_grad(): 139 | with self.compute_loss_context_manager(): 140 | outputs = model(**inputs) 141 | if self.label_smoother is not None: 142 | loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() 143 | else: 144 | loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() 145 | else: 146 | loss = None 147 | 148 | if self.args.prediction_loss_only: 149 | return (loss, None, None) 150 | 151 | if has_labels: 152 | labels = inputs["labels"] 153 | if labels.shape[-1] < gen_kwargs["max_length"]: 154 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) 155 | else: 156 | labels = None 157 | 158 | return (loss, generated_tokens, labels) 159 | 160 | @property 161 | def _restart_generator(self): 162 | if getattr(self, '_is_restart_generator', False): 163 | self._is_restart_generator = False 164 | return True 165 | return False 166 | 167 | def set_restart_generator(self): 168 | self._is_restart_generator = True 169 | 170 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 171 | sampler = super()._get_train_sampler() 172 | try: 173 | if self._restart_generator: 174 | sampler.generator.manual_seed(self._initial_seed) 175 | else: 176 | self._initial_seed = sampler.generator.initial_seed() 177 | except Exception as e: 178 | logger.warning(f'Cannot save or set the seed of the generator: {e}') 179 | return sampler 180 | 181 | def _post_process_function(self, untokenized_eval_dataset, predictions): 182 | id_to_prediction = {} 183 | id_to_label_ids = defaultdict(list) 184 | 185 | assert len(untokenized_eval_dataset) == len(self.eval_dataset) 186 | 187 | for i, (instance, not_valid_for_eval) in enumerate(zip(untokenized_eval_dataset, self.eval_dataset["not_valid_for_eval"])): 188 | if not_valid_for_eval: 189 | id_to_prediction[instance["id"]] = self.mock_predictions_to_assign_zero_metric_score 190 | else: 191 | id_to_prediction[instance["id"]] = predictions[i] 192 | 193 | if "outputs" in instance: 194 | id_to_label_ids[instance["id"]] = instance["outputs"] 195 | else: 196 | id_to_label_ids[instance["id"]].append(instance["output"]) 197 | 198 | return id_to_prediction, id_to_label_ids 199 | 200 | def evaluate( 201 | self, 202 | eval_dataset: Optional[Dataset] = None, 203 | ignore_keys: Optional[List[str]] = None, 204 | metric_key_prefix: str = "eval", 205 | untokenized_eval_dataset: Optional[Dataset] = None, 206 | **gen_kwargs 207 | ) -> Dict[str, float]: 208 | """ 209 | Run evaluation and returns metrics. 210 | 211 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent 212 | (pass it to the init `compute_metrics` argument). 213 | 214 | You can also subclass and override this method to inject custom behavior. 215 | 216 | Args: 217 | eval_dataset (`Dataset`, *optional*): 218 | Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns 219 | not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` 220 | method. 221 | ignore_keys (`List[str]`, *optional*): 222 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 223 | gathering predictions. 224 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`): 225 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 226 | "eval_bleu" if the prefix is `"eval"` (default) 227 | max_length (`int`, *optional*): 228 | The maximum target length to use when predicting with the generate method. 229 | num_beams (`int`, *optional*): 230 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no 231 | beam search. 232 | gen_kwargs: 233 | Additional `generate` specific kwargs. 234 | 235 | Returns: 236 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The 237 | dictionary also contains the epoch number which comes from the training state. 238 | """ 239 | 240 | gen_kwargs = gen_kwargs.copy() 241 | gen_kwargs["max_length"] = ( 242 | gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length 243 | ) 244 | gen_kwargs["num_beams"] = ( 245 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams 246 | ) 247 | self._gen_kwargs = gen_kwargs 248 | 249 | self._memory_tracker.start() 250 | 251 | eval_dataloader = self.get_eval_dataloader(eval_dataset) 252 | # ----------------------------------- Added ----------------------------------- 253 | untokenized_eval_dataset = ( 254 | self._untokenized_eval_dataset if untokenized_eval_dataset is None else untokenized_eval_dataset 255 | ) 256 | compute_metrics = self.compute_metrics 257 | self.compute_metrics = None 258 | # ----------------------------------------------------------------------------- 259 | 260 | start_time = time.time() 261 | 262 | eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop 263 | try: 264 | output = eval_loop( 265 | eval_dataloader, 266 | description="Evaluation", 267 | # No point gathering the predictions if there are no metrics, otherwise we defer to 268 | # self.args.prediction_loss_only 269 | prediction_loss_only=None, # MODIFIED since we need the predictions 270 | ignore_keys=ignore_keys, 271 | metric_key_prefix=metric_key_prefix, 272 | ) 273 | finally: 274 | # ----------------------------------- Added ----------------------------------- 275 | # revert the compute metrics back 276 | self.compute_metrics = compute_metrics 277 | # ----------------------------------------------------------------------------- 278 | 279 | # ----------------------------------- Added ----------------------------------- 280 | # compute our metrics 281 | if output.predictions is not None: 282 | eval_preds = self._post_process_function(untokenized_eval_dataset, output.predictions) 283 | 284 | if self._output_dir is not None and self.is_world_process_zero(): 285 | predictions = decode(eval_preds[0], self.tokenizer, self._data_args) 286 | output_prediction_file = os.path.join( 287 | self._output_dir, f"generated_predictions_eval_{self.state.global_step}.json" 288 | ) 289 | with open(output_prediction_file, "w") as writer: 290 | json.dump(predictions, writer, indent=4) 291 | 292 | output_labels_file = os.path.join( 293 | self._output_dir, f"eval_labels.json" 294 | ) 295 | if not os.path.isfile(output_labels_file): 296 | with open(output_labels_file, "w") as writer: 297 | json.dump(eval_preds[1], writer, indent=4) 298 | 299 | if self.compute_metrics is not None: 300 | output.metrics.update(self.compute_metrics(*eval_preds)) 301 | 302 | # Prefix all keys with metric_key_prefix + '_' 303 | for key in list(output.metrics.keys()): 304 | if not key.startswith(f"{metric_key_prefix}_"): 305 | output.metrics[f"{metric_key_prefix}_{key}"] = output.metrics.pop(key) 306 | # ----------------------------------------------------------------------------- 307 | 308 | total_batch_size = self.args.eval_batch_size * self.args.world_size 309 | output.metrics.update( 310 | speed_metrics( 311 | metric_key_prefix, 312 | start_time, 313 | num_samples=output.num_samples, 314 | num_steps=math.ceil(output.num_samples / total_batch_size), 315 | ) 316 | ) 317 | 318 | self.log(output.metrics) 319 | 320 | if DebugOption.TPU_METRICS_DEBUG in self.args.debug: 321 | # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) 322 | xm.master_print(met.metrics_report()) 323 | 324 | self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) 325 | 326 | self._memory_tracker.stop_and_update_metrics(output.metrics) 327 | 328 | return output.metrics 329 | -------------------------------------------------------------------------------- /src/utils/decoding.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import numpy as np 3 | 4 | 5 | def decode(id_to_something, tokenizer=None, data_args=None): 6 | decode_fn = None 7 | switch_case = None 8 | elem = next(iter(id_to_something.values())) 9 | if isinstance(elem, str): 10 | switch_case = -1 11 | decode_fn = lambda text: text.strip() 12 | elif isinstance(elem, list) and not isinstance(elem[0], int): 13 | if isinstance(elem[0], str): 14 | switch_case = 0 15 | decode_fn = lambda texts: [text.strip() for text in texts] 16 | else: 17 | switch_case = 1 18 | decode_fn = lambda token_ids_list: [ 19 | text.strip() 20 | for text in partial( 21 | tokenizer.batch_decode, skip_special_tokens=True, clean_up_tokenization_spaces=True 22 | )(token_ids_list) 23 | ] 24 | else: 25 | switch_case = 2 26 | decode_fn = lambda token_ids: partial( 27 | tokenizer.decode, skip_special_tokens=True, clean_up_tokenization_spaces=True 28 | )(token_ids).strip() 29 | 30 | id_to_text = {} 31 | for id_, something in id_to_something.items(): 32 | if switch_case == -1 or switch_case == 0: 33 | obj_to_decode = something 34 | else: 35 | if data_args is None: 36 | data_args = {} 37 | if not isinstance(data_args, dict): 38 | data_args = vars(data_args) 39 | if data_args.get("ignore_pad_token_for_loss", True): 40 | # Replace -100 in the token_ids as we can't decode them. 41 | if switch_case == 1: 42 | token_ids_list = something 43 | for i in range(len(token_ids_list)): 44 | token_ids_list[i] = _replace_padding(token_ids_list[i], tokenizer.pad_token_id) 45 | obj_to_decode = token_ids_list 46 | elif switch_case == 2: 47 | token_ids = something 48 | token_ids = _replace_padding(token_ids, tokenizer.pad_token_id) 49 | obj_to_decode = token_ids 50 | else: 51 | obj_to_decode = something 52 | 53 | id_to_text[id_] = decode_fn(obj_to_decode) 54 | 55 | return id_to_text 56 | 57 | 58 | def _replace_padding(token_ids: np.array, pad_token_id): 59 | return np.where(token_ids != -100, token_ids, pad_token_id) 60 | -------------------------------------------------------------------------------- /src/utils/duplicates.py: -------------------------------------------------------------------------------- 1 | def drop_duplicates_in_input(untokenized_dataset): 2 | indices_to_keep = [] 3 | id_to_idx = {} 4 | outputs = [] 5 | for i, (id_, output) in enumerate(zip(untokenized_dataset["id"], untokenized_dataset["output"])): 6 | if id_ in id_to_idx: 7 | outputs[id_to_idx[id_]].append(output) 8 | continue 9 | indices_to_keep.append(i) 10 | id_to_idx[id_] = len(outputs) 11 | outputs.append([output]) 12 | untokenized_dataset = untokenized_dataset.select(indices_to_keep).flatten_indices() 13 | untokenized_dataset = untokenized_dataset.remove_columns("output") 14 | untokenized_dataset = untokenized_dataset.add_column("outputs", outputs) 15 | return untokenized_dataset 16 | -------------------------------------------------------------------------------- /src/utils/override_training_args.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | 5 | import torch.cuda 6 | from transformers.utils import logging 7 | 8 | sys.path.insert(0, os.getcwd()) 9 | 10 | from dataclasses import dataclass, field 11 | 12 | from transformers.trainer_utils import IntervalStrategy 13 | from transformers import Seq2SeqTrainingArguments 14 | 15 | logger = logging.get_logger('swed_logger') 16 | 17 | @dataclass 18 | class TrainingOverridesArguments(Seq2SeqTrainingArguments): 19 | """ 20 | To use if, it requires evaluation_strategy == IntervalStrategy.STEPS 21 | """ 22 | eval_steps_override: float = field(default=0, metadata={"help": "a fraction, to set the the save_steps w.r.t to number of steps in " 23 | "a single epoch. changes eval_steps. 0 to disable (default)"}) 24 | save_steps_override: float = field(default=0, metadata={"help": "a fraction, to set the the save_steps w.r.t to number of steps in " 25 | "a single epoch. changes save_steps. must be a multiple of eval_steps" 26 | " (or eval_steps_override if given). 0 to disable (default)"}) 27 | 28 | eval_fraction: float = field(default=1, metadata={ 29 | "help": "A float in (0,1] that corresponds to how much of the eval set to use during evaluations " 30 | "(same subset all the time) or an integer >= 2 which amounts to the absolute number of training " 31 | "samples to use. 1. to disable it and use the entire eval set "}) 32 | 33 | use_auth_token: bool = field( 34 | default=False, 35 | metadata={ 36 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 37 | "with private models). If AUTH_TOKEN is set as an environment variable, would use that" 38 | }, 39 | ) 40 | 41 | fp16_padding: bool = field( 42 | default=False, 43 | metadata={"help": "Whether to use padding for fp16"}, 44 | ) 45 | 46 | 47 | def __post_init__(self): 48 | super(TrainingOverridesArguments, self).__post_init__() 49 | if self.eval_steps_override > 0 or self.save_steps_override > 0: 50 | if self.evaluation_strategy != IntervalStrategy.STEPS: 51 | raise ValueError( 52 | f"using eval/save steps override requires evaluation strategy to be {IntervalStrategy.STEPS}" 53 | ) 54 | if self.save_steps_override == 0 or self.eval_steps_override == 0: 55 | raise ValueError( 56 | f"using eval/save steps override requires both overrides to be non zero" 57 | ) 58 | diff = (self.save_steps_override / self.eval_steps_override) % 1 59 | if min(1-diff, diff) > 1e-5: # we do it like that to support fractions modulo as well, with loss of precision 60 | raise ValueError( 61 | f"using eval/save steps override requires save steps override to be a multiple of eval_steps_override" 62 | ) 63 | if self.use_auth_token and 'AUTH_TOKEN' in os.environ: 64 | self.use_auth_token = os.getenv('AUTH_TOKEN') 65 | 66 | @property 67 | def effective_batch_size(self): 68 | if not hasattr(self, '_ebs'): 69 | n_gpu = self.n_gpu if torch.cuda.is_available() else 1 # may be on cpu 70 | self._ebs = self.per_device_train_batch_size * self.gradient_accumulation_steps * n_gpu 71 | logger.warning(f'Training with {self.per_device_train_batch_size} per_device_train_size, {self.n_gpu} gpus and ' 72 | f'{self.gradient_accumulation_steps} gradient accumulation steps, resulting in {self._ebs} effective batch size') 73 | return self._ebs 74 | 75 | def apply_overrides(self, dataset_size): 76 | # Uri: 77 | return 78 | 79 | if self.eval_steps_override == 0: 80 | return 81 | es, ss = self.eval_steps, self.save_steps 82 | total_steps_per_epoch = dataset_size / self.effective_batch_size # note that this may not be an integer 83 | eval_steps = int(total_steps_per_epoch * self.eval_steps_override) 84 | if eval_steps >= self.logging_steps: 85 | if eval_steps % self.logging_steps != 0: 86 | logger.warning(f'Eval steps override would result in eval every {eval_steps} steps, but it is not a ' 87 | f'multiple of logging steps ({self.logging_steps}) so changing to ' 88 | f'{eval_steps + self.logging_steps - eval_steps % self.logging_steps}') 89 | eval_steps = eval_steps + self.logging_steps - eval_steps % self.logging_steps 90 | elif eval_steps < self.logging_steps: 91 | logger.warning(f'Eval steps override would result in eval every {eval_steps} steps, but it is not a ' 92 | f'multiple of logging steps ({self.logging_steps}) so changing to {self.logging_steps}') 93 | eval_steps = self.logging_steps 94 | self.eval_steps = eval_steps 95 | 96 | save_steps = int(total_steps_per_epoch * self.save_steps_override) 97 | if save_steps < eval_steps or save_steps % eval_steps != 0: 98 | logger.warning(f'Save steps override would result in eval every {save_steps} steps, but it is not a ' 99 | f'multiple of eval steps ({eval_steps}) so changing to ' 100 | f'{save_steps + eval_steps - save_steps % self.eval_steps}') 101 | save_steps = save_steps + eval_steps - save_steps % self.eval_steps 102 | self.save_steps = save_steps 103 | 104 | logger.warning(f'Using overrides with dataset of size {dataset_size} and effective batch size of ' 105 | f'{self.effective_batch_size}, moving from (eval_steps, save_steps) ' 106 | f'of {(es, ss)} to {(self.eval_steps, self.save_steps)}') 107 | --------------------------------------------------------------------------------