├── .gitignore ├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING-ARCHIVED.md ├── Dockerfile ├── LICENSE.txt ├── README.md ├── SECURITY.md ├── experiments ├── __init__.py ├── constants.py ├── data_utils.py ├── hans.py ├── hans_utils.py ├── influence_helpers.py ├── misc_utils.py ├── mnli.py ├── mnli_utils.py ├── s_test_speedup.py ├── visualization.py └── visualization_utils.py ├── figs └── main.png ├── influence_utils ├── __init__.py ├── faiss_utils.py ├── glue_utils.py ├── multiprocessing_utils.py ├── nn_influence_utils.py └── parallel.py ├── requirements.txt ├── run_experiments.py ├── run_glue.py └── scripts ├── run_Amazon.sh ├── run_HANS.sh ├── run_MNLI.20200913.sh ├── run_MNLI.sh └── run_MNLI_2label.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # Sublime 141 | *.sublime-project 142 | *.sublime-workspace 143 | *.pem -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing. 2 | #ECCN:Open Source 3 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Salesforce Open Source Community Code of Conduct 2 | 3 | ## About the Code of Conduct 4 | 5 | Equality is a core value at Salesforce. We believe a diverse and inclusive 6 | community fosters innovation and creativity, and are committed to building a 7 | culture where everyone feels included. 8 | 9 | Salesforce open-source projects are committed to providing a friendly, safe, and 10 | welcoming environment for all, regardless of gender identity and expression, 11 | sexual orientation, disability, physical appearance, body size, ethnicity, nationality, 12 | race, age, religion, level of experience, education, socioeconomic status, or 13 | other similar personal characteristics. 14 | 15 | The goal of this code of conduct is to specify a baseline standard of behavior so 16 | that people with different social values and communication styles can work 17 | together effectively, productively, and respectfully in our open source community. 18 | It also establishes a mechanism for reporting issues and resolving conflicts. 19 | 20 | All questions and reports of abusive, harassing, or otherwise unacceptable behavior 21 | in a Salesforce open-source project may be reported by contacting the Salesforce 22 | Open Source Conduct Committee at ossconduct@salesforce.com. 23 | 24 | ## Our Pledge 25 | 26 | In the interest of fostering an open and welcoming environment, we as 27 | contributors and maintainers pledge to making participation in our project and 28 | our community a harassment-free experience for everyone, regardless of gender 29 | identity and expression, sexual orientation, disability, physical appearance, 30 | body size, ethnicity, nationality, race, age, religion, level of experience, education, 31 | socioeconomic status, or other similar personal characteristics. 32 | 33 | ## Our Standards 34 | 35 | Examples of behavior that contributes to creating a positive environment 36 | include: 37 | 38 | * Using welcoming and inclusive language 39 | * Being respectful of differing viewpoints and experiences 40 | * Gracefully accepting constructive criticism 41 | * Focusing on what is best for the community 42 | * Showing empathy toward other community members 43 | 44 | Examples of unacceptable behavior by participants include: 45 | 46 | * The use of sexualized language or imagery and unwelcome sexual attention or 47 | advances 48 | * Personal attacks, insulting/derogatory comments, or trolling 49 | * Public or private harassment 50 | * Publishing, or threatening to publish, others' private information—such as 51 | a physical or electronic address—without explicit permission 52 | * Other conduct which could reasonably be considered inappropriate in a 53 | professional setting 54 | * Advocating for or encouraging any of the above behaviors 55 | 56 | ## Our Responsibilities 57 | 58 | Project maintainers are responsible for clarifying the standards of acceptable 59 | behavior and are expected to take appropriate and fair corrective action in 60 | response to any instances of unacceptable behavior. 61 | 62 | Project maintainers have the right and responsibility to remove, edit, or 63 | reject comments, commits, code, wiki edits, issues, and other contributions 64 | that are not aligned with this Code of Conduct, or to ban temporarily or 65 | permanently any contributor for other behaviors that they deem inappropriate, 66 | threatening, offensive, or harmful. 67 | 68 | ## Scope 69 | 70 | This Code of Conduct applies both within project spaces and in public spaces 71 | when an individual is representing the project or its community. Examples of 72 | representing a project or community include using an official project email 73 | address, posting via an official social media account, or acting as an appointed 74 | representative at an online or offline event. Representation of a project may be 75 | further defined and clarified by project maintainers. 76 | 77 | ## Enforcement 78 | 79 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 80 | reported by contacting the Salesforce Open Source Conduct Committee 81 | at ossconduct@salesforce.com. All complaints will be reviewed and investigated 82 | and will result in a response that is deemed necessary and appropriate to the 83 | circumstances. The committee is obligated to maintain confidentiality with 84 | regard to the reporter of an incident. Further details of specific enforcement 85 | policies may be posted separately. 86 | 87 | Project maintainers who do not follow or enforce the Code of Conduct in good 88 | faith may face temporary or permanent repercussions as determined by other 89 | members of the project's leadership and the Salesforce Open Source Conduct 90 | Committee. 91 | 92 | ## Attribution 93 | 94 | This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home], 95 | version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html. 96 | It includes adaptions and additions from [Go Community Code of Conduct][golang-coc], 97 | [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc]. 98 | 99 | This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us]. 100 | 101 | [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/) 102 | [golang-coc]: https://golang.org/conduct 103 | [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md 104 | [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/ 105 | [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/ 106 | -------------------------------------------------------------------------------- /CONTRIBUTING-ARCHIVED.md: -------------------------------------------------------------------------------- 1 | # ARCHIVED 2 | 3 | This project is `Archived` and is no longer actively maintained; 4 | We are not accepting contributions or Pull Requests. 5 | 6 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax = docker/dockerfile:1.0-experimental 2 | FROM pytorch/pytorch:1.5.1-cuda10.1-cudnn7-runtime 3 | 4 | # working directory 5 | WORKDIR /workspace 6 | 7 | # --------------------------------------------- 8 | # Project-agnostic System Dependencies 9 | # --------------------------------------------- 10 | RUN \ 11 | # Install System Dependencies 12 | apt-get update && apt-get install -y --no-install-recommends \ 13 | build-essential \ 14 | cmake \ 15 | wget \ 16 | unzip \ 17 | psmisc \ 18 | vim \ 19 | git \ 20 | ssh \ 21 | curl \ 22 | lshw \ 23 | ubuntu-drivers-common \ 24 | ca-certificates \ 25 | libjpeg-dev \ 26 | libpng-dev && \ 27 | rm -rf /var/lib/apt/lists/* && \ 28 | # Install NVIDIA Driver 29 | # https://www.linuxbabe.com/ubuntu/install-nvidia-driver-ubuntu-18-04 30 | # ubuntu-drivers autoinstall && \ 31 | # https://serverfault.com/questions/227190/how-do-i-ask-apt-get-to-skip-any-interactive-post-install-configuration-steps 32 | # https://stackoverflow.com/questions/38165407/installing-lightdm-in-dockerfile-raises-interactive-keyboard-layout-menu 33 | # apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 34 | # nvidia-driver-440 && \ 35 | # rm -rf /var/lib/apt/lists/* && \ 36 | # Install NodeJS 37 | # https://github.com/nodesource/distributions/blob/master/README.md#deb 38 | curl -sL https://deb.nodesource.com/setup_12.x | bash - && \ 39 | apt-get install -y nodejs 40 | 41 | # --------------------------------------------- 42 | # Project-specific System Dependencies 43 | # --------------------------------------------- 44 | RUN \ 45 | # Install `graph_tool` 46 | # https://git.skewed.de/count0/graph-tool/-/wikis/installation-instructions#debian-ubuntu 47 | # and `cairo` https://cairographics.org/download/ 48 | echo "deb [ arch=amd64 ] https://downloads.skewed.de/apt bionic main" >> /etc/apt/sources.list && \ 49 | apt-key adv --keyserver keys.openpgp.org --recv-key 612DEFB798507F25 && \ 50 | apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 51 | python3-graph-tool \ 52 | libcairo2-dev && \ 53 | rm -rf /var/lib/apt/lists/* && \ 54 | # Link the directory to `graph_tool`, which is installed in a differet python path 55 | ln -s /usr/lib/python3/dist-packages/graph_tool/ /opt/conda/lib/python3.7/site-packages/graph_tool 56 | # Clone the Apex Module (this requires torch) 57 | # git clone https://github.com/NVIDIA/apex /workspace/apex && \ 58 | # cd /workspace/apex && \ 59 | # pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ && \ 60 | 61 | # --------------------------------------------- 62 | # Build Python depencies and utilize caching 63 | # --------------------------------------------- 64 | COPY ./fast-influence-functions/requirements.txt /workspace/fast-influence-functions/requirements.txt 65 | RUN pip install --no-cache-dir --upgrade pip && \ 66 | pip install --no-cache-dir -r /workspace/fast-influence-functions/requirements.txt && \ 67 | # for binding/linking path from host machines 68 | mkdir -p /nlp && \ 69 | mkdir -p /export/ && \ 70 | chmod -R 777 /export 71 | 72 | # upload everything 73 | COPY ./fast-influence-functions/ /workspace/fast-influence-functions/ 74 | 75 | # Set HOME 76 | ENV HOME="/workspace/fast-influence-functions" 77 | 78 | # --------------------------------------------- 79 | # Project-agnostic User-dependent Dependencies 80 | # --------------------------------------------- 81 | RUN \ 82 | # Install FZF the fuzzy finder 83 | git clone --depth 1 https://github.com/junegunn/fzf.git ~/.fzf && \ 84 | ~/.fzf/install --all && \ 85 | # Install Awesome vimrc 86 | git clone --depth=1 https://github.com/amix/vimrc.git ~/.vim_runtime && \ 87 | sh ~/.vim_runtime/install_awesome_vimrc.sh 88 | 89 | # Reset Entrypoint from Parent Images 90 | # https://stackoverflow.com/questions/40122152/how-to-remove-entrypoint-from-parent-image-on-dockerfile/40122750 91 | ENTRYPOINT [] 92 | 93 | # load bash 94 | CMD /bin/bash -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020, Salesforce.com, Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastIF: Scalable Influence Functions for Efficient Model Interpretation and Debugging 2 | 3 | [Link to the Paper](https://arxiv.org/abs/2012.15781) 4 | 5 | ![main](figs/main.png) 6 | 7 | # Requirements 8 | Please see `requirements.txt` and `Dockerfile` for detailed dependencies. The major ones include 9 | - `python 3.6 or later` (for type annotations and f-string) 10 | - `pytorch==1.5.1` 11 | - `transformers==3.0.2` 12 | 13 | # Setup 14 | ### Docker Setup 15 | To build the docker image, run the following script. 16 | 17 | ```bash 18 | DOCKER_BUILDKIT=1 docker build \ 19 | -t ${TAG} \ 20 | -f Dockerfile . 21 | ``` 22 | 23 | ### Data Setup 24 | 1. Download the data following the examples from [here](https://github.com/huggingface/transformers/tree/master/examples/text-classification) and [here](https://github.com/huggingface/transformers/tree/master/examples/adversarial). 25 | 2. Mount the data into `/export/home/Data/Glue` and `/export/home/Data/HANS` inside the image. 26 | 27 | # Experiments 28 | 1. To train the base models, please use `scripts/run_MNLI.sh` and `scripts/run_HANS.sh`. 29 | 2. To build FAISS indices, please see the function `create_FAISS_index` in `experiments/hans.py`. 30 | 3. Modify the paths in `experiments/constants.py` based on your setup. 31 | 4. To run the experiments, please follow the instructions in `run_experiments.py` where we have provided most of the default configurations/hyper-parameters. 32 | 33 | # Code Structure 34 | 35 | ### `experiments/` 36 | - This directory contains code that are used to conduct experiments. 37 | - However, the entry-point for experiments is `run_experiments.py`. 38 | 39 | ### `influence_utils/` 40 | This directory contains the core components of the influence functions. Most of the codes are designed to be independent of the experiments so could be adapted for others downstream needs. Two of the most important ones are: 41 | - `influence_utils/nn_influence_utils.py` contains the code for influence functions. 42 | - `influence_utils/parallel.py` contains the code for the parallel variant. Note that when running the parallel variant, make sure to turn off `wandb` (see [here](https://docs.wandb.ai/integrations/huggingface) for details) as the current codebase does not work well with `wandb` turned on. -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com) 4 | as soon as it is discovered. This library limits its runtime dependencies in 5 | order to reduce the total cost of ownership as much as can be, but all consumers 6 | should remain vigilant and have their security stakeholders review all third-party 7 | products (3PP) like this one and their dependencies. 8 | -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -------------------------------------------------------------------------------- /experiments/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | # Note that the paths used in `scripts/run_*.sh` are still 7 | # hard-coded to `/export/home/` 8 | WEIGHT_DECAY = 0.005 9 | 10 | MNLI_MODEL_PATH = None 11 | HANS_MODEL_PATH = None 12 | MNLI2_MODEL_PATH = None 13 | Amazon_MODEL_PATH = None 14 | MNLI_IMITATOR_MODEL_PATH = None 15 | 16 | # Trained and used in MNLI 17 | MNLI_FAISS_INDEX_PATH = None 18 | # Trained and used in HANS 19 | HANS_FAISS_INDEX_PATH = None 20 | # Trained and used in MNLI-2 21 | MNLI2_FAISS_INDEX_PATH = None 22 | # Trained on MNLI2 and used in HANS 23 | MNLI2_HANS_FAISS_INDEX_PATH = None 24 | # Trained on HANS and used in MNLI2 25 | HANS_MNLI2_FAISS_INDEX_PATH = None 26 | # Trained and used in Amazon 27 | Amazon_FAISS_INDEX_PATH = None 28 | # Trained on MNLI and used in ANLI 29 | MNLI_ANLI_FAISS_INDEX_PATH = None 30 | 31 | MNLI_TRAIN_INPUT_COLLECTIONS_PATH = None 32 | 33 | HANS_DATA_DIR = None 34 | GLUE_DATA_DIR = None 35 | ANLI_DATA_DIR = None 36 | Amazon_DATA_DIR = None 37 | 38 | MNLI_TRAIN_FILE_NAME = None 39 | MNLI_EVAL_MATCHED_FILE_NAME = None 40 | MNLI_EVAL_MISMATCHED_FILE_NAME = None 41 | HANS_TRAIN_FILE_NAME = None 42 | HANS_EVAL_FILE_NAME = None 43 | HANS_VALID_INDICES_FILE_NAME = None 44 | AMAZON_METADATA_ARRAY_FILE_NAME = None 45 | 46 | # Experiments specific 47 | MNLI_RETRAINING_INFLUENCE_OUTPUT_BASE_DIR = None 48 | MNLI_RETRAINING_INFLUENCE_OUTPUT_BASE_DIR2 = None 49 | MNLI_RETRAINING_INFLUENCE_OUTPUT_BASE_DIR3 = None 50 | 51 | # Some useful default hparams for influence functions 52 | DEFAULT_INFLUENCE_HPARAMS = { 53 | # `train_on_task_name` 54 | "mnli": { 55 | # `eval_task_name` 56 | "mnli": { 57 | "damp": 5e-3, 58 | "scale": 1e4, 59 | "num_samples": 1000 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /experiments/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | import os 7 | import time 8 | import torch 9 | import logging 10 | import pandas as pd 11 | from tqdm import trange 12 | from typing import Optional, Union, List, Dict 13 | 14 | from transformers import ( 15 | GlueDataset, 16 | GlueDataTrainingArguments, 17 | PreTrainedTokenizer, 18 | glue_convert_examples_to_features, 19 | InputExample, 20 | DataProcessor, 21 | # Used in label-flipping hacks 22 | RobertaTokenizer, 23 | RobertaTokenizerFast, 24 | XLMRobertaTokenizer, 25 | BartTokenizer, 26 | BartTokenizerFast) 27 | 28 | from transformers.data.datasets.glue import ( 29 | Split, 30 | FileLock) 31 | from transformers.data.processors.glue import ( 32 | MnliProcessor, 33 | MnliMismatchedProcessor) 34 | from transformers.data.metrics import simple_accuracy 35 | 36 | try: 37 | from wilds.datasets.amazon_dataset import AmazonDataset 38 | except ModuleNotFoundError: 39 | AmazonDataset = None 40 | 41 | logger = logging.getLogger(__name__) 42 | 43 | 44 | class CustomGlueDataset(GlueDataset): 45 | """Customized GlueData with changes: 46 | 47 | 1. Changed the `glue_processors` and `glue_output_modes` to customized ones. 48 | """ 49 | 50 | def __init__( 51 | self, 52 | args: GlueDataTrainingArguments, 53 | tokenizer: PreTrainedTokenizer, 54 | limit_length: Optional[int] = None, 55 | mode: Union[str, Split] = Split.train, 56 | cache_dir: Optional[str] = None, 57 | ): 58 | self.args = args 59 | self.processor = glue_processors[args.task_name]() 60 | self.output_mode = glue_output_modes[args.task_name] 61 | if isinstance(mode, str): 62 | try: 63 | mode = Split[mode] 64 | except KeyError: 65 | raise KeyError("mode is not a valid split name") 66 | # Load data features from cache or dataset file 67 | cached_features_file = os.path.join( 68 | cache_dir if cache_dir is not None else args.data_dir, 69 | "cached_{}_{}_{}_{}".format( 70 | mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name, 71 | ), 72 | ) 73 | label_list = self.processor.get_labels() 74 | if args.task_name in ["mnli", "mnli-mm", "mnli-2", "mnli-2-mm", "hans"] and tokenizer.__class__ in ( 75 | RobertaTokenizer, 76 | RobertaTokenizerFast, 77 | XLMRobertaTokenizer, 78 | BartTokenizer, 79 | BartTokenizerFast, 80 | ): 81 | # HACK(label indices are swapped in RoBERTa pretrained model) 82 | label_list[1], label_list[2] = label_list[2], label_list[1] 83 | self.label_list = label_list 84 | 85 | # Make sure only the first process in distributed training processes the dataset, 86 | # and the others will use the cache. 87 | lock_path = cached_features_file + ".lock" 88 | with FileLock(lock_path): 89 | 90 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 91 | start = time.time() 92 | self.features = torch.load(cached_features_file) 93 | logger.info( 94 | f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start 95 | ) 96 | else: 97 | logger.info(f"Creating features from dataset file at {args.data_dir}") 98 | 99 | if mode == Split.dev: 100 | examples = self.processor.get_dev_examples(args.data_dir) 101 | elif mode == Split.test: 102 | examples = self.processor.get_test_examples(args.data_dir) 103 | else: 104 | examples = self.processor.get_train_examples(args.data_dir) 105 | if limit_length is not None: 106 | examples = examples[:limit_length] 107 | self.features = glue_convert_examples_to_features( 108 | examples, 109 | tokenizer, 110 | max_length=args.max_seq_length, 111 | label_list=label_list, 112 | output_mode=self.output_mode, 113 | ) 114 | start = time.time() 115 | torch.save(self.features, cached_features_file) 116 | # ^ This seems to take a lot of time so I want to investigate why and how we can improve. 117 | logger.info( 118 | "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start 119 | ) 120 | 121 | 122 | class TwoLabelMnliProcessor(MnliProcessor): 123 | 124 | def get_labels(self) -> List[str]: 125 | """See base class.""" 126 | return ["non_entailment", "entailment"] 127 | 128 | def _create_examples(self, lines: List[List[str]], set_type: str) -> List[InputExample]: 129 | """Creates examples for the training, dev and test sets.""" 130 | examples = [] 131 | for (i, line) in enumerate(lines): 132 | if i == 0: 133 | continue 134 | guid = "%s-%s" % (set_type, line[0]) 135 | text_a = line[8] 136 | text_b = line[9] 137 | label = None if set_type.startswith("test") else self._preprocess_label(line[-1]) 138 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 139 | return examples 140 | 141 | def _preprocess_label(self, label: str) -> str: 142 | if label not in ["contradiction", "entailment", "neutral"]: 143 | raise ValueError(f"Label {label} not recognized.") 144 | 145 | if label in ["contradiction", "neutral"]: 146 | return "non_entailment" 147 | else: 148 | return "entailment" 149 | 150 | 151 | class TwoLabelMnliMismatchedProcessor(TwoLabelMnliProcessor): 152 | """Processor for the MultiNLI Mismatched data set (GLUE version).""" 153 | 154 | def get_dev_examples(self, data_dir: str) -> List[InputExample]: 155 | """See base class.""" 156 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")), "dev_mismatched") 157 | 158 | def get_test_examples(self, data_dir: str) -> List[InputExample]: 159 | """See base class.""" 160 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_mismatched.tsv")), "test_mismatched") 161 | 162 | 163 | class HansProcessor(DataProcessor): 164 | """Processor for the HANS data set.""" 165 | 166 | def get_train_examples(self, data_dir: str) -> List[InputExample]: 167 | """See base class.""" 168 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "heuristics_train_set.txt")), "train") 169 | 170 | def get_dev_examples(self, data_dir: str) -> List[InputExample]: 171 | """See base class.""" 172 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "heuristics_evaluation_set.txt")), "dev") 173 | 174 | def get_labels(self) -> List[str]: 175 | """See base class.""" 176 | return ["non_entailment", "entailment"] 177 | 178 | def _create_examples(self, lines: List[List[str]], set_type: str) -> List[InputExample]: 179 | """Creates examples for the training and dev sets.""" 180 | examples = [] 181 | for (i, line) in enumerate(lines): 182 | if i == 0: 183 | continue 184 | guid = "%s-%s" % (set_type, i) 185 | text_a = line[5] 186 | text_b = line[6] 187 | label = self._preprocess_label(line[0]) 188 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 189 | return examples 190 | 191 | def _preprocess_label(self, label: str) -> str: 192 | if label not in ["non-entailment", "entailment"]: 193 | raise ValueError(f"Label {label} not recognized.") 194 | 195 | if label in ["non-entailment"]: 196 | return "non_entailment" 197 | else: 198 | return "entailment" 199 | 200 | 201 | class WILDSAmazonProcessor(DataProcessor): 202 | """Processor for the Amazon data set (WILDS version).""" 203 | 204 | def get_train_examples(self, data_dir): 205 | """See base class.""" 206 | # Using `quotechar` since some rows have strings that span multiple lines 207 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "amazon.train.tsv"), quotechar='"'), "train") 208 | 209 | def get_dev_examples(self, data_dir): 210 | """See base class.""" 211 | # Using `quotechar` since some rows have strings that span multiple lines 212 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "amazon.val.tsv"), quotechar='"'), "dev") 213 | 214 | def get_test_examples(self, data_dir): 215 | """See base class.""" 216 | # Using `quotechar` since some rows have strings that span multiple lines 217 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "amazon.test.tsv"), quotechar='"'), "test") 218 | 219 | def get_labels(self): 220 | """See base class.""" 221 | return ["0", "1", "2", "3", "4"] 222 | 223 | def _create_examples(self, lines, set_type): 224 | """Creates examples for the training, dev and test sets.""" 225 | examples = [] 226 | for (i, line) in enumerate(lines): 227 | if i == 0: 228 | continue 229 | guid = "%s-%s" % (set_type, i) 230 | text_a = line[0] 231 | label = line[1] 232 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 233 | return examples 234 | 235 | 236 | class ANLIProcessor(DataProcessor): 237 | """Processor for the HANS data set.""" 238 | 239 | def get_train_examples(self, data_dir: str) -> List[InputExample]: 240 | """See base class.""" 241 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv"), quotechar='"'), "train") 242 | 243 | def get_dev_examples(self, data_dir: str) -> List[InputExample]: 244 | """See base class.""" 245 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "valid.tsv"), quotechar='"'), "dev") 246 | 247 | def get_test_examples(self, data_dir: str) -> List[InputExample]: 248 | """See base class.""" 249 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv"), quotechar='"'), "test") 250 | 251 | def get_labels(self) -> List[str]: 252 | """See base class.""" 253 | return ["contradiction", "entailment", "neutral"] 254 | 255 | def _create_examples(self, lines: List[List[str]], set_type: str) -> List[InputExample]: 256 | """Creates examples for the training and dev sets.""" 257 | examples = [] 258 | for (i, line) in enumerate(lines): 259 | if i == 0: 260 | continue 261 | guid = "%s-%s" % (set_type, i) 262 | text_a = line[1] 263 | text_b = line[2] 264 | label = self._preprocess_label(line[3]) 265 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 266 | return examples 267 | 268 | def _preprocess_label(self, label: str) -> str: 269 | label_map = { 270 | "e": "entailment", 271 | "n": "neutral", 272 | "c": "contradiction" 273 | } 274 | if label not in label_map.keys(): 275 | raise ValueError(f"Label {label} not recognized.") 276 | 277 | return label_map[label] 278 | 279 | 280 | def glue_compute_metrics(task_name: str, preds: List, labels: List) -> Dict[str, float]: 281 | assert len(preds) == len(labels) 282 | if task_name not in glue_processors.keys(): 283 | raise ValueError(f"Unrecognized {task_name}") 284 | 285 | return {"acc": simple_accuracy(preds, labels)} 286 | 287 | 288 | def write_amazon_dataset_to_disk(base_dir: str) -> None: 289 | dataset = AmazonDataset(download=False) 290 | for split in dataset.split_dict.keys(): 291 | datasubset = dataset.get_subset(split) 292 | file_name = os.path.join( 293 | base_dir, 294 | f"amazon.{split}.tsv") 295 | 296 | examples = [] 297 | for index in trange(len(datasubset)): 298 | examples.append({ 299 | "sentence": datasubset[index][0], 300 | "label": datasubset[index][1].item()}) 301 | 302 | pd.DataFrame(examples).to_csv( 303 | file_name, 304 | sep="\t", 305 | index=False) 306 | 307 | print(f"Wrote {file_name} to disk") 308 | 309 | 310 | glue_tasks_num_labels = { 311 | "mnli": 3, 312 | "mnli-2": 2, 313 | "hans": 2, 314 | "amazon": 5, 315 | "anli": 3, 316 | } 317 | 318 | glue_processors = { 319 | "mnli": MnliProcessor, 320 | "mnli-mm": MnliMismatchedProcessor, 321 | "mnli-2": TwoLabelMnliProcessor, 322 | "mnli-2-mm": TwoLabelMnliMismatchedProcessor, 323 | "hans": HansProcessor, 324 | "amazon": WILDSAmazonProcessor, 325 | "anli": ANLIProcessor, 326 | } 327 | 328 | glue_output_modes = { 329 | "mnli": "classification", 330 | "mnli-mm": "classification", 331 | "mnli-2": "classification", 332 | "mnli-2-mm": "classification", 333 | "hans": "classification", 334 | "amazon": "classification", 335 | "anli": "classification", 336 | } 337 | -------------------------------------------------------------------------------- /experiments/hans.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | import torch 7 | import numpy as np 8 | import transformers 9 | from tqdm import tqdm 10 | from copy import deepcopy 11 | from collections import defaultdict 12 | from transformers import InputFeatures 13 | from transformers import default_data_collator 14 | from typing import Union, Dict, Any, List, Tuple, Optional 15 | 16 | from influence_utils import faiss_utils 17 | from influence_utils import nn_influence_utils 18 | from experiments import constants 19 | from experiments import misc_utils 20 | # from experiments import remote_utils 21 | from experiments import influence_helpers 22 | from experiments.hans_utils import HansHelper, SimpleHelper, AmazonHelper 23 | from transformers import TrainingArguments 24 | from experiments.data_utils import CustomGlueDataset 25 | 26 | DEFAULT_KNN_K = 1000 27 | DEFAULT_NUM_REPLICAS = 3 28 | EVAL_HEURISTICS_SAMPLE_BATCH_SIZE = 10 29 | EXPERIMENT_TYPES = ["most-helpful", "most-harmful", "random"] 30 | DEFAULT_HANS_EVAL_HEURISTICS = ["lexical_overlap", "subsequence", "constituent"] 31 | DEFAULT_ANLI_EVAL_HEURISTICS = ["null"] 32 | DEFAULT_Amazon_EVAL_HEURISTICS = ["null"] 33 | VERSION_2_NUM_DATAPOINTS_CHOICES = [EVAL_HEURISTICS_SAMPLE_BATCH_SIZE] 34 | VERSION_2_LEARNING_RATE_CHOICES = [1e-4] 35 | 36 | 37 | def main( 38 | trained_on_task_name: str, 39 | train_task_name: str, 40 | train_heuristic: str, 41 | num_replicas: Optional[int] = None, 42 | use_parallel: bool = True, 43 | version: Optional[str] = None, 44 | ) -> Dict[str, List[Dict[str, Any]]]: 45 | 46 | if trained_on_task_name not in ["mnli", "mnli-2", "amazon"]: 47 | raise ValueError 48 | 49 | if train_task_name not in ["mnli-2", "hans", "amazon", "anli"]: 50 | raise ValueError 51 | 52 | if num_replicas is None: 53 | num_replicas = DEFAULT_NUM_REPLICAS 54 | 55 | if version not in ["new-only-z", "new-only-ztest", "new-z-and-ztest"]: 56 | raise ValueError 57 | 58 | if trained_on_task_name in ["mnli-2"]: 59 | eval_heuristics = DEFAULT_HANS_EVAL_HEURISTICS 60 | task_tokenizer, task_model = misc_utils.create_tokenizer_and_model( 61 | constants.MNLI2_MODEL_PATH) 62 | 63 | (mnli_train_dataset, 64 | mnli_eval_dataset) = misc_utils.create_datasets( 65 | task_name="mnli-2", 66 | tokenizer=task_tokenizer) 67 | 68 | (hans_train_dataset, 69 | hans_eval_dataset) = misc_utils.create_datasets( 70 | task_name="hans", 71 | tokenizer=task_tokenizer) 72 | 73 | if train_task_name == "mnli-2": 74 | train_dataset = mnli_train_dataset 75 | 76 | if train_task_name == "hans": 77 | train_dataset = hans_train_dataset 78 | 79 | (s_test_damp, 80 | s_test_scale, 81 | s_test_num_samples) = influence_helpers.select_s_test_config( 82 | trained_on_task_name=trained_on_task_name, 83 | train_task_name=train_task_name, 84 | eval_task_name="hans", 85 | ) 86 | 87 | hans_helper = HansHelper( 88 | hans_train_dataset=hans_train_dataset, 89 | hans_eval_dataset=hans_eval_dataset) 90 | 91 | if trained_on_task_name in ["amazon"]: 92 | # This is not used, so used `null` as a placeholder 93 | eval_heuristics = DEFAULT_Amazon_EVAL_HEURISTICS 94 | task_tokenizer, task_model = misc_utils.create_tokenizer_and_model( 95 | constants.Amazon_MODEL_PATH) 96 | 97 | (amazon_train_dataset, 98 | amazon_eval_dataset, 99 | amazon_test_dataset) = misc_utils.create_datasets( 100 | task_name="amazon", 101 | tokenizer=task_tokenizer, 102 | create_test_dataset=True) 103 | 104 | # Fine-tune on training dataset 105 | train_dataset = amazon_train_dataset 106 | 107 | (s_test_damp, 108 | s_test_scale, 109 | s_test_num_samples) = influence_helpers.select_s_test_config( 110 | trained_on_task_name=trained_on_task_name, 111 | train_task_name=train_task_name, 112 | eval_task_name="amazon", 113 | ) 114 | 115 | hans_helper = AmazonHelper( 116 | train_dataset=amazon_train_dataset, 117 | eval_dataset=amazon_eval_dataset, 118 | test_dataset=amazon_test_dataset) 119 | 120 | if trained_on_task_name in ["mnli"]: 121 | # This is not used, so used `null` as a placeholder 122 | eval_heuristics = DEFAULT_ANLI_EVAL_HEURISTICS 123 | task_tokenizer, task_model = misc_utils.create_tokenizer_and_model( 124 | constants.MNLI_MODEL_PATH) 125 | 126 | (anli_train_dataset, 127 | anli_eval_dataset, 128 | anli_test_dataset) = misc_utils.create_datasets( 129 | task_name="anli", 130 | tokenizer=task_tokenizer, 131 | create_test_dataset=True) 132 | 133 | # Fine-tune on training dataset 134 | train_dataset = anli_train_dataset 135 | 136 | (s_test_damp, 137 | s_test_scale, 138 | s_test_num_samples) = influence_helpers.select_s_test_config( 139 | trained_on_task_name=trained_on_task_name, 140 | train_task_name=train_task_name, 141 | eval_task_name="anli", 142 | ) 143 | 144 | hans_helper = SimpleHelper( 145 | train_dataset=anli_train_dataset, 146 | eval_dataset=anli_eval_dataset, 147 | test_dataset=anli_test_dataset) 148 | 149 | # We will be running model trained on MNLI-2 150 | # but calculate influences on HANS dataset 151 | faiss_index = influence_helpers.load_faiss_index( 152 | trained_on_task_name=trained_on_task_name, 153 | train_task_name=train_task_name 154 | ) 155 | 156 | # Most of these arguments are placeholders 157 | # and are not really used at all, so ignore 158 | # the exact values of these. 159 | trainer = transformers.Trainer( 160 | model=task_model, 161 | args=TrainingArguments( 162 | output_dir="./tmp-output", 163 | per_device_train_batch_size=128, 164 | per_device_eval_batch_size=128, 165 | learning_rate=5e-5, 166 | logging_steps=100), 167 | ) 168 | 169 | output_collections: Dict[str, List] = defaultdict(list) 170 | 171 | if version == "old": 172 | raise ValueError("Deprecated") 173 | 174 | else: 175 | NUM_STEPS = 10 176 | num_total_experiments = ( 177 | len(EXPERIMENT_TYPES) * 178 | num_replicas * 179 | len(VERSION_2_NUM_DATAPOINTS_CHOICES) * 180 | len(VERSION_2_LEARNING_RATE_CHOICES) * 181 | NUM_STEPS 182 | ) 183 | 184 | with tqdm(total=num_total_experiments) as pbar: 185 | for experiment_type in EXPERIMENT_TYPES: 186 | for replica_index in range(num_replicas): 187 | for version_2_num_datapoints in VERSION_2_NUM_DATAPOINTS_CHOICES: 188 | for version_2_learning_rate in VERSION_2_LEARNING_RATE_CHOICES: 189 | 190 | # The model will be used for multiple 191 | # steps so `deepcopy` it here. 192 | _model = deepcopy(task_model) 193 | for step in range(NUM_STEPS): 194 | 195 | # Sample anchor data-points every step 196 | (hans_eval_heuristic_inputs, 197 | hans_eval_heuristic_raw_inputs) = hans_helper.sample_batch_of_heuristic( 198 | mode="eval", 199 | heuristic=train_heuristic, 200 | size=EVAL_HEURISTICS_SAMPLE_BATCH_SIZE, 201 | return_raw_data=True) 202 | 203 | misc_utils.move_inputs_to_device( 204 | inputs=hans_eval_heuristic_inputs, 205 | device=task_model.device) 206 | 207 | outputs_one_experiment, _model = one_experiment( 208 | use_parallel=use_parallel, 209 | eval_heuristics=eval_heuristics, 210 | experiment_type=experiment_type, 211 | hans_helper=hans_helper, 212 | train_dataset=train_dataset, 213 | task_model=_model, 214 | faiss_index=faiss_index, 215 | s_test_damp=s_test_damp, 216 | s_test_scale=s_test_scale, 217 | s_test_num_samples=s_test_num_samples, 218 | trainer=trainer, 219 | version=version, 220 | version_2_num_datapoints=version_2_num_datapoints, 221 | version_2_learning_rate=version_2_learning_rate, 222 | hans_eval_heuristic_inputs=hans_eval_heuristic_inputs, 223 | hans_eval_heuristic_raw_inputs=hans_eval_heuristic_raw_inputs, 224 | ) 225 | 226 | output_collections[ 227 | f"{experiment_type}-" 228 | f"{replica_index}-" 229 | f"{version_2_num_datapoints}-" 230 | f"{version_2_learning_rate}-" 231 | ].append(outputs_one_experiment) 232 | 233 | pbar.update(1) 234 | pbar.set_description(f"{experiment_type} #{replica_index}") 235 | 236 | torch.save( 237 | output_collections, 238 | f"hans-augmentation-{version}." 239 | f"{trained_on_task_name}." 240 | f"{train_task_name}." 241 | f"{train_heuristic}." 242 | f"{num_replicas}." 243 | f"{use_parallel}." 244 | f"{DEFAULT_KNN_K}.pth") 245 | 246 | return output_collections 247 | 248 | 249 | def one_experiment( 250 | use_parallel: bool, 251 | eval_heuristics: List[str], 252 | experiment_type: str, 253 | hans_helper: HansHelper, 254 | train_dataset: CustomGlueDataset, 255 | task_model: torch.nn.Module, 256 | faiss_index: faiss_utils.FAISSIndex, 257 | s_test_damp: float, 258 | s_test_scale: float, 259 | s_test_num_samples: int, 260 | trainer: transformers.Trainer, 261 | version: str, 262 | version_2_num_datapoints: Optional[int], 263 | version_2_learning_rate: Optional[float], 264 | hans_eval_heuristic_inputs: Dict[str, Any], 265 | hans_eval_heuristic_raw_inputs: List[InputFeatures], 266 | ) -> Tuple[Dict[str, Any], Optional[torch.nn.Module]]: 267 | if task_model.device.type != "cuda": 268 | raise ValueError("The model is supposed to be on CUDA") 269 | 270 | if version_2_num_datapoints is None: 271 | raise ValueError 272 | if version_2_learning_rate is None: 273 | raise ValueError 274 | 275 | if experiment_type in ["most-harmful", "most-helpful"]: 276 | 277 | influences = influence_helpers.compute_influences_simplified( 278 | k=DEFAULT_KNN_K, 279 | faiss_index=faiss_index, 280 | model=task_model, 281 | inputs=hans_eval_heuristic_inputs, 282 | train_dataset=train_dataset, 283 | use_parallel=use_parallel, 284 | s_test_damp=s_test_damp, 285 | s_test_scale=s_test_scale, 286 | s_test_num_samples=s_test_num_samples, 287 | device_ids=[1, 2, 3], 288 | precomputed_s_test=None, 289 | faiss_index_use_mean_features_as_query=True, 290 | ) 291 | helpful_indices, harmful_indices = misc_utils.get_helpful_harmful_indices_from_influences_dict( 292 | influences, n=version_2_num_datapoints) 293 | if experiment_type == "most-helpful": 294 | datapoint_indices = helpful_indices 295 | 296 | if experiment_type == "most-harmful": 297 | datapoint_indices = harmful_indices 298 | 299 | if experiment_type == "random": 300 | # s_test = None 301 | influences = None 302 | hans_eval_heuristic_inputs = None 303 | # Essentially shuffle the indices 304 | datapoint_indices = np.random.choice( 305 | len(train_dataset), 306 | size=len(train_dataset), 307 | replace=False) 308 | 309 | loss_collections = {} 310 | accuracy_collections = {} 311 | 312 | # num_datapoints = 1 313 | # learning_rate = 1e-4 314 | num_datapoints = version_2_num_datapoints 315 | learning_rate = version_2_learning_rate 316 | 317 | if version == "new-only-z": 318 | datapoints = [ 319 | train_dataset[index] 320 | for index in datapoint_indices[:num_datapoints]] 321 | 322 | if version == "new-only-ztest": 323 | datapoints = hans_eval_heuristic_raw_inputs 324 | 325 | if version == "new-z-and-ztest": 326 | datapoints = [ 327 | train_dataset[index] 328 | for index in datapoint_indices[:num_datapoints] 329 | ] + hans_eval_heuristic_raw_inputs 330 | 331 | batch = default_data_collator(datapoints) 332 | new_model, _ = pseudo_gradient_step( 333 | model=task_model, 334 | inputs=batch, 335 | learning_rate=learning_rate) 336 | 337 | for heuristic in eval_heuristics: 338 | new_model_loss, new_model_accuracy = evaluate_heuristic( 339 | hans_helper=hans_helper, 340 | heuristic=heuristic, 341 | trainer=trainer, 342 | model=new_model) 343 | 344 | loss_collections[heuristic] = new_model_loss 345 | accuracy_collections[heuristic] = new_model_accuracy 346 | # print(f"Finished {num_datapoints}-{learning_rate}") 347 | 348 | output_collections = { 349 | # "s_test": s_test, 350 | "influences": influences, 351 | "loss": loss_collections, 352 | "accuracy": accuracy_collections, 353 | "datapoint_indices": datapoint_indices, 354 | "learning_rate": learning_rate, 355 | "num_datapoints": num_datapoints, 356 | "hans_eval_heuristic_inputs": hans_eval_heuristic_inputs, 357 | } 358 | 359 | # Warning: Check again whether using this `new_model` is a good idea 360 | return output_collections, new_model 361 | 362 | 363 | def pseudo_gradient_step( 364 | model: torch.nn.Module, 365 | inputs: Dict[str, Union[torch.Tensor, Any]], 366 | learning_rate: float, 367 | precomputed_gradients_z: Optional[List[torch.FloatTensor]] = None 368 | ) -> Tuple[torch.nn.Module, List[torch.FloatTensor]]: 369 | 370 | params_filter = [ 371 | n for n, p in model.named_parameters() 372 | if not p.requires_grad] 373 | 374 | weight_decay_ignores = [ 375 | "bias", 376 | "LayerNorm.weight"] + [ 377 | n for n, p in model.named_parameters() 378 | if not p.requires_grad] 379 | 380 | params_to_freeze = [ 381 | "bert.embeddings.", 382 | "bert.encoder.layer.0.", 383 | "bert.encoder.layer.1.", 384 | "bert.encoder.layer.2.", 385 | "bert.encoder.layer.3.", 386 | "bert.encoder.layer.4.", 387 | "bert.encoder.layer.5.", 388 | "bert.encoder.layer.6.", 389 | "bert.encoder.layer.7.", 390 | "bert.encoder.layer.8.", 391 | "bert.encoder.layer.9.", 392 | ] 393 | 394 | if precomputed_gradients_z is not None: 395 | gradients_z = precomputed_gradients_z 396 | else: 397 | gradients_z = nn_influence_utils.compute_gradients( 398 | n_gpu=1, 399 | device=torch.device("cuda"), 400 | model=model, 401 | inputs=inputs, 402 | params_filter=params_filter, 403 | weight_decay=constants.WEIGHT_DECAY, 404 | weight_decay_ignores=weight_decay_ignores) 405 | 406 | new_model = deepcopy(model) 407 | params_to_update = [ 408 | p for name, p in new_model.named_parameters() 409 | if not any(pfreeze in name for pfreeze in params_to_freeze)] 410 | 411 | # They should refer to the same parameters 412 | if len(params_to_update) != len(gradients_z): 413 | raise ValueError 414 | 415 | with torch.no_grad(): 416 | [p.sub_(learning_rate * grad_z) for p, grad_z in 417 | zip(params_to_update, gradients_z)] 418 | 419 | return new_model, gradients_z 420 | 421 | 422 | def evaluate_heuristic( 423 | hans_helper: HansHelper, 424 | heuristic: str, 425 | trainer: transformers.Trainer, 426 | model: torch.nn.Module, 427 | ) -> Tuple[float, float]: 428 | 429 | _, batch_dataloader = hans_helper.get_dataset_and_dataloader_of_heuristic( 430 | mode="test", 431 | heuristic=heuristic, 432 | batch_size=1000, 433 | random=False) 434 | 435 | loss = 0. 436 | num_corrects = 0. 437 | num_examples = 0 438 | for index, inputs in enumerate(batch_dataloader): 439 | batch_size = inputs["labels"].shape[0] 440 | batch_preds, batch_label_ids, batch_mean_loss = misc_utils.predict( 441 | trainer=trainer, 442 | model=model, 443 | inputs=inputs) 444 | 445 | num_examples += batch_size 446 | loss += batch_mean_loss * batch_size 447 | num_corrects += (batch_preds.argmax(axis=-1) == batch_label_ids).sum() 448 | 449 | return loss / num_examples, num_corrects / num_examples 450 | 451 | 452 | def create_FAISS_index( 453 | train_task_name: str, 454 | trained_on_task_name: str, 455 | ) -> faiss_utils.FAISSIndex: 456 | if train_task_name not in ["mnli-2", "hans", "amazon", "anli"]: 457 | raise ValueError 458 | 459 | if trained_on_task_name not in ["mnli", "mnli-2", "hans", "amazon"]: 460 | raise ValueError 461 | 462 | if trained_on_task_name == "mnli": 463 | tokenizer, model = misc_utils.create_tokenizer_and_model( 464 | constants.MNLI_MODEL_PATH) 465 | 466 | if trained_on_task_name == "mnli-2": 467 | tokenizer, model = misc_utils.create_tokenizer_and_model( 468 | constants.MNLI2_MODEL_PATH) 469 | 470 | if trained_on_task_name == "hans": 471 | tokenizer, model = misc_utils.create_tokenizer_and_model( 472 | constants.HANS_MODEL_PATH) 473 | 474 | if trained_on_task_name == "amazon": 475 | tokenizer, model = misc_utils.create_tokenizer_and_model( 476 | constants.Amazon_MODEL_PATH) 477 | 478 | train_dataset, _ = misc_utils.create_datasets( 479 | task_name=train_task_name, 480 | tokenizer=tokenizer) 481 | 482 | faiss_index = faiss_utils.FAISSIndex(768, "Flat") 483 | 484 | model.cuda() 485 | device = model.device 486 | train_batch_data_loader = misc_utils.get_dataloader( 487 | dataset=train_dataset, 488 | batch_size=128, 489 | random=False) 490 | 491 | for inputs in tqdm(train_batch_data_loader): 492 | for k, v in inputs.items(): 493 | inputs[k] = v.to(device) 494 | features = misc_utils.compute_BERT_CLS_feature(model, **inputs) 495 | features = features.cpu().detach().numpy() 496 | faiss_index.add(features) 497 | 498 | return faiss_index 499 | -------------------------------------------------------------------------------- /experiments/hans_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | import torch 7 | import numpy as np 8 | import pandas as pd 9 | from collections import defaultdict 10 | 11 | from experiments import constants 12 | from experiments import data_utils 13 | from experiments import misc_utils 14 | from transformers import default_data_collator 15 | from typing import List, Union, Iterable, Dict, Any, Tuple, Optional 16 | 17 | try: 18 | from wilds.datasets.amazon_dataset import AmazonDataset 19 | except ModuleNotFoundError: 20 | AmazonDataset = None 21 | 22 | 23 | class SubsetDataset(torch.utils.data.Dataset): 24 | def __init__(self, 25 | dataset: data_utils.CustomGlueDataset, 26 | indices: Union[np.ndarray, List[int]]) -> None: 27 | 28 | super(SubsetDataset, self).__init__() 29 | self.wrapped_dataset = dataset 30 | self.indices = indices 31 | 32 | def __getitem__(self, index) -> Dict[str, Union[torch.Tensor, Any]]: 33 | mapped_index = self.indices[index] 34 | return self.wrapped_dataset[mapped_index] 35 | 36 | def __len__(self) -> int: 37 | return len(self.indices) 38 | 39 | 40 | class HansHelper(object): 41 | def __init__( 42 | self, 43 | hans_train_dataset: Optional[data_utils.CustomGlueDataset] = None, 44 | hans_eval_dataset: Optional[data_utils.CustomGlueDataset] = None) -> None: 45 | 46 | # This file includes both validation and test 47 | combined_hans_eval_df = pd.read_csv(constants.HANS_EVAL_FILE_NAME, sep="\t") 48 | # This is a list of indices that should be mapped to validation dataset 49 | valid_indices = torch.load(constants.HANS_VALID_INDICES_FILE_NAME) 50 | # https://stackoverflow.com/questions/28256761/select-pandas-rows-by-excluding-index-number 51 | valid_selector = combined_hans_eval_df.index.isin(valid_indices) 52 | 53 | self._hans_train_df = pd.read_csv(constants.HANS_TRAIN_FILE_NAME, sep="\t") 54 | self._hans_eval_df = combined_hans_eval_df[valid_selector] 55 | self._hans_test_df = combined_hans_eval_df[~valid_selector] 56 | self._hans_train_dataset = hans_train_dataset 57 | self._hans_eval_dataset = hans_eval_dataset 58 | 59 | def get_indices_of_heuristic( 60 | self, 61 | mode: str, 62 | heuristic: str) -> List[int]: 63 | 64 | if mode not in ["train", "eval", "test"]: 65 | raise ValueError 66 | 67 | if heuristic not in ["lexical_overlap", "subsequence", "constituent"]: 68 | raise ValueError 69 | 70 | if mode == "train": 71 | df = self._hans_train_df 72 | if mode == "eval": 73 | df = self._hans_eval_df 74 | if mode == "test": 75 | df = self._hans_test_df 76 | 77 | indices_of_heuristic = df[df.heuristic == heuristic].index 78 | return indices_of_heuristic.tolist() 79 | 80 | def sample_batch_of_heuristic( 81 | self, 82 | mode: str, 83 | heuristic: str, 84 | size: int, 85 | return_raw_data: bool = False) -> np.ndarray: 86 | 87 | if mode not in ["train", "eval", "test"]: 88 | raise ValueError 89 | 90 | if mode == "train": 91 | dataset = self._hans_train_dataset 92 | else: 93 | dataset = self._hans_eval_dataset 94 | 95 | if dataset is None: 96 | raise ValueError("`dataset` is None") 97 | 98 | indices = self.get_indices_of_heuristic( 99 | mode=mode, heuristic=heuristic) 100 | 101 | sampled_indices = np.random.choice( 102 | indices, size=size, replace=False) 103 | 104 | sampled_data = [dataset[index] for index in sampled_indices] 105 | batched_data = default_data_collator(sampled_data) 106 | if return_raw_data is False: 107 | return batched_data 108 | 109 | return batched_data, sampled_data 110 | 111 | def get_dataset_and_dataloader_of_heuristic( 112 | self, 113 | mode: str, 114 | heuristic: str, 115 | batch_size: int, 116 | random: bool) -> Tuple[SubsetDataset, 117 | torch.utils.data.DataLoader]: 118 | 119 | if mode not in ["train", "eval", "test"]: 120 | raise ValueError 121 | 122 | if mode == "train": 123 | dataset = self._hans_train_dataset 124 | else: 125 | dataset = self._hans_eval_dataset 126 | 127 | if dataset is None: 128 | raise ValueError("`dataset` is None") 129 | 130 | indices = self.get_indices_of_heuristic( 131 | mode=mode, heuristic=heuristic) 132 | 133 | heuristic_dataset = SubsetDataset(dataset=dataset, indices=indices) 134 | heuristic_dataloader = misc_utils.get_dataloader( 135 | dataset=heuristic_dataset, 136 | batch_size=batch_size, 137 | random=random) 138 | 139 | return heuristic_dataset, heuristic_dataloader 140 | 141 | 142 | class SimpleHelper(object): 143 | def __init__( 144 | self, 145 | train_dataset: Optional[data_utils.CustomGlueDataset] = None, 146 | eval_dataset: Optional[data_utils.CustomGlueDataset] = None, 147 | test_dataset: Optional[data_utils.CustomGlueDataset] = None) -> None: 148 | 149 | self._train_dataset = train_dataset 150 | self._eval_dataset = eval_dataset 151 | self._test_dataset = test_dataset 152 | 153 | def sample_batch_of_heuristic( 154 | self, 155 | mode: str, 156 | heuristic: str, 157 | size: int, 158 | return_raw_data: bool = False) -> np.ndarray: 159 | 160 | if mode not in ["train", "eval", "test"]: 161 | raise ValueError 162 | 163 | if heuristic not in ["null"]: 164 | raise ValueError 165 | 166 | if mode == "train": 167 | dataset = self._train_dataset 168 | if mode == "eval": 169 | dataset = self._eval_dataset 170 | if mode == "test": 171 | dataset = self._test_dataset 172 | 173 | if dataset is None: 174 | raise ValueError("`dataset` is None") 175 | 176 | sampled_indices = np.random.choice( 177 | len(dataset), size=size, replace=False) 178 | 179 | sampled_data = [dataset[index] for index in sampled_indices] 180 | batched_data = default_data_collator(sampled_data) 181 | if return_raw_data is False: 182 | return batched_data 183 | 184 | return batched_data, sampled_data 185 | 186 | def get_dataset_and_dataloader_of_heuristic( 187 | self, 188 | mode: str, 189 | heuristic: str, 190 | batch_size: int, 191 | random: bool) -> Tuple[SubsetDataset, 192 | torch.utils.data.DataLoader]: 193 | 194 | if mode not in ["train", "eval", "test"]: 195 | raise ValueError 196 | 197 | if heuristic not in ["null"]: 198 | raise ValueError 199 | 200 | if mode == "train": 201 | dataset = self._train_dataset 202 | if mode == "eval": 203 | dataset = self._eval_dataset 204 | if mode == "test": 205 | dataset = self._test_dataset 206 | 207 | if dataset is None: 208 | raise ValueError("`dataset` is None") 209 | 210 | heuristic_dataloader = misc_utils.get_dataloader( 211 | dataset=dataset, 212 | batch_size=batch_size, 213 | random=random) 214 | 215 | return dataset, heuristic_dataloader 216 | 217 | 218 | class AmazonHelper(SimpleHelper): 219 | def __init__(self, *args, **kwargs) -> None: 220 | super().__init__(*args, **kwargs) 221 | # This is a dictionary that maps split-key 222 | # into the corresponding metadata array. 223 | metadata_arrays_dict = torch.load( 224 | constants.AMAZON_METADATA_ARRAY_FILE_NAME) 225 | # The first column refers to `user` 226 | valid_user_ids = metadata_arrays_dict["val"][:, 0].numpy() 227 | # Create a mapping from `user_id` to the 228 | # data-indices of this `user_id` 229 | valid_user_to_index_map = defaultdict(list) 230 | for index, user_id in enumerate(valid_user_ids): 231 | valid_user_to_index_map[user_id].append(index) 232 | 233 | self._valid_user_to_index_map = valid_user_to_index_map 234 | 235 | def generate_sampled_indices(self, mode: str, size: int) -> List[int]: 236 | if mode not in ["eval"]: 237 | raise ValueError 238 | 239 | # For now, we only sample one index per `user_id` 240 | if size > len(self._valid_user_to_index_map.keys()): 241 | raise ValueError 242 | 243 | sampled_indices = [] 244 | # First sample the `user_id` 245 | sampled_user_ids = np.random.choice( 246 | list(self._valid_user_to_index_map.keys()), 247 | size=size, 248 | replace=False) 249 | 250 | # Then sample data indices from the sampled `user_id` 251 | for sampled_user_id in sampled_user_ids: 252 | _sampled_indices = np.random.choice( 253 | self._valid_user_to_index_map[sampled_user_id], 254 | size=1, 255 | replace=False) 256 | sampled_indices.extend(_sampled_indices) 257 | 258 | return sampled_indices 259 | 260 | def sample_batch_of_heuristic( 261 | self, 262 | mode: str, 263 | heuristic: str, 264 | size: int, 265 | return_raw_data: bool = False) -> np.ndarray: 266 | 267 | if mode not in ["train", "eval", "test"]: 268 | raise ValueError 269 | 270 | if heuristic not in ["null"]: 271 | raise ValueError 272 | 273 | if mode == "train": 274 | dataset = self._train_dataset 275 | if mode == "eval": 276 | dataset = self._eval_dataset 277 | if mode == "test": 278 | dataset = self._test_dataset 279 | 280 | if dataset is None: 281 | raise ValueError("`dataset` is None") 282 | 283 | sampled_indices = self.generate_sampled_indices(mode=mode, size=size) 284 | sampled_data = [dataset[index] for index in sampled_indices] 285 | batched_data = default_data_collator(sampled_data) 286 | if return_raw_data is False: 287 | return batched_data 288 | 289 | return batched_data, sampled_data 290 | 291 | 292 | def save_amazon_metadata(file_name: str) -> None: 293 | dataset = AmazonDataset( 294 | root_dir=constants.Amazon_DATA_DIR, 295 | download=False) 296 | 297 | metadata_arrays_dict = {} 298 | for key in dataset.split_dict.keys(): 299 | datasubset = dataset.get_subset(key) 300 | metadata_arrays_dict[key] = datasubset.metadata_array 301 | print(f"{key:<10}: {len(datasubset):<10} " 302 | f"{datasubset.metadata_array.shape} " 303 | f"{metadata_arrays_dict[key].shape}") 304 | 305 | # print(torch.__version__) 306 | torch.save(metadata_arrays_dict, file_name) 307 | -------------------------------------------------------------------------------- /experiments/influence_helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | import torch 7 | from influence_utils import faiss_utils 8 | from typing import List, Dict, Tuple, Optional, Union, Any 9 | 10 | from experiments import constants 11 | from experiments import misc_utils 12 | from influence_utils import parallel 13 | from influence_utils import nn_influence_utils 14 | 15 | 16 | def load_faiss_index( 17 | trained_on_task_name: str, 18 | train_task_name: str, 19 | ) -> faiss_utils.FAISSIndex: 20 | 21 | if trained_on_task_name not in ["mnli", "mnli-2", "hans", "amazon"]: 22 | raise ValueError 23 | 24 | if train_task_name not in ["mnli", "mnli-2", "hans", "amazon", "anli"]: 25 | raise ValueError 26 | 27 | if trained_on_task_name == "mnli" and train_task_name == "mnli": 28 | faiss_index = faiss_utils.FAISSIndex(768, "Flat") 29 | faiss_index.load(constants.MNLI_FAISS_INDEX_PATH) 30 | 31 | elif trained_on_task_name == "mnli-2" and train_task_name == "mnli-2": 32 | faiss_index = faiss_utils.FAISSIndex(768, "Flat") 33 | faiss_index.load(constants.MNLI2_FAISS_INDEX_PATH) 34 | 35 | elif trained_on_task_name == "hans" and train_task_name == "hans": 36 | faiss_index = faiss_utils.FAISSIndex(768, "Flat") 37 | faiss_index.load(constants.HANS_FAISS_INDEX_PATH) 38 | 39 | elif trained_on_task_name == "mnli-2" and train_task_name == "hans": 40 | faiss_index = faiss_utils.FAISSIndex(768, "Flat") 41 | faiss_index.load(constants.MNLI2_HANS_FAISS_INDEX_PATH) 42 | 43 | elif trained_on_task_name == "hans" and train_task_name == "mnli-2": 44 | faiss_index = faiss_utils.FAISSIndex(768, "Flat") 45 | faiss_index.load(constants.HANS_MNLI2_FAISS_INDEX_PATH) 46 | 47 | elif trained_on_task_name == "amazon" and train_task_name == "amazon": 48 | faiss_index = faiss_utils.FAISSIndex(768, "Flat") 49 | faiss_index.load(constants.Amazon_FAISS_INDEX_PATH) 50 | 51 | elif trained_on_task_name == "mnli" and train_task_name == "anli": 52 | faiss_index = faiss_utils.FAISSIndex(768, "Flat") 53 | faiss_index.load(constants.MNLI_ANLI_FAISS_INDEX_PATH) 54 | 55 | else: 56 | faiss_index = None 57 | 58 | return faiss_index 59 | 60 | 61 | def select_s_test_config( 62 | trained_on_task_name: str, 63 | train_task_name: str, 64 | eval_task_name: str, 65 | ) -> Tuple[float, float, int]: 66 | 67 | if trained_on_task_name != train_task_name: 68 | # Only this setting is supported for now 69 | # basically, the config for this combination 70 | # of `trained_on_task_name` and `eval_task_name` 71 | # would be fine, so not raising issues here for now. 72 | if not ( 73 | all([trained_on_task_name == "mnli-2", 74 | train_task_name == "hans", 75 | eval_task_name == "hans"]) or 76 | all([trained_on_task_name == "mnli", 77 | train_task_name == "anli", 78 | eval_task_name == "anli"]) 79 | ): 80 | raise ValueError("Unsupported as of now") 81 | 82 | if trained_on_task_name not in ["mnli", "mnli-2", "hans", "amazon"]: 83 | raise ValueError 84 | 85 | if eval_task_name not in ["mnli", "mnli-2", "hans", "amazon", "anli"]: 86 | raise ValueError 87 | 88 | # Other settings are not supported as of now 89 | if trained_on_task_name == "mnli" and eval_task_name == "mnli": 90 | s_test_damp = 5e-3 91 | s_test_scale = 1e4 92 | s_test_num_samples = 1000 93 | 94 | elif trained_on_task_name == "mnli-2" and eval_task_name == "mnli-2": 95 | s_test_damp = 5e-3 96 | s_test_scale = 1e4 97 | s_test_num_samples = 1000 98 | 99 | elif trained_on_task_name == "hans" and eval_task_name == "hans": 100 | s_test_damp = 5e-3 101 | s_test_scale = 1e6 102 | s_test_num_samples = 2000 103 | 104 | elif trained_on_task_name == "mnli-2" and eval_task_name == "hans": 105 | s_test_damp = 5e-3 106 | s_test_scale = 1e6 107 | s_test_num_samples = 1000 108 | 109 | elif trained_on_task_name == "hans" and eval_task_name == "mnli-2": 110 | s_test_damp = 5e-3 111 | s_test_scale = 1e6 112 | s_test_num_samples = 2000 113 | 114 | elif trained_on_task_name == "amazon" and eval_task_name == "amazon": 115 | s_test_damp = 5e-3 116 | s_test_scale = 1e4 117 | s_test_num_samples = 1000 118 | 119 | elif trained_on_task_name == "mnli" and eval_task_name == "anli": 120 | if train_task_name != "anli": 121 | raise NotImplementedError 122 | 123 | s_test_damp = 5e-3 124 | s_test_scale = 1e6 125 | s_test_num_samples = 1500 126 | 127 | else: 128 | raise ValueError 129 | 130 | return s_test_damp, s_test_scale, s_test_num_samples 131 | 132 | 133 | def compute_influences_simplified( 134 | k: int, 135 | faiss_index: faiss_utils.FAISSIndex, 136 | model: torch.nn.Module, 137 | inputs: Dict[str, torch.Tensor], 138 | train_dataset: torch.utils.data.DataLoader, 139 | use_parallel: bool, 140 | s_test_damp: float, 141 | s_test_scale: float, 142 | s_test_num_samples: int, 143 | device_ids: Optional[List[int]] = None, 144 | precomputed_s_test: Optional[List[torch.FloatTensor]] = None, 145 | faiss_index_use_mean_features_as_query: bool = False, 146 | ) -> Dict[int, float]: 147 | 148 | # Make sure indices are sorted according to distances 149 | # KNN_distances[( 150 | # KNN_indices.squeeze(axis=0)[ 151 | # np.argsort(KNN_distances.squeeze(axis=0)) 152 | # ] != KNN_indices)] 153 | 154 | params_filter = [ 155 | n for n, p in model.named_parameters() 156 | if not p.requires_grad] 157 | 158 | weight_decay_ignores = [ 159 | "bias", 160 | "LayerNorm.weight"] + [ 161 | n for n, p in model.named_parameters() 162 | if not p.requires_grad] 163 | 164 | if faiss_index is not None: 165 | features = misc_utils.compute_BERT_CLS_feature(model, **inputs) 166 | features = features.cpu().detach().numpy() 167 | 168 | if faiss_index_use_mean_features_as_query is True: 169 | # We use the mean embedding as the final query here 170 | features = features.mean(axis=0, keepdims=True) 171 | 172 | KNN_distances, KNN_indices = faiss_index.search( 173 | k=k, queries=features) 174 | else: 175 | KNN_indices = None 176 | 177 | if not use_parallel: 178 | model.cuda() 179 | batch_train_data_loader = misc_utils.get_dataloader( 180 | train_dataset, 181 | batch_size=1, 182 | random=True) 183 | 184 | instance_train_data_loader = misc_utils.get_dataloader( 185 | train_dataset, 186 | batch_size=1, 187 | random=False) 188 | 189 | influences, _, _ = nn_influence_utils.compute_influences( 190 | n_gpu=1, 191 | device=torch.device("cuda"), 192 | batch_train_data_loader=batch_train_data_loader, 193 | instance_train_data_loader=instance_train_data_loader, 194 | model=model, 195 | test_inputs=inputs, 196 | params_filter=params_filter, 197 | weight_decay=constants.WEIGHT_DECAY, 198 | weight_decay_ignores=weight_decay_ignores, 199 | s_test_damp=s_test_damp, 200 | s_test_scale=s_test_scale, 201 | s_test_num_samples=s_test_num_samples, 202 | train_indices_to_include=KNN_indices, 203 | precomputed_s_test=precomputed_s_test) 204 | else: 205 | if device_ids is None: 206 | raise ValueError("`device_ids` cannot be None") 207 | 208 | influences, _ = parallel.compute_influences_parallel( 209 | # Avoid clash with main process 210 | device_ids=device_ids, 211 | train_dataset=train_dataset, 212 | batch_size=1, 213 | model=model, 214 | test_inputs=inputs, 215 | params_filter=params_filter, 216 | weight_decay=constants.WEIGHT_DECAY, 217 | weight_decay_ignores=weight_decay_ignores, 218 | s_test_damp=s_test_damp, 219 | s_test_scale=s_test_scale, 220 | s_test_num_samples=s_test_num_samples, 221 | train_indices_to_include=KNN_indices, 222 | return_s_test=False, 223 | debug=False) 224 | 225 | return influences 226 | -------------------------------------------------------------------------------- /experiments/misc_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | import os 7 | import torch 8 | import numpy as np 9 | # from tqdm import tqdm 10 | from torch.utils.data.dataloader import DataLoader 11 | from torch.utils.data.sampler import SequentialSampler, RandomSampler 12 | from typing import Tuple, Optional, Union, Any, Dict, List, Callable 13 | from transformers import ( 14 | AutoModelForSequenceClassification, 15 | AutoTokenizer, 16 | BertTokenizer, 17 | BertForSequenceClassification, 18 | GlueDataTrainingArguments, 19 | Trainer, 20 | DataCollator, 21 | default_data_collator) 22 | 23 | from influence_utils import glue_utils 24 | from experiments import constants 25 | from experiments.data_utils import CustomGlueDataset 26 | 27 | 28 | def sort_dict_keys_by_vals(d: Dict[int, float]) -> List[int]: 29 | sorted_items = sorted(list(d.items()), 30 | key=lambda pair: pair[1]) 31 | return [pair[0] for pair in sorted_items] 32 | 33 | 34 | def sort_dict_keys_by_vals_with_conditions( 35 | d: Dict[int, float], 36 | condition_func: Callable[[Tuple[int, float]], bool] 37 | ) -> List[int]: 38 | 39 | sorted_items = sorted(list(d.items()), 40 | key=lambda pair: pair[1]) 41 | return [pair[0] for pair in sorted_items 42 | if condition_func(pair)] 43 | 44 | 45 | def get_helpful_harmful_indices_from_influences_dict( 46 | d: Dict[int, float], 47 | n: Optional[int] = None, 48 | ) -> Tuple[List[int], List[int]]: 49 | 50 | helpful_indices = sort_dict_keys_by_vals_with_conditions( 51 | d, condition_func=lambda k_v: k_v[1] < 0.0) 52 | harmful_indices = sort_dict_keys_by_vals_with_conditions( 53 | d, condition_func=lambda k_v: k_v[1] > 0.0)[::-1] 54 | 55 | if n is not None: 56 | if len(helpful_indices) < n: 57 | raise ValueError( 58 | f"`helpful_indices` have only " 59 | f"{len(helpful_indices)} elememts " 60 | f"whereas {n} is needed") 61 | 62 | if len(harmful_indices) < n: 63 | raise ValueError( 64 | f"`harmful_indices` have only " 65 | f"{len(harmful_indices)} elememts " 66 | f"whereas {n} is needed") 67 | 68 | helpful_indices = helpful_indices[:n] 69 | harmful_indices = harmful_indices[:n] 70 | 71 | return helpful_indices, harmful_indices 72 | 73 | 74 | def compute_BERT_CLS_feature( 75 | model, 76 | input_ids=None, 77 | attention_mask=None, 78 | token_type_ids=None, 79 | position_ids=None, 80 | head_mask=None, 81 | inputs_embeds=None, 82 | labels=None, 83 | output_attentions=None, 84 | output_hidden_states=None, 85 | ) -> torch.FloatTensor: 86 | r""" 87 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): 88 | Labels for computing the sequence classification/regression loss. 89 | Indices should be in :obj:`[0, ..., config.num_labels - 1]`. 90 | If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), 91 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). 92 | """ 93 | if model.training is True: 94 | raise ValueError 95 | 96 | outputs = model.bert( 97 | input_ids, 98 | attention_mask=attention_mask, 99 | token_type_ids=token_type_ids, 100 | position_ids=position_ids, 101 | head_mask=head_mask, 102 | inputs_embeds=inputs_embeds, 103 | output_attentions=output_attentions, 104 | output_hidden_states=output_hidden_states, 105 | ) 106 | 107 | pooled_output = outputs[1] 108 | 109 | return model.dropout(pooled_output) 110 | 111 | 112 | def create_tokenizer_and_model( 113 | model_name_or_path: str, 114 | freeze_parameters: bool = True 115 | ) -> Tuple[BertTokenizer, BertForSequenceClassification]: 116 | if model_name_or_path is None: 117 | raise ValueError 118 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 119 | model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path) 120 | 121 | model.eval() 122 | if freeze_parameters is True: 123 | glue_utils.freeze_BERT_parameters(model) 124 | 125 | return tokenizer, model 126 | 127 | 128 | def create_datasets( 129 | task_name: str, 130 | tokenizer: BertTokenizer, 131 | data_dir: Optional[str] = None, 132 | create_test_dataset: bool = False, 133 | ) -> Union[Tuple[CustomGlueDataset, CustomGlueDataset], 134 | Tuple[CustomGlueDataset, CustomGlueDataset, CustomGlueDataset]]: 135 | if task_name not in ["mnli", "mnli-2", "hans", "amazon", "anli"]: 136 | raise ValueError(f"Unrecognized task {task_name}") 137 | 138 | if data_dir is None: 139 | if task_name in ["mnli", "mnli-2"]: 140 | data_dir = constants.GLUE_DATA_DIR 141 | if task_name in ["hans"]: 142 | data_dir = constants.HANS_DATA_DIR 143 | if task_name in ["amazon"]: 144 | data_dir = constants.Amazon_DATA_DIR 145 | if task_name in ["anli"]: 146 | data_dir = constants.ANLI_DATA_DIR 147 | 148 | data_args = GlueDataTrainingArguments( 149 | task_name=task_name, 150 | data_dir=data_dir, 151 | max_seq_length=128) 152 | 153 | train_dataset = CustomGlueDataset( 154 | args=data_args, 155 | tokenizer=tokenizer, 156 | mode="train") 157 | 158 | eval_dataset = CustomGlueDataset( 159 | args=data_args, 160 | tokenizer=tokenizer, 161 | mode="dev") 162 | 163 | if create_test_dataset is False: 164 | return train_dataset, eval_dataset 165 | else: 166 | test_dataset = CustomGlueDataset( 167 | args=data_args, 168 | tokenizer=tokenizer, 169 | mode="test") 170 | 171 | return train_dataset, eval_dataset, test_dataset 172 | 173 | 174 | def predict(trainer: Trainer, 175 | model: torch.nn.Module, 176 | inputs: Dict[str, Union[torch.Tensor, Any]], 177 | ) -> Tuple[np.ndarray, np.ndarray, Optional[float]]: 178 | 179 | if trainer.args.past_index >= 0: 180 | raise ValueError 181 | 182 | has_labels = any( 183 | inputs.get(k) is not None for k in 184 | ["labels", "lm_labels", "masked_lm_labels"]) 185 | 186 | for k, v in inputs.items(): 187 | if isinstance(v, torch.Tensor): 188 | inputs[k] = v.to(trainer.args.device) 189 | 190 | step_eval_loss = None 191 | with torch.no_grad(): 192 | outputs = model(**inputs) 193 | if has_labels: 194 | step_eval_loss, logits = outputs[:2] 195 | else: 196 | logits = outputs[0] 197 | 198 | preds = logits.detach() 199 | preds = preds.cpu().numpy() 200 | if inputs.get("labels") is not None: 201 | label_ids = inputs["labels"].detach() 202 | label_ids = label_ids.cpu().numpy() 203 | 204 | if step_eval_loss is not None: 205 | step_eval_loss = step_eval_loss.mean().item() 206 | 207 | return preds, label_ids, step_eval_loss 208 | 209 | 210 | def get_dataloader(dataset: CustomGlueDataset, 211 | batch_size: int, 212 | random: bool = False, 213 | data_collator: Optional[DataCollator] = None 214 | ) -> DataLoader: 215 | if data_collator is None: 216 | data_collator = default_data_collator 217 | 218 | if random is True: 219 | sampler = RandomSampler(dataset) 220 | else: 221 | sampler = SequentialSampler(dataset) 222 | 223 | data_loader = DataLoader( 224 | dataset, 225 | sampler=sampler, 226 | batch_size=batch_size, 227 | collate_fn=data_collator, 228 | ) 229 | 230 | return data_loader 231 | 232 | 233 | def remove_file_if_exists(file_name: str) -> None: 234 | if os.path.exists(file_name): 235 | os.remove(file_name) 236 | else: 237 | print("The file does not exist") 238 | 239 | 240 | def is_prediction_correct( 241 | trainer: Trainer, 242 | model: torch.nn.Module, 243 | inputs: Dict[str, Union[torch.Tensor, Any]]) -> bool: 244 | 245 | preds, label_ids, step_eval_loss = predict( 246 | trainer=trainer, 247 | model=model, 248 | inputs=inputs) 249 | 250 | if preds.shape[0] != 1: 251 | raise ValueError("This function only works on instances.") 252 | 253 | return bool((preds.argmax(axis=-1) == label_ids).all()) 254 | 255 | 256 | def move_inputs_to_device( 257 | inputs: Dict[str, Any], 258 | device: torch.device 259 | ) -> None: 260 | for k, v in inputs.items(): 261 | if isinstance(v, torch.Tensor): 262 | inputs[k] = v.to(device) 263 | -------------------------------------------------------------------------------- /experiments/mnli.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | import os 7 | import time 8 | import torch 9 | import subprocess 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | import transformers 13 | from tqdm import tqdm 14 | from glob import glob 15 | from copy import deepcopy 16 | from contexttimer import Timer 17 | from collections import defaultdict 18 | from transformers import TrainingArguments 19 | from transformers import default_data_collator 20 | from typing import List, Dict, Tuple, Optional, Union, Any 21 | 22 | from experiments import constants 23 | from experiments import mnli_utils 24 | from experiments import misc_utils 25 | from experiments import remote_utils 26 | from experiments import influence_helpers 27 | from experiments import hans 28 | from influence_utils import nn_influence_utils 29 | from experiments.data_utils import ( 30 | glue_output_modes, 31 | glue_compute_metrics) 32 | 33 | MNLI_TRAINING_SCRIPT_NAME = "scripts/run_MNLI.20200913.sh" 34 | NUM_DATAPOINTS_TO_REMOVE_CHOICES = [1, 5, 25] 35 | 36 | CORRECT_INDICES = sorted([ 37 | # e.g., `KNN-recall.only-correct.50.0.pth.g0301.ll.unc.edu` 38 | int(f.split("/")[-1].split(".")[3]) 39 | for f in glob(os.path.join( 40 | constants.MNLI_RETRAINING_INFLUENCE_OUTPUT_BASE_DIR, 41 | "*only-correct*") 42 | ) 43 | ]) 44 | INCORRECT_INDICES = sorted([ 45 | # e.g., `KNN-recall.only-correct.50.0.pth.g0301.ll.unc.edu` 46 | int(f.split("/")[-1].split(".")[3]) 47 | for f in glob(os.path.join( 48 | constants.MNLI_RETRAINING_INFLUENCE_OUTPUT_BASE_DIR, 49 | "*only-incorrect*") 50 | ) 51 | ]) 52 | 53 | 54 | def run_retraining_main( 55 | mode: str, 56 | num_examples_to_test: int): 57 | 58 | if mode not in ["full", "KNN-1000", "KNN-10000", "random"]: 59 | raise ValueError(f"Unrecognized `mode` {mode}") 60 | 61 | for example_relative_index in range(num_examples_to_test): 62 | for correct_mode in ["correct", "incorrect"]: 63 | if correct_mode == "correct": 64 | example_index = CORRECT_INDICES[example_relative_index] 65 | if correct_mode == "incorrect": 66 | example_index = INCORRECT_INDICES[example_relative_index] 67 | 68 | if mode in ["full"]: 69 | # Load file from local or sync from remote 70 | file_name = os.path.join( 71 | constants.MNLI_RETRAINING_INFLUENCE_OUTPUT_BASE_DIR, 72 | f"KNN-recall.only-{correct_mode}.50.{example_index}" 73 | f".pth.g0301.ll.unc.edu") 74 | 75 | influences_dict = torch.load(file_name) 76 | if example_index != influences_dict["test_index"]: 77 | raise ValueError 78 | 79 | if (correct_mode == "correct" and 80 | influences_dict["correct"] is not True or 81 | correct_mode == "incorrect" and 82 | influences_dict["correct"] is True): 83 | raise ValueError 84 | 85 | helpful_indices, harmful_indices = ( 86 | misc_utils.get_helpful_harmful_indices_from_influences_dict( 87 | influences_dict["influences"])) 88 | 89 | indices_dict = { 90 | "helpful": helpful_indices, 91 | "harmful": harmful_indices} 92 | 93 | if mode in ["KNN-1000", "KNN-10000"]: 94 | if mode == "KNN-1000": 95 | kNN_k = 1000 96 | if mode == "KNN-10000": 97 | kNN_k = 10000 98 | 99 | file_name = os.path.join( 100 | constants.MNLI_RETRAINING_INFLUENCE_OUTPUT_BASE_DIR2, 101 | f"visualization" 102 | f".only-{correct_mode}" 103 | f".5.mnli-mnli-None-mnli" 104 | f".{kNN_k}.True.pth.g0306.ll.unc.edu") 105 | 106 | influences_dict = torch.load(file_name)[example_relative_index] 107 | if example_index != influences_dict["index"]: 108 | raise ValueError 109 | 110 | helpful_indices, harmful_indices = ( 111 | misc_utils.get_helpful_harmful_indices_from_influences_dict( 112 | influences_dict["influences"])) 113 | 114 | indices_dict = { 115 | "helpful": helpful_indices, 116 | "harmful": harmful_indices} 117 | 118 | if mode == "random": 119 | # Get indices corresponding to each label 120 | label_to_indices = mnli_utils.get_label_to_indices_map() 121 | np.random.shuffle(label_to_indices["neutral"]) 122 | np.random.shuffle(label_to_indices["entailment"]) 123 | np.random.shuffle(label_to_indices["contradiction"]) 124 | indices_dict = { 125 | "neutral": label_to_indices["neutral"], 126 | "entailment": label_to_indices["entailment"], 127 | "contradiction": label_to_indices["contradiction"], 128 | } 129 | 130 | for tag, indices in indices_dict.items(): 131 | for num_data_points_to_remove in NUM_DATAPOINTS_TO_REMOVE_CHOICES: 132 | if len(indices) < num_data_points_to_remove: 133 | raise ValueError(f"`indices` have only {len(indices)} elememts " 134 | f"whereas {num_data_points_to_remove} is needed") 135 | 136 | run_one_retraining( 137 | indices=indices[:num_data_points_to_remove], 138 | dir_name=( 139 | f"./retraining-remove-" 140 | f"{example_index}-" 141 | f"{correct_mode}-" 142 | f"{mode}-" 143 | f"{tag}-" 144 | f"{num_data_points_to_remove}")) 145 | 146 | 147 | def run_one_retraining( 148 | indices: List[int], 149 | dir_name: str, 150 | ) -> None: 151 | mnli_utils.create_one_set_of_data_for_retraining( 152 | dir_name=dir_name, 153 | indices_to_remove=indices) 154 | output_dir = os.path.join(dir_name, "output_dir") 155 | subprocess.check_call([ 156 | "bash", 157 | MNLI_TRAINING_SCRIPT_NAME, 158 | dir_name, output_dir 159 | ]) 160 | # client = remote_utils.ScpClient() 161 | # client.scp_file_to_remote( 162 | # local_file_name=dir_name, 163 | # remote_file_name=os.path.join( 164 | # constants.REMOTE_DEFAULT_REMOTE_BASE_DIR, 165 | # f"{dir_name}.{client.host_name}"), 166 | # # This is a folder 167 | # recursive=True) 168 | 169 | 170 | def run_full_influence_functions( 171 | mode: str, 172 | num_examples_to_test: int, 173 | s_test_num_samples: int = 1000 174 | ) -> Dict[int, Dict[str, Any]]: 175 | 176 | if mode not in ["only-correct", "only-incorrect"]: 177 | raise ValueError(f"Unrecognized mode {mode}") 178 | 179 | tokenizer, model = misc_utils.create_tokenizer_and_model( 180 | constants.MNLI_MODEL_PATH) 181 | 182 | (mnli_train_dataset, 183 | mnli_eval_dataset) = misc_utils.create_datasets( 184 | task_name="mnli", 185 | tokenizer=tokenizer) 186 | 187 | batch_train_data_loader = misc_utils.get_dataloader( 188 | mnli_train_dataset, 189 | batch_size=128, 190 | random=True) 191 | 192 | instance_train_data_loader = misc_utils.get_dataloader( 193 | mnli_train_dataset, 194 | batch_size=1, 195 | random=False) 196 | 197 | eval_instance_data_loader = misc_utils.get_dataloader( 198 | dataset=mnli_eval_dataset, 199 | batch_size=1, 200 | random=False) 201 | 202 | output_mode = glue_output_modes["mnli"] 203 | 204 | def build_compute_metrics_fn(task_name: str): 205 | def compute_metrics_fn(p): 206 | if output_mode == "classification": 207 | preds = np.argmax(p.predictions, axis=1) 208 | elif output_mode == "regression": 209 | preds = np.squeeze(p.predictions) 210 | return glue_compute_metrics(task_name, preds, p.label_ids) 211 | 212 | return compute_metrics_fn 213 | 214 | # Most of these arguments are placeholders 215 | # and are not really used at all, so ignore 216 | # the exact values of these. 217 | trainer = transformers.Trainer( 218 | model=model, 219 | args=TrainingArguments( 220 | output_dir="./tmp-output", 221 | per_device_train_batch_size=128, 222 | per_device_eval_batch_size=128, 223 | learning_rate=5e-5, 224 | logging_steps=100), 225 | data_collator=default_data_collator, 226 | train_dataset=mnli_train_dataset, 227 | eval_dataset=mnli_eval_dataset, 228 | compute_metrics=build_compute_metrics_fn("mnli"), 229 | ) 230 | 231 | params_filter = [ 232 | n for n, p in model.named_parameters() 233 | if not p.requires_grad] 234 | 235 | weight_decay_ignores = [ 236 | "bias", 237 | "LayerNorm.weight"] + [ 238 | n for n, p in model.named_parameters() 239 | if not p.requires_grad] 240 | 241 | model.cuda() 242 | num_examples_tested = 0 243 | outputs_collections = {} 244 | for test_index, test_inputs in enumerate(eval_instance_data_loader): 245 | if num_examples_tested >= num_examples_to_test: 246 | break 247 | 248 | # Skip when we only want cases of correction prediction but the 249 | # prediction is incorrect, or vice versa 250 | prediction_is_correct = misc_utils.is_prediction_correct( 251 | trainer=trainer, 252 | model=model, 253 | inputs=test_inputs) 254 | 255 | if mode == "only-correct" and prediction_is_correct is False: 256 | continue 257 | 258 | if mode == "only-incorrect" and prediction_is_correct is True: 259 | continue 260 | 261 | with Timer() as timer: 262 | influences, _, s_test = nn_influence_utils.compute_influences( 263 | n_gpu=1, 264 | device=torch.device("cuda"), 265 | batch_train_data_loader=batch_train_data_loader, 266 | instance_train_data_loader=instance_train_data_loader, 267 | model=model, 268 | test_inputs=test_inputs, 269 | params_filter=params_filter, 270 | weight_decay=constants.WEIGHT_DECAY, 271 | weight_decay_ignores=weight_decay_ignores, 272 | s_test_damp=5e-3, 273 | s_test_scale=1e4, 274 | s_test_num_samples=s_test_num_samples, 275 | train_indices_to_include=None, 276 | s_test_iterations=1, 277 | precomputed_s_test=None) 278 | 279 | outputs = { 280 | "test_index": test_index, 281 | "influences": influences, 282 | "s_test": s_test, 283 | "time": timer.elapsed, 284 | "correct": prediction_is_correct, 285 | } 286 | num_examples_tested += 1 287 | outputs_collections[test_index] = outputs 288 | 289 | remote_utils.save_and_mirror_scp_to_remote( 290 | object_to_save=outputs, 291 | file_name=f"KNN-recall.{mode}.{num_examples_to_test}.{test_index}.pth") 292 | print(f"Status: #{test_index} | {num_examples_tested} / {num_examples_to_test}") 293 | 294 | return outputs_collections 295 | 296 | 297 | def imitator_main(mode: str, num_examples_to_test: int) -> List[Dict[str, Any]]: 298 | if mode not in ["only-correct", "only-incorrect"]: 299 | raise ValueError(f"Unrecognized mode {mode}") 300 | 301 | task_tokenizer, task_model = misc_utils.create_tokenizer_and_model( 302 | constants.MNLI_MODEL_PATH) 303 | 304 | imitator_tokenizer, imitator_model = misc_utils.create_tokenizer_and_model( 305 | constants.MNLI_IMITATOR_MODEL_PATH) 306 | 307 | (mnli_train_dataset, 308 | mnli_eval_dataset) = misc_utils.create_datasets( 309 | task_name="mnli", 310 | tokenizer=task_tokenizer) 311 | 312 | task_model.cuda() 313 | imitator_model.cuda() 314 | if task_model.training is True or imitator_model.training is True: 315 | raise ValueError("One of the model is in training mode") 316 | print(task_model.device, imitator_model.device) 317 | 318 | # Most of these arguments are placeholders 319 | # and are not really used at all, so ignore 320 | # the exact values of these. 321 | trainer = transformers.Trainer( 322 | model=task_model, 323 | args=TrainingArguments( 324 | output_dir="./tmp-output", 325 | per_device_train_batch_size=128, 326 | per_device_eval_batch_size=128, 327 | learning_rate=5e-5, 328 | logging_steps=100), 329 | ) 330 | 331 | eval_instance_data_loader = misc_utils.get_dataloader( 332 | mnli_eval_dataset, 333 | batch_size=1, 334 | data_collator=default_data_collator) 335 | 336 | train_inputs_collections = torch.load( 337 | constants.MNLI_TRAIN_INPUT_COLLECTIONS_PATH) 338 | 339 | inputs_by_label: Dict[str, List[int]] = defaultdict(list) 340 | for i in range(len(train_inputs_collections)): 341 | label = mnli_train_dataset.label_list[ 342 | train_inputs_collections[i]["labels"]] 343 | inputs_by_label[label].append(i) 344 | 345 | outputs_collections = [] 346 | for i, test_inputs in enumerate(eval_instance_data_loader): 347 | if mode == "only-correct" and i not in CORRECT_INDICES[:num_examples_to_test]: 348 | continue 349 | if mode == "only-incorrect" and i not in INCORRECT_INDICES[:num_examples_to_test]: 350 | continue 351 | 352 | start_time = time.time() 353 | for using_ground_truth in [True, False]: 354 | outputs = run_one_imitator_experiment( 355 | task_model=task_model, 356 | imitator_model=imitator_model, 357 | test_inputs=test_inputs, 358 | trainer=trainer, 359 | train_dataset=mnli_train_dataset, 360 | train_inputs_collections=train_inputs_collections, 361 | inputs_by_label=inputs_by_label, 362 | finetune_using_ground_truth_label=using_ground_truth) 363 | outputs["index"] = i 364 | outputs_collections.append(outputs) 365 | 366 | end_time = time.time() 367 | print(f"#{len(outputs_collections)}/{len(outputs_collections)}: " 368 | f"Elapsed {(end_time - start_time) / 60:.2f}") 369 | 370 | torch.save( 371 | outputs_collections, 372 | f"imiator_experiments.{mode}.{num_examples_to_test}.pt") 373 | 374 | return outputs_collections 375 | 376 | 377 | def run_one_imitator_experiment( 378 | task_model: torch.nn.Module, 379 | imitator_model: torch.nn.Module, 380 | test_inputs, 381 | trainer: transformers.Trainer, 382 | train_dataset: torch.utils.data.Dataset, 383 | train_inputs_collections: List, 384 | inputs_by_label: Dict[str, List[int]], 385 | sample_size: int = 10, 386 | num_nearest_neighbors: int = 10000, 387 | finetune_using_ground_truth_label: bool = False 388 | ) -> Dict[str, Any]: 389 | 390 | imitator_test_inputs = _make_imitator_inputs( 391 | trainer=trainer, task_model=task_model, inputs=test_inputs) 392 | # if labels[0] != logits.argmax(axis=1)[0]: 393 | # break 394 | faiss_index = influence_helpers.load_faiss_index( 395 | trained_on_task_name="mnli", 396 | train_task_name="mnli") 397 | 398 | s_test_damp, s_test_scale, s_test_num_samples = influence_helpers.select_s_test_config( 399 | trained_on_task_name="mnli", 400 | train_task_name="mnli", 401 | eval_task_name="mnli") 402 | 403 | influences = influence_helpers.compute_influences_simplified( 404 | k=num_nearest_neighbors, 405 | faiss_index=faiss_index, 406 | model=task_model, 407 | inputs=test_inputs, 408 | train_dataset=train_dataset, 409 | use_parallel=False, 410 | s_test_damp=s_test_damp, 411 | s_test_scale=s_test_scale, 412 | s_test_num_samples=s_test_num_samples) 413 | 414 | data_indices = ( 415 | np.random.choice(inputs_by_label["neutral"], 416 | size=sample_size, 417 | replace=False).tolist() + # noqa 418 | np.random.choice(inputs_by_label["entailment"], 419 | size=sample_size, 420 | replace=False).tolist() + # noqa 421 | np.random.choice(inputs_by_label["contradiction"], 422 | size=sample_size, 423 | replace=False).tolist() + # noqa 424 | misc_utils.sort_dict_keys_by_vals(influences)[:sample_size] + # noqa 425 | misc_utils.sort_dict_keys_by_vals(influences)[-sample_size:] 426 | ) 427 | 428 | data_tags = ( 429 | ["random-neutral" for _ in range(sample_size)] + # noqa 430 | ["random-entailment" for _ in range(sample_size)] + # noqa 431 | ["random-contradiction" for _ in range(sample_size)] + # noqa 432 | ["most-negative-influential" for _ in range(sample_size)] + # noqa 433 | ["most-positive-influential" for _ in range(sample_size)] 434 | ) 435 | 436 | learning_rates = np.logspace(-5, -2.5, 50) 437 | losses = compute_new_imitator_losses( 438 | trainer=trainer, 439 | tags=data_tags, 440 | indices=data_indices, 441 | task_model=task_model, 442 | imitator_model=imitator_model, 443 | learning_rates=learning_rates, 444 | imitator_test_inputs=imitator_test_inputs, 445 | train_inputs_collections=train_inputs_collections, 446 | finetune_using_ground_truth_label=finetune_using_ground_truth_label) 447 | 448 | return { 449 | "losses": losses, 450 | "influences": influences, 451 | "test_inputs": test_inputs, 452 | "learning_rates": learning_rates, 453 | "imitator_test_inputs": imitator_test_inputs, 454 | "finetune_using_ground_truth_label": finetune_using_ground_truth_label, 455 | } 456 | 457 | 458 | def compute_new_imitator_losses( 459 | indices: List[int], 460 | tags: List[str], 461 | task_model: torch.nn.Module, 462 | imitator_model: torch.nn.Module, 463 | trainer: transformers.Trainer, 464 | learning_rates: Union[np.ndarray, List[float]], 465 | imitator_test_inputs: Dict[str, torch.Tensor], 466 | train_inputs_collections: List[Dict[str, torch.Tensor]], 467 | finetune_using_ground_truth_label: bool = False, 468 | ) -> Dict[str, List[List[float]]]: 469 | 470 | params_filter = [ 471 | n for n, p in imitator_model.named_parameters() 472 | if not p.requires_grad] 473 | 474 | weight_decay_ignores = [ 475 | "bias", 476 | "LayerNorm.weight"] + [ 477 | n for n, p in imitator_model.named_parameters() 478 | if not p.requires_grad] 479 | 480 | losses = defaultdict(list) 481 | for index, tag in zip(tqdm(indices), tags): 482 | if finetune_using_ground_truth_label is True: 483 | imitator_train_inputs = train_inputs_collections[index] 484 | else: 485 | imitator_train_inputs = _make_imitator_inputs( 486 | trainer=trainer, 487 | task_model=task_model, 488 | inputs=train_inputs_collections[index]) 489 | 490 | _losses = [] 491 | gradients_z = None 492 | for lr in learning_rates: 493 | # Re-use `gradients_z` 494 | new_imitator_model, gradients_z = hans.pseudo_gradient_step( 495 | model=imitator_model, 496 | inputs=imitator_train_inputs, 497 | learning_rate=lr, 498 | params_filter=params_filter, 499 | weight_decay_ignores=weight_decay_ignores, 500 | precomputed_gradients_z=gradients_z) 501 | _, _, imitator_loss = misc_utils.predict( 502 | trainer=trainer, 503 | model=new_imitator_model, 504 | inputs=imitator_test_inputs) 505 | _losses.append(imitator_loss) 506 | 507 | losses[tag].append(_losses) 508 | 509 | return losses 510 | 511 | 512 | def _make_imitator_inputs( 513 | trainer: transformers.Trainer, 514 | task_model: torch.nn.Module, 515 | inputs: Dict[str, torch.Tensor], 516 | ) -> Dict[str, torch.Tensor]: 517 | logits, _, _ = misc_utils.predict( 518 | trainer=trainer, model=task_model, inputs=inputs) 519 | imitator_inputs = deepcopy(inputs) 520 | imitator_inputs["labels"] = torch.tensor(logits.argmax(axis=1)) 521 | return imitator_inputs 522 | 523 | 524 | def plot_Xs_and_Ys_dict( 525 | Xs: List[float], 526 | Ys_dict: Dict[str, List[List[float]]] 527 | ) -> None: 528 | # plt.rcParams["figure.figsize"] = (10, 10) 529 | color_map = { 530 | "random-neutral": "grey", 531 | "random-entailment": "salmon", 532 | "random-contradiction": "skyblue", 533 | "most-positive-influential": "darkred", 534 | "most-negative-influential": "steelblue"} 535 | 536 | legends = [] 537 | for tag in Ys_dict.keys(): 538 | if tag not in color_map.keys(): 539 | raise ValueError 540 | 541 | legends.append(tag) 542 | color = color_map[tag] 543 | data = np.array(Ys_dict[tag]) 544 | is_random_data_point = "random" in tag 545 | 546 | if data.shape[0] != 1: 547 | data_mean = data.mean(axis=0) 548 | data_max = data.max(axis=0) 549 | data_min = data.min(axis=0) 550 | plt.plot(Xs, data_mean, 551 | color=color, 552 | linestyle=("--" if is_random_data_point else None)) 553 | 554 | plt.fill_between(Xs, data_max, data_min, 555 | alpha=0.1, 556 | color=color) 557 | else: 558 | plt.plot(Xs, data[0, ...], color=color) 559 | 560 | plt.xscale("log") 561 | plt.yscale("log") 562 | plt.xlabel("learning rate", fontsize=30) 563 | plt.ylabel("Loss", fontsize=30) 564 | plt.legend(legends, fontsize=15) 565 | plt.title("Loss of the Imitator Model", fontsize=30) 566 | # plt.savefig("./20200719-fig1.pdf") 567 | -------------------------------------------------------------------------------- /experiments/mnli_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | import os 7 | import torch 8 | import shutil 9 | import pandas as pd 10 | from transformers import ( 11 | BertTokenizer, 12 | InputFeatures, 13 | default_data_collator) 14 | from typing import Tuple, Optional, Union, List, Dict 15 | from experiments import constants 16 | 17 | 18 | def decode_one_example( 19 | tokenizer: BertTokenizer, 20 | label_list: List[str], 21 | inputs: Dict[str, torch.Tensor], 22 | logits: Optional[torch.FloatTensor] = None 23 | ) -> Union[Tuple[str, str], Tuple[str, str, str]]: 24 | 25 | if inputs["input_ids"].shape[0] != 1: 26 | raise ValueError 27 | 28 | X = tokenizer.decode(inputs["input_ids"][0]) 29 | Y = label_list[inputs["labels"].item()] 30 | if logits is not None: 31 | _Y_hat = logits.argmax(dim=-1).item() 32 | Y_hat = label_list[_Y_hat] 33 | return X, Y, Y_hat 34 | else: 35 | return X, Y 36 | 37 | 38 | def visualize(tokenizer: BertTokenizer, 39 | label_list: List[str], 40 | inputs: Dict[str, torch.Tensor],) -> None: 41 | X, Y = decode_one_example( 42 | tokenizer=tokenizer, 43 | label_list=label_list, 44 | inputs=inputs, 45 | logits=None) 46 | premise, hypothesis = X.split("[CLS]")[1].split("[SEP]")[:2] 47 | print(f"\tP: {premise.strip()}\n\tH: {hypothesis.strip()}\n\tL: {Y}") 48 | 49 | 50 | def get_data_from_features_or_inputs( 51 | tokenizer: BertTokenizer, 52 | label_list: List[str], 53 | feature: Optional[InputFeatures] = None, 54 | inputs: Optional[Dict[str, torch.Tensor]] = None, 55 | ) -> Tuple[str, str, str]: 56 | 57 | if feature is not None and inputs is None: 58 | inputs = default_data_collator([feature]) 59 | 60 | elif feature is None and inputs is not None: 61 | pass 62 | 63 | elif feature is None and inputs is None: 64 | raise ValueError 65 | 66 | elif feature is not None and inputs is not None: 67 | raise ValueError 68 | 69 | X, Y = decode_one_example( 70 | tokenizer=tokenizer, 71 | label_list=label_list, 72 | inputs=inputs, 73 | logits=None) 74 | premise, hypothesis = X.split("[CLS]")[1].split("[SEP]")[:2] 75 | return premise.strip(), hypothesis.strip(), Y 76 | 77 | 78 | def create_one_set_of_data_for_retraining( 79 | dir_name: str, 80 | indices_to_remove: List[int], 81 | ) -> None: 82 | """Create the training data and evaluation data 83 | 84 | 1. Load the training data, remove lines based in inputs, and write. 85 | 2. Copy the evaluation data into the same directory. 86 | 87 | """ 88 | with open(constants.MNLI_TRAIN_FILE_NAME) as f: 89 | lines = f.readlines() 90 | 91 | if not os.path.isdir(dir_name): 92 | os.makedirs(dir_name) 93 | else: 94 | raise ValueError 95 | 96 | with open(os.path.join(dir_name, "train.tsv"), "w") as f: 97 | lines_to_write = [ 98 | # "-1" because of the header line 99 | l for i, l in enumerate(lines) 100 | if i - 1 not in indices_to_remove] 101 | 102 | f.write("".join(lines_to_write)) 103 | print(f"Wrote {len(lines_to_write)} to {dir_name}") 104 | 105 | shutil.copyfile(constants.MNLI_EVAL_MATCHED_FILE_NAME, 106 | os.path.join(dir_name, "dev_matched.tsv")) 107 | shutil.copyfile(constants.MNLI_EVAL_MISMATCHED_FILE_NAME, 108 | os.path.join(dir_name, "dev_mismatched.tsv")) 109 | 110 | 111 | def get_label_to_indices_map() -> Dict[str, List[int]]: 112 | with open(constants.MNLI_TRAIN_FILE_NAME) as f: 113 | lines = f.readlines() 114 | 115 | data_frame = pd.DataFrame( 116 | [line.strip().split("\t") for line in lines[1:]], 117 | columns=lines[0].strip().split("\t")) 118 | 119 | return { 120 | "contradiction": ( 121 | data_frame[data_frame.gold_label == "contradiction"].index).tolist(), 122 | "entailment": ( 123 | data_frame[data_frame.gold_label == "entailment"].index).tolist(), 124 | "neutral": ( 125 | data_frame[data_frame.gold_label == "neutral"].index).tolist(), 126 | } 127 | 128 | 129 | def get_label_to_indices_map_2() -> Dict[str, List[int]]: 130 | """Slower, deprecated""" 131 | contradiction_indices = [] 132 | entailment_indices = [] 133 | neutral_indices = [] 134 | train_inputs_collections = torch.load(constants.MNLI_TRAIN_INPUT_COLLECTIONS_PATH) 135 | for index, train_inputs in enumerate(train_inputs_collections): 136 | if train_inputs["labels"].item() == 0: 137 | contradiction_indices.append(index) 138 | if train_inputs["labels"].item() == 1: 139 | entailment_indices.append(index) 140 | if train_inputs["labels"].item() == 2: 141 | neutral_indices.append(index) 142 | 143 | return { 144 | "contradiction": contradiction_indices, 145 | "entailment": entailment_indices, 146 | "neutral": neutral_indices, 147 | } 148 | -------------------------------------------------------------------------------- /experiments/s_test_speedup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | import sys 7 | import torch 8 | import transformers 9 | import numpy as np 10 | from contexttimer import Timer 11 | from typing import List, Dict, Any 12 | from transformers import GlueDataset 13 | from transformers import TrainingArguments 14 | from transformers import default_data_collator 15 | 16 | from influence_utils import parallel 17 | from influence_utils import faiss_utils 18 | from influence_utils import nn_influence_utils 19 | from influence_utils.nn_influence_utils import compute_s_test 20 | from experiments import constants 21 | from experiments import misc_utils 22 | from experiments import remote_utils 23 | from experiments.data_utils import ( 24 | glue_output_modes, 25 | glue_compute_metrics) 26 | 27 | 28 | def one_experiment( 29 | model: torch.nn.Module, 30 | train_dataset: GlueDataset, 31 | test_inputs: Dict[str, torch.Tensor], 32 | batch_size: int, 33 | random: bool, 34 | n_gpu: int, 35 | device: torch.device, 36 | damp: float, 37 | scale: float, 38 | num_samples: int, 39 | ) -> List[torch.Tensor]: 40 | 41 | params_filter = [ 42 | n for n, p in model.named_parameters() 43 | if not p.requires_grad] 44 | 45 | weight_decay_ignores = [ 46 | "bias", 47 | "LayerNorm.weight"] + [ 48 | n for n, p in model.named_parameters() 49 | if not p.requires_grad] 50 | 51 | # Make sure each dataloader is re-initialized 52 | batch_train_data_loader = misc_utils.get_dataloader( 53 | dataset=train_dataset, 54 | batch_size=batch_size, 55 | random=random) 56 | 57 | s_test = compute_s_test( 58 | n_gpu=n_gpu, 59 | device=device, 60 | model=model, 61 | test_inputs=test_inputs, 62 | train_data_loaders=[batch_train_data_loader], 63 | params_filter=params_filter, 64 | weight_decay=constants.WEIGHT_DECAY, 65 | weight_decay_ignores=weight_decay_ignores, 66 | damp=damp, 67 | scale=scale, 68 | num_samples=num_samples) 69 | 70 | return [X.cpu() for X in s_test] 71 | 72 | 73 | def main( 74 | mode: str, 75 | num_examples_to_test: int = 5, 76 | num_repetitions: int = 4, 77 | ) -> List[Dict[str, Any]]: 78 | 79 | if mode not in ["only-correct", "only-incorrect"]: 80 | raise ValueError(f"Unrecognized mode {mode}") 81 | 82 | task_tokenizer, task_model = misc_utils.create_tokenizer_and_model( 83 | constants.MNLI_MODEL_PATH) 84 | train_dataset, eval_dataset = misc_utils.create_datasets( 85 | task_name="mnli", 86 | tokenizer=task_tokenizer) 87 | eval_instance_data_loader = misc_utils.get_dataloader( 88 | dataset=eval_dataset, 89 | batch_size=1, 90 | random=False) 91 | 92 | output_mode = glue_output_modes["mnli"] 93 | 94 | def build_compute_metrics_fn(task_name: str): 95 | def compute_metrics_fn(p): 96 | if output_mode == "classification": 97 | preds = np.argmax(p.predictions, axis=1) 98 | elif output_mode == "regression": 99 | preds = np.squeeze(p.predictions) 100 | return glue_compute_metrics(task_name, preds, p.label_ids) 101 | 102 | return compute_metrics_fn 103 | 104 | # Most of these arguments are placeholders 105 | # and are not really used at all, so ignore 106 | # the exact values of these. 107 | trainer = transformers.Trainer( 108 | model=task_model, 109 | args=TrainingArguments( 110 | output_dir="./tmp-output", 111 | per_device_train_batch_size=128, 112 | per_device_eval_batch_size=128, 113 | learning_rate=5e-5, 114 | logging_steps=100), 115 | data_collator=default_data_collator, 116 | train_dataset=train_dataset, 117 | eval_dataset=eval_dataset, 118 | compute_metrics=build_compute_metrics_fn("mnli"), 119 | ) 120 | 121 | task_model.cuda() 122 | num_examples_tested = 0 123 | output_collections = [] 124 | for test_index, test_inputs in enumerate(eval_instance_data_loader): 125 | if num_examples_tested >= num_examples_to_test: 126 | break 127 | 128 | # Skip when we only want cases of correction prediction but the 129 | # prediction is incorrect, or vice versa 130 | prediction_is_correct = misc_utils.is_prediction_correct( 131 | trainer=trainer, 132 | model=task_model, 133 | inputs=test_inputs) 134 | 135 | if mode == "only-correct" and prediction_is_correct is False: 136 | continue 137 | 138 | if mode == "only-incorrect" and prediction_is_correct is True: 139 | continue 140 | 141 | for k, v in test_inputs.items(): 142 | if isinstance(v, torch.Tensor): 143 | test_inputs[k] = v.to(torch.device("cuda")) 144 | 145 | # with batch-size 128, 1500 iterations is enough 146 | for num_samples in range(700, 1300 + 1, 100): # 7 choices 147 | for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]: # 8 choices 148 | for repetition in range(num_repetitions): 149 | print(f"Running #{test_index} " 150 | f"N={num_samples} " 151 | f"B={batch_size} " 152 | f"R={repetition} takes ...", end=" ") 153 | with Timer() as timer: 154 | s_test = one_experiment( 155 | model=task_model, 156 | train_dataset=train_dataset, 157 | test_inputs=test_inputs, 158 | batch_size=batch_size, 159 | random=True, 160 | n_gpu=1, 161 | device=torch.device("cuda"), 162 | damp=constants.DEFAULT_INFLUENCE_HPARAMS["mnli"]["mnli"]["damp"], 163 | scale=constants.DEFAULT_INFLUENCE_HPARAMS["mnli"]["mnli"]["scale"], 164 | num_samples=num_samples) 165 | time_elapsed = timer.elapsed 166 | print(f"{time_elapsed:.2f} seconds") 167 | 168 | outputs = { 169 | "test_index": test_index, 170 | "num_samples": num_samples, 171 | "batch_size": batch_size, 172 | "repetition": repetition, 173 | "s_test": s_test, 174 | "time_elapsed": time_elapsed, 175 | "correct": prediction_is_correct, 176 | } 177 | output_collections.append(outputs) 178 | remote_utils.save_and_mirror_scp_to_remote( 179 | object_to_save=outputs, 180 | file_name=f"stest.{mode}.{num_examples_to_test}." 181 | f"{test_index}.{num_samples}." 182 | f"{batch_size}.{repetition}.pth") 183 | 184 | num_examples_tested += 1 185 | 186 | return output_collections 187 | -------------------------------------------------------------------------------- /experiments/visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from tqdm import tqdm, trange 9 | import matplotlib.pyplot as plt 10 | from scipy.optimize import minimize 11 | from matplotlib.axes._subplots import Subplot 12 | # from graph_tool.draw import graph_draw 13 | # from joblib import Parallel, delayed 14 | 15 | from typing import List, Dict, Tuple, Optional, Union, Callable, Any 16 | from experiments.visualization_utils import ( 17 | get_circle_coordinates, 18 | get_within_circle_constraint, 19 | # distance_to_points_on_circle, 20 | # distance_to_points_within_circle, 21 | distance_to_points_within_circle_vectorized) 22 | from experiments import constants 23 | from experiments import misc_utils 24 | from experiments import remote_utils 25 | from experiments import influence_helpers 26 | from experiments.hans_utils import HansHelper 27 | from transformers import Trainer, TrainingArguments 28 | 29 | 30 | try: 31 | import graph_tool as gt 32 | gt_Graph_t = gt.Graph 33 | except ModuleNotFoundError: 34 | # We do not need `graph_tool` unless 35 | # visualization is to be created 36 | gt = None 37 | gt_Graph_t = "gt.Graph" 38 | 39 | DEFAULT_KNN_K = 1000 40 | DEFAULT_TRAIN_VERTEX_COLOR = 0 41 | DEFAULT_TRAIN_VERTEX_RADIUS = 2 42 | DEFAULT_EVAL_VERTEX_COLORS_BASE = 2 43 | DEFAULT_EVAL_VERTEX_RADIUS = 3 44 | DEFAULT_HELPFUL_EDGE_COLOR = 0 45 | DEFAULT_HARMFUL_EDGE_COLOR = 1 46 | DEFAULT_TRAIN_VERTEX_SIZE = 3 47 | 48 | 49 | def main( 50 | mode: str, 51 | train_task_name: str, 52 | eval_task_name: str, 53 | num_eval_to_collect: int, 54 | use_parallel: bool = True, 55 | kNN_k: Optional[int] = None, 56 | hans_heuristic: Optional[str] = None, 57 | trained_on_task_name: Optional[str] = None, 58 | ) -> List[Dict[str, Union[int, Dict[int, float]]]]: 59 | 60 | if train_task_name not in ["mnli", "mnli-2", "hans"]: 61 | raise ValueError 62 | 63 | if eval_task_name not in ["mnli", "mnli-2", "hans"]: 64 | raise ValueError 65 | 66 | if trained_on_task_name is None: 67 | # The task the model was trained on 68 | # can be different from `train_task_name` 69 | # which is used to determine on which the 70 | # influence values will be computed. 71 | trained_on_task_name = train_task_name 72 | 73 | if trained_on_task_name not in ["mnli", "mnli-2", "hans"]: 74 | raise ValueError 75 | 76 | if mode not in ["only-correct", "only-incorrect"]: 77 | raise ValueError(f"Unrecognized mode {mode}") 78 | 79 | if kNN_k is None: 80 | kNN_k = DEFAULT_KNN_K 81 | 82 | # `trained_on_task_name` determines the model to load 83 | if trained_on_task_name in ["mnli"]: 84 | tokenizer, model = misc_utils.create_tokenizer_and_model( 85 | constants.MNLI_MODEL_PATH) 86 | 87 | if trained_on_task_name in ["mnli-2"]: 88 | tokenizer, model = misc_utils.create_tokenizer_and_model( 89 | constants.MNLI2_MODEL_PATH) 90 | 91 | if trained_on_task_name in ["hans"]: 92 | tokenizer, model = misc_utils.create_tokenizer_and_model( 93 | constants.HANS_MODEL_PATH) 94 | 95 | train_dataset, _ = misc_utils.create_datasets( 96 | task_name=train_task_name, 97 | tokenizer=tokenizer) 98 | 99 | _, eval_dataset = misc_utils.create_datasets( 100 | task_name=eval_task_name, 101 | tokenizer=tokenizer) 102 | 103 | faiss_index = influence_helpers.load_faiss_index( 104 | trained_on_task_name=trained_on_task_name, 105 | train_task_name=train_task_name) 106 | 107 | trainer = Trainer( 108 | model=model, 109 | args=TrainingArguments( 110 | output_dir="./tmp-output", 111 | per_device_train_batch_size=128, 112 | per_device_eval_batch_size=128, 113 | learning_rate=5e-5, 114 | logging_steps=100), 115 | train_dataset=train_dataset, 116 | eval_dataset=eval_dataset, 117 | ) 118 | 119 | if eval_task_name in ["mnli", "mnli-2"]: 120 | eval_instance_data_loader = misc_utils.get_dataloader( 121 | dataset=eval_dataset, 122 | batch_size=1, 123 | random=False) 124 | 125 | if eval_task_name in ["hans"]: 126 | if hans_heuristic is None: 127 | raise ValueError("`hans_heuristic` cannot be None for now") 128 | 129 | hans_helper = HansHelper( 130 | hans_train_dataset=None, 131 | hans_eval_dataset=eval_dataset) 132 | 133 | _, eval_instance_data_loader = hans_helper.get_dataset_and_dataloader_of_heuristic( 134 | mode="eval", 135 | heuristic=hans_heuristic, 136 | batch_size=1, 137 | random=False) 138 | 139 | # Data-points where the model got wrong 140 | correct_input_collections = [] 141 | incorrect_input_collections = [] 142 | for index, test_inputs in enumerate(eval_instance_data_loader): 143 | logits, labels, step_eval_loss = misc_utils.predict( 144 | trainer=trainer, 145 | model=model, 146 | inputs=test_inputs) 147 | if logits.argmax(axis=-1).item() != labels.item(): 148 | incorrect_input_collections.append((index, test_inputs)) 149 | else: 150 | correct_input_collections.append((index, test_inputs)) 151 | 152 | if mode == "only-incorrect": 153 | input_collections = incorrect_input_collections 154 | else: 155 | input_collections = correct_input_collections 156 | 157 | # Other settings are not supported as of now 158 | (s_test_damp, 159 | s_test_scale, 160 | s_test_num_samples) = influence_helpers.select_s_test_config( 161 | trained_on_task_name=trained_on_task_name, 162 | train_task_name=train_task_name, 163 | eval_task_name=eval_task_name) 164 | 165 | influences_collections = [] 166 | for index, inputs in input_collections[:num_eval_to_collect]: 167 | print(f"#{index}") 168 | influences = influence_helpers.compute_influences_simplified( 169 | k=kNN_k, 170 | faiss_index=faiss_index, 171 | model=model, 172 | inputs=inputs, 173 | train_dataset=train_dataset, 174 | use_parallel=use_parallel, 175 | s_test_damp=s_test_damp, 176 | s_test_scale=s_test_scale, 177 | s_test_num_samples=s_test_num_samples, 178 | device_ids=[0, 1, 2, 3], 179 | precomputed_s_test=None) 180 | 181 | influences_collections.append({ 182 | "index": index, 183 | "influences": influences, 184 | }) 185 | 186 | remote_utils.save_and_mirror_scp_to_remote( 187 | object_to_save=influences_collections, 188 | file_name=( 189 | f"visualization" 190 | f".{mode}.{num_eval_to_collect}" 191 | f".{train_task_name}-{eval_task_name}" 192 | f"-{hans_heuristic}-{trained_on_task_name}" 193 | f".{kNN_k}.{use_parallel}.pth")) 194 | 195 | return influences_collections 196 | 197 | 198 | def run_experiments(option: str) -> List[List[Dict[int, float]]]: 199 | if option == "mnli2_and_hans": 200 | mnli2_influences = main( 201 | mode="only-incorrect", 202 | train_task_name="mnli-2", 203 | eval_task_name="mnli-2", 204 | num_eval_to_collect=100) 205 | 206 | hans_influences = main( 207 | mode="only-incorrect", 208 | train_task_name="mnli-2", 209 | eval_task_name="hans", 210 | num_eval_to_collect=100) 211 | 212 | return [mnli2_influences, hans_influences] 213 | 214 | if option == "mnli_and_hans_heuristic": 215 | hans_influences_collections = [] 216 | for hans_heuristic in ["lexical_overlap", "subsequence", "constituent"]: 217 | hans_influences = main( 218 | mode="only-incorrect", 219 | train_task_name="mnli-2", 220 | eval_task_name="hans", 221 | num_eval_to_collect=100, 222 | hans_heuristic=hans_heuristic) 223 | 224 | hans_influences_collections.append(hans_influences) 225 | 226 | return hans_influences_collections 227 | 228 | if option == "hans_and_hans_heuristic": 229 | hans_influences_collections = [] 230 | for hans_heuristic in ["lexical_overlap", "subsequence", "constituent"]: 231 | hans_influences = main( 232 | mode="only-incorrect", 233 | train_task_name="hans", 234 | eval_task_name="hans", 235 | num_eval_to_collect=100, 236 | hans_heuristic=hans_heuristic) 237 | 238 | hans_influences_collections.append(hans_influences) 239 | 240 | return hans_influences_collections 241 | 242 | raise ValueError(f"Unrecognized `option` {option}") 243 | 244 | 245 | def get_datapoints_map( 246 | influences_collections: List[Dict[int, float]] 247 | ) -> Tuple[List[int], Dict[int, int]]: 248 | possible_datapoints = [] 249 | for influences in influences_collections: 250 | possible_datapoints.extend(list(influences.keys())) 251 | 252 | possible_datapoints = sorted(set(possible_datapoints)) 253 | datapoints_map = dict((v, k) for k, v in enumerate(possible_datapoints)) 254 | return possible_datapoints, datapoints_map 255 | 256 | 257 | def get_graph( 258 | influences_collections_list: List[List[Dict[int, float]]], 259 | train_vertex_color_map_fn: Optional[Callable[[int], int]] = None, 260 | train_vertex_radius_map_fn: Optional[Callable[[int], int]] = None, 261 | eval_vertex_radius: Optional[int] = None, 262 | eval_vertex_color_base: Optional[int] = None, 263 | ) -> gt_Graph_t: 264 | 265 | if train_vertex_color_map_fn is None: 266 | def train_vertex_color_map_fn(index: int) -> int: 267 | return DEFAULT_TRAIN_VERTEX_COLOR 268 | 269 | if train_vertex_radius_map_fn is None: 270 | def train_vertex_radius_map_fn(index: int) -> int: 271 | return DEFAULT_TRAIN_VERTEX_RADIUS 272 | 273 | if eval_vertex_radius is None: 274 | eval_vertex_radius = DEFAULT_EVAL_VERTEX_RADIUS 275 | 276 | if eval_vertex_color_base is None: 277 | eval_vertex_color_base = DEFAULT_EVAL_VERTEX_COLORS_BASE 278 | 279 | if train_vertex_color_map_fn is None: 280 | raise ValueError 281 | 282 | if train_vertex_radius_map_fn is None: 283 | raise ValueError 284 | 285 | NUM_INFLUENCE_COLLECTIONS = len(influences_collections_list) 286 | influences_collections_list_flatten = [] 287 | for influences_collections in influences_collections_list: 288 | # Assume they all have the same lengths 289 | if len(influences_collections_list[0][0]) != len(influences_collections[0]): 290 | raise ValueError 291 | influences_collections_list_flatten.extend(influences_collections) 292 | 293 | # Note they share the same training dataset 294 | possible_datapoints, datapoints_map = get_datapoints_map( 295 | influences_collections=influences_collections_list_flatten) 296 | 297 | g = gt.Graph(directed=True) 298 | # Edge properties 299 | e_colors = g.new_edge_property("int") 300 | e_weights = g.new_edge_property("double") 301 | e_signed_influences = g.new_edge_property("double") 302 | e_unsigned_influences = g.new_edge_property("double") 303 | # Vertex properties 304 | v_sizes = g.new_vertex_property("int") 305 | v_colors = g.new_vertex_property("int") 306 | v_radius = g.new_vertex_property("int") 307 | v_data_indices = g.new_vertex_property("string") 308 | v_positions = g.new_vertex_property("vector") 309 | v_positive_positions = g.new_vertex_property("vector") 310 | v_negative_positions = g.new_vertex_property("vector") 311 | 312 | train_vertices = [] 313 | eval_vertices_collections = [] 314 | 315 | # Add train vertices 316 | for datapoint_index in trange(len(possible_datapoints)): 317 | v = g.add_vertex() 318 | v_sizes[v] = DEFAULT_TRAIN_VERTEX_SIZE 319 | v_colors[v] = train_vertex_color_map_fn( 320 | possible_datapoints[datapoint_index]) 321 | v_radius[v] = train_vertex_radius_map_fn( 322 | possible_datapoints[datapoint_index]) 323 | v_data_indices[v] = f"train-{possible_datapoints[datapoint_index]}" 324 | train_vertices.append(v) 325 | 326 | # Add eval vertices 327 | for i, influences_collections in enumerate(influences_collections_list): 328 | 329 | eval_vertices = [] 330 | for datapoint_index in trange(len(influences_collections)): 331 | v = g.add_vertex() 332 | v_sizes[v] = 10 333 | v_colors[v] = eval_vertex_color_base + i 334 | v_radius[v] = eval_vertex_radius 335 | v_data_indices[v] = f"eval-{i}-{datapoint_index}" 336 | 337 | base_degree = (360 / NUM_INFLUENCE_COLLECTIONS) * i 338 | fine_degree = (360 / NUM_INFLUENCE_COLLECTIONS / len(influences_collections)) * datapoint_index 339 | x_y_coordinate = get_circle_coordinates( 340 | r=eval_vertex_radius, 341 | degree=base_degree + fine_degree) 342 | position = np.random.normal(x_y_coordinate, 0.1) 343 | v_positions[v] = position 344 | v_positive_positions[v] = position 345 | v_negative_positions[v] = position 346 | eval_vertices.append(v) 347 | 348 | eval_vertices_collections.append(eval_vertices) 349 | 350 | # Add edges 351 | def add_edges(influences_collections: List[Dict[int, float]], 352 | eval_vertices: List[gt.Vertex]) -> None: 353 | for eval_index, influences in enumerate(tqdm(influences_collections)): 354 | for train_index, train_influence in influences.items(): 355 | # Negative influence is helpful (when the prediction is wrong) 356 | if train_influence < 0.0: 357 | train_vertex = train_vertices[datapoints_map[train_index]] 358 | eval_vertex = eval_vertices[eval_index] 359 | e = g.add_edge(train_vertex, eval_vertex) 360 | e_colors[e] = DEFAULT_HELPFUL_EDGE_COLOR 361 | e_weights[e] = np.abs(train_influence) 362 | e_signed_influences[e] = train_influence 363 | e_unsigned_influences[e] = np.abs(train_influence) 364 | else: 365 | train_vertex = train_vertices[datapoints_map[train_index]] 366 | eval_vertex = eval_vertices[eval_index] 367 | e = g.add_edge(train_vertex, eval_vertex) 368 | e_colors[e] = DEFAULT_HARMFUL_EDGE_COLOR 369 | e_weights[e] = np.abs(train_influence) 370 | e_signed_influences[e] = train_influence 371 | e_unsigned_influences[e] = np.abs(train_influence) 372 | 373 | for i, influences_collections in enumerate(influences_collections_list): 374 | add_edges(influences_collections, eval_vertices_collections[i]) 375 | 376 | def _calculate_position(train_vertex: gt.Vertex) -> None: 377 | """Determine X-axis and Y-axis 378 | - We use X-axis to determine the divergence 379 | - We use Y-axis to determine the helpfulness/harmfulness 380 | """ 381 | # Two types of targets 382 | # two types of connections 383 | _positive_points = [] 384 | _negative_points = [] 385 | _positive_influences = [] 386 | _negative_influences = [] 387 | for e in train_vertex.all_edges(): 388 | target = e.target() 389 | if e_signed_influences[e] > 0: 390 | _positive_points.append(v_positions[target]) 391 | _positive_influences.append(e_unsigned_influences[e]) 392 | else: 393 | _negative_points.append(v_positions[target]) 394 | _negative_influences.append(e_unsigned_influences[e]) 395 | 396 | # `minimize` might fail using `np.sqrt(2)` for some reasons :\ 397 | bound = 1.4 * v_radius[train_vertex] 398 | constraints = ({ 399 | "type": "ineq", 400 | "fun": get_within_circle_constraint(v_radius[train_vertex]) 401 | }) 402 | 403 | if len(_positive_influences) == 0: 404 | _positive_xval = 0.0 405 | _positive_yval = 0.0 406 | else: 407 | _positive_points_stacked = np.stack(_positive_points, axis=0) 408 | _positive_influences_stacked = np.stack(_positive_influences, axis=0) 409 | _positive_optimize_result = minimize( 410 | distance_to_points_within_circle_vectorized, 411 | x0=(0, 0), 412 | constraints=constraints, 413 | bounds=((-bound, bound), (-bound, bound)), 414 | args=(_positive_influences_stacked, 415 | _positive_points_stacked)) 416 | _positive_xval, _positive_yval = _positive_optimize_result.x 417 | 418 | if len(_negative_influences) == 0: 419 | _negative_xval = 0.0 420 | _negative_yval = 0.0 421 | else: 422 | _negative_points_stacked = np.stack(_negative_points, axis=0) 423 | _negative_influences_stacked = np.stack(_negative_influences, axis=0) 424 | _negative_optimize_result = minimize( 425 | distance_to_points_within_circle_vectorized, 426 | x0=(0, 0), 427 | constraints=constraints, 428 | bounds=((-bound, bound), (-bound, bound)), 429 | args=(_negative_influences_stacked, 430 | _negative_points_stacked)) 431 | _negative_xval, _negative_yval = _negative_optimize_result.x 432 | 433 | _positive_xval = np.random.normal(_positive_xval, 0.01) 434 | _negative_xval = np.random.normal(_negative_xval, 0.01) 435 | _positive_yval = np.random.normal(_positive_yval, 0.01) 436 | _negative_yval = np.random.normal(_negative_yval, 0.01) 437 | v_positive_positions[train_vertex] = np.array([_positive_xval, _positive_yval]) 438 | v_negative_positions[train_vertex] = np.array([_negative_xval, _negative_yval]) 439 | v_positions[train_vertex] = np.array([(_positive_xval + _negative_xval) / 2, 440 | (_positive_yval + _negative_yval) / 2]) 441 | 442 | # Run them in parallel 443 | # Parallel(n_jobs=-1)( 444 | # delayed(_calculate_position)(train_vertex) 445 | # for train_vertex in tqdm(train_vertices)) 446 | for train_vertex in tqdm(train_vertices): 447 | _calculate_position(train_vertex) 448 | 449 | # Assign Edge properties 450 | g.edge_properties["colors"] = e_colors 451 | g.edge_properties["weights"] = e_weights 452 | g.edge_properties["signed_influences"] = e_signed_influences 453 | g.edge_properties["unsigned_influences"] = e_unsigned_influences 454 | # Assign Vertex properties 455 | g.vertex_properties["sizes"] = v_sizes 456 | g.vertex_properties["colors"] = v_colors 457 | g.vertex_properties["radius"] = v_radius 458 | g.vertex_properties["data_indices"] = v_data_indices 459 | g.vertex_properties["positions"] = v_positions 460 | g.vertex_properties["positive_positions"] = v_positive_positions 461 | g.vertex_properties["negative_positions"] = v_negative_positions 462 | 463 | return g, { 464 | "train_vertices": train_vertices, 465 | "eval_vertices_collections": eval_vertices_collections 466 | } 467 | 468 | 469 | def get_recall_plot(model, example, faiss_index, full_influences_dict): 470 | # plt.rcParams["figure.figsize"] = [20, 5] 471 | recall_num_neighbors = [10, 100, 1000] 472 | num_neighbors = [10, 100, 1000, 10000, 50000, 100000] 473 | names = ["Most Helpful", 474 | "Most Harmful", 475 | "Most Influencetial", 476 | "Least Influential"] 477 | 478 | features = misc_utils.compute_BERT_CLS_feature(model, **example) 479 | features = features.cpu().detach().numpy() 480 | if list(full_influences_dict.keys()) != list(range(len(full_influences_dict))): 481 | raise ValueError 482 | 483 | full_influences = [] 484 | for key in sorted(full_influences_dict): 485 | full_influences.append(full_influences_dict[key]) 486 | 487 | sorted_indices_small_to_large = np.argsort(full_influences) 488 | sorted_indices_large_to_small = np.argsort(full_influences)[::-1] 489 | sorted_indices_abs_large_to_small = np.argsort(np.abs(full_influences))[::-1] 490 | sorted_indices_abs_small_to_large = np.argsort(np.abs(full_influences)) 491 | 492 | fig, axes = plt.subplots(1, 4, sharex=True, sharey=True) 493 | recalls_collections = {} 494 | for i, (name, sorted_indices) in enumerate(zip( 495 | names, 496 | [sorted_indices_small_to_large, 497 | sorted_indices_large_to_small, 498 | sorted_indices_abs_large_to_small, 499 | sorted_indices_abs_small_to_large])): 500 | 501 | recalls_collection = [] 502 | for recall_k in tqdm(recall_num_neighbors): 503 | recalls = [] 504 | influential = sorted_indices[:recall_k] 505 | influential_set = set(influential.tolist()) 506 | for k in num_neighbors: 507 | distances, indices = faiss_index.search(k=k, queries=features) 508 | indices_set = set(indices.squeeze(axis=0).tolist()) 509 | recall = len(influential_set & indices_set) / len(influential_set) 510 | recalls.append(recall) 511 | 512 | recalls_collection.append(recalls) 513 | axes[i].plot(num_neighbors, recalls, 514 | linestyle="--", marker="o", 515 | label=f"recall@{recall_k}") 516 | 517 | axes[i].legend() 518 | axes[i].set_title(name) 519 | axes[i].set_xscale("log") 520 | axes[i].set_ylabel("Recall") 521 | axes[i].set_xlabel("Number of Nearest Neighbors") 522 | recalls_collections[name] = recalls_collection 523 | 524 | return recalls_collections 525 | 526 | 527 | def plot_Xs_and_Ys_dict( 528 | axis: Subplot, 529 | Xs: List[float], 530 | Ys_dict: Dict[str, List[List[float]]], 531 | title: str, 532 | xlabel: str, 533 | ylabel: str, 534 | xscale_log: bool = True, 535 | yscale_log: bool = True, 536 | output_file_name: Optional[str] = None, 537 | ) -> None: 538 | 539 | color_map = { 540 | "helpful-1": "lightskyblue", 541 | "helpful-10": "deepskyblue", 542 | "helpful-100": "dodgerblue", 543 | "harmful-1": "lightcoral", 544 | "harmful-10": "salmon", 545 | "harmful-100": "red", 546 | "random-1": "darkgrey", 547 | "random-10": "dimgrey", 548 | "random-100": "black", 549 | } 550 | 551 | for tag in Ys_dict.keys(): 552 | if tag not in color_map.keys(): 553 | raise ValueError 554 | 555 | color = color_map[tag] 556 | data = np.array(Ys_dict[tag]) 557 | is_random_data_point = "random" in tag 558 | # `data` should be [n, m] 559 | # where `n` is the number of independent trials 560 | # and `m` is the number of experiments within each trial 561 | if len(data.shape) != 2: 562 | raise ValueError(f"`data` should be an 2d array, {data.shape}") 563 | 564 | if data.shape[0] != 1: 565 | # i.e., it has multiple trials 566 | data_mean = data.mean(axis=0) 567 | data_max = data.max(axis=0) 568 | data_min = data.min(axis=0) 569 | # data_std = data.std(axis=0) 570 | axis.plot( 571 | Xs, 572 | data_mean, 573 | color=color, 574 | label=tag, 575 | linestyle=("--" if is_random_data_point else None)) 576 | 577 | axis.fill_between( 578 | Xs, 579 | data_max, 580 | data_min, 581 | alpha=0.25, 582 | color=color) 583 | else: 584 | # i.e., only one trial 585 | axis.plot( 586 | Xs, 587 | data[0, ...], 588 | color=color) 589 | 590 | if xscale_log is True: 591 | axis.set_xscale("log") 592 | 593 | if yscale_log is True: 594 | axis.set_yscale("log") 595 | 596 | axis.set_xlabel(xlabel, fontsize=30) 597 | axis.set_ylabel(ylabel, fontsize=30) 598 | axis.set_title(title, fontsize=30) 599 | axis.legend(fontsize=15) 600 | 601 | if output_file_name is not None: 602 | plt.savefig(output_file_name) 603 | 604 | 605 | def collect_edges_from_graph( 606 | g: gt.Graph, 607 | vertex_color_to_slice_map: Dict[int, str] 608 | ) -> Tuple[pd.DataFrame, List[Dict[str, Any]]]: 609 | 610 | edge_collections = [] 611 | edge_to_color_map = g.vertex_properties["colors"] 612 | edge_to_data_index_map = g.vertex_properties["data_indices"] 613 | for edge in tqdm(g.edges()): 614 | source_vertex = edge.source() 615 | target_vertex = edge.target() 616 | source_vertex_color = edge_to_color_map[source_vertex] 617 | target_vertex_color = edge_to_color_map[target_vertex] 618 | source_vertex_data_index = edge_to_data_index_map[source_vertex] 619 | target_vertex_data_index = edge_to_data_index_map[target_vertex] 620 | 621 | # source vertex should be training data 622 | if not source_vertex_data_index.startswith("train"): 623 | raise ValueError 624 | 625 | # target vertex should be evaluation data 626 | if not target_vertex_data_index.startswith("eval"): 627 | raise ValueError 628 | 629 | # train vertex should have this color 630 | if source_vertex_color != DEFAULT_TRAIN_VERTEX_COLOR: 631 | raise ValueError 632 | 633 | # eval vertex should not have this color 634 | if target_vertex_color == DEFAULT_TRAIN_VERTEX_COLOR: 635 | raise ValueError 636 | 637 | edge_collection = { 638 | "edge": edge, 639 | "target_slice": vertex_color_to_slice_map[target_vertex_color], 640 | "source_vertex_data_index": source_vertex_data_index, 641 | "target_vertex_data_index": target_vertex_data_index, 642 | } 643 | 644 | for property_name, property_map in g.edge_properties.items(): 645 | if property_name in edge_collection.keys(): 646 | raise ValueError(f"Duplicate key {property_name}") 647 | edge_collection[property_name] = property_map[edge] 648 | 649 | edge_collections.append(edge_collection) 650 | 651 | return pd.DataFrame(edge_collections), edge_collections 652 | -------------------------------------------------------------------------------- /experiments/visualization_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from typing import List, Callable, Dict, Tuple 9 | 10 | 11 | def get_circle_coordinates(r: float, degree: float): 12 | if degree < 0 or degree > 360: 13 | raise ValueError 14 | 15 | radian = (degree / 360) * 2 * np.pi 16 | x = r * np.sin(radian) 17 | y = r * np.cos(radian) 18 | return x, y 19 | 20 | 21 | def distance_to_points_on_circle( 22 | x: float, r: float, 23 | weights: List[float], 24 | points: List[List[float]]) -> float: 25 | # x^2 + y^2 = r^2 26 | y = np.sqrt(np.square(r) - np.square(x)) 27 | 28 | weighted_distances = 0.0 29 | for weight, point in zip(weights, points): 30 | _x, _y = point 31 | distance = np.sqrt((np.square(x - _x) + np.square(y - _y))) 32 | weighted_distances += weight * distance 33 | 34 | return weighted_distances 35 | 36 | 37 | def distance_to_points_within_circle( 38 | x_y: List[float], 39 | weights: List[float], 40 | points: List[List[float]]) -> float: 41 | if len(x_y) != 2: 42 | raise ValueError(f"Invalid `x_y` {x_y}") 43 | 44 | x, y = x_y 45 | weighted_distances = 0.0 46 | for weight, point in zip(weights, points): 47 | _x, _y = point 48 | distance = np.sqrt((np.square(x - _x) + np.square(y - _y))) 49 | weighted_distances += weight * distance 50 | 51 | return weighted_distances 52 | 53 | 54 | def distance_to_points_within_circle_vectorized( 55 | x_y: List[float], 56 | weights: np.ndarray, 57 | points: np.ndarray) -> float: 58 | if len(x_y) != 2: 59 | raise ValueError(f"Invalid `x_y` {x_y}") 60 | if len(weights.shape) != 1: 61 | raise ValueError(f"Invalid `weights` shape {weights.shape}") 62 | if len(points.shape) != 2 or points.shape[1] != 2: 63 | raise ValueError(f"Invalid `points` shape {points.shape}") 64 | if weights.shape[0] != points.shape[0]: 65 | raise ValueError(f"Incompatible shapes {weights.shape} {points.shape}") 66 | 67 | point = np.array(x_y) 68 | distance = np.sqrt(np.square(point - points).sum(axis=-1)) 69 | weighted_distances = distance * weights 70 | return weighted_distances.sum() 71 | 72 | 73 | def get_within_circle_constraint(r: float) -> Callable[[List[float]], float]: 74 | 75 | # Inequality constraint must be non-negative 76 | def _constraint(x_y: List[float]) -> float: 77 | x, y = x_y 78 | return np.square(r) - np.square(x) - np.square(y) 79 | 80 | return _constraint 81 | 82 | 83 | def plot_influences_distribution( 84 | influences_collections: List[Dict[int, float]], 85 | label: str, 86 | hist_xrange: Tuple[float, float]) -> None: 87 | 88 | influences: List[float] = [] 89 | for L in influences_collections: 90 | influences.extend(L.values()) 91 | plt.hist(influences, label=label, range=hist_xrange) 92 | -------------------------------------------------------------------------------- /figs/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/fast-influence-functions/40740c49b27f472c8323343a15490ce2d04eae05/figs/main.png -------------------------------------------------------------------------------- /influence_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause -------------------------------------------------------------------------------- /influence_utils/faiss_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | import faiss 7 | import numpy as np 8 | from typing import Optional, Tuple 9 | 10 | 11 | class FAISSIndex(object): 12 | def __init__(self, 13 | d: Optional[int] = None, 14 | description: Optional[str] = "Flat", 15 | index: Optional[faiss.Index] = None) -> None: 16 | if index is None: 17 | index = faiss.index_factory(d, description) 18 | self._index = index 19 | 20 | def search(self, 21 | k: int, 22 | key: Optional[int] = None, 23 | query: Optional[np.ndarray] = None, 24 | queries: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]: 25 | """Search Nearest Neighbor 26 | 27 | Args: 28 | k [int]: number of nearest neighbors to use 29 | key [int, Optional]: Exclusive with `query` and `queries`, 30 | the key used to do the search if exists 31 | in the `index` already. 32 | query [int, Array ]: Exclusive with `key` and `queries`, 33 | the query used to do the search. 34 | queries [int, Array ]: Exclusive with `key` and `query`, 35 | the queries used to do the search. 36 | 37 | Returns: 38 | distance [Array ]: Distance of the return indices. 39 | indices [Array ]: Return indices. 40 | 41 | """ 42 | num_inputs_provided = sum([ 43 | key is not None, 44 | query is not None, 45 | queries is not None 46 | ]) 47 | if num_inputs_provided != 1: 48 | raise ValueError 49 | 50 | # if `query` is not None then we are good 51 | if key is not None: 52 | query = self.get(key) 53 | query_expanded = np.expand_dims(query, axis=0) 54 | 55 | if query is not None: 56 | query_expanded = np.expand_dims(query, axis=0) 57 | 58 | if queries is not None: 59 | query_expanded = queries 60 | 61 | if query_expanded.ndim != 2: 62 | raise ValueError 63 | 64 | return self._index.search(query_expanded, k) 65 | 66 | def add(self, vectors: np.ndarray) -> None: 67 | self._index.add(vectors) 68 | 69 | def get(self, key: int) -> np.ndarray: 70 | """Returns Array """ 71 | return self._index.reconstruct(key=key) 72 | 73 | def get_n(self, key_0: int, key_i: int) -> np.ndarray: 74 | """Returns Array """ 75 | return self._index.reconstruct_n( 76 | n0=key_0, ni=key_i) 77 | 78 | def save(self, file_name) -> None: 79 | faiss.write_index(self._index, file_name) 80 | 81 | def load(self, file_name) -> None: 82 | self._index = faiss.read_index(file_name) 83 | 84 | def __len__(self): 85 | return self._index.ntotal 86 | -------------------------------------------------------------------------------- /influence_utils/glue_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | from transformers import BertForSequenceClassification 7 | 8 | 9 | def freeze_BERT_parameters(model: BertForSequenceClassification, verbose: bool = True) -> None: 10 | # https://github.com/huggingface/transformers/issues/400 11 | if not isinstance(model, BertForSequenceClassification): 12 | raise TypeError 13 | 14 | # Table 3 in https://arxiv.org/pdf/1911.03090.pdf 15 | params_to_freeze = [ 16 | "bert.embeddings.", 17 | "bert.encoder.layer.0.", 18 | "bert.encoder.layer.1.", 19 | "bert.encoder.layer.2.", 20 | "bert.encoder.layer.3.", 21 | "bert.encoder.layer.4.", 22 | "bert.encoder.layer.5.", 23 | "bert.encoder.layer.6.", 24 | "bert.encoder.layer.7.", 25 | "bert.encoder.layer.8.", 26 | "bert.encoder.layer.9.", 27 | ] 28 | for name, param in model.named_parameters(): 29 | # if "classifier" not in name: # classifier layer 30 | # param.requires_grad = False 31 | 32 | if any(pfreeze in name for pfreeze in params_to_freeze): 33 | param.requires_grad = False 34 | 35 | if verbose is True: 36 | num_trainable_params = sum([ 37 | p.numel() for n, p in model.named_parameters() 38 | if p.requires_grad]) 39 | trainable_param_names = [ 40 | n for n, p in model.named_parameters() 41 | if p.requires_grad] 42 | print(f"Params Trainable: {num_trainable_params}\n\t" + 43 | f"\n\t".join(trainable_param_names)) 44 | -------------------------------------------------------------------------------- /influence_utils/multiprocessing_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | """Light Modification of the following file (v1.5.1) 7 | 8 | https://github.com/pytorch/pytorch/blob/v1.5.1/torch/multiprocessing/spawn.py 9 | 10 | The major change is to allow device-specific arguments instead of the same 11 | arguments applied to all processes. The motivation is to reduce sending 12 | over unnecessary data, which could increase the spawning overhead. 13 | """ 14 | 15 | import torch.multiprocessing as torch_mp 16 | from torch.multiprocessing.spawn import ( 17 | _python_version_check, _wrap, multiprocessing, warnings) 18 | 19 | 20 | # Note: [start_processes] 21 | # mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a 22 | # more generalized API than mp.spawn. Currently we only document mp.spawn as it's the 23 | # CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork' 24 | # works better than 'spawn'. Every helper function we created for mp.spawn is indeed 25 | # general enough, and backends like XLA can reuse them in Colab notebooks as well. 26 | # Currently we only add this API first, we can consider adding it to documentation as 27 | # needed in the future. 28 | def start_processes(fn, list_of_args, nprocs=1, join=True, daemon=False, start_method='spawn'): 29 | _python_version_check() 30 | mp = multiprocessing.get_context(start_method) 31 | error_queues = [] 32 | processes = [] 33 | for i in range(nprocs): 34 | error_queue = mp.SimpleQueue() 35 | process = mp.Process( 36 | target=_wrap, 37 | args=(fn, i, list_of_args[i], error_queue), 38 | daemon=daemon, 39 | ) 40 | process.start() 41 | error_queues.append(error_queue) 42 | processes.append(process) 43 | 44 | context = torch_mp.ProcessContext(processes, error_queues) 45 | if not join: 46 | return context 47 | 48 | # Loop on join until it returns True or raises an exception. 49 | while not context.join(): 50 | pass 51 | 52 | 53 | def spawn(fn, list_of_args, nprocs=1, join=True, daemon=False, start_method='spawn'): 54 | r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``. 55 | If one of the processes exits with a non-zero exit status, the 56 | remaining processes are killed and an exception is raised with the 57 | cause of termination. In the case an exception was caught in the 58 | child process, it is forwarded and its traceback is included in 59 | the exception raised in the parent process. 60 | Arguments: 61 | fn (function): Function is called as the entrypoint of the 62 | spawned process. This function must be defined at the top 63 | level of a module so it can be pickled and spawned. This 64 | is a requirement imposed by multiprocessing. 65 | The function is called as ``fn(i, *args)``, where ``i`` is 66 | the process index and ``args`` is the passed through tuple 67 | of arguments. 68 | args (tuple): Arguments passed to ``fn``. 69 | nprocs (int): Number of processes to spawn. 70 | join (bool): Perform a blocking join on all processes. 71 | daemon (bool): The spawned processes' daemon flag. If set to True, 72 | daemonic processes will be created. 73 | start_method (string): (deprecated) this method will always use ``spawn`` 74 | as the start method. To use a different start method 75 | use ``start_processes()``. 76 | Returns: 77 | None if ``join`` is ``True``, 78 | :class:`~ProcessContext` if ``join`` is ``False`` 79 | """ 80 | if start_method != 'spawn': 81 | msg = ('This method only supports start_method=spawn (got: %s).\n' 82 | 'To use a different start_method use:\n\t\t' 83 | ' torch.multiprocessing.start_process(...)' % start_method) 84 | warnings.warn(msg) 85 | return start_processes(fn, list_of_args, nprocs, join, daemon, start_method='spawn') 86 | -------------------------------------------------------------------------------- /influence_utils/nn_influence_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | import torch 7 | import numpy as np 8 | from tqdm import tqdm 9 | from transformers import PreTrainedTokenizer 10 | from typing import Dict, List, Union, Optional, Tuple, Iterator, Any 11 | 12 | 13 | def count_parameters(model: torch.nn.Module) -> int: 14 | return sum(p.numel() for p in model.parameters()) 15 | 16 | 17 | def convert_ids_to_string( 18 | tokenizer: PreTrainedTokenizer, 19 | ids: torch.LongTensor) -> str: 20 | tokens = tokenizer.convert_ids_to_tokens(ids) 21 | return tokenizer.convert_tokens_to_string(tokens) 22 | 23 | 24 | def get_loss_with_weight_decay( 25 | device: torch.device, 26 | n_gpu: int, 27 | model: torch.nn.Module, 28 | inputs: Dict[str, torch.Tensor], 29 | weight_decay: Optional[float], 30 | weight_decay_ignores: Optional[List[str]]) -> float: 31 | 32 | # model.train() 33 | for k, v in inputs.items(): 34 | inputs[k] = v.to(device) 35 | 36 | outputs = model(**inputs) 37 | # model outputs are always tuple in transformers (see doc) 38 | loss = outputs[0] 39 | 40 | if n_gpu > 1: 41 | # mean() to average on multi-gpu parallel training 42 | loss = loss.mean() 43 | 44 | # In PyTorch, weight-decay loss and gradients are calculated in 45 | # optimizers rather in nn.Module, so we have to manually specify 46 | # this for the loss here. 47 | if weight_decay is not None: 48 | no_decay = ( 49 | weight_decay_ignores 50 | if weight_decay_ignores 51 | is not None else []) 52 | 53 | weight_decay_loss = torch.cat([ 54 | p.square().view(-1) 55 | for n, p in model.named_parameters() 56 | if not any(nd in n for nd in no_decay) 57 | ]).sum() * weight_decay 58 | loss = loss + weight_decay_loss 59 | 60 | return loss 61 | 62 | 63 | def compute_gradients( 64 | device: torch.device, 65 | n_gpu: int, 66 | model: torch.nn.Module, 67 | inputs: Dict[str, torch.Tensor], 68 | params_filter: Optional[List[str]], 69 | weight_decay: Optional[float], 70 | weight_decay_ignores: Optional[List[str]] 71 | ) -> List[torch.FloatTensor]: 72 | 73 | if params_filter is None: 74 | params_filter = [] 75 | 76 | model.zero_grad() 77 | loss = get_loss_with_weight_decay( 78 | device=device, n_gpu=n_gpu, 79 | model=model, inputs=inputs, 80 | weight_decay=weight_decay, 81 | weight_decay_ignores=weight_decay_ignores) 82 | 83 | return torch.autograd.grad( 84 | outputs=loss, 85 | inputs=[ 86 | param for name, param 87 | in model.named_parameters() 88 | if name not in params_filter], 89 | create_graph=True) 90 | 91 | 92 | def compute_hessian_vector_products( 93 | device: torch.device, 94 | n_gpu: int, 95 | model: torch.nn.Module, 96 | inputs: Dict[str, torch.Tensor], 97 | vectors: torch.FloatTensor, 98 | params_filter: Optional[List[str]], 99 | weight_decay: Optional[float], 100 | weight_decay_ignores: Optional[List[str]] 101 | ) -> List[torch.FloatTensor]: 102 | 103 | if params_filter is None: 104 | params_filter = [] 105 | 106 | model.zero_grad() 107 | loss = get_loss_with_weight_decay( 108 | model=model, n_gpu=n_gpu, 109 | device=device, inputs=inputs, 110 | weight_decay=weight_decay, 111 | weight_decay_ignores=weight_decay_ignores) 112 | 113 | grad_tuple = torch.autograd.grad( 114 | outputs=loss, 115 | inputs=[ 116 | param for name, param 117 | in model.named_parameters() 118 | if name not in params_filter], 119 | create_graph=True) 120 | 121 | model.zero_grad() 122 | grad_grad_tuple = torch.autograd.grad( 123 | outputs=grad_tuple, 124 | inputs=[ 125 | param for name, param 126 | in model.named_parameters() 127 | if name not in params_filter], 128 | grad_outputs=vectors, 129 | only_inputs=True 130 | ) 131 | 132 | return grad_grad_tuple 133 | 134 | 135 | def compute_s_test( 136 | n_gpu: int, 137 | device: torch.device, 138 | model: torch.nn.Module, 139 | test_inputs: Dict[str, torch.Tensor], 140 | train_data_loaders: List[torch.utils.data.DataLoader], 141 | params_filter: Optional[List[str]], 142 | weight_decay: Optional[float], 143 | weight_decay_ignores: Optional[List[str]], 144 | damp: float, 145 | scale: float, 146 | num_samples: Optional[int] = None, 147 | verbose: bool = True, 148 | ) -> List[torch.FloatTensor]: 149 | 150 | v = compute_gradients( 151 | model=model, 152 | n_gpu=n_gpu, 153 | device=device, 154 | inputs=test_inputs, 155 | params_filter=params_filter, 156 | weight_decay=weight_decay, 157 | weight_decay_ignores=weight_decay_ignores) 158 | 159 | # Technically, it's hv^-1 160 | last_estimate = list(v).copy() 161 | cumulative_num_samples = 0 162 | with tqdm(total=num_samples) as pbar: 163 | for data_loader in train_data_loaders: 164 | for i, inputs in enumerate(data_loader): 165 | this_estimate = compute_hessian_vector_products( 166 | model=model, 167 | n_gpu=n_gpu, 168 | device=device, 169 | vectors=last_estimate, 170 | inputs=inputs, 171 | params_filter=params_filter, 172 | weight_decay=weight_decay, 173 | weight_decay_ignores=weight_decay_ignores) 174 | # Recursively caclulate h_estimate 175 | # https://github.com/dedeswim/pytorch_influence_functions/blob/master/pytorch_influence_functions/influence_functions/hvp_grad.py#L118 176 | with torch.no_grad(): 177 | new_estimate = [ 178 | a + (1 - damp) * b - c / scale 179 | for a, b, c in zip(v, last_estimate, this_estimate) 180 | ] 181 | 182 | pbar.update(1) 183 | if verbose is True: 184 | new_estimate_norm = new_estimate[0].norm().item() 185 | last_estimate_norm = last_estimate[0].norm().item() 186 | estimate_norm_diff = new_estimate_norm - last_estimate_norm 187 | pbar.set_description(f"{new_estimate_norm:.2f} | {estimate_norm_diff:.2f}") 188 | 189 | cumulative_num_samples += 1 190 | last_estimate = new_estimate 191 | if num_samples is not None and i > num_samples: 192 | break 193 | 194 | # References: 195 | # https://github.com/kohpangwei/influence-release/blob/master/influence/genericNeuralNet.py#L475 196 | # Do this for each iteration of estimation 197 | # Since we use one estimation, we put this at the end 198 | inverse_hvp = [X / scale for X in last_estimate] 199 | 200 | # Sanity check 201 | # Note that in parallel settings, we should have `num_samples` 202 | # whereas in sequential settings we would have `num_samples + 2`. 203 | # This is caused by some loose stop condition. In parallel settings, 204 | # We only allocate `num_samples` data to reduce communication overhead. 205 | # Should probably make this more consistent sometime. 206 | if cumulative_num_samples not in [num_samples, num_samples + 2]: 207 | raise ValueError(f"cumulative_num_samples={cumulative_num_samples} f" 208 | f"but num_samples={num_samples}: Untested Territory") 209 | 210 | return inverse_hvp 211 | 212 | 213 | def compute_grad_zs( 214 | n_gpu: int, 215 | device: torch.device, 216 | model: torch.nn.Module, 217 | data_loader: torch.utils.data.DataLoader, 218 | params_filter: Optional[List[str]] = None, 219 | weight_decay: Optional[float] = None, 220 | weight_decay_ignores: Optional[List[str]] = None, 221 | ) -> List[List[torch.FloatTensor]]: 222 | 223 | if weight_decay_ignores is None: 224 | weight_decay_ignores = [ 225 | "bias", 226 | "LayerNorm.weight"] 227 | 228 | grad_zs = [] 229 | for inputs in data_loader: 230 | grad_z = compute_gradients( 231 | n_gpu=n_gpu, device=device, 232 | model=model, inputs=inputs, 233 | params_filter=params_filter, 234 | weight_decay=weight_decay, 235 | weight_decay_ignores=weight_decay_ignores) 236 | with torch.no_grad(): 237 | grad_zs.append([X.cpu() for X in grad_z]) 238 | 239 | return grad_zs 240 | 241 | 242 | def compute_influences( 243 | n_gpu: int, 244 | device: torch.device, 245 | model: torch.nn.Module, 246 | test_inputs: Dict[str, torch.Tensor], 247 | batch_train_data_loader: torch.utils.data.DataLoader, 248 | instance_train_data_loader: torch.utils.data.DataLoader, 249 | params_filter: Optional[List[str]] = None, 250 | weight_decay: Optional[float] = None, 251 | weight_decay_ignores: Optional[List[str]] = None, 252 | s_test_damp: float = 3e-5, 253 | s_test_scale: float = 1e4, 254 | s_test_num_samples: Optional[int] = None, 255 | s_test_iterations: int = 1, 256 | precomputed_s_test: Optional[List[torch.FloatTensor]] = None, 257 | train_indices_to_include: Optional[Union[np.ndarray, List[int]]] = None, 258 | ) -> Tuple[Dict[int, float], Dict[int, Dict], List[torch.FloatTensor]]: 259 | 260 | if s_test_iterations < 1: 261 | raise ValueError("`s_test_iterations` must >= 1") 262 | 263 | if weight_decay_ignores is None: 264 | # https://github.com/huggingface/transformers/blob/v3.0.2/src/transformers/trainer.py#L325 265 | weight_decay_ignores = [ 266 | "bias", 267 | "LayerNorm.weight"] 268 | 269 | if precomputed_s_test is not None: 270 | s_test = precomputed_s_test 271 | else: 272 | s_test = None 273 | for _ in range(s_test_iterations): 274 | _s_test = compute_s_test( 275 | n_gpu=n_gpu, 276 | device=device, 277 | model=model, 278 | test_inputs=test_inputs, 279 | train_data_loaders=[batch_train_data_loader], 280 | params_filter=params_filter, 281 | weight_decay=weight_decay, 282 | weight_decay_ignores=weight_decay_ignores, 283 | damp=s_test_damp, 284 | scale=s_test_scale, 285 | num_samples=s_test_num_samples) 286 | 287 | # Sum the values across runs 288 | if s_test is None: 289 | s_test = _s_test 290 | else: 291 | s_test = [ 292 | a + b for a, b in zip(s_test, _s_test) 293 | ] 294 | # Do the averaging 295 | s_test = [a / s_test_iterations for a in s_test] 296 | 297 | influences = {} 298 | train_inputs_collections = {} 299 | for index, train_inputs in enumerate(tqdm(instance_train_data_loader)): 300 | 301 | # Skip indices when a subset is specified to be included 302 | if (train_indices_to_include is not None) and ( 303 | index not in train_indices_to_include): 304 | continue 305 | 306 | grad_z = compute_gradients( 307 | n_gpu=n_gpu, 308 | device=device, 309 | model=model, 310 | inputs=train_inputs, 311 | params_filter=params_filter, 312 | weight_decay=weight_decay, 313 | weight_decay_ignores=weight_decay_ignores) 314 | 315 | with torch.no_grad(): 316 | influence = [ 317 | - torch.sum(x * y) 318 | for x, y in zip(grad_z, s_test)] 319 | 320 | influences[index] = sum(influence).item() 321 | train_inputs_collections[index] = train_inputs 322 | 323 | return influences, train_inputs_collections, s_test 324 | -------------------------------------------------------------------------------- /influence_utils/parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | import os 7 | import sys 8 | import torch 9 | import tempfile 10 | import numpy as np 11 | import torch.distributed as dist 12 | # import torch.multiprocessing as mp 13 | from tqdm import tqdm 14 | from copy import deepcopy 15 | from transformers import GlueDataset 16 | from typing import Dict, List, Any, Tuple, Union, Optional 17 | 18 | from experiments import misc_utils 19 | from influence_utils import nn_influence_utils 20 | from influence_utils import multiprocessing_utils as custom_mp 21 | 22 | 23 | def _compute_s_test( 24 | rank: int, 25 | model: torch.nn.Module, 26 | dataloaders: torch.utils.data.DataLoader, 27 | n_gpu: int, 28 | devices: List[torch.device], 29 | test_inputs: Dict[str, torch.Tensor], 30 | params_filter: Optional[List[str]] = None, 31 | weight_decay: Optional[float] = None, 32 | weight_decay_ignores: Optional[List[str]] = None, 33 | s_test_damp: float = 3e-5, 34 | s_test_scale: float = 1e4, 35 | s_test_num_samples: Optional[int] = None, 36 | ) -> List[torch.Tensor]: 37 | 38 | s_test = nn_influence_utils.compute_s_test( 39 | n_gpu=n_gpu, 40 | device=devices[rank], 41 | model=model, 42 | test_inputs=test_inputs, 43 | train_data_loaders=[dataloaders], 44 | params_filter=params_filter, 45 | weight_decay=weight_decay, 46 | weight_decay_ignores=weight_decay_ignores, 47 | damp=s_test_damp, 48 | scale=s_test_scale, 49 | num_samples=s_test_num_samples) 50 | 51 | # Gather `s_test` computed in other processes and 52 | # aggregate them via averaging. 53 | # print(flatten_and_concat(s_test).norm()) 54 | world_size = float(dist.get_world_size()) 55 | for index in range(len(s_test)): 56 | dist.all_reduce(s_test[index], op=dist.ReduceOp.SUM) 57 | s_test[index] = s_test[index] / world_size 58 | # print(flatten_and_concat(s_test).norm()) 59 | return s_test 60 | 61 | 62 | def _compute_influences( 63 | rank: int, 64 | model: torch.nn.Module, 65 | s_test: List[torch.Tensor], 66 | scattered_inputs: List[Any], 67 | scattered_indices: List[int], 68 | params_filter: Optional[List[str]] = None, 69 | weight_decay: Optional[float] = None, 70 | weight_decay_ignores: Optional[List[str]] = None, 71 | ) -> Dict[int, float]: 72 | 73 | wrapped_model = InfluenceHelper( 74 | mode="list", 75 | n_gpu=1, 76 | model=model, 77 | progress_bar=True, 78 | params_filter=params_filter, 79 | weight_decay=weight_decay, 80 | weight_decay_ignores=weight_decay_ignores) 81 | 82 | influences_list = wrapped_model( 83 | Xs=scattered_inputs, 84 | s_test=s_test) 85 | 86 | influences = {} 87 | for i, index in enumerate(scattered_indices): 88 | # Save just the values not the Tensor to 89 | # speed up saving/loading time 90 | influences[index] = influences_list[i].item() 91 | 92 | return influences 93 | 94 | 95 | def compute_s_test_and_influence( 96 | rank: int, 97 | file_name: str, 98 | model: torch.nn.Module, 99 | dataloaders: torch.utils.data.DataLoader, 100 | scattered_inputs: List[Any], 101 | scattered_indices: List[int], 102 | n_gpu: int, 103 | devices: List[torch.device], 104 | test_inputs: Dict[str, torch.Tensor], 105 | params_filter: Optional[List[str]] = None, 106 | weight_decay: Optional[float] = None, 107 | weight_decay_ignores: Optional[List[str]] = None, 108 | s_test_damp: float = 3e-5, 109 | s_test_scale: float = 1e4, 110 | s_test_num_samples: Optional[int] = None, 111 | return_s_test: bool = False, 112 | log_stdin_and_stdout: bool = True, 113 | ) -> Tuple[Dict[int, float], List[torch.FloatTensor]]: 114 | 115 | # Initialize 116 | # https://pytorch.org/tutorials/intermediate/dist_tuto.html 117 | os.environ["MASTER_ADDR"] = "127.0.0.1" 118 | os.environ["MASTER_PORT"] = "29500" 119 | dist.init_process_group("gloo", rank=rank, world_size=len(devices)) 120 | 121 | if log_stdin_and_stdout is True: 122 | logdir = "./logs" 123 | if not os.path.isdir(logdir): 124 | os.mkdir(logdir) 125 | # https://stackoverflow.com/questions/1501651/log-output-of-multiprocessing-process 126 | sys.stdout = open(os.path.join(logdir, f"mp.{os.getpid()}.out"), "a") 127 | sys.stderr = open(os.path.join(logdir, f"mp.{os.getpid()}.err"), "a") 128 | 129 | # Approx. 4-5sec for moving model to a specified GPU 130 | model.to(devices[rank]) 131 | s_test = _compute_s_test( 132 | rank=rank, 133 | model=model, 134 | dataloaders=dataloaders, 135 | n_gpu=n_gpu, 136 | devices=devices, 137 | test_inputs=test_inputs, 138 | params_filter=params_filter, 139 | weight_decay=weight_decay, 140 | weight_decay_ignores=weight_decay_ignores, 141 | s_test_damp=s_test_damp, 142 | s_test_scale=s_test_scale, 143 | s_test_num_samples=s_test_num_samples) 144 | 145 | influences = _compute_influences( 146 | rank=rank, 147 | model=model, 148 | s_test=s_test, 149 | scattered_inputs=scattered_inputs, 150 | scattered_indices=scattered_indices, 151 | params_filter=params_filter, 152 | weight_decay=weight_decay, 153 | weight_decay_ignores=weight_decay_ignores) 154 | 155 | if log_stdin_and_stdout is True: 156 | # https://stackoverflow.com/questions/14245227/python-reset-stdout-to-normal-after-previously-redirecting-it-to-a-file 157 | sys.stdout = sys.__stdout__ 158 | sys.stderr = sys.__stderr__ 159 | 160 | # Save outputs, normally we do not need 161 | # need `s_test` and saving it takes extra time, 162 | # but sometimes we need it for diagnostics. 163 | if return_s_test is True: 164 | torch.save({ 165 | "influences": influences, 166 | "s_test": s_test}, 167 | file_name) 168 | 169 | else: 170 | torch.save({ 171 | "influences": influences}, 172 | file_name) 173 | 174 | # Always return both, though in multiprocessing 175 | # this does not matter except making type annotation 176 | # cleaner, which is also important :) 177 | return influences, s_test 178 | 179 | 180 | def compute_influences_parallel( 181 | device_ids: List[int], 182 | train_dataset: GlueDataset, 183 | batch_size: int, 184 | model: torch.nn.Module, 185 | test_inputs: Dict[str, torch.Tensor], 186 | params_filter: Optional[List[str]] = None, 187 | weight_decay: Optional[float] = None, 188 | weight_decay_ignores: Optional[List[str]] = None, 189 | s_test_damp: float = 3e-5, 190 | s_test_scale: float = 1e4, 191 | s_test_num_samples: Optional[int] = None, 192 | random: bool = True, 193 | debug: bool = False, 194 | return_s_test: bool = False, 195 | train_indices_to_include: Optional[Union[np.ndarray, List[int]]] = None, 196 | ) -> Tuple[Dict[int, float], Optional[List[torch.FloatTensor]]]: 197 | 198 | if s_test_num_samples is None: 199 | raise ValueError("`s_test_num_samples` cannot be None") 200 | 201 | # Passing the smaller subset of training data to child 202 | # processes can significantly reduce the overhead of 203 | # spawning new child processes. 204 | dataloders = prepare_small_dataloaders( 205 | dataset=train_dataset, 206 | random=random, 207 | batch_size=batch_size, 208 | num_datasets=len(device_ids), 209 | num_examples_per_dataset=s_test_num_samples) 210 | 211 | scattered_inputs, scattered_indices = prepare_scattered_inputs_and_indices( 212 | dataset=train_dataset, 213 | device_ids=device_ids, 214 | indices_to_include=train_indices_to_include) 215 | 216 | devices = [torch.device(f"cuda:{device_id}") for device_id in device_ids] 217 | tmpfiles = [tempfile.NamedTemporaryFile() for _ in range(len(device_ids))] 218 | process_args = [( 219 | tmpfiles[process_index].name, 220 | model, 221 | dataloders[process_index], 222 | scattered_inputs[process_index], 223 | scattered_indices[process_index], 224 | 1, # n_gpu 225 | devices, 226 | test_inputs, 227 | params_filter, 228 | weight_decay, 229 | weight_decay_ignores, 230 | s_test_damp, 231 | s_test_scale, 232 | s_test_num_samples, 233 | return_s_test, 234 | True if debug is False else False, # log_stdin_and_stdout 235 | ) for process_index in range(len(device_ids))] 236 | 237 | if debug is False: 238 | try: 239 | custom_mp.spawn( 240 | compute_s_test_and_influence, 241 | list_of_args=process_args, 242 | nprocs=len(device_ids), 243 | join=True) 244 | 245 | influences: Dict[int, float] = {} 246 | for tmpfile in tmpfiles: 247 | outputs_dict = torch.load(tmpfile.name) 248 | for key, val in outputs_dict["influences"].items(): 249 | if key in influences.keys(): 250 | raise ValueError 251 | influences[key] = val 252 | 253 | # Note that `s_test` is the same across all processes 254 | # after they end because of the syncronization, so we 255 | # just need to load one of them, and we pick the last one 256 | s_test = outputs_dict.get("s_test", None) 257 | 258 | finally: 259 | for tmpfile in tmpfiles: 260 | tmpfile.close() 261 | 262 | return influences, s_test 263 | 264 | else: 265 | random_rank = np.random.choice(len(device_ids)) 266 | print(f"Using random rank {random_rank}") 267 | return compute_s_test_and_influence( 268 | random_rank, *process_args[random_rank]) 269 | 270 | 271 | class SimpleDataset(torch.utils.data.Dataset): 272 | """Simple Dataset class where examples are fetched by index""" 273 | 274 | def __init__(self, examples: List[Any]) -> None: 275 | self.examples = examples 276 | 277 | def __len__(self) -> int: 278 | return len(self.examples) 279 | 280 | def __getitem__(self, index: int) -> Any: 281 | return self.examples[index] 282 | 283 | 284 | def prepare_small_dataloaders( 285 | dataset: torch.utils.data.Dataset, 286 | random: bool, 287 | batch_size: int, 288 | num_datasets: int, 289 | num_examples_per_dataset: int) -> List[SimpleDataset]: 290 | """Only pass to child processes the data we will really use""" 291 | 292 | examples = [] 293 | total_num_examples = batch_size * num_datasets * num_examples_per_dataset 294 | 295 | if random is True: 296 | indices = np.random.choice( 297 | len(dataset), 298 | size=total_num_examples, 299 | # Sample without replacement 300 | replace=False) 301 | else: 302 | indices = list(range(total_num_examples)) 303 | 304 | for index in indices: 305 | example = dataset[index] 306 | examples.append(example) 307 | 308 | dataloaders = [] 309 | for i in range(num_datasets): 310 | start_index = i * batch_size * num_examples_per_dataset 311 | end_index = (i + 1) * batch_size * num_examples_per_dataset 312 | new_dataset = SimpleDataset(examples[start_index: end_index]) 313 | dataloader = misc_utils.get_dataloader( 314 | dataset=new_dataset, 315 | batch_size=batch_size, 316 | # The random here doesn't matter? 317 | random=random) 318 | dataloaders.append(dataloader) 319 | 320 | return dataloaders 321 | 322 | 323 | def prepare_scattered_inputs_and_indices( 324 | device_ids: List[int], 325 | dataset: torch.utils.data.Dataset, 326 | indices_to_include: Optional[List[int]] = None, 327 | ) -> Tuple[List[List[Any]], List[List[int]]]: 328 | """Scatter the data into devices""" 329 | 330 | indices_list = [] 331 | # inputs_collections = {} 332 | inputs_collections_list = [] 333 | instance_dataloader = misc_utils.get_dataloader( 334 | dataset=dataset, batch_size=1) 335 | for index, train_inputs in enumerate(tqdm(instance_dataloader)): 336 | 337 | # Skip indices when a subset is specified to be included 338 | if (indices_to_include is not None) and ( 339 | index not in indices_to_include): 340 | continue 341 | 342 | indices_list.append(index) 343 | # inputs_collections[index] = train_inputs 344 | inputs_collections_list.append(train_inputs) 345 | 346 | scattered_inputs, scattered_indices = scatter_inputs_and_indices( 347 | Xs=inputs_collections_list, 348 | indices=indices_list, device_ids=device_ids) 349 | 350 | return scattered_inputs, scattered_indices 351 | 352 | 353 | def flatten_and_concat(Xs: List[torch.Tensor]) -> torch.Tensor: 354 | return torch.cat([X.flatten() for X in Xs], dim=0) 355 | 356 | 357 | def scatter_inputs_and_indices( 358 | Xs: List[Any], 359 | indices: List[int], 360 | device_ids: List[int] 361 | ) -> Tuple[List[List[Any]], List[List[int]]]: 362 | """Scatter `Xs` across devices""" 363 | copied_Xs = deepcopy(Xs) 364 | copied_indices = deepcopy(indices) 365 | devices = [torch.device(f"cuda:{i}") for i in device_ids] 366 | 367 | def _map_to_device(X: Any, device: torch.device): 368 | for k, v in X.items(): 369 | if isinstance(v, torch.Tensor): 370 | X[k] = v.to(device) 371 | 372 | return X 373 | 374 | scattered_Xs: List[List[Any]] = [[] for _ in range(len(device_ids))] 375 | scattered_indices: List[List[int]] = [[] for _ in range(len(device_ids))] 376 | boundary = np.ceil(len(copied_Xs) / len(device_ids)) 377 | for i, (X, index) in enumerate(zip(copied_Xs, copied_indices)): 378 | device_index = int(i // boundary) 379 | device = devices[device_index] 380 | scattered_Xs[device_index].append( 381 | _map_to_device(X, device)) 382 | scattered_indices[device_index].append(index) 383 | 384 | return scattered_Xs, scattered_indices 385 | 386 | 387 | class InfluenceHelper(torch.nn.Module): 388 | """Helper Module for computing influence values""" 389 | 390 | def __init__(self, 391 | mode: str, 392 | n_gpu: int, 393 | model: torch.nn.Module, 394 | progress_bar: bool = False, 395 | params_filter: Optional[List[str]] = None, 396 | weight_decay: Optional[float] = None, 397 | weight_decay_ignores: Optional[List[str]] = None): 398 | 399 | super(InfluenceHelper, self).__init__() 400 | 401 | if mode not in ["list", "instance"]: 402 | raise ValueError 403 | 404 | if weight_decay_ignores is None: 405 | # https://github.com/huggingface/transformers/blob/v3.0.2/src/transformers/trainer.py#L325 406 | weight_decay_ignores = [ 407 | "bias", 408 | "LayerNorm.weight"] 409 | 410 | self.model = model 411 | self._mode = mode 412 | self._n_gpu = n_gpu 413 | self._progress_bar = progress_bar 414 | self._params_filter = params_filter 415 | self._weight_decay = weight_decay 416 | self._weight_decay_ignores = weight_decay_ignores 417 | 418 | def _compute_influence( 419 | self, 420 | device: torch.device, 421 | X: Dict[str, torch.Tensor], 422 | s_test: List[torch.FloatTensor], 423 | ) -> torch.Tensor: 424 | 425 | grad_z = nn_influence_utils.compute_gradients( 426 | n_gpu=self._n_gpu, 427 | device=device, 428 | model=self.model, 429 | inputs=X, 430 | params_filter=self._params_filter, 431 | weight_decay=self._weight_decay, 432 | weight_decay_ignores=self._weight_decay_ignores) 433 | 434 | with torch.no_grad(): 435 | influence = [ 436 | - torch.sum(x * y) 437 | for x, y in zip(grad_z, s_test)] 438 | 439 | return sum(influence) 440 | 441 | def forward(self, 442 | Xs: Union[Dict[str, torch.Tensor], 443 | List[Dict[str, torch.Tensor]]], 444 | s_test: List[torch.FloatTensor] 445 | ) -> torch.FloatTensor: 446 | 447 | if self._mode in ["instance"]: 448 | # `Xs` has single instance 449 | if not isinstance(Xs, dict): 450 | raise TypeError(f"`Xs` should be a dictionary but {type(Xs)}") 451 | 452 | device = Xs["labels"].device 453 | new_s_test = [x.to(device) for x in s_test] 454 | return self._compute_influence( 455 | device=device, X=Xs, s_test=new_s_test) 456 | 457 | else: 458 | # `Xs` is a list of instances 459 | if not isinstance(Xs, list): 460 | raise TypeError(f"`Xs` should be a list but {type(Xs)}") 461 | 462 | influences = [] 463 | device = Xs[0]["labels"].device 464 | new_s_test = [x.to(device) for x in s_test] 465 | if self._progress_bar is True: 466 | Xs = tqdm(Xs) 467 | 468 | influences = [ 469 | self._compute_influence( 470 | device=device, X=X, # noqa 471 | s_test=new_s_test) 472 | for X in Xs] 473 | 474 | return torch.stack(influences, dim=0) 475 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jupyterlab 2 | transformers==3.0.2 3 | wandb 4 | scikit-learn 5 | matplotlib 6 | faiss-gpu 7 | overrides 8 | contexttimer 9 | pandas 10 | yagmail[all] 11 | scp 12 | streamlit 13 | pycairo -------------------------------------------------------------------------------- /run_experiments.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | import sys 7 | from typing import Optional, Dict 8 | 9 | from experiments import mnli 10 | from experiments import hans 11 | from experiments import s_test_speedup 12 | from experiments import remote_utils 13 | from experiments import visualization 14 | 15 | 16 | USE_PARALLEL = True 17 | NUM_KNN_RECALL_EXPERIMENTS = 50 18 | NUM_RETRAINING_EXPERIMENTS = 3 19 | NUM_STEST_EXPERIMENTS = 10 20 | NUM_VISUALIZATION_EXPERIMENTS = 100 21 | NUM_IMITATOR_EXPERIMENTS = 10 22 | 23 | 24 | def KNN_recall_experiments( 25 | mode: str, 26 | num_experiments: Optional[int] = None 27 | ) -> None: 28 | """Experiments to Check The Influence Recall of KNN""" 29 | print("RUNNING `KNN_recall_experiments`") 30 | 31 | if num_experiments is None: 32 | num_experiments = NUM_KNN_RECALL_EXPERIMENTS 33 | 34 | # (a) when the prediction is correct, and (b) incorrect 35 | mnli.run_full_influence_functions( 36 | mode=mode, 37 | num_examples_to_test=num_experiments) 38 | 39 | 40 | def s_test_speed_quality_tradeoff_experiments( 41 | mode: str, 42 | num_experiments: Optional[int] = None 43 | ) -> None: 44 | """Experiments to Check The Speed/Quality Trade-off of `s_test` estimation""" 45 | print("RUNNING `s_test_speed_quality_tradeoff_experiments`") 46 | 47 | if num_experiments is None: 48 | num_experiments = NUM_STEST_EXPERIMENTS 49 | 50 | # (a) when the prediction is correct, and (b) incorrect 51 | s_test_speedup.main( 52 | mode=mode, 53 | num_examples_to_test=num_experiments) 54 | 55 | 56 | def MNLI_retraining_experiments( 57 | mode: str, 58 | num_experiments: Optional[int] = None 59 | ) -> None: 60 | print("RUNNING `MNLI_retraining_experiments`") 61 | 62 | if num_experiments is None: 63 | num_experiments = NUM_RETRAINING_EXPERIMENTS 64 | 65 | mnli.run_retraining_main( 66 | mode=mode, 67 | num_examples_to_test=num_experiments) 68 | 69 | 70 | def visualization_experiments( 71 | num_experiments: Optional[int] = None 72 | ) -> None: 73 | """Experiments for Visualizing Effects""" 74 | print("RUNNING `visualization_experiments`") 75 | 76 | if num_experiments is None: 77 | num_experiments = NUM_VISUALIZATION_EXPERIMENTS 78 | 79 | for heuristic in hans.DEFAULT_HANS_EVAL_HEURISTICS: 80 | visualization.main( 81 | train_task_name="hans", 82 | eval_task_name="hans", 83 | num_eval_to_collect=num_experiments, 84 | use_parallel=USE_PARALLEL, 85 | hans_heuristic=heuristic, 86 | trained_on_task_name="hans") 87 | 88 | visualization.main( 89 | train_task_name="hans", 90 | eval_task_name="mnli-2", 91 | num_eval_to_collect=num_experiments, 92 | use_parallel=USE_PARALLEL, 93 | hans_heuristic=None, 94 | trained_on_task_name="hans") 95 | 96 | 97 | def prepare_data_for_retraining( 98 | num_eval_to_collect: int, 99 | ) -> None: 100 | for mode in ["only-correct", "only-incorrect"]: 101 | for kNN_k in [1000, 10000]: 102 | visualization.main( 103 | mode=mode, 104 | train_task_name="mnli", 105 | eval_task_name="mnli", 106 | num_eval_to_collect=num_eval_to_collect, 107 | use_parallel=True, 108 | kNN_k=kNN_k) 109 | 110 | 111 | def hans_augmentation_experiments( 112 | num_replicas: Optional[int] = None 113 | ) -> None: 114 | print("RUNNING `hans_augmentation_experiments`") 115 | # We will use the all the `train_heuristic` here, as we did in 116 | # `eval_heuristics`. So looping over the `DEFAULT_HANS_EVAL_HEURISTICS` 117 | for train_task_name in ["mnli-2", "hans"]: 118 | for train_heuristic in hans.DEFAULT_HANS_EVAL_HEURISTICS: 119 | for version in ["new-only-z", "new-only-ztest"]: 120 | hans.main( 121 | trained_on_task_name="mnli-2", 122 | train_task_name=train_task_name, 123 | train_heuristic=train_heuristic, 124 | num_replicas=num_replicas, 125 | use_parallel=USE_PARALLEL, 126 | version=version) 127 | 128 | 129 | def amazon_augmentation_experiments( 130 | num_replicas: Optional[int] = None 131 | ) -> None: 132 | print("RUNNING `amazon_augmentation_experiments`") 133 | for train_task_name in ["amazon"]: 134 | for train_heuristic in hans.DEFAULT_Amazon_EVAL_HEURISTICS: 135 | for version in ["new-only-z", "new-only-ztest"]: 136 | hans.main( 137 | trained_on_task_name="amazon", 138 | train_task_name=train_task_name, 139 | train_heuristic=train_heuristic, 140 | num_replicas=num_replicas, 141 | use_parallel=USE_PARALLEL, 142 | version=version) 143 | 144 | 145 | def anli_augmentation_experiments( 146 | num_replicas: Optional[int] = None 147 | ) -> None: 148 | print("RUNNING `anli_augmentation_experiments`") 149 | for train_task_name in ["anli"]: 150 | for train_heuristic in hans.DEFAULT_ANLI_EVAL_HEURISTICS: 151 | for version in ["new-only-z", "new-only-ztest"]: 152 | hans.main( 153 | trained_on_task_name="mnli", 154 | train_task_name=train_task_name, 155 | train_heuristic=train_heuristic, 156 | num_replicas=num_replicas, 157 | use_parallel=USE_PARALLEL, 158 | version=version) 159 | 160 | 161 | def imitator_experiments( 162 | num_experiments: Optional[int] = None 163 | ) -> None: 164 | print("RUNNING `imitator_experiments`") 165 | 166 | if num_experiments is None: 167 | num_experiments = NUM_IMITATOR_EXPERIMENTS 168 | 169 | mnli.imitator_main( 170 | mode="only-correct", 171 | num_examples_to_test=num_experiments) 172 | 173 | mnli.imitator_main( 174 | mode="only-incorrect", 175 | num_examples_to_test=num_experiments) 176 | 177 | 178 | if __name__ == "__main__": 179 | # Make sure the environment is properly setup 180 | # remote_utils.setup_and_verify_environment() 181 | 182 | experiment_name = sys.argv[1] 183 | if experiment_name == "knn-recall-correct": 184 | KNN_recall_experiments( 185 | mode="only-correct") 186 | if experiment_name == "knn-recall-incorrect": 187 | KNN_recall_experiments( 188 | mode="only-incorrect") 189 | 190 | if experiment_name == "s-test-correct": 191 | s_test_speed_quality_tradeoff_experiments( 192 | mode="only-correct") 193 | if experiment_name == "s-test-incorrect": 194 | s_test_speed_quality_tradeoff_experiments( 195 | mode="only-incorrect") 196 | 197 | if experiment_name == "retraining-full": 198 | MNLI_retraining_experiments( 199 | mode="full") 200 | 201 | if experiment_name == "retraining-random": 202 | MNLI_retraining_experiments( 203 | mode="random") 204 | 205 | if experiment_name == "retraining-KNN-1000": 206 | MNLI_retraining_experiments( 207 | mode="KNN-1000") 208 | 209 | if experiment_name == "retraining-KNN-10000": 210 | MNLI_retraining_experiments( 211 | mode="KNN-10000") 212 | 213 | if experiment_name == "hans-augmentation": 214 | hans_augmentation_experiments() 215 | 216 | if experiment_name == "amazon-augmentation": 217 | amazon_augmentation_experiments() 218 | 219 | if experiment_name == "anli-augmentation": 220 | anli_augmentation_experiments() 221 | 222 | if experiment_name == "imitator": 223 | imitator_experiments() 224 | 225 | # raise ValueError(f"Unknown Experiment Name: {experiment_name}") 226 | -------------------------------------------------------------------------------- /run_glue.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Copyright (c) 2020, salesforce.com, inc. 18 | # All rights reserved. 19 | # SPDX-License-Identifier: BSD-3-Clause 20 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 21 | 22 | """ Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa, Albert, XLM-RoBERTa).""" 23 | 24 | 25 | import dataclasses 26 | import logging 27 | import os 28 | import sys 29 | from dataclasses import dataclass, field 30 | from typing import Callable, Dict, Optional 31 | 32 | import numpy as np 33 | 34 | from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction 35 | from transformers import GlueDataTrainingArguments as DataTrainingArguments 36 | from transformers import ( 37 | HfArgumentParser, 38 | Trainer, 39 | TrainingArguments, 40 | # glue_compute_metrics, 41 | # glue_output_modes, 42 | # glue_tasks_num_labels, 43 | set_seed, 44 | ) 45 | 46 | from influence_utils import glue_utils 47 | from experiments.data_utils import ( 48 | CustomGlueDataset, 49 | glue_compute_metrics, 50 | glue_output_modes, 51 | glue_tasks_num_labels) 52 | 53 | 54 | logger = logging.getLogger(__name__) 55 | 56 | 57 | @dataclass 58 | class ModelArguments: 59 | """ 60 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 61 | """ 62 | 63 | model_name_or_path: str = field( 64 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 65 | ) 66 | config_name: Optional[str] = field( 67 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 68 | ) 69 | tokenizer_name: Optional[str] = field( 70 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 71 | ) 72 | cache_dir: Optional[str] = field( 73 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 74 | ) 75 | 76 | 77 | def main(): 78 | # See all possible arguments in src/transformers/training_args.py 79 | # or by passing the --help flag to this script. 80 | # We now keep distinct sets of args, for a cleaner separation of concerns. 81 | 82 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 83 | 84 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 85 | # If we pass only one argument to the script and it's the path to a json file, 86 | # let's parse it to get our arguments. 87 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 88 | else: 89 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 90 | 91 | if ( 92 | os.path.exists(training_args.output_dir) 93 | and os.listdir(training_args.output_dir) 94 | and training_args.do_train 95 | and not training_args.overwrite_output_dir 96 | ): 97 | raise ValueError( 98 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 99 | ) 100 | 101 | # Setup logging 102 | logging.basicConfig( 103 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 104 | datefmt="%m/%d/%Y %H:%M:%S", 105 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 106 | ) 107 | logger.warning( 108 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 109 | training_args.local_rank, 110 | training_args.device, 111 | training_args.n_gpu, 112 | bool(training_args.local_rank != -1), 113 | training_args.fp16, 114 | ) 115 | logger.info("Training/evaluation parameters %s", training_args) 116 | 117 | # Set seed 118 | set_seed(training_args.seed) 119 | 120 | try: 121 | num_labels = glue_tasks_num_labels[data_args.task_name] 122 | output_mode = glue_output_modes[data_args.task_name] 123 | except KeyError: 124 | raise ValueError("Task not found: %s" % (data_args.task_name)) 125 | 126 | # Load pretrained model and tokenizer 127 | # 128 | # Distributed training: 129 | # The .from_pretrained methods guarantee that only one local process can concurrently 130 | # download model & vocab. 131 | 132 | config = AutoConfig.from_pretrained( 133 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 134 | num_labels=num_labels, 135 | finetuning_task=data_args.task_name, 136 | cache_dir=model_args.cache_dir, 137 | ) 138 | tokenizer = AutoTokenizer.from_pretrained( 139 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 140 | cache_dir=model_args.cache_dir, 141 | ) 142 | model = AutoModelForSequenceClassification.from_pretrained( 143 | model_args.model_name_or_path, 144 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 145 | config=config, 146 | cache_dir=model_args.cache_dir, 147 | ) 148 | 149 | # Get datasets 150 | train_dataset = ( 151 | CustomGlueDataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) 152 | if training_args.do_train 153 | else None 154 | ) 155 | eval_dataset = ( 156 | CustomGlueDataset(data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir) 157 | if training_args.do_eval 158 | else None 159 | ) 160 | test_dataset = ( 161 | CustomGlueDataset(data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir) 162 | if training_args.do_predict 163 | else None 164 | ) 165 | 166 | def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]: 167 | def compute_metrics_fn(p: EvalPrediction): 168 | if output_mode == "classification": 169 | preds = np.argmax(p.predictions, axis=1) 170 | elif output_mode == "regression": 171 | preds = np.squeeze(p.predictions) 172 | return glue_compute_metrics(task_name, preds, p.label_ids) 173 | 174 | return compute_metrics_fn 175 | 176 | # Freeze parameters except the classifier layer 177 | # NOTE: Even though weight decay is set w.r.t. all parameters 178 | # this is fine because the gradients of frozen parameters will 179 | # be zero, and they will not be updated. 180 | glue_utils.freeze_BERT_parameters(model) 181 | 182 | # Initialize our Trainer 183 | trainer = Trainer( 184 | model=model, 185 | args=training_args, 186 | train_dataset=train_dataset, 187 | eval_dataset=eval_dataset, 188 | compute_metrics=build_compute_metrics_fn(data_args.task_name), 189 | ) 190 | 191 | # Training 192 | if training_args.do_train: 193 | trainer.train( 194 | model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None 195 | ) 196 | trainer.save_model() 197 | # For convenience, we also re-save the tokenizer to the same directory, 198 | # so that you can share your model easily on huggingface.co/models =) 199 | if trainer.is_world_master(): 200 | tokenizer.save_pretrained(training_args.output_dir) 201 | 202 | # Evaluation 203 | eval_results = {} 204 | if training_args.do_eval: 205 | logger.info("*** Evaluate ***") 206 | 207 | # Loop to handle MNLI double evaluation (matched, mis-matched) 208 | eval_datasets = [eval_dataset] 209 | if data_args.task_name == "mnli": 210 | mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm") 211 | eval_datasets.append( 212 | CustomGlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir) 213 | ) 214 | 215 | if data_args.task_name == "mnli-2": 216 | mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-2-mm") 217 | eval_datasets.append( 218 | CustomGlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir) 219 | ) 220 | 221 | for eval_dataset in eval_datasets: 222 | trainer.compute_metrics = build_compute_metrics_fn(eval_dataset.args.task_name) 223 | eval_result = trainer.evaluate(eval_dataset=eval_dataset) 224 | 225 | output_eval_file = os.path.join( 226 | training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt" 227 | ) 228 | if trainer.is_world_master(): 229 | with open(output_eval_file, "w") as writer: 230 | logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name)) 231 | for key, value in eval_result.items(): 232 | logger.info(" %s = %s", key, value) 233 | writer.write("%s = %s\n" % (key, value)) 234 | 235 | eval_results.update(eval_result) 236 | 237 | if training_args.do_predict: 238 | logging.info("*** Test ***") 239 | test_datasets = [test_dataset] 240 | if data_args.task_name == "mnli": 241 | mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm") 242 | test_datasets.append( 243 | CustomGlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir) 244 | ) 245 | 246 | if data_args.task_name == "mnli-2": 247 | mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-2-mm") 248 | test_datasets.append( 249 | CustomGlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir) 250 | ) 251 | 252 | for test_dataset in test_datasets: 253 | predictions = trainer.predict(test_dataset=test_dataset).predictions 254 | if output_mode == "classification": 255 | predictions = np.argmax(predictions, axis=1) 256 | 257 | output_test_file = os.path.join( 258 | training_args.output_dir, f"test_results_{test_dataset.args.task_name}.txt" 259 | ) 260 | if trainer.is_world_master(): 261 | with open(output_test_file, "w") as writer: 262 | logger.info("***** Test results {} *****".format(test_dataset.args.task_name)) 263 | writer.write("index\tprediction\n") 264 | for index, item in enumerate(predictions): 265 | if output_mode == "regression": 266 | writer.write("%d\t%3.3f\n" % (index, item)) 267 | else: 268 | item = test_dataset.get_labels()[item] 269 | writer.write("%d\t%s\n" % (index, item)) 270 | return eval_results 271 | 272 | 273 | def _mp_fn(index): 274 | # For xla_spawn (TPUs) 275 | main() 276 | 277 | 278 | if __name__ == "__main__": 279 | main() -------------------------------------------------------------------------------- /scripts/run_Amazon.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | export BASE_DIR_DIR=/export/home/hguo/Data/WILDS 7 | 8 | python -m torch.distributed.launch \ 9 | --nproc_per_node 4 run_glue.py \ 10 | --model_name_or_path bert-base-cased \ 11 | --task_name amazon \ 12 | --do_train \ 13 | --do_eval \ 14 | --data_dir $BASE_DIR_DIR/Amazon/ \ 15 | --max_seq_length 128 \ 16 | --per_device_train_batch_size 128 \ 17 | --learning_rate 2e-5 \ 18 | --num_train_epochs 10.0 \ 19 | --output_dir output_dir \ 20 | --weight_decay 0.005 \ 21 | --save_steps 5000 \ 22 | --logging_steps 100 \ 23 | --save_total_limit 1 24 | -------------------------------------------------------------------------------- /scripts/run_HANS.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | export BASE_DIR_DIR=/export/home/Data 7 | 8 | python -m torch.distributed.launch \ 9 | --nproc_per_node 4 run_glue.py \ 10 | --model_name_or_path bert-base-cased \ 11 | --task_name hans \ 12 | --do_train \ 13 | --do_eval \ 14 | --data_dir $BASE_DIR_DIR/HANS/ \ 15 | --max_seq_length 128 \ 16 | --per_device_train_batch_size 128 \ 17 | --learning_rate 2e-5 \ 18 | --num_train_epochs 10.0 \ 19 | --output_dir output_dir \ 20 | --weight_decay 0.005 \ 21 | --save_steps 5000 \ 22 | --logging_steps 100 \ 23 | --save_total_limit 1 24 | -------------------------------------------------------------------------------- /scripts/run_MNLI.20200913.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | # https://stackoverflow.com/questions/2107945/how-to-loop-over-directories-in-linux 7 | data_dir=$1 8 | output_dir=$2 9 | 10 | echo "Using Data From ${data_dir}, Saving to ${output_dir}" 11 | 12 | python -m torch.distributed.launch \ 13 | --nproc_per_node 4 run_glue.py \ 14 | --model_name_or_path bert-base-cased \ 15 | --task_name mnli \ 16 | --do_train \ 17 | --do_eval \ 18 | --data_dir ${data_dir} \ 19 | --max_seq_length 128 \ 20 | --per_device_train_batch_size 128 \ 21 | --learning_rate 2e-5 \ 22 | --num_train_epochs 10.0 \ 23 | --output_dir ${output_dir} \ 24 | --weight_decay 0.005 \ 25 | --save_steps 5000 \ 26 | --logging_steps 100 \ 27 | --save_total_limit 1 -------------------------------------------------------------------------------- /scripts/run_MNLI.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | export GLUE_DIR=/export/home/Data/Glue 7 | 8 | python -m torch.distributed.launch \ 9 | --nproc_per_node 4 run_glue.py \ 10 | --model_name_or_path bert-base-cased \ 11 | --task_name mnli \ 12 | --do_train \ 13 | --do_eval \ 14 | --data_dir $GLUE_DIR/MNLI/ \ 15 | --max_seq_length 128 \ 16 | --per_device_train_batch_size 128 \ 17 | --learning_rate 2e-5 \ 18 | --num_train_epochs 10.0 \ 19 | --output_dir output_dir \ 20 | --weight_decay 0.005 \ 21 | --save_steps 5000 \ 22 | --logging_steps 100 \ 23 | --save_total_limit 1 24 | -------------------------------------------------------------------------------- /scripts/run_MNLI_2label.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, salesforce.com, inc. 2 | # All rights reserved. 3 | # SPDX-License-Identifier: BSD-3-Clause 4 | # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 5 | 6 | export GLUE_DIR=/export/home/Data/Glue 7 | 8 | python -m torch.distributed.launch \ 9 | --nproc_per_node 4 run_glue.py \ 10 | --model_name_or_path bert-base-cased \ 11 | --task_name mnli-2 \ 12 | --do_train \ 13 | --do_eval \ 14 | --data_dir $GLUE_DIR/MNLI/ \ 15 | --max_seq_length 128 \ 16 | --per_device_train_batch_size 128 \ 17 | --learning_rate 2e-5 \ 18 | --num_train_epochs 10.0 \ 19 | --output_dir output_dir \ 20 | --weight_decay 0.005 \ 21 | --save_steps 5000 \ 22 | --logging_steps 100 \ 23 | --save_total_limit 1 24 | --------------------------------------------------------------------------------