├── .gitignore ├── LICENSE ├── README.md ├── allennlp-requirements.txt ├── config.py ├── data ├── README.md └── get_bert_embeddings │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── bert_config.json │ ├── create_pretraining_data.py │ ├── extract_features.py │ ├── modeling.py │ ├── optimization.py │ ├── pretrain_on_vcr.py │ ├── requirements.txt │ ├── tokenization.py │ ├── vcr_loader.py │ └── vocab.txt ├── dataloaders ├── __init__.py ├── bert_field.py ├── box_utils.py ├── cocoontology.json ├── mask_utils.py └── vcr.py ├── models ├── README.md ├── __init__.py ├── eval_for_leaderboard.py ├── eval_q2ar.py ├── multiatt │ ├── __init__.py │ ├── default.json │ ├── model.py │ ├── no_class_embs.json │ ├── no_obj.json │ └── no_reasoning.json └── train.py └── utils ├── __init__.py ├── detector.py └── pytorch_misc.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Rowan Zellers 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # From Recognition to Cognition: Visual Commonsense Reasoning (cvpr 2019 oral) 2 | 3 | This repository contains data and PyTorch code for the paper [From Recognition to Cognition: Visual Commonsense Reasoning (arxiv)](https://visualcommonsense.com). For more info, check out the project page at [visualcommonsense.com](https://visualcommonsense.com). For updates, or to ask for help, [check out and join our google group!!](https://groups.google.com/forum/#!forum/visualcommonsense/join) 4 | 5 | ![visualization](https://i.imgur.com/5XTaEkx.png "Visualization") 6 | 7 | This repo should be ready to replicate my results from the paper. If you have any issues with getting it set up though, please file a github issue. Still, the paper is just an arxiv version, so there might be more updates in the future. I'm super excited about VCR but it should be viewed as knowledge that's still in the making :) 8 | 9 | # Background as to the Recognition to Cognition model 10 | 11 | This repository is for the new task of Visual Commonsense Reasoning. A model is given an image, objects, a question, and four answer choices. The model has to decide which answer choice is correct. Then, it's given four rationale choices, and it has to decide which of those is the best rationale that explains *why its answer is right*. 12 | 13 | In particular, I have code and checkpoints for the Recognition to Cognition (R2C) model, as discussed in the paper [VCR paper](https://arxiv.org/abs/1811.10830). Here's a diagram that explains what's going on: 14 | 15 | ![modelfig](https://i.imgur.com/SNyz40p.png "Model figure") 16 | 17 | We'll treat going from Q->A and QA->R as two separate tasks: in each, the model is given a 'query' (question, or question+answer) and 'response choices' (answer, or rationale). Essentially, we'll use BERT and detection regions to *ground* the words in the query, then *contextualize* the query with the response. We'll perform several steps of *reasoning* on top of a representation consisting of the response choice in question, the attended query, and the attended detection regions. See the paper for more details. 18 | 19 | ## What this repo has / doesn't have 20 | I have code and checkpoints for replicating my R2C results. You might find the dataloader useful (in dataloaders/vcr.py), as it handles loading the data in a nice way using the allennlp library. You can submit to the [leaderboard](https://visualcommonsense.com/leaderboard/) using my script in `models/eval_for_leaderboard.py` 21 | 22 | You can train a model using `models/train.py`. This also has code to obtain model predictions. Use `models/eval_q2ar.py` to get validation results combining Q->A and QA->R components. 23 | 24 | # Setting up and using the repo 25 | 26 | 1. Get the dataset. Follow the steps in `data/README.md`. This includes the steps to get the pretrained BERT embeddings. **Note (as of Jan 23rd)** you'll need to re-download the test embeddings if you downloaded them before, as there was a bug in the version I had uploaded (essentially the 'anonymized' code didn't condition on the right context). 27 | 28 | 2. Install cuda 9.0 if it's not available already. You might want to follow this [this guide](https://medium.com/@zhanwenchen/install-cuda-9-2-and-cudnn-7-1-for-tensorflow-pytorch-gpu-on-ubuntu-16-04-1822ab4b2421) but using cuda 9.0. I use the following commands (my OS is ubuntu 16.04): 29 | ``` 30 | wget https://developer.nvidia.com/compute/cuda/9.0/Prod/local_installers/cuda_9.0.176_384.81_linux-run 31 | chmod +x cuda_9.0.176_384.81_linux-run 32 | ./cuda_9.0.176_384.81_linux-run --extract=$HOME 33 | sudo ./cuda-linux.9.0.176-22781540.run 34 | sudo ln -s /usr/local/cuda-9.0/ /usr/local/cuda 35 | export LD_LIBRARY_PATH=/usr/local/cuda-9.0/ 36 | ``` 37 | 38 | 3. Install anaconda if it's not available already, and create a new environment. You need to install a few things, namely, pytorch 1.0, torchvision (*from the layers branch, which has ROI pooling*), and allennlp. 39 | 40 | ``` 41 | wget https://repo.anaconda.com/archive/Anaconda3-5.2.0-Linux-x86_64.sh 42 | conda update -n base -c defaults conda 43 | conda create --name r2c python=3.6 44 | source activate r2c 45 | 46 | conda install numpy pyyaml setuptools cmake cffi tqdm pyyaml scipy ipython mkl mkl-include cython typing h5py pandas nltk spacy numpydoc scikit-learn jpeg 47 | 48 | conda install pytorch cudatoolkit=9.0 -c pytorch 49 | pip install git+git://github.com/pytorch/vision.git@24577864e92b72f7066e1ed16e978e873e19d13d 50 | 51 | pip install -r allennlp-requirements.txt 52 | pip install --no-deps allennlp==0.8.0 53 | python -m spacy download en_core_web_sm 54 | 55 | 56 | # this one is optional but it should help make things faster 57 | pip uninstall pillow && CC="cc -mavx2" pip install -U --force-reinstall pillow-simd 58 | ``` 59 | 60 | 4. If you don't want to download from scratch, then download my checkpoint. 61 | 62 | ``` 63 | wget https://s3-us-west-2.amazonaws.com/ai2-rowanz/r2c/flagship_answer/best.th -P models/saves/flagship_answer/ 64 | wget https://s3-us-west-2.amazonaws.com/ai2-rowanz/r2c/flagship_rationale/best.th -P models/saves/flagship_rationale/ 65 | ``` 66 | 67 | 5. That's it! Now to set up the environment, run `source activate r2c && export PYTHONPATH=/home/rowan/code/r2c` (or wherever you have this directory). 68 | 69 | ## help 70 | 71 | Feel free to open an issue if you encounter trouble getting it to work! [Or, post in the google group](https://groups.google.com/forum/#!forum/visualcommonsense/join). 72 | 73 | ### Bibtex 74 | 75 | ``` 76 | @inproceedings{zellers2019vcr, 77 | author = {Zellers, Rowan and Bisk, Yonatan and Farhadi, Ali and Choi, Yejin}, 78 | title = {From Recognition to Cognition: Visual Commonsense Reasoning}, 79 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 80 | month = {June}, 81 | year = {2019} 82 | } 83 | ``` 84 | -------------------------------------------------------------------------------- /allennlp-requirements.txt: -------------------------------------------------------------------------------- 1 | # Library dependencies for the python code. You need to install these with 2 | # `pip install -r requirements.txt` before you can run this. 3 | # NOTE: all essential packages must be placed under a section named 'ESSENTIAL ...' 4 | # so that the script `./scripts/check_requirements_and_setup.py` can find them. 5 | 6 | #### ESSENTIAL LIBRARIES FOR MAIN FUNCTIONALITY #### 7 | 8 | # Parameter parsing (but not on Windows). 9 | jsonnet==0.10.0 ; sys.platform != 'win32' 10 | 11 | # Adds an @overrides decorator for better documentation and error checking when using subclasses. 12 | overrides 13 | 14 | # Used by some old code. We moved away from it because it's too slow, but some old code still 15 | # imports this. 16 | nltk 17 | 18 | # Pin msgpack because the newer version introduces an incompatibility with spaCy 19 | # Get rid of this if we ever unpin spacy 20 | msgpack>=0.5.6,<0.6.0 21 | 22 | # Mainly used for the faster tokenizer. 23 | spacy>=2.0,<2.1 24 | 25 | # Used by span prediction models. 26 | numpy 27 | 28 | # Used for reading configuration info out of numpy-style docstrings. 29 | numpydoc==0.8.0 30 | 31 | # Used in coreference resolution evaluation metrics. 32 | scipy 33 | scikit-learn 34 | 35 | # Write logs for training visualisation with the Tensorboard application 36 | # Install the Tensorboard application separately (part of tensorflow) to view them. 37 | tensorboardX==1.2 38 | 39 | # Required by torch.utils.ffi 40 | cffi==1.11.5 41 | 42 | # aws commandline tools for running on Docker remotely. 43 | # second requirement is to get botocore < 1.11, to avoid the below bug 44 | awscli>=1.11.91 45 | 46 | # Accessing files from S3 directly. 47 | boto3 48 | 49 | # REST interface for models 50 | flask==1.0.2 51 | flask-cors==3.0.7 52 | gevent==1.3.6 53 | 54 | # Used by semantic parsing code to strip diacritics from unicode strings. 55 | unidecode 56 | 57 | # Used by semantic parsing code to parse SQL 58 | parsimonious==0.8.0 59 | 60 | # Used by semantic parsing code to format and postprocess SQL 61 | sqlparse==0.2.4 62 | 63 | # For text normalization 64 | ftfy 65 | 66 | # To use the BERT model 67 | pytorch-pretrained-bert==0.3.0 68 | 69 | #### ESSENTIAL LIBRARIES USED IN SCRIPTS #### 70 | 71 | # Plot graphs for learning rate finder 72 | matplotlib==2.2.3 73 | 74 | # Used for downloading datasets over HTTP 75 | requests>=2.18 76 | 77 | # progress bars in data cleaning scripts 78 | tqdm>=4.19 79 | 80 | # In SQuAD eval script, we use this to see if we likely have some tokenization problem. 81 | editdistance 82 | 83 | # For pretrained model weights 84 | h5py 85 | 86 | # For timezone utilities 87 | pytz==2017.3 88 | 89 | # Reads Universal Dependencies files. 90 | conllu==0.11 91 | 92 | #### ESSENTIAL TESTING-RELATED PACKAGES #### 93 | 94 | # We'll use pytest to run our tests; this isn't really necessary to run the code, but it is to run 95 | # the tests. With this here, you can run the tests with `py.test` from the base directory. 96 | pytest 97 | 98 | # Allows marking tests as flaky, to be rerun if they fail 99 | flaky 100 | 101 | # Required to mock out `requests` calls 102 | responses>=0.7 103 | 104 | # For mocking s3. 105 | moto==1.3.4 106 | 107 | #### TESTING-RELATED PACKAGES #### 108 | 109 | # Checks style, syntax, and other useful errors. 110 | pylint==1.8.1 111 | 112 | # Tutorial notebooks 113 | # see: https://github.com/jupyter/jupyter/issues/370 for ipykernel 114 | ipykernel<5.0.0 115 | jupyter 116 | 117 | # Static type checking 118 | mypy==0.521 119 | 120 | # Allows generation of coverage reports with pytest. 121 | pytest-cov 122 | 123 | # Allows codecov to generate coverage reports 124 | coverage 125 | codecov 126 | 127 | # Required to run sanic tests 128 | aiohttp 129 | 130 | #### DOC-RELATED PACKAGES #### 131 | 132 | # Builds our documentation. 133 | sphinx==1.5.3 134 | 135 | # Watches the documentation directory and rebuilds on changes. 136 | sphinx-autobuild 137 | 138 | # doc theme 139 | sphinx_rtd_theme 140 | 141 | # Only used to convert our readme to reStructuredText on Pypi. 142 | pypandoc 143 | 144 | # Pypi uploads 145 | twine==1.11.0 146 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | USE_IMAGENET_PRETRAINED = True # otherwise use detectron, but that doesnt seem to work?!? 3 | 4 | # Change these to match where your annotations and images are 5 | VCR_IMAGES_DIR = os.path.join(os.path.dirname(__file__), 'data', 'vcr1images') 6 | VCR_ANNOTS_DIR = os.path.join(os.path.dirname(__file__), 'data') 7 | 8 | if not os.path.exists(VCR_IMAGES_DIR): 9 | raise ValueError("Update config.py with where you saved VCR images to.") -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data 2 | 3 | Obtain the dataset by visiting [visualcommonsense.com/download.html](https://visualcommonsense.com/download.html). 4 | - Extract the images somewhere. I put them in a different directory, `/home/rowan/datasets2/vcr1/vcr1images` and added a symlink in this (`data`): `ln -s /home/rowan/datasets2/vcr1/vcr1images` 5 | - Put `train.jsonl`, `val.jsonl`, and `test.jsonl` in here (`data`). 6 | 7 | You can also put the dataset somewhere else, you'll just need to update `config.py` (in the main directory) accordingly. 8 | ``` 9 | unzip vcr1annots.zip 10 | ``` 11 | 12 | # Precomputed representations 13 | Running R2c requires computed bert representations in this folder. Warning: these files are quite large. You have two options to generate these: 14 | 15 | 1. (recommended) download them from : 16 | * `https://s3-us-west-2.amazonaws.com/ai2-rowanz/r2c/bert_da_answer_train.h5` 17 | * `https://s3-us-west-2.amazonaws.com/ai2-rowanz/r2c/bert_da_rationale_train.h5` 18 | * `https://s3-us-west-2.amazonaws.com/ai2-rowanz/r2c/bert_da_answer_val.h5` 19 | * `https://s3-us-west-2.amazonaws.com/ai2-rowanz/r2c/bert_da_rationale_val.h5` 20 | * `https://s3-us-west-2.amazonaws.com/ai2-rowanz/r2c/bert_da_answer_test.h5` 21 | * `https://s3-us-west-2.amazonaws.com/ai2-rowanz/r2c/bert_da_rationale_test.h5` 22 | 2. You can use the script in the folder `get_bert_embeddings` to precompute BERT representations for all sentences. If you want my finetuned checkpoint, it's available below. (note that you *don't* need this checkpoint if you want to just use the embeddings I shared above.) 23 | * `https://s3-us-west-2.amazonaws.com/ai2-rowanz/r2c/bert-pretrain/model.ckpt-53230.data-00000-of-00001` 24 | * `https://s3-us-west-2.amazonaws.com/ai2-rowanz/r2c/bert-pretrain/model.ckpt-53230.index` 25 | * `https://s3-us-west-2.amazonaws.com/ai2-rowanz/r2c/bert-pretrain/model.ckpt-53230.meta` -------------------------------------------------------------------------------- /data/get_bert_embeddings/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /data/get_bert_embeddings/README.md: -------------------------------------------------------------------------------- 1 | # Extracting BERT representations 2 | 3 | Replicating my results with R2C requires precomputing BERT representations of the dataset. These representations are really expensive to precompute, so doing so saves a lot of time. 4 | 5 | You can download them here: 6 | 7 | 8 | ## Extracting BERT representations yourself 9 | 10 | If you want to do so yourself, create a condaenv with tensorflow 1.11 installed. Here we'll call it `bert`. You can use 11 | the following command to compute BERT representations of `train.jsonl` in the `data/` folder: 12 | 13 | ``` 14 | source activate bert 15 | export LD_LIBRARY_PATH=/usr/local/cuda-8.0/lib64 16 | export PYTHONPATH=/user/home/rowan/code/r2c 17 | export CUDA_VISIBLE_DEVICES=0 18 | 19 | python extract_features.py --name bert --split=train 20 | ``` 21 | 22 | ## Domain adaptation 23 | 24 | In my early experiments, I found domain adaptation to be important with BERT, mainly because VCR is quite different than books/wikipedia style-wise. So, for the results in the paper, I performed domain adaptation as follows: 25 | 26 | First, I used the following script to create `pretrainingdata.tfrecord`: 27 | ``` 28 | source activate bert 29 | export LD_LIBRARY_PATH=/usr/local/cuda-8.0/lib64 30 | export PYTHONPATH=/user/home/rowan/code/r2c 31 | 32 | python create_pretraining_data.py 33 | ``` 34 | 35 | Then, I trained BERT on that, using 36 | 37 | ``` 38 | export CUDA_VISIBLE_DEVICES=0 39 | 40 | python pretrain_on_vcr.py --do_train 41 | ``` 42 | 43 | This creates a folder called `bert-pretrain`. Now, extract the features as follows. 44 | 45 | ``` 46 | python extract_features.py --name bert_da --init_checkpoint bert-pretrain/model.ckpt-53230 --split=train 47 | ``` 48 | 49 | -------------------------------------------------------------------------------- /data/get_bert_embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /data/get_bert_embeddings/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 768, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 3072, 8 | "max_position_embeddings": 512, 9 | "num_attention_heads": 12, 10 | "num_hidden_layers": 12, 11 | "type_vocab_size": 2, 12 | "vocab_size": 30522 13 | } 14 | -------------------------------------------------------------------------------- /data/get_bert_embeddings/create_pretraining_data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Create masked LM/next sentence masked_lm TF examples for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import os 23 | import random 24 | 25 | import tensorflow as tf 26 | 27 | from data.get_bert_embeddings import tokenization 28 | from data.get_bert_embeddings.vcr_loader import data_iter, convert_examples_to_features 29 | 30 | flags = tf.flags 31 | 32 | FLAGS = flags.FLAGS 33 | 34 | 35 | flags.DEFINE_string("split", 'train', "The split to use") 36 | 37 | flags.DEFINE_bool( 38 | "do_lower_case", True, 39 | "Whether to lower case the input text. Should be True for uncased " 40 | "models and False for cased models.") 41 | 42 | flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.") 43 | 44 | flags.DEFINE_integer("max_predictions_per_seq", 20, 45 | "Maximum number of masked LM predictions per sequence.") 46 | 47 | flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.") 48 | 49 | flags.DEFINE_integer( 50 | "dupe_factor", 1, 51 | "Number of times to duplicate the input data (with different masks).") 52 | 53 | flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.") 54 | 55 | mypath = os.getcwd() 56 | vocab_file = os.path.join(mypath, 'vocab.txt') 57 | assert os.path.exists(vocab_file) 58 | 59 | output_file = f'pretrainingdata.tfrecord' 60 | 61 | class TrainingInstance(object): 62 | """A single training instance (sentence pair).""" 63 | 64 | def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, 65 | is_random_next): 66 | self.tokens = tokens 67 | self.segment_ids = segment_ids 68 | self.is_random_next = is_random_next 69 | self.masked_lm_positions = masked_lm_positions 70 | self.masked_lm_labels = masked_lm_labels 71 | 72 | def __str__(self): 73 | s = "" 74 | s += "tokens: %s\n" % (" ".join( 75 | [tokenization.printable_text(x) for x in self.tokens])) 76 | s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids])) 77 | s += "is_random_next: %s\n" % self.is_random_next 78 | s += "masked_lm_positions: %s\n" % (" ".join( 79 | [str(x) for x in self.masked_lm_positions])) 80 | s += "masked_lm_labels: %s\n" % (" ".join( 81 | [tokenization.printable_text(x) for x in self.masked_lm_labels])) 82 | s += "\n" 83 | return s 84 | 85 | def __repr__(self): 86 | return self.__str__() 87 | 88 | 89 | def write_instance_to_example_files(instances, tokenizer, max_seq_length, 90 | max_predictions_per_seq, output_files): 91 | """Create TF example files from `TrainingInstance`s.""" 92 | writers = [] 93 | for output_file in output_files: 94 | writers.append(tf.python_io.TFRecordWriter(output_file)) 95 | 96 | writer_index = 0 97 | 98 | total_written = 0 99 | for (inst_index, instance) in enumerate(instances): 100 | input_ids = tokenizer.convert_tokens_to_ids(instance.tokens) 101 | input_mask = [1] * len(input_ids) 102 | segment_ids = list(instance.segment_ids) 103 | assert len(input_ids) <= max_seq_length 104 | 105 | while len(input_ids) < max_seq_length: 106 | input_ids.append(0) 107 | input_mask.append(0) 108 | segment_ids.append(0) 109 | 110 | assert len(input_ids) == max_seq_length 111 | assert len(input_mask) == max_seq_length 112 | assert len(segment_ids) == max_seq_length 113 | 114 | masked_lm_positions = list(instance.masked_lm_positions) 115 | masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels) 116 | masked_lm_weights = [1.0] * len(masked_lm_ids) 117 | 118 | while len(masked_lm_positions) < max_predictions_per_seq: 119 | masked_lm_positions.append(0) 120 | masked_lm_ids.append(0) 121 | masked_lm_weights.append(0.0) 122 | 123 | next_sentence_label = 1 if instance.is_random_next else 0 124 | 125 | features = collections.OrderedDict() 126 | features["input_ids"] = create_int_feature(input_ids) 127 | features["input_mask"] = create_int_feature(input_mask) 128 | features["segment_ids"] = create_int_feature(segment_ids) 129 | features["masked_lm_positions"] = create_int_feature(masked_lm_positions) 130 | features["masked_lm_ids"] = create_int_feature(masked_lm_ids) 131 | features["masked_lm_weights"] = create_float_feature(masked_lm_weights) 132 | features["next_sentence_labels"] = create_int_feature([next_sentence_label]) 133 | 134 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 135 | 136 | writers[writer_index].write(tf_example.SerializeToString()) 137 | writer_index = (writer_index + 1) % len(writers) 138 | 139 | total_written += 1 140 | 141 | if inst_index < 20: 142 | tf.logging.info("*** Example ***") 143 | tf.logging.info("tokens: %s" % " ".join( 144 | [tokenization.printable_text(x) for x in instance.tokens])) 145 | 146 | for feature_name in features.keys(): 147 | feature = features[feature_name] 148 | values = [] 149 | if feature.int64_list.value: 150 | values = feature.int64_list.value 151 | elif feature.float_list.value: 152 | values = feature.float_list.value 153 | tf.logging.info( 154 | "%s: %s" % (feature_name, " ".join([str(x) for x in values]))) 155 | 156 | for writer in writers: 157 | writer.close() 158 | 159 | tf.logging.info("Wrote %d total instances", total_written) 160 | 161 | 162 | def create_int_feature(values): 163 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 164 | return feature 165 | 166 | 167 | def create_float_feature(values): 168 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 169 | return feature 170 | 171 | 172 | def create_training_instances(tokenizer, rng): 173 | print("Iterating through the data", flush=True) 174 | input_examples = [] 175 | for x in data_iter(f'../{FLAGS.split}.jsonl', tokenizer=tokenizer, max_seq_length=FLAGS.max_seq_length, 176 | endingonly=False): 177 | input_examples.append(x[0]) 178 | 179 | print("Converting data to features", flush=True) 180 | input_features = convert_examples_to_features(input_examples, 181 | seq_length=FLAGS.max_seq_length, tokenizer=tokenizer) 182 | 183 | vocab_words = list(tokenizer.vocab.keys()) 184 | training_instances = [] 185 | for i, feature_example in enumerate(input_features): 186 | for j in range(FLAGS.dupe_factor): 187 | (tokens, masked_lm_positions, 188 | masked_lm_labels) = create_masked_lm_predictions( 189 | feature_example.tokens, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq, vocab_words, rng) 190 | 191 | training_instances.append(TrainingInstance( 192 | tokens=tokens, 193 | segment_ids=feature_example.input_type_ids[:len(tokens)], 194 | is_random_next=not feature_example.is_correct, 195 | masked_lm_positions=masked_lm_positions, 196 | masked_lm_labels=masked_lm_labels)) 197 | 198 | rng.shuffle(training_instances) 199 | return training_instances 200 | 201 | 202 | def create_masked_lm_predictions(tokens, masked_lm_prob, 203 | max_predictions_per_seq, vocab_words, rng): 204 | """Creates the predictions for the masked LM objective.""" 205 | 206 | cand_indexes = [] 207 | for (i, token) in enumerate(tokens): 208 | if token == "[CLS]" or token == "[SEP]": 209 | continue 210 | cand_indexes.append(i) 211 | 212 | rng.shuffle(cand_indexes) 213 | 214 | output_tokens = list(tokens) 215 | 216 | masked_lm = collections.namedtuple("masked_lm", ["index", "label"]) # pylint: disable=invalid-name 217 | 218 | num_to_predict = min(max_predictions_per_seq, 219 | max(1, int(round(len(tokens) * masked_lm_prob)))) 220 | 221 | masked_lms = [] 222 | covered_indexes = set() 223 | for index in cand_indexes: 224 | if len(masked_lms) >= num_to_predict: 225 | break 226 | if index in covered_indexes: 227 | continue 228 | covered_indexes.add(index) 229 | 230 | masked_token = None 231 | # 80% of the time, replace with [MASK] 232 | if rng.random() < 0.8: 233 | masked_token = "[MASK]" 234 | else: 235 | # 10% of the time, keep original 236 | if rng.random() < 0.5: 237 | masked_token = tokens[index] 238 | # 10% of the time, replace with random word 239 | else: 240 | masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)] 241 | 242 | output_tokens[index] = masked_token 243 | 244 | masked_lms.append(masked_lm(index=index, label=tokens[index])) 245 | 246 | masked_lms = sorted(masked_lms, key=lambda x: x.index) 247 | 248 | masked_lm_positions = [] 249 | masked_lm_labels = [] 250 | for p in masked_lms: 251 | masked_lm_positions.append(p.index) 252 | masked_lm_labels.append(p.label) 253 | 254 | return (output_tokens, masked_lm_positions, masked_lm_labels) 255 | 256 | 257 | tf.logging.set_verbosity(tf.logging.INFO) 258 | 259 | tokenizer = tokenization.FullTokenizer( 260 | vocab_file=vocab_file, do_lower_case=FLAGS.do_lower_case) 261 | 262 | rng = random.Random(FLAGS.random_seed) 263 | instances = create_training_instances(tokenizer, rng) 264 | 265 | output_files = [output_file] 266 | tf.logging.info("*** Writing to output files ***") 267 | for output_file in output_files: 268 | tf.logging.info(" %s", output_file) 269 | 270 | write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length, 271 | FLAGS.max_predictions_per_seq, output_files) 272 | -------------------------------------------------------------------------------- /data/get_bert_embeddings/extract_features.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Extract pre-computed feature vectors from BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import requests 23 | import zipfile 24 | import numpy as np 25 | 26 | from data.get_bert_embeddings import modeling 27 | from data.get_bert_embeddings import tokenization 28 | import tensorflow as tf 29 | import h5py 30 | from tqdm import tqdm 31 | from data.get_bert_embeddings.vcr_loader import data_iter, data_iter_test, convert_examples_to_features, input_fn_builder 32 | 33 | flags = tf.flags 34 | 35 | FLAGS = flags.FLAGS 36 | 37 | flags.DEFINE_string("name", 'bert', "The name to use") 38 | 39 | flags.DEFINE_string("split", 'train', "The split to use") 40 | 41 | flags.DEFINE_string("layers", "-2", "") 42 | 43 | flags.DEFINE_integer( 44 | "max_seq_length", 128, 45 | "The maximum total input sequence length after WordPiece tokenization. " 46 | "Sequences longer than this will be truncated, and sequences shorter " 47 | "than this will be padded.") 48 | 49 | flags.DEFINE_string( 50 | "init_checkpoint", 'uncased_L-12_H-768_A-12/bert_model.ckpt', 51 | "Initial checkpoint (usually from a pre-trained BERT model).") 52 | 53 | flags.DEFINE_bool( 54 | "do_lower_case", True, 55 | "Whether to lower case the input text. Should be True for uncased " 56 | "models and False for cased models.") 57 | 58 | flags.DEFINE_integer("batch_size", 32, "Batch size for predictions.") 59 | 60 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 61 | 62 | flags.DEFINE_bool("endingonly", False, "Only use the ending") 63 | 64 | flags.DEFINE_string("master", None, 65 | "If using a TPU, the address of the master.") 66 | 67 | flags.DEFINE_integer( 68 | "num_tpu_cores", 8, 69 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 70 | 71 | flags.DEFINE_bool( 72 | "use_one_hot_embeddings", False, 73 | "If True, tf.one_hot will be used for embedding lookups, otherwise " 74 | "tf.nn.embedding_lookup will be used. On TPUs, this should be True " 75 | "since it is much faster.") 76 | 77 | #### 78 | 79 | if not os.path.exists('uncased_L-12_H-768_A-12'): 80 | response = requests.get('https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip', 81 | stream=True) 82 | with open('uncased_L-12_H-768_A-12.zip', "wb") as handle: 83 | for chunk in response.iter_content(chunk_size=512): 84 | if chunk: # filter out keep-alive new chunks 85 | handle.write(chunk) 86 | with zipfile.ZipFile('uncased_L-12_H-768_A-12.zip') as zf: 87 | zf.extractall() 88 | 89 | print("BERT HAS BEEN DOWNLOADED") 90 | mypath = os.getcwd() 91 | bert_config_file = os.path.join(mypath, 'uncased_L-12_H-768_A-12', 'bert_config.json') 92 | vocab_file = os.path.join(mypath, 'uncased_L-12_H-768_A-12', 'vocab.txt') 93 | # init_checkpoint = os.path.join(mypath, 'uncased_L-12_H-768_A-12', 'bert_model.ckpt') 94 | bert_config = modeling.BertConfig.from_json_file(bert_config_file) 95 | 96 | 97 | def model_fn_builder(bert_config, init_checkpoint, layer_indexes, use_tpu, 98 | use_one_hot_embeddings): 99 | """Returns `model_fn` closure for TPUEstimator.""" 100 | 101 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 102 | """The `model_fn` for TPUEstimator.""" 103 | 104 | unique_ids = features["unique_ids"] 105 | input_ids = features["input_ids"] 106 | input_mask = features["input_mask"] 107 | input_type_ids = features["input_type_ids"] 108 | 109 | model = modeling.BertModel( 110 | config=bert_config, 111 | is_training=False, 112 | input_ids=input_ids, 113 | input_mask=input_mask, 114 | token_type_ids=input_type_ids, 115 | use_one_hot_embeddings=use_one_hot_embeddings) 116 | 117 | if mode != tf.estimator.ModeKeys.PREDICT: 118 | raise ValueError("Only PREDICT modes are supported: %s" % (mode)) 119 | 120 | tvars = tf.trainable_variables() 121 | scaffold_fn = None 122 | (assignment_map, _) = modeling.get_assignment_map_from_checkpoint( 123 | tvars, init_checkpoint) 124 | if use_tpu: 125 | 126 | def tpu_scaffold(): 127 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 128 | return tf.train.Scaffold() 129 | 130 | scaffold_fn = tpu_scaffold 131 | else: 132 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 133 | 134 | all_layers = model.get_all_encoder_layers() 135 | 136 | predictions = { 137 | "unique_id": unique_ids, 138 | } 139 | 140 | for (i, layer_index) in enumerate(layer_indexes): 141 | predictions["layer_output_%d" % i] = all_layers[layer_index] 142 | 143 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 144 | mode=mode, predictions=predictions, scaffold_fn=scaffold_fn) 145 | return output_spec 146 | 147 | return model_fn 148 | 149 | 150 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 151 | """Truncates a sequence pair in place to the maximum length.""" 152 | 153 | # This is a simple heuristic which will always truncate the longer sequence 154 | # one token at a time. This makes more sense than truncating an equal percent 155 | # of tokens from each, since if one sequence is very short then each token 156 | # that's truncated likely contains more information than a longer sequence. 157 | while True: 158 | total_length = len(tokens_a) + len(tokens_b) 159 | if total_length <= max_length: 160 | break 161 | if len(tokens_a) > len(tokens_b): 162 | tokens_a.pop() 163 | else: 164 | tokens_b.pop() 165 | 166 | 167 | tf.logging.set_verbosity(tf.logging.INFO) 168 | 169 | layer_indexes = [int(x) for x in FLAGS.layers.split(",")] 170 | 171 | tokenizer = tokenization.FullTokenizer( 172 | vocab_file=vocab_file, do_lower_case=FLAGS.do_lower_case) 173 | ######################################## 174 | 175 | data_iter_ = data_iter if FLAGS.split != 'test' else data_iter_test 176 | examples = [x for x in data_iter_(f'../{FLAGS.split}.jsonl', 177 | tokenizer=tokenizer, 178 | max_seq_length=FLAGS.max_seq_length, 179 | endingonly=FLAGS.endingonly)] 180 | features = convert_examples_to_features( 181 | examples=[x[0] for x in examples], seq_length=FLAGS.max_seq_length, tokenizer=tokenizer) 182 | unique_id_to_ind = {} 183 | for i, feature in enumerate(features): 184 | unique_id_to_ind[feature.unique_id] = i 185 | 186 | ############################ Tensorflow boilerplate 187 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 188 | run_config = tf.contrib.tpu.RunConfig( 189 | master=FLAGS.master, 190 | tpu_config=tf.contrib.tpu.TPUConfig( 191 | num_shards=FLAGS.num_tpu_cores, 192 | per_host_input_for_training=is_per_host)) 193 | 194 | model_fn = model_fn_builder( 195 | bert_config=bert_config, 196 | init_checkpoint=FLAGS.init_checkpoint, 197 | layer_indexes=layer_indexes, 198 | use_tpu=FLAGS.use_tpu, 199 | use_one_hot_embeddings=FLAGS.use_one_hot_embeddings) 200 | 201 | # If TPU is not available, this will fall back to normal Estimator on CPU 202 | # or GPU. 203 | estimator = tf.contrib.tpu.TPUEstimator( 204 | use_tpu=FLAGS.use_tpu, 205 | model_fn=model_fn, 206 | config=run_config, 207 | predict_batch_size=FLAGS.batch_size) 208 | 209 | input_fn = input_fn_builder( 210 | features=features, seq_length=FLAGS.max_seq_length) 211 | 212 | output_h5_qa = h5py.File(f'../{FLAGS.name}_answer_{FLAGS.split}.h5', 'w') 213 | output_h5_qar = h5py.File(f'../{FLAGS.name}_rationale_{FLAGS.split}.h5', 'w') 214 | 215 | if FLAGS.split != 'test': 216 | subgroup_names = [ 217 | 'answer0', 218 | 'answer1', 219 | 'answer2', 220 | 'answer3', 221 | 'rationale0', 222 | 'rationale1', 223 | 'rationale2', 224 | 'rationale3', 225 | ] 226 | else: 227 | subgroup_names = [ 228 | 'answer0', 229 | 'answer1', 230 | 'answer2', 231 | 'answer3'] + [f'rationale{x}{y}' for x in range(4) for y in range(4)] 232 | 233 | for i in range(len(examples) // len(subgroup_names)): 234 | output_h5_qa.create_group(f'{i}') 235 | output_h5_qar.create_group(f'{i}') 236 | 237 | 238 | def alignment_gather(alignment, layer): 239 | reverse_alignment = [[i for i, x in enumerate(alignment) if x == j] for j in range(max(alignment) + 1)] 240 | output_embs = np.zeros((max(alignment) + 1, layer.shape[1]), dtype=np.float16) 241 | 242 | # Make sure everything is covered 243 | uncovered = np.zeros(max(alignment) + 1, dtype=np.bool) 244 | 245 | for j, trgs in enumerate(reverse_alignment): 246 | if len(trgs) == 0: 247 | uncovered[j] = True 248 | else: 249 | output_embs[j] = np.mean(layer[trgs], 0).astype(np.float16) 250 | things_to_fill = np.where(uncovered[:j])[0] 251 | if things_to_fill.shape[0] != 0: 252 | output_embs[things_to_fill] = output_embs[j] 253 | uncovered[:j] = False 254 | return output_embs 255 | 256 | 257 | for result in tqdm(estimator.predict(input_fn, yield_single_examples=True)): 258 | ind = unique_id_to_ind[int(result["unique_id"])] 259 | 260 | text, ctx_alignment, choice_alignment = examples[ind] 261 | # just one layer for now 262 | layer = result['layer_output_0'] 263 | ex2use = ind//len(subgroup_names) 264 | subgroup_name = subgroup_names[ind % len(subgroup_names)] 265 | 266 | group2use = (output_h5_qa if subgroup_name.startswith('answer') else output_h5_qar)[f'{ex2use}'] 267 | alignment_ctx = [-1] + ctx_alignment 268 | 269 | if FLAGS.endingonly: 270 | # just a single span here 271 | group2use.create_dataset(f'answer_{subgroup_name}', data=alignment_gather(alignment_ctx, layer)) 272 | else: 273 | alignment_answer = [-1] + [-1 for i in range(len(ctx_alignment))] + [-1] + choice_alignment 274 | group2use.create_dataset(f'ctx_{subgroup_name}', data=alignment_gather(alignment_ctx, layer)) 275 | group2use.create_dataset(f'answer_{subgroup_name}', data=alignment_gather(alignment_answer, layer)) 276 | -------------------------------------------------------------------------------- /data/get_bert_embeddings/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | new_global_step = global_step + 1 80 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 81 | return train_op 82 | 83 | 84 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 85 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 86 | 87 | def __init__(self, 88 | learning_rate, 89 | weight_decay_rate=0.0, 90 | beta_1=0.9, 91 | beta_2=0.999, 92 | epsilon=1e-6, 93 | exclude_from_weight_decay=None, 94 | name="AdamWeightDecayOptimizer"): 95 | """Constructs a AdamWeightDecayOptimizer.""" 96 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 97 | 98 | self.learning_rate = learning_rate 99 | self.weight_decay_rate = weight_decay_rate 100 | self.beta_1 = beta_1 101 | self.beta_2 = beta_2 102 | self.epsilon = epsilon 103 | self.exclude_from_weight_decay = exclude_from_weight_decay 104 | 105 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 106 | """See base class.""" 107 | assignments = [] 108 | for (grad, param) in grads_and_vars: 109 | if grad is None or param is None: 110 | continue 111 | 112 | param_name = self._get_variable_name(param.name) 113 | 114 | m = tf.get_variable( 115 | name=param_name + "/adam_m", 116 | shape=param.shape.as_list(), 117 | dtype=tf.float32, 118 | trainable=False, 119 | initializer=tf.zeros_initializer()) 120 | v = tf.get_variable( 121 | name=param_name + "/adam_v", 122 | shape=param.shape.as_list(), 123 | dtype=tf.float32, 124 | trainable=False, 125 | initializer=tf.zeros_initializer()) 126 | 127 | # Standard Adam update. 128 | next_m = ( 129 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 130 | next_v = ( 131 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 132 | tf.square(grad))) 133 | 134 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 135 | 136 | # Just adding the square of the weights to the loss function is *not* 137 | # the correct way of using L2 regularization/weight decay with Adam, 138 | # since that will interact with the m and v parameters in strange ways. 139 | # 140 | # Instead we want ot decay the weights in a manner that doesn't interact 141 | # with the m/v parameters. This is equivalent to adding the square 142 | # of the weights to the loss with plain (non-momentum) SGD. 143 | if self._do_use_weight_decay(param_name): 144 | update += self.weight_decay_rate * param 145 | 146 | update_with_lr = self.learning_rate * update 147 | 148 | next_param = param - update_with_lr 149 | 150 | assignments.extend( 151 | [param.assign(next_param), 152 | m.assign(next_m), 153 | v.assign(next_v)]) 154 | return tf.group(*assignments, name=name) 155 | 156 | def _do_use_weight_decay(self, param_name): 157 | """Whether to use L2 weight decay for `param_name`.""" 158 | if not self.weight_decay_rate: 159 | return False 160 | if self.exclude_from_weight_decay: 161 | for r in self.exclude_from_weight_decay: 162 | if re.search(r, param_name) is not None: 163 | return False 164 | return True 165 | 166 | def _get_variable_name(self, param_name): 167 | """Get the variable name from the tensor name.""" 168 | m = re.match("^(.*):\\d+$", param_name) 169 | if m is not None: 170 | param_name = m.group(1) 171 | return param_name 172 | -------------------------------------------------------------------------------- /data/get_bert_embeddings/pretrain_on_vcr.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Run masked LM/next sentence masked_lm pre-training for BERT.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | from data.get_bert_embeddings import modeling 23 | from data.get_bert_embeddings import optimization 24 | import tensorflow as tf 25 | import zipfile 26 | import requests 27 | 28 | flags = tf.flags 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | ## Required parameters 33 | flags.DEFINE_string( 34 | "input_file", 'pretrainingdata.tfrecord', 35 | "Input TF example files (can be a glob or comma separated).") 36 | 37 | flags.DEFINE_string( 38 | "output_dir", 'bert-pretrained', 39 | "The output directory where the model checkpoints will be written.") 40 | 41 | flags.DEFINE_integer( 42 | "max_seq_length", 128, 43 | "The maximum total input sequence length after WordPiece tokenization. " 44 | "Sequences longer than this will be truncated, and sequences shorter " 45 | "than this will be padded. Must match data generation.") 46 | 47 | flags.DEFINE_integer( 48 | "max_predictions_per_seq", 20, 49 | "Maximum number of masked LM predictions per sequence. " 50 | "Must match data generation.") 51 | 52 | flags.DEFINE_bool("do_train", False, "Whether to run training.") 53 | 54 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") 55 | 56 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") 57 | 58 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.") 59 | 60 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") 61 | 62 | # One epoch53230.75 63 | flags.DEFINE_integer("num_train_steps", 53230, "Number of training steps.") 64 | 65 | flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.") 66 | 67 | flags.DEFINE_integer("save_checkpoints_steps", 1000, 68 | "How often to save the model checkpoint.") 69 | 70 | flags.DEFINE_integer("iterations_per_loop", 1000, 71 | "How many steps to make in each estimator call.") 72 | 73 | flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.") 74 | 75 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") 76 | 77 | tf.flags.DEFINE_string( 78 | "tpu_name", None, 79 | "The Cloud TPU to use for training. This should be either the name " 80 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " 81 | "url.") 82 | 83 | tf.flags.DEFINE_string( 84 | "tpu_zone", None, 85 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 86 | "specified, we will attempt to automatically detect the GCE project from " 87 | "metadata.") 88 | 89 | tf.flags.DEFINE_string( 90 | "gcp_project", None, 91 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 92 | "specified, we will attempt to automatically detect the GCE project from " 93 | "metadata.") 94 | 95 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") 96 | 97 | flags.DEFINE_integer( 98 | "num_tpu_cores", 8, 99 | "Only used if `use_tpu` is True. Total number of TPU cores to use.") 100 | 101 | 102 | if not os.path.exists('uncased_L-12_H-768_A-12'): 103 | response = requests.get('https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip', 104 | stream=True) 105 | with open('uncased_L-12_H-768_A-12.zip', "wb") as handle: 106 | for chunk in response.iter_content(chunk_size=512): 107 | if chunk: # filter out keep-alive new chunks 108 | handle.write(chunk) 109 | with zipfile.ZipFile('uncased_L-12_H-768_A-12.zip') as zf: 110 | zf.extractall() 111 | 112 | print("BERT HAS BEEN DOWNLOADED") 113 | mypath = os.getcwd() 114 | bert_config_file = os.path.join(mypath, 'uncased_L-12_H-768_A-12', 'bert_config.json') 115 | vocab_file = os.path.join(mypath, 'uncased_L-12_H-768_A-12', 'vocab.txt') 116 | init_checkpoint = os.path.join(mypath, 'uncased_L-12_H-768_A-12', 'bert_model.ckpt') 117 | bert_config = modeling.BertConfig.from_json_file(bert_config_file) 118 | 119 | 120 | def model_fn_builder(bert_config, init_checkpoint, learning_rate, 121 | num_train_steps, num_warmup_steps, use_tpu, 122 | use_one_hot_embeddings): 123 | """Returns `model_fn` closure for TPUEstimator.""" 124 | 125 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument 126 | """The `model_fn` for TPUEstimator.""" 127 | 128 | tf.logging.info("*** Features ***") 129 | for name in sorted(features.keys()): 130 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) 131 | 132 | input_ids = features["input_ids"] 133 | input_mask = features["input_mask"] 134 | segment_ids = features["segment_ids"] 135 | masked_lm_positions = features["masked_lm_positions"] 136 | masked_lm_ids = features["masked_lm_ids"] 137 | masked_lm_weights = features["masked_lm_weights"] 138 | next_sentence_labels = features["next_sentence_labels"] 139 | 140 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 141 | 142 | model = modeling.BertModel( 143 | config=bert_config, 144 | is_training=is_training, 145 | input_ids=input_ids, 146 | input_mask=input_mask, 147 | token_type_ids=segment_ids, 148 | use_one_hot_embeddings=use_one_hot_embeddings) 149 | 150 | (masked_lm_loss, 151 | masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output( 152 | bert_config, model.get_sequence_output(), model.get_embedding_table(), 153 | masked_lm_positions, masked_lm_ids, masked_lm_weights) 154 | 155 | (next_sentence_loss, next_sentence_example_loss, 156 | next_sentence_log_probs) = get_next_sentence_output( 157 | bert_config, model.get_pooled_output(), next_sentence_labels) 158 | 159 | total_loss = masked_lm_loss + next_sentence_loss 160 | 161 | tvars = tf.trainable_variables() 162 | 163 | initialized_variable_names = {} 164 | scaffold_fn = None 165 | if init_checkpoint: 166 | (assignment_map, initialized_variable_names 167 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) 168 | if use_tpu: 169 | 170 | def tpu_scaffold(): 171 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 172 | return tf.train.Scaffold() 173 | 174 | scaffold_fn = tpu_scaffold 175 | else: 176 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 177 | 178 | tf.logging.info("**** Trainable Variables ****") 179 | for var in tvars: 180 | init_string = "" 181 | if var.name in initialized_variable_names: 182 | init_string = ", *INIT_FROM_CKPT*" 183 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, 184 | init_string) 185 | 186 | output_spec = None 187 | if mode == tf.estimator.ModeKeys.TRAIN: 188 | train_op = optimization.create_optimizer( 189 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) 190 | 191 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 192 | mode=mode, 193 | loss=total_loss, 194 | train_op=train_op, 195 | scaffold_fn=scaffold_fn) 196 | elif mode == tf.estimator.ModeKeys.EVAL: 197 | 198 | def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 199 | masked_lm_weights, next_sentence_example_loss, 200 | next_sentence_log_probs, next_sentence_labels): 201 | """Computes the loss and accuracy of the model.""" 202 | masked_lm_log_probs = tf.reshape(masked_lm_log_probs, 203 | [-1, masked_lm_log_probs.shape[-1]]) 204 | masked_lm_predictions = tf.argmax( 205 | masked_lm_log_probs, axis=-1, output_type=tf.int32) 206 | masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) 207 | masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) 208 | masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) 209 | masked_lm_accuracy = tf.metrics.accuracy( 210 | labels=masked_lm_ids, 211 | predictions=masked_lm_predictions, 212 | weights=masked_lm_weights) 213 | masked_lm_mean_loss = tf.metrics.mean( 214 | values=masked_lm_example_loss, weights=masked_lm_weights) 215 | 216 | next_sentence_log_probs = tf.reshape( 217 | next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]) 218 | next_sentence_predictions = tf.argmax( 219 | next_sentence_log_probs, axis=-1, output_type=tf.int32) 220 | next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) 221 | next_sentence_accuracy = tf.metrics.accuracy( 222 | labels=next_sentence_labels, predictions=next_sentence_predictions) 223 | next_sentence_mean_loss = tf.metrics.mean( 224 | values=next_sentence_example_loss) 225 | 226 | return { 227 | "masked_lm_accuracy": masked_lm_accuracy, 228 | "masked_lm_loss": masked_lm_mean_loss, 229 | "next_sentence_accuracy": next_sentence_accuracy, 230 | "next_sentence_loss": next_sentence_mean_loss, 231 | } 232 | 233 | eval_metrics = (metric_fn, [ 234 | masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, 235 | masked_lm_weights, next_sentence_example_loss, 236 | next_sentence_log_probs, next_sentence_labels 237 | ]) 238 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 239 | mode=mode, 240 | loss=total_loss, 241 | eval_metrics=eval_metrics, 242 | scaffold_fn=scaffold_fn) 243 | else: 244 | raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode)) 245 | 246 | return output_spec 247 | 248 | return model_fn 249 | 250 | 251 | def get_masked_lm_output(bert_config, input_tensor, output_weights, positions, 252 | label_ids, label_weights): 253 | """Get loss and log probs for the masked LM.""" 254 | input_tensor = gather_indexes(input_tensor, positions) 255 | 256 | with tf.variable_scope("cls/predictions"): 257 | # We apply one more non-linear transformation before the output layer. 258 | # This matrix is not used after pre-training. 259 | with tf.variable_scope("transform"): 260 | input_tensor = tf.layers.dense( 261 | input_tensor, 262 | units=bert_config.hidden_size, 263 | activation=modeling.get_activation(bert_config.hidden_act), 264 | kernel_initializer=modeling.create_initializer( 265 | bert_config.initializer_range)) 266 | input_tensor = modeling.layer_norm(input_tensor) 267 | 268 | # The output weights are the same as the input embeddings, but there is 269 | # an output-only bias for each token. 270 | output_bias = tf.get_variable( 271 | "output_bias", 272 | shape=[bert_config.vocab_size], 273 | initializer=tf.zeros_initializer()) 274 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 275 | logits = tf.nn.bias_add(logits, output_bias) 276 | log_probs = tf.nn.log_softmax(logits, axis=-1) 277 | 278 | label_ids = tf.reshape(label_ids, [-1]) 279 | label_weights = tf.reshape(label_weights, [-1]) 280 | 281 | one_hot_labels = tf.one_hot( 282 | label_ids, depth=bert_config.vocab_size, dtype=tf.float32) 283 | 284 | # The `positions` tensor might be zero-padded (if the sequence is too 285 | # short to have the maximum number of predictions). The `label_weights` 286 | # tensor has a value of 1.0 for every real prediction and 0.0 for the 287 | # padding predictions. 288 | per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) 289 | numerator = tf.reduce_sum(label_weights * per_example_loss) 290 | denominator = tf.reduce_sum(label_weights) + 1e-5 291 | loss = numerator / denominator 292 | 293 | return (loss, per_example_loss, log_probs) 294 | 295 | 296 | def get_next_sentence_output(bert_config, input_tensor, labels): 297 | """Get loss and log probs for the next sentence prediction.""" 298 | 299 | # Simple binary classification. Note that 0 is "next sentence" and 1 is 300 | # "random sentence". This weight matrix is not used after pre-training. 301 | with tf.variable_scope("cls/seq_relationship"): 302 | output_weights = tf.get_variable( 303 | "output_weights", 304 | shape=[2, bert_config.hidden_size], 305 | initializer=modeling.create_initializer(bert_config.initializer_range)) 306 | output_bias = tf.get_variable( 307 | "output_bias", shape=[2], initializer=tf.zeros_initializer()) 308 | 309 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True) 310 | logits = tf.nn.bias_add(logits, output_bias) 311 | log_probs = tf.nn.log_softmax(logits, axis=-1) 312 | labels = tf.reshape(labels, [-1]) 313 | one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) 314 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 315 | loss = tf.reduce_mean(per_example_loss) 316 | return (loss, per_example_loss, log_probs) 317 | 318 | 319 | def gather_indexes(sequence_tensor, positions): 320 | """Gathers the vectors at the specific positions over a minibatch.""" 321 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3) 322 | batch_size = sequence_shape[0] 323 | seq_length = sequence_shape[1] 324 | width = sequence_shape[2] 325 | 326 | flat_offsets = tf.reshape( 327 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1]) 328 | flat_positions = tf.reshape(positions + flat_offsets, [-1]) 329 | flat_sequence_tensor = tf.reshape(sequence_tensor, 330 | [batch_size * seq_length, width]) 331 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions) 332 | return output_tensor 333 | 334 | 335 | def input_fn_builder(input_files, 336 | max_seq_length, 337 | max_predictions_per_seq, 338 | is_training, 339 | num_cpu_threads=4): 340 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 341 | 342 | def input_fn(params): 343 | """The actual input function.""" 344 | batch_size = params["batch_size"] 345 | 346 | name_to_features = { 347 | "input_ids": 348 | tf.FixedLenFeature([max_seq_length], tf.int64), 349 | "input_mask": 350 | tf.FixedLenFeature([max_seq_length], tf.int64), 351 | "segment_ids": 352 | tf.FixedLenFeature([max_seq_length], tf.int64), 353 | "masked_lm_positions": 354 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 355 | "masked_lm_ids": 356 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64), 357 | "masked_lm_weights": 358 | tf.FixedLenFeature([max_predictions_per_seq], tf.float32), 359 | "next_sentence_labels": 360 | tf.FixedLenFeature([1], tf.int64), 361 | } 362 | 363 | # For training, we want a lot of parallel reading and shuffling. 364 | # For eval, we want no shuffling and parallel reading doesn't matter. 365 | if is_training: 366 | d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files)) 367 | d = d.repeat() 368 | d = d.shuffle(buffer_size=len(input_files)) 369 | 370 | # `cycle_length` is the number of parallel files that get read. 371 | cycle_length = min(num_cpu_threads, len(input_files)) 372 | 373 | # `sloppy` mode means that the interleaving is not exact. This adds 374 | # even more randomness to the training pipeline. 375 | d = d.apply( 376 | tf.contrib.data.parallel_interleave( 377 | tf.data.TFRecordDataset, 378 | sloppy=is_training, 379 | cycle_length=cycle_length)) 380 | d = d.shuffle(buffer_size=100) 381 | else: 382 | d = tf.data.TFRecordDataset(input_files) 383 | # Since we evaluate for a fixed number of steps we don't want to encounter 384 | # out-of-range exceptions. 385 | d = d.repeat() 386 | 387 | # We must `drop_remainder` on training because the TPU requires fixed 388 | # size dimensions. For eval, we assume we are evaluating on the CPU or GPU 389 | # and we *don't* want to drop the remainder, otherwise we wont cover 390 | # every sample. 391 | d = d.apply( 392 | tf.contrib.data.map_and_batch( 393 | lambda record: _decode_record(record, name_to_features), 394 | batch_size=batch_size, 395 | num_parallel_batches=num_cpu_threads, 396 | drop_remainder=True)) 397 | return d 398 | 399 | return input_fn 400 | 401 | 402 | def _decode_record(record, name_to_features): 403 | """Decodes a record to a TensorFlow example.""" 404 | example = tf.parse_single_example(record, name_to_features) 405 | 406 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32. 407 | # So cast all int64 to int32. 408 | for name in list(example.keys()): 409 | t = example[name] 410 | if t.dtype == tf.int64: 411 | t = tf.to_int32(t) 412 | example[name] = t 413 | 414 | return example 415 | 416 | 417 | tf.logging.set_verbosity(tf.logging.INFO) 418 | 419 | if not FLAGS.do_train and not FLAGS.do_eval: 420 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 421 | 422 | tf.gfile.MakeDirs(FLAGS.output_dir) 423 | 424 | tpu_cluster_resolver = None 425 | if FLAGS.use_tpu and FLAGS.tpu_name: 426 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 427 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 428 | 429 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 430 | run_config = tf.contrib.tpu.RunConfig( 431 | cluster=tpu_cluster_resolver, 432 | master=FLAGS.master, 433 | model_dir=FLAGS.output_dir, 434 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 435 | tpu_config=tf.contrib.tpu.TPUConfig( 436 | iterations_per_loop=FLAGS.iterations_per_loop, 437 | num_shards=FLAGS.num_tpu_cores, 438 | per_host_input_for_training=is_per_host)) 439 | 440 | 441 | model_fn = model_fn_builder( 442 | bert_config=bert_config, 443 | init_checkpoint=init_checkpoint, 444 | learning_rate=FLAGS.learning_rate, 445 | num_train_steps=FLAGS.num_train_steps, 446 | num_warmup_steps=FLAGS.num_warmup_steps, 447 | use_tpu=FLAGS.use_tpu, 448 | use_one_hot_embeddings=FLAGS.use_tpu) 449 | 450 | # If TPU is not available, this will fall back to normal Estimator on CPU 451 | # or GPU. 452 | estimator = tf.contrib.tpu.TPUEstimator( 453 | use_tpu=FLAGS.use_tpu, 454 | model_fn=model_fn, 455 | config=run_config, 456 | train_batch_size=FLAGS.train_batch_size, 457 | eval_batch_size=FLAGS.eval_batch_size) 458 | 459 | if FLAGS.do_train: 460 | tf.logging.info("***** Running training *****") 461 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) 462 | train_input_fn = input_fn_builder( 463 | input_files=[FLAGS.input_file], 464 | max_seq_length=FLAGS.max_seq_length, 465 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 466 | is_training=True) 467 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps) 468 | 469 | if FLAGS.do_eval: 470 | tf.logging.info("***** Running evaluation *****") 471 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) 472 | 473 | eval_input_fn = input_fn_builder( 474 | input_files=[FLAGS.input_file], 475 | max_seq_length=FLAGS.max_seq_length, 476 | max_predictions_per_seq=FLAGS.max_predictions_per_seq, 477 | is_training=False) 478 | 479 | result = estimator.evaluate( 480 | input_fn=eval_input_fn, steps=FLAGS.max_eval_steps) 481 | 482 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") 483 | with tf.gfile.GFile(output_eval_file, "w") as writer: 484 | tf.logging.info("***** Eval results *****") 485 | for key in sorted(result.keys()): 486 | tf.logging.info(" %s = %s", key, str(result[key])) 487 | writer.write("%s = %s\n" % (key, str(result[key]))) -------------------------------------------------------------------------------- /data/get_bert_embeddings/requirements.txt: -------------------------------------------------------------------------------- 1 | # tensorflow >= 1.11.0 # CPU Version of TensorFlow. 2 | tensorflow-gpu >= 1.11.0 # GPU version of TensorFlow. 3 | -------------------------------------------------------------------------------- /data/get_bert_embeddings/tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | import tensorflow as tf 25 | 26 | 27 | def convert_to_unicode(text, errors="ignore"): 28 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 29 | if isinstance(text, six.text_type): 30 | return text 31 | elif isinstance(text, six.binary_type): 32 | return text.decode("utf-8", errors) 33 | else: 34 | raise ValueError("Unsupported string type: %s" % (type(text))) 35 | 36 | 37 | def printable_text(text): 38 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 39 | 40 | # These functions want `str` for both Python2 and Python3, but in one case 41 | # it's a Unicode string and in the other it's a byte string. 42 | if six.PY3: 43 | if isinstance(text, str): 44 | return text 45 | elif isinstance(text, bytes): 46 | return text.decode("utf-8", "ignore") 47 | else: 48 | raise ValueError("Unsupported string type: %s" % (type(text))) 49 | elif six.PY2: 50 | if isinstance(text, str): 51 | return text 52 | elif isinstance(text, unicode): 53 | return text.encode("utf-8") 54 | else: 55 | raise ValueError("Unsupported string type: %s" % (type(text))) 56 | else: 57 | raise ValueError("Not running on Python2 or Python 3?") 58 | 59 | 60 | def load_vocab(vocab_file): 61 | """Loads a vocabulary file into a dictionary.""" 62 | vocab = collections.OrderedDict() 63 | index = 0 64 | with tf.gfile.GFile(vocab_file, "r") as reader: 65 | while True: 66 | token = convert_to_unicode(reader.readline()) 67 | if not token: 68 | break 69 | token = token.strip() 70 | vocab[token] = index 71 | index += 1 72 | return vocab 73 | 74 | 75 | def convert_tokens_to_ids(vocab, tokens): 76 | """Converts a sequence of tokens into ids using the vocab.""" 77 | ids = [] 78 | for token in tokens: 79 | ids.append(vocab[token]) 80 | return ids 81 | 82 | 83 | def whitespace_tokenize(text): 84 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 85 | text = text.strip() 86 | if not text: 87 | return [] 88 | tokens = text.split() 89 | return tokens 90 | 91 | 92 | class FullTokenizer(object): 93 | """Runs end-to-end tokenziation.""" 94 | 95 | def __init__(self, vocab_file, do_lower_case=True): 96 | self.vocab = load_vocab(vocab_file) 97 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 98 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 99 | 100 | def tokenize(self, text): 101 | split_tokens = [] 102 | for token in self.basic_tokenizer.tokenize(text): 103 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 104 | split_tokens.append(sub_token) 105 | 106 | return split_tokens 107 | 108 | def convert_tokens_to_ids(self, tokens): 109 | return convert_tokens_to_ids(self.vocab, tokens) 110 | 111 | 112 | class BasicTokenizer(object): 113 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 114 | 115 | def __init__(self, do_lower_case=True): 116 | """Constructs a BasicTokenizer. 117 | 118 | Args: 119 | do_lower_case: Whether to lower case the input. 120 | """ 121 | self.do_lower_case = do_lower_case 122 | 123 | def tokenize(self, text): 124 | """Tokenizes a piece of text.""" 125 | text = convert_to_unicode(text) 126 | text = self._clean_text(text) 127 | orig_tokens = whitespace_tokenize(text) 128 | split_tokens = [] 129 | for token in orig_tokens: 130 | if self.do_lower_case: 131 | token = token.lower() 132 | token = self._run_strip_accents(token) 133 | split_tokens.extend(self._run_split_on_punc(token)) 134 | 135 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 136 | return output_tokens 137 | 138 | def _run_strip_accents(self, text): 139 | """Strips accents from a piece of text.""" 140 | text = unicodedata.normalize("NFD", text) 141 | output = [] 142 | for char in text: 143 | cat = unicodedata.category(char) 144 | if cat == "Mn": 145 | continue 146 | output.append(char) 147 | return "".join(output) 148 | 149 | def _run_split_on_punc(self, text): 150 | """Splits punctuation on a piece of text.""" 151 | chars = list(text) 152 | i = 0 153 | start_new_word = True 154 | output = [] 155 | while i < len(chars): 156 | char = chars[i] 157 | if _is_punctuation(char): 158 | output.append([char]) 159 | start_new_word = True 160 | else: 161 | if start_new_word: 162 | output.append([]) 163 | start_new_word = False 164 | output[-1].append(char) 165 | i += 1 166 | 167 | return ["".join(x) for x in output] 168 | 169 | def _clean_text(self, text): 170 | """Performs invalid character removal and whitespace cleanup on text.""" 171 | output = [] 172 | for char in text: 173 | cp = ord(char) 174 | if cp == 0 or cp == 0xfffd or _is_control(char): 175 | continue 176 | if _is_whitespace(char): 177 | output.append(" ") 178 | else: 179 | output.append(char) 180 | return "".join(output) 181 | 182 | 183 | class WordpieceTokenizer(object): 184 | """Runs WordPiece tokenziation.""" 185 | 186 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 187 | self.vocab = vocab 188 | self.unk_token = unk_token 189 | self.max_input_chars_per_word = max_input_chars_per_word 190 | 191 | def tokenize(self, text): 192 | """Tokenizes a piece of text into its word pieces. 193 | 194 | This uses a greedy longest-match-first algorithm to perform tokenization 195 | using the given vocabulary. 196 | 197 | For example: 198 | input = "unaffable" 199 | output = ["un", "##aff", "##able"] 200 | 201 | Args: 202 | text: A single token or whitespace separated tokens. This should have 203 | already been passed through `BasicTokenizer. 204 | 205 | Returns: 206 | A list of wordpiece tokens. 207 | """ 208 | 209 | text = convert_to_unicode(text) 210 | 211 | output_tokens = [] 212 | for token in whitespace_tokenize(text): 213 | chars = list(token) 214 | if len(chars) > self.max_input_chars_per_word: 215 | output_tokens.append(self.unk_token) 216 | continue 217 | 218 | is_bad = False 219 | start = 0 220 | sub_tokens = [] 221 | while start < len(chars): 222 | end = len(chars) 223 | cur_substr = None 224 | while start < end: 225 | substr = "".join(chars[start:end]) 226 | if start > 0: 227 | substr = "##" + substr 228 | if substr in self.vocab: 229 | cur_substr = substr 230 | break 231 | end -= 1 232 | if cur_substr is None: 233 | is_bad = True 234 | break 235 | sub_tokens.append(cur_substr) 236 | start = end 237 | 238 | if is_bad: 239 | output_tokens.append(self.unk_token) 240 | else: 241 | output_tokens.extend(sub_tokens) 242 | return output_tokens 243 | 244 | 245 | def _is_whitespace(char): 246 | """Checks whether `chars` is a whitespace character.""" 247 | # \t, \n, and \r are technically contorl characters but we treat them 248 | # as whitespace since they are generally considered as such. 249 | if char == " " or char == "\t" or char == "\n" or char == "\r": 250 | return True 251 | cat = unicodedata.category(char) 252 | if cat == "Zs": 253 | return True 254 | return False 255 | 256 | 257 | def _is_control(char): 258 | """Checks whether `chars` is a control character.""" 259 | # These are technically control characters but we count them as whitespace 260 | # characters. 261 | if char == "\t" or char == "\n" or char == "\r": 262 | return False 263 | cat = unicodedata.category(char) 264 | if cat.startswith("C"): 265 | return True 266 | return False 267 | 268 | 269 | def _is_punctuation(char): 270 | """Checks whether `chars` is a punctuation character.""" 271 | cp = ord(char) 272 | # We treat all non-letter/number ASCII as punctuation. 273 | # Characters such as "^", "$", and "`" are not in the Unicode 274 | # Punctuation class but we treat them as punctuation anyways, for 275 | # consistency. 276 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 277 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 278 | return True 279 | cat = unicodedata.category(char) 280 | if cat.startswith("P"): 281 | return True 282 | return False 283 | -------------------------------------------------------------------------------- /data/get_bert_embeddings/vcr_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | 4 | import tensorflow as tf 5 | from tqdm import tqdm 6 | 7 | 8 | class InputExample(object): 9 | def __init__(self, unique_id, text_a, text_b, is_correct=None): 10 | self.unique_id = unique_id 11 | self.text_a = text_a 12 | self.text_b = text_b 13 | self.is_correct = is_correct 14 | 15 | 16 | class InputFeatures(object): 17 | """A single set of features of data.""" 18 | 19 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids, is_correct): 20 | self.unique_id = unique_id 21 | self.tokens = tokens 22 | self.input_ids = input_ids 23 | self.input_mask = input_mask 24 | self.input_type_ids = input_type_ids 25 | self.is_correct = is_correct 26 | 27 | 28 | def input_fn_builder(features, seq_length): 29 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 30 | 31 | all_unique_ids = [] 32 | all_input_ids = [] 33 | all_input_mask = [] 34 | all_input_type_ids = [] 35 | 36 | for feature in features: 37 | all_unique_ids.append(feature.unique_id) 38 | all_input_ids.append(feature.input_ids) 39 | all_input_mask.append(feature.input_mask) 40 | all_input_type_ids.append(feature.input_type_ids) 41 | 42 | def input_fn(params): 43 | """The actual input function.""" 44 | batch_size = params["batch_size"] 45 | 46 | num_examples = len(features) 47 | 48 | # This is for demo purposes and does NOT scale to large data sets. We do 49 | # not use Dataset.from_generator() because that uses tf.py_func which is 50 | # not TPU compatible. The right way to load data is with TFRecordReader. 51 | d = tf.data.Dataset.from_tensor_slices({ 52 | "unique_ids": 53 | tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32), 54 | "input_ids": 55 | tf.constant( 56 | all_input_ids, shape=[num_examples, seq_length], 57 | dtype=tf.int32), 58 | "input_mask": 59 | tf.constant( 60 | all_input_mask, 61 | shape=[num_examples, seq_length], 62 | dtype=tf.int32), 63 | "input_type_ids": 64 | tf.constant( 65 | all_input_type_ids, 66 | shape=[num_examples, seq_length], 67 | dtype=tf.int32), 68 | }) 69 | 70 | d = d.batch(batch_size=batch_size, drop_remainder=False) 71 | return d 72 | 73 | return input_fn 74 | 75 | 76 | GENDER_NEUTRAL_NAMES = ['Casey', 'Riley', 'Jessie', 'Jackie', 'Avery', 'Jaime', 'Peyton', 'Kerry', 'Jody', 'Kendall', 77 | 'Peyton', 'Skyler', 'Frankie', 'Pat', 'Quinn', 78 | ] 79 | 80 | 81 | def convert_examples_to_features(examples, seq_length, tokenizer): 82 | """Loads a data file into a list of `InputBatch`s.""" 83 | 84 | features = [] 85 | for (ex_index, example) in enumerate(examples): 86 | # note, this is different because weve already tokenized 87 | tokens_a = example.text_a 88 | 89 | # tokens_b = example.text_b 90 | 91 | tokens_b = None 92 | if example.text_b: 93 | tokens_b = example.text_b 94 | 95 | # The convention in BERT is: 96 | # (a) For sequence pairs: 97 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 98 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 99 | # (b) For single sequences: 100 | # tokens: [CLS] the dog is hairy . [SEP] 101 | # type_ids: 0 0 0 0 0 0 0 102 | # 103 | # Where "type_ids" are used to indicate whether this is the first 104 | # sequence or the second sequence. The embedding vectors for `type=0` and 105 | # `type=1` were learned during pre-training and are added to the wordpiece 106 | # embedding vector (and position vector). This is not *strictly* necessary 107 | # since the [SEP] token unambiguously separates the sequences, but it makes 108 | # it easier for the model to learn the concept of sequences. 109 | # 110 | # For classification tasks, the first vector (corresponding to [CLS]) is 111 | # used as as the "sentence vector". Note that this only makes sense because 112 | # the entire model is fine-tuned. 113 | tokens = [] 114 | input_type_ids = [] 115 | tokens.append("[CLS]") 116 | input_type_ids.append(0) 117 | for token in tokens_a: 118 | tokens.append(token) 119 | input_type_ids.append(0) 120 | tokens.append("[SEP]") 121 | input_type_ids.append(0) 122 | 123 | if tokens_b: 124 | for token in tokens_b: 125 | tokens.append(token) 126 | input_type_ids.append(1) 127 | tokens.append("[SEP]") 128 | input_type_ids.append(1) 129 | 130 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 131 | 132 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 133 | # tokens are attended to. 134 | input_mask = [1] * len(input_ids) 135 | 136 | # Zero-pad up to the sequence length. 137 | while len(input_ids) < seq_length: 138 | input_ids.append(0) 139 | input_mask.append(0) 140 | input_type_ids.append(0) 141 | 142 | assert len(input_ids) == seq_length 143 | assert len(input_mask) == seq_length 144 | assert len(input_type_ids) == seq_length 145 | 146 | if ex_index < 5: 147 | tf.logging.info("*** Example ***") 148 | tf.logging.info("unique_id: %s" % (example.unique_id)) 149 | tf.logging.info("tokens: %s" % " ".join([str(x) for x in tokens])) 150 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 151 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 152 | tf.logging.info( 153 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) 154 | 155 | features.append( 156 | InputFeatures( 157 | unique_id=example.unique_id, 158 | tokens=tokens, 159 | input_ids=input_ids, 160 | input_mask=input_mask, 161 | input_type_ids=input_type_ids, 162 | is_correct=example.is_correct)) 163 | return features 164 | 165 | 166 | ################################################################################################## 167 | 168 | def _fix_tokenization(tokenized_sent, obj_to_type, det_hist=None): 169 | if det_hist is None: 170 | det_hist = {} 171 | else: 172 | det_hist = {k: v for k, v in det_hist.items()} 173 | 174 | obj2count = defaultdict(int) 175 | # Comment this in if you want to preserve stuff from the earlier rounds. 176 | for v in det_hist.values(): 177 | obj2count[v.split('_')[0]] += 1 178 | 179 | new_tokenization = [] 180 | for i, tok in enumerate(tokenized_sent): 181 | if isinstance(tok, list): 182 | for int_name in tok: 183 | if int_name not in det_hist: 184 | if obj_to_type[int_name] == 'person': 185 | det_hist[int_name] = GENDER_NEUTRAL_NAMES[obj2count['person'] % len(GENDER_NEUTRAL_NAMES)] 186 | else: 187 | det_hist[int_name] = obj_to_type[int_name] 188 | obj2count[obj_to_type[int_name]] += 1 189 | new_tokenization.append(det_hist[int_name]) 190 | else: 191 | new_tokenization.append(tok) 192 | return new_tokenization, det_hist 193 | 194 | 195 | def fix_item(item, answer_label=None, rationales=True): 196 | if rationales: 197 | assert answer_label is not None 198 | ctx = item['question'] + item['answer_choices'][answer_label] 199 | else: 200 | ctx = item['question'] 201 | 202 | q_tok, hist = _fix_tokenization(ctx, item['objects']) 203 | choices = item['rationale_choices'] if rationales else item['answer_choices'] 204 | a_toks = [_fix_tokenization(choice, obj_to_type=item['objects'], det_hist=hist)[0] for choice in choices] 205 | return q_tok, a_toks 206 | 207 | 208 | def retokenize_with_alignment(span, tokenizer): 209 | tokens = [] 210 | alignment = [] 211 | for i, tok in enumerate(span): 212 | for token in tokenizer.basic_tokenizer.tokenize(tok): 213 | for sub_token in tokenizer.wordpiece_tokenizer.tokenize(token): 214 | tokens.append(sub_token) 215 | alignment.append(i) 216 | return tokens, alignment 217 | 218 | 219 | def process_ctx_ans_for_bert(ctx_raw, ans_raw, tokenizer, counter, endingonly, max_seq_length, is_correct): 220 | """ 221 | Processes a Q/A pair for BERT 222 | :param ctx_raw: 223 | :param ans_raw: 224 | :param tokenizer: 225 | :param counter: 226 | :param endingonly: 227 | :param max_seq_length: 228 | :param is_correct: 229 | :return: 230 | """ 231 | context = retokenize_with_alignment(ctx_raw, tokenizer) 232 | answer = retokenize_with_alignment(ans_raw, tokenizer) 233 | 234 | if endingonly: 235 | take_away_from_ctx = len(answer[0]) - max_seq_length + 2 236 | if take_away_from_ctx > 0: 237 | answer = (answer[0][take_away_from_ctx:], 238 | answer[1][take_away_from_ctx:]) 239 | 240 | return InputExample(unique_id=counter, text_a=answer[0], text_b=None, 241 | is_correct=is_correct), answer[1], None 242 | 243 | len_total = len(context[0]) + len(answer[0]) + 3 244 | if len_total > max_seq_length: 245 | take_away_from_ctx = min((len_total - max_seq_length + 1) // 2, max(len(context) - 32, 0)) 246 | take_away_from_answer = len_total - max_seq_length + take_away_from_ctx 247 | context = (context[0][take_away_from_ctx:], 248 | context[1][take_away_from_ctx:]) 249 | answer = (answer[0][take_away_from_answer:], 250 | answer[1][take_away_from_answer:]) 251 | 252 | print("FOR Q{} A {}\nLTotal was {} so take away {} from ctx and {} from answer".format( 253 | ' '.join(context[0]), ' '.join(answer[0]), len_total, take_away_from_ctx, 254 | take_away_from_answer), flush=True) 255 | assert len(context[0]) + len(answer[0]) + 3 <= max_seq_length 256 | 257 | return InputExample(unique_id=counter, 258 | text_a=context[0], 259 | text_b=answer[0], is_correct=is_correct), context[1], answer[1] 260 | 261 | 262 | def data_iter(data_fn, tokenizer, max_seq_length, endingonly): 263 | counter = 0 264 | with open(data_fn, 'r') as f: 265 | for line_no, line in enumerate(tqdm(f)): 266 | item = json.loads(line) 267 | q_tokens, a_tokens = fix_item(item, rationales=False) 268 | qa_tokens, r_tokens = fix_item(item, answer_label=item['answer_label'], rationales=True) 269 | 270 | for (name, ctx, answers) in (('qa', q_tokens, a_tokens), ('qar', qa_tokens, r_tokens)): 271 | for i in range(4): 272 | is_correct = item['answer_label' if name == 'qa' else 'rationale_label'] == i 273 | 274 | yield process_ctx_ans_for_bert(ctx, answers[i], tokenizer, counter=counter, endingonly=endingonly, 275 | max_seq_length=max_seq_length, is_correct=is_correct) 276 | counter += 1 277 | 278 | 279 | def data_iter_test(data_fn, tokenizer, max_seq_length, endingonly): 280 | """ Essentially this needs to be a bit separate from data_iter because we don't know which answer is correct.""" 281 | counter = 0 282 | with open(data_fn, 'r') as f: 283 | for line_no, line in enumerate(tqdm(f)): 284 | item = json.loads(line) 285 | q_tokens, a_tokens = fix_item(item, rationales=False) 286 | 287 | # First yield the answers 288 | for i in range(4): 289 | yield process_ctx_ans_for_bert(q_tokens, a_tokens[i], tokenizer, counter=counter, endingonly=endingonly, 290 | max_seq_length=max_seq_length, is_correct=False) 291 | counter += 1 292 | 293 | for i in range(4): 294 | qa_tokens, r_tokens = fix_item(item, answer_label=i, rationales=True) 295 | for r_token in r_tokens: 296 | yield process_ctx_ans_for_bert(qa_tokens, r_token, tokenizer, counter=counter, 297 | endingonly=endingonly, 298 | max_seq_length=max_seq_length, is_correct=False) 299 | counter += 1 300 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rowanz/r2c/77813d9e335711759c25df79c348a7c2a8275d72/dataloaders/__init__.py -------------------------------------------------------------------------------- /dataloaders/bert_field.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | import textwrap 3 | 4 | from overrides import overrides 5 | from spacy.tokens import Token as SpacyToken 6 | import torch 7 | 8 | from allennlp.common.checks import ConfigurationError 9 | from allennlp.data.fields.sequence_field import SequenceField 10 | from allennlp.data.tokenizers.token import Token 11 | from allennlp.data.token_indexers.token_indexer import TokenIndexer, TokenType 12 | from allennlp.data.vocabulary import Vocabulary 13 | from allennlp.nn import util 14 | import numpy 15 | TokenList = List[TokenType] # pylint: disable=invalid-name 16 | 17 | 18 | # This will work for anything really 19 | class BertField(SequenceField[Dict[str, torch.Tensor]]): 20 | """ 21 | A class representing an array, which could have arbitrary dimensions. 22 | A batch of these arrays are padded to the max dimension length in the batch 23 | for each dimension. 24 | """ 25 | def __init__(self, tokens: List[Token], embs: numpy.ndarray, padding_value: int = 0, 26 | token_indexers=None) -> None: 27 | self.tokens = tokens 28 | self.embs = embs 29 | self.padding_value = padding_value 30 | 31 | if len(self.tokens) != self.embs.shape[0]: 32 | raise ValueError("The tokens you passed into the BERTField, {} " 33 | "aren't the same size as the embeddings of shape {}".format(self.tokens, self.embs.shape)) 34 | assert len(self.tokens) == self.embs.shape[0] 35 | 36 | @overrides 37 | def sequence_length(self) -> int: 38 | return len(self.tokens) 39 | 40 | 41 | @overrides 42 | def get_padding_lengths(self) -> Dict[str, int]: 43 | return {'num_tokens': self.sequence_length()} 44 | 45 | @overrides 46 | def as_tensor(self, padding_lengths: Dict[str, int]) -> Dict[str, torch.Tensor]: 47 | num_tokens = padding_lengths['num_tokens'] 48 | 49 | new_arr = numpy.ones((num_tokens, self.embs.shape[1]), 50 | dtype=numpy.float32) * self.padding_value 51 | new_arr[:self.sequence_length()] = self.embs 52 | 53 | tensor = torch.from_numpy(new_arr) 54 | return {'bert': tensor} 55 | 56 | @overrides 57 | def empty_field(self): 58 | return BertField([], numpy.array([], dtype="float32"),padding_value=self.padding_value) 59 | 60 | @overrides 61 | def batch_tensors(self, tensor_list: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: 62 | # pylint: disable=no-self-use 63 | # This is creating a dict of {token_indexer_key: batch_tensor} for each token indexer used 64 | # to index this field. 65 | return util.batch_tensor_dicts(tensor_list) 66 | 67 | 68 | def __str__(self) -> str: 69 | return f"BertField: {self.tokens} and {self.embs.shape}." 70 | -------------------------------------------------------------------------------- /dataloaders/box_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import scipy 5 | import warnings 6 | from torchvision.datasets.folder import default_loader 7 | from torchvision.transforms import functional 8 | from config import USE_IMAGENET_PRETRAINED 9 | 10 | 11 | ##### Image 12 | def load_image(img_fn): 13 | """Load the specified image and return a [H,W,3] Numpy array. 14 | """ 15 | return default_loader(img_fn) 16 | # # Load image 17 | # image = skimage.io.imread(img_fn) 18 | # # If grayscale. Convert to RGB for consistency. 19 | # if image.ndim != 3: 20 | # image = skimage.color.gray2rgb(image) 21 | # # If has an alpha channel, remove it for consistency 22 | # if image.shape[-1] == 4: 23 | # image = image[..., :3] 24 | # return image 25 | 26 | 27 | # Let's do 16x9 28 | # Two common resolutions: 16x9 and 16/6 -> go to 16x8 as that's simple 29 | # let's say width is 576. for neural motifs it was 576*576 pixels so 331776. here we have 2x*x = 331776-> 408 base 30 | # so the best thing that's divisible by 4 is 384. that's 31 | def resize_image(image, desired_width=768, desired_height=384, random_pad=False): 32 | """Resizes an image keeping the aspect ratio mostly unchanged. 33 | 34 | Returns: 35 | image: the resized image 36 | window: (x1, y1, x2, y2). If max_dim is provided, padding might 37 | be inserted in the returned image. If so, this window is the 38 | coordinates of the image part of the full image (excluding 39 | the padding). The x2, y2 pixels are not included. 40 | scale: The scale factor used to resize the image 41 | padding: Padding added to the image [left, top, right, bottom] 42 | """ 43 | # Default window (x1, y1, x2, y2) and default scale == 1. 44 | w, h = image.size 45 | 46 | width_scale = desired_width / w 47 | height_scale = desired_height / h 48 | scale = min(width_scale, height_scale) 49 | 50 | # Resize image using bilinear interpolation 51 | if scale != 1: 52 | image = functional.resize(image, (round(h * scale), round(w * scale))) 53 | w, h = image.size 54 | y_pad = desired_height - h 55 | x_pad = desired_width - w 56 | top_pad = random.randint(0, y_pad) if random_pad else y_pad // 2 57 | left_pad = random.randint(0, x_pad) if random_pad else x_pad // 2 58 | 59 | padding = (left_pad, top_pad, x_pad - left_pad, y_pad - top_pad) 60 | assert all([x >= 0 for x in padding]) 61 | image = functional.pad(image, padding) 62 | window = [left_pad, top_pad, w + left_pad, h + top_pad] 63 | 64 | return image, window, scale, padding 65 | 66 | 67 | if USE_IMAGENET_PRETRAINED: 68 | def to_tensor_and_normalize(image): 69 | return functional.normalize(functional.to_tensor(image), mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 70 | else: 71 | # For COCO pretrained 72 | def to_tensor_and_normalize(image): 73 | tensor255 = functional.to_tensor(image) * 255 74 | return functional.normalize(tensor255, mean=(102.9801, 115.9465, 122.7717), std=(1, 1, 1)) -------------------------------------------------------------------------------- /dataloaders/cocoontology.json: -------------------------------------------------------------------------------- 1 | { 2 | "1": { 3 | "name": "person", 4 | "supercategory": "person" 5 | }, 6 | "2": { 7 | "name": "bicycle", 8 | "supercategory": "vehicle" 9 | }, 10 | "3": { 11 | "name": "car", 12 | "supercategory": "vehicle" 13 | }, 14 | "4": { 15 | "name": "motorcycle", 16 | "supercategory": "vehicle" 17 | }, 18 | "5": { 19 | "name": "airplane", 20 | "supercategory": "vehicle" 21 | }, 22 | "6": { 23 | "name": "bus", 24 | "supercategory": "vehicle" 25 | }, 26 | "7": { 27 | "name": "train", 28 | "supercategory": "vehicle" 29 | }, 30 | "8": { 31 | "name": "truck", 32 | "supercategory": "vehicle" 33 | }, 34 | "9": { 35 | "name": "boat", 36 | "supercategory": "vehicle" 37 | }, 38 | "10": { 39 | "name": "trafficlight", 40 | "supercategory": "furniture" 41 | }, 42 | "11": { 43 | "name": "firehydrant", 44 | "supercategory": "furniture" 45 | }, 46 | "13": { 47 | "name": "stopsign", 48 | "supercategory": "furniture" 49 | }, 50 | "14": { 51 | "name": "parkingmeter", 52 | "supercategory": "furniture" 53 | }, 54 | "15": { 55 | "name": "bench", 56 | "supercategory": "furniture" 57 | }, 58 | "16": { 59 | "name": "bird", 60 | "supercategory": "animal" 61 | }, 62 | "17": { 63 | "name": "cat", 64 | "supercategory": "animal" 65 | }, 66 | "18": { 67 | "name": "dog", 68 | "supercategory": "animal" 69 | }, 70 | "19": { 71 | "name": "horse", 72 | "supercategory": "animal" 73 | }, 74 | "20": { 75 | "name": "sheep", 76 | "supercategory": "animal" 77 | }, 78 | "21": { 79 | "name": "cow", 80 | "supercategory": "animal" 81 | }, 82 | "22": { 83 | "name": "elephant", 84 | "supercategory": "animal" 85 | }, 86 | "23": { 87 | "name": "bear", 88 | "supercategory": "animal" 89 | }, 90 | "24": { 91 | "name": "zebra", 92 | "supercategory": "animal" 93 | }, 94 | "25": { 95 | "name": "giraffe", 96 | "supercategory": "animal" 97 | }, 98 | "27": { 99 | "name": "backpack", 100 | "supercategory": "accessory" 101 | }, 102 | "28": { 103 | "name": "umbrella", 104 | "supercategory": "accessory" 105 | }, 106 | "31": { 107 | "name": "handbag", 108 | "supercategory": "accessory" 109 | }, 110 | "32": { 111 | "name": "tie", 112 | "supercategory": "accessory" 113 | }, 114 | "33": { 115 | "name": "suitcase", 116 | "supercategory": "accessory" 117 | }, 118 | "34": { 119 | "name": "frisbee", 120 | "supercategory": "object" 121 | }, 122 | "35": { 123 | "name": "skis", 124 | "supercategory": "object" 125 | }, 126 | "36": { 127 | "name": "snowboard", 128 | "supercategory": "object" 129 | }, 130 | "37": { 131 | "name": "sportsball", 132 | "supercategory": "object" 133 | }, 134 | "38": { 135 | "name": "kite", 136 | "supercategory": "object" 137 | }, 138 | "39": { 139 | "name": "baseballbat", 140 | "supercategory": "object" 141 | }, 142 | "40": { 143 | "name": "baseballglove", 144 | "supercategory": "object" 145 | }, 146 | "41": { 147 | "name": "skateboard", 148 | "supercategory": "object" 149 | }, 150 | "42": { 151 | "name": "surfboard", 152 | "supercategory": "object" 153 | }, 154 | "43": { 155 | "name": "tennisracket", 156 | "supercategory": "object" 157 | }, 158 | "44": { 159 | "name": "bottle", 160 | "supercategory": "object" 161 | }, 162 | "46": { 163 | "name": "wineglass", 164 | "supercategory": "object" 165 | }, 166 | "47": { 167 | "name": "cup", 168 | "supercategory": "object" 169 | }, 170 | "48": { 171 | "name": "fork", 172 | "supercategory": "object" 173 | }, 174 | "49": { 175 | "name": "knife", 176 | "supercategory": "object" 177 | }, 178 | "50": { 179 | "name": "spoon", 180 | "supercategory": "object" 181 | }, 182 | "51": { 183 | "name": "bowl", 184 | "supercategory": "object" 185 | }, 186 | "52": { 187 | "name": "banana", 188 | "supercategory": "food" 189 | }, 190 | "53": { 191 | "name": "apple", 192 | "supercategory": "food" 193 | }, 194 | "54": { 195 | "name": "sandwich", 196 | "supercategory": "food" 197 | }, 198 | "55": { 199 | "name": "orange", 200 | "supercategory": "food" 201 | }, 202 | "56": { 203 | "name": "broccoli", 204 | "supercategory": "food" 205 | }, 206 | "57": { 207 | "name": "carrot", 208 | "supercategory": "food" 209 | }, 210 | "58": { 211 | "name": "hotdog", 212 | "supercategory": "food" 213 | }, 214 | "59": { 215 | "name": "pizza", 216 | "supercategory": "food" 217 | }, 218 | "60": { 219 | "name": "donut", 220 | "supercategory": "food" 221 | }, 222 | "61": { 223 | "name": "cake", 224 | "supercategory": "food" 225 | }, 226 | "62": { 227 | "name": "chair", 228 | "supercategory": "furniture" 229 | }, 230 | "63": { 231 | "name": "couch", 232 | "supercategory": "furniture" 233 | }, 234 | "64": { 235 | "name": "pottedplant", 236 | "supercategory": "furniture" 237 | }, 238 | "65": { 239 | "name": "bed", 240 | "supercategory": "furniture" 241 | }, 242 | "67": { 243 | "name": "diningtable", 244 | "supercategory": "furniture" 245 | }, 246 | "70": { 247 | "name": "toilet", 248 | "supercategory": "furniture" 249 | }, 250 | "72": { 251 | "name": "tv", 252 | "supercategory": "object" 253 | }, 254 | "73": { 255 | "name": "laptop", 256 | "supercategory": "object" 257 | }, 258 | "74": { 259 | "name": "mouse", 260 | "supercategory": "object" 261 | }, 262 | "75": { 263 | "name": "remote", 264 | "supercategory": "object" 265 | }, 266 | "76": { 267 | "name": "keyboard", 268 | "supercategory": "object" 269 | }, 270 | "77": { 271 | "name": "cellphone", 272 | "supercategory": "object" 273 | }, 274 | "78": { 275 | "name": "microwave", 276 | "supercategory": "object" 277 | }, 278 | "79": { 279 | "name": "oven", 280 | "supercategory": "object" 281 | }, 282 | "80": { 283 | "name": "toaster", 284 | "supercategory": "object" 285 | }, 286 | "81": { 287 | "name": "sink", 288 | "supercategory": "object" 289 | }, 290 | "82": { 291 | "name": "refrigerator", 292 | "supercategory": "object" 293 | }, 294 | "84": { 295 | "name": "book", 296 | "supercategory": "object" 297 | }, 298 | "85": { 299 | "name": "clock", 300 | "supercategory": "object" 301 | }, 302 | "86": { 303 | "name": "vase", 304 | "supercategory": "object" 305 | }, 306 | "87": { 307 | "name": "scissors", 308 | "supercategory": "object" 309 | }, 310 | "88": { 311 | "name": "teddybear", 312 | "supercategory": "object" 313 | }, 314 | "89": { 315 | "name": "hairdrier", 316 | "supercategory": "object" 317 | }, 318 | "90": { 319 | "name": "toothbrush", 320 | "supercategory": "object" 321 | } 322 | } -------------------------------------------------------------------------------- /dataloaders/mask_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | from matplotlib import path 4 | matplotlib.use('agg') 5 | 6 | 7 | def _spaced_points(low, high,n): 8 | """ We want n points between low and high, but we don't want them to touch either side""" 9 | padding = (high-low)/(n*2) 10 | return np.linspace(low + padding, high-padding, num=n) 11 | 12 | def make_mask(mask_size, box, polygons_list): 13 | """ 14 | Mask size: int about how big mask will be 15 | box: [x1, y1, x2, y2, conf.] 16 | polygons_list: List of polygons that go inside the box 17 | """ 18 | mask = np.zeros((mask_size, mask_size), dtype=np.bool) 19 | 20 | xy = np.meshgrid(_spaced_points(box[0], box[2], n=mask_size), 21 | _spaced_points(box[1], box[3], n=mask_size)) 22 | xy_flat = np.stack(xy, 2).reshape((-1, 2)) 23 | 24 | for polygon in polygons_list: 25 | polygon_path = path.Path(polygon) 26 | mask |= polygon_path.contains_points(xy_flat).reshape((mask_size, mask_size)) 27 | return mask.astype(np.float32) 28 | # 29 | #from matplotlib import pyplot as plt 30 | # 31 | # 32 | #with open('XdtbL0dP0X0@44.json', 'r') as f: 33 | # metadata = json.load(f) 34 | #from time import time 35 | #s = time() 36 | #for i in range(100): 37 | # mask = make_mask(14, metadata['boxes'][3], metadata['segms'][3]) 38 | #print("Elapsed {:3f}s".format(time()-s)) 39 | #plt.imshow(mask) -------------------------------------------------------------------------------- /dataloaders/vcr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataloaders for VCR 3 | """ 4 | import json 5 | import os 6 | 7 | import numpy as np 8 | import torch 9 | from allennlp.data.dataset import Batch 10 | from allennlp.data.fields import TextField, ListField, LabelField, SequenceLabelField, ArrayField, MetadataField 11 | from allennlp.data.instance import Instance 12 | from allennlp.data.token_indexers import ELMoTokenCharactersIndexer 13 | from allennlp.data.tokenizers import Token 14 | from allennlp.data.vocabulary import Vocabulary 15 | from allennlp.nn.util import get_text_field_mask 16 | from torch.utils.data import Dataset 17 | from dataloaders.box_utils import load_image, resize_image, to_tensor_and_normalize 18 | from dataloaders.mask_utils import make_mask 19 | from dataloaders.bert_field import BertField 20 | import h5py 21 | from copy import deepcopy 22 | from config import VCR_IMAGES_DIR, VCR_ANNOTS_DIR 23 | 24 | GENDER_NEUTRAL_NAMES = ['Casey', 'Riley', 'Jessie', 'Jackie', 'Avery', 'Jaime', 'Peyton', 'Kerry', 'Jody', 'Kendall', 25 | 'Peyton', 'Skyler', 'Frankie', 'Pat', 'Quinn'] 26 | 27 | 28 | # Here's an example jsonl 29 | # { 30 | # "movie": "3015_CHARLIE_ST_CLOUD", 31 | # "objects": ["person", "person", "person", "car"], 32 | # "interesting_scores": [0], 33 | # "answer_likelihood": "possible", 34 | # "img_fn": "lsmdc_3015_CHARLIE_ST_CLOUD/3015_CHARLIE_ST_CLOUD_00.23.57.935-00.24.00.783@0.jpg", 35 | # "metadata_fn": "lsmdc_3015_CHARLIE_ST_CLOUD/3015_CHARLIE_ST_CLOUD_00.23.57.935-00.24.00.783@0.json", 36 | # "answer_orig": "No she does not", 37 | # "question_orig": "Does 3 feel comfortable?", 38 | # "rationale_orig": "She is standing with her arms crossed and looks disturbed", 39 | # "question": ["Does", [2], "feel", "comfortable", "?"], 40 | # "answer_match_iter": [3, 0, 2, 1], 41 | # "answer_sources": [3287, 0, 10184, 2260], 42 | # "answer_choices": [ 43 | # ["Yes", "because", "the", "person", "sitting", "next", "to", "her", "is", "smiling", "."], 44 | # ["No", "she", "does", "not", "."], 45 | # ["Yes", ",", "she", "is", "wearing", "something", "with", "thin", "straps", "."], 46 | # ["Yes", ",", "she", "is", "cold", "."]], 47 | # "answer_label": 1, 48 | # "rationale_choices": [ 49 | # ["There", "is", "snow", "on", "the", "ground", ",", "and", 50 | # "she", "is", "wearing", "a", "coat", "and", "hate", "."], 51 | # ["She", "is", "standing", "with", "her", "arms", "crossed", "and", "looks", "disturbed", "."], 52 | # ["She", "is", "sitting", "very", "rigidly", "and", "tensely", "on", "the", "edge", "of", "the", 53 | # "bed", ".", "her", "posture", "is", "not", "relaxed", "and", "her", "face", "looks", "serious", "."], 54 | # [[2], "is", "laying", "in", "bed", "but", "not", "sleeping", ".", 55 | # "she", "looks", "sad", "and", "is", "curled", "into", "a", "ball", "."]], 56 | # "rationale_sources": [1921, 0, 9750, 25743], 57 | # "rationale_match_iter": [3, 0, 2, 1], 58 | # "rationale_label": 1, 59 | # "img_id": "train-0", 60 | # "question_number": 0, 61 | # "annot_id": "train-0", 62 | # "match_fold": "train-0", 63 | # "match_index": 0, 64 | # } 65 | 66 | def _fix_tokenization(tokenized_sent, bert_embs, old_det_to_new_ind, obj_to_type, token_indexers, pad_ind=-1): 67 | """ 68 | Turn a detection list into what we want: some text, as well as some tags. 69 | :param tokenized_sent: Tokenized sentence with detections collapsed to a list. 70 | :param old_det_to_new_ind: Mapping of the old ID -> new ID (which will be used as the tag) 71 | :param obj_to_type: [person, person, pottedplant] indexed by the old labels 72 | :return: tokenized sentence 73 | """ 74 | 75 | new_tokenization_with_tags = [] 76 | for tok in tokenized_sent: 77 | if isinstance(tok, list): 78 | for int_name in tok: 79 | obj_type = obj_to_type[int_name] 80 | new_ind = old_det_to_new_ind[int_name] 81 | if new_ind < 0: 82 | raise ValueError("Oh no, the new index is negative! that means it's invalid. {} {}".format( 83 | tokenized_sent, old_det_to_new_ind 84 | )) 85 | text_to_use = GENDER_NEUTRAL_NAMES[ 86 | new_ind % len(GENDER_NEUTRAL_NAMES)] if obj_type == 'person' else obj_type 87 | new_tokenization_with_tags.append((text_to_use, new_ind)) 88 | else: 89 | new_tokenization_with_tags.append((tok, pad_ind)) 90 | 91 | text_field = BertField([Token(x[0]) for x in new_tokenization_with_tags], 92 | bert_embs, 93 | padding_value=0) 94 | tags = SequenceLabelField([x[1] for x in new_tokenization_with_tags], text_field) 95 | return text_field, tags 96 | 97 | 98 | class VCR(Dataset): 99 | def __init__(self, split, mode, only_use_relevant_dets=True, add_image_as_a_box=True, embs_to_load='bert_da', 100 | conditioned_answer_choice=0): 101 | """ 102 | 103 | :param split: train, val, or test 104 | :param mode: answer or rationale 105 | :param only_use_relevant_dets: True, if we will only use the detections mentioned in the question and answer. 106 | False, if we should use all detections. 107 | :param add_image_as_a_box: True to add the image in as an additional 'detection'. It'll go first in the list 108 | of objects. 109 | :param embs_to_load: Which precomputed embeddings to load. 110 | :param conditioned_answer_choice: If you're in test mode, the answer labels aren't provided, which could be 111 | a problem for the QA->R task. Pass in 'conditioned_answer_choice=i' 112 | to always condition on the i-th answer. 113 | """ 114 | self.split = split 115 | self.mode = mode 116 | self.only_use_relevant_dets = only_use_relevant_dets 117 | print("Only relevant dets" if only_use_relevant_dets else "Using all detections", flush=True) 118 | 119 | self.add_image_as_a_box = add_image_as_a_box 120 | self.conditioned_answer_choice = conditioned_answer_choice 121 | 122 | with open(os.path.join(VCR_ANNOTS_DIR, '{}.jsonl'.format(split)), 'r') as f: 123 | self.items = [json.loads(s) for s in f] 124 | 125 | if split not in ('test', 'train', 'val'): 126 | raise ValueError("Mode must be in test, train, or val. Supplied {}".format(mode)) 127 | 128 | if mode not in ('answer', 'rationale'): 129 | raise ValueError("split must be answer or rationale") 130 | 131 | self.token_indexers = {'elmo': ELMoTokenCharactersIndexer()} 132 | self.vocab = Vocabulary() 133 | 134 | with open(os.path.join(os.path.dirname(VCR_ANNOTS_DIR), 'dataloaders', 'cocoontology.json'), 'r') as f: 135 | coco = json.load(f) 136 | self.coco_objects = ['__background__'] + [x['name'] for k, x in sorted(coco.items(), key=lambda x: int(x[0]))] 137 | self.coco_obj_to_ind = {o: i for i, o in enumerate(self.coco_objects)} 138 | 139 | self.embs_to_load = embs_to_load 140 | self.h5fn = os.path.join(VCR_ANNOTS_DIR, f'{self.embs_to_load}_{self.mode}_{self.split}.h5') 141 | print("Loading embeddings from {}".format(self.h5fn), flush=True) 142 | 143 | @property 144 | def is_train(self): 145 | return self.split == 'train' 146 | 147 | @classmethod 148 | def splits(cls, **kwargs): 149 | """ Helper method to generate splits of the dataset""" 150 | kwargs_copy = {x: y for x, y in kwargs.items()} 151 | if 'mode' not in kwargs: 152 | kwargs_copy['mode'] = 'answer' 153 | train = cls(split='train', **kwargs_copy) 154 | val = cls(split='val', **kwargs_copy) 155 | test = cls(split='test', **kwargs_copy) 156 | return train, val, test 157 | 158 | @classmethod 159 | def eval_splits(cls, **kwargs): 160 | """ Helper method to generate splits of the dataset. Use this for testing, because it will 161 | condition on everything.""" 162 | for forbidden_key in ['mode', 'split', 'conditioned_answer_choice']: 163 | if forbidden_key in kwargs: 164 | raise ValueError(f"don't supply {forbidden_key} to eval_splits()") 165 | 166 | stuff_to_return = [cls(split='test', mode='answer', **kwargs)] + [ 167 | cls(split='test', mode='rationale', conditioned_answer_choice=i, **kwargs) for i in range(4)] 168 | return tuple(stuff_to_return) 169 | 170 | def __len__(self): 171 | return len(self.items) 172 | 173 | def _get_dets_to_use(self, item): 174 | """ 175 | We might want to use fewer detectiosn so lets do so. 176 | :param item: 177 | :param question: 178 | :param answer_choices: 179 | :return: 180 | """ 181 | # Load questions and answers 182 | question = item['question'] 183 | answer_choices = item['{}_choices'.format(self.mode)] 184 | 185 | if self.only_use_relevant_dets: 186 | dets2use = np.zeros(len(item['objects']), dtype=bool) 187 | people = np.array([x == 'person' for x in item['objects']], dtype=bool) 188 | for sent in answer_choices + [question]: 189 | for possibly_det_list in sent: 190 | if isinstance(possibly_det_list, list): 191 | for tag in possibly_det_list: 192 | if tag >= 0 and tag < len(item['objects']): # sanity check 193 | dets2use[tag] = True 194 | elif possibly_det_list.lower() in ('everyone', 'everyones'): 195 | dets2use |= people 196 | if not dets2use.any(): 197 | dets2use |= people 198 | else: 199 | dets2use = np.ones(len(item['objects']), dtype=bool) 200 | 201 | # we will use these detections 202 | dets2use = np.where(dets2use)[0] 203 | 204 | old_det_to_new_ind = np.zeros(len(item['objects']), dtype=np.int32) - 1 205 | old_det_to_new_ind[dets2use] = np.arange(dets2use.shape[0], dtype=np.int32) 206 | 207 | # If we add the image as an extra box then the 0th will be the image. 208 | if self.add_image_as_a_box: 209 | old_det_to_new_ind[dets2use] += 1 210 | old_det_to_new_ind = old_det_to_new_ind.tolist() 211 | return dets2use, old_det_to_new_ind 212 | 213 | def __getitem__(self, index): 214 | # if self.split == 'test': 215 | # raise ValueError("blind test mode not supported quite yet") 216 | item = deepcopy(self.items[index]) 217 | 218 | ################################################################### 219 | # Load questions and answers 220 | if self.mode == 'rationale': 221 | conditioned_label = item['answer_label'] if self.split != 'test' else self.conditioned_answer_choice 222 | item['question'] += item['answer_choices'][conditioned_label] 223 | 224 | answer_choices = item['{}_choices'.format(self.mode)] 225 | dets2use, old_det_to_new_ind = self._get_dets_to_use(item) 226 | 227 | ################################################################### 228 | # Load in BERT. We'll get contextual representations of the context and the answer choices 229 | # grp_items = {k: np.array(v, dtype=np.float16) for k, v in self.get_h5_group(index).items()} 230 | with h5py.File(self.h5fn, 'r') as h5: 231 | grp_items = {k: np.array(v, dtype=np.float16) for k, v in h5[str(index)].items()} 232 | 233 | # Essentially we need to condition on the right answer choice here, if we're doing QA->R. We will always 234 | # condition on the `conditioned_answer_choice.` 235 | condition_key = self.conditioned_answer_choice if self.split == "test" and self.mode == "rationale" else "" 236 | 237 | instance_dict = {} 238 | if 'endingonly' not in self.embs_to_load: 239 | questions_tokenized, question_tags = zip(*[_fix_tokenization( 240 | item['question'], 241 | grp_items[f'ctx_{self.mode}{condition_key}{i}'], 242 | old_det_to_new_ind, 243 | item['objects'], 244 | token_indexers=self.token_indexers, 245 | pad_ind=0 if self.add_image_as_a_box else -1 246 | ) for i in range(4)]) 247 | instance_dict['question'] = ListField(questions_tokenized) 248 | instance_dict['question_tags'] = ListField(question_tags) 249 | 250 | answers_tokenized, answer_tags = zip(*[_fix_tokenization( 251 | answer, 252 | grp_items[f'answer_{self.mode}{condition_key}{i}'], 253 | old_det_to_new_ind, 254 | item['objects'], 255 | token_indexers=self.token_indexers, 256 | pad_ind=0 if self.add_image_as_a_box else -1 257 | ) for i, answer in enumerate(answer_choices)]) 258 | 259 | instance_dict['answers'] = ListField(answers_tokenized) 260 | instance_dict['answer_tags'] = ListField(answer_tags) 261 | if self.split != 'test': 262 | instance_dict['label'] = LabelField(item['{}_label'.format(self.mode)], skip_indexing=True) 263 | instance_dict['metadata'] = MetadataField({'annot_id': item['annot_id'], 'ind': index, 'movie': item['movie'], 264 | 'img_fn': item['img_fn'], 265 | 'question_number': item['question_number']}) 266 | 267 | ################################################################### 268 | # Load image now and rescale it. Might have to subtract the mean and whatnot here too. 269 | image = load_image(os.path.join(VCR_IMAGES_DIR, item['img_fn'])) 270 | image, window, img_scale, padding = resize_image(image, random_pad=self.is_train) 271 | image = to_tensor_and_normalize(image) 272 | c, h, w = image.shape 273 | 274 | ################################################################### 275 | # Load boxes. 276 | with open(os.path.join(VCR_IMAGES_DIR, item['metadata_fn']), 'r') as f: 277 | metadata = json.load(f) 278 | 279 | # [nobj, 14, 14] 280 | segms = np.stack([make_mask(mask_size=14, box=metadata['boxes'][i], polygons_list=metadata['segms'][i]) 281 | for i in dets2use]) 282 | 283 | # Chop off the final dimension, that's the confidence 284 | boxes = np.array(metadata['boxes'])[dets2use, :-1] 285 | # Possibly rescale them if necessary 286 | boxes *= img_scale 287 | boxes[:, :2] += np.array(padding[:2])[None] 288 | boxes[:, 2:] += np.array(padding[:2])[None] 289 | obj_labels = [self.coco_obj_to_ind[item['objects'][i]] for i in dets2use.tolist()] 290 | if self.add_image_as_a_box: 291 | boxes = np.row_stack((window, boxes)) 292 | segms = np.concatenate((np.ones((1, 14, 14), dtype=np.float32), segms), 0) 293 | obj_labels = [self.coco_obj_to_ind['__background__']] + obj_labels 294 | 295 | instance_dict['segms'] = ArrayField(segms, padding_value=0) 296 | instance_dict['objects'] = ListField([LabelField(x, skip_indexing=True) for x in obj_labels]) 297 | 298 | if not np.all((boxes[:, 0] >= 0.) & (boxes[:, 0] < boxes[:, 2])): 299 | import ipdb 300 | ipdb.set_trace() 301 | assert np.all((boxes[:, 1] >= 0.) & (boxes[:, 1] < boxes[:, 3])) 302 | assert np.all((boxes[:, 2] <= w)) 303 | assert np.all((boxes[:, 3] <= h)) 304 | instance_dict['boxes'] = ArrayField(boxes, padding_value=-1) 305 | 306 | instance = Instance(instance_dict) 307 | instance.index_fields(self.vocab) 308 | return image, instance 309 | 310 | 311 | def collate_fn(data, to_gpu=False): 312 | """Creates mini-batch tensors 313 | """ 314 | images, instances = zip(*data) 315 | images = torch.stack(images, 0) 316 | batch = Batch(instances) 317 | td = batch.as_tensor_dict() 318 | if 'question' in td: 319 | td['question_mask'] = get_text_field_mask(td['question'], num_wrapping_dims=1) 320 | td['question_tags'][td['question_mask'] == 0] = -2 # Padding 321 | 322 | td['answer_mask'] = get_text_field_mask(td['answers'], num_wrapping_dims=1) 323 | td['answer_tags'][td['answer_mask'] == 0] = -2 324 | 325 | td['box_mask'] = torch.all(td['boxes'] >= 0, -1).long() 326 | td['images'] = images 327 | 328 | # Deprecated 329 | # if to_gpu: 330 | # for k in td: 331 | # if k != 'metadata': 332 | # td[k] = {k2: v.cuda(non_blocking=True) for k2, v in td[k].items()} if isinstance(td[k], dict) else td[k].cuda( 333 | # non_blocking=True) 334 | 335 | # # No nested dicts 336 | # for k in sorted(td.keys()): 337 | # if isinstance(td[k], dict): 338 | # for k2 in sorted(td[k].keys()): 339 | # td['{}_{}'.format(k, k2)] = td[k].pop(k2) 340 | # td.pop(k) 341 | 342 | return td 343 | 344 | 345 | class VCRLoader(torch.utils.data.DataLoader): 346 | """ 347 | Iterates through the data, filtering out None, 348 | but also loads everything as a (cuda) variable 349 | """ 350 | 351 | @classmethod 352 | def from_dataset(cls, data, batch_size=3, num_workers=6, num_gpus=3, **kwargs): 353 | loader = cls( 354 | dataset=data, 355 | batch_size=batch_size * num_gpus, 356 | shuffle=data.is_train, 357 | num_workers=num_workers, 358 | collate_fn=lambda x: collate_fn(x, to_gpu=False), 359 | drop_last=data.is_train, 360 | pin_memory=False, 361 | **kwargs, 362 | ) 363 | return loader 364 | 365 | # You could use this for debugging maybe 366 | # if __name__ == '__main__': 367 | # train, val, test = VCR.splits() 368 | # for i in range(len(train)): 369 | # res = train[i] 370 | # print("done with {}".format(i)) 371 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | # models 2 | 3 | This folder is for `r2c` models. They broadly follow the allennlp configuration format. If you want r2c, you'll want to look at `multiatt`. 4 | 5 | ## Replicating validation results 6 | Here's how you can replicate my val results. Run the command(s) below. First, you might want to make your GPUs available. When I ran these experiments I used 7 | 8 | `source activate r2c && export LD_LIBRARY_PATH=/usr/local/cuda-9.0/ && export PYTHONPATH=/home/rowan/code/r2c && export CUDA_VISIBLE_DEVICES=0,1,2` 9 | 10 | - For question answering, run: 11 | ``` 12 | python train.py -params multiatt/default.json -folder saves/flagship_answer 13 | ``` 14 | 15 | - for Answer justification, run 16 | ``` 17 | python train.py -params multiatt/default.json -folder saves/flagship_rationale -rationale 18 | ``` 19 | 20 | You can combine the validation predictions using 21 | `python eval_q2ar.py -answer_preds saves/flagship_answer/valpreds.npy -rationale_preds saves/flagship_rationale/valpreds.npy` 22 | 23 | ## Submitting to the leaderboard 24 | 25 | VCR features a [leaderboard](https://visualcommonsense.com/leaderboard/) where you can submit your answers on the test set. Submitting to the leaderboard is easy! You'll need to submit something like [the example submission CSV file](https://s3-us-west-2.amazonaws.com/ai2-rowanz/r2c/example-submission.csv). You can use the `eval_for_leaderboard.py` script, which formats everything in the right way. 26 | 27 | Essentially, your submission has to have the following columns: 28 | 29 | ``` 30 | annot_id,answer_0,answer_1,answer_2,answer_3,rationale_conditioned_on_a0_0,rationale_conditioned_on_a0_1,rationale_conditioned_on_a0_2,rationale_conditioned_on_a0_3,rationale_conditioned_on_a1_0,rationale_conditioned_on_a1_1,rationale_conditioned_on_a1_2,rationale_conditioned_on_a1_3,rationale_conditioned_on_a2_0,rationale_conditioned_on_a2_1,rationale_conditioned_on_a2_2,rationale_conditioned_on_a2_3,rationale_conditioned_on_a3_0,rationale_conditioned_on_a3_1,rationale_conditioned_on_a3_2,rationale_conditioned_on_a3_3 31 | ``` 32 | 33 | To evaluate, I'll first take the argmax over the answer choices, then take the argmax over your rationale choices (conditioned on the right answers). 34 | These give two sets of predictions, which can be used to compute Q->A and QA->R accuracy. For Q->AR accuracy, we take a bitwise AND between the hits of the QA and QAR columns. In other words, to get a question right, you have to get the answer AND the rationale right. -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models import * 2 | from models.multiatt import model 3 | 4 | # You can add more models in this folder. like 5 | # from models.no_question import model 6 | # from models.no_vision_at_all import model 7 | # from models.old_model import model 8 | # from models.bottom_up_top_down import model 9 | # from models.revisiting_vqa_baseline import model 10 | # from models.mlb import model -------------------------------------------------------------------------------- /models/eval_for_leaderboard.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluation script for the leaderboard 3 | """ 4 | import argparse 5 | import logging 6 | import multiprocessing 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | from allennlp.common.params import Params 12 | from allennlp.nn.util import device_mapping 13 | from torch.nn import DataParallel 14 | from torch.nn.modules import BatchNorm2d 15 | 16 | from dataloaders.vcr import VCR, VCRLoader 17 | from utils.pytorch_misc import time_batch 18 | 19 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', level=logging.DEBUG) 20 | 21 | # This is needed to make the imports work 22 | from allennlp.models import Model 23 | import models 24 | 25 | ################################# 26 | ################################# 27 | ######## Data loading stuff 28 | ################################# 29 | ################################# 30 | 31 | parser = argparse.ArgumentParser(description='train') 32 | parser.add_argument( 33 | '-params', 34 | dest='params', 35 | default='multiatt/default.json', 36 | help='Params location', 37 | type=str, 38 | ) 39 | parser.add_argument( 40 | '-answer_ckpt', 41 | dest='answer_ckpt', 42 | default='saves/flagship_answer/best.th', 43 | help='Answer checkpoint', 44 | type=str, 45 | ) 46 | parser.add_argument( 47 | '-rationale_ckpt', 48 | dest='rationale_ckpt', 49 | default='saves/flagship_rationale/best.th', 50 | help='Rationale checkpoint', 51 | type=str, 52 | ) 53 | parser.add_argument( 54 | '-output', 55 | dest='output', 56 | default='submission.csv', 57 | help='Output CSV file to save the predictions to', 58 | type=str, 59 | ) 60 | 61 | args = parser.parse_args() 62 | params = Params.from_file(args.params) 63 | 64 | NUM_GPUS = torch.cuda.device_count() 65 | NUM_CPUS = multiprocessing.cpu_count() 66 | if NUM_GPUS == 0: 67 | raise ValueError("you need gpus!") 68 | 69 | 70 | def _to_gpu(td): 71 | if NUM_GPUS > 1: 72 | return td 73 | for k in td: 74 | td[k] = {k2: v.cuda(async=True) for k2, v in td[k].items()} if isinstance(td[k], dict) else td[k].cuda( 75 | async=True) 76 | return td 77 | 78 | 79 | num_workers = (4 * NUM_GPUS if NUM_CPUS == 32 else 2 * NUM_GPUS) - 1 80 | print(f"Using {num_workers} workers out of {NUM_CPUS} possible", flush=True) 81 | loader_params = {'batch_size': 96 // NUM_GPUS, 'num_gpus': NUM_GPUS, 'num_workers': num_workers} 82 | 83 | vcr_modes = VCR.eval_splits(embs_to_load=params['dataset_reader'].get('embs', 'bert_da'), 84 | only_use_relevant_dets=params['dataset_reader'].get('only_use_relevant_dets', True)) 85 | probs_grp = [] 86 | ids_grp = [] 87 | for (vcr_dataset, mode_long) in zip(vcr_modes, ['answer'] + [f'rationale_{i}' for i in range(4)]): 88 | mode = mode_long.split('_')[0] 89 | 90 | test_loader = VCRLoader.from_dataset(vcr_dataset, **loader_params) 91 | 92 | # Load the params again because allennlp will delete them... ugh. 93 | params = Params.from_file(args.params) 94 | print("Loading {} for {}".format(params['model'].get('type', 'WTF?'), mode), flush=True) 95 | model = Model.from_params(vocab=vcr_dataset.vocab, params=params['model']) 96 | for submodule in model.detector.backbone.modules(): 97 | if isinstance(submodule, BatchNorm2d): 98 | submodule.track_running_stats = False 99 | 100 | model_state = torch.load(getattr(args, f'{mode}_ckpt'), map_location=device_mapping(-1)) 101 | model.load_state_dict(model_state) 102 | 103 | model = DataParallel(model).cuda() if NUM_GPUS > 1 else model.cuda() 104 | model.eval() 105 | 106 | test_probs = [] 107 | test_ids = [] 108 | for b, (time_per_batch, batch) in enumerate(time_batch(test_loader)): 109 | with torch.no_grad(): 110 | batch = _to_gpu(batch) 111 | output_dict = model(**batch) 112 | test_probs.append(output_dict['label_probs'].detach().cpu().numpy()) 113 | test_ids += [m['annot_id'] for m in batch['metadata']] 114 | if (b > 0) and (b % 10 == 0): 115 | print("Completed {}/{} batches in {:.3f}s".format(b, len(test_loader), time_per_batch * 10), flush=True) 116 | 117 | probs_grp.append(np.concatenate(test_probs, 0)) 118 | ids_grp.append(test_ids) 119 | 120 | ################################################################################ 121 | # This is the part you'll care about if you want to submit to the leaderboard! 122 | ################################################################################ 123 | 124 | # Double check the IDs are in the same order for everything 125 | assert [x == ids_grp[0] for x in ids_grp] 126 | 127 | probs_grp = np.stack(probs_grp, 1) 128 | # essentially probs_grp is a [num_ex, 5, 4] array of probabilities. The 5 'groups' are 129 | # [answer, rationale_conditioned_on_a0, rationale_conditioned_on_a1, 130 | # rationale_conditioned_on_a2, rationale_conditioned_on_a3]. 131 | # We will flatten this to a CSV file so it's easy to submit. 132 | group_names = ['answer'] + [f'rationale_conditioned_on_a{i}' for i in range(4)] 133 | probs_df = pd.DataFrame(data=probs_grp.reshape((-1, 20)), 134 | columns=[f'{group_name}_{i}' for group_name in group_names for i in range(4)]) 135 | probs_df['annot_id'] = ids_grp[0] 136 | probs_df = probs_df.set_index('annot_id', drop=True) 137 | probs_df.to_csv(args.output) 138 | -------------------------------------------------------------------------------- /models/eval_q2ar.py: -------------------------------------------------------------------------------- 1 | """ 2 | You can use this script to evaluate prediction files (valpreds.npy). Essentially this is needed if you want to, say, 3 | combine answer and rationale predictions. 4 | """ 5 | 6 | import numpy as np 7 | import json 8 | import os 9 | from config import VCR_ANNOTS_DIR 10 | import argparse 11 | 12 | parser = argparse.ArgumentParser(description='Evaluate question -> answer and rationale') 13 | parser.add_argument( 14 | '-answer_preds', 15 | dest='answer_preds', 16 | default='saves/flagship_answer/valpreds.npy', 17 | help='Location of question->answer predictions', 18 | type=str, 19 | ) 20 | parser.add_argument( 21 | '-rationale_preds', 22 | dest='rationale_preds', 23 | default='saves/flagship_rationale/valpreds.npy', 24 | help='Location of question+answer->rationale predictions', 25 | type=str, 26 | ) 27 | parser.add_argument( 28 | '-split', 29 | dest='split', 30 | default='val', 31 | help='Split you\'re using. Probably you want val.', 32 | type=str, 33 | ) 34 | 35 | args = parser.parse_args() 36 | 37 | answer_preds = np.load(args.answer_preds) 38 | rationale_preds = np.load(args.rationale_preds) 39 | 40 | rationale_labels = [] 41 | answer_labels = [] 42 | 43 | with open(os.path.join(VCR_ANNOTS_DIR, '{}.jsonl'.format(args.split)), 'r') as f: 44 | for l in f: 45 | item = json.loads(l) 46 | answer_labels.append(item['answer_label']) 47 | rationale_labels.append(item['rationale_label']) 48 | 49 | answer_labels = np.array(answer_labels) 50 | rationale_labels = np.array(rationale_labels) 51 | 52 | # Sanity checks 53 | assert answer_preds.shape[0] == answer_labels.size 54 | assert rationale_preds.shape[0] == answer_labels.size 55 | assert answer_preds.shape[1] == 4 56 | assert rationale_preds.shape[1] == 4 57 | 58 | answer_hits = answer_preds.argmax(1) == answer_labels 59 | rationale_hits = rationale_preds.argmax(1) == rationale_labels 60 | joint_hits = answer_hits & rationale_hits 61 | 62 | print("Answer acc: {:.3f}".format(np.mean(answer_hits)), flush=True) 63 | print("Rationale acc: {:.3f}".format(np.mean(rationale_hits)), flush=True) 64 | print("Joint acc: {:.3f}".format(np.mean(answer_hits & rationale_hits)), flush=True) 65 | -------------------------------------------------------------------------------- /models/multiatt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rowanz/r2c/77813d9e335711759c25df79c348a7c2a8275d72/models/multiatt/__init__.py -------------------------------------------------------------------------------- /models/multiatt/default.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "vswag" 4 | }, 5 | "model": { 6 | "type": "MultiHopAttentionQA", 7 | "span_encoder": { 8 | "type": "lstm", 9 | "input_size": 1280, 10 | "hidden_size": 256, 11 | "num_layers": 1, 12 | "bidirectional": true 13 | }, 14 | "reasoning_encoder": { 15 | "type": "lstm", 16 | "input_size": 1536, 17 | "hidden_size": 256, 18 | "num_layers": 2, 19 | "bidirectional": true 20 | }, 21 | "hidden_dim_maxpool": 1024, 22 | "input_dropout": 0.3, 23 | "pool_question": true, 24 | "pool_answer": true, 25 | "initializer": [ 26 | [".*final_mlp.*weight", {"type": "xavier_uniform"}], 27 | [".*final_mlp.*bias", {"type": "zero"}], 28 | [".*weight_ih.*", {"type": "xavier_uniform"}], 29 | [".*weight_hh.*", {"type": "orthogonal"}], 30 | [".*bias_ih.*", {"type": "zero"}], 31 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}]] 32 | }, 33 | "trainer": { 34 | "optimizer": { 35 | "type": "adam", 36 | "lr": 0.0002, 37 | "weight_decay": 0.0001 38 | }, 39 | "validation_metric": "+accuracy", 40 | "num_serialized_models_to_keep": 2, 41 | "num_epochs": 20, 42 | "grad_norm": 1.0, 43 | "patience": 3, 44 | "cuda_device": 0, 45 | "learning_rate_scheduler": { 46 | "type": "reduce_on_plateau", 47 | "factor": 0.5, 48 | "mode": "max", 49 | "patience": 1, 50 | "verbose": true, 51 | "cooldown": 2 52 | } 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /models/multiatt/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Let's get the relationships yo 3 | """ 4 | 5 | from typing import Dict, List, Any 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.nn.parallel 10 | from allennlp.data.vocabulary import Vocabulary 11 | from allennlp.models.model import Model 12 | from allennlp.modules import TextFieldEmbedder, Seq2SeqEncoder, FeedForward, InputVariationalDropout, TimeDistributed 13 | from allennlp.training.metrics import CategoricalAccuracy 14 | from allennlp.modules.matrix_attention import BilinearMatrixAttention 15 | from utils.detector import SimpleDetector 16 | from allennlp.nn.util import masked_softmax, weighted_sum, replace_masked_values 17 | from allennlp.nn import InitializerApplicator 18 | 19 | @Model.register("MultiHopAttentionQA") 20 | class AttentionQA(Model): 21 | def __init__(self, 22 | vocab: Vocabulary, 23 | span_encoder: Seq2SeqEncoder, 24 | reasoning_encoder: Seq2SeqEncoder, 25 | input_dropout: float = 0.3, 26 | hidden_dim_maxpool: int = 1024, 27 | class_embs: bool=True, 28 | reasoning_use_obj: bool=True, 29 | reasoning_use_answer: bool=True, 30 | reasoning_use_question: bool=True, 31 | pool_reasoning: bool = True, 32 | pool_answer: bool = True, 33 | pool_question: bool = False, 34 | initializer: InitializerApplicator = InitializerApplicator(), 35 | ): 36 | super(AttentionQA, self).__init__(vocab) 37 | 38 | self.detector = SimpleDetector(pretrained=True, average_pool=True, semantic=class_embs, final_dim=512) 39 | ################################################################################################### 40 | 41 | self.rnn_input_dropout = TimeDistributed(InputVariationalDropout(input_dropout)) if input_dropout > 0 else None 42 | 43 | self.span_encoder = TimeDistributed(span_encoder) 44 | self.reasoning_encoder = TimeDistributed(reasoning_encoder) 45 | 46 | self.span_attention = BilinearMatrixAttention( 47 | matrix_1_dim=span_encoder.get_output_dim(), 48 | matrix_2_dim=span_encoder.get_output_dim(), 49 | ) 50 | 51 | self.obj_attention = BilinearMatrixAttention( 52 | matrix_1_dim=span_encoder.get_output_dim(), 53 | matrix_2_dim=self.detector.final_dim, 54 | ) 55 | 56 | self.reasoning_use_obj = reasoning_use_obj 57 | self.reasoning_use_answer = reasoning_use_answer 58 | self.reasoning_use_question = reasoning_use_question 59 | self.pool_reasoning = pool_reasoning 60 | self.pool_answer = pool_answer 61 | self.pool_question = pool_question 62 | dim = sum([d for d, to_pool in [(reasoning_encoder.get_output_dim(), self.pool_reasoning), 63 | (span_encoder.get_output_dim(), self.pool_answer), 64 | (span_encoder.get_output_dim(), self.pool_question)] if to_pool]) 65 | 66 | self.final_mlp = torch.nn.Sequential( 67 | torch.nn.Dropout(input_dropout, inplace=False), 68 | torch.nn.Linear(dim, hidden_dim_maxpool), 69 | torch.nn.ReLU(inplace=True), 70 | torch.nn.Dropout(input_dropout, inplace=False), 71 | torch.nn.Linear(hidden_dim_maxpool, 1), 72 | ) 73 | self._accuracy = CategoricalAccuracy() 74 | self._loss = torch.nn.CrossEntropyLoss() 75 | initializer(self) 76 | 77 | def _collect_obj_reps(self, span_tags, object_reps): 78 | """ 79 | Collect span-level object representations 80 | :param span_tags: [batch_size, ..leading_dims.., L] 81 | :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim] 82 | :return: 83 | """ 84 | span_tags_fixed = torch.clamp(span_tags, min=0) # In case there were masked values here 85 | row_id = span_tags_fixed.new_zeros(span_tags_fixed.shape) 86 | row_id_broadcaster = torch.arange(0, row_id.shape[0], step=1, device=row_id.device)[:, None] 87 | 88 | # Add extra diminsions to the row broadcaster so it matches row_id 89 | leading_dims = len(span_tags.shape) - 2 90 | for i in range(leading_dims): 91 | row_id_broadcaster = row_id_broadcaster[..., None] 92 | row_id += row_id_broadcaster 93 | return object_reps[row_id.view(-1), span_tags_fixed.view(-1)].view(*span_tags_fixed.shape, -1) 94 | 95 | def embed_span(self, span, span_tags, span_mask, object_reps): 96 | """ 97 | :param span: Thing that will get embed and turned into [batch_size, ..leading_dims.., L, word_dim] 98 | :param span_tags: [batch_size, ..leading_dims.., L] 99 | :param object_reps: [batch_size, max_num_objs_per_batch, obj_dim] 100 | :param span_mask: [batch_size, ..leading_dims.., span_mask 101 | :return: 102 | """ 103 | retrieved_feats = self._collect_obj_reps(span_tags, object_reps) 104 | 105 | span_rep = torch.cat((span['bert'], retrieved_feats), -1) 106 | # add recurrent dropout here 107 | if self.rnn_input_dropout: 108 | span_rep = self.rnn_input_dropout(span_rep) 109 | 110 | return self.span_encoder(span_rep, span_mask), retrieved_feats 111 | 112 | def forward(self, 113 | images: torch.Tensor, 114 | objects: torch.LongTensor, 115 | segms: torch.Tensor, 116 | boxes: torch.Tensor, 117 | box_mask: torch.LongTensor, 118 | question: Dict[str, torch.Tensor], 119 | question_tags: torch.LongTensor, 120 | question_mask: torch.LongTensor, 121 | answers: Dict[str, torch.Tensor], 122 | answer_tags: torch.LongTensor, 123 | answer_mask: torch.LongTensor, 124 | metadata: List[Dict[str, Any]] = None, 125 | label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: 126 | """ 127 | :param images: [batch_size, 3, im_height, im_width] 128 | :param objects: [batch_size, max_num_objects] Padded objects 129 | :param boxes: [batch_size, max_num_objects, 4] Padded boxes 130 | :param box_mask: [batch_size, max_num_objects] Mask for whether or not each box is OK 131 | :param question: AllenNLP representation of the question. [batch_size, num_answers, seq_length] 132 | :param question_tags: A detection label for each item in the Q [batch_size, num_answers, seq_length] 133 | :param question_mask: Mask for the Q [batch_size, num_answers, seq_length] 134 | :param answers: AllenNLP representation of the answer. [batch_size, num_answers, seq_length] 135 | :param answer_tags: A detection label for each item in the A [batch_size, num_answers, seq_length] 136 | :param answer_mask: Mask for the As [batch_size, num_answers, seq_length] 137 | :param metadata: Ignore, this is about which dataset item we're on 138 | :param label: Optional, which item is valid 139 | :return: shit 140 | """ 141 | # Trim off boxes that are too long. this is an issue b/c dataparallel, it'll pad more zeros that are 142 | # not needed 143 | max_len = int(box_mask.sum(1).max().item()) 144 | objects = objects[:, :max_len] 145 | box_mask = box_mask[:, :max_len] 146 | boxes = boxes[:, :max_len] 147 | segms = segms[:, :max_len] 148 | 149 | for tag_type, the_tags in (('question', question_tags), ('answer', answer_tags)): 150 | if int(the_tags.max()) > max_len: 151 | raise ValueError("Oh no! {}_tags has maximum of {} but objects is of dim {}. Values are\n{}".format( 152 | tag_type, int(the_tags.max()), objects.shape, the_tags 153 | )) 154 | 155 | obj_reps = self.detector(images=images, boxes=boxes, box_mask=box_mask, classes=objects, segms=segms) 156 | 157 | # Now get the question representations 158 | q_rep, q_obj_reps = self.embed_span(question, question_tags, question_mask, obj_reps['obj_reps']) 159 | a_rep, a_obj_reps = self.embed_span(answers, answer_tags, answer_mask, obj_reps['obj_reps']) 160 | 161 | #################################### 162 | # Perform Q by A attention 163 | # [batch_size, 4, question_length, answer_length] 164 | qa_similarity = self.span_attention( 165 | q_rep.view(q_rep.shape[0] * q_rep.shape[1], q_rep.shape[2], q_rep.shape[3]), 166 | a_rep.view(a_rep.shape[0] * a_rep.shape[1], a_rep.shape[2], a_rep.shape[3]), 167 | ).view(a_rep.shape[0], a_rep.shape[1], q_rep.shape[2], a_rep.shape[2]) 168 | qa_attention_weights = masked_softmax(qa_similarity, question_mask[..., None], dim=2) 169 | attended_q = torch.einsum('bnqa,bnqd->bnad', (qa_attention_weights, q_rep)) 170 | 171 | # Have a second attention over the objects, do A by Objs 172 | # [batch_size, 4, answer_length, num_objs] 173 | atoo_similarity = self.obj_attention(a_rep.view(a_rep.shape[0], a_rep.shape[1] * a_rep.shape[2], -1), 174 | obj_reps['obj_reps']).view(a_rep.shape[0], a_rep.shape[1], 175 | a_rep.shape[2], obj_reps['obj_reps'].shape[1]) 176 | atoo_attention_weights = masked_softmax(atoo_similarity, box_mask[:,None,None]) 177 | attended_o = torch.einsum('bnao,bod->bnad', (atoo_attention_weights, obj_reps['obj_reps'])) 178 | 179 | 180 | reasoning_inp = torch.cat([x for x, to_pool in [(a_rep, self.reasoning_use_answer), 181 | (attended_o, self.reasoning_use_obj), 182 | (attended_q, self.reasoning_use_question)] 183 | if to_pool], -1) 184 | 185 | if self.rnn_input_dropout is not None: 186 | reasoning_inp = self.rnn_input_dropout(reasoning_inp) 187 | reasoning_output = self.reasoning_encoder(reasoning_inp, answer_mask) 188 | 189 | 190 | ########################################### 191 | things_to_pool = torch.cat([x for x, to_pool in [(reasoning_output, self.pool_reasoning), 192 | (a_rep, self.pool_answer), 193 | (attended_q, self.pool_question)] if to_pool], -1) 194 | 195 | pooled_rep = replace_masked_values(things_to_pool,answer_mask[...,None], -1e7).max(2)[0] 196 | logits = self.final_mlp(pooled_rep).squeeze(2) 197 | 198 | ########################################### 199 | 200 | class_probabilities = F.softmax(logits, dim=-1) 201 | 202 | output_dict = {"label_logits": logits, "label_probs": class_probabilities, 203 | 'cnn_regularization_loss': obj_reps['cnn_regularization_loss'], 204 | # Uncomment to visualize attention, if you want 205 | # 'qa_attention_weights': qa_attention_weights, 206 | # 'atoo_attention_weights': atoo_attention_weights, 207 | } 208 | if label is not None: 209 | loss = self._loss(logits, label.long().view(-1)) 210 | self._accuracy(logits, label) 211 | output_dict["loss"] = loss[None] 212 | 213 | return output_dict 214 | 215 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 216 | return {'accuracy': self._accuracy.get_metric(reset)} 217 | -------------------------------------------------------------------------------- /models/multiatt/no_class_embs.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "vswag" 4 | }, 5 | "model": { 6 | "type": "MultiHopAttentionQA", 7 | "span_encoder": { 8 | "type": "lstm", 9 | "input_size": 1280, 10 | "hidden_size": 256, 11 | "num_layers": 1, 12 | "bidirectional": true 13 | }, 14 | "reasoning_encoder": { 15 | "type": "lstm", 16 | "input_size": 1536, 17 | "hidden_size": 256, 18 | "num_layers": 2, 19 | "bidirectional": true 20 | }, 21 | "hidden_dim_maxpool": 1024, 22 | "input_dropout": 0.3, 23 | "pool_question": true, 24 | "pool_answer": true, 25 | "class_embs": false, 26 | "initializer": [ 27 | [".*final_mlp.*weight", {"type": "xavier_uniform"}], 28 | [".*final_mlp.*bias", {"type": "zero"}], 29 | [".*weight_ih.*", {"type": "xavier_uniform"}], 30 | [".*weight_hh.*", {"type": "orthogonal"}], 31 | [".*bias_ih.*", {"type": "zero"}], 32 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}]] 33 | }, 34 | "trainer": { 35 | "optimizer": { 36 | "type": "adam", 37 | "lr": 0.0002, 38 | "weight_decay": 0.0001 39 | }, 40 | "validation_metric": "+accuracy", 41 | "num_serialized_models_to_keep": 2, 42 | "num_epochs": 20, 43 | "grad_norm": 1.0, 44 | "patience": 3, 45 | "cuda_device": 0, 46 | "learning_rate_scheduler": { 47 | "type": "reduce_on_plateau", 48 | "factor": 0.5, 49 | "mode": "max", 50 | "patience": 1, 51 | "verbose": true, 52 | "cooldown": 2 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /models/multiatt/no_obj.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "vswag" 4 | }, 5 | "model": { 6 | "type": "MultiHopAttentionQA", 7 | "span_encoder": { 8 | "type": "lstm", 9 | "input_size": 1280, 10 | "hidden_size": 256, 11 | "num_layers": 1, 12 | "bidirectional": true 13 | }, 14 | "reasoning_encoder": { 15 | "type": "lstm", 16 | "input_size": 1024, 17 | "hidden_size": 256, 18 | "num_layers": 2, 19 | "bidirectional": true 20 | }, 21 | "hidden_dim_maxpool": 1024, 22 | "input_dropout": 0.3, 23 | "pool_question": true, 24 | "pool_answer": true, 25 | "reasoning_use_obj": false, 26 | "initializer": [ 27 | [".*final_mlp.*weight", {"type": "xavier_uniform"}], 28 | [".*final_mlp.*bias", {"type": "zero"}], 29 | [".*weight_ih.*", {"type": "xavier_uniform"}], 30 | [".*weight_hh.*", {"type": "orthogonal"}], 31 | [".*bias_ih.*", {"type": "zero"}], 32 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}]] 33 | }, 34 | "trainer": { 35 | "optimizer": { 36 | "type": "adam", 37 | "lr": 0.0002, 38 | "weight_decay": 0.0001 39 | }, 40 | "validation_metric": "+accuracy", 41 | "num_serialized_models_to_keep": 2, 42 | "num_epochs": 20, 43 | "grad_norm": 1.0, 44 | "patience": 3, 45 | "cuda_device": 0, 46 | "learning_rate_scheduler": { 47 | "type": "reduce_on_plateau", 48 | "factor": 0.5, 49 | "mode": "max", 50 | "patience": 1, 51 | "verbose": true, 52 | "cooldown": 2 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /models/multiatt/no_reasoning.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "vswag" 4 | }, 5 | "model": { 6 | "type": "MultiHopAttentionQA", 7 | "span_encoder": { 8 | "type": "lstm", 9 | "input_size": 1280, 10 | "hidden_size": 256, 11 | "num_layers": 1, 12 | "bidirectional": true 13 | }, 14 | "reasoning_encoder": { 15 | "type": "pass_through", 16 | "input_dim": 1536, 17 | }, 18 | "hidden_dim_maxpool": 1024, 19 | "input_dropout": 0.3, 20 | "pool_question": true, 21 | "pool_answer": true, 22 | "pool_reasoning": false, 23 | "initializer": [ 24 | [".*final_mlp.*weight", {"type": "xavier_uniform"}], 25 | [".*final_mlp.*bias", {"type": "zero"}], 26 | [".*weight_ih.*", {"type": "xavier_uniform"}], 27 | [".*weight_hh.*", {"type": "orthogonal"}], 28 | [".*bias_ih.*", {"type": "zero"}], 29 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}]] 30 | }, 31 | "trainer": { 32 | "optimizer": { 33 | "type": "adam", 34 | "lr": 0.0002, 35 | "weight_decay": 0.0001 36 | }, 37 | "validation_metric": "+accuracy", 38 | "num_serialized_models_to_keep": 2, 39 | "num_epochs": 20, 40 | "grad_norm": 1.0, 41 | "patience": 3, 42 | "cuda_device": 0, 43 | "learning_rate_scheduler": { 44 | "type": "reduce_on_plateau", 45 | "factor": 0.5, 46 | "mode": "max", 47 | "patience": 1, 48 | "verbose": true, 49 | "cooldown": 2 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /models/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training script. Should be pretty adaptable to whatever. 3 | """ 4 | import argparse 5 | import os 6 | import shutil 7 | 8 | import multiprocessing 9 | import numpy as np 10 | import pandas as pd 11 | import torch 12 | from allennlp.common.params import Params 13 | from allennlp.training.learning_rate_schedulers import LearningRateScheduler 14 | from allennlp.training.optimizers import Optimizer 15 | from torch.nn import DataParallel 16 | from torch.nn.modules import BatchNorm2d 17 | from tqdm import tqdm 18 | 19 | from dataloaders.vcr import VCR, VCRLoader 20 | from utils.pytorch_misc import time_batch, save_checkpoint, clip_grad_norm, \ 21 | restore_checkpoint, print_para, restore_best_checkpoint 22 | 23 | import logging 24 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', level=logging.DEBUG) 25 | 26 | # This is needed to make the imports work 27 | from allennlp.models import Model 28 | import models 29 | 30 | ################################# 31 | ################################# 32 | ######## Data loading stuff 33 | ################################# 34 | ################################# 35 | 36 | parser = argparse.ArgumentParser(description='train') 37 | parser.add_argument( 38 | '-params', 39 | dest='params', 40 | help='Params location', 41 | type=str, 42 | ) 43 | parser.add_argument( 44 | '-rationale', 45 | action="store_true", 46 | help='use rationale', 47 | ) 48 | parser.add_argument( 49 | '-folder', 50 | dest='folder', 51 | help='folder location', 52 | type=str, 53 | ) 54 | parser.add_argument( 55 | '-no_tqdm', 56 | dest='no_tqdm', 57 | action='store_true', 58 | ) 59 | 60 | args = parser.parse_args() 61 | 62 | params = Params.from_file(args.params) 63 | train, val, test = VCR.splits(mode='rationale' if args.rationale else 'answer', 64 | embs_to_load=params['dataset_reader'].get('embs', 'bert_da'), 65 | only_use_relevant_dets=params['dataset_reader'].get('only_use_relevant_dets', True)) 66 | NUM_GPUS = torch.cuda.device_count() 67 | NUM_CPUS = multiprocessing.cpu_count() 68 | if NUM_GPUS == 0: 69 | raise ValueError("you need gpus!") 70 | 71 | def _to_gpu(td): 72 | if NUM_GPUS > 1: 73 | return td 74 | for k in td: 75 | if k != 'metadata': 76 | td[k] = {k2: v.cuda(non_blocking=True) for k2, v in td[k].items()} if isinstance(td[k], dict) else td[k].cuda( 77 | non_blocking=True) 78 | return td 79 | num_workers = (4 * NUM_GPUS if NUM_CPUS == 32 else 2*NUM_GPUS)-1 80 | print(f"Using {num_workers} workers out of {NUM_CPUS} possible", flush=True) 81 | loader_params = {'batch_size': 96 // NUM_GPUS, 'num_gpus':NUM_GPUS, 'num_workers':num_workers} 82 | train_loader = VCRLoader.from_dataset(train, **loader_params) 83 | val_loader = VCRLoader.from_dataset(val, **loader_params) 84 | test_loader = VCRLoader.from_dataset(test, **loader_params) 85 | 86 | ARGS_RESET_EVERY = 100 87 | print("Loading {} for {}".format(params['model'].get('type', 'WTF?'), 'rationales' if args.rationale else 'answer'), flush=True) 88 | model = Model.from_params(vocab=train.vocab, params=params['model']) 89 | for submodule in model.detector.backbone.modules(): 90 | if isinstance(submodule, BatchNorm2d): 91 | submodule.track_running_stats = False 92 | for p in submodule.parameters(): 93 | p.requires_grad = False 94 | 95 | model = DataParallel(model).cuda() if NUM_GPUS > 1 else model.cuda() 96 | optimizer = Optimizer.from_params([x for x in model.named_parameters() if x[1].requires_grad], 97 | params['trainer']['optimizer']) 98 | 99 | lr_scheduler_params = params['trainer'].pop("learning_rate_scheduler", None) 100 | scheduler = LearningRateScheduler.from_params(optimizer, lr_scheduler_params) if lr_scheduler_params else None 101 | 102 | if os.path.exists(args.folder): 103 | print("Found folder! restoring", flush=True) 104 | start_epoch, val_metric_per_epoch = restore_checkpoint(model, optimizer, serialization_dir=args.folder, 105 | learning_rate_scheduler=scheduler) 106 | else: 107 | print("Making directories") 108 | os.makedirs(args.folder, exist_ok=True) 109 | start_epoch, val_metric_per_epoch = 0, [] 110 | shutil.copy2(args.params, args.folder) 111 | 112 | param_shapes = print_para(model) 113 | num_batches = 0 114 | for epoch_num in range(start_epoch, params['trainer']['num_epochs'] + start_epoch): 115 | train_results = [] 116 | norms = [] 117 | model.train() 118 | for b, (time_per_batch, batch) in enumerate(time_batch(train_loader if args.no_tqdm else tqdm(train_loader), reset_every=ARGS_RESET_EVERY)): 119 | batch = _to_gpu(batch) 120 | optimizer.zero_grad() 121 | output_dict = model(**batch) 122 | loss = output_dict['loss'].mean() + output_dict['cnn_regularization_loss'].mean() 123 | loss.backward() 124 | 125 | num_batches += 1 126 | if scheduler: 127 | scheduler.step_batch(num_batches) 128 | 129 | norms.append( 130 | clip_grad_norm(model.named_parameters(), max_norm=params['trainer']['grad_norm'], clip=True, verbose=False) 131 | ) 132 | optimizer.step() 133 | 134 | train_results.append(pd.Series({'loss': output_dict['loss'].mean().item(), 135 | 'crl': output_dict['cnn_regularization_loss'].mean().item(), 136 | 'accuracy': (model.module if NUM_GPUS > 1 else model).get_metrics( 137 | reset=(b % ARGS_RESET_EVERY) == 0)[ 138 | 'accuracy'], 139 | 'sec_per_batch': time_per_batch, 140 | 'hr_per_epoch': len(train_loader) * time_per_batch / 3600, 141 | })) 142 | if b % ARGS_RESET_EVERY == 0 and b > 0: 143 | norms_df = pd.DataFrame(pd.DataFrame(norms[-ARGS_RESET_EVERY:]).mean(), columns=['norm']).join( 144 | param_shapes[['shape', 'size']]).sort_values('norm', ascending=False) 145 | 146 | print("e{:2d}b{:5d}/{:5d}. norms: \n{}\nsumm:\n{}\n~~~~~~~~~~~~~~~~~~\n".format( 147 | epoch_num, b, len(train_loader), 148 | norms_df.to_string(formatters={'norm': '{:.2f}'.format}), 149 | pd.DataFrame(train_results[-ARGS_RESET_EVERY:]).mean(), 150 | ), flush=True) 151 | 152 | print("---\nTRAIN EPOCH {:2d}:\n{}\n----".format(epoch_num, pd.DataFrame(train_results).mean())) 153 | val_probs = [] 154 | val_labels = [] 155 | val_loss_sum = 0.0 156 | model.eval() 157 | for b, (time_per_batch, batch) in enumerate(time_batch(val_loader)): 158 | with torch.no_grad(): 159 | batch = _to_gpu(batch) 160 | output_dict = model(**batch) 161 | val_probs.append(output_dict['label_probs'].detach().cpu().numpy()) 162 | val_labels.append(batch['label'].detach().cpu().numpy()) 163 | val_loss_sum += output_dict['loss'].mean().item() * batch['label'].shape[0] 164 | val_labels = np.concatenate(val_labels, 0) 165 | val_probs = np.concatenate(val_probs, 0) 166 | val_loss_avg = val_loss_sum / val_labels.shape[0] 167 | 168 | val_metric_per_epoch.append(float(np.mean(val_labels == val_probs.argmax(1)))) 169 | if scheduler: 170 | scheduler.step(val_metric_per_epoch[-1], epoch_num) 171 | 172 | print("Val epoch {} has acc {:.3f} and loss {:.3f}".format(epoch_num, val_metric_per_epoch[-1], val_loss_avg), 173 | flush=True) 174 | if int(np.argmax(val_metric_per_epoch)) < (len(val_metric_per_epoch) - 1 - params['trainer']['patience']): 175 | print("Stopping at epoch {:2d}".format(epoch_num)) 176 | break 177 | save_checkpoint(model, optimizer, args.folder, epoch_num, val_metric_per_epoch, 178 | is_best=int(np.argmax(val_metric_per_epoch)) == (len(val_metric_per_epoch) - 1)) 179 | 180 | print("STOPPING. now running the best model on the validation set", flush=True) 181 | # Load best 182 | restore_best_checkpoint(model, args.folder) 183 | model.eval() 184 | val_probs = [] 185 | val_labels = [] 186 | for b, (time_per_batch, batch) in enumerate(time_batch(val_loader)): 187 | with torch.no_grad(): 188 | batch = _to_gpu(batch) 189 | output_dict = model(**batch) 190 | val_probs.append(output_dict['label_probs'].detach().cpu().numpy()) 191 | val_labels.append(batch['label'].detach().cpu().numpy()) 192 | val_labels = np.concatenate(val_labels, 0) 193 | val_probs = np.concatenate(val_probs, 0) 194 | acc = float(np.mean(val_labels == val_probs.argmax(1))) 195 | print("Final val accuracy is {:.3f}".format(acc)) 196 | np.save(os.path.join(args.folder, f'valpreds.npy'), val_probs) 197 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rowanz/r2c/77813d9e335711759c25df79c348a7c2a8275d72/utils/__init__.py -------------------------------------------------------------------------------- /utils/detector.py: -------------------------------------------------------------------------------- 1 | """ 2 | ok so I lied. it's not a detector, it's the resnet backbone 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | from torchvision.models import resnet 9 | 10 | from utils.pytorch_misc import Flattener 11 | from torchvision.layers import ROIAlign 12 | import torch.utils.model_zoo as model_zoo 13 | from config import USE_IMAGENET_PRETRAINED 14 | from utils.pytorch_misc import pad_sequence 15 | from torch.nn import functional as F 16 | 17 | 18 | def _load_resnet(pretrained=True): 19 | # huge thx to https://github.com/ruotianluo/pytorch-faster-rcnn/blob/master/lib/nets/resnet_v1.py 20 | backbone = resnet.resnet50(pretrained=False) 21 | if pretrained: 22 | backbone.load_state_dict(model_zoo.load_url( 23 | 'https://s3.us-west-2.amazonaws.com/ai2-rowanz/resnet50-e13db6895d81.th')) 24 | for i in range(2, 4): 25 | getattr(backbone, 'layer%d' % i)[0].conv1.stride = (2, 2) 26 | getattr(backbone, 'layer%d' % i)[0].conv2.stride = (1, 1) 27 | return backbone 28 | 29 | 30 | def _load_resnet_imagenet(pretrained=True): 31 | # huge thx to https://github.com/ruotianluo/pytorch-faster-rcnn/blob/master/lib/nets/resnet_v1.py 32 | backbone = resnet.resnet50(pretrained=pretrained) 33 | for i in range(2, 4): 34 | getattr(backbone, 'layer%d' % i)[0].conv1.stride = (2, 2) 35 | getattr(backbone, 'layer%d' % i)[0].conv2.stride = (1, 1) 36 | # use stride 1 for the last conv4 layer (same as tf-faster-rcnn) 37 | backbone.layer4[0].conv2.stride = (1, 1) 38 | backbone.layer4[0].downsample[0].stride = (1, 1) 39 | 40 | # # Make batchnorm more sensible 41 | # for submodule in backbone.modules(): 42 | # if isinstance(submodule, torch.nn.BatchNorm2d): 43 | # submodule.momentum = 0.01 44 | 45 | return backbone 46 | 47 | 48 | class SimpleDetector(nn.Module): 49 | def __init__(self, pretrained=True, average_pool=True, semantic=True, final_dim=1024): 50 | """ 51 | :param average_pool: whether or not to average pool the representations 52 | :param pretrained: Whether we need to load from scratch 53 | :param semantic: Whether or not we want to introduce the mask and the class label early on (default Yes) 54 | """ 55 | super(SimpleDetector, self).__init__() 56 | # huge thx to https://github.com/ruotianluo/pytorch-faster-rcnn/blob/master/lib/nets/resnet_v1.py 57 | backbone = _load_resnet_imagenet(pretrained=pretrained) if USE_IMAGENET_PRETRAINED else _load_resnet( 58 | pretrained=pretrained) 59 | 60 | self.backbone = nn.Sequential( 61 | backbone.conv1, 62 | backbone.bn1, 63 | backbone.relu, 64 | backbone.maxpool, 65 | backbone.layer1, 66 | backbone.layer2, 67 | backbone.layer3, 68 | # backbone.layer4 69 | ) 70 | self.roi_align = ROIAlign((7, 7) if USE_IMAGENET_PRETRAINED else (14, 14), 71 | spatial_scale=1 / 16, sampling_ratio=0) 72 | 73 | if semantic: 74 | self.mask_dims = 32 75 | self.object_embed = torch.nn.Embedding(num_embeddings=81, embedding_dim=128) 76 | self.mask_upsample = torch.nn.Conv2d(1, self.mask_dims, kernel_size=3, 77 | stride=2 if USE_IMAGENET_PRETRAINED else 1, 78 | padding=1, bias=True) 79 | else: 80 | self.object_embed = None 81 | self.mask_upsample = None 82 | 83 | after_roi_align = [backbone.layer4] 84 | self.final_dim = final_dim 85 | if average_pool: 86 | after_roi_align += [nn.AvgPool2d(7, stride=1), Flattener()] 87 | 88 | self.after_roi_align = torch.nn.Sequential(*after_roi_align) 89 | 90 | self.obj_downsample = torch.nn.Sequential( 91 | torch.nn.Dropout(p=0.1), 92 | torch.nn.Linear(2048 + (128 if semantic else 0), final_dim), 93 | torch.nn.ReLU(inplace=True), 94 | ) 95 | self.regularizing_predictor = torch.nn.Linear(2048, 81) 96 | 97 | def forward(self, 98 | images: torch.Tensor, 99 | boxes: torch.Tensor, 100 | box_mask: torch.LongTensor, 101 | classes: torch.Tensor = None, 102 | segms: torch.Tensor = None, 103 | ): 104 | """ 105 | :param images: [batch_size, 3, im_height, im_width] 106 | :param boxes: [batch_size, max_num_objects, 4] Padded boxes 107 | :param box_mask: [batch_size, max_num_objects] Mask for whether or not each box is OK 108 | :return: object reps [batch_size, max_num_objects, dim] 109 | """ 110 | # [batch_size, 2048, im_height // 32, im_width // 32 111 | img_feats = self.backbone(images) 112 | box_inds = box_mask.nonzero() 113 | assert box_inds.shape[0] > 0 114 | rois = torch.cat(( 115 | box_inds[:, 0, None].type(boxes.dtype), 116 | boxes[box_inds[:, 0], box_inds[:, 1]], 117 | ), 1) 118 | 119 | # Object class and segmentation representations 120 | roi_align_res = self.roi_align(img_feats, rois) 121 | if self.mask_upsample is not None: 122 | assert segms is not None 123 | segms_indexed = segms[box_inds[:, 0], None, box_inds[:, 1]] - 0.5 124 | roi_align_res[:, :self.mask_dims] += self.mask_upsample(segms_indexed) 125 | 126 | 127 | post_roialign = self.after_roi_align(roi_align_res) 128 | 129 | # Add some regularization, encouraging the model to keep giving decent enough predictions 130 | obj_logits = self.regularizing_predictor(post_roialign) 131 | obj_labels = classes[box_inds[:, 0], box_inds[:, 1]] 132 | cnn_regularization = F.cross_entropy(obj_logits, obj_labels, size_average=True)[None] 133 | 134 | feats_to_downsample = post_roialign if self.object_embed is None else torch.cat((post_roialign, self.object_embed(obj_labels)), -1) 135 | roi_aligned_feats = self.obj_downsample(feats_to_downsample) 136 | 137 | # Reshape into a padded sequence - this is expensive and annoying but easier to implement and debug... 138 | obj_reps = pad_sequence(roi_aligned_feats, box_mask.sum(1).tolist()) 139 | return { 140 | 'obj_reps_raw': post_roialign, 141 | 'obj_reps': obj_reps, 142 | 'obj_logits': obj_logits, 143 | 'obj_labels': obj_labels, 144 | 'cnn_regularization_loss': cnn_regularization 145 | } 146 | -------------------------------------------------------------------------------- /utils/pytorch_misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Question relevance model 3 | """ 4 | 5 | # Make stuff 6 | import os 7 | import re 8 | import shutil 9 | import time 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import torch 14 | from allennlp.common.util import START_SYMBOL, END_SYMBOL 15 | from allennlp.nn.util import device_mapping 16 | from allennlp.training.trainer import move_optimizer_to_cuda 17 | from torch.nn import DataParallel 18 | 19 | 20 | def time_batch(gen, reset_every=100): 21 | """ 22 | Gets timing info for a batch 23 | :param gen: 24 | :param reset_every: How often we'll reset 25 | :return: 26 | """ 27 | start = time.time() 28 | start_t = 0 29 | for i, item in enumerate(gen): 30 | time_per_batch = (time.time() - start) / (i + 1 - start_t) 31 | yield time_per_batch, item 32 | if i % reset_every == 0: 33 | start = time.time() 34 | start_t = i 35 | 36 | 37 | class Flattener(torch.nn.Module): 38 | def __init__(self): 39 | """ 40 | Flattens last 3 dimensions to make it only batch size, -1 41 | """ 42 | super(Flattener, self).__init__() 43 | 44 | def forward(self, x): 45 | return x.view(x.size(0), -1) 46 | 47 | 48 | def pad_sequence(sequence, lengths): 49 | """ 50 | :param sequence: [\sum b, .....] sequence 51 | :param lengths: [b1, b2, b3...] that sum to \sum b 52 | :return: [len(lengths), maxlen(b), .....] tensor 53 | """ 54 | output = sequence.new_zeros(len(lengths), max(lengths), *sequence.shape[1:]) 55 | start = 0 56 | for i, diff in enumerate(lengths): 57 | if diff > 0: 58 | output[i, :diff] = sequence[start:(start + diff)] 59 | start += diff 60 | return output 61 | 62 | 63 | def extra_leading_dim_in_sequence(f, x, mask): 64 | return f(x.view(-1, *x.shape[2:]), mask.view(-1, mask.shape[2])).view(*x.shape[:3], -1) 65 | 66 | 67 | def clip_grad_norm(named_parameters, max_norm, clip=True, verbose=False): 68 | """Clips gradient norm of an iterable of parameters. 69 | 70 | The norm is computed over all gradients together, as if they were 71 | concatenated into a single vector. Gradients are modified in-place. 72 | 73 | Arguments: 74 | parameters (Iterable[Variable]): an iterable of Variables that will have 75 | gradients normalized 76 | max_norm (float or int): max norm of the gradients 77 | 78 | Returns: 79 | Total norm of the parameters (viewed as a single vector). 80 | """ 81 | max_norm = float(max_norm) 82 | parameters = [(n, p) for n, p in named_parameters if p.grad is not None] 83 | total_norm = 0 84 | param_to_norm = {} 85 | param_to_shape = {} 86 | for n, p in parameters: 87 | param_norm = p.grad.data.norm(2) 88 | total_norm += param_norm ** 2 89 | param_to_norm[n] = param_norm 90 | param_to_shape[n] = tuple(p.size()) 91 | if np.isnan(param_norm.item()): 92 | raise ValueError("the param {} was null.".format(n)) 93 | 94 | total_norm = total_norm ** (1. / 2) 95 | clip_coef = max_norm / (total_norm + 1e-6) 96 | if clip_coef.item() < 1 and clip: 97 | for n, p in parameters: 98 | p.grad.data.mul_(clip_coef) 99 | 100 | if verbose: 101 | print('---Total norm {:.3f} clip coef {:.3f}-----------------'.format(total_norm, clip_coef)) 102 | for name, norm in sorted(param_to_norm.items(), key=lambda x: -x[1]): 103 | print("{:<60s}: {:.3f}, ({}: {})".format(name, norm, np.prod(param_to_shape[name]), param_to_shape[name])) 104 | print('-------------------------------', flush=True) 105 | 106 | return pd.Series({name: norm.item() for name, norm in param_to_norm.items()}) 107 | 108 | 109 | def find_latest_checkpoint(serialization_dir): 110 | """ 111 | Return the location of the latest model and training state files. 112 | If there isn't a valid checkpoint then return None. 113 | """ 114 | have_checkpoint = (serialization_dir is not None and 115 | any("model_state_epoch_" in x for x in os.listdir(serialization_dir))) 116 | 117 | if not have_checkpoint: 118 | return None 119 | 120 | serialization_files = os.listdir(serialization_dir) 121 | model_checkpoints = [x for x in serialization_files if "model_state_epoch" in x] 122 | # Get the last checkpoint file. Epochs are specified as either an 123 | # int (for end of epoch files) or with epoch and timestamp for 124 | # within epoch checkpoints, e.g. 5.2018-02-02-15-33-42 125 | found_epochs = [ 126 | # pylint: disable=anomalous-backslash-in-string 127 | re.search("model_state_epoch_([0-9\.\-]+)\.th", x).group(1) 128 | for x in model_checkpoints 129 | ] 130 | int_epochs = [] 131 | for epoch in found_epochs: 132 | pieces = epoch.split('.') 133 | if len(pieces) == 1: 134 | # Just a single epoch without timestamp 135 | int_epochs.append([int(pieces[0]), 0]) 136 | else: 137 | # has a timestamp 138 | int_epochs.append([int(pieces[0]), pieces[1]]) 139 | last_epoch = sorted(int_epochs, reverse=True)[0] 140 | if last_epoch[1] == 0: 141 | epoch_to_load = str(last_epoch[0]) 142 | else: 143 | epoch_to_load = '{0}.{1}'.format(last_epoch[0], last_epoch[1]) 144 | 145 | model_path = os.path.join(serialization_dir, 146 | "model_state_epoch_{}.th".format(epoch_to_load)) 147 | training_state_path = os.path.join(serialization_dir, 148 | "training_state_epoch_{}.th".format(epoch_to_load)) 149 | return model_path, training_state_path 150 | 151 | 152 | def save_checkpoint(model, optimizer, serialization_dir, epoch, val_metric_per_epoch, is_best=None, 153 | learning_rate_scheduler=None) -> None: 154 | """ 155 | Saves a checkpoint of the model to self._serialization_dir. 156 | Is a no-op if self._serialization_dir is None. 157 | Parameters 158 | ---------- 159 | epoch : Union[int, str], required. 160 | The epoch of training. If the checkpoint is saved in the middle 161 | of an epoch, the parameter is a string with the epoch and timestamp. 162 | is_best: bool, optional (default = None) 163 | A flag which causes the model weights at the given epoch to 164 | be copied to a "best.th" file. The value of this flag should 165 | be based on some validation metric computed by your model. 166 | """ 167 | if serialization_dir is not None: 168 | model_path = os.path.join(serialization_dir, "model_state_epoch_{}.th".format(epoch)) 169 | model_state = model.module.state_dict() if isinstance(model, DataParallel) else model.state_dict() 170 | torch.save(model_state, model_path) 171 | 172 | training_state = {'epoch': epoch, 173 | 'val_metric_per_epoch': val_metric_per_epoch, 174 | 'optimizer': optimizer.state_dict() 175 | } 176 | if learning_rate_scheduler is not None: 177 | training_state["learning_rate_scheduler"] = \ 178 | learning_rate_scheduler.lr_scheduler.state_dict() 179 | training_path = os.path.join(serialization_dir, 180 | "training_state_epoch_{}.th".format(epoch)) 181 | torch.save(training_state, training_path) 182 | if is_best: 183 | print("Best validation performance so far. Copying weights to '{}/best.th'.".format(serialization_dir)) 184 | shutil.copyfile(model_path, os.path.join(serialization_dir, "best.th")) 185 | 186 | 187 | def restore_best_checkpoint(model, serialization_dir): 188 | fn = os.path.join(serialization_dir, 'best.th') 189 | model_state = torch.load(fn, map_location=device_mapping(-1)) 190 | assert os.path.exists(fn) 191 | if isinstance(model, DataParallel): 192 | model.module.load_state_dict(model_state) 193 | else: 194 | model.load_state_dict(model_state) 195 | 196 | 197 | def restore_checkpoint(model, optimizer, serialization_dir, learning_rate_scheduler=None): 198 | """ 199 | Restores a model from a serialization_dir to the last saved checkpoint. 200 | This includes an epoch count and optimizer state, which is serialized separately 201 | from model parameters. This function should only be used to continue training - 202 | if you wish to load a model for inference/load parts of a model into a new 203 | computation graph, you should use the native Pytorch functions: 204 | `` model.load_state_dict(torch.load("/path/to/model/weights.th"))`` 205 | If ``self._serialization_dir`` does not exist or does not contain any checkpointed weights, 206 | this function will do nothing and return 0. 207 | Returns 208 | ------- 209 | epoch: int 210 | The epoch at which to resume training, which should be one after the epoch 211 | in the saved training state. 212 | """ 213 | latest_checkpoint = find_latest_checkpoint(serialization_dir) 214 | 215 | if latest_checkpoint is None: 216 | # No checkpoint to restore, start at 0 217 | return 0, [] 218 | 219 | model_path, training_state_path = latest_checkpoint 220 | 221 | # Load the parameters onto CPU, then transfer to GPU. 222 | # This avoids potential OOM on GPU for large models that 223 | # load parameters onto GPU then make a new GPU copy into the parameter 224 | # buffer. The GPU transfer happens implicitly in load_state_dict. 225 | model_state = torch.load(model_path, map_location=device_mapping(-1)) 226 | training_state = torch.load(training_state_path, map_location=device_mapping(-1)) 227 | if isinstance(model, DataParallel): 228 | model.module.load_state_dict(model_state) 229 | else: 230 | model.load_state_dict(model_state) 231 | 232 | # idk this is always bad luck for me 233 | optimizer.load_state_dict(training_state["optimizer"]) 234 | 235 | if learning_rate_scheduler is not None and "learning_rate_scheduler" in training_state: 236 | learning_rate_scheduler.lr_scheduler.load_state_dict( 237 | training_state["learning_rate_scheduler"]) 238 | move_optimizer_to_cuda(optimizer) 239 | 240 | # We didn't used to save `validation_metric_per_epoch`, so we can't assume 241 | # that it's part of the trainer state. If it's not there, an empty list is all 242 | # we can do. 243 | if "val_metric_per_epoch" not in training_state: 244 | print("trainer state `val_metric_per_epoch` not found, using empty list") 245 | val_metric_per_epoch: [] 246 | else: 247 | val_metric_per_epoch = training_state["val_metric_per_epoch"] 248 | 249 | if isinstance(training_state["epoch"], int): 250 | epoch_to_return = training_state["epoch"] + 1 251 | else: 252 | epoch_to_return = int(training_state["epoch"].split('.')[0]) + 1 253 | return epoch_to_return, val_metric_per_epoch 254 | 255 | 256 | def detokenize(array, vocab): 257 | """ 258 | Given an array of ints, we'll turn this into a string or a list of strings. 259 | :param array: possibly multidimensional numpy array 260 | :return: 261 | """ 262 | if array.ndim > 1: 263 | return [detokenize(x, vocab) for x in array] 264 | tokenized = [vocab.get_token_from_index(v) for v in array] 265 | return ' '.join([x for x in tokenized if x not in (vocab._padding_token, START_SYMBOL, END_SYMBOL)]) 266 | 267 | 268 | def print_para(model): 269 | """ 270 | Prints parameters of a model 271 | :param opt: 272 | :return: 273 | """ 274 | st = {} 275 | total_params = 0 276 | total_params_training = 0 277 | for p_name, p in model.named_parameters(): 278 | # if not ('bias' in p_name.split('.')[-1] or 'bn' in p_name.split('.')[-1]): 279 | st[p_name] = ([str(x) for x in p.size()], np.prod(p.size()), p.requires_grad) 280 | total_params += np.prod(p.size()) 281 | if p.requires_grad: 282 | total_params_training += np.prod(p.size()) 283 | pd.set_option('display.max_columns', None) 284 | shapes_df = pd.DataFrame([(p_name, '[{}]'.format(','.join(size)), prod, p_req_grad) 285 | for p_name, (size, prod, p_req_grad) in sorted(st.items(), key=lambda x: -x[1][1])], 286 | columns=['name', 'shape', 'size', 'requires_grad']).set_index('name') 287 | 288 | print('\n {:.1f}M total parameters. {:.1f}M training \n ----- \n {} \n ----'.format(total_params / 1000000.0, 289 | total_params_training / 1000000.0, 290 | shapes_df.to_string()), 291 | flush=True) 292 | return shapes_df 293 | 294 | 295 | def batch_index_iterator(len_l, batch_size, skip_end=True): 296 | """ 297 | Provides indices that iterate over a list 298 | :param len_l: int representing size of thing that we will 299 | iterate over 300 | :param batch_size: size of each batch 301 | :param skip_end: if true, don't iterate over the last batch 302 | :return: A generator that returns (start, end) tuples 303 | as it goes through all batches 304 | """ 305 | iterate_until = len_l 306 | if skip_end: 307 | iterate_until = (len_l // batch_size) * batch_size 308 | 309 | for b_start in range(0, iterate_until, batch_size): 310 | yield (b_start, min(b_start + batch_size, len_l)) 311 | 312 | 313 | def batch_iterator(seq, batch_size, skip_end=True): 314 | for b_start, b_end in batch_index_iterator(len(seq), batch_size, skip_end=skip_end): 315 | yield seq[b_start:b_end] 316 | --------------------------------------------------------------------------------