├── .gitattributes ├── .gitignore ├── LICENSE.md ├── README.md ├── configs ├── ablations │ ├── base_model │ │ └── roberta.json │ ├── varying_attention_heads │ │ └── roberta_2heads.json │ ├── varying_depth │ │ ├── roberta_2layers.json │ │ ├── roberta_4layers.json │ │ ├── roberta_6layers.json │ │ └── roberta_8layers.json │ └── varying_width │ │ ├── roberta_156.json │ │ ├── roberta_24.json │ │ ├── roberta_36.json │ │ ├── roberta_384.json │ │ └── roberta_72.json └── roberta.json ├── env.yaml ├── images └── multiplexing.gif ├── models ├── __init__.py ├── multiplexing.py ├── trainer.py └── utils.py ├── run_glue.py ├── run_glue.sh ├── run_job.sh ├── run_ner.py ├── run_ner.sh └── vision └── vision_multiplexing.ipynb /.gitattributes: -------------------------------------------------------------------------------- 1 | .ipynb linguist-documentation 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | #wand 141 | wandb/ 142 | checkpoints/ 143 | checkpoints 144 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | 2 | Copyright (c) [2022] [The Trustees of Princeton University] 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted for academic and research use only (subject to the limitations in the disclaimer below) provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, 9 | this list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | * Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from this 17 | software without specific prior written permission. 18 | 19 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 20 | THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 21 | CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 23 | PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 24 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 25 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 26 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR 27 | BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER 28 | IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 29 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 30 | POSSIBILITY OF SUCH DAMAGE. NO COMMERCIAL USE IS PERMITTED UNDER THIS LICENSE. 31 | 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## DataMUX ## 2 | 3 | PyTorch implementation for the paper: 4 | 5 | **[DataMUX: Data Multiplexing for Neural Networks](https://princeton-nlp.github.io/DataMUX/)** 6 | [Vishvak Murahari](https://vishvakmurahari.com/), [Carlos E. Jimenez](https://www.carlosejimenez.com/), [Runzhe Yang](https://runzhe-yang.science/), [Karthik Narasimhan](https://www.cs.princeton.edu/~karthikn/) 7 | 8 | ![models](images/multiplexing.gif) 9 | 10 | This repository contains code for reproducing results. We provide pretrained model weights and associated configs to run inference or train these models from scratch. If you find this work useful in your research, please cite: 11 | 12 | ``` 13 | @inproceedings{ 14 | murahari2022datamux, 15 | title={Data{MUX}: Data Multiplexing for Neural Networks}, 16 | author={Vishvak Murahari and Carlos E Jimenez and Runzhe Yang and Karthik R Narasimhan}, 17 | booktitle={Thirty-Sixth Conference on Neural Information Processing Systems}, 18 | year={2022}, 19 | url={https://openreview.net/forum?id=UdgtTVTdswg} 20 | } 21 | ``` 22 | 23 | ### Table of Contents 24 | 25 | * [Setup and Dependencies](#setup-and-dependencies) 26 | * [Usage](#usage) 27 | * [Overview](#Overview) 28 | * [Pre-trained checkpoints](#pre-trained-checkpoints) 29 | * [Training settings](#settings) 30 | * [Vision Tasks](#vision) 31 | * [Reference](#reference) 32 | * [License](#license) 33 | 34 | ### Setup and Dependencies 35 | 36 | Our code is implemented in PyTorch. To setup, do the following: 37 | 38 | 1. Install [Python 3.6](https://www.python.org/downloads/release/python-365/) 39 | 2. Get the source: 40 | ``` 41 | git clone https://github.com/princeton-nlp/DataMUX.git datamux 42 | ``` 43 | 3. Install requirements into the `datamux` virtual environment, using [Anaconda](https://anaconda.org/anaconda/python): 44 | ``` 45 | conda env create -f env.yaml 46 | ``` 47 | 48 | ### Usage 49 | 50 | #### Overview 51 | For sentence-level classification tasks, refer to `run_glue.py` and `run_glue.sh`. For token-level classification tasks, refer to `run_ner.py` and `run_ner.sh`. 52 | #### Pre-trained checkpoints 53 | We release all the pretrained checkpoints on the Hugging Face [model hub](https://huggingface.co/princeton-nlp). We list the checkpoints below. For number of instances, use 2, 5, 10, 20 or 40. 54 | 55 | | Task | Model name on hub | Full path | 56 | | ----------------|:-------------------|---------: 57 | | Retrieval Warmup| datamux-retrieval- | princeton-nlp/datamux-retrieval-| 58 | | MNLI | datamux-mnli- | princeton-nlp/datamux-mnli-| 59 | | QNLI | datamux-qnli- | princeton-nlp/datamux-qnli-| 60 | | QQP | datamux-qqp- | princeton-nlp/datamux-qqp-| 61 | | SST2 | datamux-sst2- | princeton-nlp/datamux-sst2-| 62 | | NER | datamux-ner- | princeton-nlp/datamux-ner-| 63 | 64 | #### Settings 65 | The bash scripts `run_ner.sh` and `run_glue.sh` take the following arguments: 66 | 67 | 68 | | Argument | Flag | Explanation |Argument Choices | 69 | | ------------- |:-----|-----------------------------:|-----------------| 70 | | NUM_INSTANCES | -N --num_instances | Number of multiplexing instances | 2,5,10,20,40 | 71 | | DEMUXING | -d --demuxing | Demultiplexing architecture| "index", "mlp" 72 | | MUXING | -m --muxing | Multiplexing architecture | "gaussian_hadamard", "binary_hadamard", "random_ortho"| 73 | | SETTING | -s --setting | Training setting | "baseline", "finetuning", "retrieval_pretraining"| 74 | | TASK_NAME | --task | Task name during finetuning | "mnli", "qnli", "sst2", "qqp" for `run_glue.py` or "ner" for `run_ner.py` 75 | | LEARNING_RATE | --lr | Learning rate for optimization| Any float but we use either 2e-5 or 5e-5| 76 | | BATCH_SIZE | --batch_size | Batch size (after multiplexing); note that the *effective* batch size is BATCH_SIZE * NUM_INSTANCES | Any integer. If left unset, will be set automatically based on value of N| 77 | | CONFIG_NAME | --config_name | Config path for backbone Transformer Model| Any config file in `configs` directory 78 | | MODEL_PATH | --model_path | Model path if either continuing to train from a checkpoint or initialize from retrieval task pretrained checkpoint| Path to local checkpoint or path to model on the [hub](https://huggingface.co/princeton-nlp) 79 | | LEARN_MUXING | --learn_muxing | Whether to learn instance embeddings in multiplexing| | 80 | | DO_TRAIN | --do_train | Pass flag to do training | | 81 | | DO_EVAL | --do_eval | Pass flag to do eval | | 82 | 83 | Below we list exemplar commands for different training settings: 84 | 85 | #### Retrieval pretraining 86 | This commands runs retrieval pretraining for N=2 87 | ``` 88 | sh run_glue.sh \ 89 | -N 2 \ 90 | -d index \ 91 | -m gaussian_hadamard \ 92 | -s retrieval_pretraining \ 93 | --config_name configs/ablations/base_model/roberta.json \ 94 | --lr 5e-5 \ 95 | --do_train \ 96 | --do_eval 97 | ``` 98 | 99 | #### Finetuning 100 | This command finetunes from a retrieval pretrained checkpoint with N=2 101 | ``` 102 | sh run_glue.sh \ 103 | -N 2 \ 104 | -d index \ 105 | -m gaussian_hadamard \ 106 | -s finetuning \ 107 | --config_name configs/ablations/base_model/roberta.json \ 108 | --lr 5e-5 \ 109 | --task mnli \ 110 | --model_path princeton-nlp/datamux-retrieval-2 \ 111 | --do_train \ 112 | --do_eval 113 | ``` 114 | 115 | Similar, to run token-level classification tasks like NER, change `run_glue.sh` to `run_ner.sh` 116 | ``` 117 | sh run_ner.sh \ 118 | -N 2 \ 119 | -d index \ 120 | -m gaussian_hadamard \ 121 | -s finetuning \ 122 | --config_name configs/ablations/base_model/roberta.json \ 123 | --lr 5e-5 \ 124 | --task ner \ 125 | --model_path princeton-nlp/datamux-retrieval-2 \ 126 | --do_train \ 127 | --do_eval 128 | ``` 129 | 130 | #### Baselines 131 | For the non-multiplexed baselines, run the following commnands 132 | ``` 133 | sh run_glue.sh \ 134 | -N 1 \ 135 | -s baseline \ 136 | --config_name configs/ablations/base_model/roberta.json \ 137 | --lr 2e-5 \ 138 | --task mnli 139 | ``` 140 | 141 | #### Vision 142 | For reproducing results on the vision tasks for MLPs and CNNs, please use this [notebook](https://github.com/princeton-nlp/DataMUX/blob/main/vision/vision_multiplexing.ipynb) 143 | 144 | ### Reference 145 | ``` 146 | @inproceedings{ 147 | murahari2022datamux, 148 | title={Data{MUX}: Data Multiplexing for Neural Networks}, 149 | author={Vishvak Murahari and Carlos E Jimenez and Runzhe Yang and Karthik R Narasimhan}, 150 | booktitle={Thirty-Sixth Conference on Neural Information Processing Systems}, 151 | year={2022}, 152 | url={https://openreview.net/forum?id=UdgtTVTdswg} 153 | } 154 | ``` 155 | ### License 156 | Check `LICENSE.md` 157 | -------------------------------------------------------------------------------- /configs/ablations/base_model/roberta.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "bos_token_id": 0, 7 | "eos_token_id": 2, 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 768, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 3072, 13 | "layer_norm_eps": 1e-05, 14 | "max_position_embeddings": 514, 15 | "model_type": "roberta", 16 | "num_attention_heads": 12, 17 | "num_hidden_layers": 12, 18 | "pad_token_id": 1, 19 | "type_vocab_size": 1, 20 | "vocab_size": 50265 21 | } 22 | -------------------------------------------------------------------------------- /configs/ablations/varying_attention_heads/roberta_2heads.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "bos_token_id": 0, 7 | "eos_token_id": 2, 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 768, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 3072, 13 | "layer_norm_eps": 1e-05, 14 | "max_position_embeddings": 514, 15 | "model_type": "roberta", 16 | "num_attention_heads": 2, 17 | "num_hidden_layers": 12, 18 | "pad_token_id": 1, 19 | "type_vocab_size": 1, 20 | "vocab_size": 50265 21 | } 22 | -------------------------------------------------------------------------------- /configs/ablations/varying_depth/roberta_2layers.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "bos_token_id": 0, 7 | "eos_token_id": 2, 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 768, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 3072, 13 | "layer_norm_eps": 1e-05, 14 | "max_position_embeddings": 514, 15 | "model_type": "roberta", 16 | "num_attention_heads": 12, 17 | "num_hidden_layers": 2, 18 | "pad_token_id": 1, 19 | "type_vocab_size": 1, 20 | "vocab_size": 50265 21 | } 22 | -------------------------------------------------------------------------------- /configs/ablations/varying_depth/roberta_4layers.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "bos_token_id": 0, 7 | "eos_token_id": 2, 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 768, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 3072, 13 | "layer_norm_eps": 1e-05, 14 | "max_position_embeddings": 514, 15 | "model_type": "roberta", 16 | "num_attention_heads": 12, 17 | "num_hidden_layers": 4, 18 | "pad_token_id": 1, 19 | "type_vocab_size": 1, 20 | "vocab_size": 50265 21 | } 22 | -------------------------------------------------------------------------------- /configs/ablations/varying_depth/roberta_6layers.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "bos_token_id": 0, 7 | "eos_token_id": 2, 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 768, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 3072, 13 | "layer_norm_eps": 1e-05, 14 | "max_position_embeddings": 514, 15 | "model_type": "roberta", 16 | "num_attention_heads": 12, 17 | "num_hidden_layers": 6, 18 | "pad_token_id": 1, 19 | "type_vocab_size": 1, 20 | "vocab_size": 50265 21 | } 22 | -------------------------------------------------------------------------------- /configs/ablations/varying_depth/roberta_8layers.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "bos_token_id": 0, 7 | "eos_token_id": 2, 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 768, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 3072, 13 | "layer_norm_eps": 1e-05, 14 | "max_position_embeddings": 514, 15 | "model_type": "roberta", 16 | "num_attention_heads": 12, 17 | "num_hidden_layers": 8, 18 | "pad_token_id": 1, 19 | "type_vocab_size": 1, 20 | "vocab_size": 50265 21 | } 22 | -------------------------------------------------------------------------------- /configs/ablations/varying_width/roberta_156.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "bos_token_id": 0, 7 | "eos_token_id": 2, 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 156, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 624, 13 | "layer_norm_eps": 1e-05, 14 | "max_position_embeddings": 514, 15 | "model_type": "roberta", 16 | "num_attention_heads": 12, 17 | "num_hidden_layers": 12, 18 | "pad_token_id": 1, 19 | "type_vocab_size": 1, 20 | "vocab_size": 50265 21 | } 22 | -------------------------------------------------------------------------------- /configs/ablations/varying_width/roberta_24.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "bos_token_id": 0, 7 | "eos_token_id": 2, 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 24, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 96, 13 | "layer_norm_eps": 1e-05, 14 | "max_position_embeddings": 514, 15 | "model_type": "roberta", 16 | "num_attention_heads": 12, 17 | "num_hidden_layers": 12, 18 | "pad_token_id": 1, 19 | "type_vocab_size": 1, 20 | "vocab_size": 50265 21 | } 22 | -------------------------------------------------------------------------------- /configs/ablations/varying_width/roberta_36.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "bos_token_id": 0, 7 | "eos_token_id": 2, 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 36, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 144, 13 | "layer_norm_eps": 1e-05, 14 | "max_position_embeddings": 514, 15 | "model_type": "roberta", 16 | "num_attention_heads": 12, 17 | "num_hidden_layers": 12, 18 | "pad_token_id": 1, 19 | "type_vocab_size": 1, 20 | "vocab_size": 50265 21 | } 22 | -------------------------------------------------------------------------------- /configs/ablations/varying_width/roberta_384.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "bos_token_id": 0, 7 | "eos_token_id": 2, 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 384, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 1536, 13 | "layer_norm_eps": 1e-05, 14 | "max_position_embeddings": 514, 15 | "model_type": "roberta", 16 | "num_attention_heads": 12, 17 | "num_hidden_layers": 12, 18 | "pad_token_id": 1, 19 | "type_vocab_size": 1, 20 | "vocab_size": 50265 21 | } 22 | -------------------------------------------------------------------------------- /configs/ablations/varying_width/roberta_72.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "bos_token_id": 0, 7 | "eos_token_id": 2, 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 72, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 288, 13 | "layer_norm_eps": 1e-05, 14 | "max_position_embeddings": 514, 15 | "model_type": "roberta", 16 | "num_attention_heads": 12, 17 | "num_hidden_layers": 12, 18 | "pad_token_id": 1, 19 | "type_vocab_size": 1, 20 | "vocab_size": 50265 21 | } 22 | -------------------------------------------------------------------------------- /configs/roberta.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "RobertaForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "bos_token_id": 0, 7 | "eos_token_id": 2, 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 768, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 3072, 13 | "layer_norm_eps": 1e-05, 14 | "max_position_embeddings": 514, 15 | "model_type": "roberta", 16 | "num_attention_heads": 12, 17 | "num_hidden_layers": 12, 18 | "pad_token_id": 1, 19 | "type_vocab_size": 1, 20 | "vocab_size": 50265 21 | } 22 | -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: datamux 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=4.5=1_gnu 8 | - blas=1.0=mkl 9 | - bzip2=1.0.8=h7b6447c_0 10 | - ca-certificates=2021.10.26=h06a4308_2 11 | - certifi=2021.10.8=py39h06a4308_0 12 | - cudatoolkit=11.3.1=h2bc3f7f_2 13 | - ffmpeg=4.3=hf484d3e_0 14 | - freetype=2.11.0=h70c0345_0 15 | - giflib=5.2.1=h7b6447c_0 16 | - gmp=6.2.1=h2531618_2 17 | - gnutls=3.6.15=he1e5248_0 18 | - intel-openmp=2021.4.0=h06a4308_3561 19 | - jpeg=9d=h7f8727e_0 20 | - lame=3.100=h7b6447c_0 21 | - lcms2=2.12=h3be6417_0 22 | - ld_impl_linux-64=2.35.1=h7274673_9 23 | - libffi=3.3=he6710b0_2 24 | - libgcc-ng=9.3.0=h5101ec6_17 25 | - libgomp=9.3.0=h5101ec6_17 26 | - libiconv=1.15=h63c8f33_5 27 | - libidn2=2.3.2=h7f8727e_0 28 | - libpng=1.6.37=hbc83047_0 29 | - libstdcxx-ng=9.3.0=hd4cf53a_17 30 | - libtasn1=4.16.0=h27cfd23_0 31 | - libtiff=4.2.0=h85742a9_0 32 | - libunistring=0.9.10=h27cfd23_0 33 | - libuv=1.40.0=h7b6447c_0 34 | - libwebp=1.2.0=h89dd481_0 35 | - libwebp-base=1.2.0=h27cfd23_0 36 | - lz4-c=1.9.3=h295c915_1 37 | - mkl=2021.4.0=h06a4308_640 38 | - mkl-service=2.4.0=py39h7f8727e_0 39 | - mkl_fft=1.3.1=py39hd3c417c_0 40 | - mkl_random=1.2.2=py39h51133e4_0 41 | - ncurses=6.3=heee7806_1 42 | - nettle=3.7.3=hbbd107a_1 43 | - numpy=1.21.2=py39h20f2e39_0 44 | - numpy-base=1.21.2=py39h79a1101_0 45 | - olefile=0.46=pyhd3eb1b0_0 46 | - openh264=2.1.0=hd408876_0 47 | - openssl=1.1.1l=h7f8727e_0 48 | - pillow=8.4.0=py39h5aabda8_0 49 | - pip=21.2.4=py39h06a4308_0 50 | - python=3.9.7=h12debd9_1 51 | - pytorch=1.10.0=py3.9_cuda11.3_cudnn8.2.0_0 52 | - pytorch-mutex=1.0=cuda 53 | - readline=8.1=h27cfd23_0 54 | - setuptools=58.0.4=py39h06a4308_0 55 | - six=1.16.0=pyhd3eb1b0_0 56 | - sqlite=3.36.0=hc218d9a_0 57 | - tk=8.6.11=h1ccaba5_0 58 | - torchaudio=0.10.0=py39_cu113 59 | - torchvision=0.11.1=py39_cu113 60 | - typing_extensions=3.10.0.2=pyh06a4308_0 61 | - tzdata=2021e=hda174b7_0 62 | - wheel=0.37.0=pyhd3eb1b0_1 63 | - xz=5.2.5=h7b6447c_0 64 | - zlib=1.2.11=h7b6447c_3 65 | - zstd=1.4.9=haebb681_0 66 | - pip: 67 | - absl-py==0.15.0 68 | - astunparse==1.6.3 69 | - cachetools==4.2.4 70 | - cffi==1.15.0 71 | - charset-normalizer==2.0.7 72 | - click==8.0.3 73 | - configparser==5.1.0 74 | - datasets==1.5.0 75 | - dill==0.3.4 76 | - docker-pycreds==0.4.0 77 | - filelock==3.3.2 78 | - flatbuffers==2.0 79 | - fsspec==2021.11.0 80 | - gast==0.4.0 81 | - gitdb==4.0.9 82 | - gitpython==3.1.24 83 | - google-auth==2.3.3 84 | - google-auth-oauthlib==0.4.6 85 | - google-pasta==0.2.0 86 | - grpcio==1.41.1 87 | - h5py==3.5.0 88 | - huggingface-hub==0.0.19 89 | - idna==3.3 90 | - joblib==1.1.0 91 | - keras==2.7.0 92 | - keras-preprocessing==1.1.2 93 | - libclang==12.0.0 94 | - markdown==3.3.4 95 | - multiprocess==0.70.12.2 96 | - oauthlib==3.1.1 97 | - opt-einsum==3.3.0 98 | - packaging==21.2 99 | - pandas==1.3.4 100 | - pathtools==0.1.2 101 | - promise==2.3 102 | - protobuf==3.19.1 103 | - psutil==5.8.0 104 | - pyarrow==6.0.0 105 | - pyasn1==0.4.8 106 | - pyasn1-modules==0.2.8 107 | - pycparser==2.21 108 | - pyparsing==2.4.7 109 | - python-dateutil==2.8.2 110 | - pytz==2021.3 111 | - pyyaml==6.0 112 | - regex==2021.11.2 113 | - requests==2.26.0 114 | - requests-oauthlib==1.3.0 115 | - rsa==4.7.2 116 | - sacremoses==0.0.46 117 | - scikit-learn==1.0.1 118 | - scipy==1.7.2 119 | - sentry-sdk==1.4.3 120 | - seqeval==1.2.2 121 | - shortuuid==1.0.4 122 | - sklearn==0.0 123 | - smmap==5.0.0 124 | - subprocess32==3.5.4 125 | - tensorboard==2.7.0 126 | - tensorboard-data-server==0.6.1 127 | - tensorboard-plugin-wit==1.8.0 128 | - tensorflow==2.7.0 129 | - tensorflow-estimator==2.7.0 130 | - tensorflow-io-gcs-filesystem==0.21.0 131 | - termcolor==1.1.0 132 | - threadpoolctl==3.0.0 133 | - tokenizers==0.10.3 134 | - tqdm==4.49.0 135 | - transformers==4.4.2 136 | - urllib3==1.26.7 137 | - wandb==0.12.6 138 | - werkzeug==2.0.2 139 | - wrapt==1.13.3 140 | - xxhash==2.0.2 141 | - yaspin==2.1.0 142 | - nvidia-ml-py3 -------------------------------------------------------------------------------- /images/multiplexing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/DataMUX/3e45c0b50070ff11c0d936ae4bc816489d842e12/images/multiplexing.gif -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/DataMUX/3e45c0b50070ff11c0d936ae4bc816489d842e12/models/__init__.py -------------------------------------------------------------------------------- /models/multiplexing.py: -------------------------------------------------------------------------------- 1 | from re import X 2 | from dataclasses import dataclass 3 | 4 | from transformers.utils import logging 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import CrossEntropyLoss, MSELoss 8 | from transformers import ( 9 | RobertaModel, 10 | ) 11 | from transformers.modeling_outputs import SequenceClassifierOutput, ModelOutput 12 | from transformers.models.roberta.modeling_roberta import ( 13 | RobertaPreTrainedModel, 14 | ) 15 | from typing import Optional, Tuple 16 | from transformers.activations import gelu 17 | import math 18 | import numpy as np 19 | from .utils import ( 20 | random_encoding, 21 | binary_encoding, 22 | ) 23 | from scipy.stats import ortho_group, special_ortho_group 24 | 25 | logger = logging.get_logger(__name__) 26 | 27 | ####### SEQUENCE CLASSIFICATION CLASSES 28 | 29 | class RobertaSequenceClassificationMuxed(RobertaPreTrainedModel): 30 | _keys_to_ignore_on_load_missing = [r"position_ids"] 31 | 32 | def __init__(self, config): 33 | super().__init__(config) 34 | self.num_labels = config.num_labels 35 | self.config = config 36 | 37 | self.num_instances = config.num_instances 38 | self.muxing_variant = config.muxing_variant 39 | self.demuxing_variant = config.demuxing_variant 40 | self.retrieval_loss_coeff = config.retrieval_loss_coeff 41 | self.task_loss_coeff = config.task_loss_coeff 42 | 43 | self.roberta = RobertaModel(config, add_pooling_layer=False) 44 | if config.demuxing_variant == "index": 45 | self.demultiplexer = RobertaIndexDemultiplexerSequenceClassification(config) 46 | self.retrieval_head = RetrievalHeadIndexDemultiplexing(config) 47 | elif config.demuxing_variant == "mlp": 48 | self.demultiplexer = RobertaMLPDemultiplexerSequenceClassification(config) 49 | self.retrieval_head = RetrievalHeadMLPDemultiplexing(config) 50 | else: 51 | raise NotImplementedError("demuxing_varaint argument (%s) not recognized." % config.demuxing_variant) 52 | 53 | self.init_weights() 54 | 55 | d_model = config.hidden_size 56 | instance_embedding = None 57 | 58 | if self.muxing_variant == "gaussian_hadamard": 59 | instance_embedding = random_encoding( 60 | self.num_instances, d_model, norm=config.gaussian_hadamard_norm 61 | ) 62 | elif self.muxing_variant == "random_ortho": 63 | instance_embedding = [ 64 | torch.from_numpy(ortho_group.rvs(config.hidden_size)).float() 65 | for _ in range(self.num_instances) 66 | ] 67 | instance_embedding = torch.stack(instance_embedding, dim=0) 68 | elif self.muxing_variant == "binary_hadamard": 69 | instance_embedding = binary_encoding( 70 | self.num_instances, d_model, epsilon=config.binary_hadamard_epsilon 71 | ) 72 | else: 73 | raise NotImplementedError("muxing_variant argument (%s) not recognized." % config.muxing_variant) 74 | 75 | if instance_embedding is not None: 76 | self.instance_embedding = torch.nn.Parameter(instance_embedding) 77 | else: 78 | instance_embedding = random_encoding( 79 | self.num_instances, d_model, norm=self.gaussian_hadamard_norm 80 | ) 81 | 82 | if not config.learn_muxing: 83 | self.instance_embedding.requires_grad = False 84 | else: 85 | self.instance_embedding.requires_grad = True 86 | 87 | def forward( 88 | self, 89 | input_ids=None, 90 | attention_mask=None, 91 | token_type_ids=None, 92 | position_ids=None, 93 | inputs_embeds=None, 94 | labels=None, 95 | return_dict=None, 96 | ): 97 | return_dict = ( 98 | return_dict if return_dict is not None else self.config.use_return_dict 99 | ) 100 | # get input embeddings and average over N instances 101 | input_shape = input_ids.size() 102 | 103 | batch_size, seq_length = input_shape 104 | num_instances = self.num_instances 105 | past_key_values_length = 0 106 | 107 | modified_batch_size = batch_size // num_instances 108 | modified_seq_length = None 109 | special_tokens_end_position=None 110 | if self.demuxing_variant == "index": 111 | 112 | # add the prefix 113 | # [CLS1, , , , ] 114 | # [, CLS2, , , ] 115 | # [, , CLS3, , ] 116 | # [, , , CLS4, ] 117 | # [, , , , CLS5] 118 | # let us just assume the last 5 tokens barring the masked token 119 | # are the cls tokens (easiest way to make use of existing vocab) 120 | 121 | # prefix 5 x 5 122 | 123 | prefix = torch.full((num_instances, num_instances), 50000, device=input_ids.device) 124 | prefix[ 125 | torch.arange(num_instances, device=input_ids.device), 126 | torch.arange(num_instances, device=input_ids.device) 127 | ] = ( 128 | -(torch.arange(num_instances, device=input_ids.device) + 2) 129 | + self.roberta.embeddings.word_embeddings.weight.shape[0] 130 | ) 131 | 132 | # [-2 , , , ] 133 | # [, -3, , , ] 134 | # [, , -4, , ] 135 | # [, , , -5, ] 136 | # [, , , , -6] 137 | # + size of vocab 138 | cls_tokens = torch.full((num_instances, 1), 49923, device=input_ids.device) 139 | prefix = torch.cat([prefix, cls_tokens], dim=1) 140 | 141 | prefix = prefix.repeat(modified_batch_size, 1) 142 | input_ids = input_ids[: (modified_batch_size * num_instances)] 143 | input_ids = torch.cat([prefix, input_ids], dim=1) 144 | modified_seq_length = seq_length + num_instances + 1 145 | special_tokens_end_position = num_instances + 1 146 | 147 | elif self.demuxing_variant == "mlp": 148 | cls_tokens = torch.full((num_instances, 1), 49923, device=input_ids.device) 149 | cls_tokens = cls_tokens.repeat(modified_batch_size, 1) 150 | # prefix = prefix.repeat(modified_batch_size, 1) 151 | input_ids = input_ids[: (modified_batch_size * num_instances)] 152 | input_ids[:, 0:1] = cls_tokens 153 | modified_seq_length = seq_length 154 | special_tokens_end_position = 1 155 | 156 | else: 157 | raise NotImplementedError("demuxing_variant (%s) not recognized." % self.demuxing_variant) 158 | 159 | # concatenate 160 | embedding_output = self.roberta.embeddings( 161 | input_ids=input_ids, 162 | position_ids=position_ids, 163 | token_type_ids=token_type_ids, 164 | inputs_embeds=inputs_embeds, 165 | past_key_values_length=past_key_values_length, 166 | ) 167 | _, _, embedding_dim = embedding_output.shape 168 | if self.muxing_variant == "random_ortho": 169 | embedding_output = embedding_output.view( 170 | modified_batch_size, 171 | num_instances, 172 | modified_seq_length, 173 | embedding_dim, 174 | ) 175 | embedding_output = torch.matmul( 176 | self.instance_embedding, embedding_output.permute(0, 1, 3, 2) 177 | ) 178 | # swap the last 2 dimensions again 179 | embedding_output = embedding_output.permute(0, 1, 3, 2) 180 | # average across the instances 181 | embedding_output = torch.sum(embedding_output, dim=1) / math.sqrt( 182 | self.num_instances 183 | ) 184 | else: 185 | embedding_output = embedding_output.view( 186 | modified_batch_size, 187 | num_instances, 188 | modified_seq_length, 189 | embedding_dim, 190 | ) 191 | 192 | # extract relevant instance embeddings 193 | instance_embed = self.instance_embedding[:num_instances, :] 194 | instance_embed = instance_embed.unsqueeze(1).expand( 195 | num_instances, modified_seq_length, embedding_dim 196 | ) 197 | embedding_output = embedding_output * instance_embed.unsqueeze(0) 198 | 199 | embedding_output = torch.mean(embedding_output, dim=1) 200 | 201 | outputs = self.roberta( 202 | input_ids=None, 203 | attention_mask=None, 204 | token_type_ids=None, 205 | position_ids=position_ids, 206 | inputs_embeds=embedding_output, 207 | return_dict=return_dict, 208 | ) 209 | sequence_output = outputs[0] 210 | # fancy indexing to get the instance position embedding 211 | 212 | logits, demuxed_representations = self.demultiplexer(sequence_output) 213 | if labels is not None: 214 | 215 | labels = labels[: (modified_batch_size * num_instances)] 216 | instance_labels = torch.full( 217 | (modified_batch_size, modified_seq_length), 218 | 0, 219 | device=input_ids.device, 220 | ).long() 221 | # skip the cls and prefix tokens 222 | instance_labels[:, special_tokens_end_position :] = torch.randint( 223 | num_instances, (modified_batch_size, modified_seq_length - special_tokens_end_position), device=input_ids.device) 224 | 225 | # index into input ids to get the corresponding labels 226 | input_ids = input_ids.view(modified_batch_size, num_instances, -1) 227 | input_ids = input_ids.permute(0, 2, 1) 228 | 229 | retrieval_labels = input_ids[ 230 | torch.arange(modified_batch_size, device=input_ids.device) 231 | .unsqueeze(1) 232 | .expand(modified_batch_size, modified_seq_length), 233 | torch.arange(modified_seq_length, device=input_ids.device) 234 | .unsqueeze(0) 235 | .expand(modified_batch_size, modified_seq_length), 236 | instance_labels, 237 | ] 238 | retrieval_labels[:, :special_tokens_end_position] = -100 239 | 240 | pad_mask = retrieval_labels == 1 241 | # wipe of 1 - (0.1 * retrieval percentage) of pad tokens 242 | pad_mask_wipe = pad_mask 243 | non_pad_mask_wipe = ~pad_mask & torch.bernoulli( 244 | torch.full(retrieval_labels.shape, 1 - self.config.retrieval_percentage, device=input_ids.device) 245 | ).bool() 246 | retrieval_labels[non_pad_mask_wipe] = -100 247 | 248 | retrieval_labels[pad_mask_wipe] = -100 249 | 250 | retrieval_predictions = self.retrieval_head(sequence_output, instance_labels) 251 | 252 | retrieval_loss = None 253 | task_loss = None 254 | loss = None 255 | if labels is not None: 256 | loss_fct = CrossEntropyLoss() 257 | task_loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 258 | retrieval_loss = loss_fct( 259 | retrieval_predictions.view(-1, self.config.vocab_size), 260 | retrieval_labels.view(-1), 261 | ) 262 | loss = (self.task_loss_coeff * task_loss) + ( 263 | self.retrieval_loss_coeff * retrieval_loss 264 | ) 265 | 266 | if not return_dict: 267 | output = (logits,) + outputs[2:] 268 | return ((loss,) + output) if loss is not None else output 269 | 270 | return SequenceClassifierOutputMuxed( 271 | loss=loss, 272 | logits=logits, 273 | hidden_states=demuxed_representations, 274 | task_loss=task_loss, 275 | retrieval_loss=retrieval_loss, 276 | ) 277 | 278 | ####### TOKEN CLASSIFICATION CLASSES 279 | 280 | class RobertaTokenClassificationMuxed(RobertaPreTrainedModel): 281 | _keys_to_ignore_on_load_missing = [r"position_ids"] 282 | 283 | def __init__(self, config): 284 | super().__init__(config) 285 | self.num_labels = config.num_labels 286 | self.config = config 287 | 288 | self.num_instances = config.num_instances 289 | self.muxing_variant = config.muxing_variant 290 | self.demuxing_variant = config.demuxing_variant 291 | self.retrieval_loss_coeff = config.retrieval_loss_coeff 292 | self.task_loss_coeff = config.task_loss_coeff 293 | 294 | self.roberta = RobertaModel(config, add_pooling_layer=False) 295 | if config.demuxing_variant == "index": 296 | self.demultiplexer = RobertaIndexDemultiplexerTokenClassification(config) 297 | self.retrieval_head = RetrievalHeadIndexDemultiplexing(config) 298 | elif config.demuxing_variant == "mlp": 299 | self.demultiplexer = RobertaMLPDemultiplexerTokenClassification(config) 300 | self.retrieval_head = RetrievalHeadMLPDemultiplexing(config) 301 | else: 302 | raise NotImplementedError("demuxing_variant (%s) not recognized." % config.demuxing_variant) 303 | 304 | self.init_weights() 305 | 306 | d_model = config.hidden_size 307 | instance_embedding = None 308 | 309 | if self.muxing_variant == "gaussian_hadamard": 310 | instance_embedding = random_encoding( 311 | self.num_instances, d_model, norm=config.gaussian_hadamard_norm 312 | ) 313 | elif self.muxing_variant == "random_ortho": 314 | instance_embedding = [ 315 | torch.from_numpy(ortho_group.rvs(config.hidden_size)).float() 316 | for _ in range(self.num_instances) 317 | ] 318 | instance_embedding = torch.stack(instance_embedding, dim=0) 319 | elif self.muxing_variant == "binary_hadamard": 320 | instance_embedding = binary_encoding( 321 | self.num_instances, d_model, epsilon=config.binary_hadamard_epsilon 322 | ) 323 | else: 324 | raise NotImplementedError("muxing_variant (%s) not recognized." % config.muxing_variant) 325 | 326 | if instance_embedding is not None: 327 | self.instance_embedding = torch.nn.Parameter(instance_embedding) 328 | else: 329 | instance_embedding = random_encoding( 330 | self.num_instances, d_model, norm=self.gaussian_hadamard_norm 331 | ) 332 | 333 | if not config.learn_muxing: 334 | self.instance_embedding.requires_grad = False 335 | else: 336 | self.instance_embedding.requires_grad = True 337 | 338 | def forward( 339 | self, 340 | input_ids=None, 341 | attention_mask=None, 342 | token_type_ids=None, 343 | position_ids=None, 344 | inputs_embeds=None, 345 | labels=None, 346 | return_dict=None, 347 | ): 348 | 349 | return_dict = ( 350 | return_dict if return_dict is not None else self.config.use_return_dict 351 | ) 352 | # get input embeddings and average over N instances 353 | input_shape = input_ids.size() 354 | 355 | batch_size, seq_length = input_shape 356 | num_instances = self.num_instances 357 | past_key_values_length = 0 358 | 359 | modified_batch_size = batch_size // num_instances 360 | modified_seq_length = None 361 | special_tokens_end_position=None 362 | if self.demuxing_variant == "index": 363 | 364 | # add the prefix 365 | # [CLS1, , , , ] 366 | # [, CLS2, , , ] 367 | # [, , CLS3, , ] 368 | # [, , , CLS4, ] 369 | # [, , , , CLS5] 370 | # let us just assume the last 5 tokens barring the masked token 371 | # are the cls tokens (easiest way to make use of existing vocab) 372 | 373 | # prefix 5 x 5 374 | 375 | prefix = torch.full((num_instances, num_instances), 50000, device=input_ids.device) 376 | prefix[ 377 | torch.arange(num_instances, device=input_ids.device), 378 | torch.arange(num_instances, device=input_ids.device), 379 | ] = ( 380 | -(torch.arange(num_instances, device=input_ids.device) + 2) 381 | + self.roberta.embeddings.word_embeddings.weight.shape[0] 382 | ) 383 | 384 | # [-2 , , , ] 385 | # [, -3, , , ] 386 | # [, , -4, , ] 387 | # [, , , -5, ] 388 | # [, , , , -6] 389 | # + size of vocab 390 | cls_tokens = torch.full((num_instances, 1), 49923, device=input_ids.device) 391 | prefix = torch.cat([prefix, cls_tokens], dim=1) 392 | 393 | prefix = prefix.repeat(modified_batch_size, 1) 394 | input_ids = input_ids[: (modified_batch_size * num_instances)] 395 | input_ids = torch.cat([prefix, input_ids], dim=1) 396 | modified_seq_length = seq_length + num_instances + 1 397 | special_tokens_end_position = num_instances + 1 398 | 399 | elif self.demuxing_variant == "mlp": 400 | cls_tokens = torch.full((num_instances, 1), 49923, device=input_ids.device) 401 | cls_tokens = cls_tokens.repeat(modified_batch_size, 1) 402 | # prefix = prefix.repeat(modified_batch_size, 1) 403 | input_ids = input_ids[: (modified_batch_size * num_instances)] 404 | input_ids[:, 0:1] = cls_tokens 405 | modified_seq_length = seq_length 406 | special_tokens_end_position = 0 407 | 408 | else: 409 | raise NotImplementedError("demuxing_variant (%s) not recognized." % self.demuxing_variant) 410 | 411 | # concatenate 412 | embedding_output = self.roberta.embeddings( 413 | input_ids=input_ids, 414 | position_ids=position_ids, 415 | token_type_ids=token_type_ids, 416 | inputs_embeds=inputs_embeds, 417 | past_key_values_length=past_key_values_length, 418 | ) 419 | _, _, embedding_dim = embedding_output.shape 420 | if self.muxing_variant == "random_ortho": 421 | embedding_output = embedding_output.view( 422 | modified_batch_size, 423 | num_instances, 424 | modified_seq_length, 425 | embedding_dim, 426 | ) 427 | embedding_output = torch.matmul( 428 | self.instance_embedding, embedding_output.permute(0, 1, 3, 2) 429 | ) 430 | # swap the last 2 dimensions again 431 | embedding_output = embedding_output.permute(0, 1, 3, 2) 432 | # average across the instances 433 | embedding_output = torch.sum(embedding_output, dim=1) / math.sqrt( 434 | self.num_instances 435 | ) 436 | else: 437 | embedding_output = embedding_output.view( 438 | modified_batch_size, 439 | num_instances, 440 | modified_seq_length, 441 | embedding_dim, 442 | ) 443 | 444 | # extract relevant instance embeddings 445 | instance_embed = self.instance_embedding[:num_instances, :] 446 | instance_embed = instance_embed.unsqueeze(1).expand( 447 | num_instances, modified_seq_length, embedding_dim 448 | ) 449 | embedding_output = embedding_output * instance_embed.unsqueeze(0) 450 | 451 | embedding_output = torch.mean(embedding_output, dim=1) 452 | 453 | outputs = self.roberta( 454 | input_ids=None, 455 | attention_mask=None, 456 | token_type_ids=None, 457 | position_ids=position_ids, 458 | inputs_embeds=embedding_output, 459 | return_dict=return_dict, 460 | ) 461 | sequence_output = outputs[0] 462 | # fancy indexing to get the instance position embedding 463 | 464 | logits, demuxed_representations = self.demultiplexer(sequence_output) 465 | if labels is not None: 466 | # retrieval loss calculation 467 | labels = labels[: (modified_batch_size * num_instances)] 468 | instance_labels = torch.full( 469 | (modified_batch_size, modified_seq_length), 470 | 0, 471 | device=input_ids.device, 472 | ).long() 473 | # skip the cls and prefix tokens 474 | instance_labels[:, special_tokens_end_position :] = torch.randint( 475 | num_instances, (modified_batch_size, modified_seq_length - special_tokens_end_position), device=input_ids.device 476 | ) 477 | 478 | # index into input ids to get the corresponding labels 479 | input_ids = input_ids.view(modified_batch_size, num_instances, -1) 480 | input_ids = input_ids.permute(0, 2, 1) 481 | 482 | retrieval_labels = input_ids[ 483 | torch.arange(modified_batch_size, device=input_ids.device) 484 | .unsqueeze(1) 485 | .expand(modified_batch_size, modified_seq_length), 486 | torch.arange(modified_seq_length, device=input_ids.device) 487 | .unsqueeze(0) 488 | .expand(modified_batch_size, modified_seq_length), 489 | instance_labels, 490 | ] 491 | retrieval_labels[:, :special_tokens_end_position] = -100 492 | 493 | pad_mask = retrieval_labels == 1 494 | # wipe of 1 - (0.1 * retrieval percentage) of pad tokens 495 | pad_mask_wipe = pad_mask 496 | non_pad_mask_wipe = ~pad_mask & torch.bernoulli( 497 | torch.full(retrieval_labels.shape, 1 - self.config.retrieval_percentage, device=input_ids.device) 498 | ).bool() 499 | retrieval_labels[non_pad_mask_wipe] = -100 500 | 501 | retrieval_labels[pad_mask_wipe] = -100 502 | 503 | retrieval_predictions = self.retrieval_head(sequence_output, instance_labels) 504 | 505 | retrieval_loss = None 506 | task_loss = None 507 | loss = None 508 | if labels is not None: 509 | if attention_mask is not None: 510 | loss_fct = CrossEntropyLoss() 511 | active_loss = attention_mask.view(-1) == 1 512 | logits = logits[:, special_tokens_end_position:, :] 513 | 514 | active_logits = logits.reshape(-1, self.num_labels) 515 | active_labels = torch.where( 516 | active_loss, 517 | labels.view(-1), 518 | torch.tensor(loss_fct.ignore_index).type_as(labels), 519 | ) 520 | task_loss = loss_fct(active_logits, active_labels) 521 | retrieval_loss = loss_fct( 522 | retrieval_predictions.view(-1, self.config.vocab_size), 523 | retrieval_labels.view(-1), 524 | ) 525 | loss = (self.task_loss_coeff * task_loss) + ( 526 | self.retrieval_loss_coeff * retrieval_loss 527 | ) 528 | 529 | if not return_dict: 530 | output = (logits,) + outputs[2:] 531 | return ((loss,) + output) if loss is not None else output 532 | return TokenClassifierOutputMuxed( 533 | loss=loss, 534 | logits=logits, 535 | hidden_states=demuxed_representations, 536 | task_loss=task_loss, 537 | retrieval_loss=retrieval_loss, 538 | ) 539 | 540 | ####### INDEX DEMUXING CLASSES ######### 541 | class RobertaIndexDemultiplexerSequenceClassification(nn.Module): 542 | """Head for sequence-level classification tasks.""" 543 | 544 | def __init__(self, config): 545 | super().__init__() 546 | self.num_instances = config.num_instances 547 | self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size) 548 | self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 549 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 550 | self.dense_before_out_proj = nn.Linear(config.hidden_size, config.hidden_size) 551 | self.out_proj = nn.Linear(config.hidden_size, config.num_labels) 552 | 553 | def forward(self, features, **kwargs): 554 | _, _, _ = features.shape 555 | positional_embeds = features[ 556 | :, : self.num_instances, : 557 | ] # take token (equiv. to [CLS]) 558 | positional_embeds = positional_embeds.reshape( 559 | -1, positional_embeds.shape[-1] 560 | ) 561 | # extract the added [CLS] token during inference 562 | x = features[:, self.num_instances, :] 563 | x = x.unsqueeze(1).repeat(1, self.num_instances, 1) 564 | x = x.view(-1, x.shape[-1]) 565 | 566 | x = torch.cat([positional_embeds, x], dim=1) 567 | x = self.dense(x) 568 | x = gelu(x) 569 | demuxed_feat = self.layer_norm(x) 570 | x = self.dense_before_out_proj(demuxed_feat) 571 | x = gelu(x) 572 | x = self.out_proj(x) 573 | return x, demuxed_feat 574 | 575 | class RetrievalHeadIndexDemultiplexing(nn.Module): 576 | 577 | def __init__(self, config): 578 | super().__init__() 579 | self.num_instances = config.num_instances 580 | self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size) 581 | self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 582 | 583 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size) 584 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 585 | self.decoder.bias = self.bias 586 | 587 | def forward(self, features, instance_labels, **kwargs): 588 | # extract the first representations and concatenate with the right word 589 | batch, seqlength, _ = features.shape 590 | positional_representations = features[:, : self.num_instances, :] 591 | # concatenate features with the instance representations based on instance labels 592 | instance_labels_copy = instance_labels.clone() 593 | instance_labels_copy[instance_labels == -100] = 0 594 | positional_embeds = positional_representations[ 595 | torch.arange(batch, device=features.device).unsqueeze(1).repeat(1, seqlength), 596 | instance_labels_copy, 597 | ] 598 | features = torch.cat([positional_embeds, features], dim=2) 599 | x = self.dense(features) 600 | x = gelu(x) 601 | x = self.layer_norm(x) 602 | 603 | # project back to size of vocabulary with bias 604 | x = self.decoder(x) 605 | 606 | return x 607 | 608 | def _tie_weights(self): 609 | # To tie those two weights if they get disconnected (on TPU or when the bias is resized) 610 | self.bias = self.decoder.bias 611 | 612 | class RobertaIndexDemultiplexerTokenClassification(nn.Module): 613 | """Roberta Head for masked language modeling.""" 614 | 615 | def __init__(self, config): 616 | super().__init__() 617 | self.num_instances = config.num_instances 618 | self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size) 619 | self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 620 | 621 | self.decoder = nn.Linear(config.hidden_size, config.num_labels) 622 | self.bias = nn.Parameter(torch.zeros(config.num_labels)) 623 | self.decoder.bias = self.bias 624 | 625 | def forward(self, features, **kwargs): 626 | 627 | # extract the first representations and concatenate with the right word 628 | batch, seqlength, feature_dim = features.shape 629 | positional_representations = features[:, : self.num_instances, :] 630 | # concatenate features with the sentence representations based on sentence_labels 631 | # don't overwrite sentence labels !! 632 | 633 | # need to expand the batch to the original size, need to make predictions 634 | # on the original 635 | positional_representations = positional_representations.unsqueeze(2).expand( 636 | batch, self.num_instances, seqlength, feature_dim 637 | ) 638 | features = features.unsqueeze(1).expand( 639 | batch, self.num_instances, seqlength, feature_dim 640 | ) 641 | features = torch.cat([positional_representations, features], dim=3) 642 | # increase the batch size by collapsing the first 2 dimensions 643 | features = features.view(-1, seqlength, 2 * feature_dim) 644 | x = self.dense(features) 645 | x = gelu(x) 646 | demuxed_feat = self.layer_norm(x) 647 | x = self.decoder(demuxed_feat) 648 | return x, demuxed_feat 649 | 650 | def _tie_weights(self): 651 | # To tie those two weights if they get disconnected (on TPU or when the bias is resized) 652 | self.bias = self.decoder.bias 653 | 654 | ####### MLP DEMUXING CLASSES ######### 655 | 656 | class RobertaMLPDemuxModule(nn.Module): 657 | def __init__(self, config): 658 | super().__init__() 659 | self.num_instances = config.num_instances 660 | # initialize different MLPs for different instances 661 | for sent_id in range(self.num_instances): 662 | setattr( 663 | self, 664 | f"dense_{sent_id}", 665 | nn.Linear(config.hidden_size, config.hidden_size), 666 | ) 667 | 668 | setattr( 669 | self, 670 | f"layer_norm_{sent_id}", 671 | nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), 672 | ) 673 | 674 | setattr(self, f"dropout_{sent_id}", nn.Dropout(config.hidden_dropout_prob)) 675 | 676 | class RobertaMLPDemultiplexerSequenceClassification(nn.Module): 677 | """Head for sequence-level classification tasks.""" 678 | 679 | def __init__(self, config): 680 | super().__init__() 681 | self.num_instances = config.num_instances 682 | self.demux_module = RobertaMLPDemuxModule(config) 683 | self.dense_before_out_proj = nn.Linear(config.hidden_size, config.hidden_size) 684 | self.out_proj = nn.Linear(config.hidden_size, config.num_labels) 685 | self.layernorm_presoftmax = nn.LayerNorm( 686 | config.hidden_size, eps=config.layer_norm_eps 687 | ) 688 | 689 | def forward(self, features): 690 | # extract the first representations and concatenate with the right word 691 | _, _, _ = features.shape 692 | all_feats = [] 693 | for sent_id in range(self.num_instances): 694 | cur_dense1 = getattr(self.demux_module, f"dense_{sent_id}") 695 | cur_layer_norm = getattr(self.demux_module, f"layer_norm_{sent_id}") 696 | dropout = getattr(self.demux_module, f"dropout_{sent_id}") 697 | 698 | cls_feat = features[:, 0, :] 699 | x = dropout(cls_feat) 700 | x = cur_dense1(x) 701 | x = gelu(x) 702 | x = cur_layer_norm(x) 703 | 704 | all_feats.append(x) 705 | 706 | all_feats = torch.stack(all_feats, dim=1) 707 | demuxed_representations = all_feats.view(-1, all_feats.shape[-1]) 708 | x = self.dense_before_out_proj(demuxed_representations) 709 | x = gelu(x) 710 | x = self.layernorm_presoftmax(x) 711 | x = self.out_proj(x) 712 | 713 | return x, demuxed_representations 714 | 715 | class RetrievalHeadMLPDemultiplexing(nn.Module): 716 | 717 | def __init__(self, config): 718 | super().__init__() 719 | self.num_instances = config.num_instances 720 | # initialize different MLPs for different instances 721 | self.demux_module = RobertaMLPDemuxModule(config) 722 | 723 | # shared vocab layers across different instances 724 | self.dense_pre_vocab = nn.Linear(config.hidden_size, config.hidden_size) 725 | self.layer_norm_pre_vocab = nn.LayerNorm( 726 | config.hidden_size, eps=config.layer_norm_eps 727 | ) 728 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size) 729 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 730 | self.decoder.bias = self.bias 731 | 732 | def forward(self, features, instance_labels, **kwargs): 733 | # extract the first representations and concatenate with the right word 734 | batch, seqlength, _ = features.shape 735 | all_feats = torch.zeros_like(features) 736 | all_feats = all_feats.view(-1, features.shape[-1]) 737 | 738 | for sent_id in range(self.num_instances): 739 | cur_dense1 = getattr(self.demux_module, f"dense_{sent_id}") 740 | cur_layer_norm = getattr(self.demux_module, f"layer_norm_{sent_id}") 741 | dropout = getattr(self.demux_module, f"dropout_{sent_id}") 742 | 743 | cur_sent_mask = instance_labels == sent_id 744 | cur_sent_feats = features[cur_sent_mask] 745 | 746 | x = dropout(cur_sent_feats) 747 | x = cur_dense1(x) 748 | x = gelu(x) 749 | x = cur_layer_norm(x) 750 | 751 | all_feats[cur_sent_mask.view(-1), :] = x 752 | 753 | # reshape into B x L x V 754 | all_feats = all_feats.view(batch, seqlength, -1) 755 | # project back to size of vocabulary with bias 756 | x = self.dense_pre_vocab(all_feats) 757 | x = gelu(x) 758 | x = self.layer_norm_pre_vocab(x) 759 | x = self.decoder(x) 760 | 761 | return x 762 | 763 | def _tie_weights(self): 764 | # To tie those two weights if they get disconnected (on TPU or when the bias is resized) 765 | self.bias = self.decoder.bias 766 | 767 | class RobertaMLPDemultiplexerTokenClassification(nn.Module): 768 | """Head for sentence-level classification tasks.""" 769 | 770 | def __init__(self, config): 771 | super().__init__() 772 | self.num_instances = config.num_instances 773 | # initialize different MLPs for different sentences 774 | self.demux_module = RobertaMLPDemuxModule(config) 775 | # shared vocab layers across different sentences 776 | self.dense_before_out_proj = nn.Linear(config.hidden_size, config.hidden_size) 777 | self.out_proj = nn.Linear(config.hidden_size, config.num_labels) 778 | self.layernorm_presoftmax = nn.LayerNorm( 779 | config.hidden_size, eps=config.layer_norm_eps 780 | ) 781 | 782 | def forward(self, features): 783 | # extract the first representations and concatenate with the right word 784 | _, seq_length, feature_dim = features.shape 785 | all_feats = [] 786 | for sent_id in range(self.num_instances): 787 | cur_dense1 = getattr(self.demux_module, f"dense_{sent_id}") 788 | cur_layer_norm = getattr(self.demux_module, f"layer_norm_{sent_id}") 789 | dropout = getattr(self.demux_module, f"dropout_{sent_id}") 790 | inp_feat = features 791 | x = dropout(inp_feat) 792 | x = cur_dense1(x) 793 | x = gelu(x) 794 | x = cur_layer_norm(x) 795 | all_feats.append(x.unsqueeze(1)) 796 | 797 | # B x L x dim 798 | # stack to get B x N X L X dim 799 | all_feats = torch.cat(all_feats, dim=1) 800 | # collapse the first 2 dimensions 801 | demuxed_representations = all_feats.view(-1, seq_length, feature_dim) 802 | 803 | x = self.dense_before_out_proj(demuxed_representations) 804 | x = gelu(x) 805 | x = self.layernorm_presoftmax(x) 806 | x = self.out_proj(x) 807 | 808 | return x, demuxed_representations 809 | 810 | ###### DATA CLASSES ####### 811 | 812 | @dataclass 813 | class SequenceClassifierOutputMuxed(SequenceClassifierOutput): 814 | loss: Optional[torch.FloatTensor] = None 815 | logits: torch.FloatTensor = None 816 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 817 | attentions: Optional[Tuple[torch.FloatTensor]] = None 818 | task_loss: Optional[torch.FloatTensor] = None 819 | retrieval_loss: Optional[torch.FloatTensor] = None 820 | retrieval_predictions: Optional[torch.FloatTensor] = None 821 | retrieval_instance_labels: Optional[torch.FloatTensor] = None 822 | 823 | @dataclass 824 | class TokenClassifierOutputMuxed(ModelOutput): 825 | 826 | loss: Optional[torch.FloatTensor] = None 827 | logits: torch.FloatTensor = None 828 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 829 | attentions: Optional[Tuple[torch.FloatTensor]] = None 830 | task_loss: Optional[torch.FloatTensor] = None 831 | retrieval_loss: Optional[torch.FloatTensor] = None 832 | -------------------------------------------------------------------------------- /models/trainer.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import gc 3 | import inspect 4 | import math 5 | from multiprocessing.spawn import import_main_path 6 | import os 7 | import re 8 | import shutil 9 | import sys 10 | import time 11 | import warnings 12 | from logging import StreamHandler 13 | from pathlib import Path 14 | from typing import ( 15 | TYPE_CHECKING, 16 | Any, 17 | Callable, 18 | Counter, 19 | Dict, 20 | List, 21 | Optional, 22 | Tuple, 23 | Union, 24 | ) 25 | from torch import optim 26 | import nvidia_smi 27 | 28 | # Integrations must be imported before ML frameworks: 29 | from transformers.integrations import ( # isort: split 30 | default_hp_search_backend, 31 | get_reporting_integration_callbacks, 32 | hp_params, 33 | is_fairscale_available, 34 | is_optuna_available, 35 | is_ray_tune_available, 36 | run_hp_search_optuna, 37 | run_hp_search_ray, 38 | init_deepspeed, 39 | ) 40 | 41 | import numpy as np 42 | import pandas as pd 43 | import torch 44 | from packaging import version 45 | from torch import nn 46 | from torch.utils.data.dataloader import DataLoader 47 | from torch.utils.data.dataset import Dataset 48 | from torch.utils.data.distributed import DistributedSampler 49 | from torch.utils.data.sampler import RandomSampler, SequentialSampler 50 | from tqdm import tqdm 51 | from transformers.data.data_collator import ( 52 | DataCollator, 53 | DataCollatorWithPadding, 54 | default_data_collator, 55 | ) 56 | from transformers.file_utils import ( 57 | WEIGHTS_NAME, 58 | is_apex_available, 59 | is_datasets_available, 60 | is_in_notebook, 61 | is_sagemaker_distributed_available, 62 | is_torch_tpu_available, 63 | is_training_run_on_sagemaker, 64 | ) 65 | from transformers.modeling_utils import PreTrainedModel, unwrap_model 66 | from transformers.optimization import Adafactor, AdamW, get_scheduler 67 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 68 | from transformers.trainer_callback import ( 69 | CallbackHandler, 70 | DefaultFlowCallback, 71 | PrinterCallback, 72 | ProgressCallback, 73 | TrainerCallback, 74 | TrainerControl, 75 | TrainerState, 76 | ) 77 | from transformers.trainer_pt_utils import ( 78 | DistributedLengthGroupedSampler, 79 | DistributedSamplerWithLoop, 80 | DistributedTensorGatherer, 81 | LabelSmoother, 82 | LengthGroupedSampler, 83 | SequentialDistributedSampler, 84 | distributed_broadcast_scalars, 85 | distributed_concat, 86 | get_parameter_names, 87 | nested_concat, 88 | nested_detach, 89 | nested_numpify, 90 | nested_xla_mesh_reduce, 91 | reissue_pt_warnings, 92 | ) 93 | from transformers import Trainer 94 | from transformers.trainer_utils import ( 95 | PREFIX_CHECKPOINT_DIR, 96 | BestRun, 97 | EvalPrediction, 98 | HPSearchBackend, 99 | PredictionOutput, 100 | ShardedDDPOption, 101 | TrainerMemoryTracker, 102 | TrainOutput, 103 | default_compute_objective, 104 | default_hp_space, 105 | denumpify_detensorize, 106 | get_last_checkpoint, 107 | set_seed, 108 | speed_metrics, 109 | ) 110 | from transformers.training_args import ParallelMode, TrainingArguments 111 | from transformers.utils import logging 112 | from transformers.utils.modeling_auto_mapping import ( 113 | MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, 114 | ) 115 | from transformers import Trainer 116 | from transformers.integrations import WandbCallback, rewrite_logs 117 | import wandb 118 | from sklearn.decomposition import PCA 119 | from sklearn.manifold import TSNE 120 | import matplotlib.pyplot as plt 121 | import seaborn as sns 122 | 123 | _is_native_amp_available = False 124 | 125 | DEFAULT_CALLBACKS = [DefaultFlowCallback] 126 | DEFAULT_PROGRESS_CALLBACK = ProgressCallback 127 | 128 | if is_in_notebook(): 129 | from transformers.utils.notebook import NotebookProgressCallback 130 | 131 | DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback 132 | 133 | if is_apex_available(): 134 | from apex import amp 135 | 136 | if version.parse(torch.__version__) >= version.parse("1.6"): 137 | _is_native_amp_available = True 138 | from torch.cuda.amp import autocast 139 | 140 | if is_datasets_available(): 141 | import datasets 142 | 143 | if is_torch_tpu_available(): 144 | import torch_xla.core.xla_model as xm 145 | import torch_xla.debug.metrics as met 146 | import torch_xla.distributed.parallel_loader as pl 147 | 148 | if is_fairscale_available(): 149 | import fairscale 150 | from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP 151 | from fairscale.optim import OSS 152 | from fairscale.optim.grad_scaler import ShardedGradScaler 153 | 154 | if version.parse(fairscale.__version__) >= version.parse("0.3"): 155 | from fairscale.nn.data_parallel import ( 156 | FullyShardedDataParallel as FullyShardedDDP, 157 | ) 158 | from fairscale.nn.wrap import auto_wrap 159 | else: 160 | FullyShardedDDP = None 161 | 162 | if is_sagemaker_distributed_available(): 163 | import smdistributed.dataparallel.torch.distributed as dist 164 | from smdistributed.dataparallel.torch.parallel.distributed import ( 165 | DistributedDataParallel as DDP, 166 | ) 167 | else: 168 | import torch.distributed as dist 169 | 170 | if is_training_run_on_sagemaker(): 171 | logging.add_handler(StreamHandler(sys.stdout)) 172 | 173 | 174 | if TYPE_CHECKING: 175 | import optuna 176 | 177 | logger = logging.get_logger(__name__) 178 | 179 | 180 | class WandbCallbackThreadFix(WandbCallback): 181 | def setup(self, args, state, model, reinit, **kwargs): 182 | """ 183 | Setup the optional Weights & Biases (`wandb`) integration. 184 | 185 | One can subclass and override this method to customize the setup if needed. Find more information `here 186 | `__. You can also override the following environment variables: 187 | 188 | Environment: 189 | WANDB_LOG_MODEL (:obj:`bool`, `optional`, defaults to :obj:`False`): 190 | Whether or not to log model as artifact at the end of training. 191 | WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`): 192 | Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient 193 | logging or :obj:`"all"` to log gradients and parameters. 194 | WANDB_PROJECT (:obj:`str`, `optional`, defaults to :obj:`"huggingface"`): 195 | Set this to a custom string to store results in a different project. 196 | WANDB_DISABLED (:obj:`bool`, `optional`, defaults to :obj:`False`): 197 | Whether or not to disable wandb entirely. Set `WANDB_DISABLED=true` to disable. 198 | """ 199 | if self._wandb is None: 200 | return 201 | self._initialized = True 202 | if state.is_world_process_zero: 203 | logger.info( 204 | 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"' 205 | ) 206 | combined_dict = {**args.to_sanitized_dict()} 207 | 208 | if hasattr(model, "config") and model.config is not None: 209 | model_config = model.config.to_dict() 210 | combined_dict = {**model_config, **combined_dict} 211 | trial_name = state.trial_name 212 | init_args = {} 213 | if trial_name is not None: 214 | run_name = trial_name 215 | init_args["group"] = args.run_name 216 | else: 217 | run_name = args.run_name 218 | init_args["settings"] = wandb.Settings(start_method="fork") 219 | self._wandb.init( 220 | project=os.getenv("WANDB_PROJECT", "huggingface"), 221 | config=combined_dict, 222 | name=run_name, 223 | reinit=reinit, 224 | **init_args, 225 | ) 226 | 227 | # keep track of model topology and gradients, unsupported on TPU 228 | if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false": 229 | self._wandb.watch( 230 | model, 231 | log=os.getenv("WANDB_WATCH", "gradients"), 232 | log_freq=max(100, args.logging_steps), 233 | ) 234 | 235 | def on_log(self, args, state, control, model=None, logs=None, **kwargs): 236 | if self._wandb is None: 237 | return 238 | if not self._initialized: 239 | self.setup(args, state, model, reinit=False) 240 | 241 | is_table = len(logs) == 1 242 | 243 | if state.is_world_process_zero: 244 | if is_table: 245 | self._wandb.log(logs) 246 | else: 247 | use_global_step = logs.pop("use_global_step", True) 248 | logs = rewrite_logs(logs) 249 | 250 | if use_global_step: 251 | self._wandb.log(logs, step=state.global_step) 252 | else: 253 | self._wandb.log(logs) 254 | 255 | class MuxTrainer(Trainer): 256 | def __init__( 257 | self, 258 | model: Union[PreTrainedModel, torch.nn.Module] = None, 259 | args: TrainingArguments = None, 260 | data_collator: Optional[DataCollator] = None, 261 | train_dataset: Optional[Dataset] = None, 262 | eval_dataset: Optional[Dataset] = None, 263 | tokenizer: Optional["PreTrainedTokenizerBase"] = None, 264 | model_init: Callable[[], PreTrainedModel] = None, 265 | compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, 266 | callbacks: Optional[List[TrainerCallback]] = None, 267 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( 268 | None, 269 | None, 270 | ), 271 | ): 272 | if args is None: 273 | output_dir = "tmp_trainer" 274 | logger.info( 275 | f"No `TrainingArguments` passed, using `output_dir={output_dir}`." 276 | ) 277 | args = TrainingArguments(output_dir=output_dir) 278 | self.args = args 279 | # Seed must be set before instantiating the model when using model 280 | set_seed(self.args.seed) 281 | self.hp_name = None 282 | self.deepspeed = None 283 | self.is_in_train = False 284 | 285 | # memory metrics - must set up as early as possible 286 | self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) 287 | self._memory_tracker.start() 288 | 289 | # force device and distributed setup init explicitly 290 | args._setup_devices 291 | 292 | if model is None: 293 | if model_init is not None: 294 | self.model_init = model_init 295 | model = self.call_model_init() 296 | else: 297 | raise RuntimeError( 298 | "`Trainer` requires either a `model` or `model_init` argument" 299 | ) 300 | else: 301 | if model_init is not None: 302 | warnings.warn( 303 | "`Trainer` requires either a `model` or `model_init` argument, but not both. " 304 | "`model_init` will overwrite your model when calling the `train` method. This will become a fatal error in the next release.", 305 | FutureWarning, 306 | ) 307 | self.model_init = model_init 308 | 309 | if ( 310 | hasattr(model, "is_parallelizable") 311 | and model.is_parallelizable 312 | and model.model_parallel 313 | ): 314 | self.is_model_parallel = True 315 | else: 316 | self.is_model_parallel = False 317 | 318 | # Setup Sharded DDP training 319 | self.sharded_ddp = None 320 | if len(args.sharded_ddp) > 0: 321 | if args.deepspeed: 322 | raise ValueError( 323 | "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags." 324 | ) 325 | 326 | if args.local_rank == -1: 327 | raise ValueError( 328 | "Using sharded DDP only works in distributed training." 329 | ) 330 | elif not is_fairscale_available(): 331 | raise ImportError( 332 | "Sharded DDP training requires fairscale: `pip install fairscale`." 333 | ) 334 | elif ( 335 | ShardedDDPOption.SIMPLE not in args.sharded_ddp 336 | and FullyShardedDDP is None 337 | ): 338 | raise ImportError( 339 | "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found " 340 | f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`." 341 | ) 342 | elif ShardedDDPOption.SIMPLE in args.sharded_ddp: 343 | self.sharded_ddp = ShardedDDPOption.SIMPLE 344 | elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp: 345 | self.sharded_ddp = ShardedDDPOption.ZERO_DP_2 346 | elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp: 347 | self.sharded_ddp = ShardedDDPOption.ZERO_DP_3 348 | 349 | # one place to sort out whether to place the model on device or not 350 | self.place_model_on_device = args.place_model_on_device 351 | if ( 352 | self.is_model_parallel 353 | or (args.deepspeed and args.do_train) 354 | or (args.fp16_full_eval and not args.do_train) 355 | or ( 356 | self.sharded_ddp 357 | in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3] 358 | ) 359 | ): 360 | self.place_model_on_device = False 361 | 362 | default_collator = ( 363 | default_data_collator 364 | if tokenizer is None 365 | else DataCollatorWithPadding(tokenizer) 366 | ) 367 | self.data_collator = ( 368 | data_collator if data_collator is not None else default_collator 369 | ) 370 | self.train_dataset = train_dataset 371 | self.eval_dataset = eval_dataset 372 | self.tokenizer = tokenizer 373 | 374 | # postpone switching model to cuda when: 375 | # 1. MP - since we are trying to fit a much bigger than 1 gpu model 376 | # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, 377 | # and we only use deepspeed for training at the moment 378 | if self.place_model_on_device: 379 | model = model.to(args.device) 380 | 381 | # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs 382 | if self.is_model_parallel: 383 | self.args._n_gpu = 1 384 | 385 | # later use `self.model is self.model_wrapped` to check if it's wrapped or not 386 | self.model_wrapped = model 387 | self.model = model 388 | 389 | self.compute_metrics = compute_metrics 390 | self.optimizer, self.lr_scheduler = optimizers 391 | if model_init is not None and ( 392 | self.optimizer is not None or self.lr_scheduler is not None 393 | ): 394 | raise RuntimeError( 395 | "Passing a `model_init` is incompatible with providing the `optimizers` argument." 396 | "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." 397 | ) 398 | # default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks( 399 | # self.args.report_to 400 | # ) 401 | default_callbacks = DEFAULT_CALLBACKS + [WandbCallbackThreadFix] 402 | callbacks = ( 403 | default_callbacks if callbacks is None else default_callbacks + callbacks 404 | ) 405 | self.callback_handler = CallbackHandler( 406 | callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler 407 | ) 408 | self.add_callback( 409 | PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK 410 | ) 411 | 412 | # Will be set to True by `self._setup_loggers()` on first call to `self.log()`. 413 | self._loggers_initialized = False 414 | 415 | # Create output directory if needed 416 | if self.is_world_process_zero(): 417 | os.makedirs(self.args.output_dir, exist_ok=True) 418 | if not callable(self.data_collator) and callable( 419 | getattr(self.data_collator, "collate_batch", None) 420 | ): 421 | raise ValueError( 422 | "The `data_collator` should be a simple callable (function, class with `__call__`)." 423 | ) 424 | 425 | if args.max_steps > 0: 426 | logger.info( 427 | "max_steps is given, it will override any value given in num_train_epochs" 428 | ) 429 | 430 | # Enforce rules on using datasets with no __len__ 431 | if ( 432 | train_dataset is not None 433 | and not isinstance(train_dataset, collections.abc.Sized) 434 | and args.max_steps <= 0 435 | ): 436 | raise ValueError( 437 | "train_dataset does not implement __len__, max_steps has to be specified" 438 | ) 439 | if eval_dataset is not None and not isinstance( 440 | eval_dataset, collections.abc.Sized 441 | ): 442 | raise ValueError("eval_dataset must implement __len__") 443 | 444 | self._signature_columns = None 445 | if is_datasets_available(): 446 | if isinstance(train_dataset, datasets.Dataset): 447 | self._remove_unused_columns(self.train_dataset, description="training") 448 | if isinstance(eval_dataset, datasets.Dataset): 449 | self._remove_unused_columns(self.eval_dataset, description="evaluation") 450 | 451 | # Mixed precision setup 452 | self.use_apex = False 453 | self.use_amp = False 454 | self.fp16_backend = None 455 | 456 | if args.fp16: 457 | if args.fp16_backend == "auto": 458 | self.fp16_backend = "amp" if _is_native_amp_available else "apex" 459 | else: 460 | self.fp16_backend = args.fp16_backend 461 | logger.info(f"Using {self.fp16_backend} fp16 backend") 462 | 463 | if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16 464 | if self.fp16_backend == "amp": 465 | self.use_amp = True 466 | self.scaler = ( 467 | ShardedGradScaler() 468 | if self.sharded_ddp is not None 469 | else torch.cuda.amp.GradScaler() 470 | ) 471 | else: 472 | if not is_apex_available(): 473 | raise ImportError( 474 | "Using FP16 with APEX but APEX is not installed, please refer to https://www.github.com/nvidia/apex." 475 | ) 476 | self.use_apex = True 477 | 478 | # Label smoothing 479 | if self.args.label_smoothing_factor != 0: 480 | self.label_smoother = LabelSmoother( 481 | epsilon=self.args.label_smoothing_factor 482 | ) 483 | else: 484 | self.label_smoother = None 485 | 486 | self.state = TrainerState() 487 | self.control = TrainerControl() 488 | # Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the 489 | # state at each call to self.log. 490 | self._total_flos = None 491 | self.hp_search_backend = None 492 | self.use_tune_checkpoints = False 493 | default_label_names = ( 494 | ["start_positions", "end_positions"] 495 | if type(self.model).__name__ 496 | in MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES.values() 497 | else ["labels"] 498 | ) 499 | self.label_names = ( 500 | default_label_names 501 | if self.args.label_names is None 502 | else self.args.label_names 503 | ) 504 | self.control = self.callback_handler.on_init_end( 505 | self.args, self.state, self.control 506 | ) 507 | 508 | # very last 509 | self._memory_tracker.stop_and_update_metrics() 510 | 511 | def train( 512 | self, 513 | resume_from_checkpoint: Optional[Union[str, bool]] = None, 514 | trial: Union["optuna.Trial", Dict[str, Any]] = None, 515 | **kwargs, 516 | ): 517 | """ 518 | Main training entry point. 519 | 520 | Args: 521 | resume_from_checkpoint (:obj:`str` or :obj:`bool`, `optional`): 522 | If a :obj:`str`, local path to a saved checkpoint as saved by a previous instance of 523 | :class:`~transformers.Trainer`. If a :obj:`bool` and equals `True`, load the last checkpoint in 524 | `args.output_dir` as saved by a previous instance of :class:`~transformers.Trainer`. If present, 525 | training will resume from the model/optimizer/scheduler states loaded here. 526 | trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`): 527 | The trial run or the hyperparameter dictionary for hyperparameter search. 528 | kwargs: 529 | Additional keyword arguments used to hide deprecated arguments 530 | """ 531 | 532 | # memory metrics - must set up as early as possible 533 | self._memory_tracker.start() 534 | 535 | self.is_in_train = True 536 | 537 | if "model_path" in kwargs: 538 | resume_from_checkpoint = kwargs.pop("model_path") 539 | warnings.warn( 540 | "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` " 541 | "instead.", 542 | FutureWarning, 543 | ) 544 | if len(kwargs) > 0: 545 | raise TypeError( 546 | f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}." 547 | ) 548 | # This might change the seed so needs to run first. 549 | self._hp_search_setup(trial) 550 | 551 | # Model re-init 552 | model_reloaded = False 553 | if self.model_init is not None: 554 | # Seed must be set before instantiating the model when using model_init. 555 | set_seed(self.args.seed) 556 | self.model = self.call_model_init(trial) 557 | model_reloaded = True 558 | # Reinitializes optimizer and scheduler 559 | self.optimizer, self.lr_scheduler = None, None 560 | 561 | # Load potential model checkpoint 562 | if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: 563 | resume_from_checkpoint = get_last_checkpoint(self.args.output_dir) 564 | if resume_from_checkpoint is None: 565 | raise ValueError( 566 | f"No valid checkpoint found in output directory ({self.args.output_dir})" 567 | ) 568 | 569 | if resume_from_checkpoint is not None and os.path.isfile( 570 | os.path.join(resume_from_checkpoint, WEIGHTS_NAME) 571 | ): 572 | logger.info(f"Loading model from {resume_from_checkpoint}).") 573 | if isinstance(self.model, PreTrainedModel): 574 | self.model = self.model.from_pretrained(resume_from_checkpoint) 575 | model_reloaded = True 576 | else: 577 | state_dict = torch.load( 578 | os.path.join(resume_from_checkpoint, WEIGHTS_NAME) 579 | ) 580 | self.model.load_state_dict(state_dict) 581 | 582 | # If model was re-initialized, put it on the right device and update self.model_wrapped 583 | if model_reloaded: 584 | if self.place_model_on_device: 585 | self.model = self.model.to(self.args.device) 586 | self.model_wrapped = self.model 587 | 588 | # Keeping track whether we can can len() on the dataset or not 589 | train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized) 590 | 591 | # Data loader and number of training steps 592 | train_dataloader = self.get_train_dataloader() 593 | 594 | # Setting up training control variables: 595 | # number of training epochs: num_train_epochs 596 | # number of training steps per epoch: num_update_steps_per_epoch 597 | # total number of training steps to execute: max_steps 598 | if train_dataset_is_sized: 599 | num_update_steps_per_epoch = ( 600 | len(train_dataloader) // self.args.gradient_accumulation_steps 601 | ) 602 | num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) 603 | if self.args.max_steps > 0: 604 | max_steps = self.args.max_steps 605 | num_train_epochs = ( 606 | self.args.max_steps // num_update_steps_per_epoch 607 | + int(self.args.max_steps % num_update_steps_per_epoch > 0) 608 | ) 609 | else: 610 | max_steps = math.ceil( 611 | self.args.num_train_epochs * num_update_steps_per_epoch 612 | ) 613 | num_train_epochs = math.ceil(self.args.num_train_epochs) 614 | else: 615 | # see __init__. max_steps is set when the dataset has no __len__ 616 | max_steps = self.args.max_steps 617 | num_train_epochs = 1 618 | num_update_steps_per_epoch = max_steps 619 | 620 | delay_optimizer_creation = ( 621 | self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE 622 | ) 623 | if self.args.deepspeed: 624 | model, optimizer, lr_scheduler = init_deepspeed( 625 | self, num_training_steps=max_steps 626 | ) 627 | self.model = model.module 628 | self.model_wrapped = model # will get further wrapped in DDP 629 | self.deepspeed = model # DeepSpeedEngine object 630 | self.optimizer = optimizer 631 | self.lr_scheduler = lr_scheduler 632 | elif not delay_optimizer_creation: 633 | self.create_optimizer_and_scheduler(num_training_steps=max_steps) 634 | 635 | self.state = TrainerState() 636 | self.state.is_hyper_param_search = trial is not None 637 | 638 | model = self._wrap_model(self.model_wrapped) 639 | 640 | # for the rest of this function `model` is the outside model, whether it was wrapped or not 641 | if model is not self.model: 642 | self.model_wrapped = model 643 | 644 | if delay_optimizer_creation: 645 | self.create_optimizer_and_scheduler(num_training_steps=max_steps) 646 | 647 | # Check if saved optimizer or scheduler states exist 648 | self._load_optimizer_and_scheduler(resume_from_checkpoint) 649 | 650 | # important: at this point: 651 | # self.model is the Transformers Model 652 | # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc. 653 | 654 | # Train! 655 | if is_torch_tpu_available(): 656 | world_size = xm.xrt_world_size() 657 | elif self.args.local_rank != -1: 658 | world_size = dist.get_world_size() 659 | else: 660 | world_size = 1 661 | 662 | total_train_batch_size = ( 663 | self.args.train_batch_size 664 | * self.args.gradient_accumulation_steps 665 | * world_size 666 | ) 667 | num_examples = ( 668 | self.num_examples(train_dataloader) 669 | if train_dataset_is_sized 670 | else total_train_batch_size * self.args.max_steps 671 | ) 672 | 673 | logger.info("***** Running training *****") 674 | logger.info(f" Num examples = {num_examples}") 675 | logger.info(f" Num Epochs = {num_train_epochs}") 676 | logger.info( 677 | f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}" 678 | ) 679 | logger.info( 680 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}" 681 | ) 682 | logger.info( 683 | f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}" 684 | ) 685 | logger.info(f" Total optimization steps = {max_steps}") 686 | 687 | self.state.epoch = 0 688 | start_time = time.time() 689 | epochs_trained = 0 690 | steps_trained_in_current_epoch = 0 691 | 692 | # Check if continuing training from a checkpoint 693 | if resume_from_checkpoint is not None and os.path.isfile( 694 | os.path.join(resume_from_checkpoint, "trainer_state.json") 695 | ): 696 | self.state = TrainerState.load_from_json( 697 | os.path.join(resume_from_checkpoint, "trainer_state.json") 698 | ) 699 | epochs_trained = self.state.global_step // num_update_steps_per_epoch 700 | if not self.args.ignore_data_skip: 701 | steps_trained_in_current_epoch = self.state.global_step % ( 702 | num_update_steps_per_epoch 703 | ) 704 | steps_trained_in_current_epoch *= self.args.gradient_accumulation_steps 705 | else: 706 | steps_trained_in_current_epoch = 0 707 | 708 | logger.info( 709 | " Continuing training from checkpoint, will skip to saved global_step" 710 | ) 711 | logger.info(f" Continuing training from epoch {epochs_trained}") 712 | logger.info( 713 | f" Continuing training from global step {self.state.global_step}" 714 | ) 715 | if not self.args.ignore_data_skip: 716 | logger.info( 717 | f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} " 718 | "batches in the first epoch." 719 | ) 720 | 721 | # Update the references 722 | self.callback_handler.model = self.model 723 | self.callback_handler.optimizer = self.optimizer 724 | self.callback_handler.lr_scheduler = self.lr_scheduler 725 | self.callback_handler.train_dataloader = train_dataloader 726 | self.state.trial_name = ( 727 | self.hp_name(trial) if self.hp_name is not None else None 728 | ) 729 | self.state.trial_params = hp_params(trial) if trial is not None else None 730 | # This should be the same if the state has been saved but in case the training arguments changed, it's safer 731 | # to set this after the load. 732 | self.state.max_steps = max_steps 733 | self.state.num_train_epochs = num_train_epochs 734 | self.state.is_local_process_zero = self.is_local_process_zero() 735 | self.state.is_world_process_zero = self.is_world_process_zero() 736 | 737 | # tr_loss is a tensor to avoid synchronization of TPUs through .item() 738 | tr_loss = torch.tensor(0.0).to(self.args.device) 739 | tr_task_loss = torch.tensor(0.0).to(self.args.device) 740 | tr_retrieval_loss = torch.tensor(0.0).to(self.args.device) 741 | 742 | # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses 743 | self._total_loss_scalar = 0.0 744 | self._globalstep_last_logged = self.state.global_step 745 | self._total_flos = self.state.total_flos 746 | model.zero_grad() 747 | 748 | self.control = self.callback_handler.on_train_begin( 749 | self.args, self.state, self.control 750 | ) 751 | 752 | # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. 753 | if not self.args.ignore_data_skip: 754 | for epoch in range(epochs_trained): 755 | # We just need to begin an iteration to create the randomization of the sampler. 756 | for _ in train_dataloader: 757 | break 758 | 759 | for epoch in range(epochs_trained, num_train_epochs): 760 | if isinstance(train_dataloader, DataLoader) and isinstance( 761 | train_dataloader.sampler, DistributedSampler 762 | ): 763 | train_dataloader.sampler.set_epoch(epoch) 764 | 765 | if is_torch_tpu_available(): 766 | parallel_loader = pl.ParallelLoader( 767 | train_dataloader, [self.args.device] 768 | ).per_device_loader(self.args.device) 769 | epoch_iterator = parallel_loader 770 | else: 771 | epoch_iterator = train_dataloader 772 | 773 | # Reset the past mems state at the beginning of each epoch if necessary. 774 | if self.args.past_index >= 0: 775 | self._past = None 776 | 777 | steps_in_epoch = ( 778 | len(epoch_iterator) 779 | if train_dataset_is_sized 780 | else self.args.max_steps * self.args.gradient_accumulation_steps 781 | ) 782 | self.control = self.callback_handler.on_epoch_begin( 783 | self.args, self.state, self.control 784 | ) 785 | 786 | for step, inputs in enumerate(epoch_iterator): 787 | 788 | # Skip past any already trained steps if resuming training 789 | if steps_trained_in_current_epoch > 0: 790 | steps_trained_in_current_epoch -= 1 791 | continue 792 | 793 | if (step + 1) % self.args.gradient_accumulation_steps == 0: 794 | self.control = self.callback_handler.on_step_begin( 795 | self.args, self.state, self.control 796 | ) 797 | 798 | if ( 799 | ((step + 1) % self.args.gradient_accumulation_steps != 0) 800 | and self.args.local_rank != -1 801 | and self.args._no_sync_in_gradient_accumulation 802 | ): 803 | # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. 804 | with model.no_sync(): 805 | ( 806 | cur_tr_loss, 807 | cur_task_loss, 808 | cur_retrieval_loss, 809 | ) = self.training_step(model, inputs) 810 | tr_loss += cur_tr_loss 811 | if cur_task_loss is not None: 812 | tr_task_loss += cur_task_loss 813 | if cur_retrieval_loss is not None: 814 | tr_retrieval_loss += cur_retrieval_loss 815 | else: 816 | cur_tr_loss, cur_task_loss, cur_retrieval_loss = self.training_step( 817 | model, inputs 818 | ) 819 | tr_loss += cur_tr_loss 820 | if cur_task_loss is not None: 821 | tr_task_loss += cur_task_loss 822 | if cur_retrieval_loss is not None: 823 | tr_retrieval_loss += cur_retrieval_loss 824 | 825 | self._total_flos += float(self.floating_point_ops(inputs)) 826 | 827 | # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps 828 | if self.deepspeed: 829 | self.deepspeed.step() 830 | 831 | if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( 832 | # last step in epoch but step is always smaller than gradient_accumulation_steps 833 | steps_in_epoch <= self.args.gradient_accumulation_steps 834 | and (step + 1) == steps_in_epoch 835 | ): 836 | # Gradient clipping 837 | if ( 838 | self.args.max_grad_norm is not None 839 | and self.args.max_grad_norm > 0 840 | and not self.deepspeed 841 | ): 842 | # deepspeed does its own clipping 843 | 844 | if self.use_amp: 845 | # AMP: gradients need unscaling 846 | self.scaler.unscale_(self.optimizer) 847 | 848 | if hasattr(self.optimizer, "clip_grad_norm"): 849 | # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping 850 | self.optimizer.clip_grad_norm(self.args.max_grad_norm) 851 | elif hasattr(model, "clip_grad_norm_"): 852 | # Some models (like FullyShardedDDP) have a specific way to do gradient clipping 853 | model.clip_grad_norm_(self.args.max_grad_norm) 854 | else: 855 | # Revert to normal clipping otherwise, handling Apex or full precision 856 | torch.nn.utils.clip_grad_norm_( 857 | amp.master_params(self.optimizer) 858 | if self.use_apex 859 | else model.parameters(), 860 | self.args.max_grad_norm, 861 | ) 862 | 863 | # Optimizer step 864 | if self.deepspeed: 865 | pass # called outside the loop 866 | elif is_torch_tpu_available(): 867 | xm.optimizer_step(self.optimizer) 868 | elif self.use_amp: 869 | self.scaler.step(self.optimizer) 870 | self.scaler.update() 871 | else: 872 | self.optimizer.step() 873 | 874 | if not self.deepspeed: 875 | self.lr_scheduler.step() 876 | 877 | model.zero_grad() 878 | self.state.global_step += 1 879 | self.state.epoch = epoch + (step + 1) / steps_in_epoch 880 | self.control = self.callback_handler.on_step_end( 881 | self.args, self.state, self.control 882 | ) 883 | 884 | self._maybe_log_save_evaluate( 885 | tr_loss, tr_task_loss, tr_retrieval_loss, model, trial, epoch 886 | ) 887 | 888 | if self.control.should_epoch_stop or self.control.should_training_stop: 889 | break 890 | 891 | self.control = self.callback_handler.on_epoch_end( 892 | self.args, self.state, self.control 893 | ) 894 | self._maybe_log_save_evaluate( 895 | tr_loss, tr_task_loss, tr_retrieval_loss, model, trial, epoch 896 | ) 897 | 898 | if self.args.tpu_metrics_debug or self.args.debug: 899 | if is_torch_tpu_available(): 900 | # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) 901 | xm.master_print(met.metrics_report()) 902 | else: 903 | logger.warning( 904 | "You enabled PyTorch/XLA debug metrics but you don't have a TPU " 905 | "configured. Check your training configuration if this is unexpected." 906 | ) 907 | if self.control.should_training_stop: 908 | break 909 | 910 | if self.args.past_index and hasattr(self, "_past"): 911 | # Clean the state at the end of training 912 | delattr(self, "_past") 913 | 914 | logger.info( 915 | "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n" 916 | ) 917 | if ( 918 | self.args.load_best_model_at_end 919 | and self.state.best_model_checkpoint is not None 920 | ): 921 | # Wait for everyone to get here so we are sur the model has been saved by process 0. 922 | if is_torch_tpu_available(): 923 | xm.rendezvous("load_best_model_at_end") 924 | elif self.args.local_rank != -1: 925 | dist.barrier() 926 | 927 | logger.info( 928 | f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})." 929 | ) 930 | if isinstance(self.model, PreTrainedModel): 931 | self.model = self.model.from_pretrained( 932 | self.state.best_model_checkpoint 933 | ) 934 | if self.place_model_on_device: 935 | self.model = self.model.to(self.args.device) 936 | else: 937 | state_dict = torch.load( 938 | os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) 939 | ) 940 | self.model.load_state_dict(state_dict) 941 | 942 | if self.deepspeed: 943 | self.deepspeed.load_checkpoint( 944 | self.state.best_model_checkpoint, 945 | load_optimizer_states=False, 946 | load_lr_scheduler_states=False, 947 | ) 948 | 949 | metrics = speed_metrics("train", start_time, self.state.max_steps) 950 | if self._total_flos is not None: 951 | self.store_flos() 952 | metrics["total_flos"] = self.state.total_flos 953 | self.log(metrics) 954 | 955 | self.control = self.callback_handler.on_train_end( 956 | self.args, self.state, self.control 957 | ) 958 | # add remaining tr_loss 959 | self._total_loss_scalar += tr_loss.item() 960 | 961 | if self.deepspeed: 962 | # free up any memory that might be useful for eval 963 | self.deepspeed = None 964 | self.optimizer = None 965 | self.lr_scheduler = None 966 | self.model_wrapped = self.model 967 | gc.collect() # force memory release 968 | # to restore normal behavior outside of train replay the place_model_on_device logic w/o deepspeed 969 | self.place_model_on_device = self.args.place_model_on_device 970 | if self.is_model_parallel: 971 | self.place_model_on_device = False 972 | 973 | self.is_in_train = False 974 | 975 | self._memory_tracker.stop_and_update_metrics(metrics) 976 | 977 | return TrainOutput( 978 | self.state.global_step, 979 | self._total_loss_scalar / self.state.global_step, 980 | metrics, 981 | ) 982 | 983 | def _maybe_log_save_evaluate( 984 | self, tr_loss, task_loss, retrieval_loss, model, trial, epoch 985 | ): 986 | if self.control.should_log: 987 | logs: Dict[str, float] = {} 988 | tr_loss_scalar = tr_loss.item() 989 | task_loss_scalar = task_loss.item() if task_loss is not None else None 990 | retrieval_loss_scalar = ( 991 | retrieval_loss.item() if retrieval_loss is not None else None 992 | ) 993 | # reset tr_loss to zero 994 | tr_loss -= tr_loss 995 | task_loss -= task_loss 996 | retrieval_loss -= retrieval_loss 997 | 998 | logs["loss"] = round( 999 | tr_loss_scalar 1000 | / (self.state.global_step - self._globalstep_last_logged), 1001 | 4, 1002 | ) 1003 | if task_loss_scalar is not None: 1004 | logs["task_loss"] = round( 1005 | task_loss_scalar 1006 | / (self.state.global_step - self._globalstep_last_logged), 1007 | 4, 1008 | ) 1009 | if retrieval_loss_scalar is not None: 1010 | logs["retrieval_loss"] = round( 1011 | retrieval_loss_scalar 1012 | / (self.state.global_step - self._globalstep_last_logged), 1013 | 4, 1014 | ) 1015 | logs["learning_rate"] = self._get_learning_rate() 1016 | 1017 | self._total_loss_scalar += tr_loss_scalar 1018 | self._globalstep_last_logged = self.state.global_step 1019 | 1020 | self.log(logs) 1021 | 1022 | metrics = None 1023 | if self.control.should_evaluate: 1024 | metrics = self.evaluate() 1025 | self._report_to_hp_search(trial, epoch, metrics) 1026 | 1027 | if self.control.should_save: 1028 | self._save_checkpoint(model, trial, metrics=metrics) 1029 | self.control = self.callback_handler.on_save( 1030 | self.args, self.state, self.control 1031 | ) 1032 | 1033 | def training_step( 1034 | self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]] 1035 | ) -> torch.Tensor: 1036 | """ 1037 | Perform a training step on a batch of inputs. 1038 | 1039 | Subclass and override to inject custom behavior. 1040 | 1041 | Args: 1042 | model (:obj:`nn.Module`): 1043 | The model to train. 1044 | inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): 1045 | The inputs and targets of the model. 1046 | 1047 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 1048 | argument :obj:`labels`. Check your model's documentation for all accepted arguments. 1049 | 1050 | Return: 1051 | :obj:`torch.Tensor`: The tensor with training loss on this batch. 1052 | """ 1053 | model.train() 1054 | inputs = self._prepare_inputs(inputs) 1055 | 1056 | if self.use_amp: 1057 | with autocast(): 1058 | ( 1059 | loss, 1060 | task_loss, 1061 | retrieval_loss, 1062 | retrieval_logits, 1063 | retrieval_instance_labels, 1064 | ) = self.compute_loss(model, inputs) 1065 | else: 1066 | ( 1067 | loss, 1068 | task_loss, 1069 | retrieval_loss, 1070 | retrieval_logits, 1071 | retrieval_instance_labels, 1072 | ) = self.compute_loss(model, inputs) 1073 | 1074 | if self.args.n_gpu > 1: 1075 | loss = loss.mean() # mean() to average on multi-gpu parallel training 1076 | task_loss = task_loss.mean() if task_loss is not None else None 1077 | retrieval_loss = ( 1078 | retrieval_loss.mean() if retrieval_loss is not None else None 1079 | ) 1080 | 1081 | if self.args.gradient_accumulation_steps > 1 and not self.deepspeed: 1082 | # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward` 1083 | loss = loss / self.args.gradient_accumulation_steps 1084 | task_loss = ( 1085 | task_loss / self.args.gradient_accumulation_steps 1086 | if task_loss is not None 1087 | else None 1088 | ) 1089 | retrieval_loss = ( 1090 | retrieval_loss / self.args.gradient_accumulation_steps 1091 | if retrieval_loss is not None 1092 | else None 1093 | ) 1094 | 1095 | if self.use_amp: 1096 | self.scaler.scale(loss).backward() 1097 | elif self.use_apex: 1098 | with amp.scale_loss(loss, self.optimizer) as scaled_loss: 1099 | scaled_loss.backward() 1100 | elif self.deepspeed: 1101 | # loss gets scaled under gradient_accumulation_steps in deepspeed 1102 | loss = self.deepspeed.backward(loss) 1103 | else: 1104 | loss.backward() 1105 | 1106 | task_loss = task_loss.detach() if task_loss is not None else None 1107 | retrieval_loss = retrieval_loss.detach() if retrieval_loss is not None else None 1108 | return loss.detach(), task_loss, retrieval_loss 1109 | 1110 | def compute_loss(self, model, inputs, return_outputs=False): 1111 | """ 1112 | How the loss is computed by Trainer. By default, all models return the loss in the first element. 1113 | 1114 | Subclass and override for custom behavior. 1115 | """ 1116 | if self.label_smoother is not None and "labels" in inputs: 1117 | labels = inputs.pop("labels") 1118 | else: 1119 | labels = None 1120 | outputs = model(**inputs) 1121 | # Save past state if it exists 1122 | # TODO: this needs to be fixed and made cleaner later. 1123 | if self.args.past_index >= 0: 1124 | self._past = outputs[self.args.past_index] 1125 | 1126 | task_loss = None 1127 | retrieval_loss = None 1128 | 1129 | if labels is not None: 1130 | loss = self.label_smoother(outputs, labels) 1131 | else: 1132 | # We don't use .loss here since the model may return tuples instead of ModelOutput. 1133 | loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] 1134 | task_loss = outputs["task_loss"] if "task_loss" in outputs else None 1135 | retrieval_loss = ( 1136 | outputs["retrieval_loss"] if "retrieval_loss" in outputs else None 1137 | ) 1138 | retrieval_logits = ( 1139 | outputs["retrieval_predictions"] 1140 | if "retrieval_predictions" in outputs 1141 | else None 1142 | ) 1143 | retrieval_instance_labels = ( 1144 | outputs["retrieval_instance_labels"] 1145 | if "retrieval_instance_labels" in outputs 1146 | else None 1147 | ) 1148 | 1149 | return ( 1150 | ( 1151 | loss, 1152 | task_loss, 1153 | retrieval_loss, 1154 | retrieval_logits, 1155 | retrieval_instance_labels, 1156 | outputs, 1157 | ) 1158 | if return_outputs 1159 | else ( 1160 | loss, 1161 | task_loss, 1162 | retrieval_loss, 1163 | retrieval_logits, 1164 | retrieval_instance_labels, 1165 | ) 1166 | ) 1167 | 1168 | def evaluate( 1169 | self, 1170 | eval_dataset: Optional[Dataset] = None, 1171 | ignore_keys: Optional[List[str]] = None, 1172 | metric_key_prefix: str = "eval", 1173 | speed_metrics=False, 1174 | interference_report=False 1175 | ) -> Dict[str, float]: 1176 | """ 1177 | Run evaluation and returns metrics. 1178 | 1179 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent 1180 | (pass it to the init :obj:`compute_metrics` argument). 1181 | 1182 | You can also subclass and override this method to inject custom behavior. 1183 | 1184 | Args: 1185 | eval_dataset (:obj:`Dataset`, `optional`): 1186 | Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, 1187 | columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the 1188 | :obj:`__len__` method. 1189 | ignore_keys (:obj:`Lst[str]`, `optional`): 1190 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 1191 | gathering predictions. 1192 | metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`): 1193 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 1194 | "eval_bleu" if the prefix is "eval" (default) 1195 | 1196 | Returns: 1197 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The 1198 | dictionary also contains the epoch number which comes from the training state. 1199 | """ 1200 | # memory metrics - must set up as early as possible 1201 | self._memory_tracker.start() 1202 | nvidia_smi.nvmlInit() 1203 | if eval_dataset is not None and not isinstance( 1204 | eval_dataset, collections.abc.Sized 1205 | ): 1206 | raise ValueError("eval_dataset must implement __len__") 1207 | 1208 | eval_dataloader = self.get_eval_dataloader(eval_dataset) 1209 | start_time = time.time() 1210 | 1211 | output = self.prediction_loop( 1212 | eval_dataloader, 1213 | description="Evaluation", 1214 | # No point gathering the predictions if there are no metrics, otherwise we defer to 1215 | # self.args.prediction_loss_only 1216 | prediction_loss_only=True if self.compute_metrics is None else None, 1217 | ignore_keys=ignore_keys, 1218 | metric_key_prefix=metric_key_prefix, 1219 | ) 1220 | 1221 | n_samples = len(eval_dataset if eval_dataset is not None else self.eval_dataset) 1222 | # output.metrics.update(speed_metrics(metric_key_prefix, start_time, n_samples)) 1223 | 1224 | self.log(output.metrics, use_global_step=False) 1225 | 1226 | if self.args.tpu_metrics_debug or self.args.debug: 1227 | # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) 1228 | xm.master_print(met.metrics_report()) 1229 | 1230 | self.control = self.callback_handler.on_evaluate( 1231 | self.args, self.state, self.control, output.metrics 1232 | ) 1233 | if speed_metrics: 1234 | 1235 | train_dataloader = self.get_train_dataloader() 1236 | model = self._wrap_model(self.model, training=False) 1237 | model.eval() 1238 | tot_samples = 0 1239 | total_infer_time = 0 1240 | average_gpu_memory = 0 1241 | batch_ctr = 0 1242 | 1243 | with torch.no_grad(): 1244 | for _, inputs in enumerate(tqdm(train_dataloader)): 1245 | inputs = self._prepare_inputs(inputs) 1246 | inputs.pop('labels') 1247 | start_time = time.time() 1248 | _ = model(**inputs) 1249 | torch.cuda.synchronize() 1250 | end_time = time.time() 1251 | total_infer_time += (end_time - start_time) 1252 | tot_samples += inputs['input_ids'].shape[0] 1253 | batch_ctr += 1 1254 | # gpu memory calculations 1255 | handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0) 1256 | info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle) 1257 | average_gpu_memory += info.used / 1e+9 1258 | if batch_ctr > 50: 1259 | break 1260 | throughput = tot_samples / total_infer_time 1261 | average_gpu_memory = average_gpu_memory / batch_ctr 1262 | # update metrics 1263 | output.metrics[f"{metric_key_prefix}_throughput"] = throughput 1264 | output.metrics[f"{metric_key_prefix}_speed_tot_samples"] = tot_samples 1265 | output.metrics[f"{metric_key_prefix}_inference_time"] = total_infer_time 1266 | output.metrics[f"{metric_key_prefix}_average_memory"] = average_gpu_memory 1267 | 1268 | if interference_report: 1269 | num_anchors = 10 1270 | num_batches = 30 1271 | train_dataloader = self.get_train_dataloader() 1272 | model = self._wrap_model(self.model, training=False) 1273 | model.eval() 1274 | anchors = None 1275 | anchor_representations = {i: [] for i in range(num_anchors)} 1276 | with torch.no_grad(): 1277 | for batch_id, inputs in enumerate(tqdm(train_dataloader)): 1278 | if batch_id > num_batches: 1279 | break 1280 | inputs = self._prepare_inputs(inputs) 1281 | if anchors is None: 1282 | anchors = inputs["input_ids"][:num_anchors] 1283 | continue 1284 | # replace with anchor and get corresponding demux representations 1285 | for anchor_id in range(num_anchors): 1286 | anchor = anchors[anchor_id] 1287 | inputs["input_ids"][0] = anchor 1288 | inputs["return_dict"] = True 1289 | outputs = model(**inputs) 1290 | anchor_representations[anchor_id].append(outputs["hidden_states"][0]) 1291 | # t-sne plot 1292 | anchor_representations_stacked = [] 1293 | for anchor_id in range(num_anchors): 1294 | anchor_representations_stacked.append(torch.stack(anchor_representations[anchor_id])) 1295 | anchor_representations_stacked = torch.cat(anchor_representations_stacked) 1296 | anchor_representations_stacked = anchor_representations_stacked.cpu().numpy() 1297 | average_cos_similarity = 0 1298 | for anchor_i in range(num_anchors): 1299 | for anchor_j in range(anchor_i +1, num_anchors): 1300 | anchor_i_representations = torch.stack(anchor_representations[anchor_i]) 1301 | anchor_j_representations = torch.stack(anchor_representations[anchor_j]) 1302 | average_cos_similarity += (torch.matmul(anchor_i_representations, anchor_j_representations.t()) / (torch.norm(anchor_i_representations, dim=1) * torch.norm(anchor_j_representations, dim=1))).mean() 1303 | average_cos_similarity /= (num_anchors * (num_anchors - 1) * 0.5) 1304 | pca_50 = PCA(n_components=200) 1305 | pca_result_50 = pca_50.fit_transform(anchor_representations_stacked) 1306 | tsne = TSNE(n_components=2, verbose=1, n_iter=500) 1307 | tsne_pca_results = tsne.fit_transform(pca_result_50) 1308 | df = pd.DataFrame() 1309 | df["tsne_1"] = tsne_pca_results[:, 0] 1310 | df["tsne_2"] = tsne_pca_results[:, 1] 1311 | df["sample"] = np.repeat(np.arange(num_anchors), len(anchor_representations[0])) 1312 | df["sample"] = df["sample"].apply(lambda i: str(i)) 1313 | sns.scatterplot( 1314 | x="tsne_1", y="tsne_2", 1315 | hue='sample', 1316 | palette=sns.color_palette("bright", num_anchors), 1317 | data=df, 1318 | legend="full", 1319 | alpha=0.3, 1320 | ) 1321 | plt.xlabel('x', fontsize=12) 1322 | plt.ylabel('y', fontsize=12) 1323 | plt.title(f'Interference analysis: N = {self.model.config.num_instances}', fontsize=16) 1324 | plt.legend(loc="lower right", fontsize=8) 1325 | plt.savefig(f"interference_fig_{self.model.config.num_instances}.png") 1326 | df.to_csv(f"interference_fig_{self.model.config.num_instances}.csv") 1327 | print(f"average cos similarity: {average_cos_similarity}") 1328 | 1329 | self._memory_tracker.stop_and_update_metrics(output.metrics) 1330 | 1331 | return output.metrics 1332 | 1333 | def predict( 1334 | self, 1335 | test_dataset: Dataset, 1336 | ignore_keys: Optional[List[str]] = None, 1337 | metric_key_prefix: str = "eval", 1338 | ) -> PredictionOutput: 1339 | """ 1340 | Run prediction and returns predictions and potential metrics. 1341 | 1342 | Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method 1343 | will also return metrics, like in :obj:`evaluate()`. 1344 | 1345 | Args: 1346 | test_dataset (:obj:`Dataset`): 1347 | Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the 1348 | ``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__` 1349 | ignore_keys (:obj:`Lst[str]`, `optional`): 1350 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 1351 | gathering predictions. 1352 | metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`): 1353 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 1354 | "eval_bleu" if the prefix is "eval" (default) 1355 | 1356 | .. note:: 1357 | 1358 | If your predictions or labels have different sequence length (for instance because you're doing dynamic 1359 | padding in a token classification task) the predictions will be padded (on the right) to allow for 1360 | concatenation into one array. The padding index is -100. 1361 | 1362 | Returns: `NamedTuple` A namedtuple with the following keys: 1363 | 1364 | - predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`. 1365 | - label_ids (:obj:`np.ndarray`, `optional`): The labels (if the dataset contained some). 1366 | - metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset 1367 | contained labels). 1368 | """ 1369 | # memory metrics - must set up as early as possible 1370 | self._memory_tracker.start() 1371 | 1372 | if test_dataset is not None and not isinstance( 1373 | test_dataset, collections.abc.Sized 1374 | ): 1375 | raise ValueError("test_dataset must implement __len__") 1376 | 1377 | test_dataloader = self.get_test_dataloader(test_dataset) 1378 | start_time = time.time() 1379 | 1380 | output = self.prediction_loop( 1381 | test_dataloader, 1382 | description="Prediction", 1383 | ignore_keys=ignore_keys, 1384 | metric_key_prefix=metric_key_prefix, 1385 | ) 1386 | # output.metrics.update( 1387 | # speed_metrics(metric_key_prefix, start_time, len(test_dataset)) 1388 | # ) 1389 | 1390 | self._memory_tracker.stop_and_update_metrics(output.metrics) 1391 | 1392 | return output 1393 | 1394 | def prediction_loop( 1395 | self, 1396 | dataloader: DataLoader, 1397 | description: str, 1398 | prediction_loss_only: Optional[bool] = None, 1399 | ignore_keys: Optional[List[str]] = None, 1400 | metric_key_prefix: str = "eval", 1401 | ) -> PredictionOutput: 1402 | """ 1403 | Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`. 1404 | 1405 | Works both with or without labels. 1406 | """ 1407 | if not isinstance(dataloader.dataset, collections.abc.Sized): 1408 | raise ValueError("dataset must implement __len__") 1409 | prediction_loss_only = ( 1410 | prediction_loss_only 1411 | if prediction_loss_only is not None 1412 | else self.args.prediction_loss_only 1413 | ) 1414 | 1415 | if self.args.deepspeed and not self.args.do_train: 1416 | # no harm, but flagging to the user that deepspeed config is ignored for eval 1417 | # flagging only for when --do_train wasn't passed as only then it's redundant 1418 | logger.info( 1419 | "Detected the deepspeed argument but it will not be used for evaluation" 1420 | ) 1421 | 1422 | model = self._wrap_model(self.model, training=False) 1423 | 1424 | # if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while 1425 | # ``train`` is running, half it first and then put on device 1426 | if not self.is_in_train and self.args.fp16_full_eval: 1427 | model = model.half().to(self.args.device) 1428 | 1429 | batch_size = dataloader.batch_size 1430 | num_examples = self.num_examples(dataloader) 1431 | num_examples = ( 1432 | (num_examples // batch_size) * batch_size 1433 | if self.args.dataloader_drop_last 1434 | else num_examples 1435 | ) 1436 | 1437 | logger.info("***** Running %s *****", description) 1438 | logger.info(" Num examples = %d", num_examples) 1439 | logger.info(" Batch size = %d", batch_size) 1440 | losses_host: torch.Tensor = None 1441 | task_losses_host: torch.Tensor = None 1442 | retrieval_losses_host: torch.Tensor = None 1443 | retrieval_accs_host: torch.Tensor = None 1444 | 1445 | preds_host: Union[torch.Tensor, List[torch.Tensor]] = None 1446 | labels_host: Union[torch.Tensor, List[torch.Tensor]] = None 1447 | 1448 | world_size = max(1, self.args.world_size) 1449 | 1450 | eval_losses_gatherer = DistributedTensorGatherer( 1451 | world_size, num_examples, make_multiple_of=batch_size 1452 | ) 1453 | eval_task_losses_gatherer = DistributedTensorGatherer( 1454 | world_size, num_examples, make_multiple_of=batch_size 1455 | ) 1456 | eval_retrieval_losses_gatherer = DistributedTensorGatherer( 1457 | world_size, num_examples, make_multiple_of=batch_size 1458 | ) 1459 | eval_retrieval_acc_gatherer = DistributedTensorGatherer( 1460 | world_size, num_examples, make_multiple_of=batch_size 1461 | ) 1462 | 1463 | if not prediction_loss_only: 1464 | # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass 1465 | # a batch size to the sampler) 1466 | make_multiple_of = None 1467 | if hasattr(dataloader, "sampler") and isinstance( 1468 | dataloader.sampler, SequentialDistributedSampler 1469 | ): 1470 | make_multiple_of = dataloader.sampler.batch_size 1471 | preds_gatherer = DistributedTensorGatherer( 1472 | world_size, num_examples, make_multiple_of=make_multiple_of 1473 | ) 1474 | labels_gatherer = DistributedTensorGatherer( 1475 | world_size, num_examples, make_multiple_of=make_multiple_of 1476 | ) 1477 | 1478 | model.eval() 1479 | 1480 | if is_torch_tpu_available(): 1481 | dataloader = pl.ParallelLoader( 1482 | dataloader, [self.args.device] 1483 | ).per_device_loader(self.args.device) 1484 | 1485 | if self.args.past_index >= 0: 1486 | self._past = None 1487 | 1488 | self.callback_handler.eval_dataloader = dataloader 1489 | 1490 | column_names = ["Token", "Count", "Correct", "Accuracy"] 1491 | table = wandb.Table(columns=column_names) 1492 | token_acc_counter = {} 1493 | for step, inputs in enumerate(dataloader): 1494 | ( 1495 | loss, 1496 | logits, 1497 | labels, 1498 | task_loss, 1499 | retrieval_loss, 1500 | retrieval_acc, 1501 | token_acc_counter, 1502 | ) = self.prediction_step( 1503 | model, 1504 | inputs, 1505 | prediction_loss_only, 1506 | ignore_keys=ignore_keys, 1507 | token_acc_counter=token_acc_counter, 1508 | ) 1509 | # print("token acc counter", token_acc_counter) 1510 | if loss is not None: 1511 | losses = loss.repeat(batch_size) 1512 | losses_host = ( 1513 | losses 1514 | if losses_host is None 1515 | else torch.cat((losses_host, losses), dim=0) 1516 | ) 1517 | if task_loss is not None: 1518 | task_losses = task_loss.repeat(batch_size) 1519 | task_losses_host = ( 1520 | task_losses 1521 | if task_losses_host is None 1522 | else torch.cat((task_losses_host, task_losses), dim=0) 1523 | ) 1524 | if retrieval_loss is not None: 1525 | retrieval_losses = retrieval_loss.repeat(batch_size) 1526 | retrieval_losses_host = ( 1527 | retrieval_losses 1528 | if retrieval_losses_host is None 1529 | else torch.cat((retrieval_losses_host, retrieval_losses), dim=0) 1530 | ) 1531 | if retrieval_acc is not None: 1532 | retrieval_accs = retrieval_acc.repeat(batch_size) 1533 | retrieval_accs_host = ( 1534 | retrieval_accs 1535 | if retrieval_accs_host is None 1536 | else torch.cat((retrieval_accs_host, retrieval_accs), dim=0) 1537 | ) 1538 | if logits is not None: 1539 | preds_host = ( 1540 | logits 1541 | if preds_host is None 1542 | else nested_concat(preds_host, logits, padding_index=-100) 1543 | ) 1544 | if labels is not None: 1545 | labels_host = ( 1546 | labels 1547 | if labels_host is None 1548 | else nested_concat(labels_host, labels, padding_index=-100) 1549 | ) 1550 | self.control = self.callback_handler.on_prediction_step( 1551 | self.args, self.state, self.control 1552 | ) 1553 | 1554 | # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. 1555 | if ( 1556 | self.args.eval_accumulation_steps is not None 1557 | and (step + 1) % self.args.eval_accumulation_steps == 0 1558 | ): 1559 | eval_losses_gatherer.add_arrays( 1560 | self._gather_and_numpify(losses_host, "eval_losses") 1561 | ) 1562 | if task_losses_host is not None: 1563 | eval_task_losses_gatherer.add_arrays( 1564 | self._gather_and_numpify(task_losses_host, "eval_task_losses") 1565 | ) 1566 | 1567 | if retrieval_losses_host is not None: 1568 | eval_retrieval_losses_gatherer.add_arrays( 1569 | self._gather_and_numpify( 1570 | retrieval_losses_host, "eval_retrieval_losses" 1571 | ) 1572 | ) 1573 | 1574 | if not prediction_loss_only: 1575 | preds_gatherer.add_arrays( 1576 | self._gather_and_numpify(preds_host, "eval_preds") 1577 | ) 1578 | labels_gatherer.add_arrays( 1579 | self._gather_and_numpify(labels_host, "eval_label_ids") 1580 | ) 1581 | 1582 | # Set back to None to begin a new accumulation 1583 | ( 1584 | losses_host, 1585 | preds_host, 1586 | labels_host, 1587 | task_losses_host, 1588 | retrieval_losses_host, 1589 | ) = (None, None, None, None, None) 1590 | 1591 | # add to wandb table 1592 | for token in token_acc_counter: 1593 | table.add_data( 1594 | token, 1595 | token_acc_counter[token]["tot_count"], 1596 | token_acc_counter[token]["correct"] 1597 | if "correct" in token_acc_counter[token] 1598 | else 0, 1599 | token_acc_counter[token]["correct"] 1600 | / token_acc_counter[token]["tot_count"] 1601 | if "correct" in token_acc_counter[token] 1602 | else 0, 1603 | ) 1604 | # log table 1605 | self.callback_handler.on_log( 1606 | self.args, self.state, self.control, {str(self.state.global_step): table} 1607 | ) 1608 | if self.args.past_index and hasattr(self, "_past"): 1609 | # Clean the state at the end of the evaluation loop 1610 | delattr(self, "_past") 1611 | 1612 | # Gather all remaining tensors and put them back on the CPU 1613 | eval_losses_gatherer.add_arrays( 1614 | self._gather_and_numpify(losses_host, "eval_losses") 1615 | ) 1616 | if task_losses_host is not None: 1617 | eval_task_losses_gatherer.add_arrays( 1618 | self._gather_and_numpify(task_losses_host, "eval_task_losses") 1619 | ) 1620 | 1621 | if retrieval_losses_host is not None: 1622 | eval_retrieval_losses_gatherer.add_arrays( 1623 | self._gather_and_numpify(retrieval_losses_host, "eval_retrieval_losses") 1624 | ) 1625 | if retrieval_accs_host is not None: 1626 | eval_retrieval_acc_gatherer.add_arrays( 1627 | self._gather_and_numpify(retrieval_accs_host, "eval_retrieval_accs") 1628 | ) 1629 | 1630 | if not prediction_loss_only: 1631 | preds_gatherer.add_arrays( 1632 | self._gather_and_numpify(preds_host, "eval_preds") 1633 | ) 1634 | labels_gatherer.add_arrays( 1635 | self._gather_and_numpify(labels_host, "eval_label_ids") 1636 | ) 1637 | 1638 | eval_loss = eval_losses_gatherer.finalize() 1639 | eval_task_loss = ( 1640 | eval_task_losses_gatherer.finalize() 1641 | if task_losses_host is not None 1642 | else None 1643 | ) 1644 | eval_retrieval_loss = ( 1645 | eval_retrieval_losses_gatherer.finalize() 1646 | if retrieval_losses_host is not None 1647 | else None 1648 | ) 1649 | eval_retrieval_accs = ( 1650 | eval_retrieval_acc_gatherer.finalize() 1651 | if retrieval_accs_host is not None 1652 | else None 1653 | ) 1654 | preds = preds_gatherer.finalize() if not prediction_loss_only else None 1655 | label_ids = labels_gatherer.finalize() if not prediction_loss_only else None 1656 | if ( 1657 | self.compute_metrics is not None 1658 | and preds is not None 1659 | and label_ids is not None 1660 | ): 1661 | metrics = self.compute_metrics( 1662 | EvalPrediction(predictions=preds, label_ids=label_ids) 1663 | ) 1664 | else: 1665 | metrics = {} 1666 | 1667 | # To be JSON-serializable, we need to remove numpy types or zero-d tensors 1668 | metrics = denumpify_detensorize(metrics) 1669 | 1670 | if eval_loss is not None: 1671 | metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item() 1672 | if eval_task_loss is not None: 1673 | metrics[f"{metric_key_prefix}_task_loss"] = eval_task_loss.mean().item() 1674 | if eval_retrieval_loss is not None: 1675 | metrics[ 1676 | f"{metric_key_prefix}_retrieval_loss" 1677 | ] = eval_retrieval_loss.mean().item() 1678 | if eval_retrieval_accs is not None: 1679 | metrics[ 1680 | f"{metric_key_prefix}_retrieval_acc" 1681 | ] = eval_retrieval_accs.mean().item() 1682 | 1683 | # Prefix all keys with metric_key_prefix + '_' 1684 | for key in list(metrics.keys()): 1685 | if not key.startswith(f"{metric_key_prefix}_"): 1686 | metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) 1687 | return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) 1688 | 1689 | def prediction_step( 1690 | self, 1691 | model: nn.Module, 1692 | inputs: Dict[str, Union[torch.Tensor, Any]], 1693 | prediction_loss_only: bool, 1694 | ignore_keys: Optional[List[str]] = None, 1695 | token_acc_counter=None, 1696 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 1697 | """ 1698 | Perform an evaluation step on :obj:`model` using obj:`inputs`. 1699 | 1700 | Subclass and override to inject custom behavior. 1701 | 1702 | Args: 1703 | model (:obj:`nn.Module`): 1704 | The model to evaluate. 1705 | inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): 1706 | The inputs and targets of the model. 1707 | 1708 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 1709 | argument :obj:`labels`. Check your model's documentation for all accepted arguments. 1710 | prediction_loss_only (:obj:`bool`): 1711 | Whether or not to return the loss only. 1712 | ignore_keys (:obj:`Lst[str]`, `optional`): 1713 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 1714 | gathering predictions. 1715 | 1716 | Return: 1717 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and 1718 | labels (each being optional). 1719 | """ 1720 | 1721 | has_labels = all(inputs.get(k) is not None for k in self.label_names) 1722 | inputs = self._prepare_inputs(inputs) 1723 | if ignore_keys is None: 1724 | if hasattr(self.model, "config"): 1725 | ignore_keys = getattr( 1726 | self.model.config, "keys_to_ignore_at_inference", [] 1727 | ) 1728 | else: 1729 | ignore_keys = [] 1730 | 1731 | # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. 1732 | if has_labels: 1733 | labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) 1734 | if len(labels) == 1: 1735 | labels = labels[0] 1736 | else: 1737 | labels = None 1738 | 1739 | logits = None 1740 | retrieval_loss = None 1741 | task_loss = None 1742 | retrieval_predictions=None 1743 | retrieval_instance_labels=None 1744 | with torch.no_grad(): 1745 | 1746 | if has_labels: 1747 | ( 1748 | loss, 1749 | task_loss, 1750 | retrieval_loss, 1751 | retrieval_predictions, 1752 | retrieval_instance_labels, 1753 | outputs, 1754 | ) = self.compute_loss(model, inputs, return_outputs=True) 1755 | loss = loss.mean().detach() 1756 | if isinstance(outputs, dict): 1757 | # logits = tuple( 1758 | # v 1759 | # for k, v in outputs.items() 1760 | # if k 1761 | # not in ignore_keys + ["loss", "task_loss", "retrieval_loss"] 1762 | # ) 1763 | logits = outputs["logits"] if "logits" in outputs else None 1764 | else: 1765 | logits = outputs[1:] 1766 | if "retrieval_loss" in outputs: 1767 | retrieval_loss = outputs["retrieval_loss"] 1768 | retrieval_loss = retrieval_loss.mean().detach() 1769 | 1770 | if "task_loss" in outputs: 1771 | task_loss = outputs["task_loss"] 1772 | task_loss = task_loss.mean().detach() 1773 | 1774 | else: 1775 | loss = None 1776 | if self.use_amp: 1777 | with autocast(): 1778 | outputs = model(**inputs) 1779 | else: 1780 | outputs = model(**inputs) 1781 | if isinstance(outputs, dict): 1782 | # logits = tuple( 1783 | # v for k, v in outputs.items() if k not in ignore_keys 1784 | # ) 1785 | logits = outputs["logits"] if "logits" in outputs else None 1786 | else: 1787 | logits = outputs 1788 | if self.args.past_index >= 0: 1789 | self._past = outputs[self.args.past_index - 1] 1790 | 1791 | retrieval_acc = None 1792 | 1793 | if retrieval_predictions is not None: 1794 | retrieval_predictions = nested_detach(retrieval_predictions) 1795 | retrieval_instance_labels = retrieval_instance_labels.detach() 1796 | # calculate retrieval acc from the retrieval logits and the sentence labels 1797 | with torch.no_grad(): 1798 | ignore_predictions = retrieval_instance_labels == -100 1799 | retrieval_correct_predictions = ( 1800 | retrieval_predictions == retrieval_instance_labels 1801 | ) & ~ignore_predictions 1802 | retrieval_acc = torch.sum(retrieval_correct_predictions) / torch.sum( 1803 | ~ignore_predictions 1804 | ) 1805 | retrieval_acc = retrieval_acc.detach() 1806 | # add token level information in the token counter object 1807 | if token_acc_counter is not None: 1808 | correct_tokens = retrieval_predictions[ 1809 | retrieval_correct_predictions 1810 | ] 1811 | correct_tokens_unique, correct_tokens_unique_counts = torch.unique( 1812 | correct_tokens, return_counts=True 1813 | ) 1814 | # iterate through vocab ids and get the exact token 1815 | for c_token_id, c_token in enumerate( 1816 | correct_tokens_unique.tolist() 1817 | ): 1818 | c_token_decoded = self.tokenizer.decode([c_token]) 1819 | if c_token_decoded not in token_acc_counter: 1820 | token_acc_counter[c_token_decoded] = { 1821 | "correct": 0, 1822 | "tot_count": 0, 1823 | } 1824 | token_acc_counter[c_token_decoded][ 1825 | "correct" 1826 | ] += correct_tokens_unique_counts[c_token_id].item() 1827 | 1828 | all_tokens_unique, all_tokens_count = torch.unique( 1829 | retrieval_instance_labels[~ignore_predictions], 1830 | return_counts=True, 1831 | ) 1832 | for a_token_id, a_token in enumerate(all_tokens_unique.tolist()): 1833 | 1834 | a_token_decoded = self.tokenizer.decode([a_token]) 1835 | 1836 | if a_token_decoded not in token_acc_counter: 1837 | token_acc_counter[a_token_decoded] = { 1838 | "correct": 0, 1839 | "tot_count": 0, 1840 | } 1841 | token_acc_counter[a_token_decoded][ 1842 | "tot_count" 1843 | ] += all_tokens_count[a_token_id].item() 1844 | 1845 | if prediction_loss_only: 1846 | return ( 1847 | loss, 1848 | None, 1849 | None, 1850 | task_loss, 1851 | retrieval_loss, 1852 | retrieval_acc, 1853 | token_acc_counter, 1854 | ) 1855 | if logits is not None: 1856 | logits = nested_detach(logits) 1857 | if len(logits) == 1: 1858 | logits = logits[0] 1859 | 1860 | return ( 1861 | loss, 1862 | logits, 1863 | labels, 1864 | task_loss, 1865 | retrieval_loss, 1866 | retrieval_acc, 1867 | token_acc_counter, 1868 | ) 1869 | def log(self, logs: Dict[str, float], use_global_step=True) -> None: 1870 | """ 1871 | Log :obj:`logs` on the various objects watching training. 1872 | 1873 | Subclass and override this method to inject custom behavior. 1874 | 1875 | Args: 1876 | logs (:obj:`Dict[str, float]`): 1877 | The values to log. 1878 | """ 1879 | if self.state.epoch is not None: 1880 | logs["epoch"] = round(self.state.epoch, 2) 1881 | output = {**logs, **{"step": self.state.global_step}} 1882 | self.state.log_history.append(output) 1883 | logs["use_global_step"] = use_global_step 1884 | self.control = self.callback_handler.on_log( 1885 | self.args, self.state, self.control, logs 1886 | ) 1887 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import re 4 | import os 5 | PREFIX_CHECKPOINT_DIR = "checkpoint" 6 | _re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$") 7 | 8 | def get_last_checkpoint_trainerstate_robust(folder): 9 | content = os.listdir(folder) 10 | checkpoints = [ 11 | path 12 | for path in content 13 | if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path)) and os.path.exists(os.path.join(folder, path, "trainer_state.json")) 14 | ] 15 | if len(checkpoints) == 0: 16 | return 17 | return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0]))) 18 | 19 | 20 | def random_encoding(max_positions, d_model, norm=1): 21 | 22 | gauss = torch.randn((max_positions, d_model)) 23 | gauss = gauss / torch.norm(gauss, dim=1).unsqueeze(1) 24 | gauss *= norm 25 | return gauss 26 | 27 | def topk( 28 | logits, 29 | gt_classes, 30 | k_list, 31 | ): 32 | assert len(logits.shape) == 2 33 | assert len(gt_classes.shape) == 1 34 | batch, _ = logits.shape 35 | max_k = max(k_list) 36 | top_labels_max_k = torch.topk(logits, max_k, dim=1)[1] 37 | return [ 38 | torch.sum(top_labels_max_k[:, :k] == gt_classes.unsqueeze(1)) / batch 39 | for k in k_list 40 | ] 41 | 42 | 43 | def gen_attn_mask(sequence_length, len=None): 44 | batch_size = sequence_length.size(0) 45 | seq_range = torch.arange(len) 46 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, len) 47 | seq_range_expand = seq_range_expand.to(sequence_length.device) 48 | seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand) 49 | return seq_range_expand < seq_length_expand 50 | 51 | 52 | def binary_encoding(max_position, d_model, epsilon=0.3): 53 | assert epsilon <= 1 and epsilon >= 0, "epsilon value should lie in [0,1)" 54 | chunk_size = d_model // max_position 55 | start_of_chunks = chunk_size * torch.arange(max_position) 56 | end_of_chunks = start_of_chunks + chunk_size 57 | end_of_chunks[-1] = d_model 58 | # tweak start and end states to account for epsilon 59 | num_intersection = (epsilon / 2) * chunk_size 60 | start_of_chunks[1:] = start_of_chunks[1:] - num_intersection 61 | end_of_chunks[:-1] = end_of_chunks[:-1] + num_intersection 62 | 63 | # for loop here :( , not worth vectorizing, only called once 64 | binary_embeds = torch.zeros(max_position, d_model) 65 | for pos in range(max_position): 66 | binary_embeds[pos, start_of_chunks[pos] : end_of_chunks[pos]] = 1 67 | return binary_embeds 68 | 69 | def count_params_hf(model): 70 | params = {k: v for k, v in model.named_parameters()} 71 | return sum([math.prod(v.shape) for _, v in params.items()]) 72 | -------------------------------------------------------------------------------- /run_glue.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification on GLUE.""" 17 | # You can also adapt this script on your own text classification task. Pointers for this are left as comments. 18 | 19 | import logging 20 | import os 21 | import random 22 | import sys 23 | from dataclasses import dataclass, field 24 | from typing import Optional 25 | 26 | import numpy as np 27 | from datasets import load_dataset, load_metric 28 | import torch 29 | import transformers 30 | from transformers import ( 31 | AutoConfig, 32 | AutoModelForSequenceClassification, 33 | AutoTokenizer, 34 | DataCollatorWithPadding, 35 | EvalPrediction, 36 | HfArgumentParser, 37 | PretrainedConfig, 38 | TrainingArguments, 39 | default_data_collator, 40 | set_seed, 41 | ) 42 | from models.multiplexing import RobertaSequenceClassificationMuxed 43 | from models.utils import get_last_checkpoint_trainerstate_robust 44 | from models.trainer import MuxTrainer 45 | import re 46 | 47 | task_to_keys = { 48 | "cola": ("sentence", None), 49 | "mnli": ("premise", "hypothesis"), 50 | "mrpc": ("sentence1", "sentence2"), 51 | "qnli": ("question", "sentence"), 52 | "qqp": ("question1", "question2"), 53 | "rte": ("sentence1", "sentence2"), 54 | "sst2": ("sentence", None), 55 | "stsb": ("sentence1", "sentence2"), 56 | "wnli": ("sentence1", "sentence2"), 57 | } 58 | 59 | logger = logging.getLogger(__name__) 60 | 61 | 62 | @dataclass 63 | class DataTrainingArguments: 64 | """ 65 | Arguments pertaining to what data we are going to input our model for training and eval. 66 | Using `HfArgumentParser` we can turn this class 67 | into argparse arguments to be able to specify them on 68 | the command line. 69 | """ 70 | 71 | task_name: Optional[str] = field( 72 | default=None, 73 | metadata={ 74 | "help": "The name of the task to train on: " 75 | + ", ".join(task_to_keys.keys()) 76 | }, 77 | ) 78 | dataset_name: Optional[str] = field( 79 | default=None, 80 | metadata={"help": "The name of the dataset to use (via the datasets library)."}, 81 | ) 82 | dataset_config_name: Optional[str] = field( 83 | default=None, 84 | metadata={ 85 | "help": "The configuration name of the dataset to use (via the datasets library)." 86 | }, 87 | ) 88 | max_seq_length: int = field( 89 | default=128, 90 | metadata={ 91 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 92 | "than this will be truncated, sequences shorter will be padded." 93 | }, 94 | ) 95 | overwrite_cache: bool = field( 96 | default=False, 97 | metadata={"help": "Overwrite the cached preprocessed datasets or not."}, 98 | ) 99 | pad_to_max_length: bool = field( 100 | default=True, 101 | metadata={ 102 | "help": "Whether to pad all samples to `max_seq_length`. " 103 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 104 | }, 105 | ) 106 | max_train_samples: Optional[int] = field( 107 | default=None, 108 | metadata={ 109 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 110 | "value if set." 111 | }, 112 | ) 113 | max_eval_samples: Optional[int] = field( 114 | default=None, 115 | metadata={ 116 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 117 | "value if set." 118 | }, 119 | ) 120 | max_predict_samples: Optional[int] = field( 121 | default=None, 122 | metadata={ 123 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 124 | "value if set." 125 | }, 126 | ) 127 | train_file: Optional[str] = field( 128 | default=None, 129 | metadata={"help": "A csv or a json file containing the training data."}, 130 | ) 131 | validation_file: Optional[str] = field( 132 | default=None, 133 | metadata={"help": "A csv or a json file containing the validation data."}, 134 | ) 135 | test_file: Optional[str] = field( 136 | default=None, 137 | metadata={"help": "A csv or a json file containing the test data."}, 138 | ) 139 | 140 | def __post_init__(self): 141 | if self.task_name is not None: 142 | self.task_name = self.task_name.lower() 143 | if self.task_name not in task_to_keys.keys(): 144 | raise ValueError( 145 | "Unknown task, you should pick one in " 146 | + ",".join(task_to_keys.keys()) 147 | ) 148 | elif self.dataset_name is not None: 149 | pass 150 | elif self.train_file is None or self.validation_file is None: 151 | raise ValueError( 152 | "Need either a GLUE task, a training/validation file or a dataset name." 153 | ) 154 | else: 155 | train_extension = self.train_file.split(".")[-1] 156 | assert train_extension in [ 157 | "csv", 158 | "json", 159 | ], "`train_file` should be a csv or a json file." 160 | validation_extension = self.validation_file.split(".")[-1] 161 | assert ( 162 | validation_extension == train_extension 163 | ), "`validation_file` should have the same extension (csv or json) as `train_file`." 164 | 165 | 166 | @dataclass 167 | class ModelArguments: 168 | """ 169 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 170 | """ 171 | 172 | model_name_or_path: str = field( 173 | default=None, 174 | metadata={ 175 | "help": "Path to pretrained model or model identifier from huggingface.co/models" 176 | } 177 | ) 178 | config_name: Optional[str] = field( 179 | default=None, 180 | metadata={ 181 | "help": "Pretrained config name or path if not the same as model_name" 182 | }, 183 | ) 184 | tokenizer_name: Optional[str] = field( 185 | default=None, 186 | metadata={ 187 | "help": "Pretrained tokenizer name or path if not the same as model_name" 188 | }, 189 | ) 190 | cache_dir: Optional[str] = field( 191 | default=None, 192 | metadata={ 193 | "help": "Where do you want to store the pretrained models downloaded from huggingface.co" 194 | }, 195 | ) 196 | use_fast_tokenizer: bool = field( 197 | default=True, 198 | metadata={ 199 | "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." 200 | }, 201 | ) 202 | model_revision: str = field( 203 | default="main", 204 | metadata={ 205 | "help": "The specific model version to use (can be a branch name, tag name or commit id)." 206 | }, 207 | ) 208 | use_auth_token: bool = field( 209 | default=False, 210 | metadata={ 211 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 212 | "with private models)." 213 | }, 214 | ) 215 | # multi instance arguments 216 | num_instances: Optional[int] = field( 217 | default=5, 218 | metadata={"help": "Number of instances i.e. N"}, 219 | ) 220 | muxing_variant: Optional[str] = field( 221 | default="gaussian_hadamard", 222 | metadata={"help": "muxing variant; choose from gaussian_hadamard or random_ortho or binary_hadamard"}, 223 | ) 224 | demuxing_variant: Optional[str] = field( 225 | default="index", 226 | metadata={"help": "demuxing variant, choose from 'index' or 'mlp'"}, 227 | ) 228 | should_mux: Optional[int] = field( 229 | default=1, 230 | metadata={"help": "whether to mux, turn off for non-multiplexed baselines"}, 231 | ) 232 | retrieval_percentage: Optional[float] = field( 233 | default=1.0, 234 | metadata={"help": "percentage of tokens to retrieve during inference"}, 235 | ) 236 | retrieval_pretraining: Optional[int] = field( 237 | default=0, 238 | metadata={"help": "Retrieval Pretraining"}, 239 | ) 240 | gaussian_hadamard_norm: Optional[float] = field( 241 | default=1, 242 | metadata={"help": "Norm of sentence embeddings if we use random projections"}, 243 | ) 244 | binary_hadamard_epsilon: Optional[float] = field( 245 | default=0, 246 | metadata={"help": "Percentage intersection among binary vectors, default is no intersection"}, 247 | ) 248 | retrieval_loss_coeff: Optional[float] = field( 249 | default=0.1, 250 | metadata={"help": "Coefficient for retrieval loss"}, 251 | ) 252 | task_loss_coeff: Optional[float] = field( 253 | default=0.9, 254 | metadata={"help": "Coefficient for task loss"}, 255 | ) 256 | learn_muxing: Optional[int] = field( 257 | default=0, 258 | metadata={"help": "whether instance embeddings are learnt or not"}, 259 | ) 260 | 261 | 262 | def main(): 263 | # See all possible arguments in src/transformers/training_args.py 264 | # or by passing the --help flag to this script. 265 | # We now keep distinct sets of args, for a cleaner separation of concerns. 266 | parser = HfArgumentParser( 267 | (ModelArguments, DataTrainingArguments, TrainingArguments) 268 | ) 269 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 270 | # If we pass only one argument to the script and it's the path to a json file, 271 | # let's parse it to get our arguments. 272 | model_args, data_args, training_args = parser.parse_json_file( 273 | json_file=os.path.abspath(sys.argv[1]) 274 | ) 275 | else: 276 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 277 | 278 | # Setup logging 279 | logging.basicConfig( 280 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 281 | datefmt="%m/%d/%Y %H:%M:%S", 282 | handlers=[logging.StreamHandler(sys.stdout)], 283 | ) 284 | logger.setLevel(logging.INFO) 285 | 286 | # Log on each process the small summary: 287 | logger.warning( 288 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 289 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 290 | ) 291 | # Set the verbosity to info of the Transformers logger (on main process only): 292 | transformers.utils.logging.set_verbosity_info() 293 | transformers.utils.logging.enable_default_handler() 294 | transformers.utils.logging.enable_explicit_format() 295 | logger.info(f"Training/evaluation parameters {training_args}") 296 | 297 | # Detecting last checkpoint. 298 | last_checkpoint = None 299 | if ( 300 | os.path.isdir(training_args.output_dir) 301 | and training_args.do_train 302 | and not training_args.overwrite_output_dir 303 | ): 304 | last_checkpoint = get_last_checkpoint_trainerstate_robust(training_args.output_dir) 305 | print("last checkpoint", last_checkpoint) 306 | 307 | # if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 308 | # raise ValueError( 309 | # f"Output directory ({training_args.output_dir}) already exists and is not empty. " 310 | # "Use --overwrite_output_dir to overcome." 311 | # ) 312 | 313 | 314 | # Set seed before initializing model. 315 | set_seed(training_args.seed) 316 | 317 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 318 | # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). 319 | # 320 | # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the 321 | # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named 322 | # label if at least two columns are provided. 323 | # 324 | # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this 325 | # single column. You can easily tweak this behavior (see below) 326 | # 327 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 328 | # download the dataset. 329 | if data_args.task_name is not None: 330 | # Downloading and loading a dataset from the hub. 331 | datasets = load_dataset( 332 | "glue", data_args.task_name, cache_dir=model_args.cache_dir 333 | ) 334 | elif data_args.dataset_name is not None: 335 | # Downloading and loading a dataset from the hub. 336 | datasets = load_dataset( 337 | data_args.dataset_name, 338 | data_args.dataset_config_name, 339 | cache_dir=model_args.cache_dir, 340 | ) 341 | else: 342 | # Loading a dataset from your local files. 343 | # CSV/JSON training and evaluation files are needed. 344 | data_files = { 345 | "train": data_args.train_file, 346 | "validation": data_args.validation_file, 347 | } 348 | 349 | # Get the test dataset: you can provide your own CSV/JSON test file (see below) 350 | # when you use `do_predict` without specifying a GLUE benchmark task. 351 | if training_args.do_predict: 352 | if data_args.test_file is not None: 353 | train_extension = data_args.train_file.split(".")[-1] 354 | test_extension = data_args.test_file.split(".")[-1] 355 | assert ( 356 | test_extension == train_extension 357 | ), "`test_file` should have the same extension (csv or json) as `train_file`." 358 | data_files["test"] = data_args.test_file 359 | else: 360 | raise ValueError( 361 | "Need either a GLUE task or a test file for `do_predict`." 362 | ) 363 | 364 | for key in data_files.keys(): 365 | logger.info(f"load a local file for {key}: {data_files[key]}") 366 | 367 | if data_args.train_file.endswith(".csv"): 368 | # Loading a dataset from local csv files 369 | datasets = load_dataset( 370 | "csv", data_files=data_files, cache_dir=model_args.cache_dir 371 | ) 372 | else: 373 | # Loading a dataset from local json files 374 | datasets = load_dataset( 375 | "json", data_files=data_files, cache_dir=model_args.cache_dir 376 | ) 377 | # See more about loading any type of standard or custom dataset at 378 | # https://huggingface.co/docs/datasets/loading_datasets.html. 379 | 380 | # Labels 381 | if data_args.task_name is not None: 382 | is_regression = data_args.task_name == "stsb" 383 | if not is_regression: 384 | label_list = datasets["train"].features["label"].names 385 | num_labels = len(label_list) 386 | else: 387 | num_labels = 1 388 | else: 389 | if model_args.retrieval_pretraining: 390 | # add dummy value here 391 | num_labels = 2 392 | is_regression = False 393 | label_list = ["dummy1", "dummy2"] 394 | else: 395 | # Trying to have good defaults here, don't hesitate to tweak to your needs. 396 | is_regression = datasets["train"].features["label"].dtype in [ 397 | "float32", 398 | "float64", 399 | ] 400 | if is_regression: 401 | num_labels = 1 402 | else: 403 | # A useful fast method: 404 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique 405 | label_list = datasets["train"].unique("label") 406 | label_list.sort() # Let's sort it for determinism 407 | num_labels = len(label_list) 408 | # 409 | # Load pretrained model and tokenizer 410 | # 411 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 412 | # download model & vocab. 413 | config = AutoConfig.from_pretrained( 414 | model_args.config_name 415 | if model_args.config_name 416 | else model_args.model_name_or_path, 417 | num_labels=num_labels, 418 | finetuning_task=data_args.task_name, 419 | cache_dir=model_args.cache_dir, 420 | revision=model_args.model_revision, 421 | use_auth_token=True if model_args.use_auth_token else None, 422 | ) 423 | tokenizer = AutoTokenizer.from_pretrained( 424 | model_args.tokenizer_name 425 | if model_args.tokenizer_name 426 | else model_args.model_name_or_path, 427 | cache_dir=model_args.cache_dir, 428 | use_fast=model_args.use_fast_tokenizer, 429 | revision=model_args.model_revision, 430 | use_auth_token=True if model_args.use_auth_token else None, 431 | ) 432 | 433 | config.num_instances = model_args.num_instances 434 | config.muxing_variant = model_args.muxing_variant 435 | config.demuxing_variant = model_args.demuxing_variant 436 | config.retrieval_percentage = model_args.retrieval_percentage 437 | config.gaussian_hadamard_norm = model_args.gaussian_hadamard_norm 438 | config.binary_hadamard_epsilon = model_args.binary_hadamard_epsilon 439 | config.retrieval_loss_coeff = model_args.retrieval_loss_coeff 440 | config.task_loss_coeff = model_args.task_loss_coeff 441 | config.learn_muxing = model_args.learn_muxing 442 | 443 | model_path_supplied = model_args.model_name_or_path is not None 444 | if model_args.should_mux: 445 | 446 | if model_path_supplied: 447 | model = RobertaSequenceClassificationMuxed.from_pretrained(model_args.model_name_or_path, config=config) 448 | else: 449 | model = RobertaSequenceClassificationMuxed(config=config) 450 | else: 451 | # non-multiplexed baseline 452 | if model_path_supplied: 453 | model = AutoModelForSequenceClassification.from_pretrained(model_args.model_name_or_path, config=config) 454 | else: 455 | model = AutoModelForSequenceClassification.from_config(config=config) 456 | 457 | if data_args.task_name is not None: 458 | sentence1_key, sentence2_key = task_to_keys[data_args.task_name] 459 | else: 460 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. 461 | non_label_column_names = [ 462 | name for name in datasets["train"].column_names if name != "label" 463 | ] 464 | if ( 465 | "sentence1" in non_label_column_names 466 | and "sentence2" in non_label_column_names 467 | ): 468 | sentence1_key, sentence2_key = "sentence1", "sentence2" 469 | else: 470 | if len(non_label_column_names) >= 2: 471 | sentence1_key, sentence2_key = non_label_column_names[:2] 472 | else: 473 | sentence1_key, sentence2_key = non_label_column_names[0], None 474 | 475 | # Padding strategy 476 | if data_args.pad_to_max_length: 477 | padding = "max_length" 478 | else: 479 | # We will pad later, dynamically at batch creation, to the max sequence length in each batch 480 | padding = False 481 | 482 | # Some models have set the order of the labels to use, so let's make sure we do use it. 483 | label_to_id = None 484 | if ( 485 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id 486 | and data_args.task_name is not None 487 | and not is_regression 488 | ): 489 | # Some have all caps in their config, some don't. 490 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} 491 | if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): 492 | label_to_id = { 493 | i: int(label_name_to_id[label_list[i]]) for i in range(num_labels) 494 | } 495 | else: 496 | logger.warning( 497 | "Your model seems to have been trained with labels, but they don't match the dataset: ", 498 | f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." 499 | "\nIgnoring the model labels as a result.", 500 | ) 501 | elif data_args.task_name is None and not is_regression: 502 | label_to_id = {v: i for i, v in enumerate(label_list)} 503 | 504 | if label_to_id is not None: 505 | model.config.label2id = label_to_id 506 | model.config.id2label = {id: label for label, id in config.label2id.items()} 507 | 508 | if data_args.max_seq_length > tokenizer.model_max_length: 509 | logger.warning( 510 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 511 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 512 | ) 513 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 514 | 515 | def preprocess_function(examples): 516 | # Tokenize the texts 517 | args = ( 518 | (examples[sentence1_key],) 519 | if sentence2_key is None 520 | else (examples[sentence1_key], examples[sentence2_key]) 521 | ) 522 | result = tokenizer( 523 | *args, padding=padding, max_length=max_seq_length, truncation=True 524 | ) 525 | 526 | # Map labels to IDs (not necessary for GLUE tasks) 527 | if label_to_id is not None and "label" in examples: 528 | result["label"] = [ 529 | (label_to_id[l] if l != -1 else -1) for l in examples["label"] 530 | ] 531 | return result 532 | 533 | if model_args.retrieval_pretraining: 534 | # process wikitext dataset 535 | column_names = datasets["train"].column_names 536 | 537 | text_column_name = "text" 538 | def tokenize_function(examples): 539 | examples["text"] = [line for line in examples["text"] if line is not None and len(line) > 0 and not line.isspace()] 540 | return tokenizer(examples[text_column_name], return_special_tokens_mask=True) 541 | 542 | tokenized_datasets = datasets.map( 543 | tokenize_function, 544 | batched=True, 545 | # num_proc=data_args.preprocessing_num_workers, 546 | remove_columns=column_names, 547 | load_from_cache_file=True, 548 | ) 549 | 550 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of 551 | # max_seq_length. 552 | def group_texts(examples): 553 | # Concatenate all texts. 554 | concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} 555 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 556 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 557 | # customize this part to your needs. 558 | total_length = (total_length // max_seq_length) * max_seq_length 559 | # Split by chunks of max_len. 560 | result = { 561 | k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] 562 | for k, t in concatenated_examples.items() 563 | } 564 | random_key = list(result.keys())[0] 565 | len_result = len(result[random_key]) 566 | result["label"] = [random.randint(0, num_labels-1) for _ in range(len_result)] 567 | return result 568 | 569 | # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a 570 | # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value 571 | # might be slower to preprocess. 572 | # 573 | # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: 574 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map 575 | 576 | tokenized_datasets = tokenized_datasets.map( 577 | group_texts, 578 | batched=True, 579 | # num_proc=data_args.preprocessing_num_workers, 580 | load_from_cache_file=True, 581 | ) 582 | datasets = tokenized_datasets 583 | 584 | else: 585 | datasets = datasets.map( 586 | preprocess_function, 587 | batched=True, 588 | load_from_cache_file=not data_args.overwrite_cache, 589 | # desc="Running tokenizer on dataset", 590 | ) 591 | # if training_args.do_train: 592 | if "train" not in datasets: 593 | raise ValueError("--do_train requires a train dataset") 594 | train_dataset = datasets["train"] 595 | if data_args.max_train_samples is not None: 596 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 597 | 598 | # if training_args.do_eval: 599 | if "validation" not in datasets and "validation_matched" not in datasets: 600 | raise ValueError("--do_eval requires a validation dataset") 601 | eval_dataset = datasets[ 602 | "validation_matched" if data_args.task_name == "mnli" else "validation" 603 | ] 604 | if data_args.max_eval_samples is not None: 605 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 606 | 607 | if ( 608 | training_args.do_predict 609 | or data_args.task_name is not None 610 | or data_args.test_file is not None 611 | ): 612 | if "test" not in datasets and "test_matched" not in datasets: 613 | raise ValueError("--do_predict requires a test dataset") 614 | predict_dataset = datasets[ 615 | "test_matched" if data_args.task_name == "mnli" else "test" 616 | ] 617 | if data_args.max_predict_samples is not None: 618 | predict_dataset = predict_dataset.select( 619 | range(data_args.max_predict_samples) 620 | ) 621 | 622 | # Log a few random samples from the training set: 623 | if training_args.do_train: 624 | for index in random.sample(range(len(train_dataset)), 3): 625 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 626 | 627 | # Get the metric function 628 | if data_args.task_name is not None: 629 | metric = load_metric("glue", data_args.task_name) 630 | else: 631 | metric = load_metric("accuracy") 632 | 633 | # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a 634 | # predictions and label_ids field) and has to return a dictionary string to float. 635 | def compute_metrics(p: EvalPrediction): 636 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 637 | preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) 638 | if data_args.task_name is not None: 639 | result = metric.compute(predictions=preds, references=p.label_ids) 640 | if len(result) > 1: 641 | result["combined_score"] = np.mean(list(result.values())).item() 642 | return result 643 | elif is_regression: 644 | return {"mse": ((preds - p.label_ids) ** 2).mean().item()} 645 | else: 646 | return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} 647 | 648 | # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding. 649 | if data_args.pad_to_max_length: 650 | data_collator = default_data_collator 651 | elif training_args.fp16: 652 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) 653 | else: 654 | data_collator = None 655 | 656 | trainer = MuxTrainer( 657 | model=model, 658 | args=training_args, 659 | train_dataset=train_dataset, 660 | eval_dataset=eval_dataset, 661 | compute_metrics=compute_metrics, 662 | tokenizer=tokenizer, 663 | data_collator=data_collator, 664 | ) 665 | # Training 666 | if training_args.do_train: 667 | logger.info("*** Train ***") 668 | 669 | checkpoint = None 670 | if last_checkpoint is not None: 671 | checkpoint = last_checkpoint 672 | print("checkpoint", checkpoint) 673 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 674 | metrics = train_result.metrics 675 | max_train_samples = ( 676 | data_args.max_train_samples 677 | if data_args.max_train_samples is not None 678 | else len(train_dataset) 679 | ) 680 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 681 | 682 | trainer.save_model() # Saves the tokenizer too for easy upload 683 | 684 | trainer.log_metrics("train", metrics) 685 | trainer.save_metrics("train", metrics) 686 | trainer.save_state() 687 | 688 | # Evaluation 689 | if training_args.do_eval: 690 | logger.info("*** Evaluate ***") 691 | # Loop to handle MNLI double evaluation (matched, mis-matched) 692 | tasks = [data_args.task_name] 693 | eval_datasets = [eval_dataset] 694 | # if data_args.task_name == "mnli": 695 | # tasks.append("mnli-mm") 696 | # eval_datasets.append(datasets["validation_mismatched"]) 697 | 698 | for eval_dataset, task in zip(eval_datasets, tasks): 699 | metrics = trainer.evaluate(eval_dataset=eval_dataset, speed_metrics=True) 700 | 701 | max_eval_samples = ( 702 | data_args.max_eval_samples 703 | if data_args.max_eval_samples is not None 704 | else len(eval_dataset) 705 | ) 706 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 707 | 708 | trainer.log_metrics("eval", metrics) 709 | trainer.save_metrics("eval", metrics) 710 | 711 | if training_args.do_predict: 712 | logger.info("*** Predict ***") 713 | 714 | # Loop to handle MNLI double evaluation (matched, mis-matched) 715 | tasks = [data_args.task_name] 716 | predict_datasets = [predict_dataset] 717 | if data_args.task_name == "mnli": 718 | tasks.append("mnli-mm") 719 | predict_datasets.append(datasets["test_mismatched"]) 720 | 721 | for predict_dataset, task in zip(predict_datasets, tasks): 722 | # Removing the `label` columns because it contains -1 and Trainer won't like that. 723 | predict_dataset.remove_columns_("label") 724 | predictions = trainer.predict( 725 | predict_dataset, metric_key_prefix="predict" 726 | ).predictions 727 | predictions = ( 728 | np.squeeze(predictions) 729 | if is_regression 730 | else np.argmax(predictions, axis=1) 731 | ) 732 | 733 | output_predict_file = os.path.join( 734 | training_args.output_dir, f"predict_results_{task}.txt" 735 | ) 736 | if trainer.is_world_process_zero(): 737 | with open(output_predict_file, "w") as writer: 738 | logger.info(f"***** Predict results {task} *****") 739 | writer.write("index\tprediction\n") 740 | for index, item in enumerate(predictions): 741 | if is_regression: 742 | writer.write(f"{index}\t{item:3.3f}\n") 743 | else: 744 | item = label_list[item] 745 | writer.write(f"{index}\t{item}\n") 746 | 747 | 748 | def _mp_fn(index): 749 | # For xla_spawn (TPUs) 750 | main() 751 | 752 | 753 | if __name__ == "__main__": 754 | main() -------------------------------------------------------------------------------- /run_glue.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # flag to run with slurm commands 4 | USE_SLURM=0 5 | 6 | # defaults 7 | 8 | NUM_INSTANCES=1 9 | DEMUXING="index" 10 | MUXING="gaussian_hadamard" 11 | CONFIG_NAME="configs/ablations/base_model/roberta.json" 12 | LEARNING_RATE=5e-5 13 | TASK_NAME="mnli" 14 | LEARN_MUXING=0 15 | CONTINUE_TRAIN=0 16 | DO_TRAIN=0 17 | DO_EVAL=0 18 | GRADIENT_ACCUMULATION=1 19 | # commmand line arguments 20 | #!/bin/bash 21 | 22 | show_help() { 23 | echo 'Usage run_glue.sh [OPTIONS]' 24 | echo 'options:' 25 | echo '-N --num_instances [2,5,10,20,40]' 26 | echo '-d --demuxing [index, mlp]' 27 | echo '-m --muxing [gaussian_hadamard, binary_hadamard, random_ortho]' 28 | echo '-s --setting [baseline, finetuning, retrieval_pretraining]' 29 | echo '--task [mnli, qnli, sst2, qqp]' 30 | echo '--config_name CONFIG_NAME' 31 | echo '--lr LR' 32 | echo '--batch_size BATCH_SIZE' 33 | echo '--model_path MODEL_PATH' 34 | echo '--learn_muxing' 35 | echo '--continue' 36 | echo '--do_train' 37 | echo '--do_eval' 38 | } 39 | 40 | die() { 41 | printf '%s\n' "$1" >&2 42 | exit 1 43 | } 44 | 45 | while :; do 46 | case $1 in 47 | -h|-\?|--help) 48 | show_help # Display a usage synopsis. 49 | exit 50 | ;; 51 | 52 | -N|--num_instances) # Takes an option argument; ensure it has been specified. 53 | if [ "$2" ]; then 54 | NUM_INSTANCES=$2 55 | shift # shift consumes $2 without treating it as another argument 56 | else 57 | die 'ERROR: "--num-instances" requires a non-empty option argument.' 58 | fi 59 | ;; 60 | 61 | -d|--demuxing) 62 | if [ "$2" ]; then 63 | DEMUXING=$2 64 | shift 65 | else 66 | die 'ERROR: "--demuxing" requires a non-empty option argument.' 67 | fi 68 | ;; 69 | 70 | -m|--muxing) 71 | if [ "$2" ]; then 72 | MUXING=$2 73 | shift 74 | else 75 | die 'ERROR: "--muxing" requires a non-empty option argument.' 76 | fi 77 | ;; 78 | 79 | -s|--setting) 80 | if [ "$2" ]; then 81 | SETTING=$2 82 | shift 83 | else 84 | die 'ERROR: "--setting" requires a non-empty option argument.' 85 | fi 86 | ;; 87 | 88 | --config_name) 89 | if [ "$2" ]; then 90 | CONFIG_NAME=$2 91 | shift 92 | else 93 | die 'ERROR: "--config_name" requires a non-empty option argument.' 94 | fi 95 | ;; 96 | 97 | --lr) 98 | if [ "$2" ]; then 99 | LEARNING_RATE=$2 100 | shift 101 | else 102 | die 'ERROR: "--lr" requires a non-empty option argument.' 103 | fi 104 | ;; 105 | 106 | --batch_size) 107 | if [ "$2" ]; then 108 | BATCH_SIZE=$2 109 | shift 110 | else 111 | die 'ERROR: "--batch_size" requires a non-empty option argument.' 112 | fi 113 | ;; 114 | 115 | --task) 116 | if [ "$2" ]; then 117 | TASK_NAME=$2 118 | shift 119 | else 120 | die 'ERROR: "--task" requires a non-empty option argument.' 121 | fi 122 | ;; 123 | 124 | --gradient_accumulation) 125 | if [ "$2" ]; then 126 | GRADIENT_ACCUMULATION=$2 127 | shift 128 | else 129 | die 'ERROR: "--gradient_accumulation" requires a non-empty option argument.' 130 | fi 131 | ;; 132 | 133 | --model_path) 134 | if [ "$2" ]; then 135 | MODEL_PATH=$2 136 | shift 137 | else 138 | die 'ERROR: "--model_path" requires a non-empty option argument.' 139 | fi 140 | ;; 141 | 142 | --learn_muxing) 143 | LEARN_MUXING=1 144 | ;; 145 | 146 | --do_train) 147 | DO_TRAIN=1 148 | ;; 149 | 150 | --do_eval) 151 | DO_EVAL=1 152 | ;; 153 | 154 | --) # End of all options. 155 | shift 156 | break 157 | ;; 158 | -?*) 159 | die "ERROR: Unknown option : ${1}" 160 | ;; 161 | *) # Default case: No more options, so break out of the loop. 162 | break 163 | esac 164 | 165 | shift 166 | done 167 | 168 | # other miscelleneous params 169 | SAVE_STEPS=10000 170 | MAX_SEQ_LENGTH=128 171 | 172 | if [ "$SETTING" == "retrieval_pretraining" ]; then 173 | 174 | RANDOM_ENCODING_NORM=20 175 | RETRIEVAL_PERCENTAGE=1.0 176 | RETRIEVAL_PRETRAINING=1 177 | RETRIEVAL_LOSS_COEFF=1 178 | TASK_LOSS_COEFF=0 179 | SHOULD_MUX=1 180 | DATALOADER_DROP_LAST=1 181 | OUTPUT_DIR_BASE="checkpoints/retrieval_pretraining" 182 | 183 | # params diff 184 | DATASET_NAME="wikitext" 185 | DATASET_CONFIG_NAME="wikitext-103-raw-v1" 186 | CMD_DIFF="--dataset_name ${DATASET_NAME}\ 187 | --dataset_config_name ${DATASET_CONFIG_NAME} \ 188 | --evaluation_strategy steps \ 189 | --eval_steps 10000 \ 190 | --max_steps 500000 \ 191 | --save_steps 10000" 192 | 193 | elif [ "$SETTING" = "finetuning" ]; then 194 | 195 | RANDOM_ENCODING_NORM=1 196 | RETRIEVAL_PERCENTAGE=1.0 197 | RETRIEVAL_PRETRAINING=0 198 | RETRIEVAL_LOSS_COEFF=0.1 199 | TASK_LOSS_COEFF=0.9 200 | SHOULD_MUX=1 201 | DATALOADER_DROP_LAST=1 202 | OUTPUT_DIR_BASE="checkpoints/finetune" 203 | 204 | # params diff 205 | # add task name 206 | # save steps + save strategy + num epochs 207 | 208 | CMD_DIFF="--task_name ${TASK_NAME}\ 209 | --evaluation_strategy steps \ 210 | --eval_steps 10000 \ 211 | --max_steps 500000 \ 212 | --save_steps 10000 " 213 | 214 | elif [ "$SETTING" = "baseline" ]; then 215 | 216 | echo "Setting is baseline; sets --num-instances to 1." 217 | RANDOM_ENCODING_NORM=1 218 | RETRIEVAL_PERCENTAGE=1.0 219 | RETRIEVAL_PRETRAINING=0 220 | RETRIEVAL_LOSS_COEFF=0 221 | TASK_LOSS_COEFF=1 222 | SHOULD_MUX=0 223 | DATALOADER_DROP_LAST=0 224 | OUTPUT_DIR_BASE="checkpoints/baselines" 225 | NUM_INSTANCES=1 226 | # add task name 227 | # save steps + save strategy + num epochs 228 | CMD_DIFF="--task_name ${TASK_NAME}\ 229 | --evaluation_strategy epoch \ 230 | --num_train_epochs 10" 231 | else 232 | echo "setting (${SETTING}) not recognized or unset. run \"run_glue.sh -h\" for usage." 233 | exit 0 234 | fi 235 | 236 | 237 | if [[ $LEARN_MUXING -ge 1 ]]; then 238 | OUTPUT_DIR=$OUTPUT_DIR_BASE/${TASK_NAME}_${MODEL_PATH}_${MUXING}_${DEMUXING}_${NUM_INSTANCES}_norm_${RANDOM_ENCODING_NORM}_rc_${RETRIEVAL_LOSS_COEFF}_lr${LEARNING_RATE}_tc_${TASK_LOSS_COEFF}_${CONFIG_NAME}_learntmuxing 239 | RUN_NAME=${TASK_NAME}_${MODEL_PATH}_${MUXING}_${DEMUXING}_${NUM_INSTANCES}_${RETRIEVAL_PERCENTAGE}_norm_${RANDOM_ENCODING_NORM}_rc_${RETRIEVAL_LOSS_COEFF}_lr${LEARNING_RATE}_tc_${TASK_LOSS_COEFF}_${CONFIG_NAME}_learnmuxing 240 | else 241 | OUTPUT_DIR=$OUTPUT_DIR_BASE/${TASK_NAME}_${MODEL_PATH}_${MUXING}_${DEMUXING}_${NUM_INSTANCES}_norm_${RANDOM_ENCODING_NORM}_rc_${RETRIEVAL_LOSS_COEFF}_lr${LEARNING_RATE}_tc_${TASK_LOSS_COEFF}_${CONFIG_NAME}_${GRADIENT_ACCUMULATION}_${RETRIEVAL_LOSS_VOCAB_SCALE} 242 | RUN_NAME=${TASK_NAME}_${MODEL_PATH}_${MUXING}_${DEMUXING}_${NUM_INSTANCES}_${RETRIEVAL_PERCENTAGE}_norm_${RANDOM_ENCODING_NORM}_rc_${RETRIEVAL_LOSS_COEFF}_lr${LEARNING_RATE}_tc_${TASK_LOSS_COEFF}_${CONFIG_NAME}_${GRADIENT_ACCUMULATION}_${RETRIEVAL_LOSS_VOCAB_SCALE} 243 | fi 244 | 245 | if [ -z "$BATCH_SIZE" ] # if BATCH_SIZE is not set manually 246 | then 247 | if [[ $NUM_INSTANCES -ge 40 ]] 248 | then 249 | BATCH_SIZE=16 250 | 251 | elif [[ $NUM_INSTANCES -ge 20 ]] 252 | then 253 | BATCH_SIZE=20 254 | elif [[ $NUM_INSTANCES -ge 2 ]] 255 | then 256 | BATCH_SIZE=24 257 | else 258 | BATCH_SIZE=32 259 | fi 260 | fi 261 | 262 | 263 | BATCH_SIZE=$(($BATCH_SIZE * NUM_INSTANCES)) 264 | 265 | CMD="python run_glue.py \ 266 | --tokenizer_name roberta-base \ 267 | --config_name ${CONFIG_NAME} \ 268 | --max_seq_length $MAX_SEQ_LENGTH \ 269 | --per_device_train_batch_size $BATCH_SIZE \ 270 | --per_device_eval_batch_size $BATCH_SIZE \ 271 | --learning_rate $LEARNING_RATE \ 272 | --output_dir $OUTPUT_DIR \ 273 | --run_name $RUN_NAME \ 274 | --logging_steps 100 \ 275 | --report_to wandb \ 276 | --dataloader_drop_last $DATALOADER_DROP_LAST \ 277 | --retrieval_percentage $RETRIEVAL_PERCENTAGE \ 278 | --retrieval_loss_coeff $RETRIEVAL_LOSS_COEFF \ 279 | --task_loss_coeff $TASK_LOSS_COEFF \ 280 | --retrieval_pretraining ${RETRIEVAL_PRETRAINING} \ 281 | --num_instances ${NUM_INSTANCES} \ 282 | --muxing_variant ${MUXING} \ 283 | --demuxing_variant ${DEMUXING} \ 284 | --should_mux ${SHOULD_MUX} \ 285 | --gaussian_hadamard_norm ${RANDOM_ENCODING_NORM} \ 286 | --learn_muxing ${LEARN_MUXING} \ 287 | --gradient_accumulation_steps ${GRADIENT_ACCUMULATION} \ 288 | --load_best_model_at_end 1 \ 289 | --metric_for_best_model eval_accuracy \ 290 | --save_total_limit 1 \ 291 | --dataloader_num_workers 8" 292 | 293 | if [ "$DO_TRAIN" -eq 1 ]; then 294 | CMD="${CMD} --do_train" 295 | fi 296 | if [ "$DO_EVAL" -eq 1 ]; then 297 | CMD="${CMD} --do_eval" 298 | fi 299 | 300 | if [ ! -z "$MODEL_PATH" ] # if MODEL PATH is set manually 301 | then 302 | CMD="${CMD} --model_name_or_path ${MODEL_PATH}" 303 | fi 304 | 305 | CMD=${CMD}" "${CMD_DIFF} 306 | 307 | if [[ $NUM_INSTANCES -ge 40 ]] 308 | then 309 | TIME="120:00:00" 310 | elif [[ $NUM_INSTANCES -ge 20 ]] 311 | then 312 | TIME="72:00:00" 313 | else 314 | TIME="30:00:00" 315 | fi 316 | 317 | echo "Running command with arguments:" 318 | echo $CMD 319 | echo $USE_SLURM 320 | if [[ $USE_SLURM = 1 ]]; then 321 | sbatch --time=$TIME --mem=32G --output=logs/%x-%j.out --job-name=${TASK_NAME}_${NUM_INSTANCES}_${MUXING}_${DEMUXING} --gres=gpu:rtx_3090:1 -A pnlp ./run_job.sh \ 322 | "$CMD" 323 | else 324 | ./run_job.sh "$CMD" 325 | fi 326 | -------------------------------------------------------------------------------- /run_job.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export WANDB_NOTES=$SLURM_JOB_ID 3 | for var in "$@" 4 | do 5 | $var 6 | done 7 | -------------------------------------------------------------------------------- /run_ner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Team All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for token classification. 18 | """ 19 | # You can also adapt this script on your own token classification task and datasets. Pointers for this are left as 20 | # comments. 21 | 22 | import logging 23 | import os 24 | import sys 25 | from dataclasses import dataclass, field 26 | from typing import Optional 27 | 28 | import numpy as np 29 | from datasets import ClassLabel, load_dataset, load_metric 30 | 31 | from transformers import ( 32 | AutoConfig, 33 | AutoModelForTokenClassification, 34 | AutoTokenizer, 35 | DataCollatorForTokenClassification, 36 | HfArgumentParser, 37 | PreTrainedTokenizerFast, 38 | TrainingArguments, 39 | set_seed, 40 | ) 41 | from transformers.trainer_utils import get_last_checkpoint 42 | from models.multiplexing import RobertaTokenClassificationMuxed 43 | from models.trainer import MuxTrainer 44 | import torch 45 | 46 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 47 | # check_min_version("4.13.0.dev0") 48 | 49 | # require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt") 50 | 51 | logger = logging.getLogger(__name__) 52 | 53 | @dataclass 54 | class DataTrainingArguments: 55 | """ 56 | Arguments pertaining to what data we are going to input our model for training and eval. 57 | Using `HfArgumentParser` we can turn this class 58 | into argparse arguments to be able to specify them on 59 | the command line. 60 | """ 61 | 62 | task_name: Optional[str] = field(default="ner", metadata={"help": "The name of the task (ner, pos...)."}) 63 | 64 | dataset_name: Optional[str] = field( 65 | default=None, 66 | metadata={"help": "The name of the dataset to use (via the datasets library)."}, 67 | ) 68 | dataset_config_name: Optional[str] = field( 69 | default=None, 70 | metadata={ 71 | "help": "The configuration name of the dataset to use (via the datasets library)." 72 | }, 73 | ) 74 | max_seq_length: int = field( 75 | default=128, 76 | metadata={ 77 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 78 | "than this will be truncated, sequences shorter will be padded." 79 | }, 80 | ) 81 | overwrite_cache: bool = field( 82 | default=False, 83 | metadata={"help": "Overwrite the cached preprocessed datasets or not."}, 84 | ) 85 | pad_to_max_length: bool = field( 86 | default=True, 87 | metadata={ 88 | "help": "Whether to pad all samples to `max_seq_length`. " 89 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 90 | }, 91 | ) 92 | max_train_samples: Optional[int] = field( 93 | default=None, 94 | metadata={ 95 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 96 | "value if set." 97 | }, 98 | ) 99 | max_eval_samples: Optional[int] = field( 100 | default=None, 101 | metadata={ 102 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 103 | "value if set." 104 | }, 105 | ) 106 | max_predict_samples: Optional[int] = field( 107 | default=None, 108 | metadata={ 109 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 110 | "value if set." 111 | }, 112 | ) 113 | train_file: Optional[str] = field( 114 | default=None, 115 | metadata={"help": "A csv or a json file containing the training data."}, 116 | ) 117 | validation_file: Optional[str] = field( 118 | default=None, 119 | metadata={"help": "A csv or a json file containing the validation data."}, 120 | ) 121 | test_file: Optional[str] = field( 122 | default=None, 123 | metadata={"help": "A csv or a json file containing the test data."}, 124 | ) 125 | text_column_name: Optional[str] = field( 126 | default=None, metadata={"help": "The column name of text to input in the file (a csv or JSON file)."} 127 | ) 128 | label_column_name: Optional[str] = field( 129 | default=None, metadata={"help": "The column name of label to input in the file (a csv or JSON file)."} 130 | ) 131 | 132 | preprocessing_num_workers: Optional[int] = field( 133 | default=None, 134 | metadata={"help": "The number of processes to use for the preprocessing."}, 135 | ) 136 | label_all_tokens: bool = field( 137 | default=False, 138 | metadata={ 139 | "help": "Whether to put the label for one word on all tokens of generated by that word or just on the " 140 | "one (in which case the other tokens will have a padding index)." 141 | }, 142 | ) 143 | return_entity_level_metrics: bool = field( 144 | default=False, 145 | metadata={"help": "Whether to return all the entity levels during evaluation or just the overall ones."}, 146 | ) 147 | 148 | def __post_init__(self): 149 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 150 | raise ValueError("Need either a dataset name or a training/validation file.") 151 | else: 152 | if self.train_file is not None: 153 | extension = self.train_file.split(".")[-1] 154 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 155 | if self.validation_file is not None: 156 | extension = self.validation_file.split(".")[-1] 157 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 158 | self.task_name = self.task_name.lower() 159 | 160 | @dataclass 161 | class ModelArguments: 162 | """ 163 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 164 | """ 165 | 166 | model_name_or_path: str = field( 167 | default=None, 168 | metadata={ 169 | "help": "Path to pretrained model or model identifier from huggingface.co/models" 170 | } 171 | ) 172 | config_name: Optional[str] = field( 173 | default=None, 174 | metadata={ 175 | "help": "Pretrained config name or path if not the same as model_name" 176 | }, 177 | ) 178 | tokenizer_name: Optional[str] = field( 179 | default=None, 180 | metadata={ 181 | "help": "Pretrained tokenizer name or path if not the same as model_name" 182 | }, 183 | ) 184 | cache_dir: Optional[str] = field( 185 | default=None, 186 | metadata={ 187 | "help": "Where do you want to store the pretrained models downloaded from huggingface.co" 188 | }, 189 | ) 190 | use_fast_tokenizer: bool = field( 191 | default=True, 192 | metadata={ 193 | "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." 194 | }, 195 | ) 196 | model_revision: str = field( 197 | default="main", 198 | metadata={ 199 | "help": "The specific model version to use (can be a branch name, tag name or commit id)." 200 | }, 201 | ) 202 | use_auth_token: bool = field( 203 | default=False, 204 | metadata={ 205 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 206 | "with private models)." 207 | }, 208 | ) 209 | # multi instance arguments 210 | num_instances: Optional[int] = field( 211 | default=5, 212 | metadata={"help": "Number of instances i.e. N"}, 213 | ) 214 | muxing_variant: Optional[str] = field( 215 | default="gaussian_hadamard", 216 | metadata={"help": "muxing variant; choose from gaussian_hadamard or random_ortho or binary_hadamard"}, 217 | ) 218 | demuxing_variant: Optional[str] = field( 219 | default="index", 220 | metadata={"help": "demuxing variant, choose from 'index' or 'mlp'"}, 221 | ) 222 | should_mux: Optional[int] = field( 223 | default=1, 224 | metadata={"help": "whether to mux, turn off for non-multiplexed baselines"}, 225 | ) 226 | retrieval_percentage: Optional[float] = field( 227 | default=1.0, 228 | metadata={"help": "percentage of tokens to retrieve during inference"}, 229 | ) 230 | retrieval_pretraining: Optional[int] = field( 231 | default=0, 232 | metadata={"help": "Retrieval Pretraining"}, 233 | ) 234 | gaussian_hadamard_norm: Optional[float] = field( 235 | default=1, 236 | metadata={"help": "Norm of sentence embeddings if we use random projections"}, 237 | ) 238 | binary_hadamard_epsilon: Optional[float] = field( 239 | default=0, 240 | metadata={"help": "Percentage intersection among binary vectors, default is no intersection"}, 241 | ) 242 | retrieval_loss_coeff: Optional[float] = field( 243 | default=0.1, 244 | metadata={"help": "Coefficient for retrieval loss"}, 245 | ) 246 | task_loss_coeff: Optional[float] = field( 247 | default=0.9, 248 | metadata={"help": "Coefficient for task loss"}, 249 | ) 250 | learn_muxing: Optional[int] = field( 251 | default=0, 252 | metadata={"help": "whether instance embeddings are learnt or not"}, 253 | ) 254 | 255 | 256 | def main(): 257 | # See all possible arguments in src/transformers/training_args.py 258 | # or by passing the --help flag to this script. 259 | # We now keep distinct sets of args, for a cleaner separation of concerns. 260 | 261 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 262 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 263 | # If we pass only one argument to the script and it's the path to a json file, 264 | # let's parse it to get our arguments. 265 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 266 | else: 267 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 268 | 269 | # Setup logging 270 | logging.basicConfig( 271 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 272 | datefmt="%m/%d/%Y %H:%M:%S", 273 | handlers=[logging.StreamHandler(sys.stdout)], 274 | ) 275 | 276 | # log_level = training_args.get_process_log_level() 277 | # logger.setLevel(log_level) 278 | # datasets.utils.logging.set_verbosity(log_level) 279 | # transformers.utils.logging.set_verbosity(log_level) 280 | # transformers.utils.logging.enable_default_handler() 281 | # transformers.utils.logging.enable_explicit_format() 282 | 283 | # Log on each process the small summary: 284 | logger.warning( 285 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 286 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 287 | ) 288 | logger.info(f"Training/evaluation parameters {training_args}") 289 | 290 | # Detecting last checkpoint. 291 | last_checkpoint = None 292 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 293 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 294 | 295 | # Set seed before initializing model. 296 | set_seed(training_args.seed) 297 | 298 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 299 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 300 | # (the dataset will be downloaded automatically from the datasets Hub). 301 | # 302 | # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called 303 | # 'text' is found. You can easily tweak this behavior (see below). 304 | # 305 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 306 | # download the dataset. 307 | if data_args.dataset_name is not None: 308 | # Downloading and loading a dataset from the hub. 309 | raw_datasets = load_dataset( 310 | data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir 311 | ) 312 | else: 313 | data_files = {} 314 | if data_args.train_file is not None: 315 | data_files["train"] = data_args.train_file 316 | if data_args.validation_file is not None: 317 | data_files["validation"] = data_args.validation_file 318 | if data_args.test_file is not None: 319 | data_files["test"] = data_args.test_file 320 | extension = data_args.train_file.split(".")[-1] 321 | raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) 322 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 323 | # https://huggingface.co/docs/datasets/loading_datasets.html. 324 | 325 | if training_args.do_train: 326 | column_names = raw_datasets["train"].column_names 327 | features = raw_datasets["train"].features 328 | else: 329 | column_names = raw_datasets["validation"].column_names 330 | features = raw_datasets["validation"].features 331 | 332 | if data_args.text_column_name is not None: 333 | text_column_name = data_args.text_column_name 334 | elif "tokens" in column_names: 335 | text_column_name = "tokens" 336 | else: 337 | text_column_name = column_names[0] 338 | 339 | if data_args.label_column_name is not None: 340 | label_column_name = data_args.label_column_name 341 | elif f"{data_args.task_name}_tags" in column_names: 342 | label_column_name = f"{data_args.task_name}_tags" 343 | else: 344 | label_column_name = column_names[1] 345 | 346 | # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the 347 | # unique labels. 348 | def get_label_list(labels): 349 | unique_labels = set() 350 | for label in labels: 351 | unique_labels = unique_labels | set(label) 352 | label_list = list(unique_labels) 353 | label_list.sort() 354 | return label_list 355 | 356 | if isinstance(features[label_column_name].feature, ClassLabel): 357 | label_list = features[label_column_name].feature.names 358 | # No need to convert the labels since they are already ints. 359 | label_to_id = {i: i for i in range(len(label_list))} 360 | else: 361 | label_list = get_label_list(raw_datasets["train"][label_column_name]) 362 | label_to_id = {l: i for i, l in enumerate(label_list)} 363 | num_labels = len(label_list) 364 | 365 | # Map that sends B-Xxx label to its I-Xxx counterpart 366 | b_to_i_label = [] 367 | for idx, label in enumerate(label_list): 368 | if label.startswith("B-") and label.replace("B-", "I-") in label_list: 369 | b_to_i_label.append(label_list.index(label.replace("B-", "I-"))) 370 | else: 371 | b_to_i_label.append(idx) 372 | 373 | # Load pretrained model and tokenizer 374 | # 375 | # Distributed training: 376 | # The .from_pretrained methods guarantee that only one local process can concurrently 377 | # download model & vocab. 378 | 379 | 380 | config = AutoConfig.from_pretrained( 381 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 382 | num_labels=num_labels, 383 | label2id=label_to_id, 384 | id2label={i: l for l, i in label_to_id.items()}, 385 | finetuning_task=data_args.task_name, 386 | cache_dir=model_args.cache_dir, 387 | revision=model_args.model_revision, 388 | use_auth_token=True if model_args.use_auth_token else None, 389 | ) 390 | 391 | config.num_instances = model_args.num_instances 392 | config.muxing_variant = model_args.muxing_variant 393 | config.demuxing_variant = model_args.demuxing_variant 394 | config.retrieval_percentage = model_args.retrieval_percentage 395 | config.gaussian_hadamard_norm = model_args.gaussian_hadamard_norm 396 | config.binary_hadamard_epsilon = model_args.binary_hadamard_epsilon 397 | config.retrieval_loss_coeff = model_args.retrieval_loss_coeff 398 | config.task_loss_coeff = model_args.task_loss_coeff 399 | config.learn_muxing = model_args.learn_muxing 400 | 401 | tokenizer_name_or_path = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path 402 | if config.model_type in {"gpt2", "roberta"}: 403 | tokenizer = AutoTokenizer.from_pretrained( 404 | tokenizer_name_or_path, 405 | cache_dir=model_args.cache_dir, 406 | use_fast=True, 407 | revision=model_args.model_revision, 408 | use_auth_token=True if model_args.use_auth_token else None, 409 | add_prefix_space=True, 410 | ) 411 | else: 412 | tokenizer = AutoTokenizer.from_pretrained( 413 | tokenizer_name_or_path, 414 | cache_dir=model_args.cache_dir, 415 | use_fast=True, 416 | revision=model_args.model_revision, 417 | use_auth_token=True if model_args.use_auth_token else None, 418 | ) 419 | 420 | 421 | model_path_supplied = model_args.model_name_or_path is not None 422 | if model_args.should_mux: 423 | 424 | if model_path_supplied: 425 | model = RobertaTokenClassificationMuxed.from_pretrained(model_args.model_name_or_path, config=config) 426 | else: 427 | model = RobertaTokenClassificationMuxed(config=config) 428 | else: 429 | # non-multiplexed baseline 430 | if model_path_supplied: 431 | model = AutoModelForTokenClassification.from_pretrained(model_args.model_name_or_path, config=config) 432 | else: 433 | model = AutoModelForTokenClassification(config=config) 434 | 435 | # Tokenizer check: this script requires a fast tokenizer. 436 | if not isinstance(tokenizer, PreTrainedTokenizerFast): 437 | raise ValueError( 438 | "This example script only works for models that have a fast tokenizer. Checkout the big table of models " 439 | "at https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet this " 440 | "requirement" 441 | ) 442 | 443 | # Preprocessing the dataset 444 | # Padding strategy 445 | padding = "max_length" if data_args.pad_to_max_length else False 446 | 447 | # Tokenize all texts and align the labels with them. 448 | def tokenize_and_align_labels(examples): 449 | tokenized_inputs = tokenizer( 450 | examples[text_column_name], 451 | padding=padding, 452 | truncation=True, 453 | max_length=data_args.max_seq_length, 454 | # We use this argument because the texts in our dataset are lists of words (with a label for each word). 455 | is_split_into_words=True, 456 | ) 457 | labels = [] 458 | for i, label in enumerate(examples[label_column_name]): 459 | word_ids = tokenized_inputs.word_ids(batch_index=i) 460 | previous_word_idx = None 461 | label_ids = [] 462 | for word_idx in word_ids: 463 | # Special tokens have a word id that is None. We set the label to -100 so they are automatically 464 | # ignored in the loss function. 465 | if word_idx is None: 466 | label_ids.append(-100) 467 | # We set the label for the first token of each word. 468 | elif word_idx != previous_word_idx: 469 | label_ids.append(label_to_id[label[word_idx]]) 470 | # For the other tokens in a word, we set the label to either the current label or -100, depending on 471 | # the label_all_tokens flag. 472 | else: 473 | if data_args.label_all_tokens: 474 | label_ids.append(b_to_i_label[label_to_id[label[word_idx]]]) 475 | else: 476 | label_ids.append(-100) 477 | previous_word_idx = word_idx 478 | 479 | labels.append(label_ids) 480 | tokenized_inputs["labels"] = labels 481 | return tokenized_inputs 482 | 483 | if training_args.do_train: 484 | if "train" not in raw_datasets: 485 | raise ValueError("--do_train requires a train dataset") 486 | train_dataset = raw_datasets["train"] 487 | if data_args.max_train_samples is not None: 488 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 489 | 490 | train_dataset = train_dataset.map( 491 | tokenize_and_align_labels, 492 | batched=True, 493 | num_proc=data_args.preprocessing_num_workers, 494 | load_from_cache_file=not data_args.overwrite_cache, 495 | # desc="Running tokenizer on train dataset", 496 | ) 497 | 498 | if training_args.do_eval: 499 | if "validation" not in raw_datasets: 500 | raise ValueError("--do_eval requires a validation dataset") 501 | eval_dataset = raw_datasets["validation"] 502 | if data_args.max_eval_samples is not None: 503 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 504 | eval_dataset = eval_dataset.map( 505 | tokenize_and_align_labels, 506 | batched=True, 507 | num_proc=data_args.preprocessing_num_workers, 508 | load_from_cache_file=not data_args.overwrite_cache, 509 | # desc="Running tokenizer on validation dataset", 510 | ) 511 | 512 | if training_args.do_predict: 513 | if "test" not in raw_datasets: 514 | raise ValueError("--do_predict requires a test dataset") 515 | predict_dataset = raw_datasets["test"] 516 | if data_args.max_predict_samples is not None: 517 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) 518 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 519 | predict_dataset = predict_dataset.map( 520 | tokenize_and_align_labels, 521 | batched=True, 522 | num_proc=data_args.preprocessing_num_workers, 523 | load_from_cache_file=not data_args.overwrite_cache, 524 | # desc="Running tokenizer on prediction dataset", 525 | ) 526 | 527 | # Data collator 528 | data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None) 529 | 530 | # Metrics 531 | metric = load_metric("seqeval") 532 | 533 | def compute_metrics(p): 534 | predictions, labels = p 535 | predictions = np.argmax(predictions, axis=2) 536 | 537 | # Remove ignored index (special tokens) 538 | true_predictions = [ 539 | [label_list[p] for (p, l) in zip(prediction, label) if l != -100] 540 | for prediction, label in zip(predictions, labels) 541 | ] 542 | true_labels = [ 543 | [label_list[l] for (p, l) in zip(prediction, label) if l != -100] 544 | for prediction, label in zip(predictions, labels) 545 | ] 546 | 547 | results = metric.compute(predictions=true_predictions, references=true_labels) 548 | if data_args.return_entity_level_metrics: 549 | # Unpack nested dictionaries 550 | final_results = {} 551 | for key, value in results.items(): 552 | if isinstance(value, dict): 553 | for n, v in value.items(): 554 | final_results[f"{key}_{n}"] = v 555 | else: 556 | final_results[key] = value 557 | return final_results 558 | else: 559 | return { 560 | "precision": results["overall_precision"], 561 | "recall": results["overall_recall"], 562 | "f1": results["overall_f1"], 563 | "accuracy": results["overall_accuracy"], 564 | } 565 | 566 | # Initialize our Trainer 567 | trainer = MuxTrainer( 568 | model=model, 569 | args=training_args, 570 | train_dataset=train_dataset if training_args.do_train else None, 571 | eval_dataset=eval_dataset if training_args.do_eval else None, 572 | tokenizer=tokenizer, 573 | data_collator=data_collator, 574 | compute_metrics=compute_metrics, 575 | ) 576 | 577 | # Training 578 | if training_args.do_train: 579 | logger.info("*** Train ***") 580 | checkpoint = None 581 | if last_checkpoint is not None: 582 | checkpoint = last_checkpoint 583 | 584 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 585 | metrics = train_result.metrics 586 | trainer.save_model() # Saves the tokenizer too for easy upload 587 | 588 | max_train_samples = ( 589 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 590 | ) 591 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 592 | 593 | trainer.log_metrics("train", metrics) 594 | trainer.save_metrics("train", metrics) 595 | trainer.save_state() 596 | 597 | # Evaluation 598 | if training_args.do_eval: 599 | logger.info("*** Evaluate ***") 600 | 601 | metrics = trainer.evaluate() 602 | 603 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 604 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 605 | 606 | trainer.log_metrics("eval", metrics) 607 | trainer.save_metrics("eval", metrics) 608 | 609 | # Predict 610 | if training_args.do_predict: 611 | logger.info("*** Predict ***") 612 | 613 | predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict") 614 | predictions = np.argmax(predictions, axis=2) 615 | 616 | # Remove ignored index (special tokens) 617 | true_predictions = [ 618 | [label_list[p] for (p, l) in zip(prediction, label) if l != -100] 619 | for prediction, label in zip(predictions, labels) 620 | ] 621 | 622 | trainer.log_metrics("predict", metrics) 623 | trainer.save_metrics("predict", metrics) 624 | 625 | # Save predictions 626 | output_predictions_file = os.path.join(training_args.output_dir, "predictions.txt") 627 | if trainer.is_world_process_zero(): 628 | with open(output_predictions_file, "w") as writer: 629 | for prediction in true_predictions: 630 | writer.write(" ".join(prediction) + "\n") 631 | 632 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "token-classification"} 633 | if data_args.dataset_name is not None: 634 | kwargs["dataset_tags"] = data_args.dataset_name 635 | if data_args.dataset_config_name is not None: 636 | kwargs["dataset_args"] = data_args.dataset_config_name 637 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 638 | else: 639 | kwargs["dataset"] = data_args.dataset_name 640 | 641 | 642 | 643 | def _mp_fn(index): 644 | # For xla_spawn (TPUs) 645 | main() 646 | 647 | 648 | if __name__ == "__main__": 649 | main() -------------------------------------------------------------------------------- /run_ner.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # flag to run with slurm commands 4 | USE_SLURM=0 5 | 6 | # defaults 7 | 8 | NUM_INSTANCES=1 9 | DEMUXING="index" 10 | MUXING="gaussian_hadamard" 11 | CONFIG_NAME="configs/ablations/base_model/roberta.json" 12 | LEARNING_RATE=5e-5 13 | TASK_NAME="mnli" 14 | LEARN_MUXING=0 15 | CONTINUE_TRAIN=0 16 | DO_TRAIN=0 17 | DO_EVAL=0 18 | # commmand line arguments 19 | #!/bin/bash 20 | 21 | show_help() { 22 | echo 'Usage run_glue.sh [OPTIONS]' 23 | echo 'options:' 24 | echo '-N --num_instances [2,5,10,20,40]' 25 | echo '-d --demuxing [index, mlp]' 26 | echo '-m --muxing [gaussian_hadamard, binary_hadamard, random_ortho]' 27 | echo '-s --setting [baseline, finetuning, retrieval_pretraining]' 28 | echo '--task [mnli, qnli, sst2, qqp]' 29 | echo '--config_name CONFIG_NAME' 30 | echo '--lr LR' 31 | echo '--batch_size BATCH_SIZE' 32 | echo '--model_path MODEL_PATH' 33 | echo '--learn_muxing' 34 | echo '--continue' 35 | echo '--do_train' 36 | echo '--do_eval' 37 | } 38 | 39 | die() { 40 | printf '%s\n' "$1" >&2 41 | exit 1 42 | } 43 | 44 | while :; do 45 | case $1 in 46 | -h|-\?|--help) 47 | show_help # Display a usage synopsis. 48 | exit 49 | ;; 50 | 51 | -N|--num-instances) # Takes an option argument; ensure it has been specified. 52 | if [ "$2" ]; then 53 | NUM_INSTANCES=$2 54 | shift # shift consumes $2 without treating it as another argument 55 | else 56 | die 'ERROR: "--num-instances" requires a non-empty option argument.' 57 | fi 58 | ;; 59 | 60 | -d|--demuxing) 61 | if [ "$2" ]; then 62 | DEMUXING=$2 63 | shift 64 | else 65 | die 'ERROR: "--demuxing" requires a non-empty option argument.' 66 | fi 67 | ;; 68 | 69 | -m|--muxing) 70 | if [ "$2" ]; then 71 | MUXING=$2 72 | shift 73 | else 74 | die 'ERROR: "--muxing" requires a non-empty option argument.' 75 | fi 76 | ;; 77 | 78 | -s|--setting) 79 | if [ "$2" ]; then 80 | SETTING=$2 81 | shift 82 | else 83 | die 'ERROR: "--setting" requires a non-empty option argument.' 84 | fi 85 | ;; 86 | 87 | --config_name) 88 | if [ "$2" ]; then 89 | CONFIG_NAME=$2 90 | shift 91 | else 92 | die 'ERROR: "--config_name" requires a non-empty option argument.' 93 | fi 94 | ;; 95 | 96 | --lr) 97 | if [ "$2" ]; then 98 | LEARNING_RATE=$2 99 | shift 100 | else 101 | die 'ERROR: "--lr" requires a non-empty option argument.' 102 | fi 103 | ;; 104 | 105 | --batch_size) 106 | if [ "$2" ]; then 107 | BATCH_SIZE=$2 108 | shift 109 | else 110 | die 'ERROR: "--batch_size" requires a non-empty option argument.' 111 | fi 112 | ;; 113 | 114 | --task) 115 | if [ "$2" ]; then 116 | TASK_NAME=$2 117 | shift 118 | else 119 | die 'ERROR: "--task" requires a non-empty option argument.' 120 | fi 121 | ;; 122 | 123 | --model_path) 124 | if [ "$2" ]; then 125 | MODEL_PATH=$2 126 | shift 127 | else 128 | die 'ERROR: "--model_path" requires a non-empty option argument.' 129 | fi 130 | ;; 131 | 132 | --learn_muxing) 133 | LEARN_MUXING=1 134 | ;; 135 | 136 | --do_train) 137 | DO_TRAIN=1 138 | ;; 139 | 140 | --do_eval) 141 | DO_EVAL=1 142 | ;; 143 | 144 | --) # End of all options. 145 | shift 146 | break 147 | ;; 148 | -?*) 149 | die "ERROR: Unknown option : ${1}" 150 | ;; 151 | *) # Default case: No more options, so break out of the loop. 152 | break 153 | esac 154 | 155 | shift 156 | done 157 | 158 | 159 | 160 | declare -A task2datasetmap 161 | task2datasetmap[ner]="conll2003" 162 | DATASET=${task2datasetmap[$TASK_NAME]} 163 | # other miscelleneous params 164 | SAVE_STEPS=10000 165 | MAX_SEQ_LENGTH=128 166 | 167 | if [ "$SETTING" == "retrieval_pretraining" ]; then 168 | 169 | RANDOM_ENCODING_NORM=20 170 | RETRIEVAL_PERCENTAGE=1.0 171 | RETRIEVAL_PRETRAINING=1 172 | RETRIEVAL_LOSS_COEFF=1 173 | TASK_LOSS_COEFF=0 174 | SHOULD_MUX=1 175 | DATALOADER_DROP_LAST=1 176 | OUTPUT_DIR_BASE="checkpoints/retrieval_pretraining" 177 | 178 | # params diff 179 | DATASET_NAME="wikitext" 180 | DATASET_CONFIG_NAME="wikitext-103-raw-v1" 181 | CMD_DIFF="--dataset_name ${DATASET_NAME}\ 182 | --dataset_config_name ${DATASET_CONFIG_NAME} \ 183 | --evaluation_strategy steps \ 184 | --eval_steps 10000 \ 185 | --max_steps 500000 \ 186 | --save_steps 10000" 187 | 188 | elif [ "$SETTING" = "finetuning" ]; then 189 | 190 | RANDOM_ENCODING_NORM=1 191 | RETRIEVAL_PERCENTAGE=1.0 192 | RETRIEVAL_PRETRAINING=0 193 | RETRIEVAL_LOSS_COEFF=0.1 194 | TASK_LOSS_COEFF=0.9 195 | SHOULD_MUX=1 196 | DATALOADER_DROP_LAST=1 197 | OUTPUT_DIR_BASE="checkpoints/finetune" 198 | 199 | # add task name 200 | # save steps + save strategy + num epochs 201 | 202 | CMD_DIFF="--task_name ${TASK_NAME}\ 203 | --dataset_name $DATASET \ 204 | --evaluation_strategy steps \ 205 | --eval_steps 10000 \ 206 | --max_steps 500000 \ 207 | --save_steps 10000 " 208 | 209 | elif [ "$SETTING" = "baseline" ]; then 210 | 211 | echo "Setting is baseline; sets --num-instances to 1." 212 | RANDOM_ENCODING_NORM=1 213 | RETRIEVAL_PERCENTAGE=1.0 214 | RETRIEVAL_PRETRAINING=0 215 | RETRIEVAL_LOSS_COEFF=0 216 | TASK_LOSS_COEFF=1 217 | SHOULD_MUX=0 218 | DATALOADER_DROP_LAST=0 219 | OUTPUT_DIR_BASE="checkpoints/baselines" 220 | NUM_INSTANCES=1 221 | # add task name 222 | # save steps + save strategy + num epochs 223 | CMD_DIFF="--task_name ${TASK_NAME}\ 224 | --dataset_name $DATASET \ 225 | --evaluation_strategy epoch \ 226 | --num_train_epochs 10" 227 | else 228 | echo "setting (${SETTING}) not recognized or unset. run \"run_glue.sh -h\" for usage." 229 | exit 0 230 | fi 231 | 232 | if [[ $LEARN_MUXING -ge 1 ]]; then 233 | OUTPUT_DIR=$OUTPUT_DIR_BASE/${TASK_NAME}_${MODEL_PATH}_${MUXING}_${DEMUXING}_${NUM_INSTANCES}_norm_${RANDOM_ENCODING_NORM}_rc_${RETRIEVAL_LOSS_COEFF}_lr${LEARNING_RATE}_tc_${TASK_LOSS_COEFF}_${CONFIG_PATH}_learntmuxing 234 | RUN_NAME=${TASK_NAME}_${MODEL_PATH}_${MUXING}_${DEMUXING}_${NUM_INSTANCES}_${RETRIEVAL_PERCENTAGE}_norm_${RANDOM_ENCODING_NORM}_rc_${RETRIEVAL_LOSS_COEFF}_lr${LEARNING_RATE}_tc_${TASK_LOSS_COEFF}_${CONFIG_PATH}_learnmuxing 235 | else 236 | OUTPUT_DIR=$OUTPUT_DIR_BASE/${TASK_NAME}_${MODEL_PATH}_${MUXING}_${DEMUXING}_${NUM_INSTANCES}_norm_${RANDOM_ENCODING_NORM}_rc_${RETRIEVAL_LOSS_COEFF}_lr${LEARNING_RATE}_tc_${TASK_LOSS_COEFF}_${CONFIG_PATH} 237 | RUN_NAME=${TASK_NAME}_${MODEL_PATH}_${MUXING}_${DEMUXING}_${NUM_INSTANCES}_${RETRIEVAL_PERCENTAGE}_norm_${RANDOM_ENCODING_NORM}_rc_${RETRIEVAL_LOSS_COEFF}_lr${LEARNING_RATE}_tc_${TASK_LOSS_COEFF}_${CONFIG_PATH} 238 | fi 239 | 240 | if [ -z "$BATCH_SIZE" ] # if BATCH_SIZE is not set manually 241 | then 242 | if [[ $NUM_INSTANCES -ge 40 ]] 243 | then 244 | BATCH_SIZE=16 245 | 246 | elif [[ $NUM_INSTANCES -ge 20 ]] 247 | then 248 | BATCH_SIZE=20 249 | elif [[ $NUM_INSTANCES -ge 2 ]] 250 | then 251 | BATCH_SIZE=24 252 | else 253 | BATCH_SIZE=32 254 | fi 255 | fi 256 | 257 | BATCH_SIZE=$(($BATCH_SIZE * NUM_INSTANCES)) 258 | 259 | CMD="python run_ner.py \ 260 | --tokenizer_name roberta-base \ 261 | --config_name ${CONFIG_NAME} \ 262 | --max_seq_length $MAX_SEQ_LENGTH \ 263 | --per_device_train_batch_size $BATCH_SIZE \ 264 | --per_device_eval_batch_size $BATCH_SIZE \ 265 | --learning_rate $LEARNING_RATE \ 266 | --output_dir $OUTPUT_DIR \ 267 | --run_name $RUN_NAME \ 268 | --logging_steps 100 \ 269 | --report_to wandb \ 270 | --dataloader_drop_last $DATALOADER_DROP_LAST \ 271 | --retrieval_percentage $RETRIEVAL_PERCENTAGE \ 272 | --retrieval_loss_coeff $RETRIEVAL_LOSS_COEFF \ 273 | --task_loss_coeff $TASK_LOSS_COEFF \ 274 | --retrieval_pretraining ${RETRIEVAL_PRETRAINING} \ 275 | --num_instances ${NUM_INSTANCES} \ 276 | --muxing_variant ${MUXING} \ 277 | --demuxing_variant ${DEMUXING} \ 278 | --should_mux ${SHOULD_MUX} \ 279 | --gaussian_hadamard_norm ${RANDOM_ENCODING_NORM} \ 280 | --learn_muxing ${LEARN_MUXING} \ 281 | --load_best_model_at_end 1 \ 282 | --metric_for_best_model eval_f1 \ 283 | --save_total_limit 1" 284 | if [ "$DO_TRAIN" -eq 1 ]; then 285 | CMD="${CMD} --do_train" 286 | fi 287 | if [ "$DO_EVAL" -eq 1 ]; then 288 | CMD="${CMD} --do_eval" 289 | fi 290 | 291 | if [ ! -z "$MODEL_PATH" ] # if MODEL PATH is set manually 292 | then 293 | CMD="${CMD} --model_name_or_path ${MODEL_PATH}" 294 | fi 295 | 296 | CMD=${CMD}" "${CMD_DIFF} 297 | 298 | if [[ $NUM_INSTANCES -ge 40 ]] 299 | then 300 | TIME="120:00:00" 301 | elif [[ $NUM_INSTANCES -ge 20 ]] 302 | then 303 | TIME="72:00:00" 304 | else 305 | TIME="30:00:00" 306 | fi 307 | 308 | if [[ $USE_SLURM = 1 ]]; then 309 | sbatch --time=$TIME --mem=32G --output=logs/%x-%j.out --job-name=${TASK_NAME}_${NUM_INSTANCES}_${MUXING}_${DEMUXING} --gres=gpu:1 ./run_job.sh \ 310 | "$CMD" 311 | else 312 | ./run_job.sh "$CMD" 313 | fi 314 | --------------------------------------------------------------------------------