├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── examples ├── README.md ├── __init__.py ├── seq2seq │ ├── README.md │ ├── __init__.py │ ├── configs │ │ ├── data │ │ │ ├── contract_nli.json │ │ │ ├── gov_report.json │ │ │ ├── hotpotqa.json │ │ │ ├── hotpotqa_second_only.json │ │ │ ├── narative_qa.json │ │ │ ├── qasper.json │ │ │ ├── qmsum.json │ │ │ ├── quality.json │ │ │ ├── squad.json │ │ │ ├── squad_ordered_distractors.json │ │ │ ├── squad_shuffled_distractors.json │ │ │ └── summ_screen_fd.json │ │ ├── model │ │ │ ├── bart_base_sled.json │ │ │ └── bart_large_sled.json │ │ └── training │ │ │ └── base_training_args.json │ ├── metrics │ │ ├── __init__.py │ │ └── metrics.py │ ├── run.py │ └── utils │ │ ├── __init__.py │ │ ├── config.py │ │ ├── custom_hf_argument_parser.py │ │ ├── custom_seq2seq_trainer.py │ │ ├── decoding.py │ │ ├── duplicates.py │ │ └── override_training_args.py └── usage_example.py ├── setup.py ├── sled ├── __init__.py ├── configuration_sled.py ├── modeling_sled.py ├── tokenization_sled.py └── tokenization_sled_fast.py └── tests ├── configs ├── bart_base_sled.json └── t5_base_sled.json └── test_sled.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Maor Ivgi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | prune tests/ 2 | prune examples/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SLED 2 | The official repository for Efficient Long-Text Understanding Using Short-Text Models [(Ivgi et al., 2022)](https://arxiv.org/abs/2208.00748.pdf), to appear in Transactions of the Association for Computational Linguistics (TACL) 2023 . 3 | 4 | SLED models use pretrained, short-range encoder-decoder models, and apply them over. 5 | long-text inputs by splitting the input into multiple overlapping chunks, encoding each independently and perform fusion-in-decoder. 6 | 7 | 8 | ## Data 9 | The data for this paper is hosted on the dataset hub [here](https://huggingface.co/datasets/tau/sled). 10 | It is based on the [SCROLLS dataset](https://huggingface.co/datasets/tau/scrolls) ([paper](https://arxiv.org/pdf/2201.03533.pdf)), the [SQuAD 1.1 dataset](https://huggingface.co/datasets/squad) ([paper](https://arxiv.org/pdf/1606.05250.pdf)) and the [HotpotQA dataset](https://huggingface.co/datasets/hotpot_qa) ([paper](https://arxiv.org/pdf/1809.09600.pdf)). 11 | It doesn't contain any unpublished data, but includes the configuration needed for the paper. 12 | 13 | Usage example : 14 | ```python 15 | from datasets import load_dataset 16 | qasper = load_dataset("tau/sled","qasper") 17 | ``` 18 | 19 | ## Installation 20 | 21 | Make sure to install pytorch according to your machine spec. See installation options [here](https://pytorch.org/get-started/locally/). 22 | 23 | Installing SLED is easy with pip. 24 | ``` 25 | pip install py-sled 26 | ``` 27 | 28 | Some backbone models require additional dependencies. If you wish to work with T5 for example, you can install using. 29 | ``` 30 | pip install py-sled[t5] 31 | ``` 32 | 33 | If you wish to run the examples, install the required dependencies with 34 | ``` 35 | pip install py-sled[examples] 36 | ``` 37 | 38 | If you wish to continue developing this repository, install the full development requirments with 39 | ``` 40 | pip install py-sled[dev] 41 | ``` 42 | 43 | ## Usage 44 | Working with SLED is seamless when working with HuggingFace's Transformers AutoClasses. 45 | 46 | A minimal usage example: 47 | ```python 48 | import sled # ** required so SLED would be properly registered by the AutoClasses ** 49 | from transformers import AutoTokenizer, AutoModel 50 | tokenizer = AutoTokenizer.from_pretrained('tau/bart-base-sled') 51 | model = AutoModel.from_pretrained('tau/bart-base-sled') 52 | inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") 53 | outputs = model(**inputs) 54 | last_hidden_states = outputs.last_hidden_state 55 | ``` 56 | 57 | _Important_: You need to `import sled` before using the AutoClass (e.g. `AutoModel.from_pretrained('tau/bart-base-sled)`) for it to work. 58 | 59 | Minimal working example can be found [here](examples/usage_example.py). 60 | 61 | To work with SCROLLS like data that was used for the paper, see [here](examples/seq2seq). 62 | 63 | ### Custom datasets 64 | For SLED to be able to prepend the prefix input to every chunk, it requires the input tensor `prefix_length`. 65 | If using a custom dataset, you can refer to [run.py](examples/seq2seq/run.py) for the correct way to preprocess the data. 66 | 67 | _Note_: Currently, HF's Seq2SeqTrainer doesn't pass the `prefix_length` tensor in the prediction loop, so you 68 | should use the [CustomSeq2SeqTrainer](examples/seq2seq/utils/custom_seq2seq_trainer.py) or something similar until it is 69 | fixed. 70 | 71 | ### Backbone models 72 | There are multiple model cards available on HuggingfaceHub including 73 | - [Bart-Base SLED](https://huggingface.co/tau/bart-base-sled) (model name `tau/bart-base-sled`) 74 | - [Bart-Large SLED](https://huggingface.co/tau/bart-large-sled) (model name `tau/bart-base-sled`) 75 | - [T5(v1.1)-base SLED](https://huggingface.co/tau/t5-v1_1-base-sled) (model name `tau/t5-v1_1-base-sled`) 76 | - [T5(v1.1)-large SLED](https://huggingface.co/tau/t5-v1_1-large-sled) (model name `tau/t5-v1_1-large-sled`) 77 | 78 | If you wish to use a custom model that is available as a model card (public or private) on the hub, or use 79 | different parameters for SLED, you can create a json config file like the below, and change the underlying_config to your custom model card. 80 | ```json 81 | { 82 | "model_type": "tau/sled", 83 | "underlying_config": "facebook/bart-base", 84 | "context_size": 256, 85 | "window_fraction": 0.5, 86 | "prepend_prefix": true, 87 | "encode_prefix": true, 88 | "sliding_method": "dynamic" 89 | } 90 | ``` 91 | You can then load it like below 92 | ```python 93 | import sled 94 | from transformers import AutoModelForSeq2SeqLM 95 | custom_sled_model = AutoModelForSeq2SeqLM.from_pretrained() 96 | ``` 97 | 98 | ## Citation 99 | 100 | If you use this repository, please cite as below: 101 | ``` 102 | @inproceedings{Ivgi2022EfficientLU, 103 | title={Efficient Long-Text Understanding with Short-Text Models}, 104 | author={Maor Ivgi and Uri Shaham and Jonathan Berant}, 105 | year={2022} 106 | } 107 | ``` 108 | 109 | 110 | ## Disclaimer 111 | This repository is still under active development, and may contain some unintended behavior. 112 | Please open an issue if any unexpected behaviour occurs, and we will promptly try to fix it. 113 | 114 | The code was developed and tested with transformers version 4.21.0. Newer version may break backward 115 | compatibility and cause unexpected behaviour. 116 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | Working with SLED is seamless when working with HuggingFace's Transformers AutoClasses. 3 | 4 | _Important_: You need to `import sled` before using the AutoClass (e.g. `AutoModel.from_pretrained('tau/bart-base-sled)`) for it to work. 5 | 6 | Minimal working example can be found [here](usage_example.py) 7 | 8 | To work with SCROLLS like data that was used for the paper, see [here](seq2seq) -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mivg/SLED/2991895d01e5ed5871de41676a14b25c534c523a/examples/__init__.py -------------------------------------------------------------------------------- /examples/seq2seq/README.md: -------------------------------------------------------------------------------- 1 | # Seq2Seq finetuning 2 | In this example, you can find the scripts needed to finetune sled models on SCROLLS like data. 3 | 4 | The entrypoint script is [run.py](run.py) 5 | 6 | ## Usage 7 | After setting up your environment as described [here](https://github.com/Mivg/SLED#installation), you can run the script to finetune, 8 | evaluate and generate predictions for all the datasets in the [SLED dataset](https://huggingface.co/datasets/tau/sled) 9 | (based on SCROLLS). 10 | 11 | Like most recipes, you can view all possible settable parameters by running 12 | ``` 13 | python run.py --help 14 | ``` 15 | 16 | To run, you can either set the parameters with command-line arguments (e.g. `--model_name_or_path tau/bart-base-sled`) 17 | or use some predfined json files to set recurring configurations. You can pass as many json files as you would like, 18 | but make sure you pass them before any other command line argument. For example, you can do the following: 19 | ``` 20 | python run.py configs/data/squad.json \ 21 | configs/model/bart_base_sled.json \ 22 | configs/training/base_training_args.json \ 23 | --output_dir /tmp/output_sled 24 | --learning_rate 2e-5 25 | ``` 26 | 27 | Example jsons files are [here](https://github.com/Mivg/SLED/tree/main/examples/seq2seq/configs). 28 | -------------------------------------------------------------------------------- /examples/seq2seq/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mivg/SLED/2991895d01e5ed5871de41676a14b25c534c523a/examples/seq2seq/__init__.py -------------------------------------------------------------------------------- /examples/seq2seq/configs/data/contract_nli.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "contract_nli", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 64, 6 | "generation_max_length": 8, 7 | "num_train_epochs": 20, 8 | "metric_names": ["exact_match"], 9 | "metric_for_best_model": "exact_match", 10 | "greater_is_better": true 11 | } -------------------------------------------------------------------------------- /examples/seq2seq/configs/data/gov_report.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "gov_report", 4 | "max_source_length": 16384, 5 | "generation_max_length": 1024, 6 | "max_prefix_length": 0, 7 | "num_train_epochs": 10, 8 | "metric_names": ["rouge"], 9 | "metric_for_best_model": "rouge/geometric_mean", 10 | "greater_is_better": true 11 | } -------------------------------------------------------------------------------- /examples/seq2seq/configs/data/hotpotqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "hotpotqa", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 64, 6 | "generation_max_length": 128, 7 | "num_train_epochs": 9, 8 | "metric_names": ["f1", "exact_match"], 9 | "metric_for_best_model": "f1", 10 | "greater_is_better": true 11 | } -------------------------------------------------------------------------------- /examples/seq2seq/configs/data/hotpotqa_second_only.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "hotpotqa_second_only", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 64, 6 | "generation_max_length": 128, 7 | "num_train_epochs": 9, 8 | "metric_names": ["f1", "exact_match"], 9 | "metric_for_best_model": "f1", 10 | "greater_is_better": true 11 | } -------------------------------------------------------------------------------- /examples/seq2seq/configs/data/narative_qa.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "narrative_qa", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 64, 6 | "num_train_epochs": 2, 7 | "generation_max_length": 128, 8 | "metric_names": ["f1"], 9 | "metric_for_best_model": "f1", 10 | "greater_is_better": true 11 | } -------------------------------------------------------------------------------- /examples/seq2seq/configs/data/qasper.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "qasper", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 64, 6 | "generation_max_length": 128, 7 | "num_train_epochs": 20, 8 | "metric_names": ["f1"], 9 | "metric_for_best_model": "f1", 10 | "greater_is_better": true 11 | } -------------------------------------------------------------------------------- /examples/seq2seq/configs/data/qmsum.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "qmsum", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 64, 6 | "num_train_epochs": 20, 7 | "generation_max_length": 1024, 8 | "metric_names": ["rouge"], 9 | "metric_for_best_model": "rouge/geometric_mean", 10 | "greater_is_better": true 11 | } -------------------------------------------------------------------------------- /examples/seq2seq/configs/data/quality.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "quality", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 160, 6 | "num_train_epochs": 20, 7 | "generation_max_length": 128, 8 | "metric_names": ["exact_match"], 9 | "metric_for_best_model": "exact_match", 10 | "greater_is_better": true 11 | } -------------------------------------------------------------------------------- /examples/seq2seq/configs/data/squad.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "squad", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 64, 6 | "num_train_epochs": 3, 7 | "generation_max_length": 128, 8 | "metric_names": ["f1", "exact_match"], 9 | "metric_for_best_model": "f1", 10 | "greater_is_better": true 11 | } -------------------------------------------------------------------------------- /examples/seq2seq/configs/data/squad_ordered_distractors.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "squad_ordered_distractors", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 64, 6 | "num_train_epochs": 3, 7 | "generation_max_length": 128, 8 | "metric_names": ["f1", "exact_match"], 9 | "metric_for_best_model": "f1", 10 | "greater_is_better": true 11 | } -------------------------------------------------------------------------------- /examples/seq2seq/configs/data/squad_shuffled_distractors.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "squad_shuffled_distractors", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 64, 6 | "num_train_epochs": 3, 7 | "generation_max_length": 128, 8 | "metric_names": ["f1", "exact_match"], 9 | "metric_for_best_model": "f1", 10 | "greater_is_better": true 11 | } -------------------------------------------------------------------------------- /examples/seq2seq/configs/data/summ_screen_fd.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "tau/sled", 3 | "dataset_config_name": "summ_screen_fd", 4 | "max_source_length": 16384, 5 | "max_prefix_length": 0, 6 | "num_train_epochs": 10, 7 | "generation_max_length": 1024, 8 | "metric_names": ["rouge"], 9 | "metric_for_best_model": "rouge/geometric_mean", 10 | "greater_is_better": true 11 | } -------------------------------------------------------------------------------- /examples/seq2seq/configs/model/bart_base_sled.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name_or_path": "tau/bart-base-sled", 3 | "use_auth_token": false, 4 | "pad_prefix": true, 5 | "max_target_length": 1024, 6 | "fp16": true 7 | } -------------------------------------------------------------------------------- /examples/seq2seq/configs/model/bart_large_sled.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name_or_path": "tau/bart-large-sled", 3 | "use_auth_token": false, 4 | "pad_prefix": true, 5 | "max_target_length": 1024, 6 | "fp16": true 7 | } -------------------------------------------------------------------------------- /examples/seq2seq/configs/training/base_training_args.json: -------------------------------------------------------------------------------- 1 | { 2 | "eval_steps_override": 0.5, 3 | "save_steps_override": 0.5, 4 | "evaluation_strategy": "steps", 5 | "eval_fraction": 1000, 6 | "predict_with_generate": true, 7 | "gradient_checkpointing": true, 8 | "do_train": true, 9 | "do_eval": true, 10 | "seed": 42, 11 | "warmup_ratio": 0.1, 12 | "save_total_limit": 2, 13 | "preprocessing_num_workers": 1, 14 | "load_best_model_at_end": true, 15 | "lr_scheduler": "linear", 16 | "adam_epsilon": 1e-6, 17 | "adam_beta1": 0.9, 18 | "adam_beta2": 0.98, 19 | "weight_decay": 0.001 20 | } -------------------------------------------------------------------------------- /examples/seq2seq/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import load_metric 2 | -------------------------------------------------------------------------------- /examples/seq2seq/metrics/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | import os 3 | import importlib 4 | from abc import ABC, abstractmethod 5 | import inspect 6 | import shutil 7 | 8 | import numpy as np 9 | 10 | from utils.decoding import decode 11 | from datasets import load_metric as hf_load_metric 12 | from huggingface_hub import hf_hub_download 13 | 14 | 15 | class Metric(ABC): 16 | def __init__(self, **kwargs) -> None: 17 | super().__init__() 18 | self._kwargs = kwargs 19 | 20 | self.prefix = os.path.splitext(os.path.basename(inspect.getfile(self.__class__)))[0] 21 | self.requires_decoded = False 22 | 23 | def __call__(self, id_to_pred, id_to_labels, is_decoded=False): 24 | if self.requires_decoded and is_decoded is False: 25 | id_to_pred = self._decode(id_to_pred) 26 | id_to_labels = self._decode(id_to_labels) 27 | return self._compute_metrics(id_to_pred, id_to_labels) 28 | 29 | @abstractmethod 30 | def _compute_metrics(self, id_to_pred, id_to_labels) -> Dict[str, float]: 31 | return 32 | 33 | def _decode(self, id_to_something): 34 | tokenizer = self._kwargs.get("tokenizer") 35 | data_args = self._kwargs.get("data_args") 36 | return decode(id_to_something, tokenizer, data_args) 37 | 38 | 39 | class MetricCollection(Metric): 40 | def __init__(self, metrics: List[Metric], **kwargs): 41 | super().__init__(**kwargs) 42 | self._metrics = metrics 43 | 44 | def __call__(self, id_to_pred, id_to_labels): 45 | return self._compute_metrics(id_to_pred, id_to_labels) 46 | 47 | def _compute_metrics(self, id_to_pred, id_to_labels): 48 | results = {} 49 | 50 | id_to_pred_decoded = None 51 | id_to_labels_decoded = None 52 | for metric in self._metrics: 53 | metric_prefix = f"{metric.prefix}/" if metric.prefix else "" 54 | if metric.requires_decoded: 55 | if id_to_pred_decoded is None: 56 | id_to_pred_decoded = self._decode(id_to_pred) 57 | if id_to_labels_decoded is None: 58 | id_to_labels_decoded = self._decode(id_to_labels) 59 | 60 | result = metric(id_to_pred_decoded, id_to_labels_decoded, is_decoded=True) 61 | else: 62 | result = metric(id_to_pred, id_to_labels) 63 | 64 | results.update({f"{metric_prefix}{k}": v for k, v in result.items()}) 65 | 66 | results["num_predicted"] = len(id_to_pred) 67 | results["mean_prediction_length_characters"] = np.mean([len(pred) for pred in id_to_pred_decoded.values()]) 68 | 69 | elem = next(iter(id_to_pred.values())) 70 | if not ((isinstance(elem, list) and isinstance(elem[0], str)) or isinstance(elem, str)): 71 | tokenizer = self._kwargs["tokenizer"] 72 | results["mean_prediction_length_tokens"] = np.mean( 73 | [np.count_nonzero(np.array(pred) != tokenizer.pad_token_id) for pred in id_to_pred.values()] 74 | ) # includes BOS/EOS tokens 75 | 76 | results = {key: round(value, 4) for key, value in results.items()} 77 | return results 78 | 79 | 80 | def load_metric(paths: List[str], **kwargs): 81 | if paths is None or len(paths) == 0: 82 | return None 83 | if isinstance(paths, str): 84 | paths = [paths] 85 | else: 86 | paths = [path for path in paths] 87 | 88 | metric_cls_list = [] 89 | 90 | scrolls_custom_metrics = [] 91 | to_remove = [] 92 | for i, path in enumerate(paths): 93 | if not os.path.isfile(path): 94 | scrolls_custom_metrics.append(path) 95 | to_remove.append(i) 96 | for i in sorted(to_remove, reverse=True): 97 | del paths[i] 98 | if len(scrolls_custom_metrics) > 0: 99 | scrolls_custom_metrics.insert(0, "") # In order to have an identifying comma in the beginning 100 | metric_cls_list.append(ScrollsWrapper(",".join(scrolls_custom_metrics), **kwargs)) 101 | 102 | for path in paths: 103 | path = path.strip() 104 | if len(path) == 0: 105 | continue 106 | if os.path.isfile(path) is False: 107 | path = os.path.join("src", "metrics", f"{path}.py") 108 | 109 | module = path[:-3].replace(os.sep, ".") 110 | 111 | metric_cls = import_main_class(module) 112 | metric_cls_list.append(metric_cls(**kwargs)) 113 | 114 | return MetricCollection(metric_cls_list, **kwargs) 115 | 116 | 117 | # Modified from datasets.load 118 | def import_main_class(module_path): 119 | """Import a module at module_path and return its main class""" 120 | module = importlib.import_module(module_path) 121 | 122 | main_cls_type = Metric 123 | 124 | # Find the main class in our imported module 125 | module_main_cls = None 126 | for name, obj in module.__dict__.items(): 127 | if isinstance(obj, type) and issubclass(obj, main_cls_type): 128 | if inspect.isabstract(obj): 129 | continue 130 | module_main_cls = obj 131 | break 132 | 133 | return module_main_cls 134 | 135 | 136 | class ScrollsWrapper(Metric): 137 | def __init__(self, comma_separated_metric_names, **kwargs) -> None: 138 | super().__init__(**kwargs) 139 | self.prefix = None 140 | 141 | self._metric = hf_load_metric(download_metric(), comma_separated_metric_names, keep_in_memory=True) 142 | 143 | self.requires_decoded = True 144 | 145 | def _compute_metrics(self, id_to_pred, id_to_labels) -> Dict[str, float]: 146 | return self._metric.compute(**self._metric.convert_from_map_format(id_to_pred, id_to_labels)) 147 | 148 | 149 | def download_metric(): 150 | # here we load the custom metrics 151 | 152 | try: 153 | scrolls_metric_path = hf_hub_download(repo_id="tau/scrolls", repo_type="dataset", filename="metrics/scrolls.py") 154 | except: 155 | # support for backward compatibility 156 | scrolls_metric_path = hf_hub_download(repo_id="datasets/tau/scrolls", filename="metrics/scrolls.py") 157 | 158 | updated_scrolls_metric_path = ( 159 | os.path.dirname(scrolls_metric_path) + os.path.basename(scrolls_metric_path).replace(".", "_") + ".py" 160 | ) 161 | shutil.copy(scrolls_metric_path, updated_scrolls_metric_path) 162 | return updated_scrolls_metric_path 163 | -------------------------------------------------------------------------------- /examples/seq2seq/run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for sequence to sequence. 18 | """ 19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. 20 | 21 | import logging 22 | import os 23 | import sys 24 | 25 | import numpy as np 26 | 27 | # we import the logging frameworks before any other import to make sure all monkey patching for the logging are active 28 | from sled import SledConfig 29 | 30 | try: 31 | import comet_ml 32 | except ImportError: 33 | pass 34 | try: 35 | import wandb 36 | except ImportError: 37 | pass 38 | import torch 39 | 40 | 41 | sys.path.insert(0, os.path.dirname(__file__)) # seq2seq package path 42 | sys.path.insert(0, os.getcwd()) 43 | 44 | from dataclasses import dataclass, field 45 | from typing import List, Optional 46 | import json 47 | from copy import deepcopy 48 | import torch.nn.functional as F 49 | 50 | import datasets 51 | 52 | import transformers 53 | from transformers import ( 54 | AutoConfig, 55 | AutoModelForSeq2SeqLM, 56 | AutoTokenizer, 57 | set_seed, WEIGHTS_NAME, 58 | ) 59 | from transformers.trainer_utils import get_last_checkpoint 60 | from transformers import DataCollatorForSeq2Seq 61 | 62 | from datasets import load_dataset 63 | 64 | # noinspection PyUnresolvedReferences 65 | import sled # *** required so that SledModels will be registered for the AutoClasses *** 66 | 67 | from utils.config import handle_args_to_ignore 68 | from utils.decoding import decode 69 | from metrics import load_metric 70 | from utils.duplicates import drop_duplicates_in_input 71 | from utils.override_training_args import TrainingOverridesArguments 72 | from utils.custom_seq2seq_trainer import CustomTrainer 73 | from utils.custom_hf_argument_parser import CustomHfArgumentParser 74 | 75 | logger = logging.getLogger('sled') 76 | 77 | PREFIX_DOC_SEP = '\n\n' 78 | 79 | DEBUG = os.environ.get('DEBUG', 'false').lower() in {'1', 'true', 'yes'} # If set, will set some configuration to help debug 80 | if DEBUG: 81 | assert not torch.cuda.is_available() or torch.cuda.device_count() == 1 82 | 83 | 84 | @dataclass 85 | class ModelArguments: 86 | """ 87 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 88 | """ 89 | 90 | model_name_or_path: str = field( 91 | default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 92 | ) 93 | config_name: Optional[str] = field( 94 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 95 | ) 96 | tokenizer_name: Optional[str] = field( 97 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 98 | ) 99 | cache_dir: Optional[str] = field( 100 | default=None, 101 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 102 | ) 103 | use_fast_tokenizer: bool = field( 104 | default=True, 105 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 106 | ) 107 | model_revision: str = field( 108 | default="main", 109 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 110 | ) 111 | drop_duplicates_in_eval: bool = field( 112 | default=True, 113 | ) 114 | 115 | def __post_init__(self): 116 | pass 117 | 118 | 119 | @dataclass 120 | class DataTrainingArguments: 121 | """ 122 | Arguments pertaining to what data we are going to input our model for training and eval. 123 | """ 124 | 125 | dataset_name: Optional[str] = field( 126 | default=None, 127 | metadata={ 128 | "help": "The name of the dataset to use (via the datasets library) or name of the file in src/data." 129 | }, 130 | ) 131 | dataset_config_name: Optional[str] = field( 132 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 133 | ) 134 | metric_names: Optional[List[str]] = field( 135 | default=None, 136 | metadata={"help": "The name of the metric to use (from src/metrics)."}, 137 | ) 138 | input_column: Optional[str] = field( 139 | default=None, 140 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 141 | ) 142 | input_prefix_column: Optional[str] = field( 143 | default=None, 144 | metadata={"help": "The name of the column in the datasets containing the input prefix (e.g. questions), when those exist."}, 145 | ) 146 | output_column: Optional[str] = field( 147 | default=None, 148 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 149 | ) 150 | train_file: Optional[str] = field( 151 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 152 | ) 153 | validation_file: Optional[str] = field( 154 | default=None, 155 | metadata={ 156 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " 157 | "(a jsonlines or csv file)." 158 | }, 159 | ) 160 | test_file: Optional[str] = field( 161 | default=None, 162 | metadata={ 163 | "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)." 164 | }, 165 | ) 166 | overwrite_cache: bool = field( 167 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 168 | ) 169 | preprocessing_num_workers: Optional[int] = field( 170 | default=None, 171 | metadata={"help": "The number of processes to use for the preprocessing."}, 172 | ) 173 | max_source_length: Optional[int] = field( 174 | default=None, 175 | metadata={ 176 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 177 | "than this will be truncated, sequences shorter will be padded." 178 | }, 179 | ) 180 | max_prefix_length: Optional[int] = field( 181 | default=0, 182 | metadata={ 183 | "help": "The maximum total input_prefix sequence length after tokenization. Sequences longer " 184 | "than this will be truncated, sequences shorter will be padded from the left " 185 | "(only used if prefixes are not merged)." 186 | }, 187 | ) 188 | max_target_length: Optional[int] = field( 189 | default=128, 190 | metadata={ 191 | "help": "The maximum total sequence length for target text after tokenization. Sequences longer " 192 | "than this will be truncated, sequences shorter will be padded." 193 | }, 194 | ) 195 | val_max_target_length: Optional[int] = field( 196 | default=None, 197 | metadata={ 198 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " 199 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 200 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 201 | "during ``evaluate`` and ``predict``." 202 | }, 203 | ) 204 | pad_to_max_length: bool = field( 205 | default=False, 206 | metadata={ 207 | "help": "Whether to pad all samples to model maximum sentence length. " 208 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 209 | "efficient on GPU but very bad for TPU." 210 | }, 211 | ) 212 | max_train_samples: Optional[int] = field( 213 | default=None, 214 | metadata={ 215 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 216 | "value if set." 217 | }, 218 | ) 219 | max_eval_samples: Optional[int] = field( 220 | default=None, 221 | metadata={ 222 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 223 | "value if set." 224 | }, 225 | ) 226 | max_predict_samples: Optional[int] = field( 227 | default=None, 228 | metadata={ 229 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 230 | "value if set." 231 | }, 232 | ) 233 | num_beams: Optional[int] = field( 234 | default=None, 235 | metadata={ 236 | "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 237 | "which is used during ``evaluate`` and ``predict``." 238 | }, 239 | ) 240 | ignore_pad_token_for_loss: bool = field( 241 | default=True, 242 | metadata={ 243 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 244 | }, 245 | ) 246 | source_prefix: Optional[str] = field( 247 | default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 248 | ) 249 | data_dir: Optional[str] = field( 250 | default=None, 251 | metadata={"help": "Defining the data_dir of the dataset configuration."}, 252 | ) 253 | download_mode: Optional[str] = field( 254 | default=None, 255 | metadata={ 256 | "help": "Defining the download_mode when loading the dataset. Options are `reuse_dataset_if_exists` (default), `reuse_cache_if_exists` and `force_redownload`." 257 | }, 258 | ) 259 | evaluate_on_training_data: bool = field( 260 | default=False, 261 | metadata={"help": "Whether to evaluate on training data or not, to make sure the model can overfit."}, 262 | ) 263 | folder_suffix: str = field( 264 | default="", 265 | metadata={"help": "args to be suffixes for the output folder of the run"}, 266 | ) 267 | preprocess_only: bool = field( 268 | default=False, 269 | metadata={"help": "Preprocess only: Don't start training, just do the things before"}, 270 | ) 271 | assign_zero_to_too_long_val_examples: bool = field( 272 | default=False, 273 | metadata={ 274 | "help": "If true, all sequences longer then max_source_length will be assign a score of 0 in the metric evaluation" 275 | }, 276 | ) 277 | shared_storage: bool = field( 278 | default=True, 279 | metadata={"help": "Whether nodes share the same storage"}, 280 | ) 281 | trim_very_long_strings: bool = field( 282 | default=False, 283 | metadata={"help": "Whether to trim very long strings before tokenizing them"}, 284 | ) 285 | pad_prefix: bool = field( 286 | default=False, 287 | metadata={ 288 | "help": "Whether to pad the prefix if it exists to max_prefix_length. " 289 | "Note - important if you are using a SLED model on an input that contains an input_prefix" 290 | }, 291 | ) 292 | test_start_ind: Optional[int] = field( 293 | default=None, 294 | metadata={"help": "if given, uses the test set starting from this index"}, 295 | ) 296 | test_end_ind: Optional[int] = field( 297 | default=None, 298 | metadata={"help": "if given, uses the test set ending at this index"}, 299 | ) 300 | 301 | 302 | def __post_init__(self): 303 | if self.val_max_target_length is None: 304 | self.val_max_target_length = self.max_target_length 305 | if self.pad_prefix and self.max_prefix_length == 0: 306 | raise ValueError('When padding prefix, you must set a max_prefix_length') 307 | assert self.max_prefix_length == 0 or self.max_prefix_length <= 0.5*self.max_source_length,\ 308 | 'If max_prefix_length is given, it must be much shorter than the total input' 309 | 310 | 311 | def main(): 312 | handle_args_to_ignore(sys.argv) # Just for sweeps 313 | 314 | # See all possible arguments in src/transformers/training_args.py 315 | # or by passing the --help flag to this script. 316 | # We now keep distinct sets of args, for a cleaner separation of concerns. 317 | 318 | parser = CustomHfArgumentParser((ModelArguments, DataTrainingArguments, TrainingOverridesArguments)) 319 | model_args, data_args, training_args = parser.parse_dictionary_and_args() 320 | 321 | set_up_logging(training_args) 322 | 323 | # Used to find missing dependencies early on 324 | load_metric(data_args.metric_names, **locals()) 325 | 326 | if data_args.source_prefix is None and model_args.model_name_or_path in [ 327 | "t5-small", 328 | "t5-base", 329 | "t5-large", 330 | "t5-3b", 331 | "t5-11b", 332 | ]: 333 | logger.warning( 334 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " 335 | "`--source_prefix 'summarize: ' `" 336 | ) 337 | 338 | # Detecting last checkpoint. 339 | last_checkpoint = _detect_last_checkpoint(training_args) 340 | 341 | # Set seed before initializing model. 342 | set_seed(training_args.seed) 343 | 344 | seq2seq_dataset = _get_dataset(data_args, model_args, training_args) 345 | 346 | # Load pretrained model and tokenizer 347 | # 348 | # Distributed training: 349 | # The .from_pretrained methods guarantee that only one local process can concurrently 350 | # download model & vocab. 351 | config_name = None 352 | if model_args.config_name: 353 | config_name = model_args.config_name 354 | else: 355 | if os.path.isfile(model_args.model_name_or_path): 356 | config_name = os.path.dirname(model_args.model_name_or_path) 357 | else: 358 | config_name = model_args.model_name_or_path 359 | 360 | config_overrides = {} 361 | if training_args.gradient_checkpointing is not None: 362 | config_overrides["gradient_checkpointing"] = training_args.gradient_checkpointing 363 | 364 | config = AutoConfig.from_pretrained( 365 | config_name, 366 | cache_dir=model_args.cache_dir, 367 | revision=model_args.model_revision, 368 | use_auth_token=training_args.use_auth_token, 369 | **config_overrides 370 | ) 371 | # override for sled models to make sure we are explicit in our request 372 | if isinstance(config, SledConfig) and (not data_args.pad_prefix or data_args.max_prefix_length == 0): 373 | logger.warning('Setting prepend_prefix to False if using a SLED model, as the input does not have a prefix or ' 374 | 'pad_prefix is False (all prefixes must be of the same length for SLED). If you do not use SLED ' 375 | 'or finetune on a dataset with no prefixes, ignore this warning') 376 | config.prepend_prefix = False 377 | 378 | if model_args.model_name_or_path is None: 379 | # Padding for divisibility by 8 380 | if config.vocab_size % 8 != 0 and training_args.fp16_padding: 381 | config.vocab_size += 8 - (config.vocab_size % 8) 382 | 383 | tokenizer_name = None 384 | if model_args.tokenizer_name: 385 | tokenizer_name = model_args.tokenizer_name 386 | else: 387 | if os.path.isfile(model_args.model_name_or_path): 388 | tokenizer_name = os.path.dirname(model_args.model_name_or_path) 389 | else: 390 | tokenizer_name = model_args.model_name_or_path 391 | tokenizer = AutoTokenizer.from_pretrained( 392 | tokenizer_name, 393 | cache_dir=model_args.cache_dir, 394 | use_fast=model_args.use_fast_tokenizer, 395 | revision=model_args.model_revision, 396 | use_auth_token=training_args.use_auth_token, 397 | ) 398 | if model_args.model_name_or_path is not None: 399 | model = AutoModelForSeq2SeqLM.from_pretrained( 400 | model_args.model_name_or_path, 401 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 402 | config=config, 403 | cache_dir=model_args.cache_dir, 404 | revision=model_args.model_revision, 405 | use_auth_token=training_args.use_auth_token, 406 | ) 407 | else: 408 | model = AutoModelForSeq2SeqLM.from_config( 409 | config, 410 | ) 411 | if training_args.gradient_checkpointing and getattr(model.config, 'use_cache', False): 412 | logger.warning('Cannot use cache in models when using gradient checkpointing. turning it off') 413 | model.config.use_cache = False 414 | 415 | model.resize_token_embeddings(len(tokenizer)) 416 | 417 | if model.config.decoder_start_token_id is None: 418 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") 419 | 420 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 421 | 422 | # Preprocessing the datasets. 423 | # We need to tokenize inputs and targets. 424 | if training_args.do_train: 425 | column_names = seq2seq_dataset["train"].column_names 426 | elif training_args.do_eval: 427 | column_names = seq2seq_dataset["validation"].column_names 428 | elif training_args.do_predict: 429 | column_names = seq2seq_dataset["test"].column_names 430 | else: 431 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") 432 | return 433 | 434 | # Get the column names for input/target. 435 | if data_args.input_column is None: 436 | input_column = "input" 437 | else: 438 | input_column = data_args.input_column 439 | if input_column not in column_names: 440 | raise ValueError( 441 | f"--input_column' value '{data_args.input_column}' needs to be one of: {', '.join(column_names)}" 442 | ) 443 | if data_args.input_prefix_column is None: 444 | input_prefix_column = "input_prefix" 445 | else: 446 | input_prefix_column = data_args.input_prefix_column 447 | if input_prefix_column not in column_names: 448 | raise ValueError( 449 | f"--input_prefix_column' value '{data_args.input_prefix_column}' needs to be one of: {', '.join(column_names)}" 450 | ) 451 | if data_args.output_column is None: 452 | output_column = "output" 453 | else: 454 | output_column = data_args.output_column 455 | if output_column not in column_names: 456 | raise ValueError( 457 | f"--output_column' value '{data_args.output_column}' needs to be one of: {', '.join(column_names)}" 458 | ) 459 | 460 | # Temporarily set max_target_length for training. 461 | max_target_length = data_args.max_target_length 462 | padding = "max_length" if data_args.pad_to_max_length else False 463 | 464 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): 465 | logger.warning( 466 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" 467 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" 468 | ) 469 | 470 | def preprocess_function_kwargs_fn(): 471 | return { 472 | "tokenizer": deepcopy(tokenizer), 473 | "prefix": prefix, 474 | "input_column": input_column, 475 | "input_prefix_column": input_prefix_column, 476 | "output_column": output_column, 477 | "max_source_length": data_args.max_source_length, 478 | "max_prefix_length": data_args.max_prefix_length, 479 | "max_target_length": max_target_length, 480 | "prefix_sep": PREFIX_DOC_SEP, 481 | "padding": padding, 482 | "ignore_pad_token_for_loss": data_args.ignore_pad_token_for_loss, 483 | "assign_zero_to_too_long_val_examples": data_args.assign_zero_to_too_long_val_examples, 484 | "trim_very_long_strings": data_args.trim_very_long_strings, 485 | "pad_prefix": data_args.pad_prefix 486 | } 487 | 488 | if training_args.do_train: 489 | if "train" not in seq2seq_dataset: 490 | raise ValueError("--do_train requires a train dataset") 491 | logger.info("") 492 | logger.info("Training examples before tokenization:") 493 | if input_prefix_column in column_names: 494 | logger.info(f"input_prefix #0: {seq2seq_dataset['train'][0][input_prefix_column]}") 495 | logger.info(f"input #0: {seq2seq_dataset['train'][0]['input']}") 496 | logger.info(f"output #0: {seq2seq_dataset['train'][0]['output']}") 497 | if input_prefix_column in column_names: 498 | logger.info(f"input_prefix #1: {seq2seq_dataset['train'][1][input_prefix_column]}") 499 | logger.info(f"input #1: {seq2seq_dataset['train'][1]['input']}") 500 | logger.info(f"output #1: {seq2seq_dataset['train'][1]['output']}") 501 | logger.info("") 502 | untokenized_train_dataset = seq2seq_dataset["train"] 503 | if data_args.max_train_samples is not None: 504 | untokenized_train_dataset = untokenized_train_dataset.select(range(data_args.max_train_samples)) 505 | 506 | if DEBUG: 507 | # In debug mode, we want ot recreate the data 508 | data_args.shared_storage = False 509 | data_args.overwrite_cache = True 510 | with training_args.main_process_first( 511 | local=not data_args.shared_storage, desc="train dataset map pre-processing" 512 | ): 513 | train_dataset = untokenized_train_dataset.map( 514 | preprocess_function, 515 | fn_kwargs=preprocess_function_kwargs_fn(), 516 | batched=True, 517 | num_proc=data_args.preprocessing_num_workers, 518 | remove_columns=untokenized_train_dataset.column_names, 519 | load_from_cache_file=not data_args.overwrite_cache, 520 | desc="Running tokenizer on train dataset", 521 | ) 522 | 523 | if training_args.do_eval: 524 | max_target_length = data_args.val_max_target_length 525 | preprocess_function_kwargs = preprocess_function_kwargs_fn() 526 | preprocess_function_kwargs["max_target_length"] = max_target_length 527 | if "validation" not in seq2seq_dataset: 528 | raise ValueError("--do_eval requires a validation dataset") 529 | logger.info("") 530 | logger.info("Validation examples before tokenization:") 531 | if input_prefix_column in column_names: 532 | logger.info(f"input_prefix #0: {seq2seq_dataset['validation'][0][input_prefix_column]}") 533 | logger.info(f"input #0: {seq2seq_dataset['validation'][0]['input']}") 534 | logger.info(f"output #0: {seq2seq_dataset['validation'][0]['output']}") 535 | if input_prefix_column in column_names: 536 | logger.info(f"input_prefix #1: {seq2seq_dataset['validation'][1][input_prefix_column]}") 537 | logger.info(f"input #1: {seq2seq_dataset['validation'][1]['input']}") 538 | logger.info(f"output #1: {seq2seq_dataset['validation'][1]['output']}") 539 | logger.info("") 540 | untokenized_eval_dataset = seq2seq_dataset["validation"] 541 | if data_args.max_eval_samples is not None: 542 | untokenized_eval_dataset = untokenized_eval_dataset.select(range(data_args.max_eval_samples)) 543 | if model_args.drop_duplicates_in_eval is True: 544 | untokenized_eval_dataset = drop_duplicates_in_input(untokenized_eval_dataset) 545 | untokenized_eval_dataset_orig = untokenized_eval_dataset 546 | assert training_args.eval_fraction > 0 547 | n = len(untokenized_eval_dataset) 548 | training_args.eval_fraction = min(training_args.eval_fraction, n) 549 | if training_args.eval_fraction != 1: 550 | if training_args.eval_fraction > 1: 551 | assert training_args.eval_fraction == int(training_args.eval_fraction) 552 | logger.info(f'using predetermined absolute samples from eval set ({training_args.eval_fraction} )') 553 | training_args.eval_fraction = training_args.eval_fraction / n 554 | indices = np.random.permutation(n)[:int(np.ceil(max(1, training_args.eval_fraction * n)))] 555 | untokenized_eval_dataset = type(untokenized_eval_dataset).from_dict(untokenized_eval_dataset[indices]) 556 | logger.info(f'During training, will only use {training_args.eval_fraction:.3%} samples of the eval set ' 557 | f'which amounts to {len(untokenized_eval_dataset)} out of {n} samples') 558 | 559 | eval_dataset = process_eval_set(data_args, preprocess_function_kwargs, training_args, untokenized_eval_dataset) 560 | eval_dataset_orig = eval_dataset 561 | if training_args.eval_fraction < 1: 562 | eval_dataset_orig = process_eval_set(data_args, preprocess_function_kwargs, training_args, 563 | untokenized_eval_dataset_orig) 564 | 565 | if training_args.do_predict: 566 | max_target_length = data_args.val_max_target_length 567 | preprocess_function_kwargs = preprocess_function_kwargs_fn() 568 | preprocess_function_kwargs["max_target_length"] = max_target_length 569 | if "test" not in seq2seq_dataset: 570 | raise ValueError("--do_predict requires a test dataset") 571 | untokenized_predict_dataset = seq2seq_dataset["test"] 572 | if data_args.max_predict_samples is not None: 573 | untokenized_predict_dataset = untokenized_predict_dataset.select(range(data_args.max_predict_samples)) 574 | if model_args.drop_duplicates_in_eval is True: 575 | untokenized_predict_dataset = drop_duplicates_in_input(untokenized_predict_dataset) 576 | 577 | if output_column in untokenized_predict_dataset.column_names: 578 | untokenized_predict_dataset = untokenized_predict_dataset.remove_columns(output_column) 579 | 580 | if data_args.test_start_ind is not None: 581 | sind = data_args.test_start_ind 582 | eind = -1 if data_args.test_end_ind is None else data_args.test_end_ind 583 | logger.info(f'Using only a subset of the test dataset [{sind}, {eind}]') 584 | untokenized_predict_dataset = type(untokenized_predict_dataset).from_dict(untokenized_predict_dataset[sind:eind]) 585 | 586 | with training_args.main_process_first( 587 | local=not data_args.shared_storage, desc="prediction dataset map pre-processing" 588 | ): 589 | predict_dataset = untokenized_predict_dataset.map( 590 | preprocess_function, 591 | fn_kwargs=preprocess_function_kwargs, 592 | batched=True, 593 | num_proc=data_args.preprocessing_num_workers, 594 | remove_columns=untokenized_predict_dataset.column_names, 595 | load_from_cache_file=not data_args.overwrite_cache, 596 | desc="Running tokenizer on prediction dataset", 597 | ) 598 | 599 | if data_args.preprocess_only: 600 | logger.info(f"With --preprocess_only, exiting after preprocess_on the data") 601 | exit() 602 | 603 | # Data collator 604 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 605 | pad_to = 8 if training_args.fp16 and training_args.fp16_padding else None 606 | 607 | 608 | data_collator = DataCollatorForSeq2Seq( 609 | tokenizer, 610 | model=model, 611 | label_pad_token_id=label_pad_token_id, 612 | pad_to_multiple_of=pad_to, 613 | ) 614 | 615 | # Metric 616 | compute_metrics = load_metric(data_args.metric_names, **locals()) 617 | 618 | # Initialize our Trainer 619 | trainer = CustomTrainer( 620 | model=model, 621 | args=training_args, 622 | train_dataset=train_dataset if training_args.do_train else None, 623 | eval_dataset=eval_dataset if training_args.do_eval else None, 624 | untokenized_eval_dataset=untokenized_eval_dataset if training_args.do_eval else None, 625 | tokenizer=tokenizer, 626 | data_collator=data_collator, 627 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 628 | output_dir=training_args.output_dir, 629 | data_args=data_args, 630 | ) 631 | 632 | # setup_cometml_trainer_callback(trainer) 633 | 634 | # Training 635 | if training_args.do_train: 636 | checkpoint = None 637 | if training_args.resume_from_checkpoint is not None: 638 | checkpoint = training_args.resume_from_checkpoint 639 | elif last_checkpoint is not None: 640 | checkpoint = last_checkpoint # look for checkpoints in the outdir 641 | 642 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 643 | logger.info('Done training') 644 | trainer.save_model() # Saves the tokenizer too for easy upload 645 | 646 | metrics = train_result.metrics 647 | max_train_samples = ( 648 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 649 | ) 650 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 651 | 652 | trainer.log_metrics("train", metrics) 653 | trainer.save_metrics("train", metrics) 654 | trainer.save_state() 655 | 656 | # Evaluation 657 | results = {} 658 | if training_args.do_eval: 659 | logger.info("*** Evaluate ***") 660 | 661 | if training_args.eval_fraction < 1: 662 | logger.info('setting the eval set back to the full one') 663 | trainer.eval_dataset = eval_dataset_orig 664 | trainer._untokenized_eval_dataset = untokenized_eval_dataset_orig 665 | 666 | metrics = trainer.evaluate(metric_key_prefix="eval") 667 | logger.info('Done evaluating') 668 | 669 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 670 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 671 | 672 | trainer.log_metrics("eval", metrics) 673 | trainer.save_metrics("eval", metrics) 674 | 675 | if training_args.do_predict: 676 | logger.info("*** Predict ***") 677 | trainer.args.predict_with_generate = True # during prediction, we don't have labels 678 | 679 | # load last (and best) model, or the one specified if any 680 | logger.info("*** Loading model weights before the prediction ***") 681 | last_checkpoint = model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else _detect_last_checkpoint(training_args) 682 | if last_checkpoint is not None and os.path.isdir(last_checkpoint): 683 | logger.info(f'Loading weights from {last_checkpoint} for the prediction') 684 | state_dict = torch.load(os.path.join(last_checkpoint, WEIGHTS_NAME), map_location="cpu") 685 | # If the model is on the GPU, it still works! 686 | trainer._load_state_dict_in_model(state_dict) 687 | # release memory 688 | del state_dict 689 | logger.info("*** Done loading weights ***") 690 | elif training_args.do_train: 691 | raise ValueError('Could not find a model to load for prediction') 692 | else: 693 | logger.info(f'Using {model_args.model_name_or_path} as the model for the prediction') 694 | 695 | predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict") 696 | logger.info('Done predicting') 697 | 698 | metrics = predict_results.metrics 699 | max_predict_samples = ( 700 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 701 | ) 702 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 703 | 704 | trainer.log_metrics("predict", metrics) 705 | trainer.save_metrics("predict", metrics) 706 | 707 | if trainer.is_world_process_zero(): 708 | if training_args.predict_with_generate: 709 | id_to_prediction = {} 710 | for i, instance in enumerate(untokenized_predict_dataset): 711 | id_to_prediction[instance["id"]] = predict_results.predictions[i] 712 | predictions = decode(id_to_prediction, tokenizer, data_args) 713 | output_name = "generated_predictions.json" 714 | if data_args.test_start_ind is not None: 715 | output_name = f"generated_predictions_{data_args.test_start_ind}_{data_args.test_end_ind}.json" 716 | output_prediction_file = os.path.join(training_args.output_dir, output_name) 717 | with open(output_prediction_file, "w") as writer: 718 | json.dump(predictions, writer, indent=4) 719 | 720 | if training_args.push_to_hub: 721 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"} 722 | if data_args.dataset_name is not None: 723 | kwargs["dataset_tags"] = data_args.dataset_name 724 | if data_args.dataset_config_name is not None: 725 | kwargs["dataset_args"] = data_args.dataset_config_name 726 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 727 | else: 728 | kwargs["dataset"] = data_args.dataset_name 729 | 730 | trainer.push_to_hub(**kwargs) 731 | 732 | return results 733 | 734 | def _detect_last_checkpoint(training_args): 735 | last_checkpoint = None 736 | if os.path.isdir(training_args.output_dir) and training_args.do_train: 737 | if not training_args.overwrite_output_dir: 738 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 739 | 740 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None: 741 | logger.info( 742 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 743 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 744 | ) 745 | return last_checkpoint 746 | 747 | def process_eval_set(data_args, preprocess_function_kwargs, training_args, untokenized_eval_dataset): 748 | with training_args.main_process_first( 749 | local=not data_args.shared_storage, desc="validation dataset map pre-processing" 750 | ): 751 | eval_dataset = untokenized_eval_dataset.map( 752 | preprocess_function, 753 | fn_kwargs=preprocess_function_kwargs, 754 | batched=True, 755 | num_proc=data_args.preprocessing_num_workers, 756 | remove_columns=untokenized_eval_dataset.column_names, 757 | load_from_cache_file=not data_args.overwrite_cache, 758 | desc="Running tokenizer on validation dataset", 759 | ) 760 | return eval_dataset 761 | 762 | 763 | def _get_dataset(data_args, model_args, training_args): 764 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 765 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 766 | # (the dataset will be downloaded automatically from the datasets Hub). 767 | # 768 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the 769 | # summaries (unless you specify column names for this with the `input_column` and `output_column` arguments). 770 | # 771 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 772 | # download the dataset. 773 | data_files = None 774 | if data_args.train_file is not None or data_args.validation_file is not None or data_args.test_file is not None: 775 | data_files = {} 776 | if data_args.train_file is not None: 777 | data_files["train"] = data_args.train_file 778 | if data_args.validation_file is not None: 779 | data_files["validation"] = data_args.validation_file 780 | if data_args.test_file is not None: 781 | data_files["test"] = data_args.test_file 782 | # Downloading and loading a dataset from the hub/local script. 783 | seq2seq_dataset = load_dataset( 784 | data_args.dataset_name, 785 | data_args.dataset_config_name, 786 | ignore_verifications=True, 787 | cache_dir=model_args.cache_dir, 788 | data_dir=data_args.data_dir, 789 | data_files=data_files, 790 | download_mode=data_args.download_mode, 791 | use_auth_token=training_args.use_auth_token 792 | ) 793 | if training_args.do_train: 794 | training_args.apply_overrides(len(seq2seq_dataset['train'])) 795 | if data_args.evaluate_on_training_data: 796 | seq2seq_dataset["validation"] = seq2seq_dataset["train"] 797 | 798 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 799 | # https://huggingface.co/docs/datasets/loading_datasets.html. 800 | 801 | return seq2seq_dataset 802 | 803 | 804 | def set_up_logging(training_args): 805 | logging.basicConfig( 806 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 807 | datefmt="%m/%d/%Y %H:%M:%S", 808 | handlers=[logging.StreamHandler(sys.stdout)], 809 | ) 810 | log_level = training_args.get_process_log_level() 811 | logger.setLevel(log_level) 812 | datasets.utils.logging.set_verbosity(log_level) 813 | transformers.utils.logging.set_verbosity(log_level) 814 | transformers.utils.logging.enable_default_handler() 815 | transformers.utils.logging.enable_explicit_format() 816 | # Log on each process the small summary: 817 | logger.warning( 818 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 819 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 820 | ) 821 | logger.info(f"Training/evaluation parameters {training_args}") 822 | 823 | 824 | def preprocess_function( 825 | examples, 826 | tokenizer, 827 | prefix, 828 | input_column, 829 | input_prefix_column, 830 | output_column, 831 | max_source_length, 832 | max_prefix_length, 833 | max_target_length, 834 | prefix_sep, 835 | padding, 836 | ignore_pad_token_for_loss, 837 | assign_zero_to_too_long_val_examples, 838 | trim_very_long_strings, 839 | pad_prefix 840 | ): 841 | if not isinstance(examples[input_column][0], str): 842 | model_inputs = _preprocess_tokenized_inputs() 843 | else: 844 | model_inputs = _preprocess_raw_inputs(assign_zero_to_too_long_val_examples, examples, input_column, input_prefix_column, 845 | max_source_length, padding, prefix, tokenizer, trim_very_long_strings, max_prefix_length, 846 | prefix_sep, pad_prefix) 847 | 848 | _preprocess_targets(examples, ignore_pad_token_for_loss, max_target_length, model_inputs, output_column, padding, tokenizer) 849 | model_inputs["length"] = [len(x) for x in model_inputs["input_ids"]] 850 | return model_inputs 851 | 852 | 853 | def _preprocess_raw_inputs(assign_zero_to_too_long_val_examples, examples, input_column, input_prefix_column, 854 | max_source_length, padding, prefix, tokenizer, trim_very_long_strings, max_prefix_length, 855 | prefix_sep, pad_prefix): 856 | inputs = examples[input_column] 857 | 858 | # the given prefix is what used in models like T5 (e.g. "summarize: ") 859 | # if prefix exists, it is added to the input_prefixes 860 | if input_prefix_column in examples.keys(): 861 | input_prefixes = [inp + prefix_sep for inp in examples[input_prefix_column]] 862 | if prefix != "": 863 | input_prefixes = [prefix + inp for inp in input_prefixes] 864 | elif prefix != "": 865 | inputs = [prefix + inp for inp in inputs] 866 | 867 | # tokenize the input prefix if it exists 868 | model_prefix_inputs = None 869 | if input_prefix_column in examples.keys(): 870 | if trim_very_long_strings: 871 | input_prefixes = [inp[: max_prefix_length * 7] for inp in input_prefixes] 872 | if pad_prefix: 873 | model_prefix_inputs = tokenizer(input_prefixes, max_length=max_prefix_length, padding='max_length', truncation=True) 874 | else: 875 | # for led, we do not pad the prefix 876 | model_prefix_inputs = tokenizer(input_prefixes, max_length=max_source_length, padding='do_not_pad', truncation=True) 877 | 878 | if trim_very_long_strings: 879 | inputs = [inp[: max_source_length * 7] for inp in inputs] 880 | model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True) 881 | 882 | if max_source_length is not None and assign_zero_to_too_long_val_examples: 883 | model_inputs_untrimmed = tokenizer(inputs) 884 | model_inputs["not_valid_for_eval"] = [ 885 | len(token_ids) > max_source_length for token_ids in model_inputs_untrimmed["input_ids"] 886 | ] 887 | else: 888 | model_inputs["not_valid_for_eval"] = [False] * len(model_inputs["input_ids"]) 889 | 890 | # now, combine the concat prefix to the input, trimming it to max_source_length if given 891 | if model_prefix_inputs is not None: 892 | max_source_length = max_source_length or -1 893 | model_inputs['input_ids'] = [(inp1+inp2)[:max_source_length] for inp1, inp2 894 | in zip(model_prefix_inputs['input_ids'], model_inputs['input_ids'])] 895 | model_inputs['attention_mask'] = [(inp1+inp2)[:max_source_length] for inp1, inp2 896 | in zip(model_prefix_inputs['attention_mask'], model_inputs['attention_mask'])] 897 | # add prefix_length 898 | if pad_prefix: 899 | # no need to go over them as they will all be of the same length 900 | model_inputs['prefix_length'] = [max_prefix_length] * len(model_inputs['input_ids']) 901 | else: 902 | model_inputs['prefix_length'] = [len(inp) for inp in model_prefix_inputs['input_ids']] 903 | 904 | return model_inputs 905 | 906 | 907 | def _preprocess_tokenized_inputs(): 908 | raise NotImplementedError('did not get here it') # TODO - implement input_prefix support here as well 909 | input_ids = [] 910 | attention_mask = [] 911 | for token_ids in examples[input_column]: 912 | length = len(token_ids) 913 | if max_source_length is not None: 914 | length = min(max_source_length, length) 915 | input_ids.append(token_ids[:length]) 916 | attention_mask.append([1 for _ in range(length)]) 917 | if len(input_ids) == 0: 918 | input_ids = examples[input_column] 919 | model_inputs = { 920 | "input_ids": input_ids, 921 | "attention_mask": attention_mask, 922 | } 923 | if max_source_length is not None: 924 | model_inputs["not_valid_for_eval"] = [ 925 | len(token_ids) > max_source_length and assign_zero_to_too_long_val_examples 926 | for token_ids in examples[input_column] 927 | ] 928 | else: 929 | model_inputs["not_valid_for_eval"] = [False for token_ids in examples[input_column]] 930 | return model_inputs 931 | 932 | 933 | 934 | 935 | def _preprocess_targets(examples, ignore_pad_token_for_loss, max_target_length, model_inputs, output_column, padding, tokenizer): 936 | targets = examples[output_column] if output_column in examples else None 937 | if targets is not None: 938 | if not isinstance(targets[0], str): 939 | if max_target_length is not None: 940 | targets = [target[:max_target_length] for target in targets] 941 | model_inputs["labels"] = targets 942 | else: 943 | # Setup the tokenizer for targets 944 | with tokenizer.as_target_tokenizer(): 945 | labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True) 946 | 947 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 948 | # padding in the loss. 949 | if padding == "max_length" and ignore_pad_token_for_loss: 950 | labels["input_ids"] = [ 951 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 952 | ] 953 | 954 | model_inputs["labels"] = labels["input_ids"] 955 | 956 | 957 | def _mp_fn(index): 958 | # For xla_spawn (TPUs) 959 | main() 960 | 961 | 962 | if __name__ == "__main__": 963 | main() 964 | -------------------------------------------------------------------------------- /examples/seq2seq/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mivg/SLED/2991895d01e5ed5871de41676a14b25c534c523a/examples/seq2seq/utils/__init__.py -------------------------------------------------------------------------------- /examples/seq2seq/utils/config.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | def handle_args_to_ignore(args: List[str]): 5 | indices_to_remove = [] 6 | for i, arg in enumerate(args): 7 | if "_ignore_" in arg: 8 | indices_to_remove.append(i) 9 | if not arg.startswith("-"): 10 | indices_to_remove.append(i - 1) 11 | 12 | for i in sorted(indices_to_remove, reverse=True): 13 | del args[i] 14 | -------------------------------------------------------------------------------- /examples/seq2seq/utils/custom_hf_argument_parser.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import sys 4 | from typing import Tuple 5 | 6 | from transformers import HfArgumentParser 7 | from transformers.hf_argparser import DataClass 8 | 9 | 10 | class CustomHfArgumentParser(HfArgumentParser): 11 | def parse_dictionary_and_args(self) -> Tuple[DataClass, ...]: 12 | """ 13 | Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the 14 | dataclass types. 15 | """ 16 | args = [] 17 | data = {} 18 | for i in range(1, len(sys.argv)): 19 | if not sys.argv[i].endswith('.json'): 20 | break 21 | 22 | with open(sys.argv[i]) as f: 23 | new_data = json.load(f) 24 | conflicting_keys = set(new_data.keys()).intersection(data.keys()) 25 | if len(conflicting_keys) > 0: 26 | raise ValueError(f'There are conflicting keys in the config files: {conflicting_keys}') 27 | data.update(new_data) 28 | 29 | for k, v in data.items(): 30 | # if any options were given explicitly through the CLA then they override anything defined in the config files 31 | if f'--{k}' in sys.argv: 32 | logging.info(f'While {k}={v} was given in a config file, a manual override was set through the CLA') 33 | continue 34 | args.extend( 35 | ["--" + k, *(v if isinstance(v, list) else [str(v)])] 36 | ) # add the file arguments first so command line args has precedence 37 | args += sys.argv[i:] 38 | 39 | return self.parse_args_into_dataclasses(args=args, look_for_args_file=False) -------------------------------------------------------------------------------- /examples/seq2seq/utils/custom_seq2seq_trainer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | import time 5 | from collections import defaultdict 6 | from typing import Any, Dict, List, Optional, Tuple, Union 7 | 8 | import torch 9 | from datasets import Dataset 10 | from torch import nn 11 | from transformers.debug_utils import DebugOption 12 | from transformers.deepspeed import is_deepspeed_zero3_enabled 13 | from transformers.trainer_utils import speed_metrics 14 | 15 | from transformers.utils import logging 16 | from transformers import Seq2SeqTrainer, is_torch_tpu_available 17 | 18 | import gc 19 | 20 | if is_torch_tpu_available(check_device=False): 21 | import torch_xla.core.xla_model as xm 22 | import torch_xla.debug.metrics as met 23 | 24 | 25 | from utils.decoding import decode 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | 30 | def _clean_memory(): 31 | gc.collect() 32 | torch.cuda.empty_cache() 33 | 34 | # This custom trainer is based on the trainer defined in https://github.com/huggingface/transformers/compare/main...eladsegal:public-transformers:scrolls 35 | class CustomTrainer(Seq2SeqTrainer): 36 | def __init__( 37 | self, *args, untokenized_eval_dataset=None, data_args=None, output_dir: Optional[str] = None, **kwargs 38 | ): 39 | super().__init__(*args, **kwargs) 40 | self._untokenized_eval_dataset = untokenized_eval_dataset 41 | self._max_length = data_args.val_max_target_length 42 | self._num_beams = data_args.num_beams 43 | self._output_dir = output_dir 44 | self._data_args = data_args 45 | self.mock_predictions_to_assign_zero_metric_score = self.tokenizer.encode("TOO_MANY_INPUT_TOKENS",return_tensors="np")[0] 46 | 47 | def prediction_step( 48 | self, 49 | model: nn.Module, 50 | inputs: Dict[str, Union[torch.Tensor, Any]], 51 | prediction_loss_only: bool, 52 | ignore_keys: Optional[List[str]] = None, 53 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 54 | """ 55 | Perform an evaluation step on `model` using `inputs`. 56 | 57 | Subclass and override to inject custom behavior. 58 | 59 | Args: 60 | model (`nn.Module`): 61 | The model to evaluate. 62 | inputs (`Dict[str, Union[torch.Tensor, Any]]`): 63 | The inputs and targets of the model. 64 | 65 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 66 | argument `labels`. Check your model's documentation for all accepted arguments. 67 | prediction_loss_only (`bool`): 68 | Whether or not to return the loss only. 69 | 70 | Return: 71 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and 72 | labels (each being optional). 73 | """ 74 | if not ("labels" in inputs or 'decoder_input_ids' in inputs): 75 | logger.warning('When computing loss, must give labels or decoder_input_ids. ' 76 | 'If you only perform prediction, you can safely ignore this message') 77 | # This is an issue here because the input may be longer than the max-output length of the model, 78 | # and if nothing was given it will shift the input and use it to compute loss (and later discard it). 79 | # This may cause an indexing error when absolute embeddings are used (CUDA device side assert) 80 | inputs['decoder_input_ids'] = inputs['input_ids'][:,:2].clone() # dummy outputs 81 | 82 | if not self.args.predict_with_generate or prediction_loss_only: 83 | return super().prediction_step( 84 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys 85 | ) 86 | 87 | has_labels = "labels" in inputs 88 | inputs = self._prepare_inputs(inputs) 89 | 90 | # XXX: adapt synced_gpus for fairscale as well 91 | gen_kwargs = self._gen_kwargs.copy() 92 | gen_kwargs["max_length"] = ( 93 | gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.model.config.max_length 94 | ) 95 | gen_kwargs["num_beams"] = ( 96 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams 97 | ) 98 | default_synced_gpus = True if is_deepspeed_zero3_enabled() else False 99 | gen_kwargs["synced_gpus"] = ( 100 | gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus 101 | ) 102 | 103 | if "attention_mask" in inputs: 104 | gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) 105 | if "global_attention_mask" in inputs: 106 | gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None) 107 | 108 | # --------------------- addition compared to the source file -------------------- 109 | if 'prefix_length' in inputs: 110 | gen_kwargs['prefix_length'] = inputs['prefix_length'] 111 | _clean_memory() 112 | # ------------------------------------------------------------------------------ 113 | 114 | # prepare generation inputs 115 | # some encoder-decoder models can have varying encoder's and thus 116 | # varying model input names 117 | if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name: 118 | generation_inputs = inputs[self.model.encoder.main_input_name] 119 | else: 120 | generation_inputs = inputs[self.model.main_input_name] 121 | 122 | generated_tokens = self.model.generate( 123 | generation_inputs, 124 | **gen_kwargs, 125 | ) 126 | # --------------------- addition compared to the source file -------------------- 127 | _clean_memory() 128 | # ------------------------------------------------------------------------------ 129 | # in case the batch is shorter than max length, the output should be padded 130 | if generated_tokens.shape[-1] < gen_kwargs["max_length"]: 131 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) 132 | 133 | if has_labels: # changed the order of the if's here because there is no point going through the model if there are no labels to compute the loss on.. 134 | with torch.no_grad(): 135 | with self.compute_loss_context_manager(): 136 | outputs = model(**inputs) 137 | if self.label_smoother is not None: 138 | loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() 139 | else: 140 | loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() 141 | else: 142 | loss = None 143 | 144 | if self.args.prediction_loss_only: 145 | return (loss, None, None) 146 | 147 | if has_labels: 148 | labels = inputs["labels"] 149 | if labels.shape[-1] < gen_kwargs["max_length"]: 150 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) 151 | else: 152 | labels = None 153 | 154 | return (loss, generated_tokens, labels) 155 | 156 | @property 157 | def _restart_generator(self): 158 | if getattr(self, '_is_restart_generator', False): 159 | self._is_restart_generator = False 160 | return True 161 | return False 162 | 163 | def set_restart_generator(self): 164 | self._is_restart_generator = True 165 | 166 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 167 | sampler = super()._get_train_sampler() 168 | try: 169 | if self._restart_generator: 170 | sampler.generator.manual_seed(self._initial_seed) 171 | else: 172 | self._initial_seed = sampler.generator.initial_seed() 173 | except Exception as e: 174 | logger.warning(f'Cannot save or set the seed of the generator: {e}') 175 | return sampler 176 | 177 | def _post_process_function(self, untokenized_eval_dataset, predictions): 178 | id_to_prediction = {} 179 | id_to_label_ids = defaultdict(list) 180 | 181 | assert len(untokenized_eval_dataset) == len(self.eval_dataset) 182 | 183 | for i, (instance, not_valid_for_eval) in enumerate(zip(untokenized_eval_dataset, self.eval_dataset["not_valid_for_eval"])): 184 | if not_valid_for_eval: 185 | id_to_prediction[instance["id"]] = self.mock_predictions_to_assign_zero_metric_score 186 | else: 187 | id_to_prediction[instance["id"]] = predictions[i] 188 | 189 | if "outputs" in instance: 190 | id_to_label_ids[instance["id"]] = instance["outputs"] 191 | else: 192 | id_to_label_ids[instance["id"]].append(instance["output"]) 193 | 194 | return id_to_prediction, id_to_label_ids 195 | 196 | def evaluate( 197 | self, 198 | eval_dataset: Optional[Dataset] = None, 199 | ignore_keys: Optional[List[str]] = None, 200 | metric_key_prefix: str = "eval", 201 | untokenized_eval_dataset: Optional[Dataset] = None, 202 | **gen_kwargs 203 | ) -> Dict[str, float]: 204 | """ 205 | Run evaluation and returns metrics. 206 | 207 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent 208 | (pass it to the init `compute_metrics` argument). 209 | 210 | You can also subclass and override this method to inject custom behavior. 211 | 212 | Args: 213 | eval_dataset (`Dataset`, *optional*): 214 | Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns 215 | not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` 216 | method. 217 | ignore_keys (`List[str]`, *optional*): 218 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 219 | gathering predictions. 220 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`): 221 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 222 | "eval_bleu" if the prefix is `"eval"` (default) 223 | max_length (`int`, *optional*): 224 | The maximum target length to use when predicting with the generate method. 225 | num_beams (`int`, *optional*): 226 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no 227 | beam search. 228 | gen_kwargs: 229 | Additional `generate` specific kwargs. 230 | 231 | Returns: 232 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The 233 | dictionary also contains the epoch number which comes from the training state. 234 | """ 235 | 236 | gen_kwargs = gen_kwargs.copy() 237 | gen_kwargs["max_length"] = ( 238 | gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length 239 | ) 240 | gen_kwargs["num_beams"] = ( 241 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams 242 | ) 243 | self._gen_kwargs = gen_kwargs 244 | 245 | self._memory_tracker.start() 246 | 247 | eval_dataloader = self.get_eval_dataloader(eval_dataset) 248 | # ----------------------------------- Added ----------------------------------- 249 | untokenized_eval_dataset = ( 250 | self._untokenized_eval_dataset if untokenized_eval_dataset is None else untokenized_eval_dataset 251 | ) 252 | compute_metrics = self.compute_metrics 253 | self.compute_metrics = None 254 | # ----------------------------------------------------------------------------- 255 | 256 | start_time = time.time() 257 | 258 | eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop 259 | try: 260 | output = eval_loop( 261 | eval_dataloader, 262 | description="Evaluation", 263 | # No point gathering the predictions if there are no metrics, otherwise we defer to 264 | # self.args.prediction_loss_only 265 | prediction_loss_only=None, # MODIFIED since we need the predictions 266 | ignore_keys=ignore_keys, 267 | metric_key_prefix=metric_key_prefix, 268 | ) 269 | finally: 270 | # ----------------------------------- Added ----------------------------------- 271 | # revert the compute metrics back 272 | self.compute_metrics = compute_metrics 273 | # ----------------------------------------------------------------------------- 274 | 275 | # ----------------------------------- Added ----------------------------------- 276 | # compute our metrics 277 | if output.predictions is not None: 278 | eval_preds = self._post_process_function(untokenized_eval_dataset, output.predictions) 279 | 280 | if self._output_dir is not None and self.is_world_process_zero(): 281 | predictions = decode(eval_preds[0], self.tokenizer, self._data_args) 282 | output_prediction_file = os.path.join( 283 | self._output_dir, f"generated_predictions_eval_{self.state.global_step}.json" 284 | ) 285 | with open(output_prediction_file, "w") as writer: 286 | json.dump(predictions, writer, indent=4) 287 | 288 | output_labels_file = os.path.join( 289 | self._output_dir, f"eval_labels.json" 290 | ) 291 | if not os.path.isfile(output_labels_file): 292 | with open(output_labels_file, "w") as writer: 293 | json.dump(eval_preds[1], writer, indent=4) 294 | 295 | if self.compute_metrics is not None: 296 | output.metrics.update(self.compute_metrics(*eval_preds)) 297 | 298 | # Prefix all keys with metric_key_prefix + '_' 299 | for key in list(output.metrics.keys()): 300 | if not key.startswith(f"{metric_key_prefix}_"): 301 | output.metrics[f"{metric_key_prefix}_{key}"] = output.metrics.pop(key) 302 | # ----------------------------------------------------------------------------- 303 | 304 | total_batch_size = self.args.eval_batch_size * self.args.world_size 305 | output.metrics.update( 306 | speed_metrics( 307 | metric_key_prefix, 308 | start_time, 309 | num_samples=output.num_samples, 310 | num_steps=math.ceil(output.num_samples / total_batch_size), 311 | ) 312 | ) 313 | 314 | self.log(output.metrics) 315 | 316 | if DebugOption.TPU_METRICS_DEBUG in self.args.debug: 317 | # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) 318 | xm.master_print(met.metrics_report()) 319 | 320 | self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) 321 | 322 | self._memory_tracker.stop_and_update_metrics(output.metrics) 323 | 324 | return output.metrics 325 | -------------------------------------------------------------------------------- /examples/seq2seq/utils/decoding.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import numpy as np 3 | 4 | 5 | def decode(id_to_something, tokenizer=None, data_args=None): 6 | decode_fn = None 7 | switch_case = None 8 | elem = next(iter(id_to_something.values())) 9 | if isinstance(elem, str): 10 | switch_case = -1 11 | decode_fn = lambda text: text.strip() 12 | elif isinstance(elem, list) and not isinstance(elem[0], int): 13 | if isinstance(elem[0], str): 14 | switch_case = 0 15 | decode_fn = lambda texts: [text.strip() for text in texts] 16 | else: 17 | switch_case = 1 18 | decode_fn = lambda token_ids_list: [ 19 | text.strip() 20 | for text in partial( 21 | tokenizer.batch_decode, skip_special_tokens=True, clean_up_tokenization_spaces=True 22 | )(token_ids_list) 23 | ] 24 | else: 25 | switch_case = 2 26 | decode_fn = lambda token_ids: partial( 27 | tokenizer.decode, skip_special_tokens=True, clean_up_tokenization_spaces=True 28 | )(token_ids).strip() 29 | 30 | id_to_text = {} 31 | for id_, something in id_to_something.items(): 32 | if switch_case == -1 or switch_case == 0: 33 | obj_to_decode = something 34 | else: 35 | if data_args is None: 36 | data_args = {} 37 | if not isinstance(data_args, dict): 38 | data_args = vars(data_args) 39 | if data_args.get("ignore_pad_token_for_loss", True): 40 | # Replace -100 in the token_ids as we can't decode them. 41 | if switch_case == 1: 42 | token_ids_list = something 43 | for i in range(len(token_ids_list)): 44 | token_ids_list[i] = _replace_padding(token_ids_list[i], tokenizer.pad_token_id) 45 | obj_to_decode = token_ids_list 46 | elif switch_case == 2: 47 | token_ids = something 48 | token_ids = _replace_padding(token_ids, tokenizer.pad_token_id) 49 | obj_to_decode = token_ids 50 | else: 51 | obj_to_decode = something 52 | 53 | id_to_text[id_] = decode_fn(obj_to_decode) 54 | 55 | return id_to_text 56 | 57 | 58 | def _replace_padding(token_ids: np.array, pad_token_id): 59 | return np.where(token_ids != -100, token_ids, pad_token_id) 60 | -------------------------------------------------------------------------------- /examples/seq2seq/utils/duplicates.py: -------------------------------------------------------------------------------- 1 | def drop_duplicates_in_input(untokenized_dataset): 2 | indices_to_keep = [] 3 | id_to_idx = {} 4 | outputs = [] 5 | for i, (id_, output) in enumerate(zip(untokenized_dataset["id"], untokenized_dataset["output"])): 6 | if id_ in id_to_idx: 7 | outputs[id_to_idx[id_]].append(output) 8 | continue 9 | indices_to_keep.append(i) 10 | id_to_idx[id_] = len(outputs) 11 | outputs.append([output]) 12 | untokenized_dataset = untokenized_dataset.select(indices_to_keep).flatten_indices() 13 | untokenized_dataset = untokenized_dataset.remove_columns("output") 14 | untokenized_dataset = untokenized_dataset.add_column("outputs", outputs) 15 | return untokenized_dataset 16 | -------------------------------------------------------------------------------- /examples/seq2seq/utils/override_training_args.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | 5 | import torch.cuda 6 | from transformers.utils import logging 7 | 8 | sys.path.insert(0, os.getcwd()) 9 | 10 | from dataclasses import dataclass, field 11 | 12 | from transformers.trainer_utils import IntervalStrategy 13 | from transformers import Seq2SeqTrainingArguments 14 | 15 | logger = logging.get_logger('swed_logger') 16 | 17 | @dataclass 18 | class TrainingOverridesArguments(Seq2SeqTrainingArguments): 19 | """ 20 | To use if, it requires evaluation_strategy == IntervalStrategy.STEPS 21 | """ 22 | eval_steps_override: float = field(default=0, metadata={"help": "a fraction, to set the the save_steps w.r.t to number of steps in " 23 | "a single epoch. changes eval_steps. 0 to disable (default)"}) 24 | save_steps_override: float = field(default=0, metadata={"help": "a fraction, to set the the save_steps w.r.t to number of steps in " 25 | "a single epoch. changes save_steps. must be a multiple of eval_steps" 26 | " (or eval_steps_override if given). 0 to disable (default)"}) 27 | 28 | eval_fraction: float = field(default=1, metadata={ 29 | "help": "A float in (0,1] that corresponds to how much of the eval set to use during evaluations " 30 | "(same subset all the time) or an integer >= 2 which amounts to the absolute number of training " 31 | "samples to use. 1. to disable it and use the entire eval set "}) 32 | 33 | use_auth_token: bool = field( 34 | default=False, 35 | metadata={ 36 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 37 | "with private models). If AUTH_TOKEN is set as an environment variable, would use that" 38 | }, 39 | ) 40 | 41 | fp16_padding: bool = field( 42 | default=False, 43 | metadata={"help": "Whether to use padding for fp16"}, 44 | ) 45 | 46 | 47 | def __post_init__(self): 48 | super(TrainingOverridesArguments, self).__post_init__() 49 | if self.eval_steps_override > 0 or self.save_steps_override > 0: 50 | if self.evaluation_strategy != IntervalStrategy.STEPS: 51 | raise ValueError( 52 | f"using eval/save steps override requires evaluation strategy to be {IntervalStrategy.STEPS}" 53 | ) 54 | if self.save_steps_override == 0 or self.eval_steps_override == 0: 55 | raise ValueError( 56 | f"using eval/save steps override requires both overrides to be non zero" 57 | ) 58 | diff = (self.save_steps_override / self.eval_steps_override) % 1 59 | if min(1-diff, diff) > 1e-5: # we do it like that to support fractions modulo as well, with loss of precision 60 | raise ValueError( 61 | f"using eval/save steps override requires save steps override to be a multiple of eval_steps_override" 62 | ) 63 | if self.use_auth_token and 'AUTH_TOKEN' in os.environ: 64 | self.use_auth_token = os.getenv('AUTH_TOKEN') 65 | 66 | @property 67 | def effective_batch_size(self): 68 | if not hasattr(self, '_ebs'): 69 | n_gpu = self.n_gpu if torch.cuda.is_available() else 1 # may be on cpu 70 | self._ebs = self.per_device_train_batch_size * self.gradient_accumulation_steps * n_gpu 71 | logger.warning(f'Training with {self.per_device_train_batch_size} per_device_train_size, {self.n_gpu} gpus and ' 72 | f'{self.gradient_accumulation_steps} gradient accumulation steps, resulting in {self._ebs} effective batch size') 73 | return self._ebs 74 | 75 | def apply_overrides(self, dataset_size): 76 | if self.eval_steps_override == 0: 77 | return 78 | es, ss = self.eval_steps, self.save_steps 79 | total_steps_per_epoch = dataset_size / self.effective_batch_size # note that this may not be an integer 80 | eval_steps = int(total_steps_per_epoch * self.eval_steps_override) 81 | if eval_steps >= self.logging_steps: 82 | if eval_steps % self.logging_steps != 0: 83 | logger.warning(f'Eval steps override would result in eval every {eval_steps} steps, but it is not a ' 84 | f'multiple of logging steps ({self.logging_steps}) so changing to ' 85 | f'{eval_steps + self.logging_steps - eval_steps % self.logging_steps}') 86 | eval_steps = eval_steps + self.logging_steps - eval_steps % self.logging_steps 87 | elif eval_steps < self.logging_steps: 88 | logger.warning(f'Eval steps override would result in eval every {eval_steps} steps, but it is not a ' 89 | f'multiple of logging steps ({self.logging_steps}) so changing to {self.logging_steps}') 90 | eval_steps = self.logging_steps 91 | self.eval_steps = eval_steps 92 | 93 | save_steps = int(total_steps_per_epoch * self.save_steps_override) 94 | if save_steps < eval_steps or save_steps % eval_steps != 0: 95 | logger.warning(f'Save steps override would result in eval every {save_steps} steps, but it is not a ' 96 | f'multiple of eval steps ({eval_steps}) so changing to ' 97 | f'{save_steps + eval_steps - save_steps % self.eval_steps}') 98 | save_steps = save_steps + eval_steps - save_steps % self.eval_steps 99 | self.save_steps = save_steps 100 | 101 | logger.warning(f'Using overrides with dataset of size {dataset_size} and effective batch size of ' 102 | f'{self.effective_batch_size}, moving from (eval_steps, save_steps) ' 103 | f'of {(es, ss)} to {(self.eval_steps, self.save_steps)}') 104 | -------------------------------------------------------------------------------- /examples/usage_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Minimal working example to use an X-SLED model 3 | """ 4 | import torch 5 | from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM 6 | # noinspection PyUnresolvedReferences 7 | import sled # *** required so that SledModels will be registered for the AutoClasses *** 8 | 9 | if __name__ == '__main__': 10 | # Load the model and tokenizer for Bart-base-SLED 11 | bart_base_sled_model = AutoModel.from_pretrained('tau/bart-base-sled') 12 | tokenizer = AutoTokenizer.from_pretrained('tau/bart-base-sled') 13 | bart_base_sled_model.eval() 14 | 15 | # The below example is for cases where there are no prefix (e.g. question) to use, such as summarization 16 | document_input_ids = tokenizer( 17 | "Studies have been shown that owning a dog is good for you", return_tensors="pt" 18 | ) # Batch size 1 19 | with torch.no_grad(): 20 | final_representations = bart_base_sled_model(**document_input_ids, return_dict=None).last_hidden_state 21 | 22 | # Now, assume we do have a prefix (for example in question answering) 23 | prefix_input_ids = tokenizer( 24 | "Is owning a dog good for you?", return_tensors="pt" 25 | ).input_ids # Batch size 1 26 | 27 | # we concatenate them together, but tell SLED where is the prefix by setting the prefix_length tensor 28 | input_ids = torch.cat((prefix_input_ids, document_input_ids.input_ids), dim=-1) 29 | attention_mask = torch.ones_like(input_ids) 30 | prefix_length = torch.LongTensor([[prefix_input_ids.size(1)]]) 31 | with torch.no_grad(): 32 | final_representations = bart_base_sled_model(input_ids=input_ids, attention_mask=attention_mask, 33 | prefix_length=prefix_length, return_dict=None).last_hidden_state 34 | 35 | # However, we are dealing with a generative model here (encoder-decoder), so, we can use it to generate 36 | bart_base_sled_model = AutoModelForSeq2SeqLM.from_pretrained('tau/bart-base-sled') 37 | with torch.no_grad(): 38 | generated_output = bart_base_sled_model.generate(input_ids=input_ids, 39 | prefix_length=prefix_length, return_dict=None) 40 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | # This setup file was created with the help of https://github.com/MichaelKim0407/tutorial-pip-package 3 | 4 | extra_t5 = [ 5 | 'protobuf<=3.20.1', 6 | 'sentencepiece' 7 | ] 8 | extra_examples = [ 9 | 'nltk', 10 | 'datasets>=1.17.0', 11 | 'absl-py', 12 | 'rouge-score', 13 | 'pandas' 14 | ] 15 | 16 | extra_dev = [ 17 | *extra_t5, 18 | *extra_examples, 19 | ] 20 | 21 | setup( 22 | name='py-sled', 23 | version='0.1.7', 24 | 25 | python_requires='>=3.7.0', 26 | install_requires=[ 27 | 'transformers==4.21.0', 28 | 'makefun>=1.14.0', 29 | ], 30 | extras_require={ 31 | 't5': extra_t5, 32 | 'examples': extra_examples, 33 | 'dev': extra_dev 34 | }, 35 | description='SLED models use pretrained, short-range encoder-decoder models, and apply them over long-text inputs '\ 36 | 'by splitting the input into multiple overlapping chunks, encoding each independently and '\ 37 | 'perform fusion-in-decoder', 38 | 39 | url='https://github.com/Mivg/SLED', 40 | author='Maor Ivgi', 41 | author_email='maor.ivgi@cs.tau.ac.il', 42 | 43 | packages=find_packages(exclude=("tests*", "examples*")), # also added to manifest.in due to https://stackoverflow.com/a/46320848 44 | 45 | classifiers=[ 46 | 'Intended Audience :: Developers', 47 | 'Programming Language :: Python', 48 | 'Programming Language :: Python :: 3', 49 | 'License :: OSI Approved :: MIT License', 50 | 'Natural Language :: English', 51 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 52 | 'Topic :: Text Processing' 53 | ], 54 | ) 55 | # build with `python setup.py sdist` 56 | # upload with `python3 -m twine upload dist/` -------------------------------------------------------------------------------- /sled/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1.7' 2 | 3 | try: 4 | # noinspection PyPackageRequirements 5 | import torch 6 | except ImportError: 7 | raise ImportError('Using sled requires torch. Please refer to https://pytorch.org/get-started/locally/ ' 8 | 'to install the correct version for your setup') 9 | 10 | from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoModelForSeq2SeqLM 11 | 12 | from .configuration_sled import SledConfig 13 | # noinspection PyUnresolvedReferences 14 | from .modeling_sled import SledModel, SledForConditionalGeneration, PREFIX_KEY 15 | from .tokenization_sled import SledTokenizer 16 | from .tokenization_sled_fast import SledTokenizerFast 17 | 18 | AutoConfig.register('tau/sled', SledConfig) 19 | AutoModel.register(SledConfig, SledModel) 20 | AutoModelForSeq2SeqLM.register(SledConfig, SledForConditionalGeneration) 21 | AutoTokenizer.register(SledConfig, slow_tokenizer_class=SledTokenizer, fast_tokenizer_class=SledTokenizerFast) -------------------------------------------------------------------------------- /sled/configuration_sled.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | from transformers import PretrainedConfig, AutoConfig 4 | 5 | 6 | class SledConfig(PretrainedConfig): 7 | r""" 8 | """ 9 | model_type = "tau/sled" 10 | 11 | def __init__(self, underlying_config="facebook/bart-base", context_size=256, window_fraction=0.5, 12 | prepend_prefix=True, encode_prefix=True, sliding_method='dynamic', **kwargs): 13 | super().__init__(**kwargs) 14 | parent_only_keys = set(super().to_dict().keys()) 15 | self.underlying_config = underlying_config 16 | self.context_size = context_size 17 | self.window_fraction = window_fraction 18 | self.prepend_prefix = prepend_prefix 19 | self.encode_prefix = encode_prefix 20 | self.sliding_method = sliding_method 21 | 22 | # load underlying_config 23 | config = AutoConfig.from_pretrained(underlying_config, **kwargs) 24 | # update internal dict based on the underlying config, overriding everything EXCEPT what was explicitly set here 25 | ignore_keys = set(self.to_dict().keys()).union(type(self).__dict__.keys()).difference(parent_only_keys) 26 | self.__dict__.update({k: v for k, v in config.to_dict().items() if k not in ignore_keys}) 27 | -------------------------------------------------------------------------------- /sled/modeling_sled.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import inspect 3 | import os 4 | import warnings 5 | from typing import Dict, Any, Optional 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from makefun import create_function 10 | from requests.exceptions import HTTPError 11 | from torch import nn 12 | from transformers import PreTrainedModel, AutoModel, AutoModelForSeq2SeqLM, WEIGHTS_NAME 13 | from transformers.generation_utils import GenerationMixin 14 | from transformers.modeling_outputs import BaseModelOutput 15 | from transformers.utils import logging 16 | 17 | 18 | from .configuration_sled import SledConfig 19 | 20 | logger = logging.get_logger(__name__) 21 | 22 | 23 | PREFIX_KEY = 'prefix_length' 24 | 25 | def _legacy_download_weights(kwargs, pretrained_model_name_or_path): 26 | # noinspection PyUnresolvedReferences 27 | from transformers.utils import hf_bucket_url, cached_path, EntryNotFoundError 28 | archive_file = hf_bucket_url( 29 | pretrained_model_name_or_path, 30 | filename=WEIGHTS_NAME, 31 | revision=kwargs.pop("revision", None), 32 | mirror=kwargs.pop("mirror", None), 33 | subfolder=kwargs.pop("subfolder", None), 34 | ) 35 | logger.info(f'Looking for pretrained weights on {archive_file}') 36 | 37 | try: 38 | # Load from URL or cache if already cached 39 | user_agent = {"file_type": "model", "framework": "pytorch", 40 | "from_auto_class": kwargs.pop("_from_auto", False)} 41 | from_pipeline = kwargs.pop("_from_pipeline", None) 42 | if from_pipeline is not None: 43 | user_agent["using_pipeline"] = from_pipeline 44 | resolved_archive_file = cached_path( 45 | archive_file, 46 | cache_dir=kwargs.pop("cache_dir", None), 47 | force_download=kwargs.pop("force_download", False), 48 | proxies=kwargs.pop("proxies", None), 49 | resume_download=kwargs.pop("resume_download", False), 50 | local_files_only=kwargs.pop("local_files_only", False), 51 | use_auth_token=kwargs.pop("use_auth_token", None), 52 | user_agent=user_agent, 53 | ) 54 | 55 | logger.info(f'Successfully downloaded the weights to {resolved_archive_file}') 56 | return torch.load(resolved_archive_file, map_location="cpu") 57 | except EntryNotFoundError: 58 | logger.info('Did not find a SLED weights file to be loaded.' 59 | ' If this is not unexpected, please reach out. ' 60 | 'Note - loading sharded weight file is not currently supported') 61 | return None 62 | 63 | def _download_weights(kwargs, pretrained_model_name_or_path): 64 | try: 65 | from transformers.utils import cached_file, EntryNotFoundError, HUGGINGFACE_CO_RESOLVE_ENDPOINT 66 | # Load from URL 67 | user_agent = {"file_type": "model", "framework": "pytorch", 68 | "from_auto_class": kwargs.pop("_from_auto", False)} 69 | cached_filename = cached_file( 70 | pretrained_model_name_or_path, 71 | WEIGHTS_NAME, # shard_filename, 72 | cache_dir=kwargs.pop("cache_dir", None), 73 | force_download=kwargs.pop("force_download", False), 74 | proxies=kwargs.pop("proxies", None), 75 | resume_download=kwargs.pop("resume_download", False), 76 | local_files_only=kwargs.pop("local_files_only", False), 77 | use_auth_token=kwargs.pop("use_auth_token", None), 78 | user_agent=user_agent, 79 | revision=kwargs.pop("revision", None), 80 | subfolder=kwargs.pop("subfolder", None), 81 | _commit_hash=kwargs.pop("_commit_hash", None), 82 | ) 83 | logger.info(f'Successfully downloaded the weights to {cached_filename}') 84 | return torch.load(cached_filename, map_location="cpu") 85 | # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so 86 | # we don't have to catch them here. 87 | except ImportError: 88 | # probably an older version of transformers. try to fallback on legacy load 89 | logger.info('Could not find a weights file with the new huggingface api hub. Attempting fallback to the old api') 90 | except (OSError, EntryNotFoundError) as e: 91 | logger.info('Did not find a SLED weights file to be loaded.' 92 | ' If this is not unexpected, please reach out. ' 93 | f'Note - loading sharded weight file is not currently supported ({e})') 94 | return None 95 | except HTTPError: 96 | raise EnvironmentError( 97 | f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {WEIGHTS_NAME}. You should try" 98 | " again after checking your internet connection." 99 | ) 100 | 101 | try: 102 | return _legacy_download_weights(kwargs, pretrained_model_name_or_path) # attempt fallback 103 | except ImportError: 104 | logger.error('Could not find a SLED weights file to be loaded due to unmatching transformers version. ' 105 | 'Please open as issue on https://github.com/Mivg/SLED/issues') 106 | raise 107 | 108 | 109 | def _find_tensor_inds_and_size(*args, **kwargs): 110 | args_tensor_inds = [i for i, v in enumerate(args) if isinstance(v, torch.Tensor)] 111 | kwargs_tensor_keys = [k for k, v in kwargs.items() if isinstance(v, torch.Tensor)] 112 | 113 | assert len(args_tensor_inds) + len(kwargs_tensor_keys) > 0, "no tensors were found" 114 | # tensor are 2d at this point with [N, s], N is the number of inputs, s is the (padded) sequence length 115 | size = args[args_tensor_inds[0]].size() if len(args_tensor_inds) > 0 else kwargs[kwargs_tensor_keys[0]].size() 116 | assert len(size) == 2, f"expected 2-d tensors but got tensors of shape: {size}" 117 | _, s = size 118 | # can also assert that all tensors here share the same size, but we will skip this for efficiency’s sake and make 119 | # this assumption 120 | 121 | return args_tensor_inds, kwargs_tensor_keys, s 122 | 123 | 124 | def _pad_if_needed(v, pad=None): 125 | if pad is None: 126 | return v 127 | return torch.pad 128 | 129 | 130 | def _concat_input_prefix_if_needed(v, start_ind, end_ind, prefix_length=None, pad=None): 131 | if prefix_length is None: 132 | v = v[:, start_ind:end_ind] 133 | if pad is not None and pad != 0: 134 | # assert v.dim() == 2 135 | assert pad >= 0, f'padding should be non negative but it is negative ({pad})' 136 | v = F.pad(v, (0, pad), "constant", 0) 137 | # padding only from the right so (0,pad), and only in the second axis so not length 4 138 | return v 139 | return torch.cat((v[:, :prefix_length], v[:, prefix_length + start_ind:prefix_length + end_ind]), axis=1) 140 | 141 | 142 | def _fix_args(args, args_tensor_inds, start_ind, end_ind, prefix_length=None, pad=None): 143 | return tuple(v if i not in args_tensor_inds else 144 | _concat_input_prefix_if_needed(v, start_ind, end_ind, prefix_length, pad) 145 | for i, v in enumerate(args)) 146 | 147 | 148 | def _fix_kwargs(kwargs, kwargs_tensor_keys, start_ind, end_ind, prefix_length=None, pad=None): 149 | return { 150 | k: v if (k not in kwargs_tensor_keys or k.startswith("decoder")) else 151 | _concat_input_prefix_if_needed(v, start_ind, end_ind, prefix_length, pad) 152 | for k, v in kwargs.items() 153 | } 154 | 155 | 156 | def _stack_args(stack_args, args_tensor_inds): 157 | return tuple(v if i not in args_tensor_inds else torch.cat(tuple(si[i] for si in stack_args)) 158 | for i, v in enumerate(stack_args[0])) 159 | 160 | 161 | def _stack_kwargs(stack_kwargs, kwargs_tensor_keys): 162 | try: 163 | return {k: v if k not in kwargs_tensor_keys else torch.cat(tuple(si[k] for si in stack_kwargs)) 164 | for k, v in stack_kwargs[0].items()} 165 | except RuntimeError as e: 166 | for k in kwargs_tensor_keys: 167 | logger.warning(f'problematic key={k}. size={tuple(si[k].size() for si in stack_kwargs)}') 168 | if str(e).startswith('Sizes of tensors must match except in dimension 0'): 169 | logger.warning('Most likely you passed in non-padded batch. make sure all examples in the batch have the same length') 170 | raise 171 | 172 | 173 | def _unstack_encoder_outputs(stacked_output, n, bs): 174 | if isinstance(stacked_output, tuple): 175 | return [tuple(v if not isinstance(v, torch.Tensor) else v[i * bs:(i + 1) * bs] for v in stacked_output) 176 | for i in range(n)] 177 | # works for dict as well as structured outputs 178 | return [type(stacked_output)(**{k: v if not isinstance(v, torch.Tensor) else v[i*bs:(i+1)*bs] 179 | for k, v in stacked_output.items()}) 180 | for i in range(n)] 181 | 182 | 183 | def _extract_keyword_args(kwargs, arg_names, prefix=None): 184 | new_kwargs = {arg_name: kwargs.get(arg_name, None) for arg_name in arg_names} 185 | if prefix is not None: 186 | for arg_name in arg_names: 187 | new_arg_name = prefix + arg_name 188 | if new_arg_name in kwargs and new_arg_name not in new_kwargs: 189 | new_kwargs[arg_name] = kwargs[new_arg_name] 190 | return new_kwargs 191 | 192 | 193 | def _slice_tensor(v, start, end, prefix_length=None): 194 | prefix_length = prefix_length or 0 195 | return v[:, start+prefix_length:end+prefix_length] 196 | 197 | 198 | def _merge_encoder_outputs(encoder_outputs_list): 199 | # a list of 4-tuples, first value is the returned value from the encoder, then the start and end indices inside the 200 | # tensors that we should take, and finally prefix_length (None if was not used) 201 | 202 | # presumed order of returned tuple from encoders: last_hidden_state, hidden_states, attentions 203 | 204 | # the first output, as returned by the underlying model on the first window 205 | resulting_output = encoder_outputs_list[0][0] 206 | if isinstance(resulting_output, tuple): # not in return dict mode: 207 | resulting_list = [] 208 | for i in range(len(resulting_output)): 209 | if resulting_output[i] is None: 210 | resulting_list.append(None) 211 | elif ( 212 | isinstance(resulting_output[i], (int, float, tuple, MockTensorForConditionalGeneration)) 213 | or resulting_output[i].dim() != 3 214 | ): 215 | continue 216 | else: 217 | assert isinstance(resulting_output[i], torch.Tensor) 218 | # tensors are of of size (N, w, d), N the batch size, w the current window size and d the hidden 219 | # state size/logits dimension these are the only parts in the encoder output that we need to merge 220 | # between windows 221 | resulting_list.append( 222 | torch.cat(tuple(_slice_tensor(out[i], start, end, prefix_length) 223 | for out, start, end, prefix_length in encoder_outputs_list), dim=1) 224 | ) # requires extra GPU memory because it doesn't dump the old copy of the tensors yet 225 | resulting_output = tuple(resulting_list) 226 | else: 227 | for key in resulting_output.keys(): 228 | if resulting_output[key] is None: 229 | continue 230 | if isinstance(resulting_output[key], tuple): 231 | resulting_output[key] = None # encoder outputs are not tuples, only the decoders 232 | else: 233 | assert isinstance(resulting_output[key], torch.Tensor) 234 | if resulting_output[key].dim() != 3: 235 | continue # decoder outputs may be 4d tensors 236 | # tensors are of of size (N, w, d), N the batch size, w the current window size and d the hidden 237 | # state size/logits dimension 238 | resulting_output[key] = torch.cat( 239 | tuple(_slice_tensor(out[key], start, end, prefix_length) 240 | for out, start, end, prefix_length in encoder_outputs_list), dim=1 241 | ) 242 | 243 | return resulting_output 244 | 245 | 246 | class MockDecoder(nn.Module): 247 | def forward(self, *_, **__): 248 | return tuple() 249 | 250 | def to(self, *_, **__): 251 | return self 252 | 253 | 254 | class SledPretrainedModel(PreTrainedModel, metaclass=abc.ABCMeta): 255 | config_class = SledConfig 256 | auto_model_loader = AutoModel 257 | IGNORE_CONFIG_KEYS = {'model_type', '_name_or_path'} # config keys we allow to be mismatched between the 258 | # SledConfig and the underlying model's config 259 | 260 | def __init__(self, underlying_model: PreTrainedModel, config: SledConfig): 261 | """ 262 | 263 | :param underlying_model: The backbone model to use. 264 | Warning - once given, it should not be used directly, as it may cause unexpected behaviours 265 | :param config: 266 | """ 267 | super(SledPretrainedModel, self).__init__(config) 268 | self._underlying_model = ( 269 | underlying_model # crucial this will be before any calls to members that is in the base model 270 | ) 271 | 272 | self._window_fraction = config.window_fraction 273 | self._context_size = config.context_size 274 | self._window_margin = int(config.context_size * (1 - config.window_fraction) / 2) 275 | self._sliding_method = config.sliding_method or 'dynamic' 276 | assert self._sliding_method in {'loop', 'stacked', 'dynamic', 'decoderonly'} 277 | 278 | for override_member in ['is_parallelizable', 'supports_gradient_checkpointing']: 279 | setattr(self, override_member, getattr(underlying_model, override_member)) 280 | 281 | # setting the base_model_prefix to return the correct underlying model and link to some methods 282 | # implemented in the base 283 | self.base_model_prefix = "sled_base_model_prefix" 284 | self.sled_base_model_prefix = self._underlying_model.base_model 285 | 286 | # override generation preparation functions that may be overridden by underlying models but will be 287 | # found in our wrapper. We wished we could do it a follows: 288 | # for method_name, _ in inspect.getmembers(PreTrainedModel, predicate=inspect.isfunction): 289 | # if method_name not in {"_replicate_for_data_parallel", 'modules'}: 290 | # setattr(self, method_name, getattr(underlying_model, method_name)) 291 | # However, the above is too broad and dangerous, so we will do it directly 292 | for method_name in {"_init_weights", "prepare_inputs_for_generation"}: 293 | if hasattr(underlying_model, method_name): 294 | setattr(self, method_name, getattr(underlying_model, method_name)) 295 | 296 | # set the resize_token_embeddings 297 | vocab_size = underlying_model.get_input_embeddings().weight.size(0) 298 | assert hasattr(self.config, 'vocab_size'), 'Underlying models must have a vocab_size config' 299 | assert underlying_model.config.vocab_size == vocab_size 300 | self.resize_token_embeddings(vocab_size) # the underlying model may have a different vocab size compared to its base config 301 | 302 | self._verified_config_consistency = False 303 | self._verify_config_consistency() 304 | self._verified_config_consistency = False # We would like to do it later again (before the first forward) 305 | 306 | self._prepend_prefix = config.prepend_prefix 307 | self._encode_prefix = config.encode_prefix 308 | 309 | # now, let's create the forward function 310 | self._create_forward_function() 311 | 312 | @classmethod 313 | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): 314 | assert isinstance(pretrained_model_name_or_path, str), 'pretrained_model_name_or_path must be a path to a ' \ 315 | 'checkpoint or a local config file (json)' 316 | config = kwargs.pop('config', None) or \ 317 | cls.config_class.from_pretrained(pretrained_model_name_or_path, 318 | use_auth_token=kwargs.get('use_auth_token', False)) 319 | state_dict = kwargs.pop('state_dict', None) 320 | if os.path.exists(pretrained_model_name_or_path): 321 | # if pretrained_model_name_or_path is a saved checkpoint 322 | if pretrained_model_name_or_path.endswith('json'): 323 | underlying_model = cls.auto_model_loader.from_pretrained(config.underlying_config, *model_args, **kwargs) 324 | else: 325 | # otherwise, it is a config json path + weights. Note LSED doesn't have any weights of its own, 326 | # so the state dict is only for the underlying model 327 | backbone_state_dict = torch.load(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME), map_location="cpu") 328 | underlying_model = cls.auto_model_loader.from_pretrained(config.underlying_config, *model_args, 329 | **kwargs) 330 | model = cls(underlying_model, config) 331 | cls._load_state_dict(backbone_state_dict, model, underlying_model) 332 | return model 333 | 334 | else: 335 | # assume it is a model card on the hub ) 336 | underlying_model = cls.auto_model_loader.from_pretrained(config.underlying_config, *model_args, **kwargs) 337 | 338 | state_dict = cls._load_remote_state_dict(kwargs, pretrained_model_name_or_path, state_dict) 339 | 340 | sled_model = cls(underlying_model, config) 341 | if state_dict is not None: 342 | logger.info('Updating SLED model weights from state dict') 343 | cls._load_state_dict(state_dict, sled_model) 344 | return sled_model 345 | 346 | @classmethod 347 | def _load_state_dict(cls, backbone_state_dict, model, underlying_model=None): 348 | load_result = model.load_state_dict(backbone_state_dict, strict=False) 349 | # Known issue - when loading a model checkpoint of type AutoModelForSeq2SeqLM with AutoModel, 350 | # the state dict will not be loaded correctly. Documented [here](https://github.com/Mivg/SLED/issues/4) 351 | if len(load_result.missing_keys) != 0: 352 | if underlying_model is not None and \ 353 | model._keys_to_ignore_on_save is not None and \ 354 | set(load_result.missing_keys) == set(model._keys_to_ignore_on_save): 355 | underlying_model.tie_weights() 356 | else: 357 | logger.warn( 358 | f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.") 359 | if len(load_result.unexpected_keys) != 0: 360 | logger.warn( 361 | f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}.") 362 | 363 | @classmethod 364 | def _load_remote_state_dict(cls, kwargs, pretrained_model_name_or_path, state_dict): 365 | if state_dict is None: 366 | state_dict = _download_weights(kwargs, pretrained_model_name_or_path) 367 | return state_dict 368 | 369 | def _verify_config_consistency(self): 370 | if not self._verified_config_consistency: 371 | # SLED models are built in one of the following ways: 372 | # 1. Explicitly (given a SledConfig and an underlying model) 373 | # 2. from_pretrained (local_config, hub config, saved checkpoint) 374 | # There are 2 cases where the underlying_config and the SledConfig may mismatch - 1, and 2.saved_checkpoint 375 | # Instead of deciding whether to update the underlying model's config or vice versa while we cannot know which 376 | # one is correct, it is better to raise an exception. The only key we were willing to tolerate is the 377 | # vocab size, and we dealt with it above 378 | # Note - we will only do it once on the first forward pass. Setting a config in the model after it has 379 | # been used is non advisble 380 | config_dict = self.config.to_dict() 381 | underlying_config_dict = self.underlying_model.config.to_dict() 382 | matching_keys = set(config_dict.keys()).intersection(underlying_config_dict.keys()).difference(self._ignore_keys) 383 | inconsistent_keys = {k: (config_dict[k], underlying_config_dict[k]) 384 | for k in matching_keys if config_dict[k] != underlying_config_dict[k]} 385 | 386 | if len(inconsistent_keys) > 0: 387 | # raise ValueError(f'SledConfig and the underlying_model config has mismatching configurations on: {inconsistent_keys}') 388 | # if we loaded the config with overrides, there may still be some conflicts 389 | logger.warning( 390 | f'SledConfig and the underlying_model config has mismatching configurations on: {inconsistent_keys}, ' 391 | f'probably due to config_overrides. Setting the underlying config to match SLEDs') 392 | for k in inconsistent_keys: 393 | setattr(self.underlying_model.config, k, getattr(self.config, k)) 394 | 395 | config = self.config 396 | if self._window_fraction != config.window_fraction or \ 397 | self._context_size != config.context_size or\ 398 | self._window_margin != int(config.context_size * (1 - config.window_fraction) / 2) or \ 399 | self._sliding_method != config.sliding_method: 400 | raise RuntimeError('SLED does not support changing its configuration after it is initialized. ' 401 | 'Try reloading the model with overrides') 402 | 403 | self._verified_config_consistency = True 404 | 405 | @property 406 | def underlying_model(self): 407 | return self._underlying_model 408 | 409 | def resize_token_embeddings(self, new_num_tokens=None): 410 | res = self.underlying_model.resize_token_embeddings(new_num_tokens) 411 | self.config.vocab_size = self.underlying_model.vocab_size # sync them 412 | return res 413 | 414 | @property 415 | def _ignore_keys(self): 416 | return self.IGNORE_CONFIG_KEYS 417 | 418 | def _replicate_for_data_parallel(self): 419 | replica = super()._replicate_for_data_parallel() 420 | replica.forward = create_function(self._signature, replica._forward) 421 | return replica 422 | 423 | def _create_forward_function(self): 424 | # https://stackoverflow.com/a/15638720 425 | self._underlying_model_signature = inspect.signature(self._underlying_model.forward) 426 | self._forward_kwargs_names = [param.name for param in self._underlying_model_signature.parameters.values()] 427 | assert PREFIX_KEY not in self._forward_kwargs_names 428 | # if we want to prepend questions in every window, we need to set the forward signature to expect the 429 | # input_prefix (e.g. question) as a separate input sequence 430 | 431 | # we want to remove any typing information as it may cause issues in the custom function build do to 432 | # non-imported modules. It is ugly and shouldn't be done like that, but it works.. 433 | params = [self._underlying_model_signature.parameters[p].replace(annotation=inspect.Parameter.empty) 434 | for p in self._underlying_model_signature.parameters] 435 | params.append(inspect.Parameter(name=PREFIX_KEY, default=None, kind=params[-1].kind)) 436 | self._signature = str(self._underlying_model_signature.replace(parameters=params, 437 | return_annotation=inspect.Signature.empty)) 438 | 439 | # HF trainer uses the signature to choose which parts to take from a dataset, so we need to make sure our 440 | # wrapped forward function has the correct signature (dynamically creating it here) 441 | self.forward = create_function(self._signature, self._forward) 442 | 443 | def __getattr__(self, item): 444 | try: 445 | return super().__getattr__(item) 446 | except AttributeError: 447 | try: 448 | return self._underlying_model.__getattribute__(item) 449 | except AttributeError: 450 | return self._underlying_model.__getattr__(item) 451 | 452 | @abc.abstractmethod 453 | def _forward(self, *args, **kwargs): 454 | # the actual forward implementation of the model. 455 | raise NotImplementedError 456 | 457 | def _set_underlying_model_attr(self, attr_name, new_val): 458 | if hasattr(self._underlying_model, attr_name): 459 | setattr(self._underlying_model, attr_name, new_val) 460 | elif hasattr(self._underlying_model, "model") and hasattr(self._underlying_model.model, attr_name): 461 | setattr(self._underlying_model.model, attr_name, new_val) 462 | else: 463 | raise ValueError(f"Cannot use this model as we cannot set its {attr_name}") 464 | 465 | def _run_sliding_window_forward(self, args_tensor_inds, kwargs_tensor_keys, s, *args, 466 | prefix_length=None, **kwargs): 467 | sm = self._sliding_method if self._sliding_method != 'dynamic' else \ 468 | ('loop' if not self.training else 'stacked') 469 | try: 470 | if sm == 'decoderonly': 471 | return self._skip_forward_for_decoder_only(args_tensor_inds, kwargs_tensor_keys, s, *args, 472 | prefix_length=prefix_length, **kwargs) 473 | if sm == 'loop': 474 | return self._run_sliding_window_forward_loop(args_tensor_inds, kwargs_tensor_keys, s, *args, 475 | prefix_length=prefix_length, **kwargs) 476 | return self._run_sliding_window_forward_stacked(args_tensor_inds, kwargs_tensor_keys, s, *args, 477 | prefix_length=prefix_length, **kwargs) 478 | finally: 479 | # so that if the model crashes halfway through it will be restored to working order 480 | pass 481 | 482 | def _skip_forward_for_decoder_only(self, args_tensor_inds, kwargs_tensor_keys, s, *args, 483 | prefix_length=None, **kwargs): 484 | # NOTE - this will work probably only with BART. 485 | embeder = self if hasattr(self, 'embed_tokens') else self.get_encoder() # account for sled encoder 486 | return (embeder.embed_tokens(kwargs['input_ids']), ) 487 | 488 | 489 | def _run_sliding_window_forward_loop(self, args_tensor_inds, kwargs_tensor_keys, s, *args, 490 | prefix_length=None, **kwargs): 491 | forward_kwargs = _extract_keyword_args(kwargs, self._forward_kwargs_names, None) 492 | encoder_outputs_list = [] 493 | if prefix_length is not None and self._prepend_prefix: 494 | # we were given prefixes in the input, and we are expected to treat them 495 | prefix_length, s = self._handle_prefix(prefix_length, s) 496 | 497 | if self._encode_prefix: 498 | # encode the question as well, if needed 499 | context_start_ind, context_end_ind, update_start_ind, update_end_ind = 0, prefix_length, 0, prefix_length 500 | 501 | encoder_outputs = self._underlying_model.forward( 502 | *_fix_args(args, args_tensor_inds, context_start_ind, context_end_ind, None), 503 | **_fix_kwargs(forward_kwargs, kwargs_tensor_keys, context_start_ind, context_end_ind, None), 504 | ) 505 | encoder_outputs_list.append((encoder_outputs, update_start_ind, update_end_ind, None)) 506 | # we will need to make sure all input tensors will also drop everything with the prefix 507 | else: 508 | prefix_length = None # we need to ignore the prefix and treat the entire input as one long document 509 | 510 | for context_start_ind, context_end_ind, update_start_ind, update_end_ind in self._window_indices(s): 511 | encoder_outputs = self._underlying_model.forward( 512 | *_fix_args(args, args_tensor_inds, context_start_ind, context_end_ind, prefix_length), 513 | **_fix_kwargs(forward_kwargs, kwargs_tensor_keys, context_start_ind, context_end_ind, prefix_length), 514 | ) 515 | encoder_outputs_list.append((encoder_outputs, update_start_ind, update_end_ind, prefix_length)) 516 | 517 | return _merge_encoder_outputs(encoder_outputs_list) 518 | 519 | def _handle_prefix(self, prefix_length, s): 520 | prefix_length_ = prefix_length[0].detach().cpu().item() 521 | assert torch.all(prefix_length == prefix_length_).item(), \ 522 | 'Using different length prefixes in the same batch is not supported. Either group your batch by ' \ 523 | 'prefix length, or pad the prefixes to match in length (and do not forget to set the attention ' \ 524 | 'mask to 0 where appropriate)' 525 | if hasattr(self.underlying_model.config, 'max_position_embeddings'): 526 | assert self._context_size + prefix_length_ <= self.underlying_model.config.max_position_embeddings, \ 527 | f'The prefix length + SLEDs chunk size must be at most the max length that the backbone model can handle' 528 | return prefix_length_, s-prefix_length_ 529 | 530 | def _run_sliding_window_forward_stacked(self, args_tensor_inds, kwargs_tensor_keys, s, *args, 531 | prefix_length=None, **kwargs): 532 | forward_kwargs = _extract_keyword_args(kwargs, self._forward_kwargs_names, None) 533 | stacks_args = [] 534 | stacks_kwargs = [] 535 | stacks_info = [] 536 | 537 | if prefix_length is not None and self._prepend_prefix: 538 | # we were given prefixes in the input, and we are expected to treat them 539 | prefix_length, s = self._handle_prefix(prefix_length, s) 540 | 541 | if self._encode_prefix: 542 | # encode the question as well, if needed 543 | context_start_ind, context_end_ind, update_start_ind, update_end_ind = 0, prefix_length, 0, prefix_length 544 | # need to pad it to match the seq len of the rest 545 | # we may have too short samples as well so don't want to pad too much 546 | pad = min(s, self._context_size) 547 | assert pad >= 0, f'We have a weird situation. pad={pad}, s={s}, ' \ 548 | f'prefix_length={prefix_length} and self._context_size={self._context_size}' 549 | stacks_args.append(_fix_args(args, args_tensor_inds, context_start_ind, context_end_ind, None, pad)) 550 | stacks_kwargs.append(_fix_kwargs(forward_kwargs, kwargs_tensor_keys, context_start_ind, 551 | context_end_ind, None, pad)) 552 | stacks_info.append([None, update_start_ind, update_end_ind, None]) 553 | else: 554 | prefix_length = None # we need to ignore the prefix and treat the entire input as one long document 555 | 556 | for context_start_ind, context_end_ind, update_start_ind, update_end_ind in self._window_indices(s): 557 | stacks_args.append(_fix_args(args, args_tensor_inds, context_start_ind, context_end_ind, prefix_length)) 558 | stacks_kwargs.append(_fix_kwargs(forward_kwargs, kwargs_tensor_keys, context_start_ind, 559 | context_end_ind, prefix_length)) 560 | stacks_info.append([None, update_start_ind, update_end_ind, prefix_length]) 561 | 562 | encoder_outputs2 = self._underlying_model.forward( 563 | *_stack_args(stacks_args, args_tensor_inds), 564 | **_stack_kwargs(stacks_kwargs, kwargs_tensor_keys)) 565 | bs = forward_kwargs[kwargs_tensor_keys[0]].size()[0] if len(kwargs_tensor_keys) > 0 else \ 566 | args[args_tensor_inds[0]].size()[0] 567 | for si, eo in zip(stacks_info, _unstack_encoder_outputs(encoder_outputs2, len(stacks_info), bs)): 568 | si[0] = eo 569 | res = _merge_encoder_outputs(stacks_info) 570 | 571 | return res 572 | 573 | def _window_indices(self, total_seq_len): 574 | """ 575 | when total_seq_len is smaller than our desired context length, we do not do sliding window at all. 576 | However, if it is longer, then we ALWAYS require the context length to be maximal, even if some windows have 577 | a lot of overlap. 578 | Also, first window will always update from the start, and last window will always update until the end. 579 | when applied, returns a generator that in each iteration produces for numbers: 580 | context_start_ind, context_end_ind, update_start_ind, update_end_ind 581 | 582 | context_start_ind, context_end_ind are indices in [0, total_seq_len], 583 | where context_end_ind > context_start_ind and when 584 | total_seq_len <= context_length then always context_end_ind = context_start_ind+context_length. 585 | The sequence of context_start_ind is strictly monotonic and same for context_end_ind. 586 | context_start_ind always start in 0 and 587 | context_end_ind will always end in total_seq_len. 588 | Gives us what token indices to take from the long input. 589 | 590 | update_start_ind, update_end_ind are indices in [0, min(total_seq_len, context_length)], 591 | where update_end_ind > update_start_ind 592 | and for all windows that are not in the edges (i.e. first/last window) we have 593 | update_end_ind-update_start_ind=context_length*window_fraction. 594 | For first window update_start_ind is always 0, and for last window, 595 | update_end_ind is always min(total_seq_len, context_length). 596 | They represents the start and end indices from the selected window of 597 | which tokens should be taken out for the final encoding 598 | 599 | When doing a full itartion, accounting for the fact that 600 | update_start_ind, update_end_ind are shifted by context_start_ind, we hould get that all indices in 601 | [0, total_seq_len] were covered exactly once 602 | 603 | Examples 604 | >>> from transformers import T5Tokenizer, T5Model 605 | >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') 606 | >>> model_ = T5Model.from_pretrained('t5-small') 607 | >>> model = SledModel(model_, 512) # testing with padding of 50% and context of 512 608 | 609 | >>> list(model._window_indices(256)) # List of: (context_start, context_end, update_start, update_end). short sequence 610 | [(0, 256, 0, 256)] 611 | >>> list(model._window_indices(510)) # another short sequence 612 | [(0, 510, 0, 510)] 613 | >>> list(model._window_indices(512)) # sequence of exactly the context size 614 | [(0, 512, 0, 512)] 615 | >>> list(model._window_indices(514)) # sequence of slightly more than the context size 616 | [(0, 512, 0, 384), (2, 514, 382, 512)] 617 | >>> list(model._window_indices(766)) # long sequence that does not require a full stride (update in the last chunk is smaller than what is possible) 618 | [(0, 512, 0, 384), (254, 766, 130, 512)] 619 | >>> list(model._window_indices(768)) # long sequence for exactly two perfect chunks 620 | [(0, 512, 0, 384), (256, 768, 128, 512)] 621 | >>> list(model._window_indices(780)) # very long sequence that does not require a full stride (update in the last chunk is smaller than what is possible) 622 | [(0, 512, 0, 384), (256, 768, 128, 384), (268, 780, 372, 512)] 623 | >>> windows = list(model._window_indices(1050)) 624 | >>> windows 625 | [(0, 512, 0, 384), (256, 768, 128, 384), (512, 1024, 128, 384), (538, 1050, 358, 512)] 626 | >>> windows = sum([list(range(us+cs, ue+cs)) for cs, _, us, ue in windows], []) # verify it covers exactly all the indices, each once 627 | >>> windows[:10] 628 | [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 629 | >>> windows[500:510] 630 | [500, 501, 502, 503, 504, 505, 506, 507, 508, 509] 631 | >>> len(windows) 632 | 1050 633 | >>> len(set(windows)) 634 | 1050 635 | >>> model = SledModel(model_, 256, window_fraction=0.75) # now testing with padding of 25% and context of 256 636 | 637 | >>> list(model._window_indices(128)) # List of: (context_start, context_end, update_start, update_end). short sequence 638 | [(0, 128, 0, 128)] 639 | >>> list(model._window_indices(254)) # another short sequence 640 | [(0, 254, 0, 254)] 641 | >>> list(model._window_indices(256)) # sequence of exactly the context size 642 | [(0, 256, 0, 256)] 643 | >>> list(model._window_indices(258)) # sequence of slightly more than the context size. margin is 256/8 -> 32 644 | [(0, 256, 0, 224), (2, 258, 222, 256)] 645 | >>> list(model._window_indices(446)) # long sequence that does not require a full stride (update in the last chunk is smaller than what is possible). stride should be 256-64=192 646 | [(0, 256, 0, 224), (190, 446, 34, 256)] 647 | >>> list(model._window_indices(448)) # long sequence for exactly two perfect chunks 648 | [(0, 256, 0, 224), (192, 448, 32, 256)] 649 | >>> list(model._window_indices(500)) # very long sequence that does not require a full stride (update in the last chunk is smaller than what is possible) 650 | [(0, 256, 0, 224), (192, 448, 32, 224), (244, 500, 172, 256)] 651 | >>> windows = list(model._window_indices(1050)) 652 | >>> windows 653 | [(0, 256, 0, 224), (192, 448, 32, 224), (384, 640, 32, 224), (576, 832, 32, 224), (768, 1024, 32, 224), (794, 1050, 198, 256)] 654 | >>> windows = sum([list(range(us+cs, ue+cs)) for cs, _, us, ue in windows], []) # verify it covers exactly all the indices, each once 655 | >>> windows[:10] 656 | [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 657 | >>> windows[500:510] 658 | [500, 501, 502, 503, 504, 505, 506, 507, 508, 509] 659 | >>> len(windows) 660 | 1050 661 | >>> len(set(windows)) 662 | 1050 663 | """ 664 | if total_seq_len <= self._context_size: 665 | yield 0, total_seq_len, 0, total_seq_len 666 | else: 667 | stride = self._context_size - 2 * self._window_margin 668 | context_start = update_start_ind = 0 669 | context_end = self._context_size 670 | update_end_ind = context_end - self._window_margin 671 | yield context_start, context_end, update_start_ind, update_end_ind # first window always should update from the beginning 672 | while context_end < total_seq_len: 673 | context_end = min(total_seq_len, context_end + stride) 674 | context_start = ( 675 | context_start + stride if context_end < total_seq_len else total_seq_len - self._context_size 676 | ) 677 | update_start_ind = max(update_start_ind + stride, update_end_ind) 678 | # last window always should update until the end 679 | update_end_ind = ( 680 | min(total_seq_len, update_end_ind + stride) if context_end < total_seq_len else total_seq_len 681 | ) 682 | 683 | cs, ce, us, ue = context_start, context_end, update_start_ind - context_start, \ 684 | update_end_ind - context_start 685 | 686 | yield cs, ce, us, ue 687 | 688 | def _fill_prefix_inputs(self, kwargs, kwargs_tensor_keys): 689 | prefix_inputs = {} 690 | k = PREFIX_KEY 691 | if PREFIX_KEY in kwargs: 692 | if self._prepend_prefix: 693 | if k not in kwargs_tensor_keys: 694 | warnings.warn(f'{k} is missing from kwargs_tensor_keys (though expected for SLED prefix prepending)') 695 | else: 696 | kwargs_tensor_keys.remove(k) 697 | prefix_inputs[k] = kwargs.pop(k) 698 | elif k in kwargs_tensor_keys: 699 | warnings.warn(f'{k} is given in kwargs_tensor_keys even though sled should not prepend prefix, ' 700 | f'that would mean the prefix would be ignored and the entire input will be treated ' 701 | f'as a single long document, which is probably not what you meant') 702 | return prefix_inputs 703 | 704 | @staticmethod 705 | def _prep_attention_mask_for_cross_attention(encode_prefix, attention_mask, prefix_length=None): 706 | # if we need to drop the prefix encodings, we also need to adjust the attention mask before decoding 707 | if not encode_prefix and prefix_length is not None: 708 | prefix_length = int(prefix_length[0]) 709 | return attention_mask[..., prefix_length:] 710 | return attention_mask 711 | 712 | 713 | class SledModel(SledPretrainedModel): 714 | """ 715 | >>> from transformers import T5Tokenizer, T5Model, BartModel, BartTokenizer 716 | 717 | >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') 718 | >>> model_ = T5Model.from_pretrained('t5-small') 719 | >>> model = SledModel(model_, 4) 720 | 721 | >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 722 | >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 723 | >>> outputs = model(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone()) 724 | >>> outputs = model(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone(), return_dict=True) 725 | >>> outputs = model(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone(), return_dict=False) 726 | """ 727 | 728 | def __init__(self, underlying_model: PreTrainedModel, config: SledConfig): 729 | super(SledModel, self).__init__(underlying_model, config) 730 | # validate the model can be used 731 | self._decoder_attr_name = getattr(underlying_model, "get_decoder_attr_name", lambda: "decoder")() 732 | self._encoder_attr_name = getattr(underlying_model, "get_encoder_attr_name", lambda: "encoder")() 733 | self._set_underlying_model_attr(self._decoder_attr_name, self.get_decoder()) 734 | self._mock_decoder = MockDecoder() 735 | assert "return_dict" in self._forward_kwargs_names 736 | assert "encoder_outputs" in self._forward_kwargs_names 737 | 738 | def _forward(self, *args, **kwargs): 739 | self._verify_config_consistency() 740 | kwargs, args = _fill_kwargs_with_args(self._forward_kwargs_names, *args, **kwargs) 741 | kwargs.setdefault("encoder_outputs", None) 742 | return_dict = kwargs.setdefault("return_dict", None) 743 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 744 | labels = kwargs.get("labels", None) 745 | kwargs["labels"] = None 746 | kwargs["return_dict"] = False 747 | kwargs.setdefault("labels", None) 748 | args_tensor_inds, kwargs_tensor_keys, s = _find_tensor_inds_and_size(*args, **kwargs) 749 | prefix_inputs = self._fill_prefix_inputs(kwargs, kwargs_tensor_keys) 750 | 751 | forward_kwargs = _extract_keyword_args(kwargs, self._forward_kwargs_names, None) 752 | if forward_kwargs["encoder_outputs"] is None: 753 | # encode, but first let's set decoder to be a mock, no reason to apply it over partial windows 754 | self._prep_for_encoding() # todo - add try catch every time we 'prep' something to rever the state on fail? 755 | 756 | forward_kwargs["encoder_outputs"] = self._run_sliding_window_forward( 757 | args_tensor_inds, kwargs_tensor_keys, s, *args, **prefix_inputs, **forward_kwargs 758 | ) 759 | forward_kwargs['attention_mask'] = self._prep_attention_mask_for_cross_attention(self._encode_prefix, 760 | forward_kwargs['attention_mask'], prefix_inputs.get('prefix_length', None)) 761 | 762 | # now, let's decode 763 | forward_kwargs["return_dict"] = return_dict 764 | forward_kwargs["labels"] = labels 765 | self._fix_post_encoding() 766 | if 'decoder_input_ids' in self._forward_kwargs_names and \ 767 | forward_kwargs.get('decoder_input_ids', None) is None and \ 768 | hasattr(self, 'prepare_decoder_input_ids_from_labels') : 769 | logger.warning('Passing a batch through the model without the decoder_input_ids is likely to cause issues. ' 770 | 'If you encounter cuda errors, make sure you use the prepare_decoder_input_ids_from_labels ' 771 | 'function of the model correctly before passing the input. ' 772 | 'If you are only performing prediction without training, you can safely ignore this message') 773 | res = self._underlying_model.forward( 774 | *args, **_extract_keyword_args(forward_kwargs, self._forward_kwargs_names) 775 | ) 776 | 777 | return res 778 | 779 | def _prep_for_encoding(self): 780 | if not getattr(self, '_preped_for_encoding', False): 781 | self._preped_for_encoding = True 782 | self._decoder = self.get_decoder() 783 | self._mock_decoder.first_device = getattr(self._decoder, "first_device", None) 784 | self._set_underlying_model_attr(self._decoder_attr_name, self._mock_decoder) 785 | 786 | def _fix_post_encoding(self): 787 | assert self._preped_for_encoding 788 | self._preped_for_encoding = False 789 | self._set_underlying_model_attr(self._decoder_attr_name, self._decoder) 790 | 791 | 792 | class MockTensorForConditionalGeneration: 793 | def __add__(self, other): 794 | return tuple() 795 | 796 | def __mul__(self, other): 797 | return tuple() 798 | 799 | def to(self, *_, **__): 800 | return self 801 | 802 | 803 | class MockDecoderForConditionalGeneration(nn.Module): 804 | pad_value = 0 805 | 806 | def forward(self, *_, **__): 807 | return (MockTensorForConditionalGeneration(),) 808 | 809 | def to(self, *_, **__): 810 | return self 811 | 812 | 813 | class MockLMHeadForConditionalGeneration(nn.Module): 814 | def forward(self, *_, **__): 815 | return MockTensorForConditionalGeneration() 816 | 817 | def to(self, *_, **__): 818 | return self 819 | 820 | def __getattr__(self, item): 821 | try: 822 | return super().__getattr__(item) 823 | except AttributeError: 824 | return self 825 | 826 | 827 | def _fill_kwargs_with_args(forward_param_names, *args, **kwargs): 828 | kwargs.update({arg_name: arg_value for arg_name, arg_value in zip(forward_param_names, args)}) 829 | return kwargs, tuple() 830 | 831 | 832 | class SledEncoder(SledPretrainedModel): 833 | def __init__(self, underlying_model: PreTrainedModel, config: SledConfig): 834 | super(SledEncoder, self).__init__(underlying_model, config) 835 | 836 | @property 837 | def _ignore_keys(self): 838 | return super()._ignore_keys | {'use_cache', 'is_encoder_decoder'} 839 | # Encoder models are not encoder-decoder but are meant for internal use by the SLED models 840 | 841 | 842 | def _forward(self, *args, **kwargs): 843 | kwargs, args = _fill_kwargs_with_args(self._forward_kwargs_names, *args, **kwargs) 844 | args_tensor_inds, kwargs_tensor_keys, s = _find_tensor_inds_and_size(*args, **kwargs) 845 | prefix_inputs = self._fill_prefix_inputs(kwargs, kwargs_tensor_keys) 846 | return self._run_sliding_window_forward(args_tensor_inds, kwargs_tensor_keys, s, *args, **prefix_inputs, 847 | **kwargs) 848 | 849 | def _skip_forward_for_decoder_only(self, args_tensor_inds, kwargs_tensor_keys, s, *args, 850 | prefix_length=None, **kwargs): 851 | encoder_outputs = super()._skip_forward_for_decoder_only(args_tensor_inds, kwargs_tensor_keys, s, 852 | *args, prefix_length, **kwargs) 853 | return BaseModelOutput(encoder_outputs[0]) 854 | 855 | 856 | class SledForConditionalGeneration(SledModel): 857 | """ 858 | >>> from transformers import T5Tokenizer, T5ForConditionalGeneration 859 | 860 | >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') 861 | >>> model_ = T5ForConditionalGeneration.from_pretrained('t5-small') 862 | >>> shared = model_.shared 863 | >>> model = SledForConditionalGeneration(model_, 4) 864 | >>> model._underlying_model == model_ # make sure the decoration works 865 | >>> model.shared == shared # make sure the decoration works 866 | >>> input_ids = tokenizer('The walks in park', return_tensors='pt').input_ids 867 | >>> labels = tokenizer(' cute dog the ', return_tensors='pt').input_ids 868 | >>> outputs = model(input_ids=input_ids, labels=labels) 869 | >>> outputs = model(input_ids=input_ids, labels=labels, return_dict=None) 870 | >>> outputs = model(input_ids=input_ids, labels=labels, return_dict=False) 871 | >>> outputs = model(input_ids=input_ids, labels=labels, return_dict=True) 872 | >>> loss = outputs.loss 873 | >>> logits = outputs.logits 874 | >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids # Batch size 1 875 | >>> outputs = model.generate(input_ids) 876 | """ 877 | OVERRIDDEN_METHODS = {'generate'} 878 | auto_model_loader = AutoModelForSeq2SeqLM 879 | 880 | def __init__(self, underlying_model: PreTrainedModel, config: SledConfig): 881 | super(SledForConditionalGeneration, self).__init__(underlying_model, config) 882 | self._mock_decoder = MockDecoderForConditionalGeneration() 883 | self._base_encoder = self.get_encoder() 884 | self._sled_encoder = SledEncoder(self._base_encoder, config) 885 | 886 | # override generation preparation functions that may be overridden by underlying models but will be found in our wrapper 887 | for method_name, _ in inspect.getmembers(GenerationMixin(), predicate=inspect.ismethod): 888 | if method_name not in SledForConditionalGeneration.OVERRIDDEN_METHODS: 889 | setattr(self, method_name, getattr(underlying_model, method_name)) 890 | 891 | # NOTE - the below affects the given underlying model, which means generating with it 892 | # directly may not work anymore 893 | self._underlying_model._prepare_encoder_decoder_kwargs_for_generation = \ 894 | self._get__prepare_encoder_decoder_kwargs_for_generation_func_override() 895 | self._underlying_model._validate_model_kwargs = _validate_model_kwargs # see hack details below 896 | 897 | def _get__prepare_encoder_decoder_kwargs_for_generation_func_override(self): 898 | # _prepare_encoder_decoder_kwargs_for_generation( 899 | # self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None 900 | # ) -> Dict[str, Any]: 901 | f = self._underlying_model._prepare_encoder_decoder_kwargs_for_generation 902 | encode_prefix = self._encode_prefix 903 | 904 | def _prepare_encoder_decoder_kwargs_for_generation( 905 | inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None) -> Dict[str, Any]: 906 | # override needed for underlying model in a conditional generation mode with prefix prepending and dropping 907 | model_kwargs = f(inputs_tensor, model_kwargs, model_input_name) 908 | model_kwargs['attention_mask'] = SledPretrainedModel._prep_attention_mask_for_cross_attention( 909 | encode_prefix, model_kwargs['attention_mask'], model_kwargs.get('prefix_length', None)) 910 | return model_kwargs 911 | 912 | return _prepare_encoder_decoder_kwargs_for_generation 913 | 914 | def _prep_for_encoding(self): 915 | was_preped = getattr(self, '_preped_for_encoding', False) 916 | super(SledForConditionalGeneration, self)._prep_for_encoding() 917 | if not was_preped: 918 | self._lm_head = getattr(self._underlying_model, "lm_head", None) 919 | setattr(self._underlying_model, "lm_head", MockLMHeadForConditionalGeneration()) 920 | 921 | def _fix_post_encoding(self): 922 | super(SledForConditionalGeneration, self)._fix_post_encoding() 923 | setattr(self._underlying_model, "lm_head", self._lm_head) 924 | 925 | def generate(self, *args, **kwargs): 926 | self._set_underlying_model_attr(self._encoder_attr_name, self._sled_encoder) 927 | try: 928 | res = self._underlying_model.generate(*args, **kwargs) 929 | finally: 930 | self._set_underlying_model_attr(self._encoder_attr_name, self._base_encoder) 931 | return res 932 | 933 | def _validate_model_kwargs(self, *args, **kwargs): 934 | # Newer versions of HF perform a check on the input args for generate and raise an exception when passing 935 | # prefix_length for example to this model because it doesn't list it explicitly. 936 | # This is a hack to support newer HF models until the generate() signature will be created dynamically to 937 | # include all the keyword args including prefix_length 938 | pass 939 | -------------------------------------------------------------------------------- /sled/tokenization_sled.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from transformers import PreTrainedTokenizer, AutoTokenizer, AutoConfig 4 | 5 | 6 | class SledTokenizer(PreTrainedTokenizer): 7 | auto_tokenizer_loader = AutoTokenizer 8 | auto_config_loader = AutoConfig 9 | 10 | def __init__(self, *args, **kwargs): 11 | super(SledTokenizer, self).__init__(*args, **kwargs) 12 | 13 | @classmethod 14 | def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs): 15 | assert isinstance(pretrained_model_name_or_path, str), 'pretrained_model_name_or_path must be a path to a ' \ 16 | 'checkpoint or a local config file (json)' 17 | if os.path.exists(pretrained_model_name_or_path): 18 | # if pretrained_model_name_or_path is a saved checkpoint 19 | config = kwargs.pop('config', None) 20 | if pretrained_model_name_or_path.endswith('json'): 21 | config = config or cls.auto_config_loader.from_pretrained(pretrained_model_name_or_path) 22 | return cls.auto_tokenizer_loader.from_pretrained(config.underlying_config, *init_inputs, **kwargs) 23 | else: 24 | # otherwise, it is a config json path 25 | raise NotImplementedError('loading a pretrained saved sled checkpoint is not yet implemented') 26 | else: 27 | # assume it is a model card on the hub 28 | config = kwargs.pop('config', None) 29 | config = config or cls.auto_config_loader.from_pretrained( 30 | pretrained_model_name_or_path, use_auth_token=kwargs.get('use_auth_token', False)) 31 | kwargs['use_fast'] = False 32 | return cls.auto_tokenizer_loader.from_pretrained(config.underlying_config, *init_inputs, **kwargs) -------------------------------------------------------------------------------- /sled/tokenization_sled_fast.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from transformers import AutoTokenizer, AutoConfig, PreTrainedTokenizerFast 4 | 5 | from .tokenization_sled import SledTokenizer 6 | 7 | 8 | class SledTokenizerFast(PreTrainedTokenizerFast): 9 | auto_tokenizer_loader = AutoTokenizer 10 | auto_config_loader = AutoConfig 11 | slow_tokenizer_class = SledTokenizer 12 | 13 | def __init__(self, *args, **kwargs): 14 | super(SledTokenizerFast, self).__init__(*args, **kwargs) 15 | 16 | @classmethod 17 | def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs): 18 | assert isinstance(pretrained_model_name_or_path, str), 'pretrained_model_name_or_path must be a path to a ' \ 19 | 'checkpoint or a local config file (json)' 20 | if os.path.exists(pretrained_model_name_or_path): 21 | # if pretrained_model_name_or_path is a saved checkpoint 22 | config = kwargs.pop('config', None) 23 | if pretrained_model_name_or_path.endswith('json'): 24 | config = config or cls.auto_config_loader.from_pretrained(pretrained_model_name_or_path) 25 | return cls.auto_tokenizer_loader.from_pretrained(config.underlying_config, *init_inputs, **kwargs) 26 | else: 27 | # otherwise, it is a config json path 28 | raise NotImplementedError('loading a pretrained saved sled checkpoint is not yet implemented') 29 | else: 30 | # assume it is a model card on the hub 31 | config = kwargs.pop('config', None) 32 | config = config or cls.auto_config_loader.from_pretrained( 33 | pretrained_model_name_or_path, use_auth_token=kwargs.get('use_auth_token', False)) 34 | kwargs['use_fast'] = True 35 | return cls.auto_tokenizer_loader.from_pretrained(config.underlying_config, *init_inputs, **kwargs) -------------------------------------------------------------------------------- /tests/configs/bart_base_sled.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "tau/sled", 3 | "underlying_config": "facebook/bart-base", 4 | "context_size": 256, 5 | "window_fraction": 0.5, 6 | "prepend_prefix": true, 7 | "encode_prefix": true, 8 | "sliding_method": "dynamic" 9 | } -------------------------------------------------------------------------------- /tests/configs/t5_base_sled.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "tau/sled", 3 | "underlying_config": "google/t5-v1_1-base", 4 | "context_size": 256, 5 | "window_fraction": 0.5, 6 | "prepend_prefix": true, 7 | "encode_prefix": true, 8 | "sliding_method": "dynamic" 9 | } -------------------------------------------------------------------------------- /tests/test_sled.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import sys 4 | 5 | from torch import nn 6 | 7 | from sled.tokenization_sled import SledTokenizer 8 | from sled.tokenization_sled_fast import SledTokenizerFast 9 | 10 | BART_BASE_SLED_URL = 'tau/bart-base-sled' 11 | BART_BASE_SLED_GOV_URL = 'tau/bart-base-sled-govreport' 12 | T5_BASE_SLED_URL = 'tau/t5-v1_1-base-sled' 13 | 14 | sys.path.insert(0, os.getcwd()) 15 | 16 | import inspect 17 | import unittest 18 | 19 | import torch 20 | from transformers import ( 21 | T5Tokenizer, 22 | T5Model, 23 | BartModel, 24 | BartTokenizer, 25 | T5ForConditionalGeneration, 26 | BartForConditionalGeneration, AutoConfig, AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM, 27 | BartTokenizerFast, T5TokenizerFast, 28 | ) 29 | from transformers.testing_utils import require_torch 30 | 31 | 32 | from sled.modeling_sled import SledModel, SledForConditionalGeneration, PREFIX_KEY 33 | from sled.configuration_sled import SledConfig 34 | 35 | use_auth_token = False 36 | 37 | def _compare_tuple_of_tensors(test_case, expected_tuple, got_tuple, prev_got_tuple, rtol=1e-5): 38 | for i, (exp_tensor, got_tensor) in enumerate(zip(expected_tuple, got_tuple)): 39 | if isinstance(exp_tensor, torch.Tensor): 40 | test_case.assertTrue(torch.allclose(exp_tensor, got_tensor, rtol=rtol)) 41 | # we can't expect the values to be the same when different context length, but at least can verify shapes 42 | if prev_got_tuple is not None: 43 | test_case.assertTrue(exp_tensor.size() == prev_got_tuple[i].size()) 44 | elif isinstance(exp_tensor, tuple): 45 | prev_got_tuple_i = prev_got_tuple[i] if prev_got_tuple is not None else None 46 | _compare_tuple_of_tensors(test_case, exp_tensor, got_tensor, prev_got_tuple_i, rtol=rtol) 47 | 48 | 49 | @require_torch 50 | class SLEDModelTest(unittest.TestCase): 51 | def _run_sled_model_test_case(self, model_, tokenizer, underlying_config: str, rtol=1e-5): 52 | model = SledModel(model_, SledConfig(underlying_config, context_size=4)) 53 | model.eval() # only change the model to be in eval (inference) mode, 54 | # thus not changing layer_norm params and removing dropout 55 | 56 | input_ids = tokenizer( 57 | "Studies have been shown that owning a dog is good for you", return_tensors="pt" 58 | ).input_ids # Batch size 1 59 | decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 60 | 61 | # simple verification there are no failures in the flow itself 62 | with torch.no_grad(): 63 | _ = model(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone()) 64 | _ = model(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone(), return_dict=None) 65 | outputs_dict = model( 66 | input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone(), return_dict=True 67 | ) 68 | outputs_no_dict = model( 69 | input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone(), return_dict=False 70 | ) 71 | 72 | # let's verify that if the sequence is short, we behave exactly as the base model 73 | model = SledModel(model_, SledConfig(underlying_config, context_size=512)) 74 | with torch.no_grad(): 75 | output_expected = model_( 76 | input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone(), return_dict=False 77 | ) 78 | 79 | # on non dict return type 80 | with torch.no_grad(): 81 | output_got = model(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone(), return_dict=False) 82 | self.assertEqual(type(output_expected), type(output_got)) 83 | self.assertEqual(type(output_expected), type(outputs_no_dict)) 84 | self.assertEqual(len(output_got), len(output_expected)) # should be tuple so it's ok 85 | self.assertEqual(len(output_got), len(outputs_no_dict)) # should be tuple so it's ok 86 | _compare_tuple_of_tensors(self, output_expected, output_got, outputs_no_dict, rtol=rtol) 87 | 88 | # on dict return type 89 | with torch.no_grad(): 90 | output_expected = model_( 91 | input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone(), return_dict=True 92 | ) 93 | output_got = model(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone(), return_dict=True) 94 | self.compare_outputs_dict(output_expected, output_got, outputs_dict, rtol=rtol) 95 | 96 | def compare_outputs_dict(self, output_expected, output_got, outputs_dict=None, rtol=1e-5): 97 | if output_got is None and outputs_dict is not None: 98 | output_got = output_expected 99 | if outputs_dict is None: 100 | outputs_dict = output_expected 101 | self.assertEqual(type(output_expected), type(output_got)) 102 | self.assertEqual(type(output_expected), type(outputs_dict)) 103 | self.assertListEqual(list(output_got.keys()), list(output_expected.keys())) 104 | self.assertListEqual(list(output_got.keys()), list(outputs_dict.keys())) 105 | for key in output_got.keys(): 106 | if isinstance(output_got[key], torch.Tensor): 107 | self.assertTrue(torch.allclose(output_got[key], output_expected[key], rtol=rtol)) 108 | self.assertFalse(output_got[key].requires_grad) 109 | self.assertFalse(output_expected[key].requires_grad) 110 | self.assertFalse(outputs_dict[key].requires_grad) 111 | 112 | # we can't expect the values to be the same when different context length, but at least can verify shapes 113 | self.assertTrue(output_got[key].size() == outputs_dict[key].size()) 114 | elif isinstance(output_got[key], tuple): 115 | _compare_tuple_of_tensors(self, output_got[key], output_expected[key], outputs_dict[key], rtol=rtol) 116 | 117 | def test_facade_get_attr_behavior(self): 118 | base_model = T5Model.from_pretrained("t5-small") 119 | sled_model = SledModel(base_model, SledConfig("t5-small", context_size=4)) 120 | self.assertEqual( 121 | base_model.shared, sled_model.shared 122 | ) # sled model does not have a 'shared' attribute, only it's base model does 123 | 124 | def _verify_loaded_sled_model(self, bart_sled_model, config_class, underlying_config_class, 125 | underlying_config_path="facebook/bart-base", expected_underlying_model=None): 126 | self.assertIsInstance(bart_sled_model, config_class) 127 | self.assertIsInstance(bart_sled_model._underlying_model, underlying_config_class) 128 | # make sure it has the pretrained weights and not random ones 129 | expected_underlying_model = expected_underlying_model or \ 130 | underlying_config_class.from_pretrained(underlying_config_path) 131 | # noinspection PyTypeChecker 132 | self.assertTrue(torch.all( 133 | expected_underlying_model.get_encoder().state_dict()['embed_tokens.weight'] == 134 | bart_sled_model.get_encoder().state_dict()['embed_tokens.weight']).item()) 135 | 136 | def test_load_from_pretrained(self): 137 | config_path = BART_BASE_SLED_URL 138 | another_config_path = 'configs/t5_base_sled.json' 139 | another_config_path_hub = T5_BASE_SLED_URL 140 | 141 | # first, we test the config 142 | # start by loading it explicitly 143 | bart_base_sled_config = SledConfig.from_pretrained(config_path, use_auth_token=use_auth_token) 144 | self.assertIsInstance(bart_base_sled_config, SledConfig) 145 | 146 | # now, with auto classes 147 | bart_base_sled_config2 = AutoConfig.from_pretrained(config_path, use_auth_token=use_auth_token) 148 | self.assertNotEqual(bart_base_sled_config, bart_base_sled_config2) # the explicit one didn't have a name or path 149 | bart_base_sled_config._name_or_path = bart_base_sled_config2._name_or_path 150 | self.assertEqual(bart_base_sled_config, bart_base_sled_config2) 151 | 152 | # Now, let's check the model loading 153 | self._verify_loaded_sled_model(SledModel.from_pretrained(config_path, use_auth_token=use_auth_token), 154 | SledModel, BartModel) # explicit load 155 | 156 | # now, with auto classes 157 | self._verify_loaded_sled_model(AutoModel.from_pretrained(config_path, use_auth_token=use_auth_token), 158 | SledModel, BartModel) # auto load 159 | 160 | # now, lets assert that "forConditionalGeneration" Also works as expected 161 | self._verify_loaded_sled_model(SledForConditionalGeneration.from_pretrained(config_path, 162 | use_auth_token=use_auth_token), 163 | SledForConditionalGeneration, BartForConditionalGeneration) # explicit 164 | 165 | # now, with auto classes 166 | self._verify_loaded_sled_model( 167 | AutoModelForSeq2SeqLM.from_pretrained(config_path, use_auth_token=use_auth_token), 168 | SledForConditionalGeneration, BartForConditionalGeneration) # auto 169 | 170 | # Finally, let's verify it also work with another model 171 | self._verify_loaded_sled_model(AutoModelForSeq2SeqLM.from_pretrained( 172 | another_config_path, use_auth_token=use_auth_token), SledForConditionalGeneration, 173 | T5ForConditionalGeneration, "google/t5-v1_1-base") 174 | 175 | self._verify_loaded_sled_model(AutoModelForSeq2SeqLM.from_pretrained( 176 | another_config_path_hub, use_auth_token=use_auth_token), SledForConditionalGeneration, 177 | T5ForConditionalGeneration, "google/t5-v1_1-base") 178 | 179 | def test_config_overrides(self): 180 | # Load the base model, and verify its consistency with the underlying config 181 | bart_base_model_sled = AutoModel.from_pretrained(BART_BASE_SLED_URL, use_auth_token=use_auth_token) 182 | self.assertFalse(bart_base_model_sled.config.gradient_checkpointing) 183 | self.assertFalse(bart_base_model_sled.underlying_model.config.gradient_checkpointing) 184 | self.assertTrue(bart_base_model_sled.config.use_cache) 185 | self.assertTrue(bart_base_model_sled.underlying_model.config.use_cache) 186 | 187 | # now, supply overrides and make sure they are consistent 188 | bart_base_model_sled = AutoModel.from_pretrained(BART_BASE_SLED_URL, gradient_checkpointing=True, use_cache=False, 189 | use_auth_token=use_auth_token) 190 | self.assertTrue(bart_base_model_sled.config.gradient_checkpointing) 191 | self.assertTrue(bart_base_model_sled.underlying_model.config.gradient_checkpointing) 192 | self.assertFalse(bart_base_model_sled.config.use_cache) 193 | self.assertFalse(bart_base_model_sled.underlying_model.config.use_cache) 194 | 195 | # Finally, set the config after load and make sure it is synced properly 196 | bart_base_model_sled.config.gradient_checkpointing = False 197 | bart_base_model_sled.config.use_cache = True 198 | self.assertFalse(bart_base_model_sled.config.gradient_checkpointing) 199 | self.assertTrue(bart_base_model_sled.config.use_cache) 200 | # this wouldn't have been fixed before the model was first used 201 | self.assertTrue(bart_base_model_sled.underlying_model.config.gradient_checkpointing) 202 | self.assertFalse(bart_base_model_sled.underlying_model.config.use_cache) 203 | # now, pass something through the model (even though it would fail) 204 | try: 205 | bart_base_model_sled(None) 206 | except: 207 | pass 208 | self.assertFalse(bart_base_model_sled.underlying_model.config.gradient_checkpointing) 209 | self.assertTrue(bart_base_model_sled.underlying_model.config.use_cache) 210 | 211 | def test_all_loading_options(self): 212 | # test load from local config, from a saved checkpoint and from a URL model card 213 | local_bart_base_sled_model = SledModel.from_pretrained('configs/bart_base_sled.json') 214 | self._verify_loaded_sled_model(local_bart_base_sled_model, SledModel, BartModel) # explicit load 215 | 216 | hub_bart_base_sled_model = SledModel.from_pretrained(BART_BASE_SLED_URL, use_auth_token=use_auth_token) 217 | self._verify_loaded_sled_model(hub_bart_base_sled_model, SledModel, BartModel) # explicit load 218 | 219 | # now, save and reload 220 | cache_dir = os.environ.get('XDG_CACHE_HOME', '/tmp/cache') 221 | out_dir = f'{cache_dir}/test_save_checkpoint' 222 | os.makedirs(out_dir, exist_ok=True) 223 | shutil.rmtree(out_dir) # cleanup previous checkpoints 224 | # let's change the model a little before saving it to make sure it works correctly and doesn't load the c 225 | # checkpoint on the hub 226 | local_bart_base_sled_model.get_encoder().state_dict()['embed_tokens.weight'] += 1 227 | # make sure they were indeed changed 228 | self.assertRaises(AssertionError, self._verify_loaded_sled_model, local_bart_base_sled_model, SledModel, 229 | BartModel) 230 | 231 | # now save and test 232 | local_bart_base_sled_model.save_pretrained(out_dir) 233 | self.assertTrue(os.path.isfile(os.path.join(out_dir, 'pytorch_model.bin'))) 234 | self.assertTrue(os.path.isfile(os.path.join(out_dir, 'config.json'))) 235 | loaded_bart_base_sled_model = SledModel.from_pretrained(out_dir, use_auth_token=use_auth_token) 236 | self._verify_loaded_sled_model(loaded_bart_base_sled_model, SledModel, BartModel, 237 | expected_underlying_model=local_bart_base_sled_model) 238 | 239 | 240 | def test_load_tokenizer_from_pretrained(self): 241 | config_path = BART_BASE_SLED_URL 242 | another_config_path = T5_BASE_SLED_URL 243 | 244 | # slow tokenizer 245 | # explicit load, should actually return a BartTokenizer (the default is a fast tokenizer) 246 | self.assertIsInstance(SledTokenizer.from_pretrained(config_path, use_auth_token=use_auth_token, use_fast=False), 247 | BartTokenizer) 248 | 249 | # autoload, should actually return a BartTokenizer 250 | self.assertIsInstance(AutoTokenizer.from_pretrained(config_path, use_auth_token=use_auth_token, use_fast=False), 251 | BartTokenizer) 252 | 253 | # fast tokenizer 254 | # explicit load, should actually return a BartTokenizerFast 255 | self.assertIsInstance(SledTokenizerFast.from_pretrained(config_path, use_auth_token=use_auth_token), 256 | BartTokenizerFast) 257 | # autoload, should actually return a BartTokenizerFast 258 | self.assertIsInstance(AutoTokenizer.from_pretrained(config_path, use_auth_token=use_auth_token), 259 | BartTokenizerFast) 260 | 261 | # and now with T5 262 | self.assertIsInstance(SledTokenizer.from_pretrained(another_config_path, use_auth_token=use_auth_token, 263 | use_fast=False), T5Tokenizer) 264 | self.assertIsInstance(AutoTokenizer.from_pretrained(another_config_path, use_auth_token=use_auth_token, 265 | use_fast=False), T5Tokenizer) 266 | self.assertIsInstance(SledTokenizerFast.from_pretrained(another_config_path, use_auth_token=use_auth_token, 267 | use_fast=True), T5TokenizerFast) 268 | self.assertIsInstance(AutoTokenizer.from_pretrained(another_config_path, use_auth_token=use_auth_token, 269 | use_fast=True), T5TokenizerFast) 270 | 271 | def test_load_finetuned_model(self): 272 | bart_base_sled = AutoModel.from_pretrained(BART_BASE_SLED_URL) 273 | bart_base_sled_gov = SledModel.from_pretrained(BART_BASE_SLED_GOV_URL) 274 | # testing embeedings have changed 275 | assert not torch.all(bart_base_sled.get_input_embeddings().weight == 276 | bart_base_sled_gov.get_input_embeddings().weight).item() 277 | # test decoder weights have changed 278 | assert not torch.all( 279 | bart_base_sled.state_dict()['_underlying_model.decoder.layers.0.encoder_attn.k_proj.weight'] == 280 | bart_base_sled_gov.state_dict()['_underlying_model.decoder.layers.0.encoder_attn.k_proj.weight'] 281 | ).item() 282 | # test encoder weights have changed 283 | assert not torch.all( 284 | bart_base_sled.state_dict()['_underlying_model.encoder.layers.0.self_attn.k_proj.weight'] == 285 | bart_base_sled_gov.state_dict()['_underlying_model.encoder.layers.0.self_attn.k_proj.weight'] 286 | ).item() 287 | 288 | 289 | def test_sled_on_t5(self): 290 | self._run_sled_model_test_case(T5Model.from_pretrained("t5-small"), T5Tokenizer.from_pretrained("t5-small"), 291 | "t5-small") 292 | 293 | def test_sled_on_bart(self): 294 | self._run_sled_model_test_case( 295 | BartModel.from_pretrained("facebook/bart-base"), BartTokenizer.from_pretrained("facebook/bart-base"), 296 | "facebook/bart-base" 297 | ) 298 | 299 | def test_forward_signature(self): 300 | # HF trainer uses the signature to choose which parts to take from a dataset, so we need to make sure our wrapped forward 301 | # function has the correct signature 302 | _model = BartModel.from_pretrained("facebook/bart-base") 303 | expected_sig = [param for param in inspect.signature(_model.forward).parameters.keys()] 304 | got_sig = [param for param in inspect.signature(SledModel(_model, SledConfig(context_size=128)).forward).parameters.keys()] 305 | self.assertListEqual(expected_sig, got_sig[:-1]) 306 | self.assertEqual(str(got_sig[-1]), PREFIX_KEY) 307 | 308 | def test_resize_token_embeddings(self): 309 | _model = BartModel.from_pretrained("facebook/bart-base") 310 | orig_vocab_size = _model.config.vocab_size 311 | self.assertNotEqual(orig_vocab_size, 512) 312 | _model.resize_token_embeddings(512) 313 | self.assertEqual(_model.config.vocab_size, 512) 314 | model = SledModel(_model, SledConfig("facebook/bart-base", context_size=128)) 315 | self.assertEqual(model.config.vocab_size, 512) 316 | self.assertEqual(_model.get_input_embeddings().weight.size()[0], 512) 317 | self.assertEqual(model.get_input_embeddings().weight.size()[0], 512) 318 | model.resize_token_embeddings(1024) 319 | self.assertEqual(model.config.vocab_size, 1024) 320 | self.assertEqual(_model.config.vocab_size, 1024) 321 | self.assertEqual(_model.get_input_embeddings().weight.size()[0], 1024) 322 | self.assertEqual(model.get_input_embeddings().weight.size()[0], 1024) 323 | 324 | def test_sled_model_parallel(self): 325 | assert torch.cuda.device_count() > 1 326 | model = SledModel(T5Model.from_pretrained("t5-small"), SledConfig("t5-small", context_size=512)).to("cuda:0") 327 | model.eval() 328 | assert model.is_parallelizable # bart is not, only t5 329 | assert not model.model_parallel 330 | assert model.device_map is None 331 | 332 | model2 = SledModel(T5Model.from_pretrained("t5-small"), SledConfig("t5-small", context_size=512)).to("cuda:0") 333 | model2.eval() 334 | model2.parallelize() 335 | assert model2.model_parallel 336 | assert len(model2.device_map) == min(3, torch.cuda.device_count()) 337 | 338 | tokenizer = T5Tokenizer.from_pretrained("t5-small") 339 | input_ids = tokenizer([" ".join(list("1234567890"))] * 16, return_tensors="pt").input_ids.to("cuda:0") 340 | assert input_ids.size() == (16, 12) # Batch size 16, input length 10 + BOS + EOS 341 | decoder_input_ids = tokenizer([" ".join(list("hello"))] * 16, return_tensors="pt").input_ids.to("cuda:0") 342 | assert decoder_input_ids.size() == (16, 11) # Batch size 16, inputs length 5 + BOS + 4 spaces + EOS 343 | 344 | # simple verification there are no failures in the flow itself 345 | with torch.no_grad(): 346 | outputs_expected = model(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone()) 347 | outputs_got = model2(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone()) 348 | self.compare_outputs_dict( 349 | outputs_expected, None, outputs_got, rtol=1e-5 350 | ) # the values may differ, but we need to make sure the sizes are correct 351 | 352 | def test_sled_with_data_parallel(self): 353 | assert torch.cuda.device_count() > 1 354 | model_ = SledModel(BartModel.from_pretrained("facebook/bart-base"), 355 | SledConfig("facebook/bart-base", context_size=128)).to("cuda:0") 356 | model = nn.DataParallel(model_) 357 | model.eval() 358 | assert isinstance(model, nn.DataParallel) 359 | replicas = model.replicate(model.module, model.device_ids) 360 | assert len(replicas) == torch.cuda.device_count() 361 | for i, rep in enumerate(replicas): 362 | assert isinstance(rep, SledModel) 363 | assert rep.device.index == i 364 | forward = replicas[0].forward 365 | replicas[0].forward = "hello" 366 | assert replicas[0].forward == "hello" 367 | assert replicas[1].forward != "hello" 368 | replicas[0].forward = forward 369 | 370 | tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") 371 | input_ids = tokenizer([" ".join(list("1234567890"))] * 16, return_tensors="pt").input_ids.to("cuda:0") 372 | assert input_ids.size() == (16, 12) # Batch size 16, input length 10 + BOS + EOS 373 | decoder_input_ids = tokenizer([" ".join(list("hello"))] * 16, return_tensors="pt").input_ids.to("cuda:0") 374 | assert decoder_input_ids.size() == (16, 7) # Batch size 16, inputs length 5 + BOS + EOS 375 | 376 | # simple verification there are no failures in the flow itself 377 | with torch.no_grad(): 378 | outputs_expected = model_(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone()) 379 | outputs_got = model(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone()) 380 | self.compare_outputs_dict( 381 | outputs_expected, None, outputs_got, rtol=1e-5 382 | ) # the values may differ, but we need to make sure the sizes are correct 383 | 384 | def test_sled_with_input_prefix(self): 385 | rtol = 1e-5 386 | model_, tokenizer = BartModel.from_pretrained("facebook/bart-base"), BartTokenizer.from_pretrained("facebook/bart-base") 387 | model = SledModel(model_, SledConfig("facebook/bart-base", context_size=16, prepend_prefix=True)) 388 | model.eval() # only change the model to be in eval (inference) mode, thus not changing layer_norm params and removing dropout 389 | 390 | document_input_ids = tokenizer( 391 | "Studies have been shown that owning a dog is good for you", return_tensors="pt" 392 | ).input_ids # Batch size 1 393 | input_prefix_ids = tokenizer("What did studies show?\n\n", return_tensors="pt").input_ids # Batch size 1 394 | decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 395 | 396 | input_ids = torch.concat((input_prefix_ids, document_input_ids), dim=-1) 397 | prefix_length = torch.LongTensor([[input_prefix_ids.size(1)]]) 398 | 399 | # simple verification there are no failures in the flow itself 400 | with torch.no_grad(): 401 | _ = model(input_ids=input_ids.clone(), prefix_length=prefix_length.clone(), 402 | decoder_input_ids=decoder_input_ids.clone()) 403 | _ = model(input_ids=input_ids.clone(), prefix_length=prefix_length.clone(), 404 | decoder_input_ids=decoder_input_ids.clone(), 405 | return_dict=None) 406 | outputs_dict = model(input_ids=input_ids.clone(), prefix_length=prefix_length.clone(), 407 | decoder_input_ids=decoder_input_ids.clone(), return_dict=True) 408 | outputs_no_dict = model(input_ids=input_ids.clone(), prefix_length=prefix_length.clone(), 409 | decoder_input_ids=decoder_input_ids.clone(), return_dict=False) 410 | 411 | # let's verify that if the sequence is short, we behave exactly as the base model 412 | model = SledModel(model_, SledConfig("facebook/bart-base", context_size=512, prepend_prefix=False)) 413 | # treat the while input as a single document and make sure the context size is large enough to contain it wholly 414 | model.eval() 415 | with torch.no_grad(): 416 | output_expected = model_( 417 | input_ids=input_ids.clone(), 418 | decoder_input_ids=decoder_input_ids.clone(), return_dict=False 419 | ) 420 | 421 | # on non dict return type 422 | with torch.no_grad(): 423 | output_got = model(input_ids=input_ids.clone(), prefix_length=prefix_length.clone(), 424 | decoder_input_ids=decoder_input_ids.clone(), return_dict=False) 425 | self.assertEqual(type(output_expected), type(output_got)) 426 | self.assertEqual(type(output_expected), type(outputs_no_dict)) 427 | self.assertEqual(len(output_got), len(output_expected)) # should be tuple so it's ok 428 | self.assertEqual(len(output_got), len(outputs_no_dict)) # should be tuple so it's ok 429 | _compare_tuple_of_tensors(self, output_expected, output_got, outputs_no_dict, rtol=rtol) 430 | 431 | # on dict return type 432 | with torch.no_grad(): 433 | output_expected = model_(input_ids=input_ids.clone(), 434 | decoder_input_ids=decoder_input_ids.clone(), return_dict=True 435 | ) 436 | output_got = model(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone(), return_dict=True) 437 | self.compare_outputs_dict(output_expected, output_got, outputs_dict, rtol=rtol) 438 | 439 | @unittest.expectedFailure 440 | def test_sled_with_variable_sized_input_prefix(self): 441 | raise NotImplementedError 442 | 443 | @unittest.expectedFailure 444 | def test_sled_overhead(self): 445 | raise NotImplementedError 446 | 447 | @unittest.expectedFailure 448 | def test_sled_multiple_update_steps(self): 449 | raise NotImplementedError 450 | 451 | @unittest.expectedFailure 452 | def test_eval_mode(self): 453 | raise NotImplementedError 454 | # TODO - assert that eval() works by looking at the dropout rate? Three forward passes with the model in train, 455 | # then three in eval and see that the train differ, but eval does not? 456 | 457 | @unittest.expectedFailure 458 | def test_prefix_prepending(self): 459 | # TODO - first, test the expected versions (no prepending+ no prefix given or w/ prepending+prefix given) 460 | # todo - test sled that should prepend but doesn't get it 461 | # todo - test sled that shouldn't prepend but gets a prefix length (make sure it ignores it) 462 | raise NotImplementedError 463 | 464 | @unittest.expectedFailure 465 | def test_drop_prefix_encoding(self): 466 | raise NotImplementedError 467 | 468 | 469 | 470 | @require_torch 471 | class SLEDForConditionalGenerationTest(unittest.TestCase): 472 | def test_facade_get_attr_behavior(self): 473 | base_model = T5ForConditionalGeneration.from_pretrained("t5-small") 474 | sled_model = SledForConditionalGeneration(base_model, SledConfig("t5-small", context_size=4)) 475 | self.assertEqual( 476 | base_model.shared, sled_model.shared 477 | ) # sled model does not have a 'shared' attribute, only it's base model does 478 | 479 | def test_sled_for_cg_on_t5(self): 480 | self._run_sled_for_cg_model_test_case( 481 | T5ForConditionalGeneration.from_pretrained("t5-small"), T5Tokenizer.from_pretrained("t5-small"), "t5-small" 482 | ) 483 | 484 | def test_sled_for_cg_on_bart(self): 485 | self._run_sled_for_cg_model_test_case( 486 | BartForConditionalGeneration.from_pretrained("facebook/bart-base"), 487 | BartTokenizer.from_pretrained("facebook/bart-base"), "facebook/bart-base" 488 | ) 489 | 490 | def test_sled_for_cg_generate_on_t5(self): 491 | self._run_sled_for_cg_model_generate_test_case( 492 | T5ForConditionalGeneration.from_pretrained("t5-small"), T5Tokenizer.from_pretrained("t5-small"), "t5-small" 493 | ) 494 | 495 | def test_sled_for_cg_generate_on_bart(self): 496 | self._run_sled_for_cg_model_generate_test_case( 497 | BartForConditionalGeneration.from_pretrained("facebook/bart-base"), 498 | BartTokenizer.from_pretrained("facebook/bart-base"), "facebook/bart-base" 499 | ) 500 | 501 | def _run_sled_for_cg_model_test_case(self, model_, tokenizer, underlying_config: str, rtol=1e-5): 502 | model = SledForConditionalGeneration(model_, SledConfig(underlying_config, context_size=4)) 503 | model.eval() 504 | 505 | input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids 506 | labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids 507 | 508 | # simple verification there are no failures in the flow itself 509 | with torch.no_grad(): 510 | _ = model(input_ids=input_ids.clone(), labels=labels.clone()) 511 | _ = model(input_ids=input_ids.clone(), labels=labels.clone(), return_dict=None) 512 | outputs_dict = model(input_ids=input_ids.clone(), labels=labels.clone(), return_dict=True) 513 | outputs_no_dict = model(input_ids=input_ids.clone(), labels=labels.clone(), return_dict=False) 514 | 515 | # let's verify that if the sequence is short, we behave exactly as the base model 516 | model = SledForConditionalGeneration(model_, SledConfig(underlying_config, context_size=512)) 517 | output_expected = model_(input_ids=input_ids.clone(), labels=labels.clone(), return_dict=False) 518 | 519 | # on non dict return type 520 | with torch.no_grad(): 521 | output_got = model(input_ids=input_ids.clone(), labels=labels.clone(), return_dict=False) 522 | self.assertEqual(type(output_expected), type(output_got)) 523 | self.assertEqual(type(output_expected), type(outputs_no_dict)) 524 | self.assertEqual(len(output_got), len(output_expected)) # should be tuple so it's ok 525 | self.assertEqual(len(output_got), len(outputs_no_dict)) # should be tuple so it's ok 526 | _compare_tuple_of_tensors(self, output_expected, output_got, outputs_no_dict, rtol=rtol) 527 | 528 | # on dict return type 529 | with torch.no_grad(): 530 | output_expected = model_(input_ids=input_ids.clone(), labels=labels.clone(), return_dict=True) 531 | output_got = model(input_ids=input_ids.clone(), labels=labels.clone(), return_dict=True) 532 | self.assertEqual(type(output_expected), type(output_got)) 533 | self.assertEqual(type(output_expected), type(outputs_dict)) 534 | self.assertListEqual(list(output_got.keys()), list(output_expected.keys())) 535 | self.assertListEqual(list(output_got.keys()), list(outputs_dict.keys())) 536 | for key in output_got.keys(): 537 | if isinstance(output_got[key], torch.Tensor): 538 | self.assertTrue(torch.allclose(output_got[key], output_expected[key], rtol=rtol)) 539 | # we can't expect the values to be the same when different context length, but at least can verify shapes 540 | self.assertTrue(output_got[key].size() == outputs_dict[key].size()) 541 | elif isinstance(output_got[key], tuple): 542 | _compare_tuple_of_tensors(self, output_got[key], output_expected[key], outputs_dict[key], rtol=rtol) 543 | 544 | def _run_sled_for_cg_model_generate_test_case(self, model_, tokenizer, underlying_config: str, rtol=1e-5): 545 | input_ids = tokenizer( 546 | "summarize: studies have shown that owning a dog is good for you ", return_tensors="pt" 547 | ).input_ids # Batch size 1 548 | model_.eval() 549 | # once it is used in SledForConditionalGeneration, it cannot generate directly anymore 550 | output_expected = model_.generate(input_ids.clone()) 551 | 552 | model = SledForConditionalGeneration(model_, SledConfig(underlying_config, context_size=4)) 553 | model.eval() 554 | 555 | # simple verification there are no failures in the flow itself 556 | _ = model.generate( 557 | torch.cat((input_ids.clone(), input_ids.clone())) 558 | ) # just to make sure can generate over two sequences 559 | _ = model.generate(input_ids.clone()) 560 | outputs_no_dict = model.generate(input_ids.clone()) 561 | assert outputs_no_dict.dim() == 2 562 | assert outputs_no_dict.size()[0] == 1 563 | 564 | # let's verify that if the sequence is short, we behave exactly as the base model 565 | model = SledForConditionalGeneration(model_, SledConfig(underlying_config, context_size=512)) 566 | model.eval() 567 | 568 | output_got = model.generate(input_ids.clone()) 569 | self.assertEqual(type(output_expected), type(output_got)) 570 | self.assertEqual(type(output_expected), type(outputs_no_dict)) 571 | self.assertEqual(len(output_got), len(output_expected)) # should be tuple so it's ok 572 | self.assertEqual(len(output_got), len(outputs_no_dict)) # should be tuple so it's ok 573 | # no point checking the dim as the generation length may be different 574 | _compare_tuple_of_tensors(self, (output_expected,), (output_got,), None, rtol=rtol) 575 | --------------------------------------------------------------------------------