├── .gitignore ├── README.md ├── configs ├── config.yaml ├── data │ ├── amazon.yaml │ ├── cose.yaml │ ├── default.yaml │ ├── esnli.yaml │ ├── irony.yaml │ ├── movies.yaml │ ├── multirc.yaml │ ├── olid.yaml │ ├── sst.yaml │ ├── stf.yaml │ └── yelp.yaml ├── experiment │ ├── fixed_lm_expl.yaml │ └── iter_lm.yaml ├── hsearch │ └── lr.yaml ├── hydra │ └── default.yaml ├── logger │ └── neptune.yaml ├── model │ ├── a2r.yaml │ ├── expl_reg.yaml │ ├── fresh.yaml │ ├── lm.yaml │ ├── optimizer │ │ ├── adamw.yaml │ │ └── hf_adamw.yaml │ └── scheduler │ │ ├── fixed.yaml │ │ └── linear_with_warmup.yaml ├── setup │ └── a100.yaml ├── trainer │ └── defaults.yaml └── training │ ├── base.yaml │ ├── evaluate.yaml │ └── finetune.yaml ├── main.py ├── requirements.txt ├── scripts └── build_dataset.py └── src ├── __init__.py ├── data ├── __init__.py └── data.py ├── model ├── __init__.py ├── base_model.py ├── lm.py └── mlp.py ├── run.py └── utils ├── __init__.py ├── callbacks.py ├── conf.py ├── data.py ├── eraser ├── data_utils.py ├── metrics.py └── utils.py ├── expl.py ├── logging.py ├── losses.py ├── metrics.py ├── misc.py └── optim.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UNIREX: A Unified Learning Framework for Language Model Rationale Extraction (ICML 2022) 2 | 3 | This is the official PyTorch repo for [UNIREX](https://arxiv.org/abs/2112.08802), a learning framework for jointly optimizing rationale extractors w.r.t. faithfulness, plausibility, and task performance. 4 | 5 | ``` 6 | UNIREX: A Unified Learning Framework for Language Model Rationale Extraction 7 | Aaron Chan, Maziar Sanjabi, Lambert Mathias, Liang Tan, Shaoliang Nie, Xiaochang Peng, Xiang Ren, Hamed Firooz 8 | ICML 2022 9 | ``` 10 | 11 | The majority of the UNIREX project is licensed under CC-BY-NC. However, some portions of the project are available under separate license terms, as indicated below: 12 | - The [ERASER benchmark](https://github.com/jayded/eraserbenchmark) is licensed under the Apache License 2.0. 13 | 14 | If UNIREX is helpful for your research, please consider citing our ICML paper: 15 | 16 | ``` 17 | @inproceedings{chan2022unirex, 18 | title={Unirex: A unified learning framework for language model rationale extraction}, 19 | author={Chan, Aaron and Sanjabi, Maziar and Mathias, Lambert and Tan, Liang and Nie, Shaoliang and Peng, Xiaochang and Ren, Xiang and Firooz, Hamed}, 20 | booktitle={International Conference on Machine Learning}, 21 | pages={2867--2889}, 22 | year={2022}, 23 | organization={PMLR} 24 | } 25 | ``` 26 | 27 | ## Basics 28 | 29 | ### Neptune 30 | Before running the code, you need to complete the following steps: 31 | 1. Create a [Neptune](https://neptune.ai/) account and project. 32 | 2. Edit the [project name](https://github.com/aarzchan/UNIREX/blob/main/configs/logger/neptune.yaml#L12), [local username](https://github.com/aarzchan/UNIREX/blob/main/src/utils/logging.py#L11), and [Neptune API token](https://github.com/aarzchan/UNIREX/blob/main/src/utils/logging.py#L11) fields in the code. 33 | 34 | 35 | ### Multirun 36 | Do grid search over different configs. 37 | ``` 38 | python main.py -m \ 39 | dataset=sst,stf \ 40 | seed=0,1,2,3,4,5 \ 41 | ``` 42 | 43 | ### Evaluate checkpoint 44 | This command evaluates a checkpoint on the train, dev, and test sets. 45 | ``` 46 | python main.py \ 47 | training=evaluate \ 48 | training.ckpt_path=/path/to/ckpt \ 49 | training.eval_splits=train,dev,test \ 50 | ``` 51 | 52 | ### Finetune checkpoint 53 | ``` 54 | python main.py \ 55 | training=evaluate \ 56 | training.ckpt_path=/path/to/ckpt \ 57 | ``` 58 | 59 | ### Offline Mode 60 | In offline mode, results are not logged to Neptune. 61 | ``` 62 | python main.py logger.offline=True 63 | ``` 64 | 65 | ### Debug Mode 66 | In debug mode, results are not logged to Neptune, and we only train/evaluate for limited number of batches and/or epochs. 67 | ``` 68 | python main.py debug=True 69 | ``` 70 | 71 | ### Hydra Working Directory 72 | 73 | Hydra will change the working directory to the path specified in `configs/hydra/default.yaml`. Therefore, if you save a file to the path `'./file.txt'`, it will actually save the file to somewhere like `logs/runs/xxxx/file.txt`. This is helpful when you want to version control your saved files, but not if you want to save to a global directory. There are two methods to get the "actual" working directory: 74 | 75 | 1. Use `hydra.utils.get_original_cwd` function call 76 | 2. Use `cfg.work_dir`. To use this in the config, can do something like `"${data_dir}/${.dataset}/${model.arch}/"` 77 | 78 | 79 | ### Config Key 80 | 81 | - `work_dir` current working directory (where `src/` is) 82 | 83 | - `data_dir` where data folder is 84 | 85 | - `log_dir` where log folder is (runs & multirun) 86 | 87 | - `root_dir` where the saved ckpt & hydra config are 88 | 89 | 90 | --- 91 | 92 | 93 | ## Example Commands 94 | 95 | Here, we assume the following: 96 | - The `data_dir` is `../data`, which means `data_dir=${work_dir}/../data`. 97 | - The dataset is `sst`. 98 | 99 | ### 1. Build dataset 100 | The commands below are used to build pre-processed datasets, saved as pickle files. The model architecture is specified so that we can use the correct tokenizer for pre-processing. 101 | 102 | ``` 103 | python scripts/build_dataset.py --data_dir ../data \ 104 | --dataset sst --arch google/bigbird-roberta-base --split train 105 | 106 | python scripts/build_dataset.py --data_dir ../data \ 107 | --dataset sst --arch google/bigbird-roberta-base --split dev 108 | 109 | python scripts/build_dataset.py --data_dir ../data \ 110 | --dataset sst --arch google/bigbird-roberta-base --split test 111 | 112 | ``` 113 | 114 | If the dataset is very large, you have the option to subsample part of the dataset for smaller-scale experiements. For example, in the command below, we build a train set with only 1000 train examples (sampled with seed 0). 115 | ``` 116 | python scripts/build_dataset.py --data_dir ../data \ 117 | --dataset sst --arch google/bigbird-roberta-base --split train \ 118 | --num_train 1000 --num_train_seed 0 119 | ``` 120 | 121 | ### 2. Train Task LM 122 | 123 | The command below is the most basic way to run `main.py` and will train the Task LM without any explanation regularization (`model=lm`). 124 | 125 | However, since all models need to be evaluated w.r.t. explainability metrics, we need to specify an attribution algorithm for computing post-hoc explanations. This is done by setting `model.explainer_type=attr_algo` to specify that we are using an attribution algorithm based explainer, `model.attr_algo` to specify the attribution algorithm, and `model.attr_pooling` to specify the attribution pooler. 126 | ``` 127 | python main.py -m \ 128 | data=sst \ 129 | model=lm \ 130 | model.explainer_type=attr_algo \ 131 | model.attr_algo=input-x-gradient \ 132 | model.attr_pooling=sum \ 133 | model.optimizer.lr=2e-5 \ 134 | setup.train_batch_size=32 \ 135 | setup.accumulate_grad_batches=1 \ 136 | setup.eff_train_batch_size=32 \ 137 | setup.eval_batch_size=32 \ 138 | setup.num_workers=3 \ 139 | seed=0,1,2 140 | ``` 141 | 142 | By default, checkpoints will not be saved, so you need to set `save_checkpoint=True` if you want to save the best checkpoint. 143 | ``` 144 | python main.py -m \ 145 | save_checkpoint=True \ 146 | data=sst \ 147 | model=lm \ 148 | model.explainer_type=attr_algo \ 149 | model.attr_algo=input-x-gradient \ 150 | model.attr_pooling=sum \ 151 | model.optimizer.lr=2e-5 \ 152 | setup.train_batch_size=32 \ 153 | setup.accumulate_grad_batches=1 \ 154 | setup.eff_train_batch_size=32 \ 155 | setup.eval_batch_size=32 \ 156 | setup.num_workers=3 \ 157 | seed=0,1,2 158 | ``` 159 | 160 | ### 3. Train Task LM with explanation regularization 161 | This repo implements a number of different methods for training the Task LM with explanation regularization. These methods aim to improve rationale faithfulness, plausibility, or both. Below are commands for running each method. 162 | 163 | 164 | **Task LM + SGT** 165 | ``` 166 | python main.py -m \ 167 | save_checkpoint=True \ 168 | data=sst \ 169 | model=expl_reg \ 170 | model.explainer_type=attr_algo \ 171 | model.attr_algo=input-x-gradient \ 172 | model.attr_pooling=sum \ 173 | model.task_wt=1.0 \ 174 | model.comp_wt=0.0 \ 175 | model.suff_wt=0.5 \ 176 | model.suff_criterion=kldiv \ 177 | model.plaus_wt=0.0 \ 178 | model.optimizer.lr=2e-5 \ 179 | setup.train_batch_size=32 \ 180 | setup.accumulate_grad_batches=1 \ 181 | setup.eff_train_batch_size=32 \ 182 | setup.eval_batch_size=32 \ 183 | setup.num_workers=3 \ 184 | seed=0,1,2 185 | ``` 186 | 187 | **Task LM + FRESH** 188 | 189 | Train Task LM, saving the best checkpoint. 190 | 191 | ``` 192 | python main.py -m \ 193 | save_checkpoint=True \ 194 | data=sst \ 195 | model=lm \ 196 | model.explainer_type=attr_algo \ 197 | model.attr_algo=input-x-gradient \ 198 | model.attr_pooling=sum \ 199 | model.optimizer.lr=2e-5 \ 200 | setup.train_batch_size=32 \ 201 | setup.accumulate_grad_batches=1 \ 202 | setup.eff_train_batch_size=32 \ 203 | setup.eval_batch_size=32 \ 204 | setup.num_workers=3 \ 205 | seed=0 206 | ``` 207 | 208 | Load Task LM checkpoint, then save its attributions to file. 209 | Here, we assume the checkpoint path is based on a Neptune experiment ID with the form `ER-XXXX`, where `ER` is the Neptune project key. 210 | ``` 211 | python main.py -m \ 212 | logger.offline=True \ 213 | training=evaluate \ 214 | training.ckpt_path="'ER-XXXX/checkpoints/epoch=X-step=XXXX.ckpt'" \ 215 | training.eval_splits="'train,dev,test'" \ 216 | data=sst \ 217 | model=lm \ 218 | model.explainer_type=attr_algo \ 219 | model.attr_algo=input-x-gradient \ 220 | model.attr_pooling=sum \ 221 | model.save_outputs=True \ 222 | model.exp_id="ER-XXXX" \ 223 | setup.train_batch_size=32 \ 224 | setup.accumulate_grad_batches=1 \ 225 | setup.eff_train_batch_size=32 \ 226 | setup.eval_batch_size=32 \ 227 | setup.num_workers=3 228 | ``` 229 | 230 | Load attributions, then retrain Task LM, using as input only the top-k% tokens w.r.t. the attributions. 231 | ``` 232 | python main.py -m \ 233 | save_checkpoint=True \ 234 | data=sst \ 235 | data.fresh_exp_id=ER-XXXX \ 236 | data.fresh_attr_algo=input-x-gradient \ 237 | data.fresh_topk=10 \ 238 | model=fresh \ 239 | model.optimizer.lr=2e-5 \ 240 | setup.train_batch_size=32 \ 241 | setup.accumulate_grad_batches=1 \ 242 | setup.eff_train_batch_size=32 \ 243 | setup.eval_batch_size=32 \ 244 | setup.num_workers=3 \ 245 | seed=0 246 | ``` 247 | 248 | **Task LM + DLM (plaus)** 249 | 250 | Use linear layer as DLM head. 251 | ``` 252 | python main.py -m \ 253 | save_checkpoint=True \ 254 | data=sst \ 255 | model=expl_reg \ 256 | model.explainer_type=lm \ 257 | model.expl_head_type=linear \ 258 | model.task_wt=1.0 \ 259 | model.comp_wt=0.0 \ 260 | model.suff_wt=0.0 \ 261 | model.plaus_wt=0.5 \ 262 | model.optimizer.lr=2e-5 \ 263 | setup.train_batch_size=32 \ 264 | setup.accumulate_grad_batches=1 \ 265 | setup.eff_train_batch_size=32 \ 266 | setup.eval_batch_size=32 \ 267 | setup.num_workers=3 \ 268 | seed=0,1,2 269 | ``` 270 | 271 | Use MLP as DLM head. 272 | ``` 273 | python main.py -m \ 274 | save_checkpoint=True \ 275 | data=sst \ 276 | model=expl_reg \ 277 | model.explainer_type=lm \ 278 | model.expl_head_type=mlp \ 279 | model.expl_head_mlp_hidden_dim=2048 \ 280 | model.expl_head_mlp_hidden_layers=2 \ 281 | model.task_wt=1.0 \ 282 | model.comp_wt=0.0 \ 283 | model.suff_wt=0.0 \ 284 | model.plaus_wt=0.5 \ 285 | model.optimizer.lr=2e-5 \ 286 | setup.train_batch_size=32 \ 287 | setup.accumulate_grad_batches=1 \ 288 | setup.eff_train_batch_size=32 \ 289 | setup.eval_batch_size=32 \ 290 | setup.num_workers=3 \ 291 | seed=0,1,2 292 | ``` 293 | 294 | **Task LM + SLM (plaus)** 295 | 296 | Use linear layer as SLM head. 297 | ``` 298 | python main.py -m \ 299 | save_checkpoint=True \ 300 | data=sst \ 301 | model=expl_reg \ 302 | model.explainer_type=self_lm \ 303 | model.expl_head_type=linear \ 304 | model.task_wt=1.0 \ 305 | model.comp_wt=0.0 \ 306 | model.suff_wt=0.0 \ 307 | model.plaus_wt=0.5 \ 308 | model.optimizer.lr=2e-5 \ 309 | setup.train_batch_size=32 \ 310 | setup.accumulate_grad_batches=1 \ 311 | setup.eff_train_batch_size=32 \ 312 | setup.eval_batch_size=32 \ 313 | setup.num_workers=3 \ 314 | seed=0,1,2 315 | ``` 316 | 317 | Use MLP as SLM head. 318 | ``` 319 | python main.py -m \ 320 | save_checkpoint=True \ 321 | data=sst \ 322 | model=expl_reg \ 323 | model.explainer_type=self_lm \ 324 | model.expl_head_type=mlp \ 325 | model.expl_head_mlp_hidden_dim=2048 \ 326 | model.expl_head_mlp_hidden_layers=2 \ 327 | model.task_wt=1.0 \ 328 | model.comp_wt=0.0 \ 329 | model.suff_wt=0.0 \ 330 | model.plaus_wt=0.5 \ 331 | model.optimizer.lr=2e-5 \ 332 | setup.train_batch_size=32 \ 333 | setup.accumulate_grad_batches=1 \ 334 | setup.eff_train_batch_size=32 \ 335 | setup.eval_batch_size=32 \ 336 | setup.num_workers=3 \ 337 | seed=0,1,2 338 | ``` 339 | 340 | **Task LM + AA-Sum (comp/suff)** 341 | 342 | Using Input*Grad attribution algorithm. 343 | ``` 344 | python main.py -m \ 345 | save_checkpoint=True \ 346 | data=sst \ 347 | model=expl_reg \ 348 | model.explainer_type=attr_algo \ 349 | model.attr_algo=input-x-gradient \ 350 | model.attr_pooling=sum \ 351 | model.task_wt=1.0 \ 352 | model.comp_wt=0.5 \ 353 | model.suff_wt=0.5 \ 354 | model.plaus_wt=0.0 \ 355 | model.optimizer.lr=2e-5 \ 356 | setup.train_batch_size=32 \ 357 | setup.eval_batch_size=32 \ 358 | setup.num_workers=3 \ 359 | seed=0,1,2 360 | ``` 361 | 362 | Using simple baseline for attribution algorithm. 363 | ``` 364 | python main.py -m \ 365 | save_checkpoint=True \ 366 | data=sst \ 367 | model=expl_reg \ 368 | model.explainer_type=attr_algo \ 369 | model.attr_algo={gold, inv, rand} \ 370 | model.attr_pooling=sum \ 371 | model.task_wt=1.0 \ 372 | model.comp_wt=0.5 \ 373 | model.suff_wt=0.5 \ 374 | model.plaus_wt=0.0 \ 375 | model.optimizer.lr=2e-5 \ 376 | setup.train_batch_size=32 \ 377 | setup.accumulate_grad_batches=1 \ 378 | setup.eff_train_batch_size=32 \ 379 | setup.eval_batch_size=32 \ 380 | setup.num_workers=3 \ 381 | seed=0,1,2 382 | ``` 383 | 384 | 385 | **Task LM + AA-Sum (comp/suff/plaus)** 386 | ``` 387 | python main.py -m \ 388 | save_checkpoint=True \ 389 | data=sst \ 390 | model=expl_reg \ 391 | model.explainer_type=attr_algo \ 392 | model.attr_algo=input-x-gradient \ 393 | model.attr_pooling=sum \ 394 | model.task_wt=1.0 \ 395 | model.comp_wt=0.5 \ 396 | model.suff_wt=0.5 \ 397 | model.plaus_wt=0.5 \ 398 | model.optimizer.lr=2e-5 \ 399 | setup.train_batch_size=32 \ 400 | setup.accumulate_grad_batches=1 \ 401 | setup.eff_train_batch_size=32 \ 402 | setup.eval_batch_size=32 \ 403 | setup.num_workers=3 \ 404 | seed=0,1,2 405 | ``` 406 | 407 | ### **Task LM + AA-MLP (comp/suff/plaus)** 408 | ``` 409 | python main.py -m \ 410 | save_checkpoint=True \ 411 | data=sst \ 412 | model=expl_reg \ 413 | model.explainer_type=attr_algo \ 414 | model.attr_algo=input-x-gradient \ 415 | model.attr_pooling=mlp \ 416 | model.attr_mlp_hidden_dim=2048 \ 417 | model.attr_mlp_hidden_layers=2 \ 418 | model.task_wt=1.0 \ 419 | model.comp_wt=0.5 \ 420 | model.suff_wt=0.5 \ 421 | model.plaus_wt=0.5 \ 422 | model.optimizer.lr=2e-5 \ 423 | setup.train_batch_size=32 \ 424 | setup.accumulate_grad_batches=1 \ 425 | setup.eff_train_batch_size=32 \ 426 | setup.eval_batch_size=32 \ 427 | setup.num_workers=3 \ 428 | seed=0,1,2 429 | ``` 430 | 431 | **Task LM + DLM (comp/suff/plaus)** 432 | ``` 433 | python main.py -m \ 434 | save_checkpoint=True \ 435 | data=sst \ 436 | model=expl_reg \ 437 | model.explainer_type=lm \ 438 | model.expl_head_type=linear \ 439 | model.task_wt=1.0 \ 440 | model.comp_wt=0.5 \ 441 | model.suff_wt=0.5 \ 442 | model.plaus_wt=0.5 \ 443 | model.optimizer.lr=2e-5 \ 444 | setup.train_batch_size=32 \ 445 | setup.accumulate_grad_batches=1 \ 446 | setup.eff_train_batch_size=32 \ 447 | setup.eval_batch_size=32 \ 448 | setup.num_workers=3 \ 449 | seed=0,1,2 450 | ``` 451 | 452 | Use MLP as DLM head. 453 | ``` 454 | python main.py -m \ 455 | save_checkpoint=True \ 456 | data=sst \ 457 | model=expl_reg \ 458 | model.explainer_type=lm \ 459 | model.expl_head_type=mlp \ 460 | model.expl_head_mlp_hidden_dim=2048 \ 461 | model.expl_head_mlp_hidden_layers=2 \ 462 | model.task_wt=1.0 \ 463 | model.comp_wt=0.5 \ 464 | model.suff_wt=0.5 \ 465 | model.plaus_wt=0.5 \ 466 | model.optimizer.lr=2e-5 \ 467 | setup.train_batch_size=32 \ 468 | setup.accumulate_grad_batches=1 \ 469 | setup.eff_train_batch_size=32 \ 470 | setup.eval_batch_size=32 \ 471 | setup.num_workers=3 \ 472 | seed=0,1,2 473 | ``` 474 | 475 | **Task LM + SLM (comp/suff/plaus)** 476 | ``` 477 | python main.py -m \ 478 | save_checkpoint=True \ 479 | data=sst \ 480 | model=expl_reg \ 481 | model.explainer_type=self_lm \ 482 | model.expl_head_type=linear \ 483 | model.task_wt=1.0 \ 484 | model.comp_wt=0.5 \ 485 | model.suff_wt=0.5 \ 486 | model.plaus_wt=0.5 \ 487 | model.optimizer.lr=2e-5 \ 488 | setup.train_batch_size=32 \ 489 | setup.accumulate_grad_batches=1 \ 490 | setup.eff_train_batch_size=32 \ 491 | setup.eval_batch_size=32 \ 492 | setup.num_workers=3 \ 493 | seed=0,1,2 494 | ``` 495 | 496 | Use MLP as SLM head. 497 | ``` 498 | python main.py -m \ 499 | save_checkpoint=True \ 500 | data=sst \ 501 | model=expl_reg \ 502 | model.explainer_type=self_lm \ 503 | model.expl_head_type=mlp \ 504 | model.expl_head_mlp_hidden_dim=2048 \ 505 | model.expl_head_mlp_hidden_layers=2 \ 506 | model.task_wt=1.0 \ 507 | model.comp_wt=0.5 \ 508 | model.suff_wt=0.5 \ 509 | model.plaus_wt=0.5 \ 510 | model.optimizer.lr=2e-5 \ 511 | setup.train_batch_size=32 \ 512 | setup.accumulate_grad_batches=1 \ 513 | setup.eff_train_batch_size=32 \ 514 | setup.eval_batch_size=32 \ 515 | setup.num_workers=3 \ 516 | seed=0,1,2 517 | ``` 518 | 519 | **Task LM + L2E** 520 | ``` 521 | python main.py -m \ 522 | save_checkpoint=True \ 523 | data=sst \ 524 | data.l2e_exp_id=ER-XXXX \ 525 | data.l2e_attr_algo=integrated-gradients \ 526 | model=expl_reg \ 527 | model.explainer_type=lm \ 528 | model.expl_head_type=linear \ 529 | model.task_wt=1.0 \ 530 | model.plaus_wt=0.5 \ 531 | model.l2e=True \ 532 | model.l2e_wt=0.5 \ 533 | model.optimizer.lr=2e-5 \ 534 | setup.train_batch_size=32 \ 535 | setup.accumulate_grad_batches=1 \ 536 | setup.eff_train_batch_size=32 \ 537 | setup.eval_batch_size=32 \ 538 | setup.num_workers=3 \ 539 | seed=0,1,2 540 | ``` 541 | 542 | **Task LM + A2R** 543 | ``` 544 | python main.py -m \ 545 | save_checkpoint=True \ 546 | data=sst \ 547 | model=a2r \ 548 | model.task_wt=1.0 \ 549 | model.plaus_wt=0.5 \ 550 | model.a2r_wt=0.5 \ 551 | model.optimizer.lr=2e-5 \ 552 | setup.train_batch_size=32 \ 553 | setup.accumulate_grad_batches=1 \ 554 | setup.eff_train_batch_size=32 \ 555 | setup.eval_batch_size=32 \ 556 | setup.num_workers=3 \ 557 | seed=0,1,2 558 | ``` 559 | 560 | **Measure Explainer Runtime** 561 | ``` 562 | python main.py -m \ 563 | training=evaluate \ 564 | training.ckpt_path="'save/ER-XXX/checkpoints/epoch=X-step=X.ckpt'" \ 565 | data=sst \ 566 | model=lm \ 567 | model.explainer_type=attr_algo \ 568 | model.attr_algo=integrated-gradients \ 569 | model.ig_steps=3 \ 570 | model.return_convergence_delta=True \ 571 | model.internal_batch_size=3 \ 572 | model.attr_pooling=sum \ 573 | model.measure_attrs_runtime=True \ 574 | setup.train_batch_size=1 \ 575 | setup.eff_train_batch_size=1 \ 576 | setup.eval_batch_size=1 \ 577 | setup.num_workers=3 \ 578 | seed=0,1,2 579 | ``` 580 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | # default use fixed-lm without expl 2 | defaults: 3 | - model: lm 4 | - data: default 5 | - logger: neptune 6 | - setup: a100 7 | - hydra: default 8 | - training: base 9 | - trainer: defaults 10 | - override /hydra/job_logging: colorlog 11 | - override /hydra/hydra_logging: colorlog 12 | 13 | seed: 0 14 | debug: False 15 | 16 | work_dir: ${hydra:runtime.cwd} 17 | data_dir: '${work_dir}/../data' 18 | log_dir: '${work_dir}/../logs' 19 | save_dir: '${work_dir}/../save' 20 | 21 | save_checkpoint: False 22 | save_rand_checkpoint: False 23 | early_stopping: True 24 | -------------------------------------------------------------------------------- /configs/data/amazon.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.data.DataModule 2 | 3 | dataset: amazon 4 | 5 | num_workers: ${setup.num_workers} 6 | data_path: "${data_dir}/${.dataset}/${model.arch}/" 7 | train_batch_size: ${setup.train_batch_size} 8 | eval_batch_size: ${setup.eval_batch_size} 9 | eff_train_batch_size: ${setup.eff_train_batch_size} 10 | mode: 'max' 11 | 12 | num_train: null 13 | num_dev: null 14 | num_test: null 15 | num_train_seed: 0 16 | num_dev_seed: 0 17 | num_test_seed: 0 18 | 19 | pct_train_rationales: null 20 | pct_train_rationales_seed: 0 21 | train_rationales_batch_factor: 2.0 22 | 23 | neg_weight: 1 24 | 25 | fresh_exp_id: null 26 | fresh_attr_algo: null 27 | fresh_topk: null 28 | 29 | l2e_exp_id: null 30 | l2e_attr_algo: null -------------------------------------------------------------------------------- /configs/data/cose.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.data.DataModule 2 | 3 | dataset: cose 4 | 5 | num_workers: ${setup.num_workers} 6 | data_path: "${data_dir}/${.dataset}/${model.arch}/" 7 | train_batch_size: ${setup.train_batch_size} 8 | eval_batch_size: ${setup.eval_batch_size} 9 | eff_train_batch_size: ${setup.eff_train_batch_size} 10 | mode: 'max' 11 | 12 | num_train: null 13 | num_dev: null 14 | num_test: null 15 | num_train_seed: 0 16 | num_dev_seed: 0 17 | num_test_seed: 0 18 | 19 | pct_train_rationales: null 20 | pct_train_rationales_seed: 0 21 | train_rationales_batch_factor: 2.0 22 | 23 | neg_weight: null 24 | 25 | fresh_exp_id: null 26 | fresh_attr_algo: null 27 | fresh_topk: null 28 | 29 | l2e_exp_id: null 30 | l2e_attr_algo: null -------------------------------------------------------------------------------- /configs/data/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.data.DataModule 2 | 3 | dataset: ??? 4 | 5 | num_workers: ${setup.num_workers} 6 | data_path: "${data_dir}/${.dataset}/${model.arch}/" 7 | train_batch_size: ${setup.train_batch_size} 8 | eval_batch_size: ${setup.eval_batch_size} 9 | eff_train_batch_size: ${setup.eff_train_batch_size} 10 | mode: 'max' 11 | 12 | num_train: null 13 | num_dev: null 14 | num_test: null 15 | num_train_seed: 0 16 | num_dev_seed: 0 17 | num_test_seed: 0 18 | 19 | pct_train_rationales: null 20 | pct_train_rationales_seed: 0 21 | train_rationales_batch_factor: 2.0 22 | 23 | neg_weight: 1 24 | 25 | fresh_exp_id: null 26 | fresh_attr_algo: null 27 | fresh_topk: null 28 | 29 | l2e_exp_id: null 30 | l2e_attr_algo: null -------------------------------------------------------------------------------- /configs/data/esnli.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.data.DataModule 2 | 3 | dataset: esnli 4 | 5 | num_workers: ${setup.num_workers} 6 | data_path: "${data_dir}/${.dataset}/${model.arch}/" 7 | train_batch_size: ${setup.train_batch_size} 8 | eval_batch_size: ${setup.eval_batch_size} 9 | eff_train_batch_size: ${setup.eff_train_batch_size} 10 | mode: 'max' 11 | 12 | num_train: null 13 | num_dev: null 14 | num_test: null 15 | num_train_seed: 0 16 | num_dev_seed: 0 17 | num_test_seed: 0 18 | 19 | pct_train_rationales: null 20 | pct_train_rationales_seed: 0 21 | train_rationales_batch_factor: 2.0 22 | 23 | neg_weight: 1 24 | 25 | fresh_exp_id: null 26 | fresh_attr_algo: null 27 | fresh_topk: null 28 | 29 | l2e_exp_id: null 30 | l2e_attr_algo: null -------------------------------------------------------------------------------- /configs/data/irony.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.data.DataModule 2 | 3 | dataset: irony 4 | 5 | num_workers: ${setup.num_workers} 6 | data_path: "${data_dir}/${.dataset}/${model.arch}/" 7 | train_batch_size: ${setup.train_batch_size} 8 | eval_batch_size: ${setup.eval_batch_size} 9 | eff_train_batch_size: ${setup.eff_train_batch_size} 10 | mode: 'max' 11 | 12 | num_train: null 13 | num_dev: null 14 | num_test: null 15 | num_train_seed: 0 16 | num_dev_seed: 0 17 | num_test_seed: 0 18 | 19 | pct_train_rationales: null 20 | pct_train_rationales_seed: 0 21 | train_rationales_batch_factor: 2.0 22 | 23 | neg_weight: 1 24 | 25 | fresh_exp_id: null 26 | fresh_attr_algo: null 27 | fresh_topk: null 28 | 29 | l2e_exp_id: null 30 | l2e_attr_algo: null -------------------------------------------------------------------------------- /configs/data/movies.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.data.DataModule 2 | 3 | dataset: movies 4 | 5 | num_workers: ${setup.num_workers} 6 | data_path: "${data_dir}/${.dataset}/${model.arch}/" 7 | train_batch_size: ${setup.train_batch_size} 8 | eval_batch_size: ${setup.eval_batch_size} 9 | eff_train_batch_size: ${setup.eff_train_batch_size} 10 | mode: 'max' 11 | 12 | num_train: null 13 | num_dev: null 14 | num_test: null 15 | num_train_seed: 0 16 | num_dev_seed: 0 17 | num_test_seed: 0 18 | 19 | pct_train_rationales: null 20 | pct_train_rationales_seed: 0 21 | train_rationales_batch_factor: 2.0 22 | 23 | neg_weight: 1 24 | 25 | fresh_exp_id: null 26 | fresh_attr_algo: null 27 | fresh_topk: null 28 | 29 | l2e_exp_id: null 30 | l2e_attr_algo: null -------------------------------------------------------------------------------- /configs/data/multirc.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.data.DataModule 2 | 3 | dataset: multirc 4 | 5 | num_workers: ${setup.num_workers} 6 | data_path: "${data_dir}/${.dataset}/${model.arch}/" 7 | train_batch_size: ${setup.train_batch_size} 8 | eval_batch_size: ${setup.eval_batch_size} 9 | eff_train_batch_size: ${setup.eff_train_batch_size} 10 | mode: 'max' 11 | 12 | num_train: null 13 | num_dev: null 14 | num_test: null 15 | num_train_seed: 0 16 | num_dev_seed: 0 17 | num_test_seed: 0 18 | 19 | pct_train_rationales: null 20 | pct_train_rationales_seed: 0 21 | train_rationales_batch_factor: 2.0 22 | 23 | neg_weight: 1 24 | 25 | fresh_exp_id: null 26 | fresh_attr_algo: null 27 | fresh_topk: null 28 | 29 | l2e_exp_id: null 30 | l2e_attr_algo: null -------------------------------------------------------------------------------- /configs/data/olid.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.data.DataModule 2 | 3 | dataset: olid 4 | 5 | num_workers: ${setup.num_workers} 6 | data_path: "${data_dir}/${.dataset}/${model.arch}/" 7 | train_batch_size: ${setup.train_batch_size} 8 | eval_batch_size: ${setup.eval_batch_size} 9 | eff_train_batch_size: ${setup.eff_train_batch_size} 10 | mode: 'max' 11 | 12 | num_train: null 13 | num_dev: null 14 | num_test: null 15 | num_train_seed: 0 16 | num_dev_seed: 0 17 | num_test_seed: 0 18 | 19 | pct_train_rationales: null 20 | pct_train_rationales_seed: 0 21 | train_rationales_batch_factor: 2.0 22 | 23 | neg_weight: 1 24 | 25 | fresh_exp_id: null 26 | fresh_attr_algo: null 27 | fresh_topk: null 28 | 29 | l2e_exp_id: null 30 | l2e_attr_algo: null -------------------------------------------------------------------------------- /configs/data/sst.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.data.DataModule 2 | 3 | dataset: sst 4 | 5 | num_workers: ${setup.num_workers} 6 | data_path: "${data_dir}/${.dataset}/${model.arch}/" 7 | train_batch_size: ${setup.train_batch_size} 8 | eval_batch_size: ${setup.eval_batch_size} 9 | eff_train_batch_size: ${setup.eff_train_batch_size} 10 | mode: 'max' 11 | 12 | num_train: null 13 | num_dev: null 14 | num_test: null 15 | num_train_seed: 0 16 | num_dev_seed: 0 17 | num_test_seed: 0 18 | 19 | pct_train_rationales: null 20 | pct_train_rationales_seed: 0 21 | train_rationales_batch_factor: 2.0 22 | 23 | neg_weight: 1 24 | 25 | fresh_exp_id: null 26 | fresh_attr_algo: null 27 | fresh_topk: null 28 | 29 | l2e_exp_id: null 30 | l2e_attr_algo: null -------------------------------------------------------------------------------- /configs/data/stf.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.data.DataModule 2 | 3 | dataset: stf 4 | 5 | num_workers: ${setup.num_workers} 6 | data_path: "${data_dir}/${.dataset}/${model.arch}/" 7 | train_batch_size: ${setup.train_batch_size} 8 | eval_batch_size: ${setup.eval_batch_size} 9 | eff_train_batch_size: ${setup.eff_train_batch_size} 10 | mode: 'max' 11 | 12 | num_train: null 13 | num_dev: null 14 | num_test: null 15 | num_train_seed: 0 16 | num_dev_seed: 0 17 | num_test_seed: 0 18 | 19 | pct_train_rationales: null 20 | pct_train_rationales_seed: 0 21 | train_rationales_batch_factor: 2.0 22 | 23 | neg_weight: 0.1 24 | 25 | fresh_exp_id: null 26 | fresh_attr_algo: null 27 | fresh_topk: null 28 | 29 | l2e_exp_id: null 30 | l2e_attr_algo: null -------------------------------------------------------------------------------- /configs/data/yelp.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.data.data.DataModule 2 | 3 | dataset: yelp 4 | 5 | num_workers: ${setup.num_workers} 6 | data_path: "${data_dir}/${.dataset}/${model.arch}/" 7 | train_batch_size: ${setup.train_batch_size} 8 | eval_batch_size: ${setup.eval_batch_size} 9 | eff_train_batch_size: ${setup.eff_train_batch_size} 10 | mode: 'max' 11 | 12 | num_train: null 13 | num_dev: null 14 | num_test: null 15 | num_train_seed: 0 16 | num_dev_seed: 0 17 | num_test_seed: 0 18 | 19 | pct_train_rationales: null 20 | pct_train_rationales_seed: 0 21 | train_rationales_batch_factor: 2.0 22 | 23 | neg_weight: 1 24 | 25 | fresh_exp_id: null 26 | fresh_attr_algo: null 27 | fresh_topk: null 28 | 29 | l2e_exp_id: null 30 | l2e_attr_algo: null -------------------------------------------------------------------------------- /configs/experiment/fixed_lm_expl.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: lm 5 | - override /dataset: default 6 | 7 | 8 | model: 9 | attn_reg: True 10 | attn_reg_layers: last 11 | attr_algo: layer-integrated-gradients 12 | 13 | # should provide one 14 | dataset: 15 | dataset: ??? 16 | expl_path: ??? -------------------------------------------------------------------------------- /configs/experiment/iter_lm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: iterative 5 | - override /dataset: default 6 | 7 | 8 | model: 9 | attn_reg: True 10 | attn_reg_layers: last 11 | attr_algo: layer-integrated-gradients 12 | 13 | # should provide one 14 | dataset: 15 | dataset: ??? 16 | expl_path: ??? -------------------------------------------------------------------------------- /configs/hsearch/lr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /hydra/sweeper: optuna 5 | 6 | tune_metric: "dev_acc_epoch" 7 | 8 | hydra: 9 | sweeper: 10 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 11 | storage: null 12 | study_name: null 13 | n_jobs: 1 14 | 15 | # 'minimize' or 'maximize' the objective 16 | direction: maximize 17 | 18 | # number of experiments that will be executed 19 | n_trials: 10 20 | 21 | # choose Optuna hyperparameter sampler 22 | # learn more here: https://optuna.readthedocs.io/en/stable/reference/samplers.html 23 | sampler: 24 | _target_: optuna.samplers.TPESampler 25 | seed: 12345 26 | consider_prior: true 27 | prior_weight: 1.0 28 | consider_magic_clip: true 29 | consider_endpoints: false 30 | n_startup_trials: 10 31 | n_ei_candidates: 24 32 | multivariate: false 33 | warn_independent_sampling: true 34 | 35 | # define range of hyperparameters 36 | search_space: 37 | model.optimizer.lr: 38 | type: float 39 | low: 0.00001 40 | high: 0.008 41 | log: true 42 | model.optimizer.weight_decay: 43 | type: float 44 | low: 1E-4 45 | high: 1E-2 46 | log: true -------------------------------------------------------------------------------- /configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | dir: ${log_dir}/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} 3 | sweep: 4 | # dir: ${log_dir}/multiruns/${now:%Y-%m-%d}/${now:%H-%M-%S} 5 | dir: ../logs/multiruns/${now:%Y-%m-%d}/${now:%H-%M-%S} -------------------------------------------------------------------------------- /configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.utils.logging.get_neptune_logger 2 | 3 | logger: neptune 4 | 5 | log_db: manual_runs 6 | tag_attrs: 7 | - ${data.dataset} 8 | - ${model.model} 9 | - ${model.arch} 10 | name: test 11 | offline: False 12 | project_name: your-project-name -------------------------------------------------------------------------------- /configs/model/a2r.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.model.lm.LanguageModel 2 | 3 | defaults: 4 | - optimizer: hf_adamw 5 | - scheduler: linear_with_warmup 6 | 7 | model: lm 8 | arch: google/bigbird-roberta-base 9 | dataset: ${data.dataset} 10 | 11 | num_freeze_layers: 0 12 | freeze_epochs: -1 13 | 14 | expl_reg: True 15 | train_topk: [1, 5, 10, 20, 50] 16 | eval_topk: [1, 5, 10, 20, 50] 17 | expl_reg_freq: 1 18 | task_wt: 1.0 19 | 20 | comp_criterion: null 21 | comp_margin: null 22 | comp_target: False 23 | comp_wt: 0.0 24 | 25 | suff_criterion: null 26 | suff_margin: null 27 | suff_target: False 28 | suff_wt: 0.0 29 | 30 | log_odds: False 31 | log_odds_target: False 32 | 33 | plaus_criterion: bce 34 | plaus_margin: 0.1 35 | plaus_wt: 0.0 36 | 37 | explainer_type: lm 38 | expl_head_type: linear 39 | expl_head_mlp_hidden_dim: null 40 | expl_head_mlp_hidden_layers: null 41 | expl_head_mlp_dropout: 0.0 42 | expl_head_mlp_layernorm: False 43 | attr_algo: null 44 | attr_pooling: null 45 | attr_mlp_hidden_dim: null 46 | attr_mlp_hidden_layers: null 47 | attr_mlp_dropout: null 48 | attr_mlp_layernorm: null 49 | ig_steps: null 50 | internal_batch_size: null 51 | return_convergence_delta: False 52 | gradshap_n_samples: null 53 | gradshap_stdevs: null 54 | 55 | fresh: False 56 | fresh_extractor: null 57 | 58 | l2e: False 59 | l2e_wt: 0.0 60 | l2e_criterion: null 61 | l2e_classes: null 62 | 63 | a2r: True 64 | a2r_wt: 0.0 65 | a2r_criterion: jsd 66 | a2r_task_out: sum 67 | 68 | save_outputs: False 69 | exp_id: null 70 | 71 | measure_attrs_runtime: False -------------------------------------------------------------------------------- /configs/model/expl_reg.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.model.lm.LanguageModel 2 | 3 | defaults: 4 | - optimizer: hf_adamw 5 | - scheduler: linear_with_warmup 6 | 7 | model: lm 8 | arch: google/bigbird-roberta-base 9 | dataset: ${data.dataset} 10 | 11 | num_freeze_layers: 0 12 | freeze_epochs: -1 13 | 14 | expl_reg: True 15 | train_topk: [1, 5, 10, 20, 50] 16 | eval_topk: [1, 5, 10, 20, 50] 17 | expl_reg_freq: 1 18 | task_wt: 1.0 19 | 20 | comp_criterion: margin 21 | comp_margin: 1.0 22 | comp_target: False 23 | comp_wt: 0.0 24 | 25 | suff_criterion: margin 26 | suff_margin: 0.1 27 | suff_target: False 28 | suff_wt: 0.0 29 | 30 | log_odds: False 31 | log_odds_target: False 32 | 33 | plaus_criterion: bce 34 | plaus_margin: 0.1 35 | plaus_wt: 0.0 36 | 37 | explainer_type: self_lm 38 | expl_head_type: linear 39 | expl_head_mlp_hidden_dim: null 40 | expl_head_mlp_hidden_layers: null 41 | expl_head_mlp_dropout: 0.0 42 | expl_head_mlp_layernorm: False 43 | attr_algo: null 44 | attr_pooling: null 45 | attr_mlp_hidden_dim: null 46 | attr_mlp_hidden_layers: null 47 | attr_mlp_dropout: 0.0 48 | attr_mlp_layernorm: False 49 | ig_steps: 3 50 | internal_batch_size: null 51 | return_convergence_delta: False 52 | gradshap_n_samples: null 53 | gradshap_stdevs: null 54 | 55 | fresh: False 56 | fresh_extractor: null 57 | 58 | l2e: False 59 | l2e_wt: 0.0 60 | l2e_criterion: ce 61 | l2e_classes: 5 62 | 63 | a2r: False 64 | a2r_wt: 0.0 65 | a2r_criterion: null 66 | a2r_task_out: null 67 | 68 | save_outputs: False 69 | exp_id: null 70 | 71 | measure_attrs_runtime: False -------------------------------------------------------------------------------- /configs/model/fresh.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.model.lm.LanguageModel 2 | 3 | defaults: 4 | - optimizer: hf_adamw 5 | - scheduler: linear_with_warmup 6 | 7 | model: lm 8 | arch: google/bigbird-roberta-base 9 | dataset: ${data.dataset} 10 | 11 | num_freeze_layers: 0 12 | freeze_epochs: -1 13 | 14 | expl_reg: False 15 | train_topk: [100] 16 | eval_topk: [100] 17 | expl_reg_freq: 1e100 18 | task_wt: null 19 | 20 | comp_criterion: null 21 | comp_margin: null 22 | comp_target: False 23 | comp_wt: null 24 | 25 | suff_criterion: null 26 | suff_margin: null 27 | suff_target: False 28 | suff_wt: null 29 | 30 | log_odds: False 31 | log_odds_target: False 32 | 33 | plaus_criterion: null 34 | plaus_margin: null 35 | plaus_wt: null 36 | 37 | explainer_type: null 38 | expl_head_type: null 39 | expl_head_mlp_hidden_dim: null 40 | expl_head_mlp_hidden_layers: null 41 | expl_head_mlp_dropout: null 42 | expl_head_mlp_layernorm: null 43 | attr_algo: null 44 | attr_pooling: null 45 | attr_mlp_hidden_dim: null 46 | attr_mlp_hidden_layers: null 47 | attr_mlp_dropout: null 48 | attr_mlp_layernorm: null 49 | ig_steps: null 50 | internal_batch_size: null 51 | return_convergence_delta: False 52 | gradshap_n_samples: null 53 | gradshap_stdevs: null 54 | 55 | fresh: True 56 | fresh_extractor: oracle 57 | 58 | save_outputs: False 59 | exp_id: null 60 | 61 | measure_attrs_runtime: False -------------------------------------------------------------------------------- /configs/model/lm.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.model.lm.LanguageModel 2 | 3 | defaults: 4 | - optimizer: hf_adamw 5 | - scheduler: linear_with_warmup 6 | 7 | model: lm 8 | arch: google/bigbird-roberta-base 9 | dataset: ${data.dataset} 10 | 11 | num_freeze_layers: 0 12 | freeze_epochs: -1 13 | 14 | expl_reg: False 15 | train_topk: [1, 5, 10, 20, 50] 16 | eval_topk: [1, 5, 10, 20, 50] 17 | expl_reg_freq: 1e100 18 | task_wt: null 19 | 20 | comp_criterion: null 21 | comp_margin: null 22 | comp_target: False 23 | comp_wt: null 24 | 25 | suff_criterion: null 26 | suff_margin: null 27 | suff_target: False 28 | suff_wt: null 29 | 30 | log_odds: False 31 | log_odds_target: False 32 | 33 | plaus_criterion: null 34 | plaus_margin: null 35 | plaus_wt: null 36 | 37 | explainer_type: null 38 | expl_head_type: null 39 | expl_head_mlp_hidden_dim: null 40 | expl_head_mlp_hidden_layers: null 41 | expl_head_mlp_dropout: null 42 | expl_head_mlp_layernorm: null 43 | attr_algo: null 44 | attr_pooling: null 45 | attr_mlp_hidden_dim: null 46 | attr_mlp_hidden_layers: null 47 | attr_mlp_dropout: null 48 | attr_mlp_layernorm: null 49 | ig_steps: 3 50 | internal_batch_size: null 51 | return_convergence_delta: False 52 | gradshap_n_samples: null 53 | gradshap_stdevs: null 54 | 55 | fresh: False 56 | fresh_extractor: null 57 | 58 | l2e: False 59 | l2e_wt: null 60 | l2e_criterion: null 61 | l2e_classes: 5 62 | 63 | a2r: False 64 | a2r_wt: null 65 | a2r_criterion: null 66 | a2r_task_out: null 67 | 68 | save_outputs: False 69 | exp_id: null 70 | 71 | measure_attrs_runtime: False -------------------------------------------------------------------------------- /configs/model/optimizer/adamw.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.AdamW 2 | 3 | lr: 1e-5 4 | betas: [ 0.9, 0.98 ] 5 | eps: 1e-8 6 | weight_decay: 0.0 7 | amsgrad: False -------------------------------------------------------------------------------- /configs/model/optimizer/hf_adamw.yaml: -------------------------------------------------------------------------------- 1 | _target_: transformers.AdamW 2 | 3 | lr: 1e-5 4 | betas: [ 0.9, 0.98 ] 5 | eps: 1e-8 6 | weight_decay: 0.0 7 | correct_bias: False -------------------------------------------------------------------------------- /configs/model/scheduler/fixed.yaml: -------------------------------------------------------------------------------- 1 | lr_scheduler: fixed 2 | warmup_updates: 0.0 3 | -------------------------------------------------------------------------------- /configs/model/scheduler/linear_with_warmup.yaml: -------------------------------------------------------------------------------- 1 | lr_scheduler: linear_with_warmup 2 | warmup_updates: 0.1 3 | -------------------------------------------------------------------------------- /configs/setup/a100.yaml: -------------------------------------------------------------------------------- 1 | train_batch_size: 1 2 | eval_batch_size: 1 3 | accumulate_grad_batches: 1 4 | eff_train_batch_size: 1 5 | num_workers: 0 6 | precision: 16 7 | -------------------------------------------------------------------------------- /configs/trainer/defaults.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | # default values for all trainer parameters 4 | checkpoint_callback: True 5 | default_root_dir: null 6 | gradient_clip_val: 0 7 | process_position: 0 8 | num_nodes: 1 9 | num_processes: 1 10 | 11 | gpus: -1 12 | auto_select_gpus: True 13 | tpu_cores: null 14 | log_gpu_memory: null 15 | progress_bar_refresh_rate: 1 16 | overfit_batches: 0.0 17 | track_grad_norm: -1 18 | check_val_every_n_epoch: 1 19 | fast_dev_run: False 20 | accumulate_grad_batches: ${setup.accumulate_grad_batches} 21 | 22 | max_epochs: 10 23 | min_epochs: 1 24 | max_steps: null 25 | min_steps: null 26 | limit_train_batches: 1.0 27 | limit_val_batches: 1.0 28 | limit_test_batches: 1.0 29 | val_check_interval: 1.0 30 | flush_logs_every_n_steps: 1000 31 | log_every_n_steps: 1000 32 | accelerator: null 33 | sync_batchnorm: False 34 | precision: ${setup.precision} 35 | weights_summary: "top" 36 | weights_save_path: null 37 | 38 | num_sanity_val_steps: 0 39 | truncated_bptt_steps: null 40 | resume_from_checkpoint: null 41 | profiler: null 42 | benchmark: True 43 | deterministic: True 44 | reload_dataloaders_every_epoch: False 45 | auto_lr_find: False 46 | replace_sampler_ddp: True 47 | terminate_on_nan: False 48 | auto_scale_batch_size: False 49 | prepare_data_per_node: True 50 | plugins: null 51 | amp_backend: "native" 52 | amp_level: "O2" 53 | move_metrics_to_cpu: False 54 | -------------------------------------------------------------------------------- /configs/training/base.yaml: -------------------------------------------------------------------------------- 1 | evaluate_ckpt: False 2 | finetune_ckpt: False 3 | # shouldn't have ckpt unless fine tune ckpt 4 | ckpt_path: null 5 | eval_splits: 'all' # comma separated (no space). E.g., train,dev,test 6 | train_shuffle: True 7 | patience: 5 -------------------------------------------------------------------------------- /configs/training/evaluate.yaml: -------------------------------------------------------------------------------- 1 | evaluate_ckpt: True 2 | finetune_ckpt: False 3 | # must provide one 4 | ckpt_path: null 5 | eval_splits: 'test' # comma separated (no space). E.g., train,dev,test 6 | train_shuffle: False 7 | patience: null -------------------------------------------------------------------------------- /configs/training/finetune.yaml: -------------------------------------------------------------------------------- 1 | evaluate_ckpt: False 2 | finetune_ckpt: True 3 | # must provide one 4 | ckpt_path: null 5 | eval_splits: 'all' # comma separated (no space). E.g., train,dev,test 6 | train_shuffle: True 7 | patience: 5 -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import hydra 4 | from omegaconf import DictConfig 5 | 6 | 7 | @hydra.main(config_path="configs", config_name="config") 8 | def main(cfg: DictConfig): 9 | # import here for faster auto completion 10 | from src.utils.conf import touch 11 | from src.run import run 12 | 13 | # additional set field by condition 14 | # assert no missing etc 15 | touch(cfg) 16 | 17 | start_time = time.time() 18 | metric = run(cfg) 19 | print( 20 | f'Time Taken for experiment {cfg.logger.neptune_exp_id}: {(time.time() - start_time) / 3600}h') 21 | 22 | return metric 23 | 24 | 25 | if __name__ == '__main__': 26 | main() 27 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # python==3.8.11 (conda create -n hitl-expl-reg python=3.8.11) 2 | # torch==1.7.1 (conda install pytorch==1.7.1 torchvision torchaudio cudatoolkit=10.1 -c pytorch) 3 | transformers 4 | datasets 5 | pytorch-lightning==1.4.8 6 | hydra-core==1.1.0 7 | omegaconf==2.1.1 8 | numpy>=1.17 9 | jsonlines==1.2.0 10 | tqdm==4.48.2 11 | captum==0.4.0 12 | pickle5==0.0.11 13 | neptune-client==0.9.16 14 | rich==10.6.0 15 | hydra_colorlog==1.1.0 16 | hydra-optuna-sweeper==1.1.0 17 | sentencepiece==0.1.96 18 | scikit-learn==1.0.2 19 | torchmetrics==0.4.0 -------------------------------------------------------------------------------- /scripts/build_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse, json, math, os, sys, random, logging 2 | from collections import defaultdict as ddict, Counter 3 | from itertools import chain 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from pickle5 import pickle 8 | from tqdm import tqdm 9 | 10 | import torch 11 | import datasets 12 | from transformers import AutoTokenizer 13 | 14 | sys.path.append(os.path.join(sys.path[0], '..')) 15 | from src.utils.data import dataset_info, eraser_datasets, data_keys 16 | from src.utils.eraser.utils import annotations_from_jsonl, load_documents 17 | from src.utils.eraser.data_utils import ( 18 | bert_tokenize_doc, 19 | bert_intern_doc, 20 | bert_intern_annotation, 21 | annotations_to_evidence_identification, 22 | annotations_to_evidence_token_identification, 23 | ) 24 | 25 | logging.basicConfig(level=logging.DEBUG, format='%(relativeCreated)6d %(threadName)s %(message)s') 26 | logger = logging.getLogger(__name__) 27 | 28 | def set_random_seed(seed): 29 | random.seed(seed) 30 | np.random.seed(seed) 31 | torch.manual_seed(seed) 32 | torch.backends.cudnn.deterministic = True 33 | torch.backends.cudnn.benchmark = False 34 | 35 | def update_dataset_dict( 36 | idx, dataset_dict, input_ids, rationale, max_length, actual_max_length, tokenizer, interned_annotations, classes 37 | ): 38 | input_ids = [tokenizer.cls_token_id] + input_ids + [tokenizer.sep_token_id] 39 | rationale = [0] + rationale + [0] 40 | assert len(input_ids) == len(rationale) 41 | num_tokens = len(input_ids) 42 | if num_tokens > actual_max_length: 43 | actual_max_length = num_tokens 44 | 45 | num_pad_tokens = max_length - num_tokens 46 | assert num_pad_tokens >= 0 47 | 48 | input_ids += [tokenizer.pad_token_id] * num_pad_tokens 49 | attention_mask = [1] * num_tokens + [0] * num_pad_tokens 50 | rationale += [0] * num_pad_tokens 51 | 52 | inv_rationale = [1.0-x for x in rationale] 53 | rand_rationale = list(np.random.randn(max_length)) 54 | 55 | has_rationale = int(sum(rationale) > 0) 56 | if has_rationale == 0: 57 | raise ValueError('empty rationale') 58 | 59 | label = classes.index(interned_annotations[idx].classification) 60 | 61 | dataset_dict['item_idx'].append(idx) 62 | dataset_dict['input_ids'].append(input_ids) 63 | dataset_dict['attention_mask'].append(attention_mask) 64 | dataset_dict['rationale'].append(rationale) 65 | dataset_dict['inv_rationale'].append(inv_rationale) 66 | dataset_dict['rand_rationale'].append(rand_rationale) 67 | dataset_dict['has_rationale'].append(has_rationale) 68 | dataset_dict['label'].append(label) 69 | 70 | return dataset_dict, actual_max_length 71 | 72 | def align_rationale_with_tokens(input_ids, raw_tokens, raw_rationale, tokenizer): 73 | tokens = tokenizer.convert_ids_to_tokens(input_ids) 74 | rationale = [] 75 | j = 0 76 | cur_token = tokens[j] 77 | 78 | for i in range(len(raw_tokens)): 79 | cur_raw_token = raw_tokens[i] 80 | cur_raw_rationale = raw_rationale[i] 81 | cur_reconstructed_raw_token = '' 82 | 83 | while len(cur_raw_token) > 0: 84 | for char in cur_token: 85 | if char == cur_raw_token[0]: 86 | cur_raw_token = cur_raw_token[1:] 87 | cur_reconstructed_raw_token += char 88 | 89 | rationale.append(cur_raw_rationale) 90 | j += 1 91 | cur_token = tokens[j] if j < len(tokens) else None 92 | 93 | assert cur_reconstructed_raw_token == raw_tokens[i] 94 | 95 | return rationale 96 | 97 | def stratified_sampling(data, num_samples): 98 | num_instances = len(data) 99 | assert num_samples < num_instances 100 | 101 | counter_dict = Counter(data) 102 | unique_vals = list(counter_dict.keys()) 103 | val_counts = list(counter_dict.values()) 104 | num_unique_vals = len(unique_vals) 105 | assert num_unique_vals > 1 106 | 107 | num_stratified_samples = [int(c*num_samples/num_instances) for c in val_counts] 108 | assert sum(num_stratified_samples) <= num_samples 109 | if sum(num_stratified_samples) < num_samples: 110 | delta = num_samples - sum(num_stratified_samples) 111 | delta_samples = np.random.choice(range(num_unique_vals), replace=True, size=delta) 112 | for val in delta_samples: 113 | num_stratified_samples[unique_vals.index(val)] += 1 114 | assert sum(num_stratified_samples) == num_samples 115 | 116 | sampled_indices = [] 117 | for i, val in enumerate(unique_vals): 118 | candidates = np.where(data == val)[0] 119 | sampled_indices += list(np.random.choice(candidates, replace=False, size=num_stratified_samples[i])) 120 | random.shuffle(sampled_indices) 121 | 122 | return sampled_indices 123 | 124 | def sample_dataset(data_path, dataset_dict, split, num_samples, seed): 125 | sampled_split_filename = f'{split}_split_{num_samples}_{seed}.pkl' 126 | if os.path.exists(os.path.join(data_path, sampled_split_filename)): 127 | with open(os.path.join(data_path, sampled_split_filename), 'rb') as f: 128 | sampled_split = pickle.load(f) 129 | else: 130 | sampled_split = stratified_sampling(dataset_dict['label'], num_samples) 131 | with open(os.path.join(data_path, sampled_split_filename), 'wb') as f: 132 | pickle.dump(sampled_split, f) 133 | 134 | for key in data_keys: 135 | dataset_dict[key] = sampled_split if key == 'item_idx' else [dataset_dict[key][i] for i in sampled_split] 136 | 137 | return dataset_dict 138 | 139 | def sample_rationale_indices(data_path, dataset_dict, num_examples, pct_train_rationales, seed): 140 | # Sample indices for train examples with gold rationales 141 | sampled_indices_filename = f'rationale_indices_{pct_train_rationales}_{seed}.pkl' 142 | if not os.path.exists(os.path.join(data_path, sampled_indices_filename)): 143 | num_train_rationales = int(math.ceil(num_examples * pct_train_rationales / 100)) 144 | sampled_indices = list(np.random.choice(dataset_dict['item_idx'], size=num_train_rationales, replace=False)) 145 | with open(os.path.join(data_path, sampled_indices_filename), 'wb') as f: 146 | pickle.dump(sampled_indices, f) 147 | 148 | def load_dataset(data_path): 149 | dataset_dict = ddict(list) 150 | for key in tqdm(data_keys, desc=f'Loading {args.split} dataset'): 151 | with open(os.path.join(data_path, f'{key}.pkl'), 'rb') as f: 152 | dataset_dict[key] = pickle.load(f) 153 | return dataset_dict 154 | 155 | def save_dataset(data_path, dataset_dict, split, num_samples, seed): 156 | for key in tqdm(data_keys, desc=f'Saving {split} dataset'): 157 | filename = f'{key}.pkl' if num_samples is None else f'{key}_{num_samples}_{seed}.pkl' 158 | with open(os.path.join(data_path, filename), 'wb') as f: 159 | pickle.dump(dataset_dict[key], f) 160 | 161 | def main(args): 162 | set_random_seed(args.seed) 163 | 164 | assert args.split is not None and args.arch is not None 165 | assert args.num_samples is None or args.num_samples >= 1 166 | 167 | split, num_examples = dataset_info[args.dataset][args.split] 168 | if args.num_samples is not None: 169 | assert args.num_samples < num_examples 170 | 171 | num_classes = dataset_info[args.dataset]['num_classes'] 172 | max_length = dataset_info[args.dataset]['max_length'][args.arch] 173 | num_special_tokens = dataset_info[args.dataset]['num_special_tokens'] 174 | tokenizer = AutoTokenizer.from_pretrained(args.arch) 175 | data_path = os.path.join(args.data_dir, args.dataset, args.arch, args.split) 176 | classes = dataset_info[args.dataset]['classes'] 177 | if not os.path.exists(data_path): 178 | os.makedirs(data_path) 179 | 180 | if args.dataset in eraser_datasets: 181 | eraser_path = os.path.join(args.data_dir, 'eraser', args.dataset) 182 | documents_path = os.path.join(args.data_dir, args.dataset, args.arch, 'documents.pkl') 183 | documents = load_documents(eraser_path) 184 | logger.info(f'Loaded {len(documents)} documents') 185 | if os.path.exists(documents_path): 186 | logger.info(f'Loading processed documents from {documents_path}') 187 | (interned_documents, interned_document_token_slices) = torch.load(documents_path) 188 | logger.info(f'Loaded {len(interned_documents)} processed documents') 189 | else: 190 | logger.info(f'Processing documents') 191 | special_token_map = { 192 | 'SEP': [tokenizer.sep_token_id], 193 | '[SEP]': [tokenizer.sep_token_id], 194 | '[sep]': [tokenizer.sep_token_id], 195 | 'UNK': [tokenizer.unk_token_id], 196 | '[UNK]': [tokenizer.unk_token_id], 197 | '[unk]': [tokenizer.unk_token_id], 198 | 'PAD': [tokenizer.unk_token_id], 199 | '[PAD]': [tokenizer.unk_token_id], 200 | '[pad]': [tokenizer.unk_token_id], 201 | } 202 | interned_documents = {} 203 | interned_document_token_slices = {} 204 | for d, doc in tqdm(documents.items(), desc='Processing documents'): 205 | tokenized, w_slices = bert_tokenize_doc(doc, tokenizer, special_token_map=special_token_map) 206 | interned_documents[d] = bert_intern_doc(tokenized, tokenizer, special_token_map=special_token_map) 207 | interned_document_token_slices[d] = w_slices 208 | logger.info(f'Saving processed documents to {documents_path}') 209 | torch.save((interned_documents, interned_document_token_slices), documents_path) 210 | sys.exit() 211 | 212 | annotations_path = os.path.join(eraser_path, f'{split}.jsonl') 213 | annotations = annotations_from_jsonl(annotations_path) 214 | interned_annotations = bert_intern_annotation(annotations, tokenizer) 215 | if args.dataset in ['cose', 'esnli', 'movies']: 216 | evidence_data = annotations_to_evidence_token_identification(annotations, documents, interned_documents, interned_document_token_slices) 217 | elif args.dataset in ['fever', 'multirc']: 218 | evidence_data = annotations_to_evidence_identification(annotations, interned_documents) 219 | assert len(evidence_data) == num_examples 220 | 221 | missing_data_keys = [x for x in data_keys if not os.path.exists(os.path.join(data_path, f'{x}.pkl'))] 222 | if args.num_samples is None and missing_data_keys: 223 | dataset_dict = ddict(list) 224 | actual_max_length = 0 225 | if args.dataset in eraser_datasets: 226 | if args.dataset not in ['cose', 'esnli', 'fever', 'movies', 'multirc']: 227 | raise NotImplementedError 228 | 229 | if args.dataset == 'cose': 230 | q_marker = tokenizer('Q:', add_special_tokens=False)['input_ids'] 231 | a_marker = tokenizer('A:', add_special_tokens=False)['input_ids'] 232 | for idx, (instance_id, instance_evidence) in tqdm(enumerate(evidence_data.items()), desc=f'Building {args.split} dataset', total=num_examples): 233 | instance_docs = ddict(dict) 234 | assert len(instance_evidence) == 1 235 | doc = interned_documents[instance_id] 236 | evidence_sentences = instance_evidence[instance_id] 237 | 238 | question = list(chain.from_iterable(doc)) 239 | question_rationale = list(chain.from_iterable([x.kls for x in evidence_sentences])) 240 | answers = evidence_sentences[0].query.split(' [sep] ') 241 | answer_ids = [tokenizer(x, add_special_tokens=False)['input_ids'] for x in answers] 242 | 243 | input_ids, attention_mask, rationale, inv_rationale, rand_rationale, has_rationale = [], [], [], [], [], [] 244 | for answer in answer_ids: 245 | cur_input_ids = [tokenizer.cls_token_id] + q_marker + question + [tokenizer.sep_token_id] + a_marker + answer + [tokenizer.sep_token_id] 246 | 247 | num_tokens = len(cur_input_ids) 248 | if num_tokens > actual_max_length: 249 | actual_max_length = num_tokens 250 | num_pad_tokens = max_length - num_tokens 251 | assert num_pad_tokens >= 0 252 | 253 | cur_input_ids += [tokenizer.pad_token_id] * num_pad_tokens 254 | input_ids.append(cur_input_ids) 255 | 256 | cur_attention_mask = [1] * num_tokens + [0] * num_pad_tokens 257 | attention_mask.append(cur_attention_mask) 258 | 259 | cur_rationale = [0] + [0]*len(q_marker) + question_rationale + [0] + [0]*len(a_marker) + [0]*len(answer) + [0] 260 | cur_rationale += [0] * num_pad_tokens 261 | assert len(cur_input_ids) == len(cur_rationale) 262 | rationale.append(cur_rationale) 263 | 264 | inv_rationale.append([1.0-x for x in cur_rationale]) 265 | rand_rationale.append(list(np.random.randn(max_length))) 266 | 267 | cur_has_rationale = int(sum(cur_rationale) > 0) 268 | if cur_has_rationale == 0: 269 | raise ValueError('empty rationale') 270 | has_rationale.append(cur_has_rationale) 271 | 272 | label = classes.index(interned_annotations[idx].classification) 273 | 274 | dataset_dict['item_idx'].append(idx) 275 | dataset_dict['input_ids'].append(input_ids) 276 | dataset_dict['attention_mask'].append(attention_mask) 277 | dataset_dict['rationale'].append(rationale) 278 | dataset_dict['inv_rationale'].append(inv_rationale) 279 | dataset_dict['rand_rationale'].append(rand_rationale) 280 | dataset_dict['has_rationale'].append(has_rationale) 281 | dataset_dict['label'].append(label) 282 | 283 | elif args.dataset == 'esnli': 284 | for idx, (instance_id, instance_evidence) in tqdm(enumerate(evidence_data.items()), desc=f'Building {args.split} dataset', total=num_examples): 285 | instance_docs = ddict(dict) 286 | assert len(instance_evidence) in [1, 2] 287 | for doc_type in ['premise', 'hypothesis']: 288 | doc_id = f'{instance_id}_{doc_type}' 289 | doc = instance_evidence[doc_id] 290 | if doc: 291 | instance_docs[doc_type]['text'] = doc[0][5] 292 | instance_docs[doc_type]['rationale'] = list(doc[0][0]) 293 | else: 294 | instance_docs[doc_type]['text'] = interned_documents[doc_id][0] 295 | instance_docs[doc_type]['rationale'] = [0] * len(interned_documents[doc_id][0]) 296 | 297 | input_ids = instance_docs['premise']['text'] + [tokenizer.sep_token_id] + instance_docs['hypothesis']['text'] 298 | assert all([x != tokenizer.unk_token_id for x in input_ids]) 299 | rationale = instance_docs['premise']['rationale'] + [0] + instance_docs['hypothesis']['rationale'] 300 | dataset_dict, actual_max_length = update_dataset_dict(idx, dataset_dict, input_ids, rationale, max_length, actual_max_length, tokenizer, interned_annotations, classes) 301 | 302 | elif args.dataset == 'movies': 303 | for idx, (instance_id, instance_evidence) in tqdm(enumerate(evidence_data.items()), desc=f'Building {args.split} dataset', total=num_examples): 304 | instance_docs = ddict(dict) 305 | assert len(instance_evidence) == 1 306 | doc_id = list(instance_evidence.keys())[0] 307 | doc = interned_documents[doc_id] 308 | evidence_sentences = instance_evidence[doc_id] 309 | 310 | input_ids = list(chain.from_iterable(doc)) 311 | rationale = list(chain.from_iterable([x.kls for x in evidence_sentences])) 312 | 313 | input_length = min(len(input_ids), max_length - num_special_tokens) 314 | if sum(rationale[:input_length]) > 0: 315 | input_ids = input_ids[:input_length] 316 | rationale = rationale[:input_length] 317 | else: 318 | input_ids = input_ids[-input_length:] 319 | rationale = rationale[-input_length:] 320 | 321 | dataset_dict, actual_max_length = update_dataset_dict(idx, dataset_dict, input_ids, rationale, max_length, actual_max_length, tokenizer, interned_annotations, classes) 322 | 323 | elif args.dataset == 'multirc': 324 | for idx, (instance_id, instance_evidence) in tqdm(enumerate(evidence_data.items()), desc=f'Building {args.split} dataset', total=num_examples): 325 | instance_docs = ddict(dict) 326 | assert len(instance_evidence) == 1 327 | doc_id = list(instance_evidence.keys())[0] 328 | doc = interned_documents[doc_id] 329 | evidence_sentences = [x for x in instance_evidence[doc_id] if x.kls == 1] 330 | evidence_indices = [x.index for x in evidence_sentences] 331 | 332 | evidence_ids = list(chain.from_iterable(doc)) 333 | query_ids = tokenizer(instance_evidence[doc_id][0].query, add_special_tokens=False)['input_ids'] 334 | rationale = [] 335 | for i, sentence in enumerate(doc): 336 | if i in evidence_indices: 337 | rationale += [1] * len(sentence) 338 | else: 339 | rationale += [0] * len(sentence) 340 | 341 | input_ids = evidence_ids + [tokenizer.sep_token_id] + query_ids 342 | rationale += [0] * (len(query_ids)+1) 343 | dataset_dict, actual_max_length = update_dataset_dict(idx, dataset_dict, input_ids, rationale, max_length, actual_max_length, tokenizer, interned_annotations, classes) 344 | 345 | else: 346 | raise NotImplementedError 347 | 348 | elif args.dataset == 'sst': 349 | sst_json_path = open(os.path.join(args.data_dir, args.dataset, f'sst_{split}.json')) 350 | dataset = json.load(sst_json_path) 351 | for idx in tqdm(range(num_examples), desc=f'Building {args.split} dataset'): 352 | instance = dataset[idx] 353 | 354 | text = instance['text'] 355 | raw_tokens = text.split() 356 | raw_rationale = [1.0 if x >= 0.5 else 0.0 for x in instance['rationale']] 357 | assert len(raw_tokens) >= len(raw_rationale) 358 | if len(raw_tokens) > len(raw_rationale): # rationale missing last token for train instances [3631, 4767, 6020] 359 | diff = len(raw_tokens) - len(raw_rationale) # diff always equals 1 360 | raw_rationale = raw_rationale + diff*[0.0] 361 | assert len(raw_tokens) == len(raw_rationale) 362 | assert len(raw_tokens) <= max_length 363 | 364 | input_ids = tokenizer(text, add_special_tokens=False)['input_ids'] 365 | rationale = align_rationale_with_tokens(input_ids, raw_tokens, raw_rationale, tokenizer) 366 | assert len(input_ids) == len(rationale) 367 | 368 | num_tokens = len(input_ids) + num_special_tokens 369 | if num_tokens > actual_max_length: 370 | actual_max_length = num_tokens 371 | num_pad_tokens = max_length - num_tokens 372 | assert num_pad_tokens >= 0 373 | 374 | input_ids = [tokenizer.cls_token_id] + input_ids + [tokenizer.sep_token_id] + num_pad_tokens*[tokenizer.pad_token_id] 375 | assert all([x != tokenizer.unk_token_id for x in input_ids]) 376 | assert len(input_ids) == max_length 377 | 378 | attn_mask = num_tokens*[1] + num_pad_tokens*[0] 379 | assert len(attn_mask) == max_length 380 | 381 | rationale = [0.0] + rationale + [0.0] + num_pad_tokens*[0.0] 382 | assert sum(rationale) > 0 383 | assert len(rationale) == max_length 384 | 385 | inv_rationale = [1.0-x for x in rationale] 386 | rand_rationale = list(np.random.randn(max_length)) 387 | 388 | label = classes.index(instance['classification']) 389 | 390 | dataset_dict['item_idx'].append(idx) 391 | dataset_dict['input_ids'].append(input_ids) 392 | dataset_dict['attention_mask'].append(attn_mask) 393 | dataset_dict['rationale'].append(rationale) 394 | dataset_dict['inv_rationale'].append(inv_rationale) 395 | dataset_dict['rand_rationale'].append(rand_rationale) 396 | dataset_dict['has_rationale'].append(1) 397 | dataset_dict['label'].append(label) 398 | 399 | elif args.dataset == 'stf': 400 | stf_raw_path = os.path.join(args.data_dir, args.dataset, 'stf_raw', f'{split}.tsv') 401 | dataset = pd.read_csv(stf_raw_path, sep='\t') 402 | for row in tqdm(dataset.itertuples(index=True), desc=f'Building {args.split} dataset'): 403 | idx, _, text, label = row 404 | tokens = tokenizer(text, padding='max_length', max_length=max_length, truncation=True) 405 | 406 | input_ids_ = tokenizer(text, add_special_tokens=False)['input_ids'] 407 | num_tokens = len(input_ids_) + num_special_tokens 408 | if num_tokens > actual_max_length: 409 | actual_max_length = num_tokens 410 | 411 | rand_rationale = list(np.random.randn(max_length)) 412 | 413 | dataset_dict['item_idx'].append(idx) 414 | dataset_dict['input_ids'].append(tokens['input_ids']) 415 | dataset_dict['attention_mask'].append(tokens['attention_mask']) 416 | dataset_dict['rationale'].append(None) 417 | dataset_dict['rand_rationale'].append(rand_rationale) 418 | dataset_dict['has_rationale'].append(0) 419 | dataset_dict['label'].append(label) 420 | 421 | elif args.dataset == 'yelp': 422 | split_ = 'train' if split != 'test' else 'test' 423 | dataset = datasets.load_dataset('yelp_polarity')[split_] 424 | 425 | if split in ['train', 'test']: 426 | start_idx = 0 427 | elif split == 'dev': 428 | start_idx = dataset_info[args.dataset][split_][1] 429 | 430 | for idx in tqdm(range(start_idx, start_idx+num_examples), desc=f'Building {args.split} dataset'): 431 | text = dataset[idx]['text'] 432 | tokens = tokenizer( 433 | text, 434 | padding='max_length', 435 | max_length=max_length, 436 | truncation=True 437 | ) 438 | 439 | input_ids_ = tokenizer(text, add_special_tokens=False)['input_ids'] 440 | num_tokens = len(input_ids_) + num_special_tokens 441 | if num_tokens > actual_max_length: 442 | actual_max_length = num_tokens 443 | 444 | rand_rationale = list(np.random.randn(max_length)) 445 | 446 | dataset_dict['item_idx'].append(idx-start_idx) 447 | dataset_dict['input_ids'].append(tokens['input_ids']) 448 | dataset_dict['attention_mask'].append(tokens['attention_mask']) 449 | dataset_dict['rationale'].append(None) 450 | dataset_dict['rand_rationale'].append(rand_rationale) 451 | dataset_dict['has_rationale'].append(0) 452 | dataset_dict['label'].append(dataset[idx]['label']) 453 | 454 | elif args.dataset == 'amazon': 455 | split_ = 'train' if split != 'test' else 'test' 456 | dataset = datasets.load_dataset('amazon_polarity')[split_] 457 | 458 | if split in ['train', 'test']: 459 | start_idx = 0 460 | elif split == 'dev': 461 | start_idx = dataset_info[args.dataset][split_][1] 462 | 463 | for idx in tqdm(range(start_idx, start_idx+num_examples), desc=f'Building {args.split} dataset'): 464 | text = f'{dataset[idx]["title"]} {tokenizer.sep_token} {dataset[idx]["content"]}' 465 | tokens = tokenizer( 466 | text, 467 | padding='max_length', 468 | max_length=max_length, 469 | truncation=True 470 | ) 471 | 472 | input_ids_ = tokenizer(text, add_special_tokens=False)['input_ids'] 473 | num_tokens = len(input_ids_) + num_special_tokens 474 | if num_tokens > actual_max_length: 475 | actual_max_length = num_tokens 476 | 477 | rand_rationale = list(np.random.randn(max_length)) 478 | 479 | dataset_dict['item_idx'].append(idx-start_idx) 480 | dataset_dict['input_ids'].append(tokens['input_ids']) 481 | dataset_dict['attention_mask'].append(tokens['attention_mask']) 482 | dataset_dict['rationale'].append(None) 483 | dataset_dict['rand_rationale'].append(rand_rationale) 484 | dataset_dict['has_rationale'].append(0) 485 | dataset_dict['label'].append(dataset[idx]['label']) 486 | 487 | elif args.dataset == 'olid': 488 | dataset = datasets.load_dataset('tweet_eval', 'offensive')[split] 489 | for idx in tqdm(range(num_examples), desc=f'Building {args.split} dataset'): 490 | text = dataset[idx]['text'] 491 | 492 | input_ids_ = tokenizer(text, add_special_tokens=False)['input_ids'] 493 | num_tokens = len(input_ids_) + num_special_tokens 494 | if num_tokens > actual_max_length: 495 | actual_max_length = num_tokens 496 | 497 | rand_rationale = list(np.random.randn(max_length)) 498 | 499 | tokens = tokenizer(text, padding='max_length', max_length=max_length, truncation=True) 500 | dataset_dict['item_idx'].append(idx) 501 | dataset_dict['input_ids'].append(tokens['input_ids']) 502 | dataset_dict['attention_mask'].append(tokens['attention_mask']) 503 | dataset_dict['rationale'].append(None) 504 | dataset_dict['rand_rationale'].append(rand_rationale) 505 | dataset_dict['has_rationale'].append(0) 506 | dataset_dict['label'].append(dataset[idx]['label']) 507 | 508 | elif args.dataset == 'irony': 509 | dataset = datasets.load_dataset('tweet_eval', 'irony')[split] 510 | for idx in tqdm(range(num_examples), desc=f'Building {args.split} dataset'): 511 | text = dataset[idx]['text'] 512 | 513 | input_ids_ = tokenizer(text, add_special_tokens=False)['input_ids'] 514 | num_tokens = len(input_ids_) + num_special_tokens 515 | if num_tokens > actual_max_length: 516 | actual_max_length = num_tokens 517 | 518 | rand_rationale = list(np.random.randn(max_length)) 519 | 520 | tokens = tokenizer(text, padding='max_length', max_length=max_length, truncation=True) 521 | dataset_dict['item_idx'].append(idx) 522 | dataset_dict['input_ids'].append(tokens['input_ids']) 523 | dataset_dict['attention_mask'].append(tokens['attention_mask']) 524 | dataset_dict['rationale'].append(None) 525 | dataset_dict['rand_rationale'].append(rand_rationale) 526 | dataset_dict['has_rationale'].append(0) 527 | dataset_dict['label'].append(dataset[idx]['label']) 528 | 529 | else: 530 | raise NotImplementedError 531 | 532 | print(f'Actual max length: {actual_max_length}') 533 | 534 | else: 535 | dataset_dict = load_dataset(data_path) 536 | 537 | if args.num_samples is not None: 538 | assert all([os.path.exists(os.path.join(data_path, f'{x}.pkl')) for x in data_keys]) 539 | dataset_dict = sample_dataset(data_path, dataset_dict, args.split, args.num_samples, args.seed) 540 | 541 | if args.pct_train_rationales is not None: 542 | assert args.split == 'train' 543 | sample_rationale_indices(data_path, dataset_dict, num_examples, args.pct_train_rationales, args.seed) 544 | sys.exit() 545 | 546 | save_dataset(data_path, dataset_dict, args.split, args.num_samples, args.seed) 547 | 548 | 549 | if __name__ == '__main__': 550 | parser = argparse.ArgumentParser(description='Dataset preprocessing') 551 | parser.add_argument('--data_dir', type=str, default='../data/', help='Root directory for datasets') 552 | parser.add_argument('--dataset', type=str, 553 | choices=['cose', 'esnli', 'movies', 'multirc', 'sst', 'amazon', 'yelp', 'stf', 'olid', 'irony']) 554 | parser.add_argument('--arch', type=str, default='google/bigbird-roberta-base', choices=['google/bigbird-roberta-base', 'bert-base-uncased']) 555 | parser.add_argument('--split', type=str, help='Dataset split', choices=['train', 'dev', 'test']) 556 | parser.add_argument('--num_samples', type=int, default=None, help='Number of examples to sample. None means all available examples are used.') 557 | parser.add_argument('--pct_train_rationales', type=float, default=None, help='Percentage of train examples to provide gold rationales for. None means all available train examples are used.') 558 | parser.add_argument('--seed', type=int, default=0, help='Random seed') 559 | args = parser.parse_args() 560 | main(args) 561 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/UNIREX/60149a9c945376069b70fb3b845e6a20c11534ad/src/__init__.py -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/UNIREX/60149a9c945376069b70fb3b845e6a20c11534ad/src/data/__init__.py -------------------------------------------------------------------------------- /src/data/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import types 3 | from pathlib import Path 4 | from typing import Optional 5 | from copy import deepcopy 6 | from itertools import chain 7 | 8 | import numpy as np 9 | import pickle5 as pickle 10 | from hydra.utils import get_original_cwd 11 | import pytorch_lightning as pl 12 | import torch 13 | from torch.utils.data import DataLoader, Dataset 14 | from tqdm import tqdm 15 | 16 | from src.utils.data import dataset_info, data_keys 17 | 18 | 19 | class DataModule(pl.LightningDataModule): 20 | 21 | def __init__(self, 22 | dataset: str, data_path: str, mode: str, 23 | train_batch_size: int = 1, eval_batch_size: int = 1, eff_train_batch_size: int = 1, num_workers: int = 0, 24 | num_train: int = None, num_dev: int = None, num_test: int = None, 25 | num_train_seed: int = None, num_dev_seed: int = None, num_test_seed: int = None, 26 | pct_train_rationales: float = None, pct_train_rationales_seed: int = None, train_rationales_batch_factor: float = None, 27 | neg_weight: float = 1, 28 | attr_algo: str = None, 29 | fresh_exp_id: str = None, fresh_attr_algo: str = None, fresh_topk: int = None, 30 | fresh_extractor = None, min_val: float = -1e10, 31 | l2e_exp_id: str = None, l2e_attr_algo: str = None, l2e_num_classes: int = 5, 32 | train_shuffle: bool = False, 33 | ): 34 | super().__init__() 35 | 36 | self.dataset = dataset 37 | self.data_path = data_path # ${data_dir}/${.dataset}/${model.arch}/ 38 | 39 | self.train_batch_size = train_batch_size 40 | self.eval_batch_size = eval_batch_size 41 | self.eff_train_batch_size = eff_train_batch_size 42 | self.num_workers = num_workers 43 | 44 | self.num_samples = {'train': num_train, 'dev': num_dev, 'test': num_test} 45 | self.num_samples_seed = {'train': num_train_seed, 'dev': num_dev_seed, 'test': num_test_seed} 46 | self.pct_train_rationales = pct_train_rationales 47 | self.pct_train_rationales_seed = pct_train_rationales_seed 48 | self.train_rationales_batch_factor = train_rationales_batch_factor 49 | 50 | self.attr_algo = attr_algo 51 | 52 | self.fresh_exp_id = fresh_exp_id 53 | self.fresh_attr_algo = fresh_attr_algo 54 | self.fresh_topk = fresh_topk 55 | if fresh_topk is not None: 56 | assert 0 < fresh_topk < 100 57 | self.fresh_extractor = fresh_extractor 58 | self.min_val = min_val 59 | 60 | self.l2e_exp_id = l2e_exp_id 61 | self.l2e_attr_algo = l2e_attr_algo 62 | self.l2e_num_classes = l2e_num_classes 63 | 64 | self.train_shuffle = train_shuffle 65 | 66 | def load_dataset(self, split): 67 | dataset = {} 68 | data_path = os.path.join(self.data_path, split) 69 | assert Path(data_path).exists() 70 | 71 | for key in tqdm(data_keys, desc=f'Loading {split} set'): 72 | if ( 73 | (key in ['inv_rationale', 'rand_rationale'] and self.attr_algo not in ['inv', 'rand']) 74 | or (key == 'rationale' and self.dataset in ['amazon', 'yelp', 'stf', 'olid', 'irony']) 75 | or (key == 'rationale_indices' and (split != 'train' or self.pct_train_rationales is None)) 76 | ): 77 | continue 78 | elif key == 'rationale_indices' and split == 'train' and self.pct_train_rationales is not None: 79 | filename = f'{key}_{self.pct_train_rationales}_{self.pct_train_rationales_seed}.pkl' 80 | elif self.num_samples[split] is not None: 81 | filename = f'{key}_{self.num_samples[split]}_{self.num_samples_seed[split]}.pkl' 82 | else: 83 | filename = f'{key}.pkl' 84 | 85 | with open(os.path.join(data_path, filename), 'rb') as f: 86 | dataset[key] = pickle.load(f) 87 | 88 | if self.fresh_exp_id is not None: 89 | assert self.fresh_attr_algo in ['integrated-gradients', 'input-x-gradient', 'lm-gold'] 90 | assert self.fresh_extractor is not None 91 | assert self.fresh_topk is not None 92 | 93 | fresh_rationale_path = f'{get_original_cwd()}/../save/{self.fresh_exp_id}/model_outputs/{self.dataset}/{split}_attrs.pkl' 94 | print(f'Using {fresh_rationale_path} for split: {split}\n') 95 | assert Path(fresh_rationale_path).exists() 96 | attrs = pickle.load(open(fresh_rationale_path, 'rb')) 97 | if self.dataset == 'cose': 98 | num_examples = dataset_info[self.dataset][split][1] 99 | num_classes = dataset_info[self.dataset]['num_classes'] 100 | attrs = attrs.reshape(num_examples, num_classes, -1) 101 | 102 | fresh_rationales = [] 103 | for i, attr in enumerate(attrs): 104 | attn_mask = torch.LongTensor(dataset['attention_mask'][i]) 105 | if self.dataset == 'cose': 106 | fresh_rationale = [] 107 | for j, a in enumerate(attn_mask): 108 | num_tokens = a.sum() - 1 # don't include CLS token when computing num_tokens 109 | cur_attr = attr[j] + (1 - a) * self.min_val # ignore pad tokens when computing topk_indices 110 | cur_attr[0] = self.min_val # don't include CLS token when computing topk_indices 111 | k = max(1, int(self.fresh_topk / 100 * num_tokens)) 112 | cur_fresh_rationale = torch.zeros_like(a).long() 113 | topk_indices = torch.argsort(cur_attr[:num_tokens], descending=True)[:k] 114 | cur_fresh_rationale[topk_indices] = 1 115 | cur_fresh_rationale[0] = 1 # always treat CLS token as positive token 116 | fresh_rationale.append(list(cur_fresh_rationale.numpy())) 117 | else: 118 | num_tokens = attn_mask.sum() - 1 # don't include CLS token when computing num_tokens 119 | cur_attr = attr + (1 - attn_mask) * self.min_val # ignore pad tokens when computing topk_indices 120 | cur_attr[0] = self.min_val # don't include CLS token when computing topk_indices 121 | k = max(1, int(self.fresh_topk / 100 * num_tokens)) 122 | fresh_rationale = torch.zeros_like(attn_mask).long() 123 | topk_indices = torch.argsort(cur_attr[:num_tokens], descending=True)[:k] 124 | fresh_rationale[topk_indices] = 1 125 | fresh_rationale[0] = 1 # always treat CLS token as positive token 126 | fresh_rationale = list(fresh_rationale.numpy()) 127 | 128 | fresh_rationales.append(fresh_rationale) 129 | 130 | if self.fresh_extractor == 'oracle': 131 | dataset['fresh_rationale'] = fresh_rationales 132 | else: 133 | raise NotImplementedError 134 | 135 | elif self.l2e_exp_id is not None: 136 | assert self.l2e_attr_algo in ['integrated-gradients'] 137 | l2e_rationale_path = f'{get_original_cwd()}/../save/{self.l2e_exp_id}/model_outputs/{self.dataset}/{split}_attrs.pkl' 138 | print(f'Using {l2e_rationale_path} for split: {split}\n') 139 | assert Path(l2e_rationale_path).exists() 140 | attrs = pickle.load(open(l2e_rationale_path, 'rb')) 141 | if self.dataset == 'cose': 142 | num_examples = dataset_info[self.dataset][split][1] 143 | num_classes = dataset_info[self.dataset]['num_classes'] 144 | attrs = attrs.reshape(num_examples, num_classes, -1) 145 | l2e_rationales = [list(attr.numpy()) for attr in attrs] 146 | dataset['l2e_rationale'] = l2e_rationales 147 | 148 | if split == 'train' and self.pct_train_rationales is not None: 149 | dataset_ = deepcopy(dataset) 150 | dataset_keys = dataset_.keys() 151 | rationale_indices = dataset_['rationale_indices'] 152 | dataset, train_rationales_dataset = {}, {} 153 | for key in dataset_keys: 154 | if key != 'rationale_indices': 155 | dataset[key] = [x for i, x in enumerate(dataset_[key]) if i not in rationale_indices] 156 | train_rationales_dataset[key] = [x for i, x in enumerate(dataset_[key]) if i in rationale_indices] 157 | assert sorted(rationale_indices) == train_rationales_dataset['item_idx'] 158 | else: 159 | train_rationales_dataset = None 160 | 161 | return dataset, train_rationales_dataset 162 | 163 | def setup(self, splits=['all']): 164 | self.data = {} 165 | splits = ['train', 'dev', 'test'] if splits == ['all'] else splits 166 | for split in splits: 167 | dataset, train_rationales_dataset = self.load_dataset(split) 168 | self.data[split] = TextClassificationDataset(dataset, split, train_rationales_dataset, self.train_batch_size, self.l2e_num_classes) 169 | 170 | def train_dataloader(self): 171 | if self.pct_train_rationales is not None: 172 | assert self.train_batch_size >= 2 173 | assert self.train_rationales_batch_factor > 1 174 | batch_size = self.train_batch_size - int(max(1, self.train_batch_size / self.train_rationales_batch_factor)) 175 | else: 176 | batch_size = self.train_batch_size 177 | 178 | return DataLoader( 179 | self.data['train'], 180 | batch_size=batch_size, 181 | num_workers=self.num_workers, 182 | collate_fn=self.data['train'].collater, 183 | shuffle=self.train_shuffle, 184 | pin_memory=True 185 | ) 186 | 187 | def val_dataloader(self, test=False): 188 | if test: 189 | return DataLoader( 190 | self.data['dev'], 191 | batch_size=self.eval_batch_size, 192 | num_workers=self.num_workers, 193 | collate_fn=self.data['dev'].collater, 194 | pin_memory=True 195 | ) 196 | 197 | return [ 198 | DataLoader( 199 | self.data[eval_split], 200 | batch_size=self.eval_batch_size, 201 | num_workers=self.num_workers, 202 | collate_fn=self.data[eval_split].collater, 203 | pin_memory=True) 204 | 205 | for eval_split in ['dev', 'test'] 206 | ] 207 | 208 | def test_dataloader(self): 209 | return DataLoader( 210 | self.data['test'], 211 | batch_size=self.eval_batch_size, 212 | num_workers=self.num_workers, 213 | collate_fn=self.data['test'].collater, 214 | pin_memory=True 215 | ) 216 | 217 | 218 | class TextClassificationDataset(Dataset): 219 | def __init__(self, dataset, split, train_rationales_dataset=None, train_batch_size=None, l2e_num_classes=None): 220 | self.data = dataset 221 | self.split = split 222 | self.train_rationales_dataset = train_rationales_dataset 223 | self.train_batch_size = train_batch_size 224 | assert not (split != 'train' and train_rationales_dataset is not None) 225 | if train_rationales_dataset is not None: 226 | self.len_train_rationales_dataset = len(train_rationales_dataset['item_idx']) 227 | self.l2e_num_classes = l2e_num_classes 228 | 229 | def __len__(self): 230 | return len(self.data['item_idx']) 231 | 232 | def __getitem__(self, idx): 233 | item_idx = torch.LongTensor([self.data['item_idx'][idx]]) 234 | input_ids = torch.LongTensor(self.data['input_ids'][idx]) 235 | attention_mask = torch.LongTensor(self.data['attention_mask'][idx]) 236 | rationale = torch.FloatTensor(self.data['rationale'][idx]) if self.data.get('rationale') else None 237 | has_rationale = torch.LongTensor([self.data['has_rationale'][idx]]) 238 | if self.train_rationales_dataset is not None: 239 | has_rationale *= 0 240 | label = torch.LongTensor([self.data['label'][idx]]) 241 | inv_rationale = torch.FloatTensor(self.data['inv_rationale'][idx]) if self.data.get('inv_rationale') else None 242 | rand_rationale = torch.FloatTensor(self.data['rand_rationale'][idx]) if self.data.get('rand_rationale') else None 243 | fresh_rationale = torch.LongTensor(self.data['fresh_rationale'][idx]) if self.data.get('fresh_rationale') else None 244 | if self.data.get('l2e_rationale'): 245 | assert self.l2e_num_classes == 5 246 | l2e_rationale = self.discretize_l2e_rationale(self.data['l2e_rationale'][idx]) 247 | else: 248 | l2e_rationale = None 249 | 250 | return ( 251 | item_idx, input_ids, attention_mask, rationale, has_rationale, label, inv_rationale, rand_rationale, fresh_rationale, l2e_rationale 252 | ) 253 | 254 | def discretize_l2e_rationale(self, l2e_rationale): 255 | l2e_rationale = list(np.array(l2e_rationale).flatten()) 256 | 257 | all_pos = [x for x in l2e_rationale if x > 0] 258 | all_neg = [x for x in l2e_rationale if x < 0] 259 | mean_pos = sum(all_pos) / len(all_pos) if len(all_pos) > 0 else 0.0 260 | mean_neg = sum(all_neg) / len(all_neg) if len(all_neg) > 0 else 0.0 261 | 262 | l2e_rationale = torch.LongTensor([ 263 | 0 if x < mean_neg 264 | else 1 if mean_neg <= x < 0.0 265 | else 2 if x == 0.0 266 | else 3 if 0.0 < x <= mean_pos 267 | else 4 268 | for x in l2e_rationale]) 269 | 270 | return l2e_rationale 271 | 272 | def sample_train_rationale_indices(self, num_samples): 273 | return list(np.random.choice(self.len_train_rationales_dataset, size=num_samples, replace=False)) 274 | 275 | def get_train_rationale_item(self, idx): 276 | item_idx = torch.LongTensor([self.train_rationales_dataset['item_idx'][idx]]) 277 | input_ids = torch.LongTensor(self.train_rationales_dataset['input_ids'][idx]) 278 | attention_mask = torch.LongTensor(self.train_rationales_dataset['attention_mask'][idx]) 279 | rationale = torch.FloatTensor(self.train_rationales_dataset['rationale'][idx]) 280 | has_rationale = torch.LongTensor([self.train_rationales_dataset['has_rationale'][idx]]) 281 | label = torch.LongTensor([self.train_rationales_dataset['label'][idx]]) 282 | inv_rationale = torch.FloatTensor(self.train_rationales_dataset['inv_rationale'][idx]) if self.train_rationales_dataset.get('inv_rationale') else None 283 | rand_rationale = torch.FloatTensor(self.train_rationales_dataset['rand_rationale'][idx]) if self.train_rationales_dataset.get('rand_rationale') else None 284 | fresh_rationale = torch.LongTensor(self.train_rationales_dataset['fresh_rationale'][idx]) if self.train_rationales_dataset.get('fresh_rationale') else None 285 | if self.train_rationales_dataset.get('l2e_rationale'): 286 | assert self.l2e_num_classes == 5 287 | l2e_rationale = self.discretize_l2e_rationale(self.train_rationales_dataset['l2e_rationale'][idx]) 288 | else: 289 | l2e_rationale = None 290 | 291 | return ( 292 | item_idx, input_ids, attention_mask, rationale, has_rationale, label, inv_rationale, rand_rationale, fresh_rationale, l2e_rationale 293 | ) 294 | 295 | def collater(self, items): 296 | batch_size = len(items) 297 | if self.train_rationales_dataset is not None: 298 | num_train_rationale_indices = int(max(1, self.train_batch_size - batch_size)) 299 | train_rationale_indices = self.sample_train_rationale_indices(num_train_rationale_indices) 300 | for idx in train_rationale_indices: 301 | items.append(self.get_train_rationale_item(idx)) 302 | 303 | batch = { 304 | 'item_idx': torch.cat([x[0] for x in items]), 305 | 'input_ids': torch.stack([x[1] for x in items], dim=0), 306 | 'attention_mask': torch.stack([x[2] for x in items], dim=0), 307 | 'rationale': torch.stack([x[3] for x in items], dim=0) if self.data.get('rationale') else None, 308 | 'has_rationale': torch.cat([x[4] for x in items]), 309 | 'label': torch.cat([x[5] for x in items]), 310 | 'inv_rationale': torch.stack([x[6] for x in items], dim=0) if self.data.get('inv_rationale') else None, 311 | 'rand_rationale': torch.stack([x[7] for x in items], dim=0) if self.data.get('rand_rationale') else None, 312 | 'fresh_rationale': torch.stack([x[8] for x in items], dim=0) if self.data.get('fresh_rationale') else None, 313 | 'l2e_rationale': torch.stack([x[9] for x in items], dim=0) if self.data.get('l2e_rationale') else None, 314 | 'split': self.split, # when evaluate_ckpt=true, split always test 315 | } 316 | 317 | return batch -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/UNIREX/60149a9c945376069b70fb3b845e6a20c11534ad/src/model/__init__.py -------------------------------------------------------------------------------- /src/model/base_model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytorch_lightning as pl 4 | 5 | 6 | class BaseModel(pl.LightningModule): 7 | def __init__(self): 8 | super().__init__() 9 | # update in `setup` 10 | self.total_steps = None 11 | 12 | def forward(self, **kwargs): 13 | raise NotImplementedError 14 | 15 | def calc_loss(self, preds, targets): 16 | raise NotImplementedError 17 | 18 | def calc_acc(self, preds, targets): 19 | raise NotImplementedError 20 | 21 | def run_step(self, batch, split): 22 | raise NotImplementedError 23 | 24 | def aggregate_epoch(self, outputs, split): 25 | raise NotImplementedError 26 | 27 | def training_step(self, batch, batch_idx): 28 | # # freeze encoder for initial few epochs based on p.freeze_epochs 29 | # if self.current_epoch < self.freeze_epochs: 30 | # freeze_net(self.text_encoder) 31 | # else: 32 | # unfreeze_net(self.text_encoder) 33 | 34 | return self.run_step(batch, 'train', batch_idx) 35 | 36 | def training_epoch_end(self, outputs): 37 | self.aggregate_epoch(outputs, 'train') 38 | 39 | def validation_step(self, batch, batch_idx, dataset_idx): 40 | assert dataset_idx in [0, 1] 41 | eval_splits = {0: 'dev', 1: 'test'} 42 | return self.run_step(batch, eval_splits[dataset_idx], batch_idx) 43 | 44 | def validation_epoch_end(self, outputs): 45 | self.aggregate_epoch(outputs, 'dev') 46 | 47 | def test_step(self, batch, batch_idx): 48 | return self.run_step(batch, 'test', batch_idx) 49 | 50 | def test_epoch_end(self, outputs): 51 | self.aggregate_epoch(outputs, 'test') 52 | 53 | def setup(self, stage: Optional[str] = None): 54 | """calculate total steps""" 55 | if stage == 'fit': 56 | # Get train dataloader 57 | train_loader = self.trainer.datamodule.train_dataloader() 58 | ngpus = self.trainer.num_gpus 59 | 60 | # Calculate total steps 61 | eff_train_batch_size = (self.trainer.datamodule.train_batch_size * 62 | max(1, ngpus) * self.trainer.accumulate_grad_batches) 63 | assert eff_train_batch_size == self.trainer.datamodule.eff_train_batch_size 64 | self.total_steps = int( 65 | (len(train_loader.dataset) // eff_train_batch_size) * float(self.trainer.max_epochs)) 66 | 67 | def configure_optimizers(self): 68 | raise NotImplementedError 69 | -------------------------------------------------------------------------------- /src/model/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def MLP_factory(layer_sizes, dropout=False, layernorm=False): 4 | modules = nn.ModuleList() 5 | unpacked_sizes = [] 6 | for block in layer_sizes: 7 | unpacked_sizes.extend([block[0]] * block[1]) 8 | 9 | for k in range(len(unpacked_sizes)-1): 10 | if layernorm: 11 | modules.append(nn.LayerNorm(unpacked_sizes[k])) 12 | modules.append(nn.Linear(unpacked_sizes[k], unpacked_sizes[k+1])) 13 | if k < len(unpacked_sizes)-2: 14 | modules.append(nn.ReLU()) 15 | if dropout is not False: 16 | modules.append(nn.Dropout(dropout)) 17 | mlp = nn.Sequential(*modules) 18 | return mlp -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | import os, shutil 2 | from typing import Tuple, Optional 3 | 4 | import torch 5 | import pytorch_lightning as pl 6 | from hydra.utils import instantiate 7 | from omegaconf import open_dict, DictConfig 8 | from pytorch_lightning.callbacks import ( 9 | ModelCheckpoint, EarlyStopping 10 | ) 11 | from transformers import AutoTokenizer 12 | 13 | from src.utils.data import dataset_info, monitor_dict 14 | from src.utils.logging import get_logger 15 | from src.utils.callbacks import BestPerformance 16 | from src.utils.expl import attr_algos, baseline_required 17 | 18 | 19 | def get_callbacks(cfg: DictConfig): 20 | 21 | monitor = monitor_dict[cfg.data.dataset] 22 | mode = cfg.data.mode 23 | callbacks = [ 24 | BestPerformance(monitor=monitor, mode=mode) 25 | ] 26 | 27 | if cfg.save_checkpoint: 28 | callbacks.append( 29 | ModelCheckpoint( 30 | monitor=monitor, 31 | dirpath=os.path.join(cfg.save_dir, 'checkpoints'), 32 | save_top_k=1, 33 | mode=mode, 34 | verbose=True, 35 | save_last=False, 36 | save_weights_only=True, 37 | ) 38 | ) 39 | 40 | if cfg.early_stopping: 41 | callbacks.append( 42 | EarlyStopping( 43 | monitor=monitor, 44 | min_delta=0.00, 45 | patience=cfg.training.patience, 46 | verbose=False, 47 | mode=mode 48 | ) 49 | ) 50 | 51 | return callbacks 52 | 53 | 54 | logger = get_logger(__name__) 55 | 56 | 57 | def build(cfg) -> Tuple[pl.LightningDataModule, pl.LightningModule, pl.Trainer]: 58 | dm = instantiate( 59 | cfg.data, 60 | attr_algo=cfg.model.attr_algo, 61 | fresh_extractor=cfg.model.fresh_extractor, 62 | train_shuffle=cfg.training.train_shuffle, 63 | ) 64 | dm.setup(splits=cfg.training.eval_splits.split(",")) 65 | 66 | logger.info(f'load {cfg.data.dataset} <{cfg.data._target_}>') 67 | 68 | model = instantiate( 69 | cfg.model, num_classes=dataset_info[cfg.data.dataset]['num_classes'], 70 | neg_weight=cfg.data.neg_weight, 71 | _recursive_=False 72 | ) 73 | logger.info(f'load {cfg.model.arch} <{cfg.model._target_}>') 74 | 75 | run_logger = instantiate(cfg.logger, cfg=cfg, _recursive_=False) 76 | 77 | with open_dict(cfg): 78 | if cfg.debug or cfg.logger.offline: 79 | exp_dir = cfg.logger.name 80 | cfg.logger.neptune_exp_id = cfg.logger.name 81 | else: 82 | if cfg.logger.logger == "neptune": 83 | exp_dir = run_logger.experiment_id 84 | cfg.logger.neptune_exp_id = run_logger.experiment_id 85 | else: 86 | raise NotImplementedError 87 | cfg.save_dir = os.path.join(cfg.save_dir, exp_dir) 88 | os.makedirs(cfg.save_dir, exist_ok=True) 89 | 90 | # copy hydra configs 91 | shutil.copytree( 92 | os.path.join(os.getcwd(), ".hydra"), 93 | os.path.join(cfg.save_dir, "hydra") 94 | ) 95 | 96 | logger.info(f"saving to {cfg.save_dir}") 97 | 98 | trainer = instantiate( 99 | cfg.trainer, 100 | callbacks=get_callbacks(cfg), 101 | checkpoint_callback=cfg.save_checkpoint, 102 | logger=run_logger, 103 | _convert_="all", 104 | ) 105 | 106 | return dm, model, trainer 107 | 108 | 109 | def restore_config_params(model, cfg: DictConfig): 110 | for key, val in cfg.model.items(): 111 | setattr(model, key, val) 112 | 113 | if cfg.model.save_outputs: 114 | assert cfg.model.exp_id in cfg.training.ckpt_path 115 | 116 | if cfg.model.explainer_type == 'attr_algo' and model.attr_algo in attr_algos.keys(): 117 | model.attr_func = attr_algos[model.attr_algo](model) 118 | model.tokenizer = AutoTokenizer.from_pretrained(cfg.model.arch) 119 | model.baseline_required = baseline_required[model.attr_algo] 120 | model.word_emb_layer = model.task_encoder.embeddings.word_embeddings 121 | model.attr_dict['baseline_required'] = model.baseline_required 122 | if model.attr_algo == 'integrated-gradients': 123 | model.attr_dict['ig_steps'] = getattr(model, 'ig_steps') 124 | model.attr_dict['internal_batch_size'] = getattr(model, 'internal_batch_size') 125 | model.attr_dict['return_convergence_delta'] = getattr(model, 'return_convergence_delta') 126 | elif model.attr_algo == 'gradient-shap': 127 | model.attr_dict['gradshap_n_samples'] = getattr(model, 'gradshap_n_samples') 128 | model.attr_dict['gradshap_stdevs'] = getattr(model, 'gradshap_stdevs') 129 | model.attr_dict['attr_func'] = model.attr_func 130 | model.attr_dict['tokenizer'] = model.tokenizer 131 | 132 | logger.info('Restored params from model config.') 133 | 134 | return model 135 | 136 | 137 | def run(cfg: DictConfig) -> Optional[float]: 138 | pl.seed_everything(cfg.seed) 139 | dm, model, trainer = build(cfg) 140 | pl.seed_everything(cfg.seed) 141 | 142 | if cfg.save_rand_checkpoint: 143 | ckpt_path = os.path.join(cfg.save_dir, 'checkpoints', 'rand.ckpt') 144 | logger.info(f"Saving randomly initialized model to {ckpt_path}") 145 | trainer.model = model 146 | trainer.save_checkpoint(ckpt_path) 147 | elif not cfg.training.evaluate_ckpt: 148 | # either train from scratch, or resume training from ckpt 149 | if cfg.training.finetune_ckpt: 150 | assert cfg.training.ckpt_path 151 | save_dir = '/'.join(cfg.save_dir.split('/')[:-1]) 152 | ckpt_path = os.path.join(save_dir, cfg.training.ckpt_path) 153 | model = model.load_from_checkpoint(ckpt_path, strict=False) 154 | model = restore_config_params(model, cfg) 155 | logger.info(f"Loaded checkpoint (for fine-tuning) from {ckpt_path}") 156 | 157 | trainer.fit(model=model, datamodule=dm) 158 | 159 | if getattr(cfg, "tune_metric", None): 160 | metric = trainer.callback_metrics[cfg.tune_metric].detach() 161 | logger.info(f"best metric {metric}") 162 | return metric 163 | else: 164 | # evaluate the pretrained model on the provided splits 165 | assert cfg.training.ckpt_path 166 | save_dir = '/'.join(cfg.save_dir.split('/')[:-2]) 167 | ckpt_path = os.path.join(save_dir, cfg.training.ckpt_path) 168 | model = model.load_from_checkpoint(ckpt_path, strict=False) 169 | logger.info(f"Loaded checkpoint for evaluation from {cfg.training.ckpt_path}") 170 | model = restore_config_params(model, cfg) 171 | print('Evaluating loaded model checkpoint...') 172 | for split in cfg.training.eval_splits.split(','): 173 | print(f'Evaluating on split: {split}') 174 | if split == 'train': 175 | loader = dm.train_dataloader() 176 | elif split == 'dev': 177 | loader = dm.val_dataloader(test=True) 178 | elif split == 'test': 179 | loader = dm.test_dataloader() 180 | 181 | trainer.test(model=model, dataloaders=loader) 182 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/UNIREX/60149a9c945376069b70fb3b845e6a20c11534ad/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from pytorch_lightning.callbacks import Callback 3 | 4 | class BestPerformance(Callback): 5 | 6 | def __init__(self, monitor, mode): 7 | super().__init__() 8 | 9 | self.monitor = monitor 10 | assert monitor.split('_')[0] == 'dev' 11 | self.test_monitor = '_'.join(['test'] + monitor.split('_')[1:]) 12 | 13 | self.mode = mode 14 | assert mode in ['max', 'min'] 15 | 16 | def set_best_expl_metric(self, trainer, pl_module, metric): 17 | assert metric in ['comp', 'suff', 'log_odds', 'csd', 'plaus'] 18 | for split in ['dev', 'test']: 19 | if metric == 'plaus': 20 | pl_module.best_metrics[f'{split}_best_{metric}_auprc'] = trainer.callback_metrics[f'{split}_{metric}_auprc_metric_epoch'] 21 | pl_module.best_metrics[f'{split}_best_{metric}_token_f1'] = trainer.callback_metrics[f'{split}_{metric}_token_f1_metric_epoch'] 22 | else: 23 | pl_module.best_metrics[f'{split}_best_{metric}_aopc'] = trainer.callback_metrics[f'{split}_{metric}_aopc_metric_epoch'] 24 | for k in pl_module.topk[split]: 25 | pl_module.best_metrics[f'{split}_best_{metric}_{k}'] = trainer.callback_metrics[f'{split}_{metric}_{k}_metric_epoch'] 26 | return pl_module 27 | 28 | def log_best_expl_metric(self, pl_module, metric): 29 | assert metric in ['comp', 'suff', 'log_odds', 'csd', 'plaus'] 30 | for split in ['dev', 'test']: 31 | if metric == 'plaus': 32 | pl_module.log(f'{split}_best_{metric}_auprc', pl_module.best_metrics[f'{split}_best_{metric}_auprc'], prog_bar=True, sync_dist=True) 33 | pl_module.log(f'{split}_best_{metric}_token_f1', pl_module.best_metrics[f'{split}_best_{metric}_token_f1'], prog_bar=True, sync_dist=True) 34 | else: 35 | pl_module.log(f'{split}_best_{metric}_aopc', pl_module.best_metrics[f'{split}_best_{metric}_aopc'], prog_bar=True, sync_dist=True) 36 | for k in pl_module.topk[split]: 37 | pl_module.log(f'{split}_best_{metric}_{k}', pl_module.best_metrics[f'{split}_best_{metric}_{k}'], prog_bar=True, sync_dist=True) 38 | 39 | def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 40 | if self.mode == 'max': 41 | if pl_module.best_metrics['dev_best_perf'] == None: 42 | assert pl_module.best_metrics['test_best_perf'] == None 43 | pl_module.best_metrics['dev_best_perf'] = -float('inf') 44 | 45 | if trainer.callback_metrics[self.monitor] > pl_module.best_metrics['dev_best_perf']: 46 | pl_module.best_metrics['dev_best_perf'] = trainer.callback_metrics[self.monitor] 47 | pl_module.best_metrics['test_best_perf'] = trainer.callback_metrics[self.test_monitor] 48 | pl_module.best_metrics['best_epoch'] = trainer.current_epoch 49 | pl_module = self.set_best_expl_metric(trainer, pl_module, 'comp') 50 | pl_module = self.set_best_expl_metric(trainer, pl_module, 'suff') 51 | if pl_module.log_odds: 52 | pl_module = self.set_best_expl_metric(trainer, pl_module, 'log_odds') 53 | pl_module = self.set_best_expl_metric(trainer, pl_module, 'csd') 54 | pl_module = self.set_best_expl_metric(trainer, pl_module, 'plaus') 55 | 56 | else: 57 | if pl_module.best_metrics['dev_best_perf'] == None: 58 | assert pl_module.best_metrics['test_best_perf'] == None 59 | pl_module.best_metrics['dev_best_perf'] = float('inf') 60 | 61 | if trainer.callback_metrics[self.monitor] < pl_module.best_metrics['dev_best_perf']: 62 | pl_module.best_metrics['dev_best_perf'] = trainer.callback_metrics[self.monitor] 63 | pl_module.best_metrics['test_best_perf'] = trainer.callback_metrics[self.test_monitor] 64 | pl_module.best_metrics['best_epoch'] = trainer.current_epoch 65 | pl_module = self.set_best_expl_metric(trainer, pl_module, 'comp') 66 | pl_module = self.set_best_expl_metric(trainer, pl_module, 'suff') 67 | if pl_module.log_odds: 68 | pl_module = self.set_best_expl_metric(trainer, pl_module, 'log_odds') 69 | pl_module = self.set_best_expl_metric(trainer, pl_module, 'csd') 70 | pl_module = self.set_best_expl_metric(trainer, pl_module, 'plaus') 71 | 72 | pl_module.log('dev_best_perf', pl_module.best_metrics['dev_best_perf'], prog_bar=True, sync_dist=True) 73 | pl_module.log('test_best_perf', pl_module.best_metrics['test_best_perf'], prog_bar=True, sync_dist=True) 74 | pl_module.log('best_epoch', pl_module.best_metrics['best_epoch'], prog_bar=True, sync_dist=True) 75 | self.log_best_expl_metric(pl_module, 'comp') 76 | self.log_best_expl_metric(pl_module, 'suff') 77 | if pl_module.log_odds: 78 | self.log_best_expl_metric(pl_module, 'log_odds') 79 | self.log_best_expl_metric(pl_module, 'csd') 80 | self.log_best_expl_metric(pl_module, 'plaus') 81 | -------------------------------------------------------------------------------- /src/utils/conf.py: -------------------------------------------------------------------------------- 1 | """ 2 | handle changes to the hydra config 3 | """ 4 | import time 5 | import uuid 6 | from pathlib import Path 7 | 8 | import rich.syntax 9 | import rich.tree 10 | from omegaconf import DictConfig, ListConfig, OmegaConf 11 | from pytorch_lightning.utilities import rank_zero_only 12 | 13 | 14 | def fail_on_missing(cfg: DictConfig) -> None: 15 | if isinstance(cfg, ListConfig): 16 | for x in cfg: 17 | fail_on_missing(x) 18 | elif isinstance(cfg, DictConfig): 19 | for _, v in cfg.items(): 20 | fail_on_missing(v) 21 | 22 | 23 | def pretty_print( 24 | cfg: DictConfig, 25 | fields=( 26 | "dataset", 27 | "model", 28 | "logger", 29 | "trainer", 30 | "setup", 31 | "training", 32 | ) 33 | ): 34 | style = "dim" 35 | tree = rich.tree.Tree(":gear: CONFIG", style=style, guide_style=style) 36 | 37 | for field in fields: 38 | branch = tree.add(field, style=style, guide_style=style) 39 | 40 | config_section = cfg.get(field) 41 | branch_content = str(config_section) 42 | if isinstance(config_section, DictConfig): 43 | branch_content = OmegaConf.to_yaml(config_section, resolve=True) 44 | 45 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 46 | 47 | # others defined in root 48 | others = tree.add("others", style=style, guide_style=style) 49 | for var, val in OmegaConf.to_container(cfg, resolve=True).items(): 50 | if not var.startswith("_") and var not in fields: 51 | others.add(f"{var}: {val}") 52 | 53 | rich.print(tree) 54 | 55 | 56 | @rank_zero_only 57 | def touch(cfg: DictConfig) -> None: 58 | # sanity check 59 | assert Path(cfg.data.data_path).exists(), f"datapath {cfg.data.data_path} not exist" 60 | 61 | if cfg.training.finetune_ckpt: 62 | assert cfg.training.ckpt_path 63 | if cfg.training.evaluate_ckpt: 64 | assert cfg.training.ckpt_path 65 | assert cfg.training.eval_splits != "all" 66 | 67 | cfg.logger.name = f'{cfg.model.model}_{cfg.data.dataset}_{cfg.model.arch}_{time.strftime("%d_%m_%Y")}_{str(uuid.uuid4())[: 8]}' 68 | 69 | if cfg.debug: 70 | # for DEBUG purposes only 71 | cfg.trainer.limit_train_batches = 1 72 | cfg.trainer.limit_val_batches = 1 73 | cfg.trainer.limit_test_batches = 1 74 | cfg.trainer.max_epochs = 1 75 | # for DEBUG purposes only 76 | 77 | fail_on_missing(cfg) 78 | pretty_print(cfg) 79 | -------------------------------------------------------------------------------- /src/utils/data.py: -------------------------------------------------------------------------------- 1 | dataset_info = { 2 | 'amazon': { 3 | 'train': ['train', 10000], 4 | 'dev': ['dev', 2000], 5 | 'test': ['test', 2000], 6 | 'num_classes': 2, 7 | 'classes': ['neg', 'pos'], 8 | 'max_length': { 9 | 'bert-base-uncased': 256, 10 | 'google/bigbird-roberta-base': 256, 11 | }, 12 | 'num_special_tokens': 2, 13 | }, 14 | 'cose': { 15 | 'train': ['train', 8752], 16 | 'dev': ['val', 1086], 17 | 'test': ['test', 1079], 18 | 'num_classes': 5, 19 | 'classes': ['A', 'B', 'C', 'D', 'E'], 20 | 'max_length': { 21 | 'bert-base-uncased': 512, 22 | 'google/bigbird-roberta-base': 77, 23 | }, 24 | 'num_special_tokens': None, 25 | }, 26 | 'esnli': { 27 | 'train': ['train', 549309], 28 | 'dev': ['val', 9823], 29 | 'test': ['test', 9807], 30 | 'num_classes': 3, 31 | 'classes': ['entailment', 'neutral', 'contradiction'], 32 | 'max_length': { 33 | 'bert-base-uncased': 125, 34 | 'google/bigbird-roberta-base': 125, 35 | }, 36 | 'num_special_tokens': 3, 37 | }, 38 | 'movies': { 39 | 'train': ['train', 1599], 40 | 'dev': ['val', 200], 41 | 'test': ['test', 200], 42 | 'num_classes': 2, 43 | 'classes': ['NEG', 'POS'], 44 | 'max_length': { 45 | 'bert-base-uncased': 512, 46 | 'google/bigbird-roberta-base': 1024, 47 | }, 48 | 'num_special_tokens': 2, 49 | }, 50 | 'multirc': { 51 | 'train': ['train', 24029], 52 | 'dev': ['val', 3214], 53 | 'test': ['test', 4848], 54 | 'num_classes': 2, 55 | 'classes': ['False', 'True'], 56 | 'max_length': { 57 | 'bert-base-uncased': 512, 58 | 'google/bigbird-roberta-base': 748, 59 | }, 60 | 'num_special_tokens': 3, 61 | }, 62 | 'sst': { 63 | 'train': ['train', 6920], 64 | 'dev': ['dev', 872], 65 | 'test': ['test', 1821], 66 | 'num_classes': 2, 67 | 'classes': ['neg', 'pos'], 68 | 'max_length': { 69 | 'bert-base-uncased': 58, 70 | 'google/bigbird-roberta-base': 67, 71 | }, 72 | 'num_special_tokens': 2, 73 | }, 74 | 'stf': { 75 | 'train': ['train', 7896], 76 | 'dev': ['dev', 978], 77 | 'test': ['test', 1998], 78 | 'num_classes': 2, 79 | 'classes': ['not_hate', 'hate'], 80 | 'max_length': { 81 | 'bert-base-uncased': 128, 82 | 'google/bigbird-roberta-base': 128, 83 | }, 84 | 'num_special_tokens': 2, 85 | }, 86 | 'yelp': { 87 | 'train': ['train', 10000], 88 | 'dev': ['dev', 2000], 89 | 'test': ['test', 2000], 90 | 'num_classes': 2, 91 | 'classes': ['neg', 'pos'], 92 | 'max_length': { 93 | 'bert-base-uncased': 512, 94 | 'google/bigbird-roberta-base': 512, 95 | }, 96 | 'num_special_tokens': 2, 97 | }, 98 | 'olid': { 99 | 'train': ['train', 11916], 100 | 'dev': ['validation', 1324], 101 | 'test': ['test', 860], 102 | 'num_classes': 2, 103 | 'classes': ['not_offensive', 'offensive'], 104 | 'max_length': { 105 | 'bert-base-uncased': 128, 106 | 'google/bigbird-roberta-base': 128, 107 | }, 108 | 'num_special_tokens': 2, 109 | }, 110 | 'irony': { 111 | 'train': ['train', 2862], 112 | 'dev': ['validation', 955], 113 | 'test': ['test', 784], 114 | 'num_classes': 2, 115 | 'classes': ['not_irony', 'irony'], 116 | 'max_length': { 117 | 'bert-base-uncased': 128, 118 | 'google/bigbird-roberta-base': 128, 119 | }, 120 | 'num_special_tokens': 2, 121 | }, 122 | } 123 | 124 | eraser_datasets = ['cose', 'esnli', 'movies', 'multirc'] 125 | 126 | monitor_dict = { 127 | 'cose': 'dev_acc_metric_epoch', 128 | 'esnli': 'dev_macro_f1_metric_epoch', 129 | 'movies': 'dev_macro_f1_metric_epoch', 130 | 'multirc': 'dev_macro_f1_metric_epoch', 131 | 'sst': 'dev_acc_metric_epoch', 132 | 'amazon': 'dev_acc_metric_epoch', 133 | 'yelp': 'dev_acc_metric_epoch', 134 | 'stf': 'dev_binary_f1_metric_epoch', 135 | 'olid': 'dev_macro_f1_metric_epoch', 136 | 'irony': 'dev_binary_f1_metric_epoch', 137 | } 138 | 139 | data_keys = ['item_idx', 'input_ids', 'attention_mask', 'rationale', 'inv_rationale', 'rand_rationale', 'has_rationale', 'label', 'rationale_indices'] -------------------------------------------------------------------------------- /src/utils/eraser/data_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Tuple, List, Dict, Any 3 | from collections import namedtuple, defaultdict 4 | from itertools import chain 5 | from tokenizers import TextInputSequence 6 | 7 | from src.utils.eraser.utils import Annotation, Evidence, annotations_from_jsonl, load_documents 8 | 9 | SentenceEvidence = namedtuple('SentenceEvidence', 'kls ann_id query docid index sentence') 10 | 11 | logging.basicConfig(level=logging.DEBUG, format='%(relativeCreated)6d %(threadName)s %(message)s') 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def bert_tokenize_doc(doc: List[List[str]], tokenizer, special_token_map) -> Tuple[List[List[str]], List[List[Tuple[int, int]]]]: 16 | """ Tokenizes a document and returns [start, end) spans to map the wordpieces back to their source words""" 17 | sents = [] 18 | sent_token_spans = [] 19 | for sent in doc: 20 | tokens = [] 21 | spans = [] 22 | start = 0 23 | for w in sent: 24 | if w in special_token_map: 25 | tokens.append(w) 26 | else: 27 | tokens.extend(tokenizer.tokenize(w)) 28 | end = len(tokens) 29 | spans.append((start, end)) 30 | start = end 31 | sents.append(tokens) 32 | sent_token_spans.append(spans) 33 | return sents, sent_token_spans 34 | 35 | 36 | def bert_intern_doc(doc: List[List[str]], tokenizer, special_token_map) -> List[List[int]]: 37 | # return [list(chain.from_iterable(special_token_map.get(w, tokenizer.encode(w)) for w in s)) for s in doc] 38 | return [[special_token_map.get(w, tokenizer.convert_tokens_to_ids(w)) for w in s] for s in doc] 39 | 40 | 41 | def bert_intern_annotation(annotations: List[Annotation], tokenizer): 42 | ret = [] 43 | for ann in annotations: 44 | ev_groups = [] 45 | for ev_group in ann.evidences: 46 | evs = [] 47 | for ev in ev_group: 48 | text = list(chain.from_iterable(tokenizer.tokenize(w) for w in ev.text.split())) 49 | if len(text) == 0: 50 | continue 51 | # text = tokenizer.encode(text, add_special_tokens=False) 52 | text = tokenizer.convert_tokens_to_ids(text) 53 | evs.append(Evidence(text=tuple(text), 54 | docid=ev.docid, 55 | start_token=ev.start_token, 56 | end_token=ev.end_token, 57 | start_sentence=ev.start_sentence, 58 | end_sentence=ev.end_sentence)) 59 | ev_groups.append(tuple(evs)) 60 | query = list(chain.from_iterable(tokenizer.tokenize(w) for w in ann.query.split())) 61 | if len(query) > 0: 62 | # query = tokenizer.encode(query, add_special_tokens=False) 63 | query = tokenizer.convert_tokens_to_ids(query) 64 | else: 65 | query = [] 66 | ret.append(Annotation(annotation_id=ann.annotation_id, 67 | query=tuple(query), 68 | evidences=frozenset(ev_groups), 69 | classification=ann.classification, 70 | query_type=ann.query_type)) 71 | return ret 72 | 73 | 74 | def annotations_to_evidence_identification(annotations: List[Annotation], 75 | documents: Dict[str, List[List[Any]]] 76 | ) -> Dict[str, Dict[str, List[SentenceEvidence]]]: 77 | """Converts Corpus-Level annotations to Sentence Level relevance judgments. 78 | 79 | As this module is about a pipelined approach for evidence identification, 80 | inputs to both an evidence identifier and evidence classifier need to be to 81 | be on a sentence level, this module converts data to be that form. 82 | 83 | The return type is of the form 84 | annotation id -> docid -> [sentence level annotations] 85 | """ 86 | ret = defaultdict(dict) # annotation id -> docid -> sentences 87 | for ann in annotations: 88 | ann_id = ann.annotation_id 89 | for ev_group in ann.evidences: 90 | for ev in ev_group: 91 | if len(ev.text) == 0: 92 | continue 93 | if ev.docid not in ret[ann_id]: 94 | ret[ann.annotation_id][ev.docid] = [] 95 | # populate the document with "not evidence"; to be filled in later 96 | for index, sent in enumerate(documents[ev.docid]): 97 | ret[ann.annotation_id][ev.docid].append(SentenceEvidence( 98 | kls=0, 99 | query=ann.query, 100 | ann_id=ann.annotation_id, 101 | docid=ev.docid, 102 | index=index, 103 | sentence=sent)) 104 | # define the evidence sections of the document 105 | for s in range(ev.start_sentence, ev.end_sentence): 106 | ret[ann.annotation_id][ev.docid][s] = SentenceEvidence( 107 | kls=1, 108 | ann_id=ann.annotation_id, 109 | query=ann.query, 110 | docid=ev.docid, 111 | index=ret[ann.annotation_id][ev.docid][s].index, 112 | sentence=ret[ann.annotation_id][ev.docid][s].sentence) 113 | return ret 114 | 115 | 116 | def annotations_to_evidence_token_identification(annotations: List[Annotation], 117 | source_documents: Dict[str, List[List[str]]], 118 | interned_documents: Dict[str, List[List[int]]], 119 | token_mapping: Dict[str, List[List[Tuple[int, int]]]] 120 | ) -> Dict[str, Dict[str, List[SentenceEvidence]]]: 121 | # TODO document 122 | # TODO should we simplify to use only source text? 123 | ret = defaultdict(lambda: defaultdict(list)) # annotation id -> docid -> sentences 124 | positive_tokens = 0 125 | negative_tokens = 0 126 | for ann in annotations: 127 | annid = ann.annotation_id 128 | docids = set(ev.docid for ev in chain.from_iterable(ann.evidences)) 129 | sentence_offsets = defaultdict(list) # docid -> [(start, end)] 130 | classes = defaultdict(list) # docid -> [token is yea or nay] 131 | for docid in docids: 132 | start = 0 133 | assert len(source_documents[docid]) == len(interned_documents[docid]) 134 | for whole_token_sent, wordpiece_sent in zip(source_documents[docid], interned_documents[docid]): 135 | classes[docid].extend([0 for _ in wordpiece_sent]) 136 | end = start + len(wordpiece_sent) 137 | sentence_offsets[docid].append((start, end)) 138 | start = end 139 | for ev in chain.from_iterable(ann.evidences): 140 | if len(ev.text) == 0: 141 | continue 142 | flat_token_map = list(chain.from_iterable(token_mapping[ev.docid])) 143 | if ev.start_token != -1: 144 | #start, end = token_mapping[ev.docid][ev.start_token][0], token_mapping[ev.docid][ev.end_token][1] 145 | start, end = flat_token_map[ev.start_token][0], flat_token_map[ev.end_token - 1][1] 146 | else: 147 | start = flat_token_map[sentence_offsets[ev.start_sentence][0]][0] 148 | end = flat_token_map[sentence_offsets[ev.end_sentence - 1][1]][1] 149 | for i in range(start, end): 150 | classes[ev.docid][i] = 1 151 | for docid, offsets in sentence_offsets.items(): 152 | token_assignments = classes[docid] 153 | positive_tokens += sum(token_assignments) 154 | negative_tokens += len(token_assignments) - sum(token_assignments) 155 | for s, (start, end) in enumerate(offsets): 156 | sent = interned_documents[docid][s] 157 | ret[annid][docid].append(SentenceEvidence(kls=tuple(token_assignments[start:end]), 158 | query=ann.query, 159 | ann_id=ann.annotation_id, 160 | docid=docid, 161 | index=s, 162 | sentence=sent)) 163 | logging.info(f"Have {positive_tokens} positive wordpiece tokens, {negative_tokens} negative wordpiece tokens") 164 | return ret -------------------------------------------------------------------------------- /src/utils/eraser/metrics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import pprint 6 | 7 | from collections import Counter, defaultdict, namedtuple 8 | from dataclasses import dataclass 9 | from itertools import chain 10 | from typing import Any, Callable, Dict, List, Set, Tuple 11 | 12 | import numpy as np 13 | import torch 14 | 15 | from scipy.stats import entropy 16 | from sklearn.metrics import accuracy_score, auc, average_precision_score, classification_report, precision_recall_curve, roc_auc_score 17 | 18 | from rationale_benchmark.utils import ( 19 | Annotation, 20 | Evidence, 21 | annotations_from_jsonl, 22 | load_jsonl, 23 | load_documents, 24 | load_flattened_documents 25 | ) 26 | 27 | logging.basicConfig(level=logging.DEBUG, format='%(relativeCreated)6d %(threadName)s %(message)s') 28 | 29 | # start_token is inclusive, end_token is exclusive 30 | @dataclass(eq=True, frozen=True) 31 | class Rationale: 32 | ann_id: str 33 | docid: str 34 | start_token: int 35 | end_token: int 36 | 37 | def to_token_level(self) -> List['Rationale']: 38 | ret = [] 39 | for t in range(self.start_token, self.end_token): 40 | ret.append(Rationale(self.ann_id, self.docid, t, t+1)) 41 | return ret 42 | 43 | @classmethod 44 | def from_annotation(cls, ann: Annotation) -> List['Rationale']: 45 | ret = [] 46 | for ev_group in ann.evidences: 47 | for ev in ev_group: 48 | ret.append(Rationale(ann.annotation_id, ev.docid, ev.start_token, ev.end_token)) 49 | return ret 50 | 51 | @classmethod 52 | def from_instance(cls, inst: dict) -> List['Rationale']: 53 | ret = [] 54 | for rat in inst['rationales']: 55 | for pred in rat.get('hard_rationale_predictions', []): 56 | ret.append(Rationale(inst['annotation_id'], rat['docid'], pred['start_token'], pred['end_token'])) 57 | return ret 58 | 59 | @dataclass(eq=True, frozen=True) 60 | class PositionScoredDocument: 61 | ann_id: str 62 | docid: str 63 | scores: Tuple[float] 64 | truths: Tuple[bool] 65 | 66 | @classmethod 67 | def from_results(cls, instances: List[dict], annotations: List[Annotation], docs: Dict[str, List[Any]], use_tokens: bool=True) -> List['PositionScoredDocument']: 68 | """Creates a paired list of annotation ids/docids/predictions/truth values""" 69 | key_to_annotation = dict() 70 | for ann in annotations: 71 | for ev in chain.from_iterable(ann.evidences): 72 | key = (ann.annotation_id, ev.docid) 73 | if key not in key_to_annotation: 74 | key_to_annotation[key] = [False for _ in docs[ev.docid]] 75 | if use_tokens: 76 | start, end = ev.start_token, ev.end_token 77 | else: 78 | start, end = ev.start_sentence, ev.end_sentence 79 | for t in range(start, end): 80 | key_to_annotation[key][t] = True 81 | ret = [] 82 | if use_tokens: 83 | field = 'soft_rationale_predictions' 84 | else: 85 | field = 'soft_sentence_predictions' 86 | for inst in instances: 87 | for rat in inst['rationales']: 88 | docid = rat['docid'] 89 | scores = rat[field] 90 | key = (inst['annotation_id'], docid) 91 | assert len(scores) == len(docs[docid]) 92 | if key in key_to_annotation : 93 | assert len(scores) == len(key_to_annotation[key]) 94 | else : 95 | #In case model makes a prediction on docuemnt(s) for which ground truth evidence is not present 96 | key_to_annotation[key] = [False for _ in docs[docid]] 97 | ret.append(PositionScoredDocument(inst['annotation_id'], docid, tuple(scores), tuple(key_to_annotation[key]))) 98 | return ret 99 | 100 | def _f1(_p, _r): 101 | if _p == 0 or _r == 0: 102 | return 0 103 | return 2 * _p * _r / (_p + _r) 104 | 105 | def _keyed_rationale_from_list(rats: List[Rationale]) -> Dict[Tuple[str, str], Rationale]: 106 | ret = defaultdict(set) 107 | for r in rats: 108 | ret[(r.ann_id, r.docid)].add(r) 109 | return ret 110 | 111 | def partial_match_score(truth: List[Rationale], pred: List[Rationale], thresholds: List[float]) -> List[Dict[str, Any]]: 112 | """Computes a partial match F1 113 | 114 | Computes an instance-level (annotation) micro- and macro-averaged F1 score. 115 | True Positives are computed by using intersection-over-union and 116 | thresholding the resulting intersection-over-union fraction. 117 | 118 | Micro-average results are computed by ignoring instance level distinctions 119 | in the TP calculation (and recall, and precision, and finally the F1 of 120 | those numbers). Macro-average results are computed first by measuring 121 | instance (annotation + document) precisions and recalls, averaging those, 122 | and finally computing an F1 of the resulting average. 123 | """ 124 | 125 | ann_to_rat = _keyed_rationale_from_list(truth) 126 | pred_to_rat = _keyed_rationale_from_list(pred) 127 | 128 | num_classifications = {k:len(v) for k,v in pred_to_rat.items()} 129 | num_truth = {k:len(v) for k,v in ann_to_rat.items()} 130 | ious = defaultdict(dict) 131 | for k in set(ann_to_rat.keys()) | set(pred_to_rat.keys()): 132 | for p in pred_to_rat.get(k, []): 133 | best_iou = 0.0 134 | for t in ann_to_rat.get(k, []): 135 | num = len(set(range(p.start_token, p.end_token)) & set(range(t.start_token, t.end_token))) 136 | denom = len(set(range(p.start_token, p.end_token)) | set(range(t.start_token, t.end_token))) 137 | iou = 0 if denom == 0 else num / denom 138 | if iou > best_iou: 139 | best_iou = iou 140 | ious[k][p] = best_iou 141 | scores = [] 142 | for threshold in thresholds: 143 | threshold_tps = dict() 144 | for k, vs in ious.items(): 145 | threshold_tps[k] = sum(int(x >= threshold) for x in vs.values()) 146 | micro_r = sum(threshold_tps.values()) / sum(num_truth.values()) if sum(num_truth.values()) > 0 else 0 147 | micro_p = sum(threshold_tps.values()) / sum(num_classifications.values()) if sum(num_classifications.values()) > 0 else 0 148 | micro_f1 = _f1(micro_r, micro_p) 149 | macro_rs = list(threshold_tps.get(k, 0.0) / n if n > 0 else 0 for k, n in num_truth.items()) 150 | macro_ps = list(threshold_tps.get(k, 0.0) / n if n > 0 else 0 for k, n in num_classifications.items()) 151 | macro_r = sum(macro_rs) / len(macro_rs) if len(macro_rs) > 0 else 0 152 | macro_p = sum(macro_ps) / len(macro_ps) if len(macro_ps) > 0 else 0 153 | macro_f1 = _f1(macro_r, macro_p) 154 | scores.append({'threshold': threshold, 155 | 'micro': { 156 | 'p': micro_p, 157 | 'r': micro_r, 158 | 'f1': micro_f1 159 | }, 160 | 'macro': { 161 | 'p': macro_p, 162 | 'r': macro_r, 163 | 'f1': macro_f1 164 | }, 165 | }) 166 | return scores 167 | 168 | def score_hard_rationale_predictions(truth: List[Rationale], pred: List[Rationale]) -> Dict[str, Dict[str, float]]: 169 | """Computes instance (annotation)-level micro/macro averaged F1s""" 170 | scores = dict() 171 | truth = set(truth) 172 | pred = set(pred) 173 | micro_prec = len(truth & pred) / len(pred) 174 | micro_rec = len(truth & pred) / len(truth) 175 | micro_f1 = _f1(micro_prec, micro_rec) 176 | 177 | scores['instance_micro'] = { 178 | 'p': micro_prec, 179 | 'r': micro_rec, 180 | 'f1': micro_f1, 181 | } 182 | 183 | ann_to_rat = _keyed_rationale_from_list(truth) 184 | pred_to_rat = _keyed_rationale_from_list(pred) 185 | instances_to_scores = dict() 186 | for k in set(ann_to_rat.keys()) | (pred_to_rat.keys()): 187 | if len(pred_to_rat.get(k, set())) > 0: 188 | instance_prec = len(ann_to_rat.get(k, set()) & pred_to_rat.get(k, set())) / len(pred_to_rat[k]) 189 | else: 190 | instance_prec = 0 191 | if len(ann_to_rat.get(k, set())) > 0: 192 | instance_rec = len(ann_to_rat.get(k, set()) & pred_to_rat.get(k, set())) / len(ann_to_rat[k]) 193 | else: 194 | instance_rec = 0 195 | instance_f1 = _f1(instance_prec, instance_rec) 196 | instances_to_scores[k] = { 197 | 'p': instance_prec, 198 | 'r': instance_rec, 199 | 'f1': instance_f1, 200 | } 201 | # these are calculated as sklearn would 202 | macro_prec = sum(instance['p'] for instance in instances_to_scores.values()) / len(instances_to_scores) 203 | macro_rec = sum(instance['r'] for instance in instances_to_scores.values()) / len(instances_to_scores) 204 | macro_f1 = sum(instance['f1'] for instance in instances_to_scores.values()) / len(instances_to_scores) 205 | scores['instance_macro'] = { 206 | 'p': macro_prec, 207 | 'r': macro_rec, 208 | 'f1': macro_f1, 209 | } 210 | return scores 211 | 212 | def _auprc(truth: Dict[Any, List[bool]], preds: Dict[Any, List[float]]) -> float: 213 | if len(preds) == 0: 214 | return 0.0 215 | assert len(truth.keys() and preds.keys()) == len(truth.keys()) 216 | aucs = [] 217 | for k, true in truth.items(): 218 | pred = preds[k] 219 | true = [int(t) for t in true] 220 | precision, recall, _ = precision_recall_curve(true, pred) 221 | aucs.append(auc(recall, precision)) 222 | return np.average(aucs) 223 | 224 | def _score_aggregator(truth: Dict[Any, List[bool]], preds: Dict[Any, List[float]], score_function: Callable[[List[float], List[float]], float ], discard_single_class_answers: bool) -> float: 225 | if len(preds) == 0: 226 | return 0.0 227 | assert len(truth.keys() and preds.keys()) == len(truth.keys()) 228 | scores = [] 229 | for k, true in truth.items(): 230 | pred = preds[k] 231 | if (all(true) or all(not x for x in true)) and discard_single_class_answers: 232 | continue 233 | true = [int(t) for t in true] 234 | scores.append(score_function(true, pred)) 235 | return np.average(scores) 236 | 237 | def score_soft_tokens(paired_scores: List[PositionScoredDocument]) -> Dict[str, float]: 238 | truth = {(ps.ann_id, ps.docid): ps.truths for ps in paired_scores} 239 | pred = {(ps.ann_id, ps.docid): ps.scores for ps in paired_scores} 240 | auprc_score = _auprc(truth, pred) 241 | ap = _score_aggregator(truth, pred, average_precision_score, True) 242 | roc_auc = _score_aggregator(truth, pred, roc_auc_score, True) 243 | 244 | return { 245 | 'auprc': auprc_score, 246 | 'average_precision': ap, 247 | 'roc_auc_score': roc_auc, 248 | } 249 | 250 | def _instances_aopc(instances: List[dict], thresholds: List[float], key: str) -> Tuple[float, List[float]]: 251 | dataset_scores = [] 252 | for inst in instances: 253 | kls = inst['classification'] 254 | beta_0 = inst['classification_scores'][kls] 255 | instance_scores = [] 256 | for score in filter(lambda x : x['threshold'] in thresholds, sorted(inst['thresholded_scores'], key=lambda x: x['threshold'])): 257 | beta_k = score[key][kls] 258 | delta = beta_0 - beta_k 259 | instance_scores.append(delta) 260 | assert len(instance_scores) == len(thresholds) 261 | dataset_scores.append(instance_scores) 262 | dataset_scores = np.array(dataset_scores) 263 | # a careful reading of Samek, et al. "Evaluating the Visualization of What a Deep Neural Network Has Learned" 264 | # and some algebra will show the reader that we can average in any of several ways and get the same result: 265 | # over a flattened array, within an instance and then between instances, or over instances (by position) an 266 | # then across them. 267 | final_score = np.average(dataset_scores) 268 | position_scores = np.average(dataset_scores, axis=0).tolist() 269 | 270 | return final_score, position_scores 271 | 272 | def compute_aopc_scores(instances: List[dict], aopc_thresholds: List[float]): 273 | if aopc_thresholds is None : 274 | aopc_thresholds = sorted(set(chain.from_iterable([x['threshold'] for x in y['thresholded_scores']] for y in instances))) 275 | aopc_comprehensiveness_score, aopc_comprehensiveness_points = _instances_aopc(instances, aopc_thresholds, 'comprehensiveness_classification_scores') 276 | aopc_sufficiency_score, aopc_sufficiency_points = _instances_aopc(instances, aopc_thresholds, 'sufficiency_classification_scores') 277 | return aopc_thresholds, aopc_comprehensiveness_score, aopc_comprehensiveness_points, aopc_sufficiency_score, aopc_sufficiency_points 278 | 279 | def score_classifications(instances: List[dict], annotations: List[Annotation], docs: Dict[str, List[str]], aopc_thresholds: List[float]) -> Dict[str, float]: 280 | def compute_kl(cls_scores_, faith_scores_): 281 | keys = list(cls_scores_.keys()) 282 | cls_scores_ = [cls_scores_[k] for k in keys] 283 | faith_scores_ = [faith_scores_[k] for k in keys] 284 | return entropy(faith_scores_, cls_scores_) 285 | labels = list(set(x.classification for x in annotations)) 286 | label_to_int = {l:i for i,l in enumerate(labels)} 287 | key_to_instances = {inst['annotation_id']:inst for inst in instances} 288 | truth = [] 289 | predicted = [] 290 | for ann in annotations: 291 | truth.append(label_to_int[ann.classification]) 292 | inst = key_to_instances[ann.annotation_id] 293 | predicted.append(label_to_int[inst['classification']]) 294 | classification_scores = classification_report(truth, predicted, output_dict=True, target_names=labels, digits=3) 295 | accuracy = accuracy_score(truth, predicted) 296 | if 'comprehensiveness_classification_scores' in instances[0]: 297 | comprehensiveness_scores = [x['classification_scores'][x['classification']] - x['comprehensiveness_classification_scores'][x['classification']] for x in instances] 298 | comprehensiveness_score = np.average(comprehensiveness_scores) 299 | else : 300 | comprehensiveness_score = None 301 | comprehensiveness_scores = None 302 | 303 | if 'sufficiency_classification_scores' in instances[0]: 304 | sufficiency_scores = [x['classification_scores'][x['classification']] - x['sufficiency_classification_scores'][x['classification']] for x in instances] 305 | sufficiency_score = np.average(sufficiency_scores) 306 | else : 307 | sufficiency_score = None 308 | sufficiency_scores = None 309 | 310 | if 'comprehensiveness_classification_scores' in instances[0]: 311 | comprehensiveness_entropies = [entropy(list(x['classification_scores'].values())) - entropy(list(x['comprehensiveness_classification_scores'].values())) for x in instances] 312 | comprehensiveness_entropy = np.average(comprehensiveness_entropies) 313 | comprehensiveness_kl = np.average(list(compute_kl(x['classification_scores'], x['comprehensiveness_classification_scores']) for x in instances)) 314 | else: 315 | comprehensiveness_entropies = None 316 | comprehensiveness_kl = None 317 | comprehensiveness_entropy = None 318 | 319 | if 'sufficiency_classification_scores' in instances[0]: 320 | sufficiency_entropies = [entropy(list(x['classification_scores'].values())) - entropy(list(x['sufficiency_classification_scores'].values())) for x in instances] 321 | sufficiency_entropy = np.average(sufficiency_entropies) 322 | sufficiency_kl = np.average(list(compute_kl(x['classification_scores'], x['sufficiency_classification_scores']) for x in instances)) 323 | else: 324 | sufficiency_entropies = None 325 | sufficiency_kl = None 326 | sufficiency_entropy = None 327 | 328 | if 'thresholded_scores' in instances[0]: 329 | aopc_thresholds, aopc_comprehensiveness_score, aopc_comprehensiveness_points, aopc_sufficiency_score, aopc_sufficiency_points = compute_aopc_scores(instances, aopc_thresholds) 330 | else: 331 | aopc_thresholds, aopc_comprehensiveness_score, aopc_comprehensiveness_points, aopc_sufficiency_score, aopc_sufficiency_points = None, None, None, None, None 332 | if 'tokens_to_flip' in instances[0]: 333 | token_percentages = [] 334 | for ann in annotations: 335 | # in practice, this is of size 1 for everything except e-snli 336 | docids = set(ev.docid for ev in chain.from_iterable(ann.evidences)) 337 | inst = key_to_instances[ann.annotation_id] 338 | tokens = inst['tokens_to_flip'] 339 | doc_lengths = sum(len(docs[d]) for d in docids) 340 | token_percentages.append(tokens / doc_lengths) 341 | token_percentages = np.average(token_percentages) 342 | else: 343 | token_percentages = None 344 | 345 | return { 346 | 'accuracy': accuracy, 347 | 'prf': classification_scores, 348 | 'comprehensiveness': comprehensiveness_score, 349 | 'sufficiency': sufficiency_score, 350 | 'comprehensiveness_entropy': comprehensiveness_entropy, 351 | 'comprehensiveness_kl': comprehensiveness_kl, 352 | 'sufficiency_entropy': sufficiency_entropy, 353 | 'sufficiency_kl': sufficiency_kl, 354 | 'aopc_thresholds': aopc_thresholds, 355 | 'comprehensiveness_aopc': aopc_comprehensiveness_score, 356 | 'comprehensiveness_aopc_points': aopc_comprehensiveness_points, 357 | 'sufficiency_aopc': aopc_sufficiency_score, 358 | 'sufficiency_aopc_points': aopc_sufficiency_points, 359 | } 360 | 361 | def verify_instance(instance: dict, docs: Dict[str, list], thresholds: Set[float]): 362 | error = False 363 | docids = [] 364 | # verify the internal structure of these instances is correct: 365 | # * hard predictions are present 366 | # * start and end tokens are valid 367 | # * soft rationale predictions, if present, must have the same document length 368 | 369 | for rat in instance['rationales']: 370 | docid = rat['docid'] 371 | if docid not in docid: 372 | error = True 373 | logging.info(f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} could not be found as a preprocessed document! Gave up on additional processing.') 374 | continue 375 | doc_length = len(docs[docid]) 376 | for h1 in rat.get('hard_rationale_predictions', []): 377 | # verify that each token is valid 378 | # verify that no annotations overlap 379 | for h2 in rat.get('hard_rationale_predictions', []): 380 | if h1 == h2: 381 | continue 382 | if len(set(range(h1['start_token'], h1['end_token'])) & set(range(h2['start_token'], h2['end_token']))) > 0: 383 | logging.info(f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} {h1} and {h2} overlap!') 384 | error = True 385 | if h1['start_token'] > doc_length: 386 | logging.info(f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} received an impossible tokenspan: {h1} for a document of length {doc_length}') 387 | error = True 388 | if h1['end_token'] > doc_length: 389 | logging.info(f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} received an impossible tokenspan: {h1} for a document of length {doc_length}') 390 | error = True 391 | # length check for soft rationale 392 | # note that either flattened_documents or sentence-broken documents must be passed in depending on result 393 | soft_rationale_predictions = rat.get('soft_rationale_predictions', []) 394 | if len(soft_rationale_predictions) > 0 and len(soft_rationale_predictions) != doc_length: 395 | logging.info(f'Error! For instance annotation={instance["annotation_id"]}, docid={docid} expected classifications for {doc_length} tokens but have them for {len(soft_rationale_predictions)} tokens instead!') 396 | error = True 397 | 398 | # count that one appears per-document 399 | docids = Counter(docids) 400 | for docid, count in docids.items(): 401 | if count > 1: 402 | error = True 403 | logging.info('Error! For instance annotation={instance["annotation_id"]}, docid={docid} appear {count} times, may only appear once!') 404 | 405 | classification = instance.get('classification', '') 406 | if not isinstance(classification, str): 407 | logging.info(f'Error! For instance annotation={instance["annotation_id"]}, classification field {classification} is not a string!') 408 | error = True 409 | classification_scores = instance.get('classification_scores', dict()) 410 | if not isinstance(classification_scores, dict): 411 | logging.info(f'Error! For instance annotation={instance["annotation_id"]}, classification_scores field {classification_scores} is not a dict!') 412 | error = True 413 | comprehensiveness_classification_scores = instance.get('comprehensiveness_classification_scores', dict()) 414 | if not isinstance(comprehensiveness_classification_scores, dict): 415 | logging.info(f'Error! For instance annotation={instance["annotation_id"]}, comprehensiveness_classification_scores field {comprehensiveness_classification_scores} is not a dict!') 416 | error = True 417 | sufficiency_classification_scores = instance.get('sufficiency_classification_scores', dict()) 418 | if not isinstance(sufficiency_classification_scores, dict): 419 | logging.info(f'Error! For instance annotation={instance["annotation_id"]}, sufficiency_classification_scores field {sufficiency_classification_scores} is not a dict!') 420 | error = True 421 | if ('classification' in instance) != ('classification_scores' in instance): 422 | logging.info(f'Error! For instance annotation={instance["annotation_id"]}, when providing a classification, you must also provide classification scores!') 423 | error = True 424 | if ('comprehensiveness_classification_scores' in instance) and not ('classification' in instance): 425 | logging.info(f'Error! For instance annotation={instance["annotation_id"]}, when providing a classification, you must also provide a comprehensiveness_classification_score') 426 | error = True 427 | if ('sufficiency_classification_scores' in instance) and not ('classification_scores' in instance): 428 | logging.info(f'Error! For instance annotation={instance["annotation_id"]}, when providing a sufficiency_classification_score, you must also provide a classification score!') 429 | error = True 430 | if 'thresholded_scores' in instance: 431 | instance_thresholds = set(x['threshold'] for x in instance['thresholded_scores']) 432 | if instance_thresholds != thresholds: 433 | error = True 434 | logging.info('Error: {instance["thresholded_scores"]} has thresholds that differ from previous thresholds: {thresholds}') 435 | if 'comprehensiveness_classification_scores' not in instance\ 436 | or 'sufficiency_classification_scores' not in instance\ 437 | or 'classification' not in instance\ 438 | or 'classification_scores' not in instance: 439 | error = True 440 | logging.info('Error: {instance} must have comprehensiveness_classification_scores, sufficiency_classification_scores, classification, and classification_scores defined when including thresholded scores') 441 | if not all('sufficiency_classification_scores' in x for x in instance['thresholded_scores']): 442 | error = True 443 | logging.info('Error: {instance} must have sufficiency_classification_scores for every threshold') 444 | if not all('comprehensiveness_classification_scores' in x for x in instance['thresholded_scores']): 445 | error = True 446 | logging.info('Error: {instance} must have comprehensiveness_classification_scores for every threshold') 447 | return error 448 | 449 | def verify_instances(instances: List[dict], docs: Dict[str, list]): 450 | annotation_ids = list(x['annotation_id'] for x in instances) 451 | key_counter = Counter(annotation_ids) 452 | multi_occurrence_annotation_ids = list(filter(lambda kv: kv[1] > 1, key_counter.items())) 453 | error = False 454 | if len(multi_occurrence_annotation_ids) > 0: 455 | error = True 456 | logging.info(f'Error in instances: {len(multi_occurrence_annotation_ids)} appear multiple times in the annotations file: {multi_occurrence_annotation_ids}') 457 | failed_validation = set() 458 | instances_with_classification = list() 459 | instances_with_soft_rationale_predictions = list() 460 | instances_with_soft_sentence_predictions = list() 461 | instances_with_comprehensiveness_classifications = list() 462 | instances_with_sufficiency_classifications = list() 463 | instances_with_thresholded_scores = list() 464 | if 'thresholded_scores' in instances[0]: 465 | thresholds = set(x['threshold'] for x in instances[0]['thresholded_scores']) 466 | else: 467 | thresholds = None 468 | for instance in instances: 469 | instance_error = verify_instance(instance, docs, thresholds) 470 | if instance_error: 471 | error = True 472 | failed_validation.add(instance['annotation_id']) 473 | if instance.get('classification', None) != None: 474 | instances_with_classification.append(instance) 475 | if instance.get('comprehensiveness_classification_scores', None) != None: 476 | instances_with_comprehensiveness_classifications.append(instance) 477 | if instance.get('sufficiency_classification_scores', None) != None: 478 | instances_with_sufficiency_classifications.append(instance) 479 | has_soft_rationales = [] 480 | has_soft_sentences = [] 481 | for rat in instance['rationales']: 482 | if rat.get('soft_rationale_predictions', None) != None: 483 | has_soft_rationales.append(rat) 484 | if rat.get('soft_sentence_predictions', None) != None: 485 | has_soft_sentences.append(rat) 486 | if len(has_soft_rationales) > 0: 487 | instances_with_soft_rationale_predictions.append(instance) 488 | if len(has_soft_rationales) != len(instance['rationales']): 489 | error = True 490 | logging.info(f'Error: instance {instance["annotation"]} has soft rationales for some but not all reported documents!') 491 | if len(has_soft_sentences) > 0: 492 | instances_with_soft_sentence_predictions.append(instance) 493 | if len(has_soft_sentences) != len(instance['rationales']): 494 | error = True 495 | logging.info(f'Error: instance {instance["annotation"]} has soft sentences for some but not all reported documents!') 496 | if 'thresholded_scores' in instance: 497 | instances_with_thresholded_scores.append(instance) 498 | logging.info(f'Error in instances: {len(failed_validation)} instances fail validation: {failed_validation}') 499 | if len(instances_with_classification) != 0 and len(instances_with_classification) != len(instances): 500 | logging.info(f'Either all {len(instances)} must have a classification or none may, instead {len(instances_with_classification)} do!') 501 | error = True 502 | if len(instances_with_soft_sentence_predictions) != 0 and len(instances_with_soft_sentence_predictions) != len(instances): 503 | logging.info(f'Either all {len(instances)} must have a sentence prediction or none may, instead {len(instances_with_soft_sentence_predictions)} do!') 504 | error = True 505 | if len(instances_with_soft_rationale_predictions) != 0 and len(instances_with_soft_rationale_predictions) != len(instances): 506 | logging.info(f'Either all {len(instances)} must have a soft rationale prediction or none may, instead {len(instances_with_soft_rationale_predictions)} do!') 507 | error = True 508 | if len(instances_with_comprehensiveness_classifications) != 0 and len(instances_with_comprehensiveness_classifications) != len(instances): 509 | error = True 510 | logging.info(f'Either all {len(instances)} must have a comprehensiveness classification or none may, instead {len(instances_with_comprehensiveness_classifications)} do!') 511 | if len(instances_with_sufficiency_classifications) != 0 and len(instances_with_sufficiency_classifications) != len(instances): 512 | error = True 513 | logging.info(f'Either all {len(instances)} must have a sufficiency classification or none may, instead {len(instances_with_sufficiency_classifications)} do!') 514 | if len(instances_with_thresholded_scores) != 0 and len(instances_with_thresholded_scores) != len(instances): 515 | error = True 516 | logging.info(f'Either all {len(instances)} must have thresholded scores or none may, instead {len(instances_with_thresholded_scores)} do!') 517 | if error: 518 | raise ValueError('Some instances are invalid, please fix your formatting and try again') 519 | 520 | def _has_hard_predictions(results: List[dict]) -> bool: 521 | # assumes that we have run "verification" over the inputs 522 | return 'rationales' in results[0]\ 523 | and len(results[0]['rationales']) > 0\ 524 | and 'hard_rationale_predictions' in results[0]['rationales'][0]\ 525 | and results[0]['rationales'][0]['hard_rationale_predictions'] is not None\ 526 | and len(results[0]['rationales'][0]['hard_rationale_predictions']) > 0 527 | 528 | def _has_soft_predictions(results: List[dict]) -> bool: 529 | # assumes that we have run "verification" over the inputs 530 | return 'rationales' in results[0] and len(results[0]['rationales']) > 0 and 'soft_rationale_predictions' in results[0]['rationales'][0] and results[0]['rationales'][0]['soft_rationale_predictions'] is not None 531 | 532 | def _has_soft_sentence_predictions(results: List[dict]) -> bool: 533 | # assumes that we have run "verification" over the inputs 534 | return 'rationales' in results[0] and len(results[0]['rationales']) > 0 and 'soft_sentence_predictions' in results[0]['rationales'][0] and results[0]['rationales'][0]['soft_sentence_predictions'] is not None 535 | 536 | def _has_classifications(results: List[dict]) -> bool: 537 | # assumes that we have run "verification" over the inputs 538 | return 'classification' in results[0] and results[0]['classification'] is not None 539 | 540 | def main(): 541 | parser = argparse.ArgumentParser(description="""Computes rationale and final class classification scores""", formatter_class=argparse.RawTextHelpFormatter) 542 | parser.add_argument('--data_dir', dest='data_dir', required=True, help='Which directory contains a {train,val,test}.jsonl file?') 543 | parser.add_argument('--split', dest='split', required=True, help='Which of {train,val,test} are we scoring on?') 544 | parser.add_argument('--strict', dest='strict', required=False, action='store_true', default=False, help='Do we perform strict scoring?') 545 | parser.add_argument('--results', dest='results', required=True, help="""Results File 546 | Contents are expected to be jsonl of: 547 | { 548 | "annotation_id": str, required 549 | # these classifications *must not* overlap 550 | "rationales": List[ 551 | { 552 | "docid": str, required 553 | "hard_rationale_predictions": List[{ 554 | "start_token": int, inclusive, required 555 | "end_token": int, exclusive, required 556 | }], optional, 557 | # token level classifications, a value must be provided per-token 558 | # in an ideal world, these correspond to the hard-decoding above. 559 | "soft_rationale_predictions": List[float], optional. 560 | # sentence level classifications, a value must be provided for every 561 | # sentence in each document, or not at all 562 | "soft_sentence_predictions": List[float], optional. 563 | } 564 | ], 565 | # the classification the model made for the overall classification task 566 | "classification": str, optional 567 | # A probability distribution output by the model. We require this to be normalized. 568 | "classification_scores": Dict[str, float], optional 569 | # The next two fields are measures for how faithful your model is (the 570 | # rationales it predicts are in some sense causal of the prediction), and 571 | # how sufficient they are. We approximate a measure for comprehensiveness by 572 | # asking that you remove the top k%% of tokens from your documents, 573 | # running your models again, and reporting the score distribution in the 574 | # "comprehensiveness_classification_scores" field. 575 | # We approximate a measure of sufficiency by asking exactly the converse 576 | # - that you provide model distributions on the removed k%% tokens. 577 | # 'k' is determined by human rationales, and is documented in our paper. 578 | # You should determine which of these tokens to remove based on some kind 579 | # of information about your model: gradient based, attention based, other 580 | # interpretability measures, etc. 581 | # scores per class having removed k%% of the data, where k is determined by human comprehensive rationales 582 | "comprehensiveness_classification_scores": Dict[str, float], optional 583 | # scores per class having access to only k%% of the data, where k is determined by human comprehensive rationales 584 | "sufficiency_classification_scores": Dict[str, float], optional 585 | # the number of tokens required to flip the prediction - see "Is Attention Interpretable" by Serrano and Smith. 586 | "tokens_to_flip": int, optional 587 | "thresholded_scores": List[{ 588 | "threshold": float, required, 589 | "comprehensiveness_classification_scores": like "classification_scores" 590 | "sufficiency_classification_scores": like "classification_scores" 591 | }], optional. if present, then "classification" and "classification_scores" must be present 592 | } 593 | When providing one of the optional fields, it must be provided for *every* instance. 594 | The classification, classification_score, and comprehensiveness_classification_scores 595 | must together be present for every instance or absent for every instance. 596 | """) 597 | parser.add_argument('--iou_thresholds', dest='iou_thresholds', required=False, nargs='+', type=float, default=[0.5], help='''Thresholds for IOU scoring. 598 | 599 | These are used for "soft" or partial match scoring of rationale spans. 600 | A span is considered a match if the size of the intersection of the prediction 601 | and the annotation, divided by the union of the two spans, is larger than 602 | the IOU threshold. This score can be computed for arbitrary thresholds. 603 | ''') 604 | parser.add_argument('--score_file', dest='score_file', required=False, default=None, help='Where to write results?') 605 | parser.add_argument('--aopc_thresholds', nargs='+', required=False, type=float, default=[0.01, 0.05, 0.1, 0.2, 0.5], help='Thresholds for aopc Thresholds') 606 | args = parser.parse_args() 607 | results = load_jsonl(args.results) 608 | docids = set(chain.from_iterable([rat['docid'] for rat in res['rationales']] for res in results)) 609 | docs = load_flattened_documents(args.data_dir, docids) 610 | verify_instances(results, docs) 611 | # load truth 612 | annotations = annotations_from_jsonl(os.path.join(args.data_dir, args.split + '.jsonl')) 613 | docids |= set(chain.from_iterable((ev.docid for ev in chain.from_iterable(ann.evidences)) for ann in annotations)) 614 | 615 | has_final_predictions = _has_classifications(results) 616 | scores = dict() 617 | if args.strict: 618 | if not args.iou_thresholds: 619 | raise ValueError("--iou_thresholds must be provided when running strict scoring") 620 | if not has_final_predictions: 621 | raise ValueError("We must have a 'classification', 'classification_score', and 'comprehensiveness_classification_score' field in order to perform scoring!") 622 | # TODO think about offering a sentence level version of these scores. 623 | if _has_hard_predictions(results): 624 | truth = list(chain.from_iterable(Rationale.from_annotation(ann) for ann in annotations)) 625 | pred = list(chain.from_iterable(Rationale.from_instance(inst) for inst in results)) 626 | if args.iou_thresholds is not None: 627 | iou_scores = partial_match_score(truth, pred, args.iou_thresholds) 628 | scores['iou_scores'] = iou_scores 629 | # NER style scoring 630 | rationale_level_prf = score_hard_rationale_predictions(truth, pred) 631 | scores['rationale_prf'] = rationale_level_prf 632 | token_level_truth = list(chain.from_iterable(rat.to_token_level() for rat in truth)) 633 | token_level_pred = list(chain.from_iterable(rat.to_token_level() for rat in pred)) 634 | token_level_prf = score_hard_rationale_predictions(token_level_truth, token_level_pred) 635 | scores['token_prf'] = token_level_prf 636 | else: 637 | logging.info("No hard predictions detected, skipping rationale scoring") 638 | 639 | if _has_soft_predictions(results): 640 | flattened_documents = load_flattened_documents(args.data_dir, docids) 641 | paired_scoring = PositionScoredDocument.from_results(results, annotations, flattened_documents, use_tokens=True) 642 | token_scores = score_soft_tokens(paired_scoring) 643 | scores['token_soft_metrics'] = token_scores 644 | else: 645 | logging.info("No soft predictions detected, skipping rationale scoring") 646 | 647 | if _has_soft_sentence_predictions(results): 648 | documents = load_documents(args.data_dir, docids) 649 | paired_scoring = PositionScoredDocument.from_results(results, annotations, documents, use_tokens=False) 650 | sentence_scores = score_soft_tokens(paired_scoring) 651 | scores['sentence_soft_metrics'] = sentence_scores 652 | else: 653 | logging.info("No sentence level predictions detected, skipping sentence-level diagnostic") 654 | 655 | if has_final_predictions: 656 | flattened_documents = load_flattened_documents(args.data_dir, docids) 657 | class_results = score_classifications(results, annotations, flattened_documents, args.aopc_thresholds) 658 | scores['classification_scores'] = class_results 659 | else: 660 | logging.info("No classification scores detected, skipping classification") 661 | 662 | pprint.pprint(scores) 663 | 664 | if args.score_file: 665 | with open(args.score_file, 'w') as of: 666 | json.dump(scores, of, indent=4, sort_keys=True) 667 | 668 | if __name__ == '__main__': 669 | main() 670 | -------------------------------------------------------------------------------- /src/utils/eraser/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from dataclasses import dataclass, asdict, is_dataclass 5 | from itertools import chain 6 | from typing import Dict, List, Set, Tuple, Union, FrozenSet 7 | 8 | 9 | @dataclass(eq=True, frozen=True) 10 | class Evidence: 11 | """ 12 | (docid, start_token, end_token) form the only official Evidence; sentence level annotations are for convenience. 13 | Args: 14 | text: Some representation of the evidence text 15 | docid: Some identifier for the document 16 | start_token: The canonical start token, inclusive 17 | end_token: The canonical end token, exclusive 18 | start_sentence: Best guess start sentence, inclusive 19 | end_sentence: Best guess end sentence, exclusive 20 | """ 21 | text: Union[str, Tuple[int], Tuple[str]] 22 | docid: str 23 | start_token: int = -1 24 | end_token: int = -1 25 | start_sentence: int = -1 26 | end_sentence: int = -1 27 | 28 | 29 | @dataclass(eq=True, frozen=True) 30 | class Annotation: 31 | """ 32 | Args: 33 | annotation_id: unique ID for this annotation element 34 | query: some representation of a query string 35 | evidences: a set of "evidence groups". 36 | Each evidence group is: 37 | * sufficient to respond to the query (or justify an answer) 38 | * composed of one or more Evidences 39 | * may have multiple documents in it (depending on the dataset) 40 | - e-snli has multiple documents 41 | - other datasets do not 42 | classification: str 43 | query_type: Optional str, additional information about the query 44 | docids: a set of docids in which one may find evidence. 45 | """ 46 | annotation_id: str 47 | query: Union[str, Tuple[int]] 48 | evidences: Union[Set[Tuple[Evidence]], FrozenSet[Tuple[Evidence]]] 49 | classification: str 50 | query_type: str = None 51 | docids: Set[str] = None 52 | 53 | def all_evidences(self) -> Tuple[Evidence]: 54 | return tuple(list(chain.from_iterable(self.evidences))) 55 | 56 | 57 | def annotations_to_jsonl(annotations, output_file): 58 | with open(output_file, 'w') as of: 59 | for ann in sorted(annotations, key=lambda x: x.annotation_id): 60 | as_json = _annotation_to_dict(ann) 61 | as_str = json.dumps(as_json, sort_keys=True) 62 | of.write(as_str) 63 | of.write('\n') 64 | 65 | 66 | def _annotation_to_dict(dc): 67 | # convenience method 68 | if is_dataclass(dc): 69 | d = asdict(dc) 70 | ret = dict() 71 | for k, v in d.items(): 72 | ret[k] = _annotation_to_dict(v) 73 | return ret 74 | elif isinstance(dc, dict): 75 | ret = dict() 76 | for k, v in dc.items(): 77 | k = _annotation_to_dict(k) 78 | v = _annotation_to_dict(v) 79 | ret[k] = v 80 | return ret 81 | elif isinstance(dc, str): 82 | return dc 83 | elif isinstance(dc, (set, frozenset, list, tuple)): 84 | ret = [] 85 | for x in dc: 86 | ret.append(_annotation_to_dict(x)) 87 | return tuple(ret) 88 | else: 89 | return dc 90 | 91 | 92 | def load_jsonl(fp: str) -> List[dict]: 93 | ret = [] 94 | with open(fp, 'r') as inf: 95 | for line in inf: 96 | content = json.loads(line) 97 | ret.append(content) 98 | return ret 99 | 100 | 101 | def write_jsonl(jsonl, output_file): 102 | with open(output_file, 'w') as of: 103 | for js in jsonl: 104 | as_str = json.dumps(js, sort_keys=True) 105 | of.write(as_str) 106 | of.write('\n') 107 | 108 | 109 | def annotations_from_jsonl(fp: str) -> List[Annotation]: 110 | ret = [] 111 | with open(fp, 'r') as inf: 112 | for line in inf: 113 | content = json.loads(line) 114 | ev_groups = [] 115 | for ev_group in content['evidences']: 116 | ev_group = tuple([Evidence(**ev) for ev in ev_group]) 117 | ev_groups.append(ev_group) 118 | content['evidences'] = frozenset(ev_groups) 119 | ret.append(Annotation(**content)) 120 | return ret 121 | 122 | 123 | def load_datasets(data_dir: str) -> Tuple[List[Annotation], List[Annotation], List[Annotation]]: 124 | """Loads a training, validation, and test dataset 125 | 126 | Each dataset is assumed to have been serialized by annotations_to_jsonl, 127 | that is it is a list of json-serialized Annotation instances. 128 | """ 129 | train_data = annotations_from_jsonl(os.path.join(data_dir, 'train.jsonl')) 130 | val_data = annotations_from_jsonl(os.path.join(data_dir, 'val.jsonl')) 131 | test_data = annotations_from_jsonl(os.path.join(data_dir, 'test.jsonl')) 132 | return train_data, val_data, test_data 133 | 134 | 135 | def load_documents(data_dir: str, docids: Set[str] = None) -> Dict[str, List[List[str]]]: 136 | """Loads a subset of available documents from disk. 137 | 138 | Each document is assumed to be serialized as newline ('\n') separated sentences. 139 | Each sentence is assumed to be space (' ') joined tokens. 140 | """ 141 | if os.path.exists(os.path.join(data_dir, 'docs.jsonl')): 142 | assert not os.path.exists(os.path.join(data_dir, 'docs')) 143 | return load_documents_from_file(data_dir, docids) 144 | 145 | docs_dir = os.path.join(data_dir, 'docs') 146 | res = dict() 147 | if docids is None: 148 | docids = sorted(os.listdir(docs_dir)) 149 | else: 150 | docids = sorted(set(str(d) for d in docids)) 151 | for d in docids: 152 | with open(os.path.join(docs_dir, d), 'r') as inf: 153 | lines = [l.strip() for l in inf.readlines()] 154 | lines = list(filter(lambda x: bool(len(x)), lines)) 155 | tokenized = [list(filter(lambda x: bool(len(x)), line.strip().split(' '))) for line in lines] 156 | res[d] = tokenized 157 | return res 158 | 159 | 160 | def load_flattened_documents(data_dir: str, docids: Set[str]) -> Dict[str, List[str]]: 161 | """Loads a subset of available documents from disk. 162 | 163 | Returns a tokenized version of the document. 164 | """ 165 | unflattened_docs = load_documents(data_dir, docids) 166 | flattened_docs = dict() 167 | for doc, unflattened in unflattened_docs.items(): 168 | flattened_docs[doc] = list(chain.from_iterable(unflattened)) 169 | return flattened_docs 170 | 171 | 172 | def intern_documents(documents: Dict[str, List[List[str]]], word_interner: Dict[str, int], unk_token: str): 173 | """ 174 | Replaces every word with its index in an embeddings file. 175 | 176 | If a word is not found, uses the unk_token instead 177 | """ 178 | ret = dict() 179 | unk = word_interner[unk_token] 180 | for docid, sentences in documents.items(): 181 | ret[docid] = [[word_interner.get(w, unk) for w in s] for s in sentences] 182 | return ret 183 | 184 | 185 | def intern_annotations(annotations: List[Annotation], word_interner: Dict[str, int], unk_token: str): 186 | ret = [] 187 | for ann in annotations: 188 | ev_groups = [] 189 | for ev_group in ann.evidences: 190 | evs = [] 191 | for ev in ev_group: 192 | evs.append(Evidence( 193 | text=tuple([word_interner.get(t, word_interner[unk_token]) for t in ev.text.split()]), 194 | docid=ev.docid, 195 | start_token=ev.start_token, 196 | end_token=ev.end_token, 197 | start_sentence=ev.start_sentence, 198 | end_sentence=ev.end_sentence)) 199 | ev_groups.append(tuple(evs)) 200 | ret.append(Annotation(annotation_id=ann.annotation_id, 201 | query=tuple([word_interner.get(t, word_interner[unk_token]) for t in ann.query.split()]), 202 | evidences=frozenset(ev_groups), 203 | classification=ann.classification, 204 | query_type=ann.query_type)) 205 | return ret 206 | 207 | 208 | def load_documents_from_file(data_dir: str, docids: Set[str] = None) -> Dict[str, List[List[str]]]: 209 | """Loads a subset of available documents from 'docs.jsonl' file on disk. 210 | 211 | Each document is assumed to be serialized as newline ('\n') separated sentences. 212 | Each sentence is assumed to be space (' ') joined tokens. 213 | """ 214 | docs_file = os.path.join(data_dir, 'docs.jsonl') 215 | documents = load_jsonl(docs_file) 216 | documents = {doc['docid']: doc['document'] for doc in documents} 217 | res = dict() 218 | if docids is None: 219 | docids = sorted(list(documents.keys())) 220 | else: 221 | docids = sorted(set(str(d) for d in docids)) 222 | for d in docids: 223 | lines = documents[d].split('\n') 224 | tokenized = [line.strip().split(' ') for line in lines] 225 | res[d] = tokenized 226 | return res 227 | -------------------------------------------------------------------------------- /src/utils/expl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from captum.attr import IntegratedGradients, GradientShap, InputXGradient, Saliency, DeepLift 3 | 4 | attr_algos = { 5 | 'integrated-gradients' : IntegratedGradients, 6 | 'gradient-shap' : GradientShap, 7 | 'input-x-gradient': InputXGradient, 8 | 'saliency': Saliency, 9 | 'deep-lift': DeepLift, 10 | } 11 | 12 | baseline_required = { 13 | 'integrated-gradients' : True, 14 | 'gradient-shap': True, 15 | 'input-x-gradient': False, 16 | 'saliency': False, 17 | 'deep-lift': True, 18 | } 19 | 20 | 21 | def calc_expl(attrs, k, attn_mask, min_val=-1e10): 22 | num_tokens = torch.sum(attn_mask, dim=1) - 1 # don't include CLS token when computing num_tokens 23 | num_highlight_tokens = torch.round(num_tokens * k / 100) 24 | ones = torch.ones_like(num_highlight_tokens) 25 | num_highlight_tokens = torch.maximum(num_highlight_tokens, ones).long() 26 | 27 | attrs = attrs + (1 - attn_mask) * min_val # ignore pad tokens when computing sorted_attrs_indices 28 | attrs[:, 0] = min_val # don't include CLS token when computing sorted_attrs_indices 29 | sorted_attrs_indices = torch.argsort(attrs, dim=1, descending=True) 30 | 31 | expl = torch.zeros_like(attn_mask).long() 32 | for i in range(len(attrs)): 33 | salient_indices = sorted_attrs_indices[i][:num_highlight_tokens[i]] 34 | expl[i, salient_indices] = 1 35 | expl[:, 0] = 1 # always treat CLS token as positive token 36 | 37 | return expl -------------------------------------------------------------------------------- /src/utils/logging.py: -------------------------------------------------------------------------------- 1 | import getpass, logging, socket 2 | from typing import Any, List 3 | import torch 4 | from omegaconf.dictconfig import DictConfig 5 | from omegaconf.omegaconf import OmegaConf 6 | from pytorch_lightning.loggers import NeptuneLogger 7 | from src.utils.metrics import calc_preds, get_step_metrics, get_epoch_metrics 8 | 9 | API_LIST = { 10 | "neptune": { 11 | 'your-local-username': 'your-api-token', 12 | }, 13 | } 14 | 15 | 16 | def get_username(): 17 | return getpass.getuser() 18 | 19 | def flatten_cfg(cfg: Any) -> dict: 20 | if isinstance(cfg, dict): 21 | ret = {} 22 | for k, v in cfg.items(): 23 | flatten: dict = flatten_cfg(v) 24 | ret.update({ 25 | f"{k}/{f}" if f else k: fv 26 | for f, fv in flatten.items() 27 | }) 28 | return ret 29 | return {"": cfg} 30 | 31 | def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: 32 | logger = logging.getLogger(name) 33 | logger.setLevel(level) 34 | return logger 35 | 36 | def get_neptune_logger( 37 | cfg: DictConfig, project_name: str, 38 | name: str, tag_attrs: List[str], log_db: str, 39 | offline: bool, logger: str, 40 | ): 41 | neptune_api_key = API_LIST["neptune"][get_username()] 42 | 43 | # flatten cfg 44 | args_dict = { 45 | **flatten_cfg(OmegaConf.to_object(cfg)), 46 | "hostname": socket.gethostname() 47 | } 48 | tags = tag_attrs 49 | if cfg.model.expl_reg: 50 | tags.append('expl_reg') 51 | 52 | tags.append(log_db) 53 | 54 | neptune_logger = NeptuneLogger( 55 | api_key=neptune_api_key, 56 | project_name=project_name, 57 | experiment_name=name, 58 | params=args_dict, 59 | tags=tags, 60 | offline_mode=offline, 61 | ) 62 | 63 | try: 64 | # for unknown reason, must access this field otherwise becomes None 65 | print(neptune_logger.experiment) 66 | except BaseException: 67 | pass 68 | 69 | return neptune_logger 70 | 71 | def log_data_to_neptune(model_class, data, data_name, data_type, suffix, split, ret_dict=None, topk=None, detach_data=True): 72 | if topk: 73 | for i, k in enumerate(topk): 74 | model_class.log(f'{split}_{data_name}_{k}_{data_type}_{suffix}', data[i].detach(), prog_bar=True, sync_dist=(split != 'train')) 75 | if ret_dict is not None: 76 | ret_dict[f'{data_name}_{k}_{data_type}'] = data[i].detach() if detach_data else data[i] 77 | else: 78 | data_key = 'loss' if f'{data_name}_{data_type}' == 'total_loss' else f'{data_name}_{data_type}' 79 | model_class.log(f'{split}_{data_key}_{suffix}', data.detach(), prog_bar=True, sync_dist=(split != 'train')) 80 | if ret_dict is not None: 81 | ret_dict[data_key] = data.detach() if detach_data else data 82 | 83 | return ret_dict 84 | 85 | def log_step_losses(model_class, loss_dict, ret_dict, do_expl_reg, split): 86 | ret_dict = log_data_to_neptune(model_class, loss_dict['loss'], 'total', 'loss', 'step', split, ret_dict, topk=None, detach_data=False) 87 | ret_dict = log_data_to_neptune(model_class, loss_dict['task_loss'], 'task', 'loss', 'step', split, ret_dict, topk=None) 88 | if do_expl_reg: 89 | ret_dict = log_data_to_neptune(model_class, loss_dict['expl_loss'], 'expl', 'loss', 'step', split, ret_dict, topk=None) 90 | if model_class.comp_wt > 0: 91 | ret_dict = log_data_to_neptune(model_class, loss_dict['comp_loss'], 'comp', 'loss', 'step', split, ret_dict, topk=None) 92 | ret_dict = log_data_to_neptune(model_class, loss_dict['comp_losses'], 'comp', 'loss', 'step', split, ret_dict, topk=model_class.topk[split]) 93 | if model_class.suff_wt > 0: 94 | ret_dict = log_data_to_neptune(model_class, loss_dict['suff_loss'], 'suff', 'loss', 'step', split, ret_dict, topk=None) 95 | ret_dict = log_data_to_neptune(model_class, loss_dict['suff_losses'], 'suff', 'loss', 'step', split, ret_dict, topk=model_class.topk[split]) 96 | if model_class.plaus_wt > 0 and loss_dict.get('plaus_loss'): 97 | ret_dict = log_data_to_neptune(model_class, loss_dict['plaus_loss'], 'plaus', 'loss', 'step', split, ret_dict, topk=None) 98 | if model_class.l2e and loss_dict.get('l2e_loss'): 99 | ret_dict = log_data_to_neptune(model_class, loss_dict['l2e_loss'], 'l2e', 'loss', 'step', split, ret_dict, topk=None) 100 | if model_class.a2r and loss_dict.get('a2r_loss'): 101 | ret_dict = log_data_to_neptune(model_class, loss_dict['a2r_loss'], 'a2r', 'loss', 'step', split, ret_dict, topk=None) 102 | 103 | return ret_dict 104 | 105 | def log_step_metrics(model_class, metric_dict, ret_dict, split): 106 | ret_dict = log_data_to_neptune(model_class, metric_dict['comp_aopc'], 'comp_aopc', 'metric', 'step', split, ret_dict, topk=None) 107 | ret_dict = log_data_to_neptune(model_class, metric_dict['comps'], 'comp', 'metric', 'step', split, ret_dict, topk=model_class.topk[split]) 108 | ret_dict = log_data_to_neptune(model_class, metric_dict['suff_aopc'], 'suff_aopc', 'metric', 'step', split, ret_dict, topk=None) 109 | ret_dict = log_data_to_neptune(model_class, metric_dict['suffs'], 'suff', 'metric', 'step', split, ret_dict, topk=model_class.topk[split]) 110 | if model_class.log_odds: 111 | ret_dict = log_data_to_neptune(model_class, metric_dict['log_odds_aopc'], 'log_odds_aopc', 'metric', 'step', split, ret_dict, topk=None) 112 | ret_dict = log_data_to_neptune(model_class, metric_dict['log_odds'], 'log_odds', 'metric', 'step', split, ret_dict, topk=model_class.topk[split]) 113 | ret_dict = log_data_to_neptune(model_class, metric_dict['comp_aopc']-metric_dict['suff_aopc'], 'csd_aopc', 'metric', 'step', split, ret_dict, topk=None) 114 | ret_dict = log_data_to_neptune(model_class, metric_dict['comps']-metric_dict['suffs'], 'csd', 'metric', 'step', split, ret_dict, topk=model_class.topk[split]) 115 | if metric_dict.get('plaus_auprc'): 116 | ret_dict = log_data_to_neptune(model_class, metric_dict['plaus_auprc'], 'plaus_auprc', 'metric', 'step', split, ret_dict, topk=None) 117 | ret_dict = log_data_to_neptune(model_class, metric_dict['plaus_token_f1'], 'plaus_token_f1', 'metric', 'step', split, ret_dict, topk=None) 118 | return ret_dict 119 | 120 | def log_epoch_losses(model_class, outputs, split): 121 | loss = torch.stack([x['loss'] for x in outputs]).mean() 122 | task_loss = torch.stack([x['task_loss'] for x in outputs]).mean() 123 | log_data_to_neptune(model_class, loss, 'total', 'loss', 'epoch', split, ret_dict=None, topk=None) 124 | log_data_to_neptune(model_class, task_loss, 'task', 'loss', 'epoch', split, ret_dict=None, topk=None) 125 | 126 | if model_class.expl_reg: 127 | assert len([x.get('expl_loss') for x in outputs if x is not None]) > 0 128 | 129 | expl_loss = torch.stack([x.get('expl_loss') for x in outputs if x is not None]).mean() 130 | log_data_to_neptune(model_class, expl_loss, 'expl', 'loss', 'epoch', split, ret_dict=None, topk=None) 131 | 132 | if model_class.comp_wt > 0: 133 | comp_loss = torch.stack([x.get('comp_loss') for x in outputs if x.get('comp_loss') is not None]).mean() 134 | log_data_to_neptune(model_class, comp_loss, 'comp', 'loss', 'epoch', split, ret_dict=None, topk=None) 135 | comp_losses = torch.stack([torch.stack([x.get(f'comp_{k}_loss') for x in outputs if x.get(f'comp_{k}_loss') is not None]).mean() for k in model_class.topk[split]]) 136 | log_data_to_neptune(model_class, comp_losses, 'comp', 'loss', 'epoch', split, ret_dict=None, topk=model_class.topk[split]) 137 | if model_class.suff_wt > 0: 138 | suff_loss = torch.stack([x.get('suff_loss') for x in outputs if x.get('suff_loss') is not None]).mean() 139 | log_data_to_neptune(model_class, suff_loss, 'suff', 'loss', 'epoch', split, ret_dict=None, topk=None) 140 | suff_losses = torch.stack([torch.stack([x.get(f'suff_{k}_loss') for x in outputs if x.get(f'suff_{k}_loss') is not None]).mean() for k in model_class.topk[split]]) 141 | log_data_to_neptune(model_class, suff_losses, 'suff', 'loss', 'epoch', split, ret_dict=None, topk=model_class.topk[split]) 142 | if model_class.plaus_wt > 0 and outputs[0].get('plaus_loss'): 143 | plaus_loss = torch.stack([x.get('plaus_loss') for x in outputs if x.get('plaus_loss') is not None]).mean() 144 | log_data_to_neptune(model_class, plaus_loss, 'plaus', 'loss', 'epoch', split, ret_dict=None, topk=None) 145 | if model_class.l2e and outputs[0].get('l2e_loss'): 146 | l2e_loss = torch.stack([x.get('l2e_loss') for x in outputs if x.get('l2e_loss') is not None]).mean() 147 | log_data_to_neptune(model_class, l2e_loss, 'l2e', 'loss', 'epoch', split, ret_dict=None, topk=None) 148 | if model_class.a2r and outputs[0].get('a2r_loss'): 149 | a2r_loss = torch.stack([x.get('a2r_loss') for x in outputs if x.get('a2r_loss') is not None]).mean() 150 | log_data_to_neptune(model_class, a2r_loss, 'a2r', 'loss', 'epoch', split, ret_dict=None, topk=None) 151 | 152 | def log_epoch_metrics(model_class, outputs, split): 153 | logits = torch.cat([x['logits'] for x in outputs]) 154 | targets = torch.cat([x['targets'] for x in outputs]) 155 | preds = calc_preds(logits) 156 | 157 | perf_metrics = get_step_metrics(preds, targets, model_class.perf_metrics) 158 | perf_metrics = get_epoch_metrics(model_class.perf_metrics) 159 | 160 | log_data_to_neptune(model_class, perf_metrics['acc'], 'acc', 'metric', 'epoch', split, ret_dict=None, topk=None) 161 | log_data_to_neptune(model_class, perf_metrics['macro_f1'], 'macro_f1', 'metric', 'epoch', split, ret_dict=None, topk=None) 162 | log_data_to_neptune(model_class, perf_metrics['micro_f1'], 'micro_f1', 'metric', 'epoch', split, ret_dict=None, topk=None) 163 | if model_class.num_classes == 2: 164 | log_data_to_neptune(model_class, perf_metrics['binary_f1'], 'binary_f1', 'metric', 'epoch', split, ret_dict=None, topk=None) 165 | 166 | assert len([x.get('comp_aopc_metric') for x in outputs if x.get('comp_aopc_metric') is not None]) > 0 167 | comp_aopc = torch.stack([x.get('comp_aopc_metric') for x in outputs if x.get('comp_aopc_metric') is not None]).mean() 168 | comps = torch.stack([torch.stack([x.get(f'comp_{k}_metric') for x in outputs if x.get(f'comp_{k}_metric') is not None]).mean() for k in model_class.topk[split]]) 169 | log_data_to_neptune(model_class, comp_aopc, 'comp_aopc', 'metric', 'epoch', split, ret_dict=None, topk=None) 170 | log_data_to_neptune(model_class, comps, 'comp', 'metric', 'epoch', split, ret_dict=None, topk=model_class.topk[split]) 171 | 172 | assert len([x.get('suff_aopc_metric') for x in outputs if x.get('suff_aopc_metric') is not None]) > 0 173 | suff_aopc = torch.stack([x.get('suff_aopc_metric') for x in outputs if x.get('suff_aopc_metric') is not None]).mean() 174 | suffs = torch.stack([torch.stack([x.get(f'suff_{k}_metric') for x in outputs if x.get(f'suff_{k}_metric') is not None]).mean() for k in model_class.topk[split]]) 175 | log_data_to_neptune(model_class, suff_aopc, 'suff_aopc', 'metric', 'epoch', split, ret_dict=None, topk=None) 176 | log_data_to_neptune(model_class, suffs, 'suff', 'metric', 'epoch', split, ret_dict=None, topk=model_class.topk[split]) 177 | 178 | if model_class.log_odds: 179 | assert len([x.get('log_odds_aopc_metric') for x in outputs if x.get('log_odds_aopc_metric') is not None]) > 0 180 | log_odds_aopc = torch.stack([x.get('log_odds_aopc_metric') for x in outputs if x.get('log_odds_aopc_metric') is not None]).mean() 181 | log_odds = torch.stack([torch.stack([x.get(f'log_odds_{k}_metric') for x in outputs if x.get(f'log_odds_{k}_metric') is not None]).mean() for k in model_class.topk[split]]) 182 | log_data_to_neptune(model_class, log_odds_aopc, 'log_odds_aopc', 'metric', 'epoch', split, ret_dict=None, topk=None) 183 | log_data_to_neptune(model_class, log_odds, 'log_odds', 'metric', 'epoch', split, ret_dict=None, topk=model_class.topk[split]) 184 | 185 | csd_aopc = torch.stack([x.get('csd_aopc_metric') for x in outputs if x.get('csd_aopc_metric') is not None]).mean() 186 | csds = torch.stack([torch.stack([x.get(f'csd_{k}_metric') for x in outputs if x.get(f'csd_{k}_metric') is not None]).mean() for k in model_class.topk[split]]) 187 | log_data_to_neptune(model_class, csd_aopc, 'csd_aopc', 'metric', 'epoch', split, ret_dict=None, topk=None) 188 | log_data_to_neptune(model_class, csds, 'csd', 'metric', 'epoch', split, ret_dict=None, topk=model_class.topk[split]) 189 | 190 | if outputs[0].get('plaus_auprc_metric'): 191 | assert len([x.get('plaus_auprc_metric') for x in outputs if x.get('plaus_auprc_metric') is not None]) > 0 192 | plaus_auprc = torch.stack([x.get('plaus_auprc_metric') for x in outputs if x.get('plaus_auprc_metric') is not None]).mean() 193 | log_data_to_neptune(model_class, plaus_auprc, 'plaus_auprc', 'metric', 'epoch', split, ret_dict=None, topk=None) 194 | 195 | assert len([x.get('plaus_token_f1_metric') for x in outputs if x.get('plaus_token_f1_metric') is not None]) > 0 196 | plaus_token_f1 = torch.stack([x.get('plaus_token_f1_metric') for x in outputs if x.get('plaus_token_f1_metric') is not None]).mean() 197 | log_data_to_neptune(model_class, plaus_token_f1, 'plaus_token_f1', 'metric', 'epoch', split, ret_dict=None, topk=None) 198 | 199 | if 'delta' in outputs[0].keys(): 200 | delta = torch.abs(torch.cat([x['delta'] for x in outputs])).mean() 201 | log_data_to_neptune(model_class, delta, 'convergence_delta', 'metric', 'epoch', split, ret_dict=None, topk=None) -------------------------------------------------------------------------------- /src/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def calc_task_loss(logits, targets, reduction='mean', class_weights=None): 6 | assert len(logits) == len(targets) 7 | return F.cross_entropy(logits, targets, weight=class_weights, reduction=reduction) 8 | 9 | def calc_comp_loss(comp_logits, comp_targets, task_losses, comp_criterion, topk, comp_margin=None): 10 | inv_expl_losses = calc_task_loss(comp_logits, comp_targets, reduction='none').reshape(len(topk), -1) 11 | if comp_criterion == 'diff': 12 | comp_losses = task_losses - inv_expl_losses 13 | elif comp_criterion == 'margin': 14 | assert comp_margin is not None 15 | comp_margins = comp_margin * torch.ones_like(inv_expl_losses) 16 | comp_losses = torch.maximum(-comp_margins, task_losses - inv_expl_losses) + comp_margins 17 | else: 18 | raise NotImplementedError 19 | 20 | assert not torch.any(torch.isnan(comp_losses)) 21 | return torch.mean(comp_losses, dim=1) 22 | 23 | def calc_suff_loss(suff_logits, suff_targets, task_losses, suff_criterion, topk, suff_margin=None, task_logits=None): 24 | if suff_criterion == 'kldiv': 25 | assert task_logits is not None 26 | batch_size = len(task_logits) 27 | task_distr = F.log_softmax(task_logits, dim=1).unsqueeze(0).expand(len(topk), -1, -1).reshape(len(topk) * batch_size, -1) 28 | suff_distr = F.softmax(suff_logits, dim=1) 29 | suff_losses = F.kl_div(task_distr, suff_distr, reduction='none').reshape(len(topk), -1) 30 | else: 31 | expl_losses = calc_task_loss(suff_logits, suff_targets, reduction='none').reshape(len(topk), -1) 32 | if suff_criterion == 'diff': 33 | suff_losses = expl_losses - task_losses 34 | elif suff_criterion == 'margin': 35 | suff_margins = suff_margin * torch.ones_like(expl_losses) 36 | suff_losses = torch.maximum(-suff_margins, expl_losses - task_losses) + suff_margins 37 | elif suff_criterion == 'mae': 38 | suff_losses = F.l1_loss(expl_losses, task_losses, reduction='none') 39 | elif suff_criterion == 'mse': 40 | suff_losses = F.mse_loss(expl_losses, task_losses, reduction='none') 41 | else: 42 | raise NotImplementedError 43 | 44 | assert not torch.any(torch.isnan(suff_losses)) 45 | return torch.mean(suff_losses, dim=1) 46 | 47 | def calc_plaus_loss(attrs, rationale, attn_mask, plaus_criterion, plaus_margin=None, has_rationale=None): 48 | if plaus_criterion == 'margin': 49 | raise NotImplementedError 50 | plaus_margins = attn_mask * plaus_margin 51 | inv_rationale = (1 - rationale) * attn_mask 52 | plaus_loss = (-rationale + inv_rationale) * attrs 53 | assert not torch.any(torch.isnan(plaus_loss)) 54 | plaus_loss = torch.maximum(-plaus_margins, plaus_loss) + plaus_margins 55 | plaus_loss = torch.sum(plaus_loss) / torch.sum(attn_mask) 56 | elif plaus_criterion == 'bce': 57 | assert has_rationale is not None 58 | max_length = attn_mask.shape[1] 59 | has_rationale_ = has_rationale.unsqueeze(1).repeat(1, max_length) * attn_mask 60 | rationale = rationale * has_rationale_ 61 | num_tokens = has_rationale_.sum() 62 | plaus_pos_wt = (num_tokens - rationale.sum()) / rationale.sum() 63 | plaus_loss = (F.binary_cross_entropy_with_logits(attrs, rationale, pos_weight=plaus_pos_wt, reduction='none') * has_rationale_).sum() 64 | if num_tokens > 0: 65 | plaus_loss /= num_tokens 66 | else: 67 | assert plaus_loss == 0 68 | assert not torch.any(torch.isnan(plaus_loss)) 69 | else: 70 | raise NotImplementedError 71 | 72 | assert not torch.isnan(plaus_loss) 73 | return plaus_loss 74 | 75 | def calc_l2e_loss(l2e_attrs, l2e_rationale, attn_mask, l2e_criterion): 76 | if l2e_criterion == 'ce': 77 | num_tokens = attn_mask.sum() 78 | num_classes = l2e_attrs.shape[2] 79 | l2e_attrs = l2e_attrs.reshape(-1, num_classes) 80 | l2e_loss = (F.cross_entropy(l2e_attrs, l2e_rationale.flatten(), reduction='none') * attn_mask.flatten()).sum() 81 | if num_tokens > 0: 82 | l2e_loss /= num_tokens 83 | else: 84 | assert l2e_loss == 0 85 | assert not torch.any(torch.isnan(l2e_loss)) 86 | else: 87 | raise NotImplementedError 88 | 89 | assert not torch.isnan(l2e_loss) 90 | return l2e_loss 91 | 92 | def calc_a2r_loss(logits, a2r_logits, a2r_criterion): 93 | if a2r_criterion == 'jsd': 94 | a2r_loss = js_div(logits, a2r_logits) 95 | else: 96 | raise NotImplementedError 97 | assert not torch.isnan(a2r_loss) 98 | return a2r_loss 99 | 100 | def js_div(logits_1, logits_2, reduction='batchmean'): 101 | probs_m = (F.softmax(logits_1, dim=1) + F.softmax(logits_2, dim=1)) / 2.0 102 | 103 | loss_1 = F.kl_div( 104 | F.log_softmax(logits_1, dim=1), 105 | probs_m, 106 | reduction=reduction 107 | ) 108 | 109 | loss_2 = F.kl_div( 110 | F.log_softmax(logits_2, dim=1), 111 | probs_m, 112 | reduction=reduction 113 | ) 114 | 115 | loss = 0.5 * (loss_1 + loss_2) 116 | 117 | return loss -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import torchmetrics 5 | from sklearn.metrics import precision_recall_curve, auc, f1_score, average_precision_score 6 | 7 | 8 | def init_best_metrics(): 9 | return { 10 | 'best_epoch': 0, 11 | 'dev_best_perf': None, 12 | 'test_best_perf': None, 13 | } 14 | 15 | def init_perf_metrics(num_classes): 16 | perf_metrics = torch.nn.ModuleDict({ 17 | 'acc': torchmetrics.Accuracy(), 18 | 'macro_f1': torchmetrics.F1(num_classes=num_classes, average='macro'), 19 | 'micro_f1': torchmetrics.F1(num_classes=num_classes, average='micro'), 20 | }) 21 | 22 | assert num_classes >= 2 23 | if num_classes == 2: 24 | perf_metrics['binary_f1'] = torchmetrics.F1(num_classes=num_classes, average='micro', ignore_index=0) 25 | 26 | return perf_metrics 27 | 28 | def calc_preds(logits): 29 | return torch.argmax(logits, dim=1) 30 | 31 | def calc_comp(logits, inv_expl_logits, targets=None, comp_target=False): 32 | assert not (comp_target and targets is None) 33 | preds = targets if comp_target else calc_preds(logits) 34 | 35 | probs = F.softmax(logits, dim=1) 36 | pred_probs = torch.gather(probs, dim=1, index=preds.unsqueeze(1)).flatten() 37 | 38 | inv_expl_probs = F.softmax(inv_expl_logits, dim=1) 39 | pred_inv_expl_probs = torch.gather(inv_expl_probs, dim=1, index=preds.unsqueeze(1)).flatten() 40 | 41 | return torch.mean(pred_probs - pred_inv_expl_probs) 42 | 43 | def calc_suff(logits, expl_logits, targets=None, suff_target=False): 44 | assert not (suff_target and targets is None) 45 | preds = targets if suff_target else calc_preds(logits) 46 | 47 | probs = F.softmax(logits, dim=1) 48 | pred_probs = torch.gather(probs, dim=1, index=preds.unsqueeze(1)).flatten() 49 | 50 | expl_probs = F.softmax(expl_logits, dim=1) 51 | pred_expl_probs = torch.gather(expl_probs, dim=1, index=preds.unsqueeze(1)).flatten() 52 | 53 | return torch.mean(pred_probs - pred_expl_probs) 54 | 55 | def calc_log_odds(logits, log_odds_logits, targets=None, log_odds_target=False): 56 | assert not (log_odds_target and targets is None) 57 | preds = targets if log_odds_target else calc_preds(logits) 58 | 59 | probs = -F.log_softmax(logits, dim=1) 60 | pred_probs = torch.gather(probs, dim=1, index=preds.unsqueeze(1)).flatten() 61 | 62 | log_odds_probs = -F.log_softmax(log_odds_logits, dim=1) 63 | pred_log_odds_probs = torch.gather(log_odds_probs, dim=1, index=preds.unsqueeze(1)).flatten() 64 | 65 | return torch.mean(pred_probs - pred_log_odds_probs) 66 | 67 | def calc_aopc(values): 68 | return torch.sum(values) / (len(values)+1) 69 | 70 | def calc_plaus(rationale, attrs, attn_mask, has_rationale, bin_thresh=0.0): 71 | batch_size = len(rationale) 72 | auprc_list, ap_list, token_f1_list = [], [], [] 73 | for i in range(batch_size): 74 | if has_rationale[i] == 1: 75 | num_tokens = attn_mask[i].sum() 76 | assert torch.sum(rationale[i][:num_tokens]) > 0 77 | 78 | rationale_ = rationale[i][:num_tokens].detach().cpu().numpy() 79 | attrs_ = attrs[i][:num_tokens].detach().cpu().numpy() 80 | bin_attrs_ = (attrs_ > bin_thresh).astype('float32') 81 | 82 | precision, recall, _ = precision_recall_curve( 83 | y_true=rationale_, 84 | probas_pred=attrs_, 85 | ) 86 | auprc_list.append(auc(recall, precision)) 87 | 88 | token_f1 = f1_score( 89 | y_true=rationale_, 90 | y_pred=bin_attrs_, 91 | average='macro', 92 | ) 93 | token_f1_list.append(token_f1) 94 | 95 | else: 96 | auprc_list.append(0.0) 97 | token_f1_list.append(0.0) 98 | ap_list.append(0.0) 99 | 100 | plaus_auprc = torch.tensor(np.mean(auprc_list)) 101 | plaus_token_f1 = torch.tensor(np.mean(token_f1_list)) 102 | 103 | return plaus_auprc, plaus_token_f1 104 | 105 | def get_step_metrics(preds, targets, metrics): 106 | res = {} 107 | for key, metric_fn in metrics.items(): 108 | res.update({key: metric_fn(preds, targets) * 100}) 109 | return res 110 | 111 | def get_epoch_metrics(metrics): 112 | res = {} 113 | for key, metric_fn in metrics.items(): 114 | res.update({key: metric_fn.compute() * 100}) 115 | metric_fn.reset() 116 | return res -------------------------------------------------------------------------------- /src/utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def clear_cache(): 5 | torch.cuda.empty_cache() 6 | 7 | def is_subsequence(a, b): 8 | # Check if list a is a subsequence if list b 9 | return any(a == b[i:i + len(a)] for i in range(len(b) - len(a) + 1)) -------------------------------------------------------------------------------- /src/utils/optim.py: -------------------------------------------------------------------------------- 1 | from transformers import get_scheduler 2 | 3 | no_decay = ['bias', 'LayerNorm.weight'] 4 | 5 | 6 | def setup_optimizer_params(model_dict, optimizer, explainer_type, attr_pooling=None, a2r=False): 7 | optimizer_parameters = [ 8 | { 9 | 'params': [p for n, p in model_dict['task_encoder'].named_parameters() if not any(nd in n for nd in no_decay)], 10 | 'weight_decay': optimizer.weight_decay, 11 | }, 12 | { 13 | 'params': [p for n, p in model_dict['task_encoder'].named_parameters() if any(nd in n for nd in no_decay)], 14 | 'weight_decay': 0.0, 15 | }, 16 | { 17 | 'params': [p for n, p in model_dict['task_head'].named_parameters() if not any(nd in n for nd in no_decay)], 18 | 'weight_decay': optimizer.weight_decay, 19 | }, 20 | { 21 | 'params': [p for n, p in model_dict['task_head'].named_parameters() if any(nd in n for nd in no_decay)], 22 | 'weight_decay': 0.0, 23 | }, 24 | ] 25 | 26 | if explainer_type == 'lm': 27 | optimizer_parameters += [ 28 | { 29 | 'params': [p for n, p in model_dict['expl_encoder'].named_parameters() if not any(nd in n for nd in no_decay)], 30 | 'weight_decay': optimizer.weight_decay, 31 | }, 32 | { 33 | 'params': [p for n, p in model_dict['expl_encoder'].named_parameters() if any(nd in n for nd in no_decay)], 34 | 'weight_decay': 0.0, 35 | }, 36 | { 37 | 'params': [p for n, p in model_dict['expl_head'].named_parameters() if not any(nd in n for nd in no_decay)], 38 | 'weight_decay': optimizer.weight_decay, 39 | }, 40 | { 41 | 'params': [p for n, p in model_dict['expl_head'].named_parameters() if any(nd in n for nd in no_decay)], 42 | 'weight_decay': 0.0, 43 | }, 44 | ] 45 | elif explainer_type == 'self_lm': 46 | optimizer_parameters += [ 47 | { 48 | 'params': [p for n, p in model_dict['expl_head'].named_parameters() if not any(nd in n for nd in no_decay)], 49 | 'weight_decay': optimizer.weight_decay, 50 | }, 51 | { 52 | 'params': [p for n, p in model_dict['expl_head'].named_parameters() if any(nd in n for nd in no_decay)], 53 | 'weight_decay': 0.0, 54 | }, 55 | ] 56 | elif explainer_type == 'attr_algo' and attr_pooling == 'mlp': 57 | optimizer_parameters += [ 58 | { 59 | 'params': [p for n, p in model_dict['attr_pooler'].named_parameters() if not any(nd in n for nd in no_decay)], 60 | 'weight_decay': optimizer.weight_decay, 61 | }, 62 | { 63 | 'params': [p for n, p in model_dict['attr_pooler'].named_parameters() if any(nd in n for nd in no_decay)], 64 | 'weight_decay': 0.0, 65 | }, 66 | ] 67 | 68 | if a2r: 69 | optimizer_parameters += [ 70 | { 71 | 'params': [p for n, p in model_dict['a2r_task_encoder'].named_parameters() if not any(nd in n for nd in no_decay)], 72 | 'weight_decay': optimizer.weight_decay, 73 | }, 74 | { 75 | 'params': [p for n, p in model_dict['a2r_task_encoder'].named_parameters() if any(nd in n for nd in no_decay)], 76 | 'weight_decay': 0.0, 77 | }, 78 | { 79 | 'params': [p for n, p in model_dict['a2r_task_head'].named_parameters() if not any(nd in n for nd in no_decay)], 80 | 'weight_decay': optimizer.weight_decay, 81 | }, 82 | { 83 | 'params': [p for n, p in model_dict['a2r_task_head'].named_parameters() if any(nd in n for nd in no_decay)], 84 | 'weight_decay': 0.0, 85 | }, 86 | ] 87 | 88 | return optimizer_parameters 89 | 90 | def setup_scheduler(scheduler, total_steps, optimizer): 91 | if scheduler.warmup_updates > 1.0: 92 | warmup_steps = int(scheduler.warmup_updates) 93 | else: 94 | warmup_steps = int(total_steps * 95 | scheduler.warmup_updates) 96 | print( 97 | f'\nTotal steps: {total_steps} with warmup steps: {warmup_steps}\n') 98 | 99 | scheduler = get_scheduler( 100 | "linear", optimizer=optimizer, 101 | num_warmup_steps=warmup_steps, num_training_steps=total_steps) 102 | 103 | scheduler = { 104 | 'scheduler': scheduler, 105 | 'interval': 'step', 106 | 'frequency': 1 107 | } 108 | return scheduler 109 | 110 | def freeze_net(module): 111 | for p in module.parameters(): 112 | p.requires_grad = False 113 | 114 | def unfreeze_net(module): 115 | for p in module.parameters(): 116 | p.requires_grad = True 117 | 118 | def freeze_layers(model, num_freeze_layers): 119 | if model.arch == 'google/bigbird-roberta-base': 120 | assert model.task_encoder is not None 121 | 122 | # Freeze task encoder's embedding layer 123 | for p in model.task_encoder.embeddings.parameters(): 124 | p.requires_grad = False 125 | 126 | # Freeze task encoder's encoder layers 127 | for i in range(num_freeze_layers): 128 | for p in model.task_encoder.encoder.layer[i].parameters(): 129 | p.requires_grad = False 130 | 131 | if model.expl_encoder is not None: 132 | # Freeze expl encoder's embedding layer 133 | for p in model.expl_encoder.embeddings.parameters(): 134 | p.requires_grad = False 135 | 136 | # Freeze expl encoder's encoder layers 137 | for i in range(num_freeze_layers): 138 | for p in model.expl_encoder.encoder.layer[i].parameters(): 139 | p.requires_grad = False 140 | 141 | else: 142 | raise NotImplementedError --------------------------------------------------------------------------------