├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── examples
├── README.md
├── __init__.py
├── seq2seq
│ ├── README.md
│ ├── __init__.py
│ ├── configs
│ │ ├── data
│ │ │ ├── contract_nli.json
│ │ │ ├── gov_report.json
│ │ │ ├── hotpotqa.json
│ │ │ ├── hotpotqa_second_only.json
│ │ │ ├── narative_qa.json
│ │ │ ├── qasper.json
│ │ │ ├── qmsum.json
│ │ │ ├── quality.json
│ │ │ ├── squad.json
│ │ │ ├── squad_ordered_distractors.json
│ │ │ ├── squad_shuffled_distractors.json
│ │ │ └── summ_screen_fd.json
│ │ ├── model
│ │ │ ├── bart_base_sled.json
│ │ │ └── bart_large_sled.json
│ │ └── training
│ │ │ └── base_training_args.json
│ ├── metrics
│ │ ├── __init__.py
│ │ └── metrics.py
│ ├── run.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── custom_hf_argument_parser.py
│ │ ├── custom_seq2seq_trainer.py
│ │ ├── decoding.py
│ │ ├── duplicates.py
│ │ └── override_training_args.py
└── usage_example.py
├── setup.py
├── sled
├── __init__.py
├── configuration_sled.py
├── modeling_sled.py
├── tokenization_sled.py
└── tokenization_sled_fast.py
└── tests
├── configs
├── bart_base_sled.json
└── t5_base_sled.json
└── test_sled.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Maor Ivgi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | prune tests/
2 | prune examples/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SLED
2 | The official repository for Efficient Long-Text Understanding Using Short-Text Models [(Ivgi et al., 2022)](https://arxiv.org/abs/2208.00748.pdf), to appear in Transactions of the Association for Computational Linguistics (TACL) 2023 .
3 |
4 | SLED models use pretrained, short-range encoder-decoder models, and apply them over.
5 | long-text inputs by splitting the input into multiple overlapping chunks, encoding each independently and perform fusion-in-decoder.
6 |
7 |
8 | ## Data
9 | The data for this paper is hosted on the dataset hub [here](https://huggingface.co/datasets/tau/sled).
10 | It is based on the [SCROLLS dataset](https://huggingface.co/datasets/tau/scrolls) ([paper](https://arxiv.org/pdf/2201.03533.pdf)), the [SQuAD 1.1 dataset](https://huggingface.co/datasets/squad) ([paper](https://arxiv.org/pdf/1606.05250.pdf)) and the [HotpotQA dataset](https://huggingface.co/datasets/hotpot_qa) ([paper](https://arxiv.org/pdf/1809.09600.pdf)).
11 | It doesn't contain any unpublished data, but includes the configuration needed for the paper.
12 |
13 | Usage example :
14 | ```python
15 | from datasets import load_dataset
16 | qasper = load_dataset("tau/sled","qasper")
17 | ```
18 |
19 | ## Installation
20 |
21 | Make sure to install pytorch according to your machine spec. See installation options [here](https://pytorch.org/get-started/locally/).
22 |
23 | Installing SLED is easy with pip.
24 | ```
25 | pip install py-sled
26 | ```
27 |
28 | Some backbone models require additional dependencies. If you wish to work with T5 for example, you can install using.
29 | ```
30 | pip install py-sled[t5]
31 | ```
32 |
33 | If you wish to run the examples, install the required dependencies with
34 | ```
35 | pip install py-sled[examples]
36 | ```
37 |
38 | If you wish to continue developing this repository, install the full development requirments with
39 | ```
40 | pip install py-sled[dev]
41 | ```
42 |
43 | ## Usage
44 | Working with SLED is seamless when working with HuggingFace's Transformers AutoClasses.
45 |
46 | A minimal usage example:
47 | ```python
48 | import sled # ** required so SLED would be properly registered by the AutoClasses **
49 | from transformers import AutoTokenizer, AutoModel
50 | tokenizer = AutoTokenizer.from_pretrained('tau/bart-base-sled')
51 | model = AutoModel.from_pretrained('tau/bart-base-sled')
52 | inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
53 | outputs = model(**inputs)
54 | last_hidden_states = outputs.last_hidden_state
55 | ```
56 |
57 | _Important_: You need to `import sled` before using the AutoClass (e.g. `AutoModel.from_pretrained('tau/bart-base-sled)`) for it to work.
58 |
59 | Minimal working example can be found [here](examples/usage_example.py).
60 |
61 | To work with SCROLLS like data that was used for the paper, see [here](examples/seq2seq).
62 |
63 | ### Custom datasets
64 | For SLED to be able to prepend the prefix input to every chunk, it requires the input tensor `prefix_length`.
65 | If using a custom dataset, you can refer to [run.py](examples/seq2seq/run.py) for the correct way to preprocess the data.
66 |
67 | _Note_: Currently, HF's Seq2SeqTrainer doesn't pass the `prefix_length` tensor in the prediction loop, so you
68 | should use the [CustomSeq2SeqTrainer](examples/seq2seq/utils/custom_seq2seq_trainer.py) or something similar until it is
69 | fixed.
70 |
71 | ### Backbone models
72 | There are multiple model cards available on HuggingfaceHub including
73 | - [Bart-Base SLED](https://huggingface.co/tau/bart-base-sled) (model name `tau/bart-base-sled`)
74 | - [Bart-Large SLED](https://huggingface.co/tau/bart-large-sled) (model name `tau/bart-base-sled`)
75 | - [T5(v1.1)-base SLED](https://huggingface.co/tau/t5-v1_1-base-sled) (model name `tau/t5-v1_1-base-sled`)
76 | - [T5(v1.1)-large SLED](https://huggingface.co/tau/t5-v1_1-large-sled) (model name `tau/t5-v1_1-large-sled`)
77 |
78 | If you wish to use a custom model that is available as a model card (public or private) on the hub, or use
79 | different parameters for SLED, you can create a json config file like the below, and change the underlying_config to your custom model card.
80 | ```json
81 | {
82 | "model_type": "tau/sled",
83 | "underlying_config": "facebook/bart-base",
84 | "context_size": 256,
85 | "window_fraction": 0.5,
86 | "prepend_prefix": true,
87 | "encode_prefix": true,
88 | "sliding_method": "dynamic"
89 | }
90 | ```
91 | You can then load it like below
92 | ```python
93 | import sled
94 | from transformers import AutoModelForSeq2SeqLM
95 | custom_sled_model = AutoModelForSeq2SeqLM.from_pretrained()
96 | ```
97 |
98 | ## Citation
99 |
100 | If you use this repository, please cite as below:
101 | ```
102 | @inproceedings{Ivgi2022EfficientLU,
103 | title={Efficient Long-Text Understanding with Short-Text Models},
104 | author={Maor Ivgi and Uri Shaham and Jonathan Berant},
105 | year={2022}
106 | }
107 | ```
108 |
109 |
110 | ## Disclaimer
111 | This repository is still under active development, and may contain some unintended behavior.
112 | Please open an issue if any unexpected behaviour occurs, and we will promptly try to fix it.
113 |
114 | The code was developed and tested with transformers version 4.21.0. Newer version may break backward
115 | compatibility and cause unexpected behaviour.
116 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # Examples
2 | Working with SLED is seamless when working with HuggingFace's Transformers AutoClasses.
3 |
4 | _Important_: You need to `import sled` before using the AutoClass (e.g. `AutoModel.from_pretrained('tau/bart-base-sled)`) for it to work.
5 |
6 | Minimal working example can be found [here](usage_example.py)
7 |
8 | To work with SCROLLS like data that was used for the paper, see [here](seq2seq)
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mivg/SLED/2991895d01e5ed5871de41676a14b25c534c523a/examples/__init__.py
--------------------------------------------------------------------------------
/examples/seq2seq/README.md:
--------------------------------------------------------------------------------
1 | # Seq2Seq finetuning
2 | In this example, you can find the scripts needed to finetune sled models on SCROLLS like data.
3 |
4 | The entrypoint script is [run.py](run.py)
5 |
6 | ## Usage
7 | After setting up your environment as described [here](https://github.com/Mivg/SLED#installation), you can run the script to finetune,
8 | evaluate and generate predictions for all the datasets in the [SLED dataset](https://huggingface.co/datasets/tau/sled)
9 | (based on SCROLLS).
10 |
11 | Like most recipes, you can view all possible settable parameters by running
12 | ```
13 | python run.py --help
14 | ```
15 |
16 | To run, you can either set the parameters with command-line arguments (e.g. `--model_name_or_path tau/bart-base-sled`)
17 | or use some predfined json files to set recurring configurations. You can pass as many json files as you would like,
18 | but make sure you pass them before any other command line argument. For example, you can do the following:
19 | ```
20 | python run.py configs/data/squad.json \
21 | configs/model/bart_base_sled.json \
22 | configs/training/base_training_args.json \
23 | --output_dir /tmp/output_sled
24 | --learning_rate 2e-5
25 | ```
26 |
27 | Example jsons files are [here](https://github.com/Mivg/SLED/tree/main/examples/seq2seq/configs).
28 |
--------------------------------------------------------------------------------
/examples/seq2seq/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mivg/SLED/2991895d01e5ed5871de41676a14b25c534c523a/examples/seq2seq/__init__.py
--------------------------------------------------------------------------------
/examples/seq2seq/configs/data/contract_nli.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "contract_nli",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 64,
6 | "generation_max_length": 8,
7 | "num_train_epochs": 20,
8 | "metric_names": ["exact_match"],
9 | "metric_for_best_model": "exact_match",
10 | "greater_is_better": true
11 | }
--------------------------------------------------------------------------------
/examples/seq2seq/configs/data/gov_report.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "gov_report",
4 | "max_source_length": 16384,
5 | "generation_max_length": 1024,
6 | "max_prefix_length": 0,
7 | "num_train_epochs": 10,
8 | "metric_names": ["rouge"],
9 | "metric_for_best_model": "rouge/geometric_mean",
10 | "greater_is_better": true
11 | }
--------------------------------------------------------------------------------
/examples/seq2seq/configs/data/hotpotqa.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "hotpotqa",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 64,
6 | "generation_max_length": 128,
7 | "num_train_epochs": 9,
8 | "metric_names": ["f1", "exact_match"],
9 | "metric_for_best_model": "f1",
10 | "greater_is_better": true
11 | }
--------------------------------------------------------------------------------
/examples/seq2seq/configs/data/hotpotqa_second_only.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "hotpotqa_second_only",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 64,
6 | "generation_max_length": 128,
7 | "num_train_epochs": 9,
8 | "metric_names": ["f1", "exact_match"],
9 | "metric_for_best_model": "f1",
10 | "greater_is_better": true
11 | }
--------------------------------------------------------------------------------
/examples/seq2seq/configs/data/narative_qa.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "narrative_qa",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 64,
6 | "num_train_epochs": 2,
7 | "generation_max_length": 128,
8 | "metric_names": ["f1"],
9 | "metric_for_best_model": "f1",
10 | "greater_is_better": true
11 | }
--------------------------------------------------------------------------------
/examples/seq2seq/configs/data/qasper.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "qasper",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 64,
6 | "generation_max_length": 128,
7 | "num_train_epochs": 20,
8 | "metric_names": ["f1"],
9 | "metric_for_best_model": "f1",
10 | "greater_is_better": true
11 | }
--------------------------------------------------------------------------------
/examples/seq2seq/configs/data/qmsum.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "qmsum",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 64,
6 | "num_train_epochs": 20,
7 | "generation_max_length": 1024,
8 | "metric_names": ["rouge"],
9 | "metric_for_best_model": "rouge/geometric_mean",
10 | "greater_is_better": true
11 | }
--------------------------------------------------------------------------------
/examples/seq2seq/configs/data/quality.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "quality",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 160,
6 | "num_train_epochs": 20,
7 | "generation_max_length": 128,
8 | "metric_names": ["exact_match"],
9 | "metric_for_best_model": "exact_match",
10 | "greater_is_better": true
11 | }
--------------------------------------------------------------------------------
/examples/seq2seq/configs/data/squad.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "squad",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 64,
6 | "num_train_epochs": 3,
7 | "generation_max_length": 128,
8 | "metric_names": ["f1", "exact_match"],
9 | "metric_for_best_model": "f1",
10 | "greater_is_better": true
11 | }
--------------------------------------------------------------------------------
/examples/seq2seq/configs/data/squad_ordered_distractors.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "squad_ordered_distractors",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 64,
6 | "num_train_epochs": 3,
7 | "generation_max_length": 128,
8 | "metric_names": ["f1", "exact_match"],
9 | "metric_for_best_model": "f1",
10 | "greater_is_better": true
11 | }
--------------------------------------------------------------------------------
/examples/seq2seq/configs/data/squad_shuffled_distractors.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "squad_shuffled_distractors",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 64,
6 | "num_train_epochs": 3,
7 | "generation_max_length": 128,
8 | "metric_names": ["f1", "exact_match"],
9 | "metric_for_best_model": "f1",
10 | "greater_is_better": true
11 | }
--------------------------------------------------------------------------------
/examples/seq2seq/configs/data/summ_screen_fd.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset_name": "tau/sled",
3 | "dataset_config_name": "summ_screen_fd",
4 | "max_source_length": 16384,
5 | "max_prefix_length": 0,
6 | "num_train_epochs": 10,
7 | "generation_max_length": 1024,
8 | "metric_names": ["rouge"],
9 | "metric_for_best_model": "rouge/geometric_mean",
10 | "greater_is_better": true
11 | }
--------------------------------------------------------------------------------
/examples/seq2seq/configs/model/bart_base_sled.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_name_or_path": "tau/bart-base-sled",
3 | "use_auth_token": false,
4 | "pad_prefix": true,
5 | "max_target_length": 1024,
6 | "fp16": true
7 | }
--------------------------------------------------------------------------------
/examples/seq2seq/configs/model/bart_large_sled.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_name_or_path": "tau/bart-large-sled",
3 | "use_auth_token": false,
4 | "pad_prefix": true,
5 | "max_target_length": 1024,
6 | "fp16": true
7 | }
--------------------------------------------------------------------------------
/examples/seq2seq/configs/training/base_training_args.json:
--------------------------------------------------------------------------------
1 | {
2 | "eval_steps_override": 0.5,
3 | "save_steps_override": 0.5,
4 | "evaluation_strategy": "steps",
5 | "eval_fraction": 1000,
6 | "predict_with_generate": true,
7 | "gradient_checkpointing": true,
8 | "do_train": true,
9 | "do_eval": true,
10 | "seed": 42,
11 | "warmup_ratio": 0.1,
12 | "save_total_limit": 2,
13 | "preprocessing_num_workers": 1,
14 | "load_best_model_at_end": true,
15 | "lr_scheduler": "linear",
16 | "adam_epsilon": 1e-6,
17 | "adam_beta1": 0.9,
18 | "adam_beta2": 0.98,
19 | "weight_decay": 0.001
20 | }
--------------------------------------------------------------------------------
/examples/seq2seq/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .metrics import load_metric
2 |
--------------------------------------------------------------------------------
/examples/seq2seq/metrics/metrics.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict
2 | import os
3 | import importlib
4 | from abc import ABC, abstractmethod
5 | import inspect
6 | import shutil
7 |
8 | import numpy as np
9 |
10 | from utils.decoding import decode
11 | from datasets import load_metric as hf_load_metric
12 | from huggingface_hub import hf_hub_download
13 |
14 |
15 | class Metric(ABC):
16 | def __init__(self, **kwargs) -> None:
17 | super().__init__()
18 | self._kwargs = kwargs
19 |
20 | self.prefix = os.path.splitext(os.path.basename(inspect.getfile(self.__class__)))[0]
21 | self.requires_decoded = False
22 |
23 | def __call__(self, id_to_pred, id_to_labels, is_decoded=False):
24 | if self.requires_decoded and is_decoded is False:
25 | id_to_pred = self._decode(id_to_pred)
26 | id_to_labels = self._decode(id_to_labels)
27 | return self._compute_metrics(id_to_pred, id_to_labels)
28 |
29 | @abstractmethod
30 | def _compute_metrics(self, id_to_pred, id_to_labels) -> Dict[str, float]:
31 | return
32 |
33 | def _decode(self, id_to_something):
34 | tokenizer = self._kwargs.get("tokenizer")
35 | data_args = self._kwargs.get("data_args")
36 | return decode(id_to_something, tokenizer, data_args)
37 |
38 |
39 | class MetricCollection(Metric):
40 | def __init__(self, metrics: List[Metric], **kwargs):
41 | super().__init__(**kwargs)
42 | self._metrics = metrics
43 |
44 | def __call__(self, id_to_pred, id_to_labels):
45 | return self._compute_metrics(id_to_pred, id_to_labels)
46 |
47 | def _compute_metrics(self, id_to_pred, id_to_labels):
48 | results = {}
49 |
50 | id_to_pred_decoded = None
51 | id_to_labels_decoded = None
52 | for metric in self._metrics:
53 | metric_prefix = f"{metric.prefix}/" if metric.prefix else ""
54 | if metric.requires_decoded:
55 | if id_to_pred_decoded is None:
56 | id_to_pred_decoded = self._decode(id_to_pred)
57 | if id_to_labels_decoded is None:
58 | id_to_labels_decoded = self._decode(id_to_labels)
59 |
60 | result = metric(id_to_pred_decoded, id_to_labels_decoded, is_decoded=True)
61 | else:
62 | result = metric(id_to_pred, id_to_labels)
63 |
64 | results.update({f"{metric_prefix}{k}": v for k, v in result.items()})
65 |
66 | results["num_predicted"] = len(id_to_pred)
67 | results["mean_prediction_length_characters"] = np.mean([len(pred) for pred in id_to_pred_decoded.values()])
68 |
69 | elem = next(iter(id_to_pred.values()))
70 | if not ((isinstance(elem, list) and isinstance(elem[0], str)) or isinstance(elem, str)):
71 | tokenizer = self._kwargs["tokenizer"]
72 | results["mean_prediction_length_tokens"] = np.mean(
73 | [np.count_nonzero(np.array(pred) != tokenizer.pad_token_id) for pred in id_to_pred.values()]
74 | ) # includes BOS/EOS tokens
75 |
76 | results = {key: round(value, 4) for key, value in results.items()}
77 | return results
78 |
79 |
80 | def load_metric(paths: List[str], **kwargs):
81 | if paths is None or len(paths) == 0:
82 | return None
83 | if isinstance(paths, str):
84 | paths = [paths]
85 | else:
86 | paths = [path for path in paths]
87 |
88 | metric_cls_list = []
89 |
90 | scrolls_custom_metrics = []
91 | to_remove = []
92 | for i, path in enumerate(paths):
93 | if not os.path.isfile(path):
94 | scrolls_custom_metrics.append(path)
95 | to_remove.append(i)
96 | for i in sorted(to_remove, reverse=True):
97 | del paths[i]
98 | if len(scrolls_custom_metrics) > 0:
99 | scrolls_custom_metrics.insert(0, "") # In order to have an identifying comma in the beginning
100 | metric_cls_list.append(ScrollsWrapper(",".join(scrolls_custom_metrics), **kwargs))
101 |
102 | for path in paths:
103 | path = path.strip()
104 | if len(path) == 0:
105 | continue
106 | if os.path.isfile(path) is False:
107 | path = os.path.join("src", "metrics", f"{path}.py")
108 |
109 | module = path[:-3].replace(os.sep, ".")
110 |
111 | metric_cls = import_main_class(module)
112 | metric_cls_list.append(metric_cls(**kwargs))
113 |
114 | return MetricCollection(metric_cls_list, **kwargs)
115 |
116 |
117 | # Modified from datasets.load
118 | def import_main_class(module_path):
119 | """Import a module at module_path and return its main class"""
120 | module = importlib.import_module(module_path)
121 |
122 | main_cls_type = Metric
123 |
124 | # Find the main class in our imported module
125 | module_main_cls = None
126 | for name, obj in module.__dict__.items():
127 | if isinstance(obj, type) and issubclass(obj, main_cls_type):
128 | if inspect.isabstract(obj):
129 | continue
130 | module_main_cls = obj
131 | break
132 |
133 | return module_main_cls
134 |
135 |
136 | class ScrollsWrapper(Metric):
137 | def __init__(self, comma_separated_metric_names, **kwargs) -> None:
138 | super().__init__(**kwargs)
139 | self.prefix = None
140 |
141 | self._metric = hf_load_metric(download_metric(), comma_separated_metric_names, keep_in_memory=True)
142 |
143 | self.requires_decoded = True
144 |
145 | def _compute_metrics(self, id_to_pred, id_to_labels) -> Dict[str, float]:
146 | return self._metric.compute(**self._metric.convert_from_map_format(id_to_pred, id_to_labels))
147 |
148 |
149 | def download_metric():
150 | # here we load the custom metrics
151 |
152 | try:
153 | scrolls_metric_path = hf_hub_download(repo_id="tau/scrolls", repo_type="dataset", filename="metrics/scrolls.py")
154 | except:
155 | # support for backward compatibility
156 | scrolls_metric_path = hf_hub_download(repo_id="datasets/tau/scrolls", filename="metrics/scrolls.py")
157 |
158 | updated_scrolls_metric_path = (
159 | os.path.dirname(scrolls_metric_path) + os.path.basename(scrolls_metric_path).replace(".", "_") + ".py"
160 | )
161 | shutil.copy(scrolls_metric_path, updated_scrolls_metric_path)
162 | return updated_scrolls_metric_path
163 |
--------------------------------------------------------------------------------
/examples/seq2seq/run.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2021 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """
17 | Fine-tuning the library models for sequence to sequence.
18 | """
19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20 |
21 | import logging
22 | import os
23 | import sys
24 |
25 | import numpy as np
26 |
27 | # we import the logging frameworks before any other import to make sure all monkey patching for the logging are active
28 | from sled import SledConfig
29 |
30 | try:
31 | import comet_ml
32 | except ImportError:
33 | pass
34 | try:
35 | import wandb
36 | except ImportError:
37 | pass
38 | import torch
39 |
40 |
41 | sys.path.insert(0, os.path.dirname(__file__)) # seq2seq package path
42 | sys.path.insert(0, os.getcwd())
43 |
44 | from dataclasses import dataclass, field
45 | from typing import List, Optional
46 | import json
47 | from copy import deepcopy
48 | import torch.nn.functional as F
49 |
50 | import datasets
51 |
52 | import transformers
53 | from transformers import (
54 | AutoConfig,
55 | AutoModelForSeq2SeqLM,
56 | AutoTokenizer,
57 | set_seed, WEIGHTS_NAME,
58 | )
59 | from transformers.trainer_utils import get_last_checkpoint
60 | from transformers import DataCollatorForSeq2Seq
61 |
62 | from datasets import load_dataset
63 |
64 | # noinspection PyUnresolvedReferences
65 | import sled # *** required so that SledModels will be registered for the AutoClasses ***
66 |
67 | from utils.config import handle_args_to_ignore
68 | from utils.decoding import decode
69 | from metrics import load_metric
70 | from utils.duplicates import drop_duplicates_in_input
71 | from utils.override_training_args import TrainingOverridesArguments
72 | from utils.custom_seq2seq_trainer import CustomTrainer
73 | from utils.custom_hf_argument_parser import CustomHfArgumentParser
74 |
75 | logger = logging.getLogger('sled')
76 |
77 | PREFIX_DOC_SEP = '\n\n'
78 |
79 | DEBUG = os.environ.get('DEBUG', 'false').lower() in {'1', 'true', 'yes'} # If set, will set some configuration to help debug
80 | if DEBUG:
81 | assert not torch.cuda.is_available() or torch.cuda.device_count() == 1
82 |
83 |
84 | @dataclass
85 | class ModelArguments:
86 | """
87 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
88 | """
89 |
90 | model_name_or_path: str = field(
91 | default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
92 | )
93 | config_name: Optional[str] = field(
94 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
95 | )
96 | tokenizer_name: Optional[str] = field(
97 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
98 | )
99 | cache_dir: Optional[str] = field(
100 | default=None,
101 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
102 | )
103 | use_fast_tokenizer: bool = field(
104 | default=True,
105 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
106 | )
107 | model_revision: str = field(
108 | default="main",
109 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
110 | )
111 | drop_duplicates_in_eval: bool = field(
112 | default=True,
113 | )
114 |
115 | def __post_init__(self):
116 | pass
117 |
118 |
119 | @dataclass
120 | class DataTrainingArguments:
121 | """
122 | Arguments pertaining to what data we are going to input our model for training and eval.
123 | """
124 |
125 | dataset_name: Optional[str] = field(
126 | default=None,
127 | metadata={
128 | "help": "The name of the dataset to use (via the datasets library) or name of the file in src/data."
129 | },
130 | )
131 | dataset_config_name: Optional[str] = field(
132 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
133 | )
134 | metric_names: Optional[List[str]] = field(
135 | default=None,
136 | metadata={"help": "The name of the metric to use (from src/metrics)."},
137 | )
138 | input_column: Optional[str] = field(
139 | default=None,
140 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
141 | )
142 | input_prefix_column: Optional[str] = field(
143 | default=None,
144 | metadata={"help": "The name of the column in the datasets containing the input prefix (e.g. questions), when those exist."},
145 | )
146 | output_column: Optional[str] = field(
147 | default=None,
148 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
149 | )
150 | train_file: Optional[str] = field(
151 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
152 | )
153 | validation_file: Optional[str] = field(
154 | default=None,
155 | metadata={
156 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on "
157 | "(a jsonlines or csv file)."
158 | },
159 | )
160 | test_file: Optional[str] = field(
161 | default=None,
162 | metadata={
163 | "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)."
164 | },
165 | )
166 | overwrite_cache: bool = field(
167 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
168 | )
169 | preprocessing_num_workers: Optional[int] = field(
170 | default=None,
171 | metadata={"help": "The number of processes to use for the preprocessing."},
172 | )
173 | max_source_length: Optional[int] = field(
174 | default=None,
175 | metadata={
176 | "help": "The maximum total input sequence length after tokenization. Sequences longer "
177 | "than this will be truncated, sequences shorter will be padded."
178 | },
179 | )
180 | max_prefix_length: Optional[int] = field(
181 | default=0,
182 | metadata={
183 | "help": "The maximum total input_prefix sequence length after tokenization. Sequences longer "
184 | "than this will be truncated, sequences shorter will be padded from the left "
185 | "(only used if prefixes are not merged)."
186 | },
187 | )
188 | max_target_length: Optional[int] = field(
189 | default=128,
190 | metadata={
191 | "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
192 | "than this will be truncated, sequences shorter will be padded."
193 | },
194 | )
195 | val_max_target_length: Optional[int] = field(
196 | default=None,
197 | metadata={
198 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
199 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
200 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
201 | "during ``evaluate`` and ``predict``."
202 | },
203 | )
204 | pad_to_max_length: bool = field(
205 | default=False,
206 | metadata={
207 | "help": "Whether to pad all samples to model maximum sentence length. "
208 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
209 | "efficient on GPU but very bad for TPU."
210 | },
211 | )
212 | max_train_samples: Optional[int] = field(
213 | default=None,
214 | metadata={
215 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
216 | "value if set."
217 | },
218 | )
219 | max_eval_samples: Optional[int] = field(
220 | default=None,
221 | metadata={
222 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
223 | "value if set."
224 | },
225 | )
226 | max_predict_samples: Optional[int] = field(
227 | default=None,
228 | metadata={
229 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
230 | "value if set."
231 | },
232 | )
233 | num_beams: Optional[int] = field(
234 | default=None,
235 | metadata={
236 | "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
237 | "which is used during ``evaluate`` and ``predict``."
238 | },
239 | )
240 | ignore_pad_token_for_loss: bool = field(
241 | default=True,
242 | metadata={
243 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
244 | },
245 | )
246 | source_prefix: Optional[str] = field(
247 | default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
248 | )
249 | data_dir: Optional[str] = field(
250 | default=None,
251 | metadata={"help": "Defining the data_dir of the dataset configuration."},
252 | )
253 | download_mode: Optional[str] = field(
254 | default=None,
255 | metadata={
256 | "help": "Defining the download_mode when loading the dataset. Options are `reuse_dataset_if_exists` (default), `reuse_cache_if_exists` and `force_redownload`."
257 | },
258 | )
259 | evaluate_on_training_data: bool = field(
260 | default=False,
261 | metadata={"help": "Whether to evaluate on training data or not, to make sure the model can overfit."},
262 | )
263 | folder_suffix: str = field(
264 | default="",
265 | metadata={"help": "args to be suffixes for the output folder of the run"},
266 | )
267 | preprocess_only: bool = field(
268 | default=False,
269 | metadata={"help": "Preprocess only: Don't start training, just do the things before"},
270 | )
271 | assign_zero_to_too_long_val_examples: bool = field(
272 | default=False,
273 | metadata={
274 | "help": "If true, all sequences longer then max_source_length will be assign a score of 0 in the metric evaluation"
275 | },
276 | )
277 | shared_storage: bool = field(
278 | default=True,
279 | metadata={"help": "Whether nodes share the same storage"},
280 | )
281 | trim_very_long_strings: bool = field(
282 | default=False,
283 | metadata={"help": "Whether to trim very long strings before tokenizing them"},
284 | )
285 | pad_prefix: bool = field(
286 | default=False,
287 | metadata={
288 | "help": "Whether to pad the prefix if it exists to max_prefix_length. "
289 | "Note - important if you are using a SLED model on an input that contains an input_prefix"
290 | },
291 | )
292 | test_start_ind: Optional[int] = field(
293 | default=None,
294 | metadata={"help": "if given, uses the test set starting from this index"},
295 | )
296 | test_end_ind: Optional[int] = field(
297 | default=None,
298 | metadata={"help": "if given, uses the test set ending at this index"},
299 | )
300 |
301 |
302 | def __post_init__(self):
303 | if self.val_max_target_length is None:
304 | self.val_max_target_length = self.max_target_length
305 | if self.pad_prefix and self.max_prefix_length == 0:
306 | raise ValueError('When padding prefix, you must set a max_prefix_length')
307 | assert self.max_prefix_length == 0 or self.max_prefix_length <= 0.5*self.max_source_length,\
308 | 'If max_prefix_length is given, it must be much shorter than the total input'
309 |
310 |
311 | def main():
312 | handle_args_to_ignore(sys.argv) # Just for sweeps
313 |
314 | # See all possible arguments in src/transformers/training_args.py
315 | # or by passing the --help flag to this script.
316 | # We now keep distinct sets of args, for a cleaner separation of concerns.
317 |
318 | parser = CustomHfArgumentParser((ModelArguments, DataTrainingArguments, TrainingOverridesArguments))
319 | model_args, data_args, training_args = parser.parse_dictionary_and_args()
320 |
321 | set_up_logging(training_args)
322 |
323 | # Used to find missing dependencies early on
324 | load_metric(data_args.metric_names, **locals())
325 |
326 | if data_args.source_prefix is None and model_args.model_name_or_path in [
327 | "t5-small",
328 | "t5-base",
329 | "t5-large",
330 | "t5-3b",
331 | "t5-11b",
332 | ]:
333 | logger.warning(
334 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with "
335 | "`--source_prefix 'summarize: ' `"
336 | )
337 |
338 | # Detecting last checkpoint.
339 | last_checkpoint = _detect_last_checkpoint(training_args)
340 |
341 | # Set seed before initializing model.
342 | set_seed(training_args.seed)
343 |
344 | seq2seq_dataset = _get_dataset(data_args, model_args, training_args)
345 |
346 | # Load pretrained model and tokenizer
347 | #
348 | # Distributed training:
349 | # The .from_pretrained methods guarantee that only one local process can concurrently
350 | # download model & vocab.
351 | config_name = None
352 | if model_args.config_name:
353 | config_name = model_args.config_name
354 | else:
355 | if os.path.isfile(model_args.model_name_or_path):
356 | config_name = os.path.dirname(model_args.model_name_or_path)
357 | else:
358 | config_name = model_args.model_name_or_path
359 |
360 | config_overrides = {}
361 | if training_args.gradient_checkpointing is not None:
362 | config_overrides["gradient_checkpointing"] = training_args.gradient_checkpointing
363 |
364 | config = AutoConfig.from_pretrained(
365 | config_name,
366 | cache_dir=model_args.cache_dir,
367 | revision=model_args.model_revision,
368 | use_auth_token=training_args.use_auth_token,
369 | **config_overrides
370 | )
371 | # override for sled models to make sure we are explicit in our request
372 | if isinstance(config, SledConfig) and (not data_args.pad_prefix or data_args.max_prefix_length == 0):
373 | logger.warning('Setting prepend_prefix to False if using a SLED model, as the input does not have a prefix or '
374 | 'pad_prefix is False (all prefixes must be of the same length for SLED). If you do not use SLED '
375 | 'or finetune on a dataset with no prefixes, ignore this warning')
376 | config.prepend_prefix = False
377 |
378 | if model_args.model_name_or_path is None:
379 | # Padding for divisibility by 8
380 | if config.vocab_size % 8 != 0 and training_args.fp16_padding:
381 | config.vocab_size += 8 - (config.vocab_size % 8)
382 |
383 | tokenizer_name = None
384 | if model_args.tokenizer_name:
385 | tokenizer_name = model_args.tokenizer_name
386 | else:
387 | if os.path.isfile(model_args.model_name_or_path):
388 | tokenizer_name = os.path.dirname(model_args.model_name_or_path)
389 | else:
390 | tokenizer_name = model_args.model_name_or_path
391 | tokenizer = AutoTokenizer.from_pretrained(
392 | tokenizer_name,
393 | cache_dir=model_args.cache_dir,
394 | use_fast=model_args.use_fast_tokenizer,
395 | revision=model_args.model_revision,
396 | use_auth_token=training_args.use_auth_token,
397 | )
398 | if model_args.model_name_or_path is not None:
399 | model = AutoModelForSeq2SeqLM.from_pretrained(
400 | model_args.model_name_or_path,
401 | from_tf=bool(".ckpt" in model_args.model_name_or_path),
402 | config=config,
403 | cache_dir=model_args.cache_dir,
404 | revision=model_args.model_revision,
405 | use_auth_token=training_args.use_auth_token,
406 | )
407 | else:
408 | model = AutoModelForSeq2SeqLM.from_config(
409 | config,
410 | )
411 | if training_args.gradient_checkpointing and getattr(model.config, 'use_cache', False):
412 | logger.warning('Cannot use cache in models when using gradient checkpointing. turning it off')
413 | model.config.use_cache = False
414 |
415 | model.resize_token_embeddings(len(tokenizer))
416 |
417 | if model.config.decoder_start_token_id is None:
418 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
419 |
420 | prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
421 |
422 | # Preprocessing the datasets.
423 | # We need to tokenize inputs and targets.
424 | if training_args.do_train:
425 | column_names = seq2seq_dataset["train"].column_names
426 | elif training_args.do_eval:
427 | column_names = seq2seq_dataset["validation"].column_names
428 | elif training_args.do_predict:
429 | column_names = seq2seq_dataset["test"].column_names
430 | else:
431 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
432 | return
433 |
434 | # Get the column names for input/target.
435 | if data_args.input_column is None:
436 | input_column = "input"
437 | else:
438 | input_column = data_args.input_column
439 | if input_column not in column_names:
440 | raise ValueError(
441 | f"--input_column' value '{data_args.input_column}' needs to be one of: {', '.join(column_names)}"
442 | )
443 | if data_args.input_prefix_column is None:
444 | input_prefix_column = "input_prefix"
445 | else:
446 | input_prefix_column = data_args.input_prefix_column
447 | if input_prefix_column not in column_names:
448 | raise ValueError(
449 | f"--input_prefix_column' value '{data_args.input_prefix_column}' needs to be one of: {', '.join(column_names)}"
450 | )
451 | if data_args.output_column is None:
452 | output_column = "output"
453 | else:
454 | output_column = data_args.output_column
455 | if output_column not in column_names:
456 | raise ValueError(
457 | f"--output_column' value '{data_args.output_column}' needs to be one of: {', '.join(column_names)}"
458 | )
459 |
460 | # Temporarily set max_target_length for training.
461 | max_target_length = data_args.max_target_length
462 | padding = "max_length" if data_args.pad_to_max_length else False
463 |
464 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
465 | logger.warning(
466 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
467 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
468 | )
469 |
470 | def preprocess_function_kwargs_fn():
471 | return {
472 | "tokenizer": deepcopy(tokenizer),
473 | "prefix": prefix,
474 | "input_column": input_column,
475 | "input_prefix_column": input_prefix_column,
476 | "output_column": output_column,
477 | "max_source_length": data_args.max_source_length,
478 | "max_prefix_length": data_args.max_prefix_length,
479 | "max_target_length": max_target_length,
480 | "prefix_sep": PREFIX_DOC_SEP,
481 | "padding": padding,
482 | "ignore_pad_token_for_loss": data_args.ignore_pad_token_for_loss,
483 | "assign_zero_to_too_long_val_examples": data_args.assign_zero_to_too_long_val_examples,
484 | "trim_very_long_strings": data_args.trim_very_long_strings,
485 | "pad_prefix": data_args.pad_prefix
486 | }
487 |
488 | if training_args.do_train:
489 | if "train" not in seq2seq_dataset:
490 | raise ValueError("--do_train requires a train dataset")
491 | logger.info("")
492 | logger.info("Training examples before tokenization:")
493 | if input_prefix_column in column_names:
494 | logger.info(f"input_prefix #0: {seq2seq_dataset['train'][0][input_prefix_column]}")
495 | logger.info(f"input #0: {seq2seq_dataset['train'][0]['input']}")
496 | logger.info(f"output #0: {seq2seq_dataset['train'][0]['output']}")
497 | if input_prefix_column in column_names:
498 | logger.info(f"input_prefix #1: {seq2seq_dataset['train'][1][input_prefix_column]}")
499 | logger.info(f"input #1: {seq2seq_dataset['train'][1]['input']}")
500 | logger.info(f"output #1: {seq2seq_dataset['train'][1]['output']}")
501 | logger.info("")
502 | untokenized_train_dataset = seq2seq_dataset["train"]
503 | if data_args.max_train_samples is not None:
504 | untokenized_train_dataset = untokenized_train_dataset.select(range(data_args.max_train_samples))
505 |
506 | if DEBUG:
507 | # In debug mode, we want ot recreate the data
508 | data_args.shared_storage = False
509 | data_args.overwrite_cache = True
510 | with training_args.main_process_first(
511 | local=not data_args.shared_storage, desc="train dataset map pre-processing"
512 | ):
513 | train_dataset = untokenized_train_dataset.map(
514 | preprocess_function,
515 | fn_kwargs=preprocess_function_kwargs_fn(),
516 | batched=True,
517 | num_proc=data_args.preprocessing_num_workers,
518 | remove_columns=untokenized_train_dataset.column_names,
519 | load_from_cache_file=not data_args.overwrite_cache,
520 | desc="Running tokenizer on train dataset",
521 | )
522 |
523 | if training_args.do_eval:
524 | max_target_length = data_args.val_max_target_length
525 | preprocess_function_kwargs = preprocess_function_kwargs_fn()
526 | preprocess_function_kwargs["max_target_length"] = max_target_length
527 | if "validation" not in seq2seq_dataset:
528 | raise ValueError("--do_eval requires a validation dataset")
529 | logger.info("")
530 | logger.info("Validation examples before tokenization:")
531 | if input_prefix_column in column_names:
532 | logger.info(f"input_prefix #0: {seq2seq_dataset['validation'][0][input_prefix_column]}")
533 | logger.info(f"input #0: {seq2seq_dataset['validation'][0]['input']}")
534 | logger.info(f"output #0: {seq2seq_dataset['validation'][0]['output']}")
535 | if input_prefix_column in column_names:
536 | logger.info(f"input_prefix #1: {seq2seq_dataset['validation'][1][input_prefix_column]}")
537 | logger.info(f"input #1: {seq2seq_dataset['validation'][1]['input']}")
538 | logger.info(f"output #1: {seq2seq_dataset['validation'][1]['output']}")
539 | logger.info("")
540 | untokenized_eval_dataset = seq2seq_dataset["validation"]
541 | if data_args.max_eval_samples is not None:
542 | untokenized_eval_dataset = untokenized_eval_dataset.select(range(data_args.max_eval_samples))
543 | if model_args.drop_duplicates_in_eval is True:
544 | untokenized_eval_dataset = drop_duplicates_in_input(untokenized_eval_dataset)
545 | untokenized_eval_dataset_orig = untokenized_eval_dataset
546 | assert training_args.eval_fraction > 0
547 | n = len(untokenized_eval_dataset)
548 | training_args.eval_fraction = min(training_args.eval_fraction, n)
549 | if training_args.eval_fraction != 1:
550 | if training_args.eval_fraction > 1:
551 | assert training_args.eval_fraction == int(training_args.eval_fraction)
552 | logger.info(f'using predetermined absolute samples from eval set ({training_args.eval_fraction} )')
553 | training_args.eval_fraction = training_args.eval_fraction / n
554 | indices = np.random.permutation(n)[:int(np.ceil(max(1, training_args.eval_fraction * n)))]
555 | untokenized_eval_dataset = type(untokenized_eval_dataset).from_dict(untokenized_eval_dataset[indices])
556 | logger.info(f'During training, will only use {training_args.eval_fraction:.3%} samples of the eval set '
557 | f'which amounts to {len(untokenized_eval_dataset)} out of {n} samples')
558 |
559 | eval_dataset = process_eval_set(data_args, preprocess_function_kwargs, training_args, untokenized_eval_dataset)
560 | eval_dataset_orig = eval_dataset
561 | if training_args.eval_fraction < 1:
562 | eval_dataset_orig = process_eval_set(data_args, preprocess_function_kwargs, training_args,
563 | untokenized_eval_dataset_orig)
564 |
565 | if training_args.do_predict:
566 | max_target_length = data_args.val_max_target_length
567 | preprocess_function_kwargs = preprocess_function_kwargs_fn()
568 | preprocess_function_kwargs["max_target_length"] = max_target_length
569 | if "test" not in seq2seq_dataset:
570 | raise ValueError("--do_predict requires a test dataset")
571 | untokenized_predict_dataset = seq2seq_dataset["test"]
572 | if data_args.max_predict_samples is not None:
573 | untokenized_predict_dataset = untokenized_predict_dataset.select(range(data_args.max_predict_samples))
574 | if model_args.drop_duplicates_in_eval is True:
575 | untokenized_predict_dataset = drop_duplicates_in_input(untokenized_predict_dataset)
576 |
577 | if output_column in untokenized_predict_dataset.column_names:
578 | untokenized_predict_dataset = untokenized_predict_dataset.remove_columns(output_column)
579 |
580 | if data_args.test_start_ind is not None:
581 | sind = data_args.test_start_ind
582 | eind = -1 if data_args.test_end_ind is None else data_args.test_end_ind
583 | logger.info(f'Using only a subset of the test dataset [{sind}, {eind}]')
584 | untokenized_predict_dataset = type(untokenized_predict_dataset).from_dict(untokenized_predict_dataset[sind:eind])
585 |
586 | with training_args.main_process_first(
587 | local=not data_args.shared_storage, desc="prediction dataset map pre-processing"
588 | ):
589 | predict_dataset = untokenized_predict_dataset.map(
590 | preprocess_function,
591 | fn_kwargs=preprocess_function_kwargs,
592 | batched=True,
593 | num_proc=data_args.preprocessing_num_workers,
594 | remove_columns=untokenized_predict_dataset.column_names,
595 | load_from_cache_file=not data_args.overwrite_cache,
596 | desc="Running tokenizer on prediction dataset",
597 | )
598 |
599 | if data_args.preprocess_only:
600 | logger.info(f"With --preprocess_only, exiting after preprocess_on the data")
601 | exit()
602 |
603 | # Data collator
604 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
605 | pad_to = 8 if training_args.fp16 and training_args.fp16_padding else None
606 |
607 |
608 | data_collator = DataCollatorForSeq2Seq(
609 | tokenizer,
610 | model=model,
611 | label_pad_token_id=label_pad_token_id,
612 | pad_to_multiple_of=pad_to,
613 | )
614 |
615 | # Metric
616 | compute_metrics = load_metric(data_args.metric_names, **locals())
617 |
618 | # Initialize our Trainer
619 | trainer = CustomTrainer(
620 | model=model,
621 | args=training_args,
622 | train_dataset=train_dataset if training_args.do_train else None,
623 | eval_dataset=eval_dataset if training_args.do_eval else None,
624 | untokenized_eval_dataset=untokenized_eval_dataset if training_args.do_eval else None,
625 | tokenizer=tokenizer,
626 | data_collator=data_collator,
627 | compute_metrics=compute_metrics if training_args.predict_with_generate else None,
628 | output_dir=training_args.output_dir,
629 | data_args=data_args,
630 | )
631 |
632 | # setup_cometml_trainer_callback(trainer)
633 |
634 | # Training
635 | if training_args.do_train:
636 | checkpoint = None
637 | if training_args.resume_from_checkpoint is not None:
638 | checkpoint = training_args.resume_from_checkpoint
639 | elif last_checkpoint is not None:
640 | checkpoint = last_checkpoint # look for checkpoints in the outdir
641 |
642 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
643 | logger.info('Done training')
644 | trainer.save_model() # Saves the tokenizer too for easy upload
645 |
646 | metrics = train_result.metrics
647 | max_train_samples = (
648 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
649 | )
650 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
651 |
652 | trainer.log_metrics("train", metrics)
653 | trainer.save_metrics("train", metrics)
654 | trainer.save_state()
655 |
656 | # Evaluation
657 | results = {}
658 | if training_args.do_eval:
659 | logger.info("*** Evaluate ***")
660 |
661 | if training_args.eval_fraction < 1:
662 | logger.info('setting the eval set back to the full one')
663 | trainer.eval_dataset = eval_dataset_orig
664 | trainer._untokenized_eval_dataset = untokenized_eval_dataset_orig
665 |
666 | metrics = trainer.evaluate(metric_key_prefix="eval")
667 | logger.info('Done evaluating')
668 |
669 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
670 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
671 |
672 | trainer.log_metrics("eval", metrics)
673 | trainer.save_metrics("eval", metrics)
674 |
675 | if training_args.do_predict:
676 | logger.info("*** Predict ***")
677 | trainer.args.predict_with_generate = True # during prediction, we don't have labels
678 |
679 | # load last (and best) model, or the one specified if any
680 | logger.info("*** Loading model weights before the prediction ***")
681 | last_checkpoint = model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else _detect_last_checkpoint(training_args)
682 | if last_checkpoint is not None and os.path.isdir(last_checkpoint):
683 | logger.info(f'Loading weights from {last_checkpoint} for the prediction')
684 | state_dict = torch.load(os.path.join(last_checkpoint, WEIGHTS_NAME), map_location="cpu")
685 | # If the model is on the GPU, it still works!
686 | trainer._load_state_dict_in_model(state_dict)
687 | # release memory
688 | del state_dict
689 | logger.info("*** Done loading weights ***")
690 | elif training_args.do_train:
691 | raise ValueError('Could not find a model to load for prediction')
692 | else:
693 | logger.info(f'Using {model_args.model_name_or_path} as the model for the prediction')
694 |
695 | predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict")
696 | logger.info('Done predicting')
697 |
698 | metrics = predict_results.metrics
699 | max_predict_samples = (
700 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
701 | )
702 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
703 |
704 | trainer.log_metrics("predict", metrics)
705 | trainer.save_metrics("predict", metrics)
706 |
707 | if trainer.is_world_process_zero():
708 | if training_args.predict_with_generate:
709 | id_to_prediction = {}
710 | for i, instance in enumerate(untokenized_predict_dataset):
711 | id_to_prediction[instance["id"]] = predict_results.predictions[i]
712 | predictions = decode(id_to_prediction, tokenizer, data_args)
713 | output_name = "generated_predictions.json"
714 | if data_args.test_start_ind is not None:
715 | output_name = f"generated_predictions_{data_args.test_start_ind}_{data_args.test_end_ind}.json"
716 | output_prediction_file = os.path.join(training_args.output_dir, output_name)
717 | with open(output_prediction_file, "w") as writer:
718 | json.dump(predictions, writer, indent=4)
719 |
720 | if training_args.push_to_hub:
721 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"}
722 | if data_args.dataset_name is not None:
723 | kwargs["dataset_tags"] = data_args.dataset_name
724 | if data_args.dataset_config_name is not None:
725 | kwargs["dataset_args"] = data_args.dataset_config_name
726 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
727 | else:
728 | kwargs["dataset"] = data_args.dataset_name
729 |
730 | trainer.push_to_hub(**kwargs)
731 |
732 | return results
733 |
734 | def _detect_last_checkpoint(training_args):
735 | last_checkpoint = None
736 | if os.path.isdir(training_args.output_dir) and training_args.do_train:
737 | if not training_args.overwrite_output_dir:
738 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
739 |
740 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
741 | logger.info(
742 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
743 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
744 | )
745 | return last_checkpoint
746 |
747 | def process_eval_set(data_args, preprocess_function_kwargs, training_args, untokenized_eval_dataset):
748 | with training_args.main_process_first(
749 | local=not data_args.shared_storage, desc="validation dataset map pre-processing"
750 | ):
751 | eval_dataset = untokenized_eval_dataset.map(
752 | preprocess_function,
753 | fn_kwargs=preprocess_function_kwargs,
754 | batched=True,
755 | num_proc=data_args.preprocessing_num_workers,
756 | remove_columns=untokenized_eval_dataset.column_names,
757 | load_from_cache_file=not data_args.overwrite_cache,
758 | desc="Running tokenizer on validation dataset",
759 | )
760 | return eval_dataset
761 |
762 |
763 | def _get_dataset(data_args, model_args, training_args):
764 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
765 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
766 | # (the dataset will be downloaded automatically from the datasets Hub).
767 | #
768 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the
769 | # summaries (unless you specify column names for this with the `input_column` and `output_column` arguments).
770 | #
771 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently
772 | # download the dataset.
773 | data_files = None
774 | if data_args.train_file is not None or data_args.validation_file is not None or data_args.test_file is not None:
775 | data_files = {}
776 | if data_args.train_file is not None:
777 | data_files["train"] = data_args.train_file
778 | if data_args.validation_file is not None:
779 | data_files["validation"] = data_args.validation_file
780 | if data_args.test_file is not None:
781 | data_files["test"] = data_args.test_file
782 | # Downloading and loading a dataset from the hub/local script.
783 | seq2seq_dataset = load_dataset(
784 | data_args.dataset_name,
785 | data_args.dataset_config_name,
786 | ignore_verifications=True,
787 | cache_dir=model_args.cache_dir,
788 | data_dir=data_args.data_dir,
789 | data_files=data_files,
790 | download_mode=data_args.download_mode,
791 | use_auth_token=training_args.use_auth_token
792 | )
793 | if training_args.do_train:
794 | training_args.apply_overrides(len(seq2seq_dataset['train']))
795 | if data_args.evaluate_on_training_data:
796 | seq2seq_dataset["validation"] = seq2seq_dataset["train"]
797 |
798 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
799 | # https://huggingface.co/docs/datasets/loading_datasets.html.
800 |
801 | return seq2seq_dataset
802 |
803 |
804 | def set_up_logging(training_args):
805 | logging.basicConfig(
806 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
807 | datefmt="%m/%d/%Y %H:%M:%S",
808 | handlers=[logging.StreamHandler(sys.stdout)],
809 | )
810 | log_level = training_args.get_process_log_level()
811 | logger.setLevel(log_level)
812 | datasets.utils.logging.set_verbosity(log_level)
813 | transformers.utils.logging.set_verbosity(log_level)
814 | transformers.utils.logging.enable_default_handler()
815 | transformers.utils.logging.enable_explicit_format()
816 | # Log on each process the small summary:
817 | logger.warning(
818 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
819 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
820 | )
821 | logger.info(f"Training/evaluation parameters {training_args}")
822 |
823 |
824 | def preprocess_function(
825 | examples,
826 | tokenizer,
827 | prefix,
828 | input_column,
829 | input_prefix_column,
830 | output_column,
831 | max_source_length,
832 | max_prefix_length,
833 | max_target_length,
834 | prefix_sep,
835 | padding,
836 | ignore_pad_token_for_loss,
837 | assign_zero_to_too_long_val_examples,
838 | trim_very_long_strings,
839 | pad_prefix
840 | ):
841 | if not isinstance(examples[input_column][0], str):
842 | model_inputs = _preprocess_tokenized_inputs()
843 | else:
844 | model_inputs = _preprocess_raw_inputs(assign_zero_to_too_long_val_examples, examples, input_column, input_prefix_column,
845 | max_source_length, padding, prefix, tokenizer, trim_very_long_strings, max_prefix_length,
846 | prefix_sep, pad_prefix)
847 |
848 | _preprocess_targets(examples, ignore_pad_token_for_loss, max_target_length, model_inputs, output_column, padding, tokenizer)
849 | model_inputs["length"] = [len(x) for x in model_inputs["input_ids"]]
850 | return model_inputs
851 |
852 |
853 | def _preprocess_raw_inputs(assign_zero_to_too_long_val_examples, examples, input_column, input_prefix_column,
854 | max_source_length, padding, prefix, tokenizer, trim_very_long_strings, max_prefix_length,
855 | prefix_sep, pad_prefix):
856 | inputs = examples[input_column]
857 |
858 | # the given prefix is what used in models like T5 (e.g. "summarize: ")
859 | # if prefix exists, it is added to the input_prefixes
860 | if input_prefix_column in examples.keys():
861 | input_prefixes = [inp + prefix_sep for inp in examples[input_prefix_column]]
862 | if prefix != "":
863 | input_prefixes = [prefix + inp for inp in input_prefixes]
864 | elif prefix != "":
865 | inputs = [prefix + inp for inp in inputs]
866 |
867 | # tokenize the input prefix if it exists
868 | model_prefix_inputs = None
869 | if input_prefix_column in examples.keys():
870 | if trim_very_long_strings:
871 | input_prefixes = [inp[: max_prefix_length * 7] for inp in input_prefixes]
872 | if pad_prefix:
873 | model_prefix_inputs = tokenizer(input_prefixes, max_length=max_prefix_length, padding='max_length', truncation=True)
874 | else:
875 | # for led, we do not pad the prefix
876 | model_prefix_inputs = tokenizer(input_prefixes, max_length=max_source_length, padding='do_not_pad', truncation=True)
877 |
878 | if trim_very_long_strings:
879 | inputs = [inp[: max_source_length * 7] for inp in inputs]
880 | model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)
881 |
882 | if max_source_length is not None and assign_zero_to_too_long_val_examples:
883 | model_inputs_untrimmed = tokenizer(inputs)
884 | model_inputs["not_valid_for_eval"] = [
885 | len(token_ids) > max_source_length for token_ids in model_inputs_untrimmed["input_ids"]
886 | ]
887 | else:
888 | model_inputs["not_valid_for_eval"] = [False] * len(model_inputs["input_ids"])
889 |
890 | # now, combine the concat prefix to the input, trimming it to max_source_length if given
891 | if model_prefix_inputs is not None:
892 | max_source_length = max_source_length or -1
893 | model_inputs['input_ids'] = [(inp1+inp2)[:max_source_length] for inp1, inp2
894 | in zip(model_prefix_inputs['input_ids'], model_inputs['input_ids'])]
895 | model_inputs['attention_mask'] = [(inp1+inp2)[:max_source_length] for inp1, inp2
896 | in zip(model_prefix_inputs['attention_mask'], model_inputs['attention_mask'])]
897 | # add prefix_length
898 | if pad_prefix:
899 | # no need to go over them as they will all be of the same length
900 | model_inputs['prefix_length'] = [max_prefix_length] * len(model_inputs['input_ids'])
901 | else:
902 | model_inputs['prefix_length'] = [len(inp) for inp in model_prefix_inputs['input_ids']]
903 |
904 | return model_inputs
905 |
906 |
907 | def _preprocess_tokenized_inputs():
908 | raise NotImplementedError('did not get here it') # TODO - implement input_prefix support here as well
909 | input_ids = []
910 | attention_mask = []
911 | for token_ids in examples[input_column]:
912 | length = len(token_ids)
913 | if max_source_length is not None:
914 | length = min(max_source_length, length)
915 | input_ids.append(token_ids[:length])
916 | attention_mask.append([1 for _ in range(length)])
917 | if len(input_ids) == 0:
918 | input_ids = examples[input_column]
919 | model_inputs = {
920 | "input_ids": input_ids,
921 | "attention_mask": attention_mask,
922 | }
923 | if max_source_length is not None:
924 | model_inputs["not_valid_for_eval"] = [
925 | len(token_ids) > max_source_length and assign_zero_to_too_long_val_examples
926 | for token_ids in examples[input_column]
927 | ]
928 | else:
929 | model_inputs["not_valid_for_eval"] = [False for token_ids in examples[input_column]]
930 | return model_inputs
931 |
932 |
933 |
934 |
935 | def _preprocess_targets(examples, ignore_pad_token_for_loss, max_target_length, model_inputs, output_column, padding, tokenizer):
936 | targets = examples[output_column] if output_column in examples else None
937 | if targets is not None:
938 | if not isinstance(targets[0], str):
939 | if max_target_length is not None:
940 | targets = [target[:max_target_length] for target in targets]
941 | model_inputs["labels"] = targets
942 | else:
943 | # Setup the tokenizer for targets
944 | with tokenizer.as_target_tokenizer():
945 | labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
946 |
947 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
948 | # padding in the loss.
949 | if padding == "max_length" and ignore_pad_token_for_loss:
950 | labels["input_ids"] = [
951 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
952 | ]
953 |
954 | model_inputs["labels"] = labels["input_ids"]
955 |
956 |
957 | def _mp_fn(index):
958 | # For xla_spawn (TPUs)
959 | main()
960 |
961 |
962 | if __name__ == "__main__":
963 | main()
964 |
--------------------------------------------------------------------------------
/examples/seq2seq/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Mivg/SLED/2991895d01e5ed5871de41676a14b25c534c523a/examples/seq2seq/utils/__init__.py
--------------------------------------------------------------------------------
/examples/seq2seq/utils/config.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 |
4 | def handle_args_to_ignore(args: List[str]):
5 | indices_to_remove = []
6 | for i, arg in enumerate(args):
7 | if "_ignore_" in arg:
8 | indices_to_remove.append(i)
9 | if not arg.startswith("-"):
10 | indices_to_remove.append(i - 1)
11 |
12 | for i in sorted(indices_to_remove, reverse=True):
13 | del args[i]
14 |
--------------------------------------------------------------------------------
/examples/seq2seq/utils/custom_hf_argument_parser.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import sys
4 | from typing import Tuple
5 |
6 | from transformers import HfArgumentParser
7 | from transformers.hf_argparser import DataClass
8 |
9 |
10 | class CustomHfArgumentParser(HfArgumentParser):
11 | def parse_dictionary_and_args(self) -> Tuple[DataClass, ...]:
12 | """
13 | Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
14 | dataclass types.
15 | """
16 | args = []
17 | data = {}
18 | for i in range(1, len(sys.argv)):
19 | if not sys.argv[i].endswith('.json'):
20 | break
21 |
22 | with open(sys.argv[i]) as f:
23 | new_data = json.load(f)
24 | conflicting_keys = set(new_data.keys()).intersection(data.keys())
25 | if len(conflicting_keys) > 0:
26 | raise ValueError(f'There are conflicting keys in the config files: {conflicting_keys}')
27 | data.update(new_data)
28 |
29 | for k, v in data.items():
30 | # if any options were given explicitly through the CLA then they override anything defined in the config files
31 | if f'--{k}' in sys.argv:
32 | logging.info(f'While {k}={v} was given in a config file, a manual override was set through the CLA')
33 | continue
34 | args.extend(
35 | ["--" + k, *(v if isinstance(v, list) else [str(v)])]
36 | ) # add the file arguments first so command line args has precedence
37 | args += sys.argv[i:]
38 |
39 | return self.parse_args_into_dataclasses(args=args, look_for_args_file=False)
--------------------------------------------------------------------------------
/examples/seq2seq/utils/custom_seq2seq_trainer.py:
--------------------------------------------------------------------------------
1 | import json
2 | import math
3 | import os
4 | import time
5 | from collections import defaultdict
6 | from typing import Any, Dict, List, Optional, Tuple, Union
7 |
8 | import torch
9 | from datasets import Dataset
10 | from torch import nn
11 | from transformers.debug_utils import DebugOption
12 | from transformers.deepspeed import is_deepspeed_zero3_enabled
13 | from transformers.trainer_utils import speed_metrics
14 |
15 | from transformers.utils import logging
16 | from transformers import Seq2SeqTrainer, is_torch_tpu_available
17 |
18 | import gc
19 |
20 | if is_torch_tpu_available(check_device=False):
21 | import torch_xla.core.xla_model as xm
22 | import torch_xla.debug.metrics as met
23 |
24 |
25 | from utils.decoding import decode
26 |
27 | logger = logging.get_logger(__name__)
28 |
29 |
30 | def _clean_memory():
31 | gc.collect()
32 | torch.cuda.empty_cache()
33 |
34 | # This custom trainer is based on the trainer defined in https://github.com/huggingface/transformers/compare/main...eladsegal:public-transformers:scrolls
35 | class CustomTrainer(Seq2SeqTrainer):
36 | def __init__(
37 | self, *args, untokenized_eval_dataset=None, data_args=None, output_dir: Optional[str] = None, **kwargs
38 | ):
39 | super().__init__(*args, **kwargs)
40 | self._untokenized_eval_dataset = untokenized_eval_dataset
41 | self._max_length = data_args.val_max_target_length
42 | self._num_beams = data_args.num_beams
43 | self._output_dir = output_dir
44 | self._data_args = data_args
45 | self.mock_predictions_to_assign_zero_metric_score = self.tokenizer.encode("TOO_MANY_INPUT_TOKENS",return_tensors="np")[0]
46 |
47 | def prediction_step(
48 | self,
49 | model: nn.Module,
50 | inputs: Dict[str, Union[torch.Tensor, Any]],
51 | prediction_loss_only: bool,
52 | ignore_keys: Optional[List[str]] = None,
53 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
54 | """
55 | Perform an evaluation step on `model` using `inputs`.
56 |
57 | Subclass and override to inject custom behavior.
58 |
59 | Args:
60 | model (`nn.Module`):
61 | The model to evaluate.
62 | inputs (`Dict[str, Union[torch.Tensor, Any]]`):
63 | The inputs and targets of the model.
64 |
65 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
66 | argument `labels`. Check your model's documentation for all accepted arguments.
67 | prediction_loss_only (`bool`):
68 | Whether or not to return the loss only.
69 |
70 | Return:
71 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
72 | labels (each being optional).
73 | """
74 | if not ("labels" in inputs or 'decoder_input_ids' in inputs):
75 | logger.warning('When computing loss, must give labels or decoder_input_ids. '
76 | 'If you only perform prediction, you can safely ignore this message')
77 | # This is an issue here because the input may be longer than the max-output length of the model,
78 | # and if nothing was given it will shift the input and use it to compute loss (and later discard it).
79 | # This may cause an indexing error when absolute embeddings are used (CUDA device side assert)
80 | inputs['decoder_input_ids'] = inputs['input_ids'][:,:2].clone() # dummy outputs
81 |
82 | if not self.args.predict_with_generate or prediction_loss_only:
83 | return super().prediction_step(
84 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
85 | )
86 |
87 | has_labels = "labels" in inputs
88 | inputs = self._prepare_inputs(inputs)
89 |
90 | # XXX: adapt synced_gpus for fairscale as well
91 | gen_kwargs = self._gen_kwargs.copy()
92 | gen_kwargs["max_length"] = (
93 | gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.model.config.max_length
94 | )
95 | gen_kwargs["num_beams"] = (
96 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
97 | )
98 | default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
99 | gen_kwargs["synced_gpus"] = (
100 | gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
101 | )
102 |
103 | if "attention_mask" in inputs:
104 | gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
105 | if "global_attention_mask" in inputs:
106 | gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)
107 |
108 | # --------------------- addition compared to the source file --------------------
109 | if 'prefix_length' in inputs:
110 | gen_kwargs['prefix_length'] = inputs['prefix_length']
111 | _clean_memory()
112 | # ------------------------------------------------------------------------------
113 |
114 | # prepare generation inputs
115 | # some encoder-decoder models can have varying encoder's and thus
116 | # varying model input names
117 | if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
118 | generation_inputs = inputs[self.model.encoder.main_input_name]
119 | else:
120 | generation_inputs = inputs[self.model.main_input_name]
121 |
122 | generated_tokens = self.model.generate(
123 | generation_inputs,
124 | **gen_kwargs,
125 | )
126 | # --------------------- addition compared to the source file --------------------
127 | _clean_memory()
128 | # ------------------------------------------------------------------------------
129 | # in case the batch is shorter than max length, the output should be padded
130 | if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
131 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
132 |
133 | if has_labels: # changed the order of the if's here because there is no point going through the model if there are no labels to compute the loss on..
134 | with torch.no_grad():
135 | with self.compute_loss_context_manager():
136 | outputs = model(**inputs)
137 | if self.label_smoother is not None:
138 | loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
139 | else:
140 | loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
141 | else:
142 | loss = None
143 |
144 | if self.args.prediction_loss_only:
145 | return (loss, None, None)
146 |
147 | if has_labels:
148 | labels = inputs["labels"]
149 | if labels.shape[-1] < gen_kwargs["max_length"]:
150 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
151 | else:
152 | labels = None
153 |
154 | return (loss, generated_tokens, labels)
155 |
156 | @property
157 | def _restart_generator(self):
158 | if getattr(self, '_is_restart_generator', False):
159 | self._is_restart_generator = False
160 | return True
161 | return False
162 |
163 | def set_restart_generator(self):
164 | self._is_restart_generator = True
165 |
166 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
167 | sampler = super()._get_train_sampler()
168 | try:
169 | if self._restart_generator:
170 | sampler.generator.manual_seed(self._initial_seed)
171 | else:
172 | self._initial_seed = sampler.generator.initial_seed()
173 | except Exception as e:
174 | logger.warning(f'Cannot save or set the seed of the generator: {e}')
175 | return sampler
176 |
177 | def _post_process_function(self, untokenized_eval_dataset, predictions):
178 | id_to_prediction = {}
179 | id_to_label_ids = defaultdict(list)
180 |
181 | assert len(untokenized_eval_dataset) == len(self.eval_dataset)
182 |
183 | for i, (instance, not_valid_for_eval) in enumerate(zip(untokenized_eval_dataset, self.eval_dataset["not_valid_for_eval"])):
184 | if not_valid_for_eval:
185 | id_to_prediction[instance["id"]] = self.mock_predictions_to_assign_zero_metric_score
186 | else:
187 | id_to_prediction[instance["id"]] = predictions[i]
188 |
189 | if "outputs" in instance:
190 | id_to_label_ids[instance["id"]] = instance["outputs"]
191 | else:
192 | id_to_label_ids[instance["id"]].append(instance["output"])
193 |
194 | return id_to_prediction, id_to_label_ids
195 |
196 | def evaluate(
197 | self,
198 | eval_dataset: Optional[Dataset] = None,
199 | ignore_keys: Optional[List[str]] = None,
200 | metric_key_prefix: str = "eval",
201 | untokenized_eval_dataset: Optional[Dataset] = None,
202 | **gen_kwargs
203 | ) -> Dict[str, float]:
204 | """
205 | Run evaluation and returns metrics.
206 |
207 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
208 | (pass it to the init `compute_metrics` argument).
209 |
210 | You can also subclass and override this method to inject custom behavior.
211 |
212 | Args:
213 | eval_dataset (`Dataset`, *optional*):
214 | Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
215 | not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
216 | method.
217 | ignore_keys (`List[str]`, *optional*):
218 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when
219 | gathering predictions.
220 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
221 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
222 | "eval_bleu" if the prefix is `"eval"` (default)
223 | max_length (`int`, *optional*):
224 | The maximum target length to use when predicting with the generate method.
225 | num_beams (`int`, *optional*):
226 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no
227 | beam search.
228 | gen_kwargs:
229 | Additional `generate` specific kwargs.
230 |
231 | Returns:
232 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
233 | dictionary also contains the epoch number which comes from the training state.
234 | """
235 |
236 | gen_kwargs = gen_kwargs.copy()
237 | gen_kwargs["max_length"] = (
238 | gen_kwargs["max_length"] if gen_kwargs.get("max_length") is not None else self.args.generation_max_length
239 | )
240 | gen_kwargs["num_beams"] = (
241 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
242 | )
243 | self._gen_kwargs = gen_kwargs
244 |
245 | self._memory_tracker.start()
246 |
247 | eval_dataloader = self.get_eval_dataloader(eval_dataset)
248 | # ----------------------------------- Added -----------------------------------
249 | untokenized_eval_dataset = (
250 | self._untokenized_eval_dataset if untokenized_eval_dataset is None else untokenized_eval_dataset
251 | )
252 | compute_metrics = self.compute_metrics
253 | self.compute_metrics = None
254 | # -----------------------------------------------------------------------------
255 |
256 | start_time = time.time()
257 |
258 | eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
259 | try:
260 | output = eval_loop(
261 | eval_dataloader,
262 | description="Evaluation",
263 | # No point gathering the predictions if there are no metrics, otherwise we defer to
264 | # self.args.prediction_loss_only
265 | prediction_loss_only=None, # MODIFIED since we need the predictions
266 | ignore_keys=ignore_keys,
267 | metric_key_prefix=metric_key_prefix,
268 | )
269 | finally:
270 | # ----------------------------------- Added -----------------------------------
271 | # revert the compute metrics back
272 | self.compute_metrics = compute_metrics
273 | # -----------------------------------------------------------------------------
274 |
275 | # ----------------------------------- Added -----------------------------------
276 | # compute our metrics
277 | if output.predictions is not None:
278 | eval_preds = self._post_process_function(untokenized_eval_dataset, output.predictions)
279 |
280 | if self._output_dir is not None and self.is_world_process_zero():
281 | predictions = decode(eval_preds[0], self.tokenizer, self._data_args)
282 | output_prediction_file = os.path.join(
283 | self._output_dir, f"generated_predictions_eval_{self.state.global_step}.json"
284 | )
285 | with open(output_prediction_file, "w") as writer:
286 | json.dump(predictions, writer, indent=4)
287 |
288 | output_labels_file = os.path.join(
289 | self._output_dir, f"eval_labels.json"
290 | )
291 | if not os.path.isfile(output_labels_file):
292 | with open(output_labels_file, "w") as writer:
293 | json.dump(eval_preds[1], writer, indent=4)
294 |
295 | if self.compute_metrics is not None:
296 | output.metrics.update(self.compute_metrics(*eval_preds))
297 |
298 | # Prefix all keys with metric_key_prefix + '_'
299 | for key in list(output.metrics.keys()):
300 | if not key.startswith(f"{metric_key_prefix}_"):
301 | output.metrics[f"{metric_key_prefix}_{key}"] = output.metrics.pop(key)
302 | # -----------------------------------------------------------------------------
303 |
304 | total_batch_size = self.args.eval_batch_size * self.args.world_size
305 | output.metrics.update(
306 | speed_metrics(
307 | metric_key_prefix,
308 | start_time,
309 | num_samples=output.num_samples,
310 | num_steps=math.ceil(output.num_samples / total_batch_size),
311 | )
312 | )
313 |
314 | self.log(output.metrics)
315 |
316 | if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
317 | # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
318 | xm.master_print(met.metrics_report())
319 |
320 | self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
321 |
322 | self._memory_tracker.stop_and_update_metrics(output.metrics)
323 |
324 | return output.metrics
325 |
--------------------------------------------------------------------------------
/examples/seq2seq/utils/decoding.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import numpy as np
3 |
4 |
5 | def decode(id_to_something, tokenizer=None, data_args=None):
6 | decode_fn = None
7 | switch_case = None
8 | elem = next(iter(id_to_something.values()))
9 | if isinstance(elem, str):
10 | switch_case = -1
11 | decode_fn = lambda text: text.strip()
12 | elif isinstance(elem, list) and not isinstance(elem[0], int):
13 | if isinstance(elem[0], str):
14 | switch_case = 0
15 | decode_fn = lambda texts: [text.strip() for text in texts]
16 | else:
17 | switch_case = 1
18 | decode_fn = lambda token_ids_list: [
19 | text.strip()
20 | for text in partial(
21 | tokenizer.batch_decode, skip_special_tokens=True, clean_up_tokenization_spaces=True
22 | )(token_ids_list)
23 | ]
24 | else:
25 | switch_case = 2
26 | decode_fn = lambda token_ids: partial(
27 | tokenizer.decode, skip_special_tokens=True, clean_up_tokenization_spaces=True
28 | )(token_ids).strip()
29 |
30 | id_to_text = {}
31 | for id_, something in id_to_something.items():
32 | if switch_case == -1 or switch_case == 0:
33 | obj_to_decode = something
34 | else:
35 | if data_args is None:
36 | data_args = {}
37 | if not isinstance(data_args, dict):
38 | data_args = vars(data_args)
39 | if data_args.get("ignore_pad_token_for_loss", True):
40 | # Replace -100 in the token_ids as we can't decode them.
41 | if switch_case == 1:
42 | token_ids_list = something
43 | for i in range(len(token_ids_list)):
44 | token_ids_list[i] = _replace_padding(token_ids_list[i], tokenizer.pad_token_id)
45 | obj_to_decode = token_ids_list
46 | elif switch_case == 2:
47 | token_ids = something
48 | token_ids = _replace_padding(token_ids, tokenizer.pad_token_id)
49 | obj_to_decode = token_ids
50 | else:
51 | obj_to_decode = something
52 |
53 | id_to_text[id_] = decode_fn(obj_to_decode)
54 |
55 | return id_to_text
56 |
57 |
58 | def _replace_padding(token_ids: np.array, pad_token_id):
59 | return np.where(token_ids != -100, token_ids, pad_token_id)
60 |
--------------------------------------------------------------------------------
/examples/seq2seq/utils/duplicates.py:
--------------------------------------------------------------------------------
1 | def drop_duplicates_in_input(untokenized_dataset):
2 | indices_to_keep = []
3 | id_to_idx = {}
4 | outputs = []
5 | for i, (id_, output) in enumerate(zip(untokenized_dataset["id"], untokenized_dataset["output"])):
6 | if id_ in id_to_idx:
7 | outputs[id_to_idx[id_]].append(output)
8 | continue
9 | indices_to_keep.append(i)
10 | id_to_idx[id_] = len(outputs)
11 | outputs.append([output])
12 | untokenized_dataset = untokenized_dataset.select(indices_to_keep).flatten_indices()
13 | untokenized_dataset = untokenized_dataset.remove_columns("output")
14 | untokenized_dataset = untokenized_dataset.add_column("outputs", outputs)
15 | return untokenized_dataset
16 |
--------------------------------------------------------------------------------
/examples/seq2seq/utils/override_training_args.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import sys
4 |
5 | import torch.cuda
6 | from transformers.utils import logging
7 |
8 | sys.path.insert(0, os.getcwd())
9 |
10 | from dataclasses import dataclass, field
11 |
12 | from transformers.trainer_utils import IntervalStrategy
13 | from transformers import Seq2SeqTrainingArguments
14 |
15 | logger = logging.get_logger('swed_logger')
16 |
17 | @dataclass
18 | class TrainingOverridesArguments(Seq2SeqTrainingArguments):
19 | """
20 | To use if, it requires evaluation_strategy == IntervalStrategy.STEPS
21 | """
22 | eval_steps_override: float = field(default=0, metadata={"help": "a fraction, to set the the save_steps w.r.t to number of steps in "
23 | "a single epoch. changes eval_steps. 0 to disable (default)"})
24 | save_steps_override: float = field(default=0, metadata={"help": "a fraction, to set the the save_steps w.r.t to number of steps in "
25 | "a single epoch. changes save_steps. must be a multiple of eval_steps"
26 | " (or eval_steps_override if given). 0 to disable (default)"})
27 |
28 | eval_fraction: float = field(default=1, metadata={
29 | "help": "A float in (0,1] that corresponds to how much of the eval set to use during evaluations "
30 | "(same subset all the time) or an integer >= 2 which amounts to the absolute number of training "
31 | "samples to use. 1. to disable it and use the entire eval set "})
32 |
33 | use_auth_token: bool = field(
34 | default=False,
35 | metadata={
36 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
37 | "with private models). If AUTH_TOKEN is set as an environment variable, would use that"
38 | },
39 | )
40 |
41 | fp16_padding: bool = field(
42 | default=False,
43 | metadata={"help": "Whether to use padding for fp16"},
44 | )
45 |
46 |
47 | def __post_init__(self):
48 | super(TrainingOverridesArguments, self).__post_init__()
49 | if self.eval_steps_override > 0 or self.save_steps_override > 0:
50 | if self.evaluation_strategy != IntervalStrategy.STEPS:
51 | raise ValueError(
52 | f"using eval/save steps override requires evaluation strategy to be {IntervalStrategy.STEPS}"
53 | )
54 | if self.save_steps_override == 0 or self.eval_steps_override == 0:
55 | raise ValueError(
56 | f"using eval/save steps override requires both overrides to be non zero"
57 | )
58 | diff = (self.save_steps_override / self.eval_steps_override) % 1
59 | if min(1-diff, diff) > 1e-5: # we do it like that to support fractions modulo as well, with loss of precision
60 | raise ValueError(
61 | f"using eval/save steps override requires save steps override to be a multiple of eval_steps_override"
62 | )
63 | if self.use_auth_token and 'AUTH_TOKEN' in os.environ:
64 | self.use_auth_token = os.getenv('AUTH_TOKEN')
65 |
66 | @property
67 | def effective_batch_size(self):
68 | if not hasattr(self, '_ebs'):
69 | n_gpu = self.n_gpu if torch.cuda.is_available() else 1 # may be on cpu
70 | self._ebs = self.per_device_train_batch_size * self.gradient_accumulation_steps * n_gpu
71 | logger.warning(f'Training with {self.per_device_train_batch_size} per_device_train_size, {self.n_gpu} gpus and '
72 | f'{self.gradient_accumulation_steps} gradient accumulation steps, resulting in {self._ebs} effective batch size')
73 | return self._ebs
74 |
75 | def apply_overrides(self, dataset_size):
76 | if self.eval_steps_override == 0:
77 | return
78 | es, ss = self.eval_steps, self.save_steps
79 | total_steps_per_epoch = dataset_size / self.effective_batch_size # note that this may not be an integer
80 | eval_steps = int(total_steps_per_epoch * self.eval_steps_override)
81 | if eval_steps >= self.logging_steps:
82 | if eval_steps % self.logging_steps != 0:
83 | logger.warning(f'Eval steps override would result in eval every {eval_steps} steps, but it is not a '
84 | f'multiple of logging steps ({self.logging_steps}) so changing to '
85 | f'{eval_steps + self.logging_steps - eval_steps % self.logging_steps}')
86 | eval_steps = eval_steps + self.logging_steps - eval_steps % self.logging_steps
87 | elif eval_steps < self.logging_steps:
88 | logger.warning(f'Eval steps override would result in eval every {eval_steps} steps, but it is not a '
89 | f'multiple of logging steps ({self.logging_steps}) so changing to {self.logging_steps}')
90 | eval_steps = self.logging_steps
91 | self.eval_steps = eval_steps
92 |
93 | save_steps = int(total_steps_per_epoch * self.save_steps_override)
94 | if save_steps < eval_steps or save_steps % eval_steps != 0:
95 | logger.warning(f'Save steps override would result in eval every {save_steps} steps, but it is not a '
96 | f'multiple of eval steps ({eval_steps}) so changing to '
97 | f'{save_steps + eval_steps - save_steps % self.eval_steps}')
98 | save_steps = save_steps + eval_steps - save_steps % self.eval_steps
99 | self.save_steps = save_steps
100 |
101 | logger.warning(f'Using overrides with dataset of size {dataset_size} and effective batch size of '
102 | f'{self.effective_batch_size}, moving from (eval_steps, save_steps) '
103 | f'of {(es, ss)} to {(self.eval_steps, self.save_steps)}')
104 |
--------------------------------------------------------------------------------
/examples/usage_example.py:
--------------------------------------------------------------------------------
1 | """
2 | Minimal working example to use an X-SLED model
3 | """
4 | import torch
5 | from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
6 | # noinspection PyUnresolvedReferences
7 | import sled # *** required so that SledModels will be registered for the AutoClasses ***
8 |
9 | if __name__ == '__main__':
10 | # Load the model and tokenizer for Bart-base-SLED
11 | bart_base_sled_model = AutoModel.from_pretrained('tau/bart-base-sled')
12 | tokenizer = AutoTokenizer.from_pretrained('tau/bart-base-sled')
13 | bart_base_sled_model.eval()
14 |
15 | # The below example is for cases where there are no prefix (e.g. question) to use, such as summarization
16 | document_input_ids = tokenizer(
17 | "Studies have been shown that owning a dog is good for you", return_tensors="pt"
18 | ) # Batch size 1
19 | with torch.no_grad():
20 | final_representations = bart_base_sled_model(**document_input_ids, return_dict=None).last_hidden_state
21 |
22 | # Now, assume we do have a prefix (for example in question answering)
23 | prefix_input_ids = tokenizer(
24 | "Is owning a dog good for you?", return_tensors="pt"
25 | ).input_ids # Batch size 1
26 |
27 | # we concatenate them together, but tell SLED where is the prefix by setting the prefix_length tensor
28 | input_ids = torch.cat((prefix_input_ids, document_input_ids.input_ids), dim=-1)
29 | attention_mask = torch.ones_like(input_ids)
30 | prefix_length = torch.LongTensor([[prefix_input_ids.size(1)]])
31 | with torch.no_grad():
32 | final_representations = bart_base_sled_model(input_ids=input_ids, attention_mask=attention_mask,
33 | prefix_length=prefix_length, return_dict=None).last_hidden_state
34 |
35 | # However, we are dealing with a generative model here (encoder-decoder), so, we can use it to generate
36 | bart_base_sled_model = AutoModelForSeq2SeqLM.from_pretrained('tau/bart-base-sled')
37 | with torch.no_grad():
38 | generated_output = bart_base_sled_model.generate(input_ids=input_ids,
39 | prefix_length=prefix_length, return_dict=None)
40 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | # This setup file was created with the help of https://github.com/MichaelKim0407/tutorial-pip-package
3 |
4 | extra_t5 = [
5 | 'protobuf<=3.20.1',
6 | 'sentencepiece'
7 | ]
8 | extra_examples = [
9 | 'nltk',
10 | 'datasets>=1.17.0',
11 | 'absl-py',
12 | 'rouge-score',
13 | 'pandas'
14 | ]
15 |
16 | extra_dev = [
17 | *extra_t5,
18 | *extra_examples,
19 | ]
20 |
21 | setup(
22 | name='py-sled',
23 | version='0.1.7',
24 |
25 | python_requires='>=3.7.0',
26 | install_requires=[
27 | 'transformers==4.21.0',
28 | 'makefun>=1.14.0',
29 | ],
30 | extras_require={
31 | 't5': extra_t5,
32 | 'examples': extra_examples,
33 | 'dev': extra_dev
34 | },
35 | description='SLED models use pretrained, short-range encoder-decoder models, and apply them over long-text inputs '\
36 | 'by splitting the input into multiple overlapping chunks, encoding each independently and '\
37 | 'perform fusion-in-decoder',
38 |
39 | url='https://github.com/Mivg/SLED',
40 | author='Maor Ivgi',
41 | author_email='maor.ivgi@cs.tau.ac.il',
42 |
43 | packages=find_packages(exclude=("tests*", "examples*")), # also added to manifest.in due to https://stackoverflow.com/a/46320848
44 |
45 | classifiers=[
46 | 'Intended Audience :: Developers',
47 | 'Programming Language :: Python',
48 | 'Programming Language :: Python :: 3',
49 | 'License :: OSI Approved :: MIT License',
50 | 'Natural Language :: English',
51 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
52 | 'Topic :: Text Processing'
53 | ],
54 | )
55 | # build with `python setup.py sdist`
56 | # upload with `python3 -m twine upload dist/`
--------------------------------------------------------------------------------
/sled/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.1.7'
2 |
3 | try:
4 | # noinspection PyPackageRequirements
5 | import torch
6 | except ImportError:
7 | raise ImportError('Using sled requires torch. Please refer to https://pytorch.org/get-started/locally/ '
8 | 'to install the correct version for your setup')
9 |
10 | from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoModelForSeq2SeqLM
11 |
12 | from .configuration_sled import SledConfig
13 | # noinspection PyUnresolvedReferences
14 | from .modeling_sled import SledModel, SledForConditionalGeneration, PREFIX_KEY
15 | from .tokenization_sled import SledTokenizer
16 | from .tokenization_sled_fast import SledTokenizerFast
17 |
18 | AutoConfig.register('tau/sled', SledConfig)
19 | AutoModel.register(SledConfig, SledModel)
20 | AutoModelForSeq2SeqLM.register(SledConfig, SledForConditionalGeneration)
21 | AutoTokenizer.register(SledConfig, slow_tokenizer_class=SledTokenizer, fast_tokenizer_class=SledTokenizerFast)
--------------------------------------------------------------------------------
/sled/configuration_sled.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any
2 |
3 | from transformers import PretrainedConfig, AutoConfig
4 |
5 |
6 | class SledConfig(PretrainedConfig):
7 | r"""
8 | """
9 | model_type = "tau/sled"
10 |
11 | def __init__(self, underlying_config="facebook/bart-base", context_size=256, window_fraction=0.5,
12 | prepend_prefix=True, encode_prefix=True, sliding_method='dynamic', **kwargs):
13 | super().__init__(**kwargs)
14 | parent_only_keys = set(super().to_dict().keys())
15 | self.underlying_config = underlying_config
16 | self.context_size = context_size
17 | self.window_fraction = window_fraction
18 | self.prepend_prefix = prepend_prefix
19 | self.encode_prefix = encode_prefix
20 | self.sliding_method = sliding_method
21 |
22 | # load underlying_config
23 | config = AutoConfig.from_pretrained(underlying_config, **kwargs)
24 | # update internal dict based on the underlying config, overriding everything EXCEPT what was explicitly set here
25 | ignore_keys = set(self.to_dict().keys()).union(type(self).__dict__.keys()).difference(parent_only_keys)
26 | self.__dict__.update({k: v for k, v in config.to_dict().items() if k not in ignore_keys})
27 |
--------------------------------------------------------------------------------
/sled/modeling_sled.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import inspect
3 | import os
4 | import warnings
5 | from typing import Dict, Any, Optional
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from makefun import create_function
10 | from requests.exceptions import HTTPError
11 | from torch import nn
12 | from transformers import PreTrainedModel, AutoModel, AutoModelForSeq2SeqLM, WEIGHTS_NAME
13 | from transformers.generation_utils import GenerationMixin
14 | from transformers.modeling_outputs import BaseModelOutput
15 | from transformers.utils import logging
16 |
17 |
18 | from .configuration_sled import SledConfig
19 |
20 | logger = logging.get_logger(__name__)
21 |
22 |
23 | PREFIX_KEY = 'prefix_length'
24 |
25 | def _legacy_download_weights(kwargs, pretrained_model_name_or_path):
26 | # noinspection PyUnresolvedReferences
27 | from transformers.utils import hf_bucket_url, cached_path, EntryNotFoundError
28 | archive_file = hf_bucket_url(
29 | pretrained_model_name_or_path,
30 | filename=WEIGHTS_NAME,
31 | revision=kwargs.pop("revision", None),
32 | mirror=kwargs.pop("mirror", None),
33 | subfolder=kwargs.pop("subfolder", None),
34 | )
35 | logger.info(f'Looking for pretrained weights on {archive_file}')
36 |
37 | try:
38 | # Load from URL or cache if already cached
39 | user_agent = {"file_type": "model", "framework": "pytorch",
40 | "from_auto_class": kwargs.pop("_from_auto", False)}
41 | from_pipeline = kwargs.pop("_from_pipeline", None)
42 | if from_pipeline is not None:
43 | user_agent["using_pipeline"] = from_pipeline
44 | resolved_archive_file = cached_path(
45 | archive_file,
46 | cache_dir=kwargs.pop("cache_dir", None),
47 | force_download=kwargs.pop("force_download", False),
48 | proxies=kwargs.pop("proxies", None),
49 | resume_download=kwargs.pop("resume_download", False),
50 | local_files_only=kwargs.pop("local_files_only", False),
51 | use_auth_token=kwargs.pop("use_auth_token", None),
52 | user_agent=user_agent,
53 | )
54 |
55 | logger.info(f'Successfully downloaded the weights to {resolved_archive_file}')
56 | return torch.load(resolved_archive_file, map_location="cpu")
57 | except EntryNotFoundError:
58 | logger.info('Did not find a SLED weights file to be loaded.'
59 | ' If this is not unexpected, please reach out. '
60 | 'Note - loading sharded weight file is not currently supported')
61 | return None
62 |
63 | def _download_weights(kwargs, pretrained_model_name_or_path):
64 | try:
65 | from transformers.utils import cached_file, EntryNotFoundError, HUGGINGFACE_CO_RESOLVE_ENDPOINT
66 | # Load from URL
67 | user_agent = {"file_type": "model", "framework": "pytorch",
68 | "from_auto_class": kwargs.pop("_from_auto", False)}
69 | cached_filename = cached_file(
70 | pretrained_model_name_or_path,
71 | WEIGHTS_NAME, # shard_filename,
72 | cache_dir=kwargs.pop("cache_dir", None),
73 | force_download=kwargs.pop("force_download", False),
74 | proxies=kwargs.pop("proxies", None),
75 | resume_download=kwargs.pop("resume_download", False),
76 | local_files_only=kwargs.pop("local_files_only", False),
77 | use_auth_token=kwargs.pop("use_auth_token", None),
78 | user_agent=user_agent,
79 | revision=kwargs.pop("revision", None),
80 | subfolder=kwargs.pop("subfolder", None),
81 | _commit_hash=kwargs.pop("_commit_hash", None),
82 | )
83 | logger.info(f'Successfully downloaded the weights to {cached_filename}')
84 | return torch.load(cached_filename, map_location="cpu")
85 | # We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
86 | # we don't have to catch them here.
87 | except ImportError:
88 | # probably an older version of transformers. try to fallback on legacy load
89 | logger.info('Could not find a weights file with the new huggingface api hub. Attempting fallback to the old api')
90 | except (OSError, EntryNotFoundError) as e:
91 | logger.info('Did not find a SLED weights file to be loaded.'
92 | ' If this is not unexpected, please reach out. '
93 | f'Note - loading sharded weight file is not currently supported ({e})')
94 | return None
95 | except HTTPError:
96 | raise EnvironmentError(
97 | f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {WEIGHTS_NAME}. You should try"
98 | " again after checking your internet connection."
99 | )
100 |
101 | try:
102 | return _legacy_download_weights(kwargs, pretrained_model_name_or_path) # attempt fallback
103 | except ImportError:
104 | logger.error('Could not find a SLED weights file to be loaded due to unmatching transformers version. '
105 | 'Please open as issue on https://github.com/Mivg/SLED/issues')
106 | raise
107 |
108 |
109 | def _find_tensor_inds_and_size(*args, **kwargs):
110 | args_tensor_inds = [i for i, v in enumerate(args) if isinstance(v, torch.Tensor)]
111 | kwargs_tensor_keys = [k for k, v in kwargs.items() if isinstance(v, torch.Tensor)]
112 |
113 | assert len(args_tensor_inds) + len(kwargs_tensor_keys) > 0, "no tensors were found"
114 | # tensor are 2d at this point with [N, s], N is the number of inputs, s is the (padded) sequence length
115 | size = args[args_tensor_inds[0]].size() if len(args_tensor_inds) > 0 else kwargs[kwargs_tensor_keys[0]].size()
116 | assert len(size) == 2, f"expected 2-d tensors but got tensors of shape: {size}"
117 | _, s = size
118 | # can also assert that all tensors here share the same size, but we will skip this for efficiency’s sake and make
119 | # this assumption
120 |
121 | return args_tensor_inds, kwargs_tensor_keys, s
122 |
123 |
124 | def _pad_if_needed(v, pad=None):
125 | if pad is None:
126 | return v
127 | return torch.pad
128 |
129 |
130 | def _concat_input_prefix_if_needed(v, start_ind, end_ind, prefix_length=None, pad=None):
131 | if prefix_length is None:
132 | v = v[:, start_ind:end_ind]
133 | if pad is not None and pad != 0:
134 | # assert v.dim() == 2
135 | assert pad >= 0, f'padding should be non negative but it is negative ({pad})'
136 | v = F.pad(v, (0, pad), "constant", 0)
137 | # padding only from the right so (0,pad), and only in the second axis so not length 4
138 | return v
139 | return torch.cat((v[:, :prefix_length], v[:, prefix_length + start_ind:prefix_length + end_ind]), axis=1)
140 |
141 |
142 | def _fix_args(args, args_tensor_inds, start_ind, end_ind, prefix_length=None, pad=None):
143 | return tuple(v if i not in args_tensor_inds else
144 | _concat_input_prefix_if_needed(v, start_ind, end_ind, prefix_length, pad)
145 | for i, v in enumerate(args))
146 |
147 |
148 | def _fix_kwargs(kwargs, kwargs_tensor_keys, start_ind, end_ind, prefix_length=None, pad=None):
149 | return {
150 | k: v if (k not in kwargs_tensor_keys or k.startswith("decoder")) else
151 | _concat_input_prefix_if_needed(v, start_ind, end_ind, prefix_length, pad)
152 | for k, v in kwargs.items()
153 | }
154 |
155 |
156 | def _stack_args(stack_args, args_tensor_inds):
157 | return tuple(v if i not in args_tensor_inds else torch.cat(tuple(si[i] for si in stack_args))
158 | for i, v in enumerate(stack_args[0]))
159 |
160 |
161 | def _stack_kwargs(stack_kwargs, kwargs_tensor_keys):
162 | try:
163 | return {k: v if k not in kwargs_tensor_keys else torch.cat(tuple(si[k] for si in stack_kwargs))
164 | for k, v in stack_kwargs[0].items()}
165 | except RuntimeError as e:
166 | for k in kwargs_tensor_keys:
167 | logger.warning(f'problematic key={k}. size={tuple(si[k].size() for si in stack_kwargs)}')
168 | if str(e).startswith('Sizes of tensors must match except in dimension 0'):
169 | logger.warning('Most likely you passed in non-padded batch. make sure all examples in the batch have the same length')
170 | raise
171 |
172 |
173 | def _unstack_encoder_outputs(stacked_output, n, bs):
174 | if isinstance(stacked_output, tuple):
175 | return [tuple(v if not isinstance(v, torch.Tensor) else v[i * bs:(i + 1) * bs] for v in stacked_output)
176 | for i in range(n)]
177 | # works for dict as well as structured outputs
178 | return [type(stacked_output)(**{k: v if not isinstance(v, torch.Tensor) else v[i*bs:(i+1)*bs]
179 | for k, v in stacked_output.items()})
180 | for i in range(n)]
181 |
182 |
183 | def _extract_keyword_args(kwargs, arg_names, prefix=None):
184 | new_kwargs = {arg_name: kwargs.get(arg_name, None) for arg_name in arg_names}
185 | if prefix is not None:
186 | for arg_name in arg_names:
187 | new_arg_name = prefix + arg_name
188 | if new_arg_name in kwargs and new_arg_name not in new_kwargs:
189 | new_kwargs[arg_name] = kwargs[new_arg_name]
190 | return new_kwargs
191 |
192 |
193 | def _slice_tensor(v, start, end, prefix_length=None):
194 | prefix_length = prefix_length or 0
195 | return v[:, start+prefix_length:end+prefix_length]
196 |
197 |
198 | def _merge_encoder_outputs(encoder_outputs_list):
199 | # a list of 4-tuples, first value is the returned value from the encoder, then the start and end indices inside the
200 | # tensors that we should take, and finally prefix_length (None if was not used)
201 |
202 | # presumed order of returned tuple from encoders: last_hidden_state, hidden_states, attentions
203 |
204 | # the first output, as returned by the underlying model on the first window
205 | resulting_output = encoder_outputs_list[0][0]
206 | if isinstance(resulting_output, tuple): # not in return dict mode:
207 | resulting_list = []
208 | for i in range(len(resulting_output)):
209 | if resulting_output[i] is None:
210 | resulting_list.append(None)
211 | elif (
212 | isinstance(resulting_output[i], (int, float, tuple, MockTensorForConditionalGeneration))
213 | or resulting_output[i].dim() != 3
214 | ):
215 | continue
216 | else:
217 | assert isinstance(resulting_output[i], torch.Tensor)
218 | # tensors are of of size (N, w, d), N the batch size, w the current window size and d the hidden
219 | # state size/logits dimension these are the only parts in the encoder output that we need to merge
220 | # between windows
221 | resulting_list.append(
222 | torch.cat(tuple(_slice_tensor(out[i], start, end, prefix_length)
223 | for out, start, end, prefix_length in encoder_outputs_list), dim=1)
224 | ) # requires extra GPU memory because it doesn't dump the old copy of the tensors yet
225 | resulting_output = tuple(resulting_list)
226 | else:
227 | for key in resulting_output.keys():
228 | if resulting_output[key] is None:
229 | continue
230 | if isinstance(resulting_output[key], tuple):
231 | resulting_output[key] = None # encoder outputs are not tuples, only the decoders
232 | else:
233 | assert isinstance(resulting_output[key], torch.Tensor)
234 | if resulting_output[key].dim() != 3:
235 | continue # decoder outputs may be 4d tensors
236 | # tensors are of of size (N, w, d), N the batch size, w the current window size and d the hidden
237 | # state size/logits dimension
238 | resulting_output[key] = torch.cat(
239 | tuple(_slice_tensor(out[key], start, end, prefix_length)
240 | for out, start, end, prefix_length in encoder_outputs_list), dim=1
241 | )
242 |
243 | return resulting_output
244 |
245 |
246 | class MockDecoder(nn.Module):
247 | def forward(self, *_, **__):
248 | return tuple()
249 |
250 | def to(self, *_, **__):
251 | return self
252 |
253 |
254 | class SledPretrainedModel(PreTrainedModel, metaclass=abc.ABCMeta):
255 | config_class = SledConfig
256 | auto_model_loader = AutoModel
257 | IGNORE_CONFIG_KEYS = {'model_type', '_name_or_path'} # config keys we allow to be mismatched between the
258 | # SledConfig and the underlying model's config
259 |
260 | def __init__(self, underlying_model: PreTrainedModel, config: SledConfig):
261 | """
262 |
263 | :param underlying_model: The backbone model to use.
264 | Warning - once given, it should not be used directly, as it may cause unexpected behaviours
265 | :param config:
266 | """
267 | super(SledPretrainedModel, self).__init__(config)
268 | self._underlying_model = (
269 | underlying_model # crucial this will be before any calls to members that is in the base model
270 | )
271 |
272 | self._window_fraction = config.window_fraction
273 | self._context_size = config.context_size
274 | self._window_margin = int(config.context_size * (1 - config.window_fraction) / 2)
275 | self._sliding_method = config.sliding_method or 'dynamic'
276 | assert self._sliding_method in {'loop', 'stacked', 'dynamic', 'decoderonly'}
277 |
278 | for override_member in ['is_parallelizable', 'supports_gradient_checkpointing']:
279 | setattr(self, override_member, getattr(underlying_model, override_member))
280 |
281 | # setting the base_model_prefix to return the correct underlying model and link to some methods
282 | # implemented in the base
283 | self.base_model_prefix = "sled_base_model_prefix"
284 | self.sled_base_model_prefix = self._underlying_model.base_model
285 |
286 | # override generation preparation functions that may be overridden by underlying models but will be
287 | # found in our wrapper. We wished we could do it a follows:
288 | # for method_name, _ in inspect.getmembers(PreTrainedModel, predicate=inspect.isfunction):
289 | # if method_name not in {"_replicate_for_data_parallel", 'modules'}:
290 | # setattr(self, method_name, getattr(underlying_model, method_name))
291 | # However, the above is too broad and dangerous, so we will do it directly
292 | for method_name in {"_init_weights", "prepare_inputs_for_generation"}:
293 | if hasattr(underlying_model, method_name):
294 | setattr(self, method_name, getattr(underlying_model, method_name))
295 |
296 | # set the resize_token_embeddings
297 | vocab_size = underlying_model.get_input_embeddings().weight.size(0)
298 | assert hasattr(self.config, 'vocab_size'), 'Underlying models must have a vocab_size config'
299 | assert underlying_model.config.vocab_size == vocab_size
300 | self.resize_token_embeddings(vocab_size) # the underlying model may have a different vocab size compared to its base config
301 |
302 | self._verified_config_consistency = False
303 | self._verify_config_consistency()
304 | self._verified_config_consistency = False # We would like to do it later again (before the first forward)
305 |
306 | self._prepend_prefix = config.prepend_prefix
307 | self._encode_prefix = config.encode_prefix
308 |
309 | # now, let's create the forward function
310 | self._create_forward_function()
311 |
312 | @classmethod
313 | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
314 | assert isinstance(pretrained_model_name_or_path, str), 'pretrained_model_name_or_path must be a path to a ' \
315 | 'checkpoint or a local config file (json)'
316 | config = kwargs.pop('config', None) or \
317 | cls.config_class.from_pretrained(pretrained_model_name_or_path,
318 | use_auth_token=kwargs.get('use_auth_token', False))
319 | state_dict = kwargs.pop('state_dict', None)
320 | if os.path.exists(pretrained_model_name_or_path):
321 | # if pretrained_model_name_or_path is a saved checkpoint
322 | if pretrained_model_name_or_path.endswith('json'):
323 | underlying_model = cls.auto_model_loader.from_pretrained(config.underlying_config, *model_args, **kwargs)
324 | else:
325 | # otherwise, it is a config json path + weights. Note LSED doesn't have any weights of its own,
326 | # so the state dict is only for the underlying model
327 | backbone_state_dict = torch.load(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME), map_location="cpu")
328 | underlying_model = cls.auto_model_loader.from_pretrained(config.underlying_config, *model_args,
329 | **kwargs)
330 | model = cls(underlying_model, config)
331 | cls._load_state_dict(backbone_state_dict, model, underlying_model)
332 | return model
333 |
334 | else:
335 | # assume it is a model card on the hub )
336 | underlying_model = cls.auto_model_loader.from_pretrained(config.underlying_config, *model_args, **kwargs)
337 |
338 | state_dict = cls._load_remote_state_dict(kwargs, pretrained_model_name_or_path, state_dict)
339 |
340 | sled_model = cls(underlying_model, config)
341 | if state_dict is not None:
342 | logger.info('Updating SLED model weights from state dict')
343 | cls._load_state_dict(state_dict, sled_model)
344 | return sled_model
345 |
346 | @classmethod
347 | def _load_state_dict(cls, backbone_state_dict, model, underlying_model=None):
348 | load_result = model.load_state_dict(backbone_state_dict, strict=False)
349 | # Known issue - when loading a model checkpoint of type AutoModelForSeq2SeqLM with AutoModel,
350 | # the state dict will not be loaded correctly. Documented [here](https://github.com/Mivg/SLED/issues/4)
351 | if len(load_result.missing_keys) != 0:
352 | if underlying_model is not None and \
353 | model._keys_to_ignore_on_save is not None and \
354 | set(load_result.missing_keys) == set(model._keys_to_ignore_on_save):
355 | underlying_model.tie_weights()
356 | else:
357 | logger.warn(
358 | f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
359 | if len(load_result.unexpected_keys) != 0:
360 | logger.warn(
361 | f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}.")
362 |
363 | @classmethod
364 | def _load_remote_state_dict(cls, kwargs, pretrained_model_name_or_path, state_dict):
365 | if state_dict is None:
366 | state_dict = _download_weights(kwargs, pretrained_model_name_or_path)
367 | return state_dict
368 |
369 | def _verify_config_consistency(self):
370 | if not self._verified_config_consistency:
371 | # SLED models are built in one of the following ways:
372 | # 1. Explicitly (given a SledConfig and an underlying model)
373 | # 2. from_pretrained (local_config, hub config, saved checkpoint)
374 | # There are 2 cases where the underlying_config and the SledConfig may mismatch - 1, and 2.saved_checkpoint
375 | # Instead of deciding whether to update the underlying model's config or vice versa while we cannot know which
376 | # one is correct, it is better to raise an exception. The only key we were willing to tolerate is the
377 | # vocab size, and we dealt with it above
378 | # Note - we will only do it once on the first forward pass. Setting a config in the model after it has
379 | # been used is non advisble
380 | config_dict = self.config.to_dict()
381 | underlying_config_dict = self.underlying_model.config.to_dict()
382 | matching_keys = set(config_dict.keys()).intersection(underlying_config_dict.keys()).difference(self._ignore_keys)
383 | inconsistent_keys = {k: (config_dict[k], underlying_config_dict[k])
384 | for k in matching_keys if config_dict[k] != underlying_config_dict[k]}
385 |
386 | if len(inconsistent_keys) > 0:
387 | # raise ValueError(f'SledConfig and the underlying_model config has mismatching configurations on: {inconsistent_keys}')
388 | # if we loaded the config with overrides, there may still be some conflicts
389 | logger.warning(
390 | f'SledConfig and the underlying_model config has mismatching configurations on: {inconsistent_keys}, '
391 | f'probably due to config_overrides. Setting the underlying config to match SLEDs')
392 | for k in inconsistent_keys:
393 | setattr(self.underlying_model.config, k, getattr(self.config, k))
394 |
395 | config = self.config
396 | if self._window_fraction != config.window_fraction or \
397 | self._context_size != config.context_size or\
398 | self._window_margin != int(config.context_size * (1 - config.window_fraction) / 2) or \
399 | self._sliding_method != config.sliding_method:
400 | raise RuntimeError('SLED does not support changing its configuration after it is initialized. '
401 | 'Try reloading the model with overrides')
402 |
403 | self._verified_config_consistency = True
404 |
405 | @property
406 | def underlying_model(self):
407 | return self._underlying_model
408 |
409 | def resize_token_embeddings(self, new_num_tokens=None):
410 | res = self.underlying_model.resize_token_embeddings(new_num_tokens)
411 | self.config.vocab_size = self.underlying_model.vocab_size # sync them
412 | return res
413 |
414 | @property
415 | def _ignore_keys(self):
416 | return self.IGNORE_CONFIG_KEYS
417 |
418 | def _replicate_for_data_parallel(self):
419 | replica = super()._replicate_for_data_parallel()
420 | replica.forward = create_function(self._signature, replica._forward)
421 | return replica
422 |
423 | def _create_forward_function(self):
424 | # https://stackoverflow.com/a/15638720
425 | self._underlying_model_signature = inspect.signature(self._underlying_model.forward)
426 | self._forward_kwargs_names = [param.name for param in self._underlying_model_signature.parameters.values()]
427 | assert PREFIX_KEY not in self._forward_kwargs_names
428 | # if we want to prepend questions in every window, we need to set the forward signature to expect the
429 | # input_prefix (e.g. question) as a separate input sequence
430 |
431 | # we want to remove any typing information as it may cause issues in the custom function build do to
432 | # non-imported modules. It is ugly and shouldn't be done like that, but it works..
433 | params = [self._underlying_model_signature.parameters[p].replace(annotation=inspect.Parameter.empty)
434 | for p in self._underlying_model_signature.parameters]
435 | params.append(inspect.Parameter(name=PREFIX_KEY, default=None, kind=params[-1].kind))
436 | self._signature = str(self._underlying_model_signature.replace(parameters=params,
437 | return_annotation=inspect.Signature.empty))
438 |
439 | # HF trainer uses the signature to choose which parts to take from a dataset, so we need to make sure our
440 | # wrapped forward function has the correct signature (dynamically creating it here)
441 | self.forward = create_function(self._signature, self._forward)
442 |
443 | def __getattr__(self, item):
444 | try:
445 | return super().__getattr__(item)
446 | except AttributeError:
447 | try:
448 | return self._underlying_model.__getattribute__(item)
449 | except AttributeError:
450 | return self._underlying_model.__getattr__(item)
451 |
452 | @abc.abstractmethod
453 | def _forward(self, *args, **kwargs):
454 | # the actual forward implementation of the model.
455 | raise NotImplementedError
456 |
457 | def _set_underlying_model_attr(self, attr_name, new_val):
458 | if hasattr(self._underlying_model, attr_name):
459 | setattr(self._underlying_model, attr_name, new_val)
460 | elif hasattr(self._underlying_model, "model") and hasattr(self._underlying_model.model, attr_name):
461 | setattr(self._underlying_model.model, attr_name, new_val)
462 | else:
463 | raise ValueError(f"Cannot use this model as we cannot set its {attr_name}")
464 |
465 | def _run_sliding_window_forward(self, args_tensor_inds, kwargs_tensor_keys, s, *args,
466 | prefix_length=None, **kwargs):
467 | sm = self._sliding_method if self._sliding_method != 'dynamic' else \
468 | ('loop' if not self.training else 'stacked')
469 | try:
470 | if sm == 'decoderonly':
471 | return self._skip_forward_for_decoder_only(args_tensor_inds, kwargs_tensor_keys, s, *args,
472 | prefix_length=prefix_length, **kwargs)
473 | if sm == 'loop':
474 | return self._run_sliding_window_forward_loop(args_tensor_inds, kwargs_tensor_keys, s, *args,
475 | prefix_length=prefix_length, **kwargs)
476 | return self._run_sliding_window_forward_stacked(args_tensor_inds, kwargs_tensor_keys, s, *args,
477 | prefix_length=prefix_length, **kwargs)
478 | finally:
479 | # so that if the model crashes halfway through it will be restored to working order
480 | pass
481 |
482 | def _skip_forward_for_decoder_only(self, args_tensor_inds, kwargs_tensor_keys, s, *args,
483 | prefix_length=None, **kwargs):
484 | # NOTE - this will work probably only with BART.
485 | embeder = self if hasattr(self, 'embed_tokens') else self.get_encoder() # account for sled encoder
486 | return (embeder.embed_tokens(kwargs['input_ids']), )
487 |
488 |
489 | def _run_sliding_window_forward_loop(self, args_tensor_inds, kwargs_tensor_keys, s, *args,
490 | prefix_length=None, **kwargs):
491 | forward_kwargs = _extract_keyword_args(kwargs, self._forward_kwargs_names, None)
492 | encoder_outputs_list = []
493 | if prefix_length is not None and self._prepend_prefix:
494 | # we were given prefixes in the input, and we are expected to treat them
495 | prefix_length, s = self._handle_prefix(prefix_length, s)
496 |
497 | if self._encode_prefix:
498 | # encode the question as well, if needed
499 | context_start_ind, context_end_ind, update_start_ind, update_end_ind = 0, prefix_length, 0, prefix_length
500 |
501 | encoder_outputs = self._underlying_model.forward(
502 | *_fix_args(args, args_tensor_inds, context_start_ind, context_end_ind, None),
503 | **_fix_kwargs(forward_kwargs, kwargs_tensor_keys, context_start_ind, context_end_ind, None),
504 | )
505 | encoder_outputs_list.append((encoder_outputs, update_start_ind, update_end_ind, None))
506 | # we will need to make sure all input tensors will also drop everything with the prefix
507 | else:
508 | prefix_length = None # we need to ignore the prefix and treat the entire input as one long document
509 |
510 | for context_start_ind, context_end_ind, update_start_ind, update_end_ind in self._window_indices(s):
511 | encoder_outputs = self._underlying_model.forward(
512 | *_fix_args(args, args_tensor_inds, context_start_ind, context_end_ind, prefix_length),
513 | **_fix_kwargs(forward_kwargs, kwargs_tensor_keys, context_start_ind, context_end_ind, prefix_length),
514 | )
515 | encoder_outputs_list.append((encoder_outputs, update_start_ind, update_end_ind, prefix_length))
516 |
517 | return _merge_encoder_outputs(encoder_outputs_list)
518 |
519 | def _handle_prefix(self, prefix_length, s):
520 | prefix_length_ = prefix_length[0].detach().cpu().item()
521 | assert torch.all(prefix_length == prefix_length_).item(), \
522 | 'Using different length prefixes in the same batch is not supported. Either group your batch by ' \
523 | 'prefix length, or pad the prefixes to match in length (and do not forget to set the attention ' \
524 | 'mask to 0 where appropriate)'
525 | if hasattr(self.underlying_model.config, 'max_position_embeddings'):
526 | assert self._context_size + prefix_length_ <= self.underlying_model.config.max_position_embeddings, \
527 | f'The prefix length + SLEDs chunk size must be at most the max length that the backbone model can handle'
528 | return prefix_length_, s-prefix_length_
529 |
530 | def _run_sliding_window_forward_stacked(self, args_tensor_inds, kwargs_tensor_keys, s, *args,
531 | prefix_length=None, **kwargs):
532 | forward_kwargs = _extract_keyword_args(kwargs, self._forward_kwargs_names, None)
533 | stacks_args = []
534 | stacks_kwargs = []
535 | stacks_info = []
536 |
537 | if prefix_length is not None and self._prepend_prefix:
538 | # we were given prefixes in the input, and we are expected to treat them
539 | prefix_length, s = self._handle_prefix(prefix_length, s)
540 |
541 | if self._encode_prefix:
542 | # encode the question as well, if needed
543 | context_start_ind, context_end_ind, update_start_ind, update_end_ind = 0, prefix_length, 0, prefix_length
544 | # need to pad it to match the seq len of the rest
545 | # we may have too short samples as well so don't want to pad too much
546 | pad = min(s, self._context_size)
547 | assert pad >= 0, f'We have a weird situation. pad={pad}, s={s}, ' \
548 | f'prefix_length={prefix_length} and self._context_size={self._context_size}'
549 | stacks_args.append(_fix_args(args, args_tensor_inds, context_start_ind, context_end_ind, None, pad))
550 | stacks_kwargs.append(_fix_kwargs(forward_kwargs, kwargs_tensor_keys, context_start_ind,
551 | context_end_ind, None, pad))
552 | stacks_info.append([None, update_start_ind, update_end_ind, None])
553 | else:
554 | prefix_length = None # we need to ignore the prefix and treat the entire input as one long document
555 |
556 | for context_start_ind, context_end_ind, update_start_ind, update_end_ind in self._window_indices(s):
557 | stacks_args.append(_fix_args(args, args_tensor_inds, context_start_ind, context_end_ind, prefix_length))
558 | stacks_kwargs.append(_fix_kwargs(forward_kwargs, kwargs_tensor_keys, context_start_ind,
559 | context_end_ind, prefix_length))
560 | stacks_info.append([None, update_start_ind, update_end_ind, prefix_length])
561 |
562 | encoder_outputs2 = self._underlying_model.forward(
563 | *_stack_args(stacks_args, args_tensor_inds),
564 | **_stack_kwargs(stacks_kwargs, kwargs_tensor_keys))
565 | bs = forward_kwargs[kwargs_tensor_keys[0]].size()[0] if len(kwargs_tensor_keys) > 0 else \
566 | args[args_tensor_inds[0]].size()[0]
567 | for si, eo in zip(stacks_info, _unstack_encoder_outputs(encoder_outputs2, len(stacks_info), bs)):
568 | si[0] = eo
569 | res = _merge_encoder_outputs(stacks_info)
570 |
571 | return res
572 |
573 | def _window_indices(self, total_seq_len):
574 | """
575 | when total_seq_len is smaller than our desired context length, we do not do sliding window at all.
576 | However, if it is longer, then we ALWAYS require the context length to be maximal, even if some windows have
577 | a lot of overlap.
578 | Also, first window will always update from the start, and last window will always update until the end.
579 | when applied, returns a generator that in each iteration produces for numbers:
580 | context_start_ind, context_end_ind, update_start_ind, update_end_ind
581 |
582 | context_start_ind, context_end_ind are indices in [0, total_seq_len],
583 | where context_end_ind > context_start_ind and when
584 | total_seq_len <= context_length then always context_end_ind = context_start_ind+context_length.
585 | The sequence of context_start_ind is strictly monotonic and same for context_end_ind.
586 | context_start_ind always start in 0 and
587 | context_end_ind will always end in total_seq_len.
588 | Gives us what token indices to take from the long input.
589 |
590 | update_start_ind, update_end_ind are indices in [0, min(total_seq_len, context_length)],
591 | where update_end_ind > update_start_ind
592 | and for all windows that are not in the edges (i.e. first/last window) we have
593 | update_end_ind-update_start_ind=context_length*window_fraction.
594 | For first window update_start_ind is always 0, and for last window,
595 | update_end_ind is always min(total_seq_len, context_length).
596 | They represents the start and end indices from the selected window of
597 | which tokens should be taken out for the final encoding
598 |
599 | When doing a full itartion, accounting for the fact that
600 | update_start_ind, update_end_ind are shifted by context_start_ind, we hould get that all indices in
601 | [0, total_seq_len] were covered exactly once
602 |
603 | Examples
604 | >>> from transformers import T5Tokenizer, T5Model
605 | >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
606 | >>> model_ = T5Model.from_pretrained('t5-small')
607 | >>> model = SledModel(model_, 512) # testing with padding of 50% and context of 512
608 |
609 | >>> list(model._window_indices(256)) # List of: (context_start, context_end, update_start, update_end). short sequence
610 | [(0, 256, 0, 256)]
611 | >>> list(model._window_indices(510)) # another short sequence
612 | [(0, 510, 0, 510)]
613 | >>> list(model._window_indices(512)) # sequence of exactly the context size
614 | [(0, 512, 0, 512)]
615 | >>> list(model._window_indices(514)) # sequence of slightly more than the context size
616 | [(0, 512, 0, 384), (2, 514, 382, 512)]
617 | >>> list(model._window_indices(766)) # long sequence that does not require a full stride (update in the last chunk is smaller than what is possible)
618 | [(0, 512, 0, 384), (254, 766, 130, 512)]
619 | >>> list(model._window_indices(768)) # long sequence for exactly two perfect chunks
620 | [(0, 512, 0, 384), (256, 768, 128, 512)]
621 | >>> list(model._window_indices(780)) # very long sequence that does not require a full stride (update in the last chunk is smaller than what is possible)
622 | [(0, 512, 0, 384), (256, 768, 128, 384), (268, 780, 372, 512)]
623 | >>> windows = list(model._window_indices(1050))
624 | >>> windows
625 | [(0, 512, 0, 384), (256, 768, 128, 384), (512, 1024, 128, 384), (538, 1050, 358, 512)]
626 | >>> windows = sum([list(range(us+cs, ue+cs)) for cs, _, us, ue in windows], []) # verify it covers exactly all the indices, each once
627 | >>> windows[:10]
628 | [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
629 | >>> windows[500:510]
630 | [500, 501, 502, 503, 504, 505, 506, 507, 508, 509]
631 | >>> len(windows)
632 | 1050
633 | >>> len(set(windows))
634 | 1050
635 | >>> model = SledModel(model_, 256, window_fraction=0.75) # now testing with padding of 25% and context of 256
636 |
637 | >>> list(model._window_indices(128)) # List of: (context_start, context_end, update_start, update_end). short sequence
638 | [(0, 128, 0, 128)]
639 | >>> list(model._window_indices(254)) # another short sequence
640 | [(0, 254, 0, 254)]
641 | >>> list(model._window_indices(256)) # sequence of exactly the context size
642 | [(0, 256, 0, 256)]
643 | >>> list(model._window_indices(258)) # sequence of slightly more than the context size. margin is 256/8 -> 32
644 | [(0, 256, 0, 224), (2, 258, 222, 256)]
645 | >>> list(model._window_indices(446)) # long sequence that does not require a full stride (update in the last chunk is smaller than what is possible). stride should be 256-64=192
646 | [(0, 256, 0, 224), (190, 446, 34, 256)]
647 | >>> list(model._window_indices(448)) # long sequence for exactly two perfect chunks
648 | [(0, 256, 0, 224), (192, 448, 32, 256)]
649 | >>> list(model._window_indices(500)) # very long sequence that does not require a full stride (update in the last chunk is smaller than what is possible)
650 | [(0, 256, 0, 224), (192, 448, 32, 224), (244, 500, 172, 256)]
651 | >>> windows = list(model._window_indices(1050))
652 | >>> windows
653 | [(0, 256, 0, 224), (192, 448, 32, 224), (384, 640, 32, 224), (576, 832, 32, 224), (768, 1024, 32, 224), (794, 1050, 198, 256)]
654 | >>> windows = sum([list(range(us+cs, ue+cs)) for cs, _, us, ue in windows], []) # verify it covers exactly all the indices, each once
655 | >>> windows[:10]
656 | [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
657 | >>> windows[500:510]
658 | [500, 501, 502, 503, 504, 505, 506, 507, 508, 509]
659 | >>> len(windows)
660 | 1050
661 | >>> len(set(windows))
662 | 1050
663 | """
664 | if total_seq_len <= self._context_size:
665 | yield 0, total_seq_len, 0, total_seq_len
666 | else:
667 | stride = self._context_size - 2 * self._window_margin
668 | context_start = update_start_ind = 0
669 | context_end = self._context_size
670 | update_end_ind = context_end - self._window_margin
671 | yield context_start, context_end, update_start_ind, update_end_ind # first window always should update from the beginning
672 | while context_end < total_seq_len:
673 | context_end = min(total_seq_len, context_end + stride)
674 | context_start = (
675 | context_start + stride if context_end < total_seq_len else total_seq_len - self._context_size
676 | )
677 | update_start_ind = max(update_start_ind + stride, update_end_ind)
678 | # last window always should update until the end
679 | update_end_ind = (
680 | min(total_seq_len, update_end_ind + stride) if context_end < total_seq_len else total_seq_len
681 | )
682 |
683 | cs, ce, us, ue = context_start, context_end, update_start_ind - context_start, \
684 | update_end_ind - context_start
685 |
686 | yield cs, ce, us, ue
687 |
688 | def _fill_prefix_inputs(self, kwargs, kwargs_tensor_keys):
689 | prefix_inputs = {}
690 | k = PREFIX_KEY
691 | if PREFIX_KEY in kwargs:
692 | if self._prepend_prefix:
693 | if k not in kwargs_tensor_keys:
694 | warnings.warn(f'{k} is missing from kwargs_tensor_keys (though expected for SLED prefix prepending)')
695 | else:
696 | kwargs_tensor_keys.remove(k)
697 | prefix_inputs[k] = kwargs.pop(k)
698 | elif k in kwargs_tensor_keys:
699 | warnings.warn(f'{k} is given in kwargs_tensor_keys even though sled should not prepend prefix, '
700 | f'that would mean the prefix would be ignored and the entire input will be treated '
701 | f'as a single long document, which is probably not what you meant')
702 | return prefix_inputs
703 |
704 | @staticmethod
705 | def _prep_attention_mask_for_cross_attention(encode_prefix, attention_mask, prefix_length=None):
706 | # if we need to drop the prefix encodings, we also need to adjust the attention mask before decoding
707 | if not encode_prefix and prefix_length is not None:
708 | prefix_length = int(prefix_length[0])
709 | return attention_mask[..., prefix_length:]
710 | return attention_mask
711 |
712 |
713 | class SledModel(SledPretrainedModel):
714 | """
715 | >>> from transformers import T5Tokenizer, T5Model, BartModel, BartTokenizer
716 |
717 | >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
718 | >>> model_ = T5Model.from_pretrained('t5-small')
719 | >>> model = SledModel(model_, 4)
720 |
721 | >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1
722 | >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
723 | >>> outputs = model(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone())
724 | >>> outputs = model(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone(), return_dict=True)
725 | >>> outputs = model(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone(), return_dict=False)
726 | """
727 |
728 | def __init__(self, underlying_model: PreTrainedModel, config: SledConfig):
729 | super(SledModel, self).__init__(underlying_model, config)
730 | # validate the model can be used
731 | self._decoder_attr_name = getattr(underlying_model, "get_decoder_attr_name", lambda: "decoder")()
732 | self._encoder_attr_name = getattr(underlying_model, "get_encoder_attr_name", lambda: "encoder")()
733 | self._set_underlying_model_attr(self._decoder_attr_name, self.get_decoder())
734 | self._mock_decoder = MockDecoder()
735 | assert "return_dict" in self._forward_kwargs_names
736 | assert "encoder_outputs" in self._forward_kwargs_names
737 |
738 | def _forward(self, *args, **kwargs):
739 | self._verify_config_consistency()
740 | kwargs, args = _fill_kwargs_with_args(self._forward_kwargs_names, *args, **kwargs)
741 | kwargs.setdefault("encoder_outputs", None)
742 | return_dict = kwargs.setdefault("return_dict", None)
743 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
744 | labels = kwargs.get("labels", None)
745 | kwargs["labels"] = None
746 | kwargs["return_dict"] = False
747 | kwargs.setdefault("labels", None)
748 | args_tensor_inds, kwargs_tensor_keys, s = _find_tensor_inds_and_size(*args, **kwargs)
749 | prefix_inputs = self._fill_prefix_inputs(kwargs, kwargs_tensor_keys)
750 |
751 | forward_kwargs = _extract_keyword_args(kwargs, self._forward_kwargs_names, None)
752 | if forward_kwargs["encoder_outputs"] is None:
753 | # encode, but first let's set decoder to be a mock, no reason to apply it over partial windows
754 | self._prep_for_encoding() # todo - add try catch every time we 'prep' something to rever the state on fail?
755 |
756 | forward_kwargs["encoder_outputs"] = self._run_sliding_window_forward(
757 | args_tensor_inds, kwargs_tensor_keys, s, *args, **prefix_inputs, **forward_kwargs
758 | )
759 | forward_kwargs['attention_mask'] = self._prep_attention_mask_for_cross_attention(self._encode_prefix,
760 | forward_kwargs['attention_mask'], prefix_inputs.get('prefix_length', None))
761 |
762 | # now, let's decode
763 | forward_kwargs["return_dict"] = return_dict
764 | forward_kwargs["labels"] = labels
765 | self._fix_post_encoding()
766 | if 'decoder_input_ids' in self._forward_kwargs_names and \
767 | forward_kwargs.get('decoder_input_ids', None) is None and \
768 | hasattr(self, 'prepare_decoder_input_ids_from_labels') :
769 | logger.warning('Passing a batch through the model without the decoder_input_ids is likely to cause issues. '
770 | 'If you encounter cuda errors, make sure you use the prepare_decoder_input_ids_from_labels '
771 | 'function of the model correctly before passing the input. '
772 | 'If you are only performing prediction without training, you can safely ignore this message')
773 | res = self._underlying_model.forward(
774 | *args, **_extract_keyword_args(forward_kwargs, self._forward_kwargs_names)
775 | )
776 |
777 | return res
778 |
779 | def _prep_for_encoding(self):
780 | if not getattr(self, '_preped_for_encoding', False):
781 | self._preped_for_encoding = True
782 | self._decoder = self.get_decoder()
783 | self._mock_decoder.first_device = getattr(self._decoder, "first_device", None)
784 | self._set_underlying_model_attr(self._decoder_attr_name, self._mock_decoder)
785 |
786 | def _fix_post_encoding(self):
787 | assert self._preped_for_encoding
788 | self._preped_for_encoding = False
789 | self._set_underlying_model_attr(self._decoder_attr_name, self._decoder)
790 |
791 |
792 | class MockTensorForConditionalGeneration:
793 | def __add__(self, other):
794 | return tuple()
795 |
796 | def __mul__(self, other):
797 | return tuple()
798 |
799 | def to(self, *_, **__):
800 | return self
801 |
802 |
803 | class MockDecoderForConditionalGeneration(nn.Module):
804 | pad_value = 0
805 |
806 | def forward(self, *_, **__):
807 | return (MockTensorForConditionalGeneration(),)
808 |
809 | def to(self, *_, **__):
810 | return self
811 |
812 |
813 | class MockLMHeadForConditionalGeneration(nn.Module):
814 | def forward(self, *_, **__):
815 | return MockTensorForConditionalGeneration()
816 |
817 | def to(self, *_, **__):
818 | return self
819 |
820 | def __getattr__(self, item):
821 | try:
822 | return super().__getattr__(item)
823 | except AttributeError:
824 | return self
825 |
826 |
827 | def _fill_kwargs_with_args(forward_param_names, *args, **kwargs):
828 | kwargs.update({arg_name: arg_value for arg_name, arg_value in zip(forward_param_names, args)})
829 | return kwargs, tuple()
830 |
831 |
832 | class SledEncoder(SledPretrainedModel):
833 | def __init__(self, underlying_model: PreTrainedModel, config: SledConfig):
834 | super(SledEncoder, self).__init__(underlying_model, config)
835 |
836 | @property
837 | def _ignore_keys(self):
838 | return super()._ignore_keys | {'use_cache', 'is_encoder_decoder'}
839 | # Encoder models are not encoder-decoder but are meant for internal use by the SLED models
840 |
841 |
842 | def _forward(self, *args, **kwargs):
843 | kwargs, args = _fill_kwargs_with_args(self._forward_kwargs_names, *args, **kwargs)
844 | args_tensor_inds, kwargs_tensor_keys, s = _find_tensor_inds_and_size(*args, **kwargs)
845 | prefix_inputs = self._fill_prefix_inputs(kwargs, kwargs_tensor_keys)
846 | return self._run_sliding_window_forward(args_tensor_inds, kwargs_tensor_keys, s, *args, **prefix_inputs,
847 | **kwargs)
848 |
849 | def _skip_forward_for_decoder_only(self, args_tensor_inds, kwargs_tensor_keys, s, *args,
850 | prefix_length=None, **kwargs):
851 | encoder_outputs = super()._skip_forward_for_decoder_only(args_tensor_inds, kwargs_tensor_keys, s,
852 | *args, prefix_length, **kwargs)
853 | return BaseModelOutput(encoder_outputs[0])
854 |
855 |
856 | class SledForConditionalGeneration(SledModel):
857 | """
858 | >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
859 |
860 | >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
861 | >>> model_ = T5ForConditionalGeneration.from_pretrained('t5-small')
862 | >>> shared = model_.shared
863 | >>> model = SledForConditionalGeneration(model_, 4)
864 | >>> model._underlying_model == model_ # make sure the decoration works
865 | >>> model.shared == shared # make sure the decoration works
866 | >>> input_ids = tokenizer('The walks in park', return_tensors='pt').input_ids
867 | >>> labels = tokenizer(' cute dog the ', return_tensors='pt').input_ids
868 | >>> outputs = model(input_ids=input_ids, labels=labels)
869 | >>> outputs = model(input_ids=input_ids, labels=labels, return_dict=None)
870 | >>> outputs = model(input_ids=input_ids, labels=labels, return_dict=False)
871 | >>> outputs = model(input_ids=input_ids, labels=labels, return_dict=True)
872 | >>> loss = outputs.loss
873 | >>> logits = outputs.logits
874 | >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids # Batch size 1
875 | >>> outputs = model.generate(input_ids)
876 | """
877 | OVERRIDDEN_METHODS = {'generate'}
878 | auto_model_loader = AutoModelForSeq2SeqLM
879 |
880 | def __init__(self, underlying_model: PreTrainedModel, config: SledConfig):
881 | super(SledForConditionalGeneration, self).__init__(underlying_model, config)
882 | self._mock_decoder = MockDecoderForConditionalGeneration()
883 | self._base_encoder = self.get_encoder()
884 | self._sled_encoder = SledEncoder(self._base_encoder, config)
885 |
886 | # override generation preparation functions that may be overridden by underlying models but will be found in our wrapper
887 | for method_name, _ in inspect.getmembers(GenerationMixin(), predicate=inspect.ismethod):
888 | if method_name not in SledForConditionalGeneration.OVERRIDDEN_METHODS:
889 | setattr(self, method_name, getattr(underlying_model, method_name))
890 |
891 | # NOTE - the below affects the given underlying model, which means generating with it
892 | # directly may not work anymore
893 | self._underlying_model._prepare_encoder_decoder_kwargs_for_generation = \
894 | self._get__prepare_encoder_decoder_kwargs_for_generation_func_override()
895 | self._underlying_model._validate_model_kwargs = _validate_model_kwargs # see hack details below
896 |
897 | def _get__prepare_encoder_decoder_kwargs_for_generation_func_override(self):
898 | # _prepare_encoder_decoder_kwargs_for_generation(
899 | # self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
900 | # ) -> Dict[str, Any]:
901 | f = self._underlying_model._prepare_encoder_decoder_kwargs_for_generation
902 | encode_prefix = self._encode_prefix
903 |
904 | def _prepare_encoder_decoder_kwargs_for_generation(
905 | inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None) -> Dict[str, Any]:
906 | # override needed for underlying model in a conditional generation mode with prefix prepending and dropping
907 | model_kwargs = f(inputs_tensor, model_kwargs, model_input_name)
908 | model_kwargs['attention_mask'] = SledPretrainedModel._prep_attention_mask_for_cross_attention(
909 | encode_prefix, model_kwargs['attention_mask'], model_kwargs.get('prefix_length', None))
910 | return model_kwargs
911 |
912 | return _prepare_encoder_decoder_kwargs_for_generation
913 |
914 | def _prep_for_encoding(self):
915 | was_preped = getattr(self, '_preped_for_encoding', False)
916 | super(SledForConditionalGeneration, self)._prep_for_encoding()
917 | if not was_preped:
918 | self._lm_head = getattr(self._underlying_model, "lm_head", None)
919 | setattr(self._underlying_model, "lm_head", MockLMHeadForConditionalGeneration())
920 |
921 | def _fix_post_encoding(self):
922 | super(SledForConditionalGeneration, self)._fix_post_encoding()
923 | setattr(self._underlying_model, "lm_head", self._lm_head)
924 |
925 | def generate(self, *args, **kwargs):
926 | self._set_underlying_model_attr(self._encoder_attr_name, self._sled_encoder)
927 | try:
928 | res = self._underlying_model.generate(*args, **kwargs)
929 | finally:
930 | self._set_underlying_model_attr(self._encoder_attr_name, self._base_encoder)
931 | return res
932 |
933 | def _validate_model_kwargs(self, *args, **kwargs):
934 | # Newer versions of HF perform a check on the input args for generate and raise an exception when passing
935 | # prefix_length for example to this model because it doesn't list it explicitly.
936 | # This is a hack to support newer HF models until the generate() signature will be created dynamically to
937 | # include all the keyword args including prefix_length
938 | pass
939 |
--------------------------------------------------------------------------------
/sled/tokenization_sled.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from transformers import PreTrainedTokenizer, AutoTokenizer, AutoConfig
4 |
5 |
6 | class SledTokenizer(PreTrainedTokenizer):
7 | auto_tokenizer_loader = AutoTokenizer
8 | auto_config_loader = AutoConfig
9 |
10 | def __init__(self, *args, **kwargs):
11 | super(SledTokenizer, self).__init__(*args, **kwargs)
12 |
13 | @classmethod
14 | def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
15 | assert isinstance(pretrained_model_name_or_path, str), 'pretrained_model_name_or_path must be a path to a ' \
16 | 'checkpoint or a local config file (json)'
17 | if os.path.exists(pretrained_model_name_or_path):
18 | # if pretrained_model_name_or_path is a saved checkpoint
19 | config = kwargs.pop('config', None)
20 | if pretrained_model_name_or_path.endswith('json'):
21 | config = config or cls.auto_config_loader.from_pretrained(pretrained_model_name_or_path)
22 | return cls.auto_tokenizer_loader.from_pretrained(config.underlying_config, *init_inputs, **kwargs)
23 | else:
24 | # otherwise, it is a config json path
25 | raise NotImplementedError('loading a pretrained saved sled checkpoint is not yet implemented')
26 | else:
27 | # assume it is a model card on the hub
28 | config = kwargs.pop('config', None)
29 | config = config or cls.auto_config_loader.from_pretrained(
30 | pretrained_model_name_or_path, use_auth_token=kwargs.get('use_auth_token', False))
31 | kwargs['use_fast'] = False
32 | return cls.auto_tokenizer_loader.from_pretrained(config.underlying_config, *init_inputs, **kwargs)
--------------------------------------------------------------------------------
/sled/tokenization_sled_fast.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from transformers import AutoTokenizer, AutoConfig, PreTrainedTokenizerFast
4 |
5 | from .tokenization_sled import SledTokenizer
6 |
7 |
8 | class SledTokenizerFast(PreTrainedTokenizerFast):
9 | auto_tokenizer_loader = AutoTokenizer
10 | auto_config_loader = AutoConfig
11 | slow_tokenizer_class = SledTokenizer
12 |
13 | def __init__(self, *args, **kwargs):
14 | super(SledTokenizerFast, self).__init__(*args, **kwargs)
15 |
16 | @classmethod
17 | def from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs):
18 | assert isinstance(pretrained_model_name_or_path, str), 'pretrained_model_name_or_path must be a path to a ' \
19 | 'checkpoint or a local config file (json)'
20 | if os.path.exists(pretrained_model_name_or_path):
21 | # if pretrained_model_name_or_path is a saved checkpoint
22 | config = kwargs.pop('config', None)
23 | if pretrained_model_name_or_path.endswith('json'):
24 | config = config or cls.auto_config_loader.from_pretrained(pretrained_model_name_or_path)
25 | return cls.auto_tokenizer_loader.from_pretrained(config.underlying_config, *init_inputs, **kwargs)
26 | else:
27 | # otherwise, it is a config json path
28 | raise NotImplementedError('loading a pretrained saved sled checkpoint is not yet implemented')
29 | else:
30 | # assume it is a model card on the hub
31 | config = kwargs.pop('config', None)
32 | config = config or cls.auto_config_loader.from_pretrained(
33 | pretrained_model_name_or_path, use_auth_token=kwargs.get('use_auth_token', False))
34 | kwargs['use_fast'] = True
35 | return cls.auto_tokenizer_loader.from_pretrained(config.underlying_config, *init_inputs, **kwargs)
--------------------------------------------------------------------------------
/tests/configs/bart_base_sled.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "tau/sled",
3 | "underlying_config": "facebook/bart-base",
4 | "context_size": 256,
5 | "window_fraction": 0.5,
6 | "prepend_prefix": true,
7 | "encode_prefix": true,
8 | "sliding_method": "dynamic"
9 | }
--------------------------------------------------------------------------------
/tests/configs/t5_base_sled.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_type": "tau/sled",
3 | "underlying_config": "google/t5-v1_1-base",
4 | "context_size": 256,
5 | "window_fraction": 0.5,
6 | "prepend_prefix": true,
7 | "encode_prefix": true,
8 | "sliding_method": "dynamic"
9 | }
--------------------------------------------------------------------------------
/tests/test_sled.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import sys
4 |
5 | from torch import nn
6 |
7 | from sled.tokenization_sled import SledTokenizer
8 | from sled.tokenization_sled_fast import SledTokenizerFast
9 |
10 | BART_BASE_SLED_URL = 'tau/bart-base-sled'
11 | BART_BASE_SLED_GOV_URL = 'tau/bart-base-sled-govreport'
12 | T5_BASE_SLED_URL = 'tau/t5-v1_1-base-sled'
13 |
14 | sys.path.insert(0, os.getcwd())
15 |
16 | import inspect
17 | import unittest
18 |
19 | import torch
20 | from transformers import (
21 | T5Tokenizer,
22 | T5Model,
23 | BartModel,
24 | BartTokenizer,
25 | T5ForConditionalGeneration,
26 | BartForConditionalGeneration, AutoConfig, AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM,
27 | BartTokenizerFast, T5TokenizerFast,
28 | )
29 | from transformers.testing_utils import require_torch
30 |
31 |
32 | from sled.modeling_sled import SledModel, SledForConditionalGeneration, PREFIX_KEY
33 | from sled.configuration_sled import SledConfig
34 |
35 | use_auth_token = False
36 |
37 | def _compare_tuple_of_tensors(test_case, expected_tuple, got_tuple, prev_got_tuple, rtol=1e-5):
38 | for i, (exp_tensor, got_tensor) in enumerate(zip(expected_tuple, got_tuple)):
39 | if isinstance(exp_tensor, torch.Tensor):
40 | test_case.assertTrue(torch.allclose(exp_tensor, got_tensor, rtol=rtol))
41 | # we can't expect the values to be the same when different context length, but at least can verify shapes
42 | if prev_got_tuple is not None:
43 | test_case.assertTrue(exp_tensor.size() == prev_got_tuple[i].size())
44 | elif isinstance(exp_tensor, tuple):
45 | prev_got_tuple_i = prev_got_tuple[i] if prev_got_tuple is not None else None
46 | _compare_tuple_of_tensors(test_case, exp_tensor, got_tensor, prev_got_tuple_i, rtol=rtol)
47 |
48 |
49 | @require_torch
50 | class SLEDModelTest(unittest.TestCase):
51 | def _run_sled_model_test_case(self, model_, tokenizer, underlying_config: str, rtol=1e-5):
52 | model = SledModel(model_, SledConfig(underlying_config, context_size=4))
53 | model.eval() # only change the model to be in eval (inference) mode,
54 | # thus not changing layer_norm params and removing dropout
55 |
56 | input_ids = tokenizer(
57 | "Studies have been shown that owning a dog is good for you", return_tensors="pt"
58 | ).input_ids # Batch size 1
59 | decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
60 |
61 | # simple verification there are no failures in the flow itself
62 | with torch.no_grad():
63 | _ = model(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone())
64 | _ = model(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone(), return_dict=None)
65 | outputs_dict = model(
66 | input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone(), return_dict=True
67 | )
68 | outputs_no_dict = model(
69 | input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone(), return_dict=False
70 | )
71 |
72 | # let's verify that if the sequence is short, we behave exactly as the base model
73 | model = SledModel(model_, SledConfig(underlying_config, context_size=512))
74 | with torch.no_grad():
75 | output_expected = model_(
76 | input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone(), return_dict=False
77 | )
78 |
79 | # on non dict return type
80 | with torch.no_grad():
81 | output_got = model(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone(), return_dict=False)
82 | self.assertEqual(type(output_expected), type(output_got))
83 | self.assertEqual(type(output_expected), type(outputs_no_dict))
84 | self.assertEqual(len(output_got), len(output_expected)) # should be tuple so it's ok
85 | self.assertEqual(len(output_got), len(outputs_no_dict)) # should be tuple so it's ok
86 | _compare_tuple_of_tensors(self, output_expected, output_got, outputs_no_dict, rtol=rtol)
87 |
88 | # on dict return type
89 | with torch.no_grad():
90 | output_expected = model_(
91 | input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone(), return_dict=True
92 | )
93 | output_got = model(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone(), return_dict=True)
94 | self.compare_outputs_dict(output_expected, output_got, outputs_dict, rtol=rtol)
95 |
96 | def compare_outputs_dict(self, output_expected, output_got, outputs_dict=None, rtol=1e-5):
97 | if output_got is None and outputs_dict is not None:
98 | output_got = output_expected
99 | if outputs_dict is None:
100 | outputs_dict = output_expected
101 | self.assertEqual(type(output_expected), type(output_got))
102 | self.assertEqual(type(output_expected), type(outputs_dict))
103 | self.assertListEqual(list(output_got.keys()), list(output_expected.keys()))
104 | self.assertListEqual(list(output_got.keys()), list(outputs_dict.keys()))
105 | for key in output_got.keys():
106 | if isinstance(output_got[key], torch.Tensor):
107 | self.assertTrue(torch.allclose(output_got[key], output_expected[key], rtol=rtol))
108 | self.assertFalse(output_got[key].requires_grad)
109 | self.assertFalse(output_expected[key].requires_grad)
110 | self.assertFalse(outputs_dict[key].requires_grad)
111 |
112 | # we can't expect the values to be the same when different context length, but at least can verify shapes
113 | self.assertTrue(output_got[key].size() == outputs_dict[key].size())
114 | elif isinstance(output_got[key], tuple):
115 | _compare_tuple_of_tensors(self, output_got[key], output_expected[key], outputs_dict[key], rtol=rtol)
116 |
117 | def test_facade_get_attr_behavior(self):
118 | base_model = T5Model.from_pretrained("t5-small")
119 | sled_model = SledModel(base_model, SledConfig("t5-small", context_size=4))
120 | self.assertEqual(
121 | base_model.shared, sled_model.shared
122 | ) # sled model does not have a 'shared' attribute, only it's base model does
123 |
124 | def _verify_loaded_sled_model(self, bart_sled_model, config_class, underlying_config_class,
125 | underlying_config_path="facebook/bart-base", expected_underlying_model=None):
126 | self.assertIsInstance(bart_sled_model, config_class)
127 | self.assertIsInstance(bart_sled_model._underlying_model, underlying_config_class)
128 | # make sure it has the pretrained weights and not random ones
129 | expected_underlying_model = expected_underlying_model or \
130 | underlying_config_class.from_pretrained(underlying_config_path)
131 | # noinspection PyTypeChecker
132 | self.assertTrue(torch.all(
133 | expected_underlying_model.get_encoder().state_dict()['embed_tokens.weight'] ==
134 | bart_sled_model.get_encoder().state_dict()['embed_tokens.weight']).item())
135 |
136 | def test_load_from_pretrained(self):
137 | config_path = BART_BASE_SLED_URL
138 | another_config_path = 'configs/t5_base_sled.json'
139 | another_config_path_hub = T5_BASE_SLED_URL
140 |
141 | # first, we test the config
142 | # start by loading it explicitly
143 | bart_base_sled_config = SledConfig.from_pretrained(config_path, use_auth_token=use_auth_token)
144 | self.assertIsInstance(bart_base_sled_config, SledConfig)
145 |
146 | # now, with auto classes
147 | bart_base_sled_config2 = AutoConfig.from_pretrained(config_path, use_auth_token=use_auth_token)
148 | self.assertNotEqual(bart_base_sled_config, bart_base_sled_config2) # the explicit one didn't have a name or path
149 | bart_base_sled_config._name_or_path = bart_base_sled_config2._name_or_path
150 | self.assertEqual(bart_base_sled_config, bart_base_sled_config2)
151 |
152 | # Now, let's check the model loading
153 | self._verify_loaded_sled_model(SledModel.from_pretrained(config_path, use_auth_token=use_auth_token),
154 | SledModel, BartModel) # explicit load
155 |
156 | # now, with auto classes
157 | self._verify_loaded_sled_model(AutoModel.from_pretrained(config_path, use_auth_token=use_auth_token),
158 | SledModel, BartModel) # auto load
159 |
160 | # now, lets assert that "forConditionalGeneration" Also works as expected
161 | self._verify_loaded_sled_model(SledForConditionalGeneration.from_pretrained(config_path,
162 | use_auth_token=use_auth_token),
163 | SledForConditionalGeneration, BartForConditionalGeneration) # explicit
164 |
165 | # now, with auto classes
166 | self._verify_loaded_sled_model(
167 | AutoModelForSeq2SeqLM.from_pretrained(config_path, use_auth_token=use_auth_token),
168 | SledForConditionalGeneration, BartForConditionalGeneration) # auto
169 |
170 | # Finally, let's verify it also work with another model
171 | self._verify_loaded_sled_model(AutoModelForSeq2SeqLM.from_pretrained(
172 | another_config_path, use_auth_token=use_auth_token), SledForConditionalGeneration,
173 | T5ForConditionalGeneration, "google/t5-v1_1-base")
174 |
175 | self._verify_loaded_sled_model(AutoModelForSeq2SeqLM.from_pretrained(
176 | another_config_path_hub, use_auth_token=use_auth_token), SledForConditionalGeneration,
177 | T5ForConditionalGeneration, "google/t5-v1_1-base")
178 |
179 | def test_config_overrides(self):
180 | # Load the base model, and verify its consistency with the underlying config
181 | bart_base_model_sled = AutoModel.from_pretrained(BART_BASE_SLED_URL, use_auth_token=use_auth_token)
182 | self.assertFalse(bart_base_model_sled.config.gradient_checkpointing)
183 | self.assertFalse(bart_base_model_sled.underlying_model.config.gradient_checkpointing)
184 | self.assertTrue(bart_base_model_sled.config.use_cache)
185 | self.assertTrue(bart_base_model_sled.underlying_model.config.use_cache)
186 |
187 | # now, supply overrides and make sure they are consistent
188 | bart_base_model_sled = AutoModel.from_pretrained(BART_BASE_SLED_URL, gradient_checkpointing=True, use_cache=False,
189 | use_auth_token=use_auth_token)
190 | self.assertTrue(bart_base_model_sled.config.gradient_checkpointing)
191 | self.assertTrue(bart_base_model_sled.underlying_model.config.gradient_checkpointing)
192 | self.assertFalse(bart_base_model_sled.config.use_cache)
193 | self.assertFalse(bart_base_model_sled.underlying_model.config.use_cache)
194 |
195 | # Finally, set the config after load and make sure it is synced properly
196 | bart_base_model_sled.config.gradient_checkpointing = False
197 | bart_base_model_sled.config.use_cache = True
198 | self.assertFalse(bart_base_model_sled.config.gradient_checkpointing)
199 | self.assertTrue(bart_base_model_sled.config.use_cache)
200 | # this wouldn't have been fixed before the model was first used
201 | self.assertTrue(bart_base_model_sled.underlying_model.config.gradient_checkpointing)
202 | self.assertFalse(bart_base_model_sled.underlying_model.config.use_cache)
203 | # now, pass something through the model (even though it would fail)
204 | try:
205 | bart_base_model_sled(None)
206 | except:
207 | pass
208 | self.assertFalse(bart_base_model_sled.underlying_model.config.gradient_checkpointing)
209 | self.assertTrue(bart_base_model_sled.underlying_model.config.use_cache)
210 |
211 | def test_all_loading_options(self):
212 | # test load from local config, from a saved checkpoint and from a URL model card
213 | local_bart_base_sled_model = SledModel.from_pretrained('configs/bart_base_sled.json')
214 | self._verify_loaded_sled_model(local_bart_base_sled_model, SledModel, BartModel) # explicit load
215 |
216 | hub_bart_base_sled_model = SledModel.from_pretrained(BART_BASE_SLED_URL, use_auth_token=use_auth_token)
217 | self._verify_loaded_sled_model(hub_bart_base_sled_model, SledModel, BartModel) # explicit load
218 |
219 | # now, save and reload
220 | cache_dir = os.environ.get('XDG_CACHE_HOME', '/tmp/cache')
221 | out_dir = f'{cache_dir}/test_save_checkpoint'
222 | os.makedirs(out_dir, exist_ok=True)
223 | shutil.rmtree(out_dir) # cleanup previous checkpoints
224 | # let's change the model a little before saving it to make sure it works correctly and doesn't load the c
225 | # checkpoint on the hub
226 | local_bart_base_sled_model.get_encoder().state_dict()['embed_tokens.weight'] += 1
227 | # make sure they were indeed changed
228 | self.assertRaises(AssertionError, self._verify_loaded_sled_model, local_bart_base_sled_model, SledModel,
229 | BartModel)
230 |
231 | # now save and test
232 | local_bart_base_sled_model.save_pretrained(out_dir)
233 | self.assertTrue(os.path.isfile(os.path.join(out_dir, 'pytorch_model.bin')))
234 | self.assertTrue(os.path.isfile(os.path.join(out_dir, 'config.json')))
235 | loaded_bart_base_sled_model = SledModel.from_pretrained(out_dir, use_auth_token=use_auth_token)
236 | self._verify_loaded_sled_model(loaded_bart_base_sled_model, SledModel, BartModel,
237 | expected_underlying_model=local_bart_base_sled_model)
238 |
239 |
240 | def test_load_tokenizer_from_pretrained(self):
241 | config_path = BART_BASE_SLED_URL
242 | another_config_path = T5_BASE_SLED_URL
243 |
244 | # slow tokenizer
245 | # explicit load, should actually return a BartTokenizer (the default is a fast tokenizer)
246 | self.assertIsInstance(SledTokenizer.from_pretrained(config_path, use_auth_token=use_auth_token, use_fast=False),
247 | BartTokenizer)
248 |
249 | # autoload, should actually return a BartTokenizer
250 | self.assertIsInstance(AutoTokenizer.from_pretrained(config_path, use_auth_token=use_auth_token, use_fast=False),
251 | BartTokenizer)
252 |
253 | # fast tokenizer
254 | # explicit load, should actually return a BartTokenizerFast
255 | self.assertIsInstance(SledTokenizerFast.from_pretrained(config_path, use_auth_token=use_auth_token),
256 | BartTokenizerFast)
257 | # autoload, should actually return a BartTokenizerFast
258 | self.assertIsInstance(AutoTokenizer.from_pretrained(config_path, use_auth_token=use_auth_token),
259 | BartTokenizerFast)
260 |
261 | # and now with T5
262 | self.assertIsInstance(SledTokenizer.from_pretrained(another_config_path, use_auth_token=use_auth_token,
263 | use_fast=False), T5Tokenizer)
264 | self.assertIsInstance(AutoTokenizer.from_pretrained(another_config_path, use_auth_token=use_auth_token,
265 | use_fast=False), T5Tokenizer)
266 | self.assertIsInstance(SledTokenizerFast.from_pretrained(another_config_path, use_auth_token=use_auth_token,
267 | use_fast=True), T5TokenizerFast)
268 | self.assertIsInstance(AutoTokenizer.from_pretrained(another_config_path, use_auth_token=use_auth_token,
269 | use_fast=True), T5TokenizerFast)
270 |
271 | def test_load_finetuned_model(self):
272 | bart_base_sled = AutoModel.from_pretrained(BART_BASE_SLED_URL)
273 | bart_base_sled_gov = SledModel.from_pretrained(BART_BASE_SLED_GOV_URL)
274 | # testing embeedings have changed
275 | assert not torch.all(bart_base_sled.get_input_embeddings().weight ==
276 | bart_base_sled_gov.get_input_embeddings().weight).item()
277 | # test decoder weights have changed
278 | assert not torch.all(
279 | bart_base_sled.state_dict()['_underlying_model.decoder.layers.0.encoder_attn.k_proj.weight'] ==
280 | bart_base_sled_gov.state_dict()['_underlying_model.decoder.layers.0.encoder_attn.k_proj.weight']
281 | ).item()
282 | # test encoder weights have changed
283 | assert not torch.all(
284 | bart_base_sled.state_dict()['_underlying_model.encoder.layers.0.self_attn.k_proj.weight'] ==
285 | bart_base_sled_gov.state_dict()['_underlying_model.encoder.layers.0.self_attn.k_proj.weight']
286 | ).item()
287 |
288 |
289 | def test_sled_on_t5(self):
290 | self._run_sled_model_test_case(T5Model.from_pretrained("t5-small"), T5Tokenizer.from_pretrained("t5-small"),
291 | "t5-small")
292 |
293 | def test_sled_on_bart(self):
294 | self._run_sled_model_test_case(
295 | BartModel.from_pretrained("facebook/bart-base"), BartTokenizer.from_pretrained("facebook/bart-base"),
296 | "facebook/bart-base"
297 | )
298 |
299 | def test_forward_signature(self):
300 | # HF trainer uses the signature to choose which parts to take from a dataset, so we need to make sure our wrapped forward
301 | # function has the correct signature
302 | _model = BartModel.from_pretrained("facebook/bart-base")
303 | expected_sig = [param for param in inspect.signature(_model.forward).parameters.keys()]
304 | got_sig = [param for param in inspect.signature(SledModel(_model, SledConfig(context_size=128)).forward).parameters.keys()]
305 | self.assertListEqual(expected_sig, got_sig[:-1])
306 | self.assertEqual(str(got_sig[-1]), PREFIX_KEY)
307 |
308 | def test_resize_token_embeddings(self):
309 | _model = BartModel.from_pretrained("facebook/bart-base")
310 | orig_vocab_size = _model.config.vocab_size
311 | self.assertNotEqual(orig_vocab_size, 512)
312 | _model.resize_token_embeddings(512)
313 | self.assertEqual(_model.config.vocab_size, 512)
314 | model = SledModel(_model, SledConfig("facebook/bart-base", context_size=128))
315 | self.assertEqual(model.config.vocab_size, 512)
316 | self.assertEqual(_model.get_input_embeddings().weight.size()[0], 512)
317 | self.assertEqual(model.get_input_embeddings().weight.size()[0], 512)
318 | model.resize_token_embeddings(1024)
319 | self.assertEqual(model.config.vocab_size, 1024)
320 | self.assertEqual(_model.config.vocab_size, 1024)
321 | self.assertEqual(_model.get_input_embeddings().weight.size()[0], 1024)
322 | self.assertEqual(model.get_input_embeddings().weight.size()[0], 1024)
323 |
324 | def test_sled_model_parallel(self):
325 | assert torch.cuda.device_count() > 1
326 | model = SledModel(T5Model.from_pretrained("t5-small"), SledConfig("t5-small", context_size=512)).to("cuda:0")
327 | model.eval()
328 | assert model.is_parallelizable # bart is not, only t5
329 | assert not model.model_parallel
330 | assert model.device_map is None
331 |
332 | model2 = SledModel(T5Model.from_pretrained("t5-small"), SledConfig("t5-small", context_size=512)).to("cuda:0")
333 | model2.eval()
334 | model2.parallelize()
335 | assert model2.model_parallel
336 | assert len(model2.device_map) == min(3, torch.cuda.device_count())
337 |
338 | tokenizer = T5Tokenizer.from_pretrained("t5-small")
339 | input_ids = tokenizer([" ".join(list("1234567890"))] * 16, return_tensors="pt").input_ids.to("cuda:0")
340 | assert input_ids.size() == (16, 12) # Batch size 16, input length 10 + BOS + EOS
341 | decoder_input_ids = tokenizer([" ".join(list("hello"))] * 16, return_tensors="pt").input_ids.to("cuda:0")
342 | assert decoder_input_ids.size() == (16, 11) # Batch size 16, inputs length 5 + BOS + 4 spaces + EOS
343 |
344 | # simple verification there are no failures in the flow itself
345 | with torch.no_grad():
346 | outputs_expected = model(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone())
347 | outputs_got = model2(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone())
348 | self.compare_outputs_dict(
349 | outputs_expected, None, outputs_got, rtol=1e-5
350 | ) # the values may differ, but we need to make sure the sizes are correct
351 |
352 | def test_sled_with_data_parallel(self):
353 | assert torch.cuda.device_count() > 1
354 | model_ = SledModel(BartModel.from_pretrained("facebook/bart-base"),
355 | SledConfig("facebook/bart-base", context_size=128)).to("cuda:0")
356 | model = nn.DataParallel(model_)
357 | model.eval()
358 | assert isinstance(model, nn.DataParallel)
359 | replicas = model.replicate(model.module, model.device_ids)
360 | assert len(replicas) == torch.cuda.device_count()
361 | for i, rep in enumerate(replicas):
362 | assert isinstance(rep, SledModel)
363 | assert rep.device.index == i
364 | forward = replicas[0].forward
365 | replicas[0].forward = "hello"
366 | assert replicas[0].forward == "hello"
367 | assert replicas[1].forward != "hello"
368 | replicas[0].forward = forward
369 |
370 | tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
371 | input_ids = tokenizer([" ".join(list("1234567890"))] * 16, return_tensors="pt").input_ids.to("cuda:0")
372 | assert input_ids.size() == (16, 12) # Batch size 16, input length 10 + BOS + EOS
373 | decoder_input_ids = tokenizer([" ".join(list("hello"))] * 16, return_tensors="pt").input_ids.to("cuda:0")
374 | assert decoder_input_ids.size() == (16, 7) # Batch size 16, inputs length 5 + BOS + EOS
375 |
376 | # simple verification there are no failures in the flow itself
377 | with torch.no_grad():
378 | outputs_expected = model_(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone())
379 | outputs_got = model(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone())
380 | self.compare_outputs_dict(
381 | outputs_expected, None, outputs_got, rtol=1e-5
382 | ) # the values may differ, but we need to make sure the sizes are correct
383 |
384 | def test_sled_with_input_prefix(self):
385 | rtol = 1e-5
386 | model_, tokenizer = BartModel.from_pretrained("facebook/bart-base"), BartTokenizer.from_pretrained("facebook/bart-base")
387 | model = SledModel(model_, SledConfig("facebook/bart-base", context_size=16, prepend_prefix=True))
388 | model.eval() # only change the model to be in eval (inference) mode, thus not changing layer_norm params and removing dropout
389 |
390 | document_input_ids = tokenizer(
391 | "Studies have been shown that owning a dog is good for you", return_tensors="pt"
392 | ).input_ids # Batch size 1
393 | input_prefix_ids = tokenizer("What did studies show?\n\n", return_tensors="pt").input_ids # Batch size 1
394 | decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
395 |
396 | input_ids = torch.concat((input_prefix_ids, document_input_ids), dim=-1)
397 | prefix_length = torch.LongTensor([[input_prefix_ids.size(1)]])
398 |
399 | # simple verification there are no failures in the flow itself
400 | with torch.no_grad():
401 | _ = model(input_ids=input_ids.clone(), prefix_length=prefix_length.clone(),
402 | decoder_input_ids=decoder_input_ids.clone())
403 | _ = model(input_ids=input_ids.clone(), prefix_length=prefix_length.clone(),
404 | decoder_input_ids=decoder_input_ids.clone(),
405 | return_dict=None)
406 | outputs_dict = model(input_ids=input_ids.clone(), prefix_length=prefix_length.clone(),
407 | decoder_input_ids=decoder_input_ids.clone(), return_dict=True)
408 | outputs_no_dict = model(input_ids=input_ids.clone(), prefix_length=prefix_length.clone(),
409 | decoder_input_ids=decoder_input_ids.clone(), return_dict=False)
410 |
411 | # let's verify that if the sequence is short, we behave exactly as the base model
412 | model = SledModel(model_, SledConfig("facebook/bart-base", context_size=512, prepend_prefix=False))
413 | # treat the while input as a single document and make sure the context size is large enough to contain it wholly
414 | model.eval()
415 | with torch.no_grad():
416 | output_expected = model_(
417 | input_ids=input_ids.clone(),
418 | decoder_input_ids=decoder_input_ids.clone(), return_dict=False
419 | )
420 |
421 | # on non dict return type
422 | with torch.no_grad():
423 | output_got = model(input_ids=input_ids.clone(), prefix_length=prefix_length.clone(),
424 | decoder_input_ids=decoder_input_ids.clone(), return_dict=False)
425 | self.assertEqual(type(output_expected), type(output_got))
426 | self.assertEqual(type(output_expected), type(outputs_no_dict))
427 | self.assertEqual(len(output_got), len(output_expected)) # should be tuple so it's ok
428 | self.assertEqual(len(output_got), len(outputs_no_dict)) # should be tuple so it's ok
429 | _compare_tuple_of_tensors(self, output_expected, output_got, outputs_no_dict, rtol=rtol)
430 |
431 | # on dict return type
432 | with torch.no_grad():
433 | output_expected = model_(input_ids=input_ids.clone(),
434 | decoder_input_ids=decoder_input_ids.clone(), return_dict=True
435 | )
436 | output_got = model(input_ids=input_ids.clone(), decoder_input_ids=decoder_input_ids.clone(), return_dict=True)
437 | self.compare_outputs_dict(output_expected, output_got, outputs_dict, rtol=rtol)
438 |
439 | @unittest.expectedFailure
440 | def test_sled_with_variable_sized_input_prefix(self):
441 | raise NotImplementedError
442 |
443 | @unittest.expectedFailure
444 | def test_sled_overhead(self):
445 | raise NotImplementedError
446 |
447 | @unittest.expectedFailure
448 | def test_sled_multiple_update_steps(self):
449 | raise NotImplementedError
450 |
451 | @unittest.expectedFailure
452 | def test_eval_mode(self):
453 | raise NotImplementedError
454 | # TODO - assert that eval() works by looking at the dropout rate? Three forward passes with the model in train,
455 | # then three in eval and see that the train differ, but eval does not?
456 |
457 | @unittest.expectedFailure
458 | def test_prefix_prepending(self):
459 | # TODO - first, test the expected versions (no prepending+ no prefix given or w/ prepending+prefix given)
460 | # todo - test sled that should prepend but doesn't get it
461 | # todo - test sled that shouldn't prepend but gets a prefix length (make sure it ignores it)
462 | raise NotImplementedError
463 |
464 | @unittest.expectedFailure
465 | def test_drop_prefix_encoding(self):
466 | raise NotImplementedError
467 |
468 |
469 |
470 | @require_torch
471 | class SLEDForConditionalGenerationTest(unittest.TestCase):
472 | def test_facade_get_attr_behavior(self):
473 | base_model = T5ForConditionalGeneration.from_pretrained("t5-small")
474 | sled_model = SledForConditionalGeneration(base_model, SledConfig("t5-small", context_size=4))
475 | self.assertEqual(
476 | base_model.shared, sled_model.shared
477 | ) # sled model does not have a 'shared' attribute, only it's base model does
478 |
479 | def test_sled_for_cg_on_t5(self):
480 | self._run_sled_for_cg_model_test_case(
481 | T5ForConditionalGeneration.from_pretrained("t5-small"), T5Tokenizer.from_pretrained("t5-small"), "t5-small"
482 | )
483 |
484 | def test_sled_for_cg_on_bart(self):
485 | self._run_sled_for_cg_model_test_case(
486 | BartForConditionalGeneration.from_pretrained("facebook/bart-base"),
487 | BartTokenizer.from_pretrained("facebook/bart-base"), "facebook/bart-base"
488 | )
489 |
490 | def test_sled_for_cg_generate_on_t5(self):
491 | self._run_sled_for_cg_model_generate_test_case(
492 | T5ForConditionalGeneration.from_pretrained("t5-small"), T5Tokenizer.from_pretrained("t5-small"), "t5-small"
493 | )
494 |
495 | def test_sled_for_cg_generate_on_bart(self):
496 | self._run_sled_for_cg_model_generate_test_case(
497 | BartForConditionalGeneration.from_pretrained("facebook/bart-base"),
498 | BartTokenizer.from_pretrained("facebook/bart-base"), "facebook/bart-base"
499 | )
500 |
501 | def _run_sled_for_cg_model_test_case(self, model_, tokenizer, underlying_config: str, rtol=1e-5):
502 | model = SledForConditionalGeneration(model_, SledConfig(underlying_config, context_size=4))
503 | model.eval()
504 |
505 | input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids
506 | labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids
507 |
508 | # simple verification there are no failures in the flow itself
509 | with torch.no_grad():
510 | _ = model(input_ids=input_ids.clone(), labels=labels.clone())
511 | _ = model(input_ids=input_ids.clone(), labels=labels.clone(), return_dict=None)
512 | outputs_dict = model(input_ids=input_ids.clone(), labels=labels.clone(), return_dict=True)
513 | outputs_no_dict = model(input_ids=input_ids.clone(), labels=labels.clone(), return_dict=False)
514 |
515 | # let's verify that if the sequence is short, we behave exactly as the base model
516 | model = SledForConditionalGeneration(model_, SledConfig(underlying_config, context_size=512))
517 | output_expected = model_(input_ids=input_ids.clone(), labels=labels.clone(), return_dict=False)
518 |
519 | # on non dict return type
520 | with torch.no_grad():
521 | output_got = model(input_ids=input_ids.clone(), labels=labels.clone(), return_dict=False)
522 | self.assertEqual(type(output_expected), type(output_got))
523 | self.assertEqual(type(output_expected), type(outputs_no_dict))
524 | self.assertEqual(len(output_got), len(output_expected)) # should be tuple so it's ok
525 | self.assertEqual(len(output_got), len(outputs_no_dict)) # should be tuple so it's ok
526 | _compare_tuple_of_tensors(self, output_expected, output_got, outputs_no_dict, rtol=rtol)
527 |
528 | # on dict return type
529 | with torch.no_grad():
530 | output_expected = model_(input_ids=input_ids.clone(), labels=labels.clone(), return_dict=True)
531 | output_got = model(input_ids=input_ids.clone(), labels=labels.clone(), return_dict=True)
532 | self.assertEqual(type(output_expected), type(output_got))
533 | self.assertEqual(type(output_expected), type(outputs_dict))
534 | self.assertListEqual(list(output_got.keys()), list(output_expected.keys()))
535 | self.assertListEqual(list(output_got.keys()), list(outputs_dict.keys()))
536 | for key in output_got.keys():
537 | if isinstance(output_got[key], torch.Tensor):
538 | self.assertTrue(torch.allclose(output_got[key], output_expected[key], rtol=rtol))
539 | # we can't expect the values to be the same when different context length, but at least can verify shapes
540 | self.assertTrue(output_got[key].size() == outputs_dict[key].size())
541 | elif isinstance(output_got[key], tuple):
542 | _compare_tuple_of_tensors(self, output_got[key], output_expected[key], outputs_dict[key], rtol=rtol)
543 |
544 | def _run_sled_for_cg_model_generate_test_case(self, model_, tokenizer, underlying_config: str, rtol=1e-5):
545 | input_ids = tokenizer(
546 | "summarize: studies have shown that owning a dog is good for you ", return_tensors="pt"
547 | ).input_ids # Batch size 1
548 | model_.eval()
549 | # once it is used in SledForConditionalGeneration, it cannot generate directly anymore
550 | output_expected = model_.generate(input_ids.clone())
551 |
552 | model = SledForConditionalGeneration(model_, SledConfig(underlying_config, context_size=4))
553 | model.eval()
554 |
555 | # simple verification there are no failures in the flow itself
556 | _ = model.generate(
557 | torch.cat((input_ids.clone(), input_ids.clone()))
558 | ) # just to make sure can generate over two sequences
559 | _ = model.generate(input_ids.clone())
560 | outputs_no_dict = model.generate(input_ids.clone())
561 | assert outputs_no_dict.dim() == 2
562 | assert outputs_no_dict.size()[0] == 1
563 |
564 | # let's verify that if the sequence is short, we behave exactly as the base model
565 | model = SledForConditionalGeneration(model_, SledConfig(underlying_config, context_size=512))
566 | model.eval()
567 |
568 | output_got = model.generate(input_ids.clone())
569 | self.assertEqual(type(output_expected), type(output_got))
570 | self.assertEqual(type(output_expected), type(outputs_no_dict))
571 | self.assertEqual(len(output_got), len(output_expected)) # should be tuple so it's ok
572 | self.assertEqual(len(output_got), len(outputs_no_dict)) # should be tuple so it's ok
573 | # no point checking the dim as the generation length may be different
574 | _compare_tuple_of_tensors(self, (output_expected,), (output_got,), None, rtol=rtol)
575 |
--------------------------------------------------------------------------------